matcher.py 13 KB
Newer Older
mibaumgartner's avatar
core  
mibaumgartner committed
1
2
3
4
5
6
7
from typing import Sequence, Callable, Tuple, TypeVar
from abc import ABC

import torch
from torch import Tensor
from loguru import logger

mibaumgartner's avatar
mibaumgartner committed
8
from nndet.core.boxes.ops import box_iou, box_center_dist, center_in_boxes
mibaumgartner's avatar
core  
mibaumgartner committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299

INF = 100  # not really inv but here it is sufficient


class Matcher(ABC):
    BELOW_LOW_THRESHOLD: int = -1
    BETWEEN_THRESHOLDS: int = -2

    def __init__(self, similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou):
        """
        Matches boxes and anchors to each other

        Args:
            similarity_fn: function for similarity computation between
                boxes and anchors
        """
        self.similarity_fn = similarity_fn

    def __call__(self,
                 boxes: torch.Tensor,
                 anchors: torch.Tensor,
                 num_anchors_per_level: Sequence[int],
                 num_anchors_per_loc: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute matches for a single image

        Args:
            boxes: anchors are matches to these boxes (e.g. ground truth)
                [N, dims * 2](x1, y1, x2, y2, (z1, z2))
            anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
            num_anchors_per_level: number of anchors per feature pyramid level
            num_anchors_per_loc: number of anchors per position

        Returns:
            Tensor: matrix which contains the similarity from each boxes
                to each anchor [N, M]
            Tensor: vector which contains the matched box index for all
                anchors (if background `BELOW_LOW_THRESHOLD` is used
                and if it should be ignored `BETWEEN_THRESHOLDS` is used)
                [M]
        """
        if boxes.numel() == 0:
            # no ground truth
            num_anchors = anchors.shape[0]
            match_quality_matrix = torch.tensor([]).to(anchors)
            matches = torch.empty(num_anchors, dtype=torch.int64).fill_(self.BELOW_LOW_THRESHOLD)
            return match_quality_matrix, matches
        else:
            # at least one ground truth
            return self.compute_matches(
                boxes=boxes, anchors=anchors,
                num_anchors_per_level=num_anchors_per_level,
                num_anchors_per_loc=num_anchors_per_loc,
                )

    def compute_matches(self,
                        boxes: torch.Tensor,
                        anchors: torch.Tensor,
                        num_anchors_per_level: Sequence[int],
                        num_anchors_per_loc: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute matches

        Args:
            boxes: anchors are matches to these boxes (e.g. ground truth)
                [N, dims * 2](x1, y1, x2, y2, (z1, z2))
            anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
            num_anchors_per_level: number of anchors per feature pyramid level
            num_anchors_per_loc: number of anchors per position

        Returns:
            Tensor: matrix which contains the similarity from each boxes
                to each anchor [N, M]
            Tensor: vector which contains the matched box index for all
                anchors (if background `BELOW_LOW_THRESHOLD` is used
                and if it should be ignored `BETWEEN_THRESHOLDS` is used)
                [M]
        """
        raise NotImplementedError


class IoUMatcher(Matcher):
    def __init__(self,
                 low_threshold: float,
                 high_threshold: float,
                 allow_low_quality_matches: bool,
                 similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou):
        """
        Compute IoU based matching for a single image

        Args:
            low_threshold: threshold used to assign background values
            high_threshold: threshold used to assign foreground values
            allow_low_quality_matches: if enabled, anchors with not
                match get the box with highest IoU assigned
            similarity_fn: function for similarity computation between
                boxes and anchors
        """
        super().__init__(similarity_fn=similarity_fn)
        assert low_threshold <= high_threshold
        self.high_threshold = high_threshold
        self.low_threshold = low_threshold
        self.allow_low_quality_matches = allow_low_quality_matches

    def compute_matches(self,
                        boxes: torch.Tensor,
                        anchors: torch.Tensor,
                        **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute matches according to given iou thresholds
        Adapted from
        (https://github.com/pytorch/vision/blob/c7c2085ec686ccc55e1df85736b240b24
        05d1179/torchvision/models/detection/_utils.py)

        Args:
            boxes: anchors are matches to these boxes (e.g. ground truth)
                [N, dims * 2](x1, y1, x2, y2, (z1, z2))
            anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
            anchors_per_level: number of anchors per feature pyramid level
            anchors_per_loc: number of anchors per position

        Returns:
            Tensor: matrix which contains the similarity from each boxes
                to each anchor [N, M]
            Tensor: vector which contains the matched box index for all
                anchors (if background `BELOW_LOW_THRESHOLD` is used
                and if it should be ignored `BETWEEN_THRESHOLDS` is used)
                [M]
        """
        match_quality_matrix = self.similarity_fn(boxes, anchors)

        # match_quality_matrix is M (gt) x N (anchors)
        # Max over gt elements (dim 0) to find best gt candidate for each anchor
        matched_vals, matches = match_quality_matrix.max(dim=0)
        # _v, _i = matched_vals.topk(5)
        # print(boxes, _v, anchors[_i])
        if self.allow_low_quality_matches:
            all_matches = matches.clone()

        # Assign candidate matches with low quality to negative (unassigned) values
        below_low_threshold = matched_vals < self.low_threshold
        between_thresholds = (matched_vals >= self.low_threshold) & (
                matched_vals < self.high_threshold
        )
        matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD
        matches[between_thresholds] = self.BETWEEN_THRESHOLDS

        if self.allow_low_quality_matches:
            matches = self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)

        # self._debug_logging(match_quality_matrix, matches, matched_vals,
        #                     below_low_threshold, between_thresholds)

        return match_quality_matrix, matches

    def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):
        """
        Find the best matching prediction for each bounding box
        regardless of its IoU (this implementation excludes ties!)

        Args:
            matches: matched anchors to background and in between
            all_matches: all matches regardless of IoU
            match_quality_matrix: [M,N] tensor of IoUs (GroundTruth x NumAnchors)
        """
        # For each gt, find the prediction with has highest quality
        _, best_pred_idx = match_quality_matrix.max(dim=1)  # [M]
        matches[best_pred_idx] = torch.arange(len(best_pred_idx)).to(matches)
        return matches

    @staticmethod
    def _debug_logging(match_quality_matrix, matches, matched_vals,
                       below_low_threshold, between_thresholds):
        logger.info("########## Matcher ##############")
        logger.info(f"Max IoU: {match_quality_matrix.max()}")
        logger.info(f"Foreground IoUs: {matched_vals[matches > -1]}")
        logger.info(f"Num GT: {match_quality_matrix.shape[0]}")
        match_bet_min = matched_vals[between_thresholds].min() if \
            matched_vals[between_thresholds].nelement() > 0 else None
        match_bet_max = matched_vals[between_thresholds].max() if \
            matched_vals[between_thresholds].nelement() > 0 else None
        logger.info(f"Inbetween IoU ranging from {match_bet_min} to {match_bet_max}")
        logger.info(f"Max background IoU: {matched_vals[below_low_threshold].max()}")
        logger.info("#################################")


class ATSSMatcher(Matcher):
    def __init__(self,
                 num_candidates: int,
                 similarity_fn: Callable[[Tensor, Tensor], Tensor] = box_iou,
                 center_in_gt: bool = True,
                 ):
        """
        Compute matching based on ATSS
        https://arxiv.org/abs/1912.02424
        `Bridging the Gap Between Anchor-based and Anchor-free Detection
        via Adaptive Training Sample Selection`

        Args:
            num_candidates: number of positions to select candidates from
            similarity_fn: function for similarity computation between
                boxes and anchors
            center_in_gt: If diabled, matched anchor center points do not need
                to lie withing the ground truth box.
        """
        super().__init__(similarity_fn=similarity_fn)
        self.num_candidates = num_candidates
        self.min_dist = 0.01
        self.center_in_gt = center_in_gt
        logger.info(f"Running ATSS Matching with num_candidates={self.num_candidates} "
                    f"and center_in_gt {self.center_in_gt}.")

    def compute_matches(self,
                        boxes: torch.Tensor,
                        anchors: torch.Tensor,
                        num_anchors_per_level: Sequence[int],
                        num_anchors_per_loc: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute matches according to ATTS for a single image
        Adapted from
        (https://github.com/sfzhang15/ATSS/blob/79dfb28bd1/atss_core/modeling/rpn/atss
        /loss.py#L180-L184)

        Args:
            boxes: anchors are matches to these boxes (e.g. ground truth)
                [N, dims * 2](x1, y1, x2, y2, (z1, z2))
            anchors: anchors to match [M, dims * 2](x1, y1, x2, y2, (z1, z2))
            num_anchors_per_level: number of anchors per feature pyramid level
            num_anchors_per_loc: number of anchors per position

        Returns:
            Tensor: matrix which contains the similarity from each boxes
                to each anchor [N, M]
            Tensor: vector which contains the matched box index for all
                anchors (if background `BELOW_LOW_THRESHOLD` is used
                and if it should be ignored `BETWEEN_THRESHOLDS` is used)
                [M]
        """
        num_gt = boxes.shape[0]
        num_anchors = anchors.shape[0]

        distances, boxes_center, anchors_center = box_center_dist(boxes, anchors)  # num_boxes x anchors

        # select candidates based on center distance
        candidate_idx = []
        start_idx = 0
        for level, apl in enumerate(num_anchors_per_level):
            end_idx = start_idx + apl

            topk = min(self.num_candidates * num_anchors_per_loc, apl)
            _, idx = distances[:, start_idx: end_idx].topk(topk, dim=1, largest=False)
            # idx shape [num_boxes x topk]
            candidate_idx.append(idx + start_idx)

            start_idx = end_idx
        # [num_boxes x num_candidates] (index of candidate anchors)
        candidate_idx = torch.cat(candidate_idx, dim=1)

        match_quality_matrix = self.similarity_fn(boxes, anchors)  # [num_boxes x anchors]
        candidate_ious = match_quality_matrix.gather(1, candidate_idx)  # [num_boxes, n_candidates]

        # compute adaptive iou threshold
        iou_mean_per_gt = candidate_ious.mean(dim=1)  # [num_boxes]
        iou_std_per_gt = candidate_ious.std(dim=1)  # [num_boxes]
        iou_thresh_per_gt = iou_mean_per_gt + iou_std_per_gt  # [num_boxes]
        is_pos = candidate_ious >= iou_thresh_per_gt[:, None]  # [num_boxes x n_candidates]

        if self.center_in_gt:  # can discard all candidates in case of very small objects :/
            # center point of selected anchors needs to lie within the ground truth
            boxes_idx = torch.arange(num_gt, device=boxes.device, dtype=torch.long)[:, None]\
                .expand_as(candidate_idx).contiguous()  # [num_boxes x n_candidates]
            is_in_gt = center_in_boxes(
                anchors_center[candidate_idx.view(-1)], boxes[boxes_idx.view(-1)], eps=self.min_dist)
            is_pos = is_pos & is_in_gt.view_as(is_pos)  # [num_boxes x n_candidates]

        # in case on anchor is assigned to multiple boxes, use box with highest IoU
        # TODO: think about a better way to do this
        for ng in range(num_gt):
            candidate_idx[ng, :] += ng * num_anchors
        ious_inf = torch.full_like(match_quality_matrix, -INF).view(-1)
        index = candidate_idx.view(-1)[is_pos.view(-1)]
        ious_inf[index] = match_quality_matrix.view(-1)[index]
        ious_inf = ious_inf.view_as(match_quality_matrix)

        matched_vals, matches = ious_inf.max(dim=0)
        matches[matched_vals == -INF] = self.BELOW_LOW_THRESHOLD
        # print(f"Num matches {(matches >= 0).sum()}, Adapt IoU {iou_thresh_per_gt}")
        return match_quality_matrix, matches


MatcherType = TypeVar('MatcherType', bound=Matcher)