retri_evaluator.py 4.6 KB
Newer Older
dengjb's avatar
update  
dengjb committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# encoding: utf-8
"""
@author:  xingyu liao
@contact: sherlockliao01@gmail.com
"""

import copy
import logging
from collections import OrderedDict
from typing import List, Optional, Dict

import faiss
import numpy as np
import torch
import torch.nn.functional as F

from fastreid.evaluation import DatasetEvaluator
from fastreid.utils import comm

logger = logging.getLogger("fastreid.retri_evaluator")


@torch.no_grad()
def recall_at_ks(query_features: torch.Tensor,
                 query_labels: np.ndarray,
                 ks: List[int],
                 gallery_features: Optional[torch.Tensor] = None,
                 gallery_labels: Optional[torch.Tensor] = None,
                 cosine: bool = False) -> Dict[int, float]:
    """
    Compute the recall between samples at each k. This function uses about 8GB of memory.
    Parameters
    ----------
    query_features : torch.Tensor
        Features for each query sample. shape: (num_queries, num_features)
    query_labels : torch.LongTensor
        Labels corresponding to the query features. shape: (num_queries,)
    ks : List[int]
        Values at which to compute the recall.
    gallery_features : torch.Tensor
        Features for each gallery sample. shape: (num_queries, num_features)
    gallery_labels : torch.LongTensor
        Labels corresponding to the gallery features. shape: (num_queries,)
    cosine : bool
        Use cosine distance between samples instead of euclidean distance.
    Returns
    -------
    recalls : Dict[int, float]
        Values of the recall at each k.
    """
    offset = 0
    if gallery_features is None and gallery_labels is None:
        offset = 1
        gallery_features = query_features
        gallery_labels = query_labels
    elif gallery_features is None or gallery_labels is None:
        raise ValueError('gallery_features and gallery_labels needs to be both None or both Tensors.')

    if cosine:
        query_features = F.normalize(query_features, p=2, dim=1)
        gallery_features = F.normalize(gallery_features, p=2, dim=1)

    to_cpu_numpy = lambda x: x.cpu().numpy()
    query_features, gallery_features = map(to_cpu_numpy, [query_features, gallery_features])

    res = faiss.StandardGpuResources()
    flat_config = faiss.GpuIndexFlatConfig()
    flat_config.device = 0

    max_k = max(ks)
    index_function = faiss.GpuIndexFlatIP if cosine else faiss.GpuIndexFlatL2
    index = index_function(res, gallery_features.shape[1], flat_config)
    index.add(gallery_features)
    closest_indices = index.search(query_features, max_k + offset)[1]

    recalls = {}
    for k in ks:
        indices = closest_indices[:, offset:k + offset]
        recalls[k] = (query_labels[:, None] == gallery_labels[indices]).any(1).mean()
    return {k: round(v * 100, 2) for k, v in recalls.items()}


class RetriEvaluator(DatasetEvaluator):
    def __init__(self, cfg, num_query, output_dir=None):
        self.cfg = cfg
        self._num_query = num_query
        self._output_dir = output_dir

        self.recalls = cfg.TEST.RECALLS

        self.features = []
        self.labels = []

    def reset(self):
        self.features = []
        self.labels = []

    def process(self, inputs, outputs):
        self.features.append(outputs.cpu())
        self.labels.extend(inputs["targets"])

    def evaluate(self):
        if comm.get_world_size() > 1:
            comm.synchronize()
            features = comm.gather(self.features)
            features = sum(features, [])

            labels = comm.gather(self.labels)
            labels = sum(labels, [])

            # fmt: off
            if not comm.is_main_process(): return {}
            # fmt: on
        else:
            features = self.features
            labels = self.labels

        features = torch.cat(features, dim=0)
        # query feature, person ids and camera ids
        query_features = features[:self._num_query]
        query_labels = np.asarray(labels[:self._num_query])

        # gallery features, person ids and camera ids
        gallery_features = features[self._num_query:]
        gallery_pids = np.asarray(labels[self._num_query:])

        self._results = OrderedDict()

        if self._num_query == len(features):
            cmc = recall_at_ks(query_features, query_labels, self.recalls, cosine=True)
        else:
            cmc = recall_at_ks(query_features, query_labels, self.recalls,
                               gallery_features, gallery_pids,
                               cosine=True)

        for r in self.recalls:
            self._results['Recall@{}'.format(r)] = cmc[r]
        self._results["metric"] = cmc[self.recalls[0]]

        return copy.deepcopy(self._results)