rpn.py 19.5 KB
Newer Older
eellison's avatar
eellison committed
1
2
from __future__ import division

3
4
5
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch.nn import functional as F
eellison's avatar
eellison committed
6
from torch import nn, Tensor
7

8
import torchvision
9
10
11
from torchvision.ops import boxes as box_ops

from . import _utils as det_utils
eellison's avatar
eellison committed
12
13
14
from .image_list import ImageList

from torch.jit.annotations import List, Optional, Dict, Tuple
15
16


17
18
@torch.jit.unused
def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
eellison's avatar
eellison committed
19
    # type: (Tensor, int) -> Tuple[int, int]
20
21
22
23
24
25
26
27
28
29
30
    from torch.onnx import operators
    num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
    # TODO : remove cast to IntTensor/num_anchors.dtype when
    #        ONNX Runtime version is updated with ReduceMin int64 support
    pre_nms_top_n = torch.min(torch.cat(
        (torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype),
         num_anchors), 0).to(torch.int32)).to(num_anchors.dtype)

    return num_anchors, pre_nms_top_n


31
class AnchorGenerator(nn.Module):
eellison's avatar
eellison committed
32
33
34
35
36
    __annotations__ = {
        "cell_anchors": Optional[List[torch.Tensor]],
        "_cache": Dict[str, List[torch.Tensor]]
    }

37
    """
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    Module that generates anchors for a set of feature maps and
    image sizes.

    The module support computing anchors at multiple sizes and aspect ratios
    per feature map.

    sizes and aspect_ratios should have the same number of elements, and it should
    correspond to the number of feature maps.

    sizes[i] and aspect_ratios[i] can have an arbitrary number of elements,
    and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors
    per spatial location for feature map i.

    Arguments:
        sizes (Tuple[Tuple[int]]):
        aspect_ratios (Tuple[Tuple[float]]):
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
    """

    def __init__(
        self,
        sizes=(128, 256, 512),
        aspect_ratios=(0.5, 1.0, 2.0),
    ):
        super(AnchorGenerator, self).__init__()

        if not isinstance(sizes[0], (list, tuple)):
            # TODO change this
            sizes = tuple((s,) for s in sizes)
        if not isinstance(aspect_ratios[0], (list, tuple)):
            aspect_ratios = (aspect_ratios,) * len(sizes)

        assert len(sizes) == len(aspect_ratios)

        self.sizes = sizes
        self.aspect_ratios = aspect_ratios
        self.cell_anchors = None
        self._cache = {}

eellison's avatar
eellison committed
76
77
78
    # TODO: https://github.com/pytorch/pytorch/issues/26792
    def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"):
        # type: (List[int], List[float], int, Device)  # noqa: F821
79
80
        scales = torch.as_tensor(scales, dtype=dtype, device=device)
        aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
81
82
83
84
85
86
87
88
89
        h_ratios = torch.sqrt(aspect_ratios)
        w_ratios = 1 / h_ratios

        ws = (w_ratios[:, None] * scales[None, :]).view(-1)
        hs = (h_ratios[:, None] * scales[None, :]).view(-1)

        base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
        return base_anchors.round()

90
    def set_cell_anchors(self, dtype, device):
eellison's avatar
eellison committed
91
        # type: (int, Device) -> None    # noqa: F821
92
        if self.cell_anchors is not None:
93
94
95
96
97
98
            cell_anchors = self.cell_anchors
            assert cell_anchors is not None
            # suppose that all anchors have the same device
            # which is a valid assumption in the current state of the codebase
            if cell_anchors[0].device == device:
                return
eellison's avatar
eellison committed
99

100
101
102
103
        cell_anchors = [
            self.generate_anchors(
                sizes,
                aspect_ratios,
104
                dtype,
105
106
107
108
109
110
111
112
113
114
                device
            )
            for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios)
        ]
        self.cell_anchors = cell_anchors

    def num_anchors_per_location(self):
        return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]

    def grid_anchors(self, grid_sizes, strides):
eellison's avatar
eellison committed
115
        # type: (List[List[int]], List[List[int]])
116
        anchors = []
eellison's avatar
eellison committed
117
118
119
        cell_anchors = self.cell_anchors
        assert cell_anchors is not None

120
        for size, stride, base_anchors in zip(
eellison's avatar
eellison committed
121
            grid_sizes, strides, cell_anchors
122
123
124
        ):
            grid_height, grid_width = size
            stride_height, stride_width = stride
125
126
127
128
            if torchvision._is_tracing():
                # required in ONNX export for mult operation with float32
                stride_width = torch.tensor(stride_width, dtype=torch.float32)
                stride_height = torch.tensor(stride_height, dtype=torch.float32)
129
130
131
132
133
134
135
            device = base_anchors.device
            shifts_x = torch.arange(
                0, grid_width, dtype=torch.float32, device=device
            ) * stride_width
            shifts_y = torch.arange(
                0, grid_height, dtype=torch.float32, device=device
            ) * stride_height
136
            shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
137
138
139
140
141
142
143
144
145
146
147
            shift_x = shift_x.reshape(-1)
            shift_y = shift_y.reshape(-1)
            shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)

            anchors.append(
                (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)
            )

        return anchors

    def cached_grid_anchors(self, grid_sizes, strides):
eellison's avatar
eellison committed
148
149
        # type: (List[List[int]], List[List[int]])
        key = str(grid_sizes + strides)
150
151
152
153
154
155
156
        if key in self._cache:
            return self._cache[key]
        anchors = self.grid_anchors(grid_sizes, strides)
        self._cache[key] = anchors
        return anchors

    def forward(self, image_list, feature_maps):
eellison's avatar
eellison committed
157
158
        # type: (ImageList, List[Tensor])
        grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
159
        image_size = image_list.tensors.shape[-2:]
eellison's avatar
eellison committed
160
        strides = [[int(image_size[0] / g[0]), int(image_size[1] / g[1])] for g in grid_sizes]
161
162
        dtype, device = feature_maps[0].dtype, feature_maps[0].device
        self.set_cell_anchors(dtype, device)
163
        anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
eellison's avatar
eellison committed
164
        anchors = torch.jit.annotate(List[List[torch.Tensor]], [])
165
166
167
168
169
170
        for i, (image_height, image_width) in enumerate(image_list.image_sizes):
            anchors_in_image = []
            for anchors_per_feature_map in anchors_over_all_feature_maps:
                anchors_in_image.append(anchors_per_feature_map)
            anchors.append(anchors_in_image)
        anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
171
172
        # Clear the cache in case that memory leaks.
        self._cache.clear()
173
174
175
176
177
178
        return anchors


class RPNHead(nn.Module):
    """
    Adds a simple RPN Head with classification and regression heads
179
180
181
182

    Arguments:
        in_channels (int): number of channels of the input feature
        num_anchors (int): number of anchors to be predicted
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    """

    def __init__(self, in_channels, num_anchors):
        super(RPNHead, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, in_channels, kernel_size=3, stride=1, padding=1
        )
        self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
        self.bbox_pred = nn.Conv2d(
            in_channels, num_anchors * 4, kernel_size=1, stride=1
        )

        for l in self.children():
            torch.nn.init.normal_(l.weight, std=0.01)
            torch.nn.init.constant_(l.bias, 0)

    def forward(self, x):
eellison's avatar
eellison committed
200
        # type: (List[Tensor])
201
202
203
204
205
206
207
208
209
210
        logits = []
        bbox_reg = []
        for feature in x:
            t = F.relu(self.conv(feature))
            logits.append(self.cls_logits(t))
            bbox_reg.append(self.bbox_pred(t))
        return logits, bbox_reg


def permute_and_flatten(layer, N, A, C, H, W):
eellison's avatar
eellison committed
211
    # type: (Tensor, int, int, int, int, int)
212
213
214
215
216
217
218
    layer = layer.view(N, -1, C, H, W)
    layer = layer.permute(0, 3, 4, 1, 2)
    layer = layer.reshape(N, -1, C)
    return layer


def concat_box_prediction_layers(box_cls, box_regression):
eellison's avatar
eellison committed
219
    # type: (List[Tensor], List[Tensor])
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    box_cls_flattened = []
    box_regression_flattened = []
    # for each feature level, permute the outputs to make them be in the
    # same format as the labels. Note that the labels are computed for
    # all feature levels concatenated, so we keep the same representation
    # for the objectness and the box_regression
    for box_cls_per_level, box_regression_per_level in zip(
        box_cls, box_regression
    ):
        N, AxC, H, W = box_cls_per_level.shape
        Ax4 = box_regression_per_level.shape[1]
        A = Ax4 // 4
        C = AxC // A
        box_cls_per_level = permute_and_flatten(
            box_cls_per_level, N, A, C, H, W
        )
        box_cls_flattened.append(box_cls_per_level)

        box_regression_per_level = permute_and_flatten(
            box_regression_per_level, N, A, 4, H, W
        )
        box_regression_flattened.append(box_regression_per_level)
    # concatenate on the first dimension (representing the feature levels), to
    # take into account the way the labels were generated (with all feature maps
    # being concatenated as well)
245
    box_cls = torch.cat(box_cls_flattened, dim=1).flatten(0, -2)
246
247
248
249
250
    box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4)
    return box_cls, box_regression


class RegionProposalNetwork(torch.nn.Module):
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
    """
    Implements Region Proposal Network (RPN).

    Arguments:
        anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
            maps.
        head (nn.Module): module that computes the objectness and regression deltas
        fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
            considered as positive during training of the RPN.
        bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
            considered as negative during training of the RPN.
        batch_size_per_image (int): number of anchors that are sampled during training of the RPN
            for computing the loss
        positive_fraction (float): proportion of positive anchors in a mini-batch during training
            of the RPN
        pre_nms_top_n (Dict[int]): number of proposals to keep before applying NMS. It should
            contain two fields: training and testing, to allow for different values depending
            on training or evaluation
        post_nms_top_n (Dict[int]): number of proposals to keep after applying NMS. It should
            contain two fields: training and testing, to allow for different values depending
            on training or evaluation
        nms_thresh (float): NMS threshold used for postprocessing the RPN proposals

    """
eellison's avatar
eellison committed
275
276
277
278
279
280
281
    __annotations__ = {
        'box_coder': det_utils.BoxCoder,
        'proposal_matcher': det_utils.Matcher,
        'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler,
        'pre_nms_top_n': Dict[str, int],
        'post_nms_top_n': Dict[str, int],
    }
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311

    def __init__(self,
                 anchor_generator,
                 head,
                 #
                 fg_iou_thresh, bg_iou_thresh,
                 batch_size_per_image, positive_fraction,
                 #
                 pre_nms_top_n, post_nms_top_n, nms_thresh):
        super(RegionProposalNetwork, self).__init__()
        self.anchor_generator = anchor_generator
        self.head = head
        self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))

        # used during training
        self.box_similarity = box_ops.box_iou

        self.proposal_matcher = det_utils.Matcher(
            fg_iou_thresh,
            bg_iou_thresh,
            allow_low_quality_matches=True,
        )

        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(
            batch_size_per_image, positive_fraction
        )
        # used during testing
        self._pre_nms_top_n = pre_nms_top_n
        self._post_nms_top_n = post_nms_top_n
        self.nms_thresh = nms_thresh
312
        self.min_size = 1e-3
313
314
315
316
317
318
319
320
321
322
323
324

    def pre_nms_top_n(self):
        if self.training:
            return self._pre_nms_top_n['training']
        return self._pre_nms_top_n['testing']

    def post_nms_top_n(self):
        if self.training:
            return self._post_nms_top_n['training']
        return self._post_nms_top_n['testing']

    def assign_targets_to_anchors(self, anchors, targets):
eellison's avatar
eellison committed
325
        # type: (List[Tensor], List[Dict[str, Tensor]])
326
327
328
329
        labels = []
        matched_gt_boxes = []
        for anchors_per_image, targets_per_image in zip(anchors, targets):
            gt_boxes = targets_per_image["boxes"]
eellison's avatar
eellison committed
330
            match_quality_matrix = box_ops.box_iou(gt_boxes, anchors_per_image)
331
332
333
334
335
336
337
338
339
340
341
342
            matched_idxs = self.proposal_matcher(match_quality_matrix)
            # get the targets corresponding GT for each proposal
            # NB: need to clamp the indices because we can have a single
            # GT in the image, and matched_idxs can be -2, which goes
            # out of bounds
            matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]

            labels_per_image = matched_idxs >= 0
            labels_per_image = labels_per_image.to(dtype=torch.float32)

            # Background (negative examples)
            bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
343
            labels_per_image[bg_indices] = torch.tensor(0.0)
344
345
346

            # discard indices that are between thresholds
            inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
347
            labels_per_image[inds_to_discard] = torch.tensor(-1.0)
348
349
350
351
352
353

            labels.append(labels_per_image)
            matched_gt_boxes.append(matched_gt_boxes_per_image)
        return labels, matched_gt_boxes

    def _get_top_n_idx(self, objectness, num_anchors_per_level):
eellison's avatar
eellison committed
354
        # type: (Tensor, List[int])
355
356
357
        r = []
        offset = 0
        for ob in objectness.split(num_anchors_per_level, 1):
358
            if torchvision._is_tracing():
eellison's avatar
eellison committed
359
                num_anchors, pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n(ob, self.pre_nms_top_n())
360
361
            else:
                num_anchors = ob.shape[1]
eellison's avatar
eellison committed
362
                pre_nms_top_n = min(self.pre_nms_top_n(), num_anchors)
363
364
365
366
367
368
            _, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
            r.append(top_n_idx + offset)
            offset += num_anchors
        return torch.cat(r, dim=1)

    def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level):
eellison's avatar
eellison committed
369
        # type: (Tensor, Tensor, List[Tuple[int, int]], List[int])
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
        num_images = proposals.shape[0]
        device = proposals.device
        # do not backprop throught objectness
        objectness = objectness.detach()
        objectness = objectness.reshape(num_images, -1)

        levels = [
            torch.full((n,), idx, dtype=torch.int64, device=device)
            for idx, n in enumerate(num_anchors_per_level)
        ]
        levels = torch.cat(levels, 0)
        levels = levels.reshape(1, -1).expand_as(objectness)

        # select top_n boxes independently per level before applying nms
        top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
eellison's avatar
eellison committed
385
386
387
388

        image_range = torch.arange(num_images, device=device)
        batch_idx = image_range[:, None]

389
390
391
392
393
394
395
396
397
398
399
400
401
        objectness = objectness[batch_idx, top_n_idx]
        levels = levels[batch_idx, top_n_idx]
        proposals = proposals[batch_idx, top_n_idx]

        final_boxes = []
        final_scores = []
        for boxes, scores, lvl, img_shape in zip(proposals, objectness, levels, image_shapes):
            boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
            keep = box_ops.remove_small_boxes(boxes, self.min_size)
            boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
            # non-maximum suppression, independently done per level
            keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
            # keep only topk scoring predictions
eellison's avatar
eellison committed
402
            keep = keep[:self.post_nms_top_n()]
403
404
405
406
407
408
            boxes, scores = boxes[keep], scores[keep]
            final_boxes.append(boxes)
            final_scores.append(scores)
        return final_boxes, final_scores

    def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets):
eellison's avatar
eellison committed
409
        # type: (Tensor, Tensor, List[Tensor], List[Tensor])
410
411
        """
        Arguments:
412
413
414
415
            objectness (Tensor)
            pred_bbox_deltas (Tensor)
            labels (List[Tensor])
            regression_targets (List[Tensor])
416
417
418

        Returns:
            objectness_loss (Tensor)
lambdaflow's avatar
lambdaflow committed
419
            box_loss (Tensor)
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
        """

        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
        sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1)
        sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1)

        sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)

        objectness = objectness.flatten()

        labels = torch.cat(labels, dim=0)
        regression_targets = torch.cat(regression_targets, dim=0)

        box_loss = F.l1_loss(
            pred_bbox_deltas[sampled_pos_inds],
            regression_targets[sampled_pos_inds],
            reduction="sum",
        ) / (sampled_inds.numel())

        objectness_loss = F.binary_cross_entropy_with_logits(
            objectness[sampled_inds], labels[sampled_inds]
        )

        return objectness_loss, box_loss

    def forward(self, images, features, targets=None):
eellison's avatar
eellison committed
446
        # type: (ImageList, Dict[str, Tensor], Optional[List[Dict[str, Tensor]]])
447
448
449
        """
        Arguments:
            images (ImageList): images for which we want to compute the predictions
450
            features (List[Tensor]): features computed from the images that are
451
452
                used for computing the predictions. Each tensor in the list
                correspond to different feature levels
lambdaflow's avatar
lambdaflow committed
453
            targets (List[Dict[Tensor]]): ground-truth boxes present in the image (optional).
454
455
                If provided, each element in the dict should contain a field `boxes`,
                with the locations of the ground-truth boxes.
456
457

        Returns:
458
            boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per
459
                image.
460
            losses (Dict[Tensor]): the losses for the model during training. During
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
                testing, it is an empty dict.
        """
        # RPN uses all feature maps that are available
        features = list(features.values())
        objectness, pred_bbox_deltas = self.head(features)
        anchors = self.anchor_generator(images, features)

        num_images = len(anchors)
        num_anchors_per_level = [o[0].numel() for o in objectness]
        objectness, pred_bbox_deltas = \
            concat_box_prediction_layers(objectness, pred_bbox_deltas)
        # apply pred_bbox_deltas to anchors to obtain the decoded proposals
        # note that we detach the deltas because Faster R-CNN do not backprop through
        # the proposals
        proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
        proposals = proposals.view(num_images, -1, 4)
        boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)

        losses = {}
        if self.training:
eellison's avatar
eellison committed
481
            assert targets is not None
482
483
484
485
486
487
488
489
490
            labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
            regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
            loss_objectness, loss_rpn_box_reg = self.compute_loss(
                objectness, pred_bbox_deltas, labels, regression_targets)
            losses = {
                "loss_objectness": loss_objectness,
                "loss_rpn_box_reg": loss_rpn_box_reg,
            }
        return boxes, losses