Unverified Commit 78fa316a authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Modify multiGPU example to use GPU sampling. (#6961)

parent f6db850d
......@@ -126,9 +126,6 @@ def create_dataloader(
shuffle=shuffle,
drop_uneven_inputs=drop_uneven_inputs,
)
datapipe = datapipe.sample_neighbor(graph, args.fanout)
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
############################################################################
# [Note]:
# datapipe.copy_to() / gb.CopyTo()
......@@ -137,8 +134,14 @@ def create_dataloader(
# [Output]:
# A CopyTo object copying data in the datapipe to a specified device.\
############################################################################
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe, num_workers=args.num_workers)
if not args.cpu_sampling:
datapipe = datapipe.copy_to(device, extra_attrs=["seed_nodes"])
datapipe = datapipe.sample_neighbor(graph, args.fanout)
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
if args.cpu_sampling:
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe, args.num_workers)
# Return the fully-initialized DataLoader object.
return dataloader
......@@ -272,15 +275,18 @@ def run(rank, world_size, args, devices, dataset):
rank=rank,
)
graph = dataset.graph
features = dataset.feature
# Pin the graph and features to enable GPU access.
if not args.cpu_sampling:
dataset.graph.pin_memory_()
dataset.feature.pin_memory_()
train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set
test_set = dataset.tasks[0].test_set
args.fanout = list(map(int, args.fanout.split(",")))
num_classes = dataset.tasks[0].metadata["num_classes"]
in_size = features.size("node", None, "feat")[0]
in_size = dataset.feature.size("node", None, "feat")[0]
hidden_size = 256
out_size = num_classes
......@@ -291,8 +297,8 @@ def run(rank, world_size, args, devices, dataset):
# Create data loaders.
train_dataloader = create_dataloader(
args,
graph,
features,
dataset.graph,
dataset.feature,
train_set,
device,
drop_last=False,
......@@ -301,8 +307,8 @@ def run(rank, world_size, args, devices, dataset):
)
valid_dataloader = create_dataloader(
args,
graph,
features,
dataset.graph,
dataset.feature,
valid_set,
device,
drop_last=False,
......@@ -311,8 +317,8 @@ def run(rank, world_size, args, devices, dataset):
)
test_dataloader = create_dataloader(
args,
graph,
features,
dataset.graph,
dataset.feature,
test_set,
device,
drop_last=False,
......@@ -387,6 +393,11 @@ def parse_args():
parser.add_argument(
"--num-workers", type=int, default=0, help="The number of processes."
)
parser.add_argument(
"--cpu-sampling",
action="store_true",
help="Disables GPU sampling and utilizes the CPU for dataloading.",
)
return parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment