rpn.py 15.4 KB
Newer Older
1
from typing import List, Optional, Dict, Tuple
2

3
4
5
import torch
from torch import nn, Tensor
from torch.nn import functional as F
6
from torchvision.ops import Conv2dNormActivation
7
8
9
10
from torchvision.ops import boxes as box_ops

from . import _utils as det_utils

11
# Import AnchorGenerator to keep compatibility.
12
from .anchor_utils import AnchorGenerator  # noqa: 401
13
from .image_list import ImageList
14

15
16
17
18

class RPNHead(nn.Module):
    """
    Adds a simple RPN Head with classification and regression heads
19

20
    Args:
21
22
        in_channels (int): number of channels of the input feature
        num_anchors (int): number of anchors to be predicted
23
        conv_depth (int, optional): number of convolutions
24
25
    """

26
27
28
    _version = 2

    def __init__(self, in_channels: int, num_anchors: int, conv_depth=1) -> None:
29
        super().__init__()
30
31
32
33
        convs = []
        for _ in range(conv_depth):
            convs.append(Conv2dNormActivation(in_channels, in_channels, kernel_size=3, norm_layer=None))
        self.conv = nn.Sequential(*convs)
34
        self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
35
        self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1)
36

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
        for layer in self.modules():
            if isinstance(layer, nn.Conv2d):
                torch.nn.init.normal_(layer.weight, std=0.01)  # type: ignore[arg-type]
                if layer.bias is not None:
                    torch.nn.init.constant_(layer.bias, 0)  # type: ignore[arg-type]

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)

        if version is None or version < 2:
            for type in ["weight", "bias"]:
                old_key = f"{prefix}conv.{type}"
                new_key = f"{prefix}conv.0.0.{type}"
59
60
                if old_key in state_dict:
                    state_dict[new_key] = state_dict.pop(old_key)
61
62
63
64
65
66
67
68
69
70

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )
71

72
    def forward(self, x: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
73
74
75
        logits = []
        bbox_reg = []
        for feature in x:
76
            t = self.conv(feature)
77
78
79
80
81
            logits.append(self.cls_logits(t))
            bbox_reg.append(self.bbox_pred(t))
        return logits, bbox_reg


82
def permute_and_flatten(layer: Tensor, N: int, A: int, C: int, H: int, W: int) -> Tensor:
83
84
85
86
87
88
    layer = layer.view(N, -1, C, H, W)
    layer = layer.permute(0, 3, 4, 1, 2)
    layer = layer.reshape(N, -1, C)
    return layer


89
def concat_box_prediction_layers(box_cls: List[Tensor], box_regression: List[Tensor]) -> Tuple[Tensor, Tensor]:
90
91
92
93
94
95
    box_cls_flattened = []
    box_regression_flattened = []
    # for each feature level, permute the outputs to make them be in the
    # same format as the labels. Note that the labels are computed for
    # all feature levels concatenated, so we keep the same representation
    # for the objectness and the box_regression
96
    for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression):
97
98
99
100
        N, AxC, H, W = box_cls_per_level.shape
        Ax4 = box_regression_per_level.shape[1]
        A = Ax4 // 4
        C = AxC // A
101
        box_cls_per_level = permute_and_flatten(box_cls_per_level, N, A, C, H, W)
102
103
        box_cls_flattened.append(box_cls_per_level)

104
        box_regression_per_level = permute_and_flatten(box_regression_per_level, N, A, 4, H, W)
105
106
107
108
        box_regression_flattened.append(box_regression_per_level)
    # concatenate on the first dimension (representing the feature levels), to
    # take into account the way the labels were generated (with all feature maps
    # being concatenated as well)
109
    box_cls = torch.cat(box_cls_flattened, dim=1).flatten(0, -2)
110
111
112
113
114
    box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4)
    return box_cls, box_regression


class RegionProposalNetwork(torch.nn.Module):
115
116
117
    """
    Implements Region Proposal Network (RPN).

118
    Args:
119
120
121
122
123
124
125
126
127
128
129
        anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
            maps.
        head (nn.Module): module that computes the objectness and regression deltas
        fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
            considered as positive during training of the RPN.
        bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
            considered as negative during training of the RPN.
        batch_size_per_image (int): number of anchors that are sampled during training of the RPN
            for computing the loss
        positive_fraction (float): proportion of positive anchors in a mini-batch during training
            of the RPN
130
        pre_nms_top_n (Dict[str, int]): number of proposals to keep before applying NMS. It should
131
132
            contain two fields: training and testing, to allow for different values depending
            on training or evaluation
133
        post_nms_top_n (Dict[str, int]): number of proposals to keep after applying NMS. It should
134
135
136
137
138
            contain two fields: training and testing, to allow for different values depending
            on training or evaluation
        nms_thresh (float): NMS threshold used for postprocessing the RPN proposals

    """
139

eellison's avatar
eellison committed
140
    __annotations__ = {
141
142
143
        "box_coder": det_utils.BoxCoder,
        "proposal_matcher": det_utils.Matcher,
        "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
eellison's avatar
eellison committed
144
    }
145

146
147
    def __init__(
        self,
148
149
150
151
152
153
154
155
156
157
158
159
160
        anchor_generator: AnchorGenerator,
        head: nn.Module,
        # Faster-RCNN Training
        fg_iou_thresh: float,
        bg_iou_thresh: float,
        batch_size_per_image: int,
        positive_fraction: float,
        # Faster-RCNN Inference
        pre_nms_top_n: Dict[str, int],
        post_nms_top_n: Dict[str, int],
        nms_thresh: float,
        score_thresh: float = 0.0,
    ) -> None:
161
        super().__init__()
162
163
164
165
166
167
168
169
170
171
172
173
174
        self.anchor_generator = anchor_generator
        self.head = head
        self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))

        # used during training
        self.box_similarity = box_ops.box_iou

        self.proposal_matcher = det_utils.Matcher(
            fg_iou_thresh,
            bg_iou_thresh,
            allow_low_quality_matches=True,
        )

175
        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
176
177
178
179
        # used during testing
        self._pre_nms_top_n = pre_nms_top_n
        self._post_nms_top_n = post_nms_top_n
        self.nms_thresh = nms_thresh
180
        self.score_thresh = score_thresh
181
        self.min_size = 1e-3
182

183
    def pre_nms_top_n(self) -> int:
184
        if self.training:
185
186
            return self._pre_nms_top_n["training"]
        return self._pre_nms_top_n["testing"]
187

188
    def post_nms_top_n(self) -> int:
189
        if self.training:
190
191
            return self._post_nms_top_n["training"]
        return self._post_nms_top_n["testing"]
192

193
194
195
196
    def assign_targets_to_anchors(
        self, anchors: List[Tensor], targets: List[Dict[str, Tensor]]
    ) -> Tuple[List[Tensor], List[Tensor]]:

197
198
199
200
        labels = []
        matched_gt_boxes = []
        for anchors_per_image, targets_per_image in zip(anchors, targets):
            gt_boxes = targets_per_image["boxes"]
201
202
203
204
205
206
207

            if gt_boxes.numel() == 0:
                # Background image (negative example)
                device = anchors_per_image.device
                matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device)
                labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device)
            else:
208
                match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image)
209
210
211
212
213
214
215
216
217
218
219
220
                matched_idxs = self.proposal_matcher(match_quality_matrix)
                # get the targets corresponding GT for each proposal
                # NB: need to clamp the indices because we can have a single
                # GT in the image, and matched_idxs can be -2, which goes
                # out of bounds
                matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]

                labels_per_image = matched_idxs >= 0
                labels_per_image = labels_per_image.to(dtype=torch.float32)

                # Background (negative examples)
                bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
221
                labels_per_image[bg_indices] = 0.0
222
223
224

                # discard indices that are between thresholds
                inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
225
                labels_per_image[inds_to_discard] = -1.0
226
227
228
229
230

            labels.append(labels_per_image)
            matched_gt_boxes.append(matched_gt_boxes_per_image)
        return labels, matched_gt_boxes

231
    def _get_top_n_idx(self, objectness: Tensor, num_anchors_per_level: List[int]) -> Tensor:
232
233
234
        r = []
        offset = 0
        for ob in objectness.split(num_anchors_per_level, 1):
235
236
            num_anchors = ob.shape[1]
            pre_nms_top_n = det_utils._topk_min(ob, self.pre_nms_top_n(), 1)
237
238
239
240
241
            _, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
            r.append(top_n_idx + offset)
            offset += num_anchors
        return torch.cat(r, dim=1)

242
243
244
245
246
247
248
249
    def filter_proposals(
        self,
        proposals: Tensor,
        objectness: Tensor,
        image_shapes: List[Tuple[int, int]],
        num_anchors_per_level: List[int],
    ) -> Tuple[List[Tensor], List[Tensor]]:

250
251
        num_images = proposals.shape[0]
        device = proposals.device
252
        # do not backprop through objectness
253
254
255
256
        objectness = objectness.detach()
        objectness = objectness.reshape(num_images, -1)

        levels = [
257
            torch.full((n,), idx, dtype=torch.int64, device=device) for idx, n in enumerate(num_anchors_per_level)
258
259
260
261
262
263
        ]
        levels = torch.cat(levels, 0)
        levels = levels.reshape(1, -1).expand_as(objectness)

        # select top_n boxes independently per level before applying nms
        top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
eellison's avatar
eellison committed
264
265
266
267

        image_range = torch.arange(num_images, device=device)
        batch_idx = image_range[:, None]

268
269
270
271
        objectness = objectness[batch_idx, top_n_idx]
        levels = levels[batch_idx, top_n_idx]
        proposals = proposals[batch_idx, top_n_idx]

272
        objectness_prob = torch.sigmoid(objectness)
273

274
275
        final_boxes = []
        final_scores = []
276
        for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes):
277
            boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
278
279

            # remove small boxes
280
281
            keep = box_ops.remove_small_boxes(boxes, self.min_size)
            boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
282
283
284
285
286
287

            # remove low scoring boxes
            # use >= for Backwards compatibility
            keep = torch.where(scores >= self.score_thresh)[0]
            boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]

288
289
            # non-maximum suppression, independently done per level
            keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
290

291
            # keep only topk scoring predictions
292
            keep = keep[: self.post_nms_top_n()]
293
            boxes, scores = boxes[keep], scores[keep]
294

295
296
297
298
            final_boxes.append(boxes)
            final_scores.append(scores)
        return final_boxes, final_scores

299
300
301
    def compute_loss(
        self, objectness: Tensor, pred_bbox_deltas: Tensor, labels: List[Tensor], regression_targets: List[Tensor]
    ) -> Tuple[Tensor, Tensor]:
302
        """
303
        Args:
304
305
306
307
            objectness (Tensor)
            pred_bbox_deltas (Tensor)
            labels (List[Tensor])
            regression_targets (List[Tensor])
308
309
310

        Returns:
            objectness_loss (Tensor)
lambdaflow's avatar
lambdaflow committed
311
            box_loss (Tensor)
312
313
314
        """

        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
315
316
        sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds, dim=0))[0]
        sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0]
317
318
319
320
321
322
323
324

        sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)

        objectness = objectness.flatten()

        labels = torch.cat(labels, dim=0)
        regression_targets = torch.cat(regression_targets, dim=0)

325
326
327
328
329
330
331
332
        box_loss = (
            F.smooth_l1_loss(
                pred_bbox_deltas[sampled_pos_inds],
                regression_targets[sampled_pos_inds],
                beta=1 / 9,
                reduction="sum",
            )
            / (sampled_inds.numel())
333
334
        )

335
336
        objectness_loss = F.binary_cross_entropy_with_logits(objectness[sampled_inds], labels[sampled_inds])

337
338
        return objectness_loss, box_loss

339
340
    def forward(
        self,
341
342
343
344
345
        images: ImageList,
        features: Dict[str, Tensor],
        targets: Optional[List[Dict[str, Tensor]]] = None,
    ) -> Tuple[List[Tensor], Dict[str, Tensor]]:

346
        """
347
        Args:
348
            images (ImageList): images for which we want to compute the predictions
349
            features (Dict[str, Tensor]): features computed from the images that are
350
351
                used for computing the predictions. Each tensor in the list
                correspond to different feature levels
352
            targets (List[Dict[str, Tensor]]): ground-truth boxes present in the image (optional).
353
354
                If provided, each element in the dict should contain a field `boxes`,
                with the locations of the ground-truth boxes.
355
356

        Returns:
357
            boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per
358
                image.
359
            losses (Dict[str, Tensor]): the losses for the model during training. During
360
361
362
363
364
365
366
367
                testing, it is an empty dict.
        """
        # RPN uses all feature maps that are available
        features = list(features.values())
        objectness, pred_bbox_deltas = self.head(features)
        anchors = self.anchor_generator(images, features)

        num_images = len(anchors)
368
369
        num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
        num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
370
        objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas)
371
372
373
374
375
376
377
378
379
        # apply pred_bbox_deltas to anchors to obtain the decoded proposals
        # note that we detach the deltas because Faster R-CNN do not backprop through
        # the proposals
        proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
        proposals = proposals.view(num_images, -1, 4)
        boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)

        losses = {}
        if self.training:
380
381
            if targets is None:
                raise ValueError("targets should not be None")
382
383
384
            labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
            regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
            loss_objectness, loss_rpn_box_reg = self.compute_loss(
385
386
                objectness, pred_bbox_deltas, labels, regression_targets
            )
387
388
389
390
391
            losses = {
                "loss_objectness": loss_objectness,
                "loss_rpn_box_reg": loss_rpn_box_reg,
            }
        return boxes, losses