rpn.py 15 KB
Newer Older
1
2
import torch
from torch.nn import functional as F
eellison's avatar
eellison committed
3
from torch import nn, Tensor
4

5
import torchvision
6
7
8
from torchvision.ops import boxes as box_ops

from . import _utils as det_utils
eellison's avatar
eellison committed
9
10
from .image_list import ImageList

11
from typing import List, Optional, Dict, Tuple
12

13
14
15
# Import AnchorGenerator to keep compatibility.
from .anchor_utils import AnchorGenerator

16

17
18
@torch.jit.unused
def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
eellison's avatar
eellison committed
19
    # type: (Tensor, int) -> Tuple[int, int]
20
21
22
23
    from torch.onnx import operators
    num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
    pre_nms_top_n = torch.min(torch.cat(
        (torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype),
24
         num_anchors), 0))
25
26
27
28

    return num_anchors, pre_nms_top_n


29
30
31
class RPNHead(nn.Module):
    """
    Adds a simple RPN Head with classification and regression heads
32

33
    Args:
34
35
        in_channels (int): number of channels of the input feature
        num_anchors (int): number of anchors to be predicted
36
37
38
39
40
41
42
43
44
45
46
47
    """

    def __init__(self, in_channels, num_anchors):
        super(RPNHead, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, in_channels, kernel_size=3, stride=1, padding=1
        )
        self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
        self.bbox_pred = nn.Conv2d(
            in_channels, num_anchors * 4, kernel_size=1, stride=1
        )

Francisco Massa's avatar
Francisco Massa committed
48
49
50
        for layer in self.children():
            torch.nn.init.normal_(layer.weight, std=0.01)
            torch.nn.init.constant_(layer.bias, 0)
51
52

    def forward(self, x):
53
        # type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
54
55
56
57
58
59
60
61
62
63
        logits = []
        bbox_reg = []
        for feature in x:
            t = F.relu(self.conv(feature))
            logits.append(self.cls_logits(t))
            bbox_reg.append(self.bbox_pred(t))
        return logits, bbox_reg


def permute_and_flatten(layer, N, A, C, H, W):
64
    # type: (Tensor, int, int, int, int, int) -> Tensor
65
66
67
68
69
70
71
    layer = layer.view(N, -1, C, H, W)
    layer = layer.permute(0, 3, 4, 1, 2)
    layer = layer.reshape(N, -1, C)
    return layer


def concat_box_prediction_layers(box_cls, box_regression):
72
    # type: (List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    box_cls_flattened = []
    box_regression_flattened = []
    # for each feature level, permute the outputs to make them be in the
    # same format as the labels. Note that the labels are computed for
    # all feature levels concatenated, so we keep the same representation
    # for the objectness and the box_regression
    for box_cls_per_level, box_regression_per_level in zip(
        box_cls, box_regression
    ):
        N, AxC, H, W = box_cls_per_level.shape
        Ax4 = box_regression_per_level.shape[1]
        A = Ax4 // 4
        C = AxC // A
        box_cls_per_level = permute_and_flatten(
            box_cls_per_level, N, A, C, H, W
        )
        box_cls_flattened.append(box_cls_per_level)

        box_regression_per_level = permute_and_flatten(
            box_regression_per_level, N, A, 4, H, W
        )
        box_regression_flattened.append(box_regression_per_level)
    # concatenate on the first dimension (representing the feature levels), to
    # take into account the way the labels were generated (with all feature maps
    # being concatenated as well)
98
    box_cls = torch.cat(box_cls_flattened, dim=1).flatten(0, -2)
99
100
101
102
103
    box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4)
    return box_cls, box_regression


class RegionProposalNetwork(torch.nn.Module):
104
105
106
    """
    Implements Region Proposal Network (RPN).

107
    Args:
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
            maps.
        head (nn.Module): module that computes the objectness and regression deltas
        fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
            considered as positive during training of the RPN.
        bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
            considered as negative during training of the RPN.
        batch_size_per_image (int): number of anchors that are sampled during training of the RPN
            for computing the loss
        positive_fraction (float): proportion of positive anchors in a mini-batch during training
            of the RPN
        pre_nms_top_n (Dict[int]): number of proposals to keep before applying NMS. It should
            contain two fields: training and testing, to allow for different values depending
            on training or evaluation
        post_nms_top_n (Dict[int]): number of proposals to keep after applying NMS. It should
            contain two fields: training and testing, to allow for different values depending
            on training or evaluation
        nms_thresh (float): NMS threshold used for postprocessing the RPN proposals

    """
eellison's avatar
eellison committed
128
129
130
131
132
133
134
    __annotations__ = {
        'box_coder': det_utils.BoxCoder,
        'proposal_matcher': det_utils.Matcher,
        'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler,
        'pre_nms_top_n': Dict[str, int],
        'post_nms_top_n': Dict[str, int],
    }
135
136
137
138
139
140
141
142

    def __init__(self,
                 anchor_generator,
                 head,
                 #
                 fg_iou_thresh, bg_iou_thresh,
                 batch_size_per_image, positive_fraction,
                 #
143
                 pre_nms_top_n, post_nms_top_n, nms_thresh, score_thresh=0.0):
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        super(RegionProposalNetwork, self).__init__()
        self.anchor_generator = anchor_generator
        self.head = head
        self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))

        # used during training
        self.box_similarity = box_ops.box_iou

        self.proposal_matcher = det_utils.Matcher(
            fg_iou_thresh,
            bg_iou_thresh,
            allow_low_quality_matches=True,
        )

        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(
            batch_size_per_image, positive_fraction
        )
        # used during testing
        self._pre_nms_top_n = pre_nms_top_n
        self._post_nms_top_n = post_nms_top_n
        self.nms_thresh = nms_thresh
165
        self.score_thresh = score_thresh
166
        self.min_size = 1e-3
167
168
169
170
171
172
173
174
175
176
177
178

    def pre_nms_top_n(self):
        if self.training:
            return self._pre_nms_top_n['training']
        return self._pre_nms_top_n['testing']

    def post_nms_top_n(self):
        if self.training:
            return self._post_nms_top_n['training']
        return self._post_nms_top_n['testing']

    def assign_targets_to_anchors(self, anchors, targets):
179
        # type: (List[Tensor], List[Dict[str, Tensor]]) -> Tuple[List[Tensor], List[Tensor]]
180
181
182
183
        labels = []
        matched_gt_boxes = []
        for anchors_per_image, targets_per_image in zip(anchors, targets):
            gt_boxes = targets_per_image["boxes"]
184
185
186
187
188
189
190

            if gt_boxes.numel() == 0:
                # Background image (negative example)
                device = anchors_per_image.device
                matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device)
                labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device)
            else:
191
                match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image)
192
193
194
195
196
197
198
199
200
201
202
203
                matched_idxs = self.proposal_matcher(match_quality_matrix)
                # get the targets corresponding GT for each proposal
                # NB: need to clamp the indices because we can have a single
                # GT in the image, and matched_idxs can be -2, which goes
                # out of bounds
                matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]

                labels_per_image = matched_idxs >= 0
                labels_per_image = labels_per_image.to(dtype=torch.float32)

                # Background (negative examples)
                bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
204
                labels_per_image[bg_indices] = 0.0
205
206
207

                # discard indices that are between thresholds
                inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
208
                labels_per_image[inds_to_discard] = -1.0
209
210
211
212
213
214

            labels.append(labels_per_image)
            matched_gt_boxes.append(matched_gt_boxes_per_image)
        return labels, matched_gt_boxes

    def _get_top_n_idx(self, objectness, num_anchors_per_level):
215
        # type: (Tensor, List[int]) -> Tensor
216
217
218
        r = []
        offset = 0
        for ob in objectness.split(num_anchors_per_level, 1):
219
            if torchvision._is_tracing():
eellison's avatar
eellison committed
220
                num_anchors, pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n(ob, self.pre_nms_top_n())
221
222
            else:
                num_anchors = ob.shape[1]
eellison's avatar
eellison committed
223
                pre_nms_top_n = min(self.pre_nms_top_n(), num_anchors)
224
225
226
227
228
229
            _, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
            r.append(top_n_idx + offset)
            offset += num_anchors
        return torch.cat(r, dim=1)

    def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level):
230
        # type: (Tensor, Tensor, List[Tuple[int, int]], List[int]) -> Tuple[List[Tensor], List[Tensor]]
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        num_images = proposals.shape[0]
        device = proposals.device
        # do not backprop throught objectness
        objectness = objectness.detach()
        objectness = objectness.reshape(num_images, -1)

        levels = [
            torch.full((n,), idx, dtype=torch.int64, device=device)
            for idx, n in enumerate(num_anchors_per_level)
        ]
        levels = torch.cat(levels, 0)
        levels = levels.reshape(1, -1).expand_as(objectness)

        # select top_n boxes independently per level before applying nms
        top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
eellison's avatar
eellison committed
246
247
248
249

        image_range = torch.arange(num_images, device=device)
        batch_idx = image_range[:, None]

250
251
252
253
        objectness = objectness[batch_idx, top_n_idx]
        levels = levels[batch_idx, top_n_idx]
        proposals = proposals[batch_idx, top_n_idx]

254
        objectness_prob = torch.sigmoid(objectness)
255

256
257
        final_boxes = []
        final_scores = []
258
        for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes):
259
            boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
260
261

            # remove small boxes
262
263
            keep = box_ops.remove_small_boxes(boxes, self.min_size)
            boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
264
265
266
267
268
269

            # remove low scoring boxes
            # use >= for Backwards compatibility
            keep = torch.where(scores >= self.score_thresh)[0]
            boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]

270
271
            # non-maximum suppression, independently done per level
            keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
272

273
            # keep only topk scoring predictions
eellison's avatar
eellison committed
274
            keep = keep[:self.post_nms_top_n()]
275
            boxes, scores = boxes[keep], scores[keep]
276

277
278
279
280
281
            final_boxes.append(boxes)
            final_scores.append(scores)
        return final_boxes, final_scores

    def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets):
282
        # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
283
        """
284
        Args:
285
286
287
288
            objectness (Tensor)
            pred_bbox_deltas (Tensor)
            labels (List[Tensor])
            regression_targets (List[Tensor])
289
290
291

        Returns:
            objectness_loss (Tensor)
lambdaflow's avatar
lambdaflow committed
292
            box_loss (Tensor)
293
294
295
        """

        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
296
297
        sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds, dim=0))[0]
        sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0]
298
299
300
301
302
303
304
305

        sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)

        objectness = objectness.flatten()

        labels = torch.cat(labels, dim=0)
        regression_targets = torch.cat(regression_targets, dim=0)

306
        box_loss = F.smooth_l1_loss(
307
308
            pred_bbox_deltas[sampled_pos_inds],
            regression_targets[sampled_pos_inds],
309
            beta=1 / 9,
310
            reduction='sum',
311
312
313
314
315
316
317
318
        ) / (sampled_inds.numel())

        objectness_loss = F.binary_cross_entropy_with_logits(
            objectness[sampled_inds], labels[sampled_inds]
        )

        return objectness_loss, box_loss

319
320
321
322
323
324
    def forward(self,
                images,       # type: ImageList
                features,     # type: Dict[str, Tensor]
                targets=None  # type: Optional[List[Dict[str, Tensor]]]
                ):
        # type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]]
325
        """
326
        Args:
327
            images (ImageList): images for which we want to compute the predictions
Jackson Liu's avatar
Jackson Liu committed
328
            features (OrderedDict[Tensor]): features computed from the images that are
329
330
                used for computing the predictions. Each tensor in the list
                correspond to different feature levels
lambdaflow's avatar
lambdaflow committed
331
            targets (List[Dict[Tensor]]): ground-truth boxes present in the image (optional).
332
333
                If provided, each element in the dict should contain a field `boxes`,
                with the locations of the ground-truth boxes.
334
335

        Returns:
336
            boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per
337
                image.
338
            losses (Dict[Tensor]): the losses for the model during training. During
339
340
341
342
343
344
345
346
                testing, it is an empty dict.
        """
        # RPN uses all feature maps that are available
        features = list(features.values())
        objectness, pred_bbox_deltas = self.head(features)
        anchors = self.anchor_generator(images, features)

        num_images = len(anchors)
347
348
        num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
        num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
349
350
351
352
353
354
355
356
357
358
359
        objectness, pred_bbox_deltas = \
            concat_box_prediction_layers(objectness, pred_bbox_deltas)
        # apply pred_bbox_deltas to anchors to obtain the decoded proposals
        # note that we detach the deltas because Faster R-CNN do not backprop through
        # the proposals
        proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
        proposals = proposals.view(num_images, -1, 4)
        boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)

        losses = {}
        if self.training:
eellison's avatar
eellison committed
360
            assert targets is not None
361
362
363
364
365
366
367
368
369
            labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
            regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
            loss_objectness, loss_rpn_box_reg = self.compute_loss(
                objectness, pred_bbox_deltas, labels, regression_targets)
            losses = {
                "loss_objectness": loss_objectness,
                "loss_rpn_box_reg": loss_rpn_box_reg,
            }
        return boxes, losses