Unverified Commit 78fa316a authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Modify multiGPU example to use GPU sampling. (#6961)

parent f6db850d
...@@ -126,9 +126,6 @@ def create_dataloader( ...@@ -126,9 +126,6 @@ def create_dataloader(
shuffle=shuffle, shuffle=shuffle,
drop_uneven_inputs=drop_uneven_inputs, drop_uneven_inputs=drop_uneven_inputs,
) )
datapipe = datapipe.sample_neighbor(graph, args.fanout)
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
############################################################################ ############################################################################
# [Note]: # [Note]:
# datapipe.copy_to() / gb.CopyTo() # datapipe.copy_to() / gb.CopyTo()
...@@ -137,8 +134,14 @@ def create_dataloader( ...@@ -137,8 +134,14 @@ def create_dataloader(
# [Output]: # [Output]:
# A CopyTo object copying data in the datapipe to a specified device.\ # A CopyTo object copying data in the datapipe to a specified device.\
############################################################################ ############################################################################
datapipe = datapipe.copy_to(device) if not args.cpu_sampling:
dataloader = gb.DataLoader(datapipe, num_workers=args.num_workers) datapipe = datapipe.copy_to(device, extra_attrs=["seed_nodes"])
datapipe = datapipe.sample_neighbor(graph, args.fanout)
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
if args.cpu_sampling:
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe, args.num_workers)
# Return the fully-initialized DataLoader object. # Return the fully-initialized DataLoader object.
return dataloader return dataloader
...@@ -272,15 +275,18 @@ def run(rank, world_size, args, devices, dataset): ...@@ -272,15 +275,18 @@ def run(rank, world_size, args, devices, dataset):
rank=rank, rank=rank,
) )
graph = dataset.graph # Pin the graph and features to enable GPU access.
features = dataset.feature if not args.cpu_sampling:
dataset.graph.pin_memory_()
dataset.feature.pin_memory_()
train_set = dataset.tasks[0].train_set train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set valid_set = dataset.tasks[0].validation_set
test_set = dataset.tasks[0].test_set test_set = dataset.tasks[0].test_set
args.fanout = list(map(int, args.fanout.split(","))) args.fanout = list(map(int, args.fanout.split(",")))
num_classes = dataset.tasks[0].metadata["num_classes"] num_classes = dataset.tasks[0].metadata["num_classes"]
in_size = features.size("node", None, "feat")[0] in_size = dataset.feature.size("node", None, "feat")[0]
hidden_size = 256 hidden_size = 256
out_size = num_classes out_size = num_classes
...@@ -291,8 +297,8 @@ def run(rank, world_size, args, devices, dataset): ...@@ -291,8 +297,8 @@ def run(rank, world_size, args, devices, dataset):
# Create data loaders. # Create data loaders.
train_dataloader = create_dataloader( train_dataloader = create_dataloader(
args, args,
graph, dataset.graph,
features, dataset.feature,
train_set, train_set,
device, device,
drop_last=False, drop_last=False,
...@@ -301,8 +307,8 @@ def run(rank, world_size, args, devices, dataset): ...@@ -301,8 +307,8 @@ def run(rank, world_size, args, devices, dataset):
) )
valid_dataloader = create_dataloader( valid_dataloader = create_dataloader(
args, args,
graph, dataset.graph,
features, dataset.feature,
valid_set, valid_set,
device, device,
drop_last=False, drop_last=False,
...@@ -311,8 +317,8 @@ def run(rank, world_size, args, devices, dataset): ...@@ -311,8 +317,8 @@ def run(rank, world_size, args, devices, dataset):
) )
test_dataloader = create_dataloader( test_dataloader = create_dataloader(
args, args,
graph, dataset.graph,
features, dataset.feature,
test_set, test_set,
device, device,
drop_last=False, drop_last=False,
...@@ -387,6 +393,11 @@ def parse_args(): ...@@ -387,6 +393,11 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--num-workers", type=int, default=0, help="The number of processes." "--num-workers", type=int, default=0, help="The number of processes."
) )
parser.add_argument(
"--cpu-sampling",
action="store_true",
help="Disables GPU sampling and utilizes the CPU for dataloading.",
)
return parser.parse_args() return parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment