Unverified Commit e60262d3 authored by caojy1998's avatar caojy1998 Committed by GitHub
Browse files

[Benchmark] Modify node_classification test to enable `sample_layer_neighbors`. (#7129)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-21-37.ap-northeast-1.compute.internal>
parent bfd7cee1
...@@ -117,7 +117,7 @@ def create_dataloader( ...@@ -117,7 +117,7 @@ def create_dataloader(
# [Role]: # [Role]:
# Initialize a neighbor sampler for sampling the neighborhoods of nodes. # Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################ ############################################################################
datapipe = datapipe.sample_neighbor( datapipe = getattr(datapipe, args.sample_mode)(
graph, fanout if job != "infer" else [-1] graph, fanout if job != "infer" else [-1]
) )
...@@ -157,7 +157,11 @@ def create_dataloader( ...@@ -157,7 +157,11 @@ def create_dataloader(
# [Role]: # [Role]:
# Initialize a multi-process dataloader to load the data in parallel. # Initialize a multi-process dataloader to load the data in parallel.
############################################################################ ############################################################################
dataloader = gb.DataLoader(datapipe, num_workers=num_workers) dataloader = gb.DataLoader(
datapipe,
num_workers=num_workers,
overlap_graph_fetch=args.overlap_graph_fetch,
)
# Return the fully-initialized DataLoader object. # Return the fully-initialized DataLoader object.
return dataloader return dataloader
...@@ -357,6 +361,13 @@ def parse_args(): ...@@ -357,6 +361,13 @@ def parse_args():
help="Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)" help="Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)"
" identical with the number of layers in your model. Default: 10,10,10", " identical with the number of layers in your model. Default: 10,10,10",
) )
parser.add_argument(
"--dataset",
type=str,
default="ogbn-products",
help="The dataset we can use for node classification example. Currently"
"dataset ogbn-products, ogbn-arxiv, ogbn-papers100M is supported.",
)
parser.add_argument( parser.add_argument(
"--mode", "--mode",
default="pinned-cuda", default="pinned-cuda",
...@@ -364,6 +375,20 @@ def parse_args(): ...@@ -364,6 +375,20 @@ def parse_args():
help="Dataset storage placement and Train device: 'cpu' for CPU and RAM," help="Dataset storage placement and Train device: 'cpu' for CPU and RAM,"
" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.", " 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.",
) )
parser.add_argument(
"--sample-mode",
default="sample_neighbor",
choices=["sample_neighbor", "sample_layer_neighbor"],
help="The sampling function when doing layerwise sampling.",
)
parser.add_argument(
"--overlap-graph-fetch",
action="store_true",
help="An option for enabling overlap_graph_fetch in graphbolt dataloader."
"If True, the data loader will overlap the UVA graph fetching operations"
"with the rest of operations by using an alternative CUDA stream. Disabled"
"by default.",
)
return parser.parse_args() return parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment