"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "8abc7aeb715c0149ee0a9982b2d608ce97f55215"
Unverified Commit ae8cbde5 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[benchmarks] fix torchmetrics accuracy (#5594)

parent b740c3f0
......@@ -61,6 +61,11 @@ def track_acc(dataset, ns_mode):
with torch.no_grad():
logits = model(g)
logits = logits[target_idx]
test_acc = accuracy(logits[test_idx].argmax(dim=1), labels[test_idx]).item()
test_acc = accuracy(
logits[test_idx].argmax(dim=1),
labels[test_idx],
task="multiclass",
num_classes=num_classes,
).item()
return test_acc
......@@ -3,7 +3,6 @@ import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics.functional import accuracy
from .. import rgcn, utils
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment