Unverified Commit c81ff6ad authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Fix fanouts setting in rgcn example (#6959)

parent bd74c44c
......@@ -430,6 +430,7 @@ def evaluate(
else:
evaluator = MAG240MEvaluator()
num_etype = len(g.num_edges)
data_loader = create_dataloader(
name,
g,
......@@ -437,7 +438,7 @@ def evaluate(
item_set,
device,
batch_size=4096,
fanouts=[25, 10],
fanouts=[torch.full((num_etype,), 25), torch.full((num_etype,), 10)],
shuffle=False,
num_workers=num_workers,
)
......@@ -491,6 +492,7 @@ def train(
print("Start to train...")
category = "paper"
num_etype = len(g.num_edges)
data_loader = create_dataloader(
name,
g,
......@@ -498,7 +500,7 @@ def train(
train_set,
device,
batch_size=1024,
fanouts=[25, 10],
fanouts=[torch.full((num_etype,), 25), torch.full((num_etype,), 10)],
shuffle=True,
num_workers=num_workers,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment