Unverified Commit ae8cbde5 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[benchmarks] fix torchmetrics accuracy (#5594)

parent b740c3f0
...@@ -61,6 +61,11 @@ def track_acc(dataset, ns_mode): ...@@ -61,6 +61,11 @@ def track_acc(dataset, ns_mode):
with torch.no_grad(): with torch.no_grad():
logits = model(g) logits = model(g)
logits = logits[target_idx] logits = logits[target_idx]
test_acc = accuracy(logits[test_idx].argmax(dim=1), labels[test_idx]).item() test_acc = accuracy(
logits[test_idx].argmax(dim=1),
labels[test_idx],
task="multiclass",
num_classes=num_classes,
).item()
return test_acc return test_acc
...@@ -3,7 +3,6 @@ import time ...@@ -3,7 +3,6 @@ import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torchmetrics.functional import accuracy
from .. import rgcn, utils from .. import rgcn, utils
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment