_utils.py 15.4 KB
Newer Older
1
import math
2
from collections import OrderedDict
3
from typing import List, Tuple
4

5
import torch
6
from torch import Tensor, nn
7
from torchvision.ops.misc import FrozenBatchNorm2d
8
9


10
class BalancedPositiveNegativeSampler:
11
12
13
14
    """
    This class samples batches, ensuring that they contain a fixed proportion of positives
    """

15
    def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None:
16
        """
17
        Args:
18
            batch_size_per_image (int): number of elements to be selected per image
19
            positive_fraction (float): percentage of positive elements per batch
20
21
22
23
        """
        self.batch_size_per_image = batch_size_per_image
        self.positive_fraction = positive_fraction

24
    def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
25
        """
26
        Args:
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
            matched idxs: list of tensors containing -1, 0 or positive values.
                Each tensor corresponds to a specific image.
                -1 values are ignored, 0 are considered as negatives and > 0 as
                positives.

        Returns:
            pos_idx (list[tensor])
            neg_idx (list[tensor])

        Returns two lists of binary masks for each image.
        The first list contains the positive elements that were selected,
        and the second list the negative example.
        """
        pos_idx = []
        neg_idx = []
        for matched_idxs_per_image in matched_idxs:
43
44
            positive = torch.where(matched_idxs_per_image >= 1)[0]
            negative = torch.where(matched_idxs_per_image == 0)[0]
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

            num_pos = int(self.batch_size_per_image * self.positive_fraction)
            # protect against not enough positive examples
            num_pos = min(positive.numel(), num_pos)
            num_neg = self.batch_size_per_image - num_pos
            # protect against not enough negative examples
            num_neg = min(negative.numel(), num_neg)

            # randomly select positive and negative examples
            perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos]
            perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg]

            pos_idx_per_image = positive[perm1]
            neg_idx_per_image = negative[perm2]

            # create binary mask from indices
61
62
            pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
            neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
eellison's avatar
eellison committed
63

64
65
            pos_idx_per_image_mask[pos_idx_per_image] = 1
            neg_idx_per_image_mask[neg_idx_per_image] = 1
66
67
68
69
70
71
72

            pos_idx.append(pos_idx_per_image_mask)
            neg_idx.append(neg_idx_per_image_mask)

        return pos_idx, neg_idx


73
@torch.jit._script_if_tracing
74
def encode_boxes(reference_boxes: Tensor, proposals: Tensor, weights: Tensor) -> Tensor:
75
76
77
78
    """
    Encode a set of proposals with respect to some
    reference boxes

79
    Args:
80
81
        reference_boxes (Tensor): reference boxes
        proposals (Tensor): boxes to be encoded
82
        weights (Tensor[4]): the weights for ``(x, y, w, h)``
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    """

    # perform some unpacking to make it JIT-fusion friendly
    wx = weights[0]
    wy = weights[1]
    ww = weights[2]
    wh = weights[3]

    proposals_x1 = proposals[:, 0].unsqueeze(1)
    proposals_y1 = proposals[:, 1].unsqueeze(1)
    proposals_x2 = proposals[:, 2].unsqueeze(1)
    proposals_y2 = proposals[:, 3].unsqueeze(1)

    reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1)
    reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1)
    reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1)
    reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1)

    # implementation starts here
    ex_widths = proposals_x2 - proposals_x1
    ex_heights = proposals_y2 - proposals_y1
    ex_ctr_x = proposals_x1 + 0.5 * ex_widths
    ex_ctr_y = proposals_y1 + 0.5 * ex_heights

    gt_widths = reference_boxes_x2 - reference_boxes_x1
    gt_heights = reference_boxes_y2 - reference_boxes_y1
    gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths
    gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights

    targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths
    targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights
    targets_dw = ww * torch.log(gt_widths / ex_widths)
    targets_dh = wh * torch.log(gt_heights / ex_heights)

    targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1)
    return targets


121
class BoxCoder:
122
123
124
125
126
    """
    This class encodes and decodes a set of bounding boxes into
    the representation used for training the regressors.
    """

127
128
129
    def __init__(
        self, weights: Tuple[float, float, float, float], bbox_xform_clip: float = math.log(1000.0 / 16)
    ) -> None:
130
        """
131
        Args:
132
133
134
135
136
137
            weights (4-element tuple)
            bbox_xform_clip (float)
        """
        self.weights = weights
        self.bbox_xform_clip = bbox_xform_clip

138
    def encode(self, reference_boxes: List[Tensor], proposals: List[Tensor]) -> List[Tensor]:
139
140
141
142
143
144
        boxes_per_image = [len(b) for b in reference_boxes]
        reference_boxes = torch.cat(reference_boxes, dim=0)
        proposals = torch.cat(proposals, dim=0)
        targets = self.encode_single(reference_boxes, proposals)
        return targets.split(boxes_per_image, 0)

145
    def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
146
147
148
149
        """
        Encode a set of proposals with respect to some
        reference boxes

150
        Args:
151
152
153
154
155
156
157
158
159
160
            reference_boxes (Tensor): reference boxes
            proposals (Tensor): boxes to be encoded
        """
        dtype = reference_boxes.dtype
        device = reference_boxes.device
        weights = torch.as_tensor(self.weights, dtype=dtype, device=device)
        targets = encode_boxes(reference_boxes, proposals, weights)

        return targets

161
    def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor:
162
163
        assert isinstance(boxes, (list, tuple))
        assert isinstance(rel_codes, torch.Tensor)
164
        boxes_per_image = [b.size(0) for b in boxes]
165
        concat_boxes = torch.cat(boxes, dim=0)
eellison's avatar
eellison committed
166
167
168
        box_sum = 0
        for val in boxes_per_image:
            box_sum += val
169
170
        if box_sum > 0:
            rel_codes = rel_codes.reshape(box_sum, -1)
171
        pred_boxes = self.decode_single(rel_codes, concat_boxes)
172
173
174
        if box_sum > 0:
            pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
        return pred_boxes
175

176
    def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
177
178
179
180
        """
        From a set of original boxes and encoded relative box offsets,
        get the decoded boxes.

181
        Args:
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
            rel_codes (Tensor): encoded boxes
            boxes (Tensor): reference boxes.
        """

        boxes = boxes.to(rel_codes.dtype)

        widths = boxes[:, 2] - boxes[:, 0]
        heights = boxes[:, 3] - boxes[:, 1]
        ctr_x = boxes[:, 0] + 0.5 * widths
        ctr_y = boxes[:, 1] + 0.5 * heights

        wx, wy, ww, wh = self.weights
        dx = rel_codes[:, 0::4] / wx
        dy = rel_codes[:, 1::4] / wy
        dw = rel_codes[:, 2::4] / ww
        dh = rel_codes[:, 3::4] / wh

        # Prevent sending too large values into torch.exp()
        dw = torch.clamp(dw, max=self.bbox_xform_clip)
        dh = torch.clamp(dh, max=self.bbox_xform_clip)

        pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
        pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
        pred_w = torch.exp(dw) * widths[:, None]
        pred_h = torch.exp(dh) * heights[:, None]

208
209
210
211
212
213
214
215
        # Distance from center to box's corner.
        c_to_c_h = torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
        c_to_c_w = torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w

        pred_boxes1 = pred_ctr_x - c_to_c_w
        pred_boxes2 = pred_ctr_y - c_to_c_h
        pred_boxes3 = pred_ctr_x + c_to_c_w
        pred_boxes4 = pred_ctr_y + c_to_c_h
216
        pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
217
218
219
        return pred_boxes


220
class Matcher:
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    """
    This class assigns to each predicted "element" (e.g., a box) a ground-truth
    element. Each predicted element will have exactly zero or one matches; each
    ground-truth element may be assigned to zero or more predicted elements.

    Matching is based on the MxN match_quality_matrix, that characterizes how well
    each (ground-truth, predicted)-pair match. For example, if the elements are
    boxes, the matrix may contain box IoU overlap values.

    The matcher returns a tensor of size N containing the index of the ground-truth
    element m that matches to prediction n. If there is no match, a negative value
    is returned.
    """

    BELOW_LOW_THRESHOLD = -1
    BETWEEN_THRESHOLDS = -2

eellison's avatar
eellison committed
238
    __annotations__ = {
239
240
        "BELOW_LOW_THRESHOLD": int,
        "BETWEEN_THRESHOLDS": int,
eellison's avatar
eellison committed
241
242
    }

243
    def __init__(self, high_threshold: float, low_threshold: float, allow_low_quality_matches: bool = False) -> None:
244
245
246
247
248
249
250
251
252
253
254
255
256
        """
        Args:
            high_threshold (float): quality values greater than or equal to
                this value are candidate matches.
            low_threshold (float): a lower quality threshold used to stratify
                matches into three levels:
                1) matches >= high_threshold
                2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold)
                3) BELOW_LOW_THRESHOLD matches in [0, low_threshold)
            allow_low_quality_matches (bool): if True, produce additional matches
                for predictions that have only low-quality match candidates. See
                set_low_quality_matches_ for more details.
        """
eellison's avatar
eellison committed
257
258
        self.BELOW_LOW_THRESHOLD = -1
        self.BETWEEN_THRESHOLDS = -2
259
260
261
262
263
        assert low_threshold <= high_threshold
        self.high_threshold = high_threshold
        self.low_threshold = low_threshold
        self.allow_low_quality_matches = allow_low_quality_matches

264
    def __call__(self, match_quality_matrix: Tensor) -> Tensor:
265
266
267
268
269
270
271
272
273
274
275
276
277
        """
        Args:
            match_quality_matrix (Tensor[float]): an MxN tensor, containing the
            pairwise quality between M ground-truth elements and N predicted elements.

        Returns:
            matches (Tensor[int64]): an N tensor where N[i] is a matched gt in
            [0, M - 1] or a negative value indicating that prediction i could not
            be matched.
        """
        if match_quality_matrix.numel() == 0:
            # empty targets or proposals not supported during training
            if match_quality_matrix.shape[0] == 0:
278
                raise ValueError("No ground-truth boxes available for one of the images during training")
279
            else:
280
                raise ValueError("No proposal boxes available for one of the images during training")
281
282
283
284
285
286

        # match_quality_matrix is M (gt) x N (predicted)
        # Max over gt elements (dim 0) to find best gt candidate for each prediction
        matched_vals, matches = match_quality_matrix.max(dim=0)
        if self.allow_low_quality_matches:
            all_matches = matches.clone()
eellison's avatar
eellison committed
287
        else:
288
            all_matches = None  # type: ignore[assignment]
289
290
291

        # Assign candidate matches with low quality to negative (unassigned) values
        below_low_threshold = matched_vals < self.low_threshold
292
        between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold)
293
294
        matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD
        matches[between_thresholds] = self.BETWEEN_THRESHOLDS
295
296

        if self.allow_low_quality_matches:
eellison's avatar
eellison committed
297
            assert all_matches is not None
298
299
300
301
            self.set_low_quality_matches_(matches, all_matches, match_quality_matrix)

        return matches

302
    def set_low_quality_matches_(self, matches: Tensor, all_matches: Tensor, match_quality_matrix: Tensor) -> None:
303
304
305
306
307
308
309
310
311
312
        """
        Produce additional matches for predictions that have only low-quality matches.
        Specifically, for each ground-truth find the set of predictions that have
        maximum overlap with it (including ties); for each prediction in that set, if
        it is unmatched, then match it to the ground-truth with which it has the highest
        quality value.
        """
        # For each gt, find the prediction with which it has highest quality
        highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
        # Find highest quality match available, even if it is low, including ties
313
        gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None])
314
315
316
317
318
319
320
321
322
323
324
325
326
327
        # Example gt_pred_pairs_of_highest_quality:
        #   tensor([[    0, 39796],
        #           [    1, 32055],
        #           [    1, 32070],
        #           [    2, 39190],
        #           [    2, 40255],
        #           [    3, 40390],
        #           [    3, 41455],
        #           [    4, 45470],
        #           [    5, 45325],
        #           [    5, 46390]])
        # Each row is a (gt index, prediction index)
        # Note how gt items 1, 2, 3, and 5 each have two ties

328
        pred_inds_to_update = gt_pred_pairs_of_highest_quality[1]
329
        matches[pred_inds_to_update] = all_matches[pred_inds_to_update]
330
331


332
class SSDMatcher(Matcher):
333
    def __init__(self, threshold: float) -> None:
334
335
        super().__init__(threshold, threshold, allow_low_quality_matches=False)

336
    def __call__(self, match_quality_matrix: Tensor) -> Tensor:
337
338
339
340
        matches = super().__call__(match_quality_matrix)

        # For each gt, find the prediction with which it has the highest quality
        _, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1)
341
342
343
        matches[highest_quality_pred_foreach_gt] = torch.arange(
            highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, device=highest_quality_pred_foreach_gt.device
        )
344
345
346
347

        return matches


348
def overwrite_eps(model: nn.Module, eps: float) -> None:
349
350
351
352
353
354
355
356
    """
    This method overwrites the default eps values of all the
    FrozenBatchNorm2d layers of the model with the provided value.
    This is necessary to address the BC-breaking change introduced
    by the bug-fix at pytorch/vision#2933. The overwrite is applied
    only when the pretrained weights are loaded to maintain compatibility
    with previous versions.

357
    Args:
358
359
360
361
362
363
        model (nn.Module): The model on which we perform the overwrite.
        eps (float): The new value of eps.
    """
    for module in model.modules():
        if isinstance(module, FrozenBatchNorm2d):
            module.eps = eps
364
365


366
def retrieve_out_channels(model: nn.Module, size: Tuple[int, int]) -> List[int]:
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
    """
    This method retrieves the number of output channels of a specific model.

    Args:
        model (nn.Module): The model for which we estimate the out_channels.
            It should return a single Tensor or an OrderedDict[Tensor].
        size (Tuple[int, int]): The size (wxh) of the input.

    Returns:
        out_channels (List[int]): A list of the output channels of the model.
    """
    in_training = model.training
    model.eval()

    with torch.no_grad():
        # Use dummy data to retrieve the feature map sizes to avoid hard-coding their values
        device = next(model.parameters()).device
        tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device)
        features = model(tmp_img)
        if isinstance(features, torch.Tensor):
387
            features = OrderedDict([("0", features)])
388
389
390
391
392
393
        out_channels = [x.size(1) for x in features.values()]

    if in_training:
        model.train()

    return out_channels