Unverified Commit 0a42d863 authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] GPUCachedFeature for multiGPU example. (#7074)

parent d5b03bcb
......@@ -284,6 +284,12 @@ def run(rank, world_size, args, devices, dataset):
hidden_size = 256
out_size = num_classes
if args.gpu_cache_size > 0:
dataset.feature._features[("node", None, "feat")] = gb.GPUCachedFeature(
dataset.feature._features[("node", None, "feat")],
args.gpu_cache_size,
)
# Create GraphSAGE model. It should be copied onto a GPU as a replica.
model = SAGE(in_size, hidden_size, out_size).to(device)
model = DDP(model)
......@@ -381,6 +387,12 @@ def parse_args():
parser.add_argument(
"--num-workers", type=int, default=0, help="The number of processes."
)
parser.add_argument(
"--gpu-cache-size",
type=int,
default=0,
help="The GPU cache size for input features.",
)
parser.add_argument(
"--mode",
default="pinned-cuda",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment