accuracy.py 623 Bytes
Newer Older
yuguo960516's avatar
yuguo960516 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import numpy as np
import oneflow as flow


class Accuracy(flow.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, preds, labels):
        top1_num = flow.zeros(1, dtype=flow.float32)
        num_samples = 0
        for pred, label in zip(preds, labels):
            clsidxs = pred.argmax(dim=-1)
            clsidxs = clsidxs.to(flow.int32)
            match = (clsidxs == label).sum()
            top1_num += match.to(device=top1_num.device, dtype=top1_num.dtype)
            num_samples += np.prod(label.shape).item()

        top1_acc = top1_num / num_samples
        return top1_acc