Unverified Commit 391f513e authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Make sure in-place pinning is used in examples. (#7138)

parent 00f33224
......@@ -392,8 +392,12 @@ def main(args):
dataset = gb.BuiltinDataset("ogbl-citation2").load()
# Move the dataset to the selected storage.
graph = dataset.graph.to(args.storage_device)
features = dataset.feature.to(args.storage_device)
if args.storage_device == "pinned":
graph = dataset.graph.pin_memory_()
features = dataset.feature.pin_memory_()
else:
graph = dataset.graph.to(args.storage_device)
features = dataset.feature.to(args.storage_device)
train_set = dataset.tasks[0].train_set
args.fanout = list(map(int, args.fanout.split(",")))
......
......@@ -404,8 +404,12 @@ def main(args):
dataset = gb.BuiltinDataset("ogbn-products").load()
# Move the dataset to the selected storage.
graph = dataset.graph.to(args.storage_device)
features = dataset.feature.to(args.storage_device)
if args.storage_device == "pinned":
graph = dataset.graph.pin_memory_()
features = dataset.feature.pin_memory_()
else:
graph = dataset.graph.to(args.storage_device)
features = dataset.feature.to(args.storage_device)
train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment