Unverified Commit 9bb36f1a authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add opt `output_cscformat` to examples. (#6773)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent de14619c
...@@ -175,7 +175,11 @@ def create_dataloader(args, graph, features, itemset, is_train=True): ...@@ -175,7 +175,11 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# [Role]: # [Role]:
# Initialize a neighbor sampler for sampling the neighborhoods of nodes. # Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################ ############################################################################
datapipe = datapipe.sample_neighbor(graph, args.fanout) datapipe = datapipe.sample_neighbor(
graph,
args.fanout,
output_cscformat=(args.output_cscformat == "True"),
)
############################################################################ ############################################################################
# [Input]: # [Input]:
...@@ -371,6 +375,12 @@ def parse_args(): ...@@ -371,6 +375,12 @@ def parse_args():
choices=["cpu", "cuda"], choices=["cpu", "cuda"],
help="Train device: 'cpu' for CPU, 'cuda' for GPU.", help="Train device: 'cpu' for CPU, 'cuda' for GPU.",
) )
parser.add_argument(
"--output_cscformat",
default="False",
choices=["False", "True"],
help="Output type of SampledSubgraph. True for csc_formats, False for node_pairs.",
)
return parser.parse_args() return parser.parse_args()
......
...@@ -50,7 +50,15 @@ from tqdm import tqdm ...@@ -50,7 +50,15 @@ from tqdm import tqdm
def create_dataloader( def create_dataloader(
graph, features, itemset, batch_size, fanout, device, num_workers, job graph,
features,
itemset,
batch_size,
fanout,
device,
num_workers,
job,
output_cscformat,
): ):
""" """
[HIGHLIGHT] [HIGHLIGHT]
...@@ -105,7 +113,9 @@ def create_dataloader( ...@@ -105,7 +113,9 @@ def create_dataloader(
# Initialize a neighbor sampler for sampling the neighborhoods of nodes. # Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################ ############################################################################
datapipe = datapipe.sample_neighbor( datapipe = datapipe.sample_neighbor(
graph, fanout if job != "infer" else [-1] graph,
fanout if job != "infer" else [-1],
output_cscformat=(output_cscformat == "True"),
) )
############################################################################ ############################################################################
...@@ -230,6 +240,7 @@ def layerwise_infer( ...@@ -230,6 +240,7 @@ def layerwise_infer(
device=args.device, device=args.device,
num_workers=args.num_workers, num_workers=args.num_workers,
job="infer", job="infer",
output_cscformat=args.output_cscformat,
) )
pred = model.inference(graph, features, dataloader, args.device) pred = model.inference(graph, features, dataloader, args.device)
pred = pred[test_set._items[0]] pred = pred[test_set._items[0]]
...@@ -257,6 +268,7 @@ def evaluate(args, model, graph, features, itemset, num_classes): ...@@ -257,6 +268,7 @@ def evaluate(args, model, graph, features, itemset, num_classes):
device=args.device, device=args.device,
num_workers=args.num_workers, num_workers=args.num_workers,
job="evaluate", job="evaluate",
output_cscformat=args.output_cscformat,
) )
for step, data in tqdm(enumerate(dataloader)): for step, data in tqdm(enumerate(dataloader)):
...@@ -283,6 +295,7 @@ def train(args, graph, features, train_set, valid_set, num_classes, model): ...@@ -283,6 +295,7 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
device=args.device, device=args.device,
num_workers=args.num_workers, num_workers=args.num_workers,
job="train", job="train",
output_cscformat=args.output_cscformat,
) )
for epoch in range(args.epochs): for epoch in range(args.epochs):
...@@ -354,6 +367,12 @@ def parse_args(): ...@@ -354,6 +367,12 @@ def parse_args():
choices=["cpu", "cuda"], choices=["cpu", "cuda"],
help="Train device: 'cpu' for CPU, 'cuda' for GPU.", help="Train device: 'cpu' for CPU, 'cuda' for GPU.",
) )
parser.add_argument(
"--output_cscformat",
default="False",
choices=["False", "True"],
help="Output type of SampledSubgraph. True for csc_formats, False for node_pairs.",
)
return parser.parse_args() return parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment