attr_evaluation.py 3.25 KB
Newer Older
dengjb's avatar
update  
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# encoding: utf-8
"""
@author:  liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import copy
import logging
from collections import OrderedDict

import torch

from fastreid.evaluation.evaluator import DatasetEvaluator
from fastreid.utils import comm

logger = logging.getLogger("fastreid.attr_evaluation")


class AttrEvaluator(DatasetEvaluator):
    def __init__(self, cfg, attr_dict, thres=0.5, output_dir=None):
        self.cfg = cfg
        self.attr_dict = attr_dict
        self.thres = thres
        self._output_dir = output_dir

        self._cpu_device = torch.device("cpu")

        self.pred_logits = []
        self.gt_labels = []

    def reset(self):
        self.pred_logits = []
        self.gt_labels = []

    def process(self, inputs, outputs):
        self.gt_labels.extend(inputs["targets"].to(self._cpu_device))
        self.pred_logits.extend(outputs.to(self._cpu_device, torch.float32))

    @staticmethod
    def get_attr_metrics(gt_labels, pred_logits, thres):

        eps = 1e-20

        pred_labels = copy.deepcopy(pred_logits)
        pred_labels[pred_logits < thres] = 0
        pred_labels[pred_logits >= thres] = 1

        # Compute label-based metric
        overlaps = pred_labels * gt_labels
        correct_pos = overlaps.sum(axis=0)
        real_pos = gt_labels.sum(axis=0)
        inv_overlaps = (1 - pred_labels) * (1 - gt_labels)
        correct_neg = inv_overlaps.sum(axis=0)
        real_neg = (1 - gt_labels).sum(axis=0)

        # Compute instance-based accuracy
        pred_labels = pred_labels.astype(bool)
        gt_labels = gt_labels.astype(bool)
        intersect = (pred_labels & gt_labels).astype(float)
        union = (pred_labels | gt_labels).astype(float)
        ins_acc = (intersect.sum(axis=1) / (union.sum(axis=1) + eps)).mean()
        ins_prec = (intersect.sum(axis=1) / (pred_labels.astype(float).sum(axis=1) + eps)).mean()
        ins_rec = (intersect.sum(axis=1) / (gt_labels.astype(float).sum(axis=1) + eps)).mean()
        ins_f1 = (2 * ins_prec * ins_rec) / (ins_prec + ins_rec + eps)

        term1 = correct_pos / (real_pos + eps)
        term2 = correct_neg / (real_neg + eps)
        label_mA_verbose = (term1 + term2) * 0.5
        label_mA = label_mA_verbose.mean()

        results = OrderedDict()
        results["Accu"] = ins_acc * 100
        results["Prec"] = ins_prec * 100
        results["Recall"] = ins_rec * 100
        results["F1"] = ins_f1 * 100
        results["mA"] = label_mA * 100
        results["metric"] = label_mA * 100
        return results

    def evaluate(self):
        if comm.get_world_size() > 1:
            comm.synchronize()
            pred_logits = comm.gather(self.pred_logits)
            pred_logits = sum(pred_logits, [])

            gt_labels = comm.gather(self.gt_labels)
            gt_labels = sum(gt_labels, [])

            if not comm.is_main_process():
                return {}
        else:
            pred_logits = self.pred_logits
            gt_labels = self.gt_labels

        pred_logits = torch.stack(pred_logits, dim=0).numpy()
        gt_labels = torch.stack(gt_labels, dim=0).numpy()

        # Pedestrian attribute metrics
        thres = self.cfg.TEST.THRES
        self._results = self.get_attr_metrics(gt_labels, pred_logits, thres)

        return copy.deepcopy(self._results)