face_evaluator.py 2.17 KB
Newer Older
dengjb's avatar
update  
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# encoding: utf-8
"""
@author:  xingyu liao
@contact: sherlockliao01@gmail.com
"""

import copy
import io
import logging
import os
from collections import OrderedDict

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from PIL import Image

from fastreid.evaluation import DatasetEvaluator
from fastreid.utils import comm
from fastreid.utils.file_io import PathManager
from .verification import evaluate

logger = logging.getLogger("fastreid.face_evaluator")


def gen_plot(fpr, tpr):
    """Create a pyplot plot and save to buffer."""
    plt.figure()
    plt.xlabel("FPR", fontsize=14)
    plt.ylabel("TPR", fontsize=14)
    plt.title("ROC Curve", fontsize=14)
    plt.plot(fpr, tpr, linewidth=2)
    buf = io.BytesIO()
    plt.savefig(buf, format='jpeg')
    buf.seek(0)
    plt.close()
    return buf


class FaceEvaluator(DatasetEvaluator):
    def __init__(self, cfg, labels, dataset_name, output_dir=None):
        self.cfg = cfg
        self.labels = labels
        self.dataset_name = dataset_name
        self._output_dir = output_dir

        self.features = []

    def reset(self):
        self.features = []

    def process(self, inputs, outputs):
        self.features.append(outputs.cpu())

    def evaluate(self):
        if comm.get_world_size() > 1:
            comm.synchronize()
            features = comm.gather(self.features)
            features = sum(features, [])

            # fmt: off
            if not comm.is_main_process(): return {}
            # fmt: on
        else:
            features = self.features

        features = torch.cat(features, dim=0)
        features = F.normalize(features, p=2, dim=1).numpy()

        self._results = OrderedDict()
        tpr, fpr, accuracy, best_thresholds = evaluate(features, self.labels)

        self._results["Accuracy"] = accuracy.mean() * 100
        self._results["Threshold"] = best_thresholds.mean()
        self._results["metric"] = accuracy.mean() * 100

        buf = gen_plot(fpr, tpr)
        roc_curve = Image.open(buf)

        PathManager.mkdirs(self._output_dir)
        roc_curve.save(os.path.join(self._output_dir, self.dataset_name + "_roc.png"))

        return copy.deepcopy(self._results)