rpn.py 14.8 KB
Newer Older
1
from typing import List, Optional, Dict, Tuple
2

3
import torch
4
import torchvision
5
6
from torch import nn, Tensor
from torch.nn import functional as F
7
8
9
10
from torchvision.ops import boxes as box_ops

from . import _utils as det_utils

11
# Import AnchorGenerator to keep compatibility.
12
from .anchor_utils import AnchorGenerator  # noqa: 401
13
from .image_list import ImageList
14

15

16
17
@torch.jit.unused
def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
eellison's avatar
eellison committed
18
    # type: (Tensor, int) -> Tuple[int, int]
19
    from torch.onnx import operators
20

21
    num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
22
    pre_nms_top_n = torch.min(torch.cat((torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), num_anchors), 0))
23
24
25
26

    return num_anchors, pre_nms_top_n


27
28
29
class RPNHead(nn.Module):
    """
    Adds a simple RPN Head with classification and regression heads
30

31
    Args:
32
33
        in_channels (int): number of channels of the input feature
        num_anchors (int): number of anchors to be predicted
34
35
36
37
    """

    def __init__(self, in_channels, num_anchors):
        super(RPNHead, self).__init__()
38
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
39
        self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
40
        self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1)
41

Francisco Massa's avatar
Francisco Massa committed
42
43
44
        for layer in self.children():
            torch.nn.init.normal_(layer.weight, std=0.01)
            torch.nn.init.constant_(layer.bias, 0)
45
46

    def forward(self, x):
47
        # type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
48
49
50
51
52
53
54
55
56
57
        logits = []
        bbox_reg = []
        for feature in x:
            t = F.relu(self.conv(feature))
            logits.append(self.cls_logits(t))
            bbox_reg.append(self.bbox_pred(t))
        return logits, bbox_reg


def permute_and_flatten(layer, N, A, C, H, W):
58
    # type: (Tensor, int, int, int, int, int) -> Tensor
59
60
61
62
63
64
65
    layer = layer.view(N, -1, C, H, W)
    layer = layer.permute(0, 3, 4, 1, 2)
    layer = layer.reshape(N, -1, C)
    return layer


def concat_box_prediction_layers(box_cls, box_regression):
66
    # type: (List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
67
68
69
70
71
72
    box_cls_flattened = []
    box_regression_flattened = []
    # for each feature level, permute the outputs to make them be in the
    # same format as the labels. Note that the labels are computed for
    # all feature levels concatenated, so we keep the same representation
    # for the objectness and the box_regression
73
    for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression):
74
75
76
77
        N, AxC, H, W = box_cls_per_level.shape
        Ax4 = box_regression_per_level.shape[1]
        A = Ax4 // 4
        C = AxC // A
78
        box_cls_per_level = permute_and_flatten(box_cls_per_level, N, A, C, H, W)
79
80
        box_cls_flattened.append(box_cls_per_level)

81
        box_regression_per_level = permute_and_flatten(box_regression_per_level, N, A, 4, H, W)
82
83
84
85
        box_regression_flattened.append(box_regression_per_level)
    # concatenate on the first dimension (representing the feature levels), to
    # take into account the way the labels were generated (with all feature maps
    # being concatenated as well)
86
    box_cls = torch.cat(box_cls_flattened, dim=1).flatten(0, -2)
87
88
89
90
91
    box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4)
    return box_cls, box_regression


class RegionProposalNetwork(torch.nn.Module):
92
93
94
    """
    Implements Region Proposal Network (RPN).

95
    Args:
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
            maps.
        head (nn.Module): module that computes the objectness and regression deltas
        fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
            considered as positive during training of the RPN.
        bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
            considered as negative during training of the RPN.
        batch_size_per_image (int): number of anchors that are sampled during training of the RPN
            for computing the loss
        positive_fraction (float): proportion of positive anchors in a mini-batch during training
            of the RPN
        pre_nms_top_n (Dict[int]): number of proposals to keep before applying NMS. It should
            contain two fields: training and testing, to allow for different values depending
            on training or evaluation
        post_nms_top_n (Dict[int]): number of proposals to keep after applying NMS. It should
            contain two fields: training and testing, to allow for different values depending
            on training or evaluation
        nms_thresh (float): NMS threshold used for postprocessing the RPN proposals

    """
116

eellison's avatar
eellison committed
117
    __annotations__ = {
118
119
120
121
122
        "box_coder": det_utils.BoxCoder,
        "proposal_matcher": det_utils.Matcher,
        "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
        "pre_nms_top_n": Dict[str, int],
        "post_nms_top_n": Dict[str, int],
eellison's avatar
eellison committed
123
    }
124

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    def __init__(
        self,
        anchor_generator,
        head,
        #
        fg_iou_thresh,
        bg_iou_thresh,
        batch_size_per_image,
        positive_fraction,
        #
        pre_nms_top_n,
        post_nms_top_n,
        nms_thresh,
        score_thresh=0.0,
    ):
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        super(RegionProposalNetwork, self).__init__()
        self.anchor_generator = anchor_generator
        self.head = head
        self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))

        # used during training
        self.box_similarity = box_ops.box_iou

        self.proposal_matcher = det_utils.Matcher(
            fg_iou_thresh,
            bg_iou_thresh,
            allow_low_quality_matches=True,
        )

154
        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
155
156
157
158
        # used during testing
        self._pre_nms_top_n = pre_nms_top_n
        self._post_nms_top_n = post_nms_top_n
        self.nms_thresh = nms_thresh
159
        self.score_thresh = score_thresh
160
        self.min_size = 1e-3
161
162
163

    def pre_nms_top_n(self):
        if self.training:
164
165
            return self._pre_nms_top_n["training"]
        return self._pre_nms_top_n["testing"]
166
167
168

    def post_nms_top_n(self):
        if self.training:
169
170
            return self._post_nms_top_n["training"]
        return self._post_nms_top_n["testing"]
171
172

    def assign_targets_to_anchors(self, anchors, targets):
173
        # type: (List[Tensor], List[Dict[str, Tensor]]) -> Tuple[List[Tensor], List[Tensor]]
174
175
176
177
        labels = []
        matched_gt_boxes = []
        for anchors_per_image, targets_per_image in zip(anchors, targets):
            gt_boxes = targets_per_image["boxes"]
178
179
180
181
182
183
184

            if gt_boxes.numel() == 0:
                # Background image (negative example)
                device = anchors_per_image.device
                matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device)
                labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device)
            else:
185
                match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image)
186
187
188
189
190
191
192
193
194
195
196
197
                matched_idxs = self.proposal_matcher(match_quality_matrix)
                # get the targets corresponding GT for each proposal
                # NB: need to clamp the indices because we can have a single
                # GT in the image, and matched_idxs can be -2, which goes
                # out of bounds
                matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]

                labels_per_image = matched_idxs >= 0
                labels_per_image = labels_per_image.to(dtype=torch.float32)

                # Background (negative examples)
                bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
198
                labels_per_image[bg_indices] = 0.0
199
200
201

                # discard indices that are between thresholds
                inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
202
                labels_per_image[inds_to_discard] = -1.0
203
204
205
206
207
208

            labels.append(labels_per_image)
            matched_gt_boxes.append(matched_gt_boxes_per_image)
        return labels, matched_gt_boxes

    def _get_top_n_idx(self, objectness, num_anchors_per_level):
209
        # type: (Tensor, List[int]) -> Tensor
210
211
212
        r = []
        offset = 0
        for ob in objectness.split(num_anchors_per_level, 1):
213
            if torchvision._is_tracing():
eellison's avatar
eellison committed
214
                num_anchors, pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n(ob, self.pre_nms_top_n())
215
216
            else:
                num_anchors = ob.shape[1]
eellison's avatar
eellison committed
217
                pre_nms_top_n = min(self.pre_nms_top_n(), num_anchors)
218
219
220
221
222
223
            _, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
            r.append(top_n_idx + offset)
            offset += num_anchors
        return torch.cat(r, dim=1)

    def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level):
224
        # type: (Tensor, Tensor, List[Tuple[int, int]], List[int]) -> Tuple[List[Tensor], List[Tensor]]
225
226
        num_images = proposals.shape[0]
        device = proposals.device
227
        # do not backprop through objectness
228
229
230
231
        objectness = objectness.detach()
        objectness = objectness.reshape(num_images, -1)

        levels = [
232
            torch.full((n,), idx, dtype=torch.int64, device=device) for idx, n in enumerate(num_anchors_per_level)
233
234
235
236
237
238
        ]
        levels = torch.cat(levels, 0)
        levels = levels.reshape(1, -1).expand_as(objectness)

        # select top_n boxes independently per level before applying nms
        top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
eellison's avatar
eellison committed
239
240
241
242

        image_range = torch.arange(num_images, device=device)
        batch_idx = image_range[:, None]

243
244
245
246
        objectness = objectness[batch_idx, top_n_idx]
        levels = levels[batch_idx, top_n_idx]
        proposals = proposals[batch_idx, top_n_idx]

247
        objectness_prob = torch.sigmoid(objectness)
248

249
250
        final_boxes = []
        final_scores = []
251
        for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes):
252
            boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
253
254

            # remove small boxes
255
256
            keep = box_ops.remove_small_boxes(boxes, self.min_size)
            boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
257
258
259
260
261
262

            # remove low scoring boxes
            # use >= for Backwards compatibility
            keep = torch.where(scores >= self.score_thresh)[0]
            boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]

263
264
            # non-maximum suppression, independently done per level
            keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
265

266
            # keep only topk scoring predictions
267
            keep = keep[: self.post_nms_top_n()]
268
            boxes, scores = boxes[keep], scores[keep]
269

270
271
272
273
274
            final_boxes.append(boxes)
            final_scores.append(scores)
        return final_boxes, final_scores

    def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets):
275
        # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
276
        """
277
        Args:
278
279
280
281
            objectness (Tensor)
            pred_bbox_deltas (Tensor)
            labels (List[Tensor])
            regression_targets (List[Tensor])
282
283
284

        Returns:
            objectness_loss (Tensor)
lambdaflow's avatar
lambdaflow committed
285
            box_loss (Tensor)
286
287
288
        """

        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
289
290
        sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds, dim=0))[0]
        sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0]
291
292
293
294
295
296
297
298

        sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)

        objectness = objectness.flatten()

        labels = torch.cat(labels, dim=0)
        regression_targets = torch.cat(regression_targets, dim=0)

299
300
301
302
303
304
305
306
        box_loss = (
            F.smooth_l1_loss(
                pred_bbox_deltas[sampled_pos_inds],
                regression_targets[sampled_pos_inds],
                beta=1 / 9,
                reduction="sum",
            )
            / (sampled_inds.numel())
307
308
        )

309
310
        objectness_loss = F.binary_cross_entropy_with_logits(objectness[sampled_inds], labels[sampled_inds])

311
312
        return objectness_loss, box_loss

313
314
315
316
317
318
    def forward(
        self,
        images,  # type: ImageList
        features,  # type: Dict[str, Tensor]
        targets=None,  # type: Optional[List[Dict[str, Tensor]]]
    ):
319
        # type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]]
320
        """
321
        Args:
322
            images (ImageList): images for which we want to compute the predictions
Jackson Liu's avatar
Jackson Liu committed
323
            features (OrderedDict[Tensor]): features computed from the images that are
324
325
                used for computing the predictions. Each tensor in the list
                correspond to different feature levels
lambdaflow's avatar
lambdaflow committed
326
            targets (List[Dict[Tensor]]): ground-truth boxes present in the image (optional).
327
328
                If provided, each element in the dict should contain a field `boxes`,
                with the locations of the ground-truth boxes.
329
330

        Returns:
331
            boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per
332
                image.
333
            losses (Dict[Tensor]): the losses for the model during training. During
334
335
336
337
338
339
340
341
                testing, it is an empty dict.
        """
        # RPN uses all feature maps that are available
        features = list(features.values())
        objectness, pred_bbox_deltas = self.head(features)
        anchors = self.anchor_generator(images, features)

        num_images = len(anchors)
342
343
        num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
        num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
344
        objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas)
345
346
347
348
349
350
351
352
353
        # apply pred_bbox_deltas to anchors to obtain the decoded proposals
        # note that we detach the deltas because Faster R-CNN do not backprop through
        # the proposals
        proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
        proposals = proposals.view(num_images, -1, 4)
        boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)

        losses = {}
        if self.training:
eellison's avatar
eellison committed
354
            assert targets is not None
355
356
357
            labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
            regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
            loss_objectness, loss_rpn_box_reg = self.compute_loss(
358
359
                objectness, pred_bbox_deltas, labels, regression_targets
            )
360
361
362
363
364
            losses = {
                "loss_objectness": loss_objectness,
                "loss_rpn_box_reg": loss_rpn_box_reg,
            }
        return boxes, losses