Unverified Commit c3a043eb authored by xiangyuzhi's avatar xiangyuzhi Committed by GitHub
Browse files

[Sparse] Fix hetero-rgcn example datasets (#6200)

parent b1153db9
...@@ -182,12 +182,8 @@ def main(args): ...@@ -182,12 +182,8 @@ def main(args):
# Load graph data. # Load graph data.
if args.dataset == "aifb": if args.dataset == "aifb":
dataset = AIFBDataset() dataset = AIFBDataset()
elif args.dataset == "mutag":
dataset = MUTAGDataset()
elif args.dataset == "bgs": elif args.dataset == "bgs":
dataset = BGSDataset() dataset = BGSDataset()
elif args.dataset == "am":
dataset = AMDataset()
else: else:
raise ValueError() raise ValueError()
...@@ -238,7 +234,7 @@ def main(args): ...@@ -238,7 +234,7 @@ def main(args):
# Training loop. # Training loop.
print("start training...") print("start training...")
model.train() model.train()
for epoch in range(20): for epoch in range(10):
optimizer.zero_grad() optimizer.zero_grad()
logits = model(A)[category] logits = model(A)[category]
loss = F.cross_entropy(logits[train_idx], labels[train_idx]) loss = F.cross_entropy(logits[train_idx], labels[train_idx])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment