Unverified Commit 9bb36f1a authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add opt `output_cscformat` to examples. (#6773)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent de14619c
......@@ -175,7 +175,11 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# [Role]:
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################
datapipe = datapipe.sample_neighbor(graph, args.fanout)
datapipe = datapipe.sample_neighbor(
graph,
args.fanout,
output_cscformat=(args.output_cscformat == "True"),
)
############################################################################
# [Input]:
......@@ -371,6 +375,12 @@ def parse_args():
choices=["cpu", "cuda"],
help="Train device: 'cpu' for CPU, 'cuda' for GPU.",
)
parser.add_argument(
"--output_cscformat",
default="False",
choices=["False", "True"],
help="Output type of SampledSubgraph. True for csc_formats, False for node_pairs.",
)
return parser.parse_args()
......
......@@ -50,7 +50,15 @@ from tqdm import tqdm
def create_dataloader(
graph, features, itemset, batch_size, fanout, device, num_workers, job
graph,
features,
itemset,
batch_size,
fanout,
device,
num_workers,
job,
output_cscformat,
):
"""
[HIGHLIGHT]
......@@ -105,7 +113,9 @@ def create_dataloader(
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################
datapipe = datapipe.sample_neighbor(
graph, fanout if job != "infer" else [-1]
graph,
fanout if job != "infer" else [-1],
output_cscformat=(output_cscformat == "True"),
)
############################################################################
......@@ -230,6 +240,7 @@ def layerwise_infer(
device=args.device,
num_workers=args.num_workers,
job="infer",
output_cscformat=args.output_cscformat,
)
pred = model.inference(graph, features, dataloader, args.device)
pred = pred[test_set._items[0]]
......@@ -257,6 +268,7 @@ def evaluate(args, model, graph, features, itemset, num_classes):
device=args.device,
num_workers=args.num_workers,
job="evaluate",
output_cscformat=args.output_cscformat,
)
for step, data in tqdm(enumerate(dataloader)):
......@@ -283,6 +295,7 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
device=args.device,
num_workers=args.num_workers,
job="train",
output_cscformat=args.output_cscformat,
)
for epoch in range(args.epochs):
......@@ -354,6 +367,12 @@ def parse_args():
choices=["cpu", "cuda"],
help="Train device: 'cpu' for CPU, 'cuda' for GPU.",
)
parser.add_argument(
"--output_cscformat",
default="False",
choices=["False", "True"],
help="Output type of SampledSubgraph. True for csc_formats, False for node_pairs.",
)
return parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment