Unverified Commit bdaccc82 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Bug Fix] Fix KeyError for rgcn-hetero Entity Classification w/ minibatch training (#4022)

* Update model.py

* Update entity_classify_mb.py
parent d70a362d
......@@ -152,7 +152,7 @@ def main(args):
output = model.inference(
g, args.batch_size, 'cuda' if use_cuda else 'cpu', 0, node_embed)
test_pred = output[category][test_idx]
test_labels = labels[test_idx]
test_labels = labels[test_idx].to(test_pred.device)
test_acc = (test_pred.argmax(1) == test_labels).float().mean()
print("Test Acc: {:.4f}".format(test_acc))
print()
......
......@@ -365,7 +365,7 @@ class EntityClassify(nn.Module):
h = {k: x[k][input_nodes[k]].to(device) for k in input_nodes.keys()}
h = layer(block, h)
for k in h.keys():
for k in output_nodes.keys():
y[k][output_nodes[k]] = h[k].cpu()
x = y
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment