Unverified Commit c81ff6ad authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Fix fanouts setting in rgcn example (#6959)

parent bd74c44c
...@@ -430,6 +430,7 @@ def evaluate( ...@@ -430,6 +430,7 @@ def evaluate(
else: else:
evaluator = MAG240MEvaluator() evaluator = MAG240MEvaluator()
num_etype = len(g.num_edges)
data_loader = create_dataloader( data_loader = create_dataloader(
name, name,
g, g,
...@@ -437,7 +438,7 @@ def evaluate( ...@@ -437,7 +438,7 @@ def evaluate(
item_set, item_set,
device, device,
batch_size=4096, batch_size=4096,
fanouts=[25, 10], fanouts=[torch.full((num_etype,), 25), torch.full((num_etype,), 10)],
shuffle=False, shuffle=False,
num_workers=num_workers, num_workers=num_workers,
) )
...@@ -491,6 +492,7 @@ def train( ...@@ -491,6 +492,7 @@ def train(
print("Start to train...") print("Start to train...")
category = "paper" category = "paper"
num_etype = len(g.num_edges)
data_loader = create_dataloader( data_loader = create_dataloader(
name, name,
g, g,
...@@ -498,7 +500,7 @@ def train( ...@@ -498,7 +500,7 @@ def train(
train_set, train_set,
device, device,
batch_size=1024, batch_size=1024,
fanouts=[25, 10], fanouts=[torch.full((num_etype,), 25), torch.full((num_etype,), 10)],
shuffle=True, shuffle=True,
num_workers=num_workers, num_workers=num_workers,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment