Unverified Commit c3a043eb authored by xiangyuzhi's avatar xiangyuzhi Committed by GitHub
Browse files

[Sparse] Fix hetero-rgcn example datasets (#6200)

parent b1153db9
......@@ -182,12 +182,8 @@ def main(args):
# Load graph data.
if args.dataset == "aifb":
dataset = AIFBDataset()
elif args.dataset == "mutag":
dataset = MUTAGDataset()
elif args.dataset == "bgs":
dataset = BGSDataset()
elif args.dataset == "am":
dataset = AMDataset()
else:
raise ValueError()
......@@ -238,7 +234,7 @@ def main(args):
# Training loop.
print("start training...")
model.train()
for epoch in range(20):
for epoch in range(10):
optimizer.zero_grad()
logits = model(A)[category]
loss = F.cross_entropy(logits[train_idx], labels[train_idx])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment