metric.py 660 Bytes
Newer Older
HHL's avatar
v  
HHL committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch


class AccMetric:
    def __call__(self, preds, labels, labels_mask):
        mask = (labels_mask != 0) & (labels != -1)
        correct_nums = float(torch.sum((preds == labels) & mask).detach().cpu().item())
        total_nums = max(float(torch.sum(mask).detach().cpu().item()), 1e-6)
        return correct_nums, total_nums


class AccMulMetric:
    def __call__(self, preds, labels, labels_mask):
        mask = labels_mask != 0
        correct_nums = float(torch.sum((preds == labels).min(1)[0] & mask).detach().cpu().item())
        total_nums = max(float(torch.sum(mask).detach().cpu().item()), 1e-6)
        return correct_nums, total_nums