Unverified Commit 6178897d authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Remove opt `output_cscformat` from examples. (#6834)

parent 3645e493
......@@ -175,11 +175,7 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# [Role]:
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################
datapipe = datapipe.sample_neighbor(
graph,
args.fanout,
output_cscformat=(args.output_cscformat == "True"),
)
datapipe = datapipe.sample_neighbor(graph, args.fanout)
############################################################################
# [Input]:
......@@ -375,12 +371,6 @@ def parse_args():
choices=["cpu", "cuda"],
help="Train device: 'cpu' for CPU, 'cuda' for GPU.",
)
parser.add_argument(
"--output_cscformat",
default="False",
choices=["False", "True"],
help="Output type of SampledSubgraph. True for csc_formats, False for node_pairs.",
)
return parser.parse_args()
......
......@@ -50,15 +50,7 @@ from tqdm import tqdm
def create_dataloader(
graph,
features,
itemset,
batch_size,
fanout,
device,
num_workers,
job,
output_cscformat,
graph, features, itemset, batch_size, fanout, device, num_workers, job
):
"""
[HIGHLIGHT]
......@@ -113,9 +105,7 @@ def create_dataloader(
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################
datapipe = datapipe.sample_neighbor(
graph,
fanout if job != "infer" else [-1],
output_cscformat=(output_cscformat == "True"),
graph, fanout if job != "infer" else [-1]
)
############################################################################
......@@ -240,7 +230,6 @@ def layerwise_infer(
device=args.device,
num_workers=args.num_workers,
job="infer",
output_cscformat=args.output_cscformat,
)
pred = model.inference(graph, features, dataloader, args.device)
pred = pred[test_set._items[0]]
......@@ -268,7 +257,6 @@ def evaluate(args, model, graph, features, itemset, num_classes):
device=args.device,
num_workers=args.num_workers,
job="evaluate",
output_cscformat=args.output_cscformat,
)
for step, data in tqdm(enumerate(dataloader)):
......@@ -295,7 +283,6 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
device=args.device,
num_workers=args.num_workers,
job="train",
output_cscformat=args.output_cscformat,
)
for epoch in range(args.epochs):
......@@ -367,12 +354,6 @@ def parse_args():
choices=["cpu", "cuda"],
help="Train device: 'cpu' for CPU, 'cuda' for GPU.",
)
parser.add_argument(
"--output_cscformat",
default="False",
choices=["False", "True"],
help="Output type of SampledSubgraph. True for csc_formats, False for node_pairs.",
)
return parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment