Unverified Commit 701b746b authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Example][Bugfix] graphsage node classification example (#4260)



* Fix node_classification.py

* Minor update
Co-authored-by: default avatarnv-dlasalle <63612878+nv-dlasalle@users.noreply.github.com>
parent 4dc5728a
...@@ -72,9 +72,9 @@ def evaluate(model, graph, dataloader): ...@@ -72,9 +72,9 @@ def evaluate(model, graph, dataloader):
def layerwise_infer(device, graph, nid, model, batch_size): def layerwise_infer(device, graph, nid, model, batch_size):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
pred = model.inference(graph, device, batch_size).to(device) pred = model.inference(graph, device, batch_size) # pred in buffer_device
pred = pred[nid] pred = pred[nid]
label = graph.ndata['label'][nid] label = graph.ndata['label'][nid].to(pred.device)
return MF.accuracy(pred, label) return MF.accuracy(pred, label)
def train(args, device, g, dataset, model): def train(args, device, g, dataset, model):
...@@ -141,5 +141,5 @@ if __name__ == '__main__': ...@@ -141,5 +141,5 @@ if __name__ == '__main__':
# test the model # test the model
print('Testing...') print('Testing...')
acc = layerwise_infer(device, g, dataset.test_idx.to(device), model, batch_size=4096) acc = layerwise_infer(device, g, dataset.test_idx, model, batch_size=4096)
print("Test Accuracy {:.4f}".format(acc.item())) print("Test Accuracy {:.4f}".format(acc.item()))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment