"test/ut/sdk/test_pruners.py" did not exist on "654e8242b47b4bce4586feb3aa1aa01a0d3ce3a9"
rpn.py 15.5 KB
Newer Older
limm's avatar
limm committed
1
2
from typing import Dict, List, Optional, Tuple

3
import torch
eellison's avatar
eellison committed
4
from torch import nn, Tensor
limm's avatar
limm committed
5
6
from torch.nn import functional as F
from torchvision.ops import boxes as box_ops, Conv2dNormActivation
7
8
9

from . import _utils as det_utils

10
# Import AnchorGenerator to keep compatibility.
limm's avatar
limm committed
11
12
from .anchor_utils import AnchorGenerator  # noqa: 401
from .image_list import ImageList
13
14


15
16
17
class RPNHead(nn.Module):
    """
    Adds a simple RPN Head with classification and regression heads
18

19
    Args:
20
21
        in_channels (int): number of channels of the input feature
        num_anchors (int): number of anchors to be predicted
limm's avatar
limm committed
22
        conv_depth (int, optional): number of convolutions
23
24
    """

limm's avatar
limm committed
25
26
27
28
29
30
31
32
    _version = 2

    def __init__(self, in_channels: int, num_anchors: int, conv_depth=1) -> None:
        super().__init__()
        convs = []
        for _ in range(conv_depth):
            convs.append(Conv2dNormActivation(in_channels, in_channels, kernel_size=3, norm_layer=None))
        self.conv = nn.Sequential(*convs)
33
        self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
limm's avatar
limm committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1)

        for layer in self.modules():
            if isinstance(layer, nn.Conv2d):
                torch.nn.init.normal_(layer.weight, std=0.01)  # type: ignore[arg-type]
                if layer.bias is not None:
                    torch.nn.init.constant_(layer.bias, 0)  # type: ignore[arg-type]

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)

        if version is None or version < 2:
            for type in ["weight", "bias"]:
                old_key = f"{prefix}conv.{type}"
                new_key = f"{prefix}conv.0.0.{type}"
                if old_key in state_dict:
                    state_dict[new_key] = state_dict.pop(old_key)

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
69
70
        )

limm's avatar
limm committed
71
    def forward(self, x: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
72
73
74
        logits = []
        bbox_reg = []
        for feature in x:
limm's avatar
limm committed
75
            t = self.conv(feature)
76
77
78
79
80
            logits.append(self.cls_logits(t))
            bbox_reg.append(self.bbox_pred(t))
        return logits, bbox_reg


limm's avatar
limm committed
81
def permute_and_flatten(layer: Tensor, N: int, A: int, C: int, H: int, W: int) -> Tensor:
82
83
84
85
86
87
    layer = layer.view(N, -1, C, H, W)
    layer = layer.permute(0, 3, 4, 1, 2)
    layer = layer.reshape(N, -1, C)
    return layer


limm's avatar
limm committed
88
def concat_box_prediction_layers(box_cls: List[Tensor], box_regression: List[Tensor]) -> Tuple[Tensor, Tensor]:
89
90
91
92
93
94
    box_cls_flattened = []
    box_regression_flattened = []
    # for each feature level, permute the outputs to make them be in the
    # same format as the labels. Note that the labels are computed for
    # all feature levels concatenated, so we keep the same representation
    # for the objectness and the box_regression
limm's avatar
limm committed
95
    for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression):
96
97
98
99
        N, AxC, H, W = box_cls_per_level.shape
        Ax4 = box_regression_per_level.shape[1]
        A = Ax4 // 4
        C = AxC // A
limm's avatar
limm committed
100
        box_cls_per_level = permute_and_flatten(box_cls_per_level, N, A, C, H, W)
101
102
        box_cls_flattened.append(box_cls_per_level)

limm's avatar
limm committed
103
        box_regression_per_level = permute_and_flatten(box_regression_per_level, N, A, 4, H, W)
104
105
106
107
        box_regression_flattened.append(box_regression_per_level)
    # concatenate on the first dimension (representing the feature levels), to
    # take into account the way the labels were generated (with all feature maps
    # being concatenated as well)
108
    box_cls = torch.cat(box_cls_flattened, dim=1).flatten(0, -2)
109
110
111
112
113
    box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4)
    return box_cls, box_regression


class RegionProposalNetwork(torch.nn.Module):
114
115
116
    """
    Implements Region Proposal Network (RPN).

117
    Args:
118
119
120
121
122
123
124
125
126
127
128
        anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
            maps.
        head (nn.Module): module that computes the objectness and regression deltas
        fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
            considered as positive during training of the RPN.
        bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
            considered as negative during training of the RPN.
        batch_size_per_image (int): number of anchors that are sampled during training of the RPN
            for computing the loss
        positive_fraction (float): proportion of positive anchors in a mini-batch during training
            of the RPN
limm's avatar
limm committed
129
        pre_nms_top_n (Dict[str, int]): number of proposals to keep before applying NMS. It should
130
131
            contain two fields: training and testing, to allow for different values depending
            on training or evaluation
limm's avatar
limm committed
132
        post_nms_top_n (Dict[str, int]): number of proposals to keep after applying NMS. It should
133
134
135
            contain two fields: training and testing, to allow for different values depending
            on training or evaluation
        nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
limm's avatar
limm committed
136
        score_thresh (float): only return proposals with an objectness score greater than score_thresh
137
138

    """
limm's avatar
limm committed
139

eellison's avatar
eellison committed
140
    __annotations__ = {
limm's avatar
limm committed
141
142
143
        "box_coder": det_utils.BoxCoder,
        "proposal_matcher": det_utils.Matcher,
        "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
eellison's avatar
eellison committed
144
    }
145

limm's avatar
limm committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    def __init__(
        self,
        anchor_generator: AnchorGenerator,
        head: nn.Module,
        # Faster-RCNN Training
        fg_iou_thresh: float,
        bg_iou_thresh: float,
        batch_size_per_image: int,
        positive_fraction: float,
        # Faster-RCNN Inference
        pre_nms_top_n: Dict[str, int],
        post_nms_top_n: Dict[str, int],
        nms_thresh: float,
        score_thresh: float = 0.0,
    ) -> None:
        super().__init__()
162
163
164
165
166
167
168
169
170
171
172
173
174
        self.anchor_generator = anchor_generator
        self.head = head
        self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))

        # used during training
        self.box_similarity = box_ops.box_iou

        self.proposal_matcher = det_utils.Matcher(
            fg_iou_thresh,
            bg_iou_thresh,
            allow_low_quality_matches=True,
        )

limm's avatar
limm committed
175
        self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
176
177
178
179
        # used during testing
        self._pre_nms_top_n = pre_nms_top_n
        self._post_nms_top_n = post_nms_top_n
        self.nms_thresh = nms_thresh
180
        self.score_thresh = score_thresh
181
        self.min_size = 1e-3
182

limm's avatar
limm committed
183
    def pre_nms_top_n(self) -> int:
184
        if self.training:
limm's avatar
limm committed
185
186
            return self._pre_nms_top_n["training"]
        return self._pre_nms_top_n["testing"]
187

limm's avatar
limm committed
188
    def post_nms_top_n(self) -> int:
189
        if self.training:
limm's avatar
limm committed
190
191
192
193
194
195
            return self._post_nms_top_n["training"]
        return self._post_nms_top_n["testing"]

    def assign_targets_to_anchors(
        self, anchors: List[Tensor], targets: List[Dict[str, Tensor]]
    ) -> Tuple[List[Tensor], List[Tensor]]:
196
197
198
199
200

        labels = []
        matched_gt_boxes = []
        for anchors_per_image, targets_per_image in zip(anchors, targets):
            gt_boxes = targets_per_image["boxes"]
201
202
203
204
205
206
207

            if gt_boxes.numel() == 0:
                # Background image (negative example)
                device = anchors_per_image.device
                matched_gt_boxes_per_image = torch.zeros(anchors_per_image.shape, dtype=torch.float32, device=device)
                labels_per_image = torch.zeros((anchors_per_image.shape[0],), dtype=torch.float32, device=device)
            else:
208
                match_quality_matrix = self.box_similarity(gt_boxes, anchors_per_image)
209
210
211
212
213
214
215
216
217
218
219
220
                matched_idxs = self.proposal_matcher(match_quality_matrix)
                # get the targets corresponding GT for each proposal
                # NB: need to clamp the indices because we can have a single
                # GT in the image, and matched_idxs can be -2, which goes
                # out of bounds
                matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]

                labels_per_image = matched_idxs >= 0
                labels_per_image = labels_per_image.to(dtype=torch.float32)

                # Background (negative examples)
                bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
221
                labels_per_image[bg_indices] = 0.0
222
223
224

                # discard indices that are between thresholds
                inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
225
                labels_per_image[inds_to_discard] = -1.0
226
227
228
229
230

            labels.append(labels_per_image)
            matched_gt_boxes.append(matched_gt_boxes_per_image)
        return labels, matched_gt_boxes

limm's avatar
limm committed
231
    def _get_top_n_idx(self, objectness: Tensor, num_anchors_per_level: List[int]) -> Tensor:
232
233
234
        r = []
        offset = 0
        for ob in objectness.split(num_anchors_per_level, 1):
limm's avatar
limm committed
235
236
            num_anchors = ob.shape[1]
            pre_nms_top_n = det_utils._topk_min(ob, self.pre_nms_top_n(), 1)
237
238
239
240
241
            _, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
            r.append(top_n_idx + offset)
            offset += num_anchors
        return torch.cat(r, dim=1)

limm's avatar
limm committed
242
243
244
245
246
247
248
249
    def filter_proposals(
        self,
        proposals: Tensor,
        objectness: Tensor,
        image_shapes: List[Tuple[int, int]],
        num_anchors_per_level: List[int],
    ) -> Tuple[List[Tensor], List[Tensor]]:

250
251
        num_images = proposals.shape[0]
        device = proposals.device
limm's avatar
limm committed
252
        # do not backprop through objectness
253
254
255
256
        objectness = objectness.detach()
        objectness = objectness.reshape(num_images, -1)

        levels = [
limm's avatar
limm committed
257
            torch.full((n,), idx, dtype=torch.int64, device=device) for idx, n in enumerate(num_anchors_per_level)
258
259
260
261
262
263
        ]
        levels = torch.cat(levels, 0)
        levels = levels.reshape(1, -1).expand_as(objectness)

        # select top_n boxes independently per level before applying nms
        top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
eellison's avatar
eellison committed
264
265
266
267

        image_range = torch.arange(num_images, device=device)
        batch_idx = image_range[:, None]

268
269
270
271
        objectness = objectness[batch_idx, top_n_idx]
        levels = levels[batch_idx, top_n_idx]
        proposals = proposals[batch_idx, top_n_idx]

272
        objectness_prob = torch.sigmoid(objectness)
273

274
275
        final_boxes = []
        final_scores = []
276
        for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes):
277
            boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
278
279

            # remove small boxes
280
281
            keep = box_ops.remove_small_boxes(boxes, self.min_size)
            boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
282
283
284
285
286
287

            # remove low scoring boxes
            # use >= for Backwards compatibility
            keep = torch.where(scores >= self.score_thresh)[0]
            boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]

288
289
            # non-maximum suppression, independently done per level
            keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
290

291
            # keep only topk scoring predictions
limm's avatar
limm committed
292
            keep = keep[: self.post_nms_top_n()]
293
            boxes, scores = boxes[keep], scores[keep]
294

295
296
297
298
            final_boxes.append(boxes)
            final_scores.append(scores)
        return final_boxes, final_scores

limm's avatar
limm committed
299
300
301
    def compute_loss(
        self, objectness: Tensor, pred_bbox_deltas: Tensor, labels: List[Tensor], regression_targets: List[Tensor]
    ) -> Tuple[Tensor, Tensor]:
302
        """
303
        Args:
304
305
306
307
            objectness (Tensor)
            pred_bbox_deltas (Tensor)
            labels (List[Tensor])
            regression_targets (List[Tensor])
308
309
310

        Returns:
            objectness_loss (Tensor)
lambdaflow's avatar
lambdaflow committed
311
            box_loss (Tensor)
312
313
314
        """

        sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
315
316
        sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds, dim=0))[0]
        sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0]
317
318
319
320
321
322
323
324

        sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)

        objectness = objectness.flatten()

        labels = torch.cat(labels, dim=0)
        regression_targets = torch.cat(regression_targets, dim=0)

325
        box_loss = F.smooth_l1_loss(
326
327
            pred_bbox_deltas[sampled_pos_inds],
            regression_targets[sampled_pos_inds],
328
            beta=1 / 9,
limm's avatar
limm committed
329
            reduction="sum",
330
331
        ) / (sampled_inds.numel())

limm's avatar
limm committed
332
        objectness_loss = F.binary_cross_entropy_with_logits(objectness[sampled_inds], labels[sampled_inds])
333
334
335

        return objectness_loss, box_loss

limm's avatar
limm committed
336
337
338
339
340
341
342
    def forward(
        self,
        images: ImageList,
        features: Dict[str, Tensor],
        targets: Optional[List[Dict[str, Tensor]]] = None,
    ) -> Tuple[List[Tensor], Dict[str, Tensor]]:

343
        """
344
        Args:
345
            images (ImageList): images for which we want to compute the predictions
limm's avatar
limm committed
346
            features (Dict[str, Tensor]): features computed from the images that are
347
348
                used for computing the predictions. Each tensor in the list
                correspond to different feature levels
limm's avatar
limm committed
349
            targets (List[Dict[str, Tensor]]): ground-truth boxes present in the image (optional).
350
351
                If provided, each element in the dict should contain a field `boxes`,
                with the locations of the ground-truth boxes.
352
353

        Returns:
354
            boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per
355
                image.
limm's avatar
limm committed
356
            losses (Dict[str, Tensor]): the losses for the model during training. During
357
358
359
360
361
362
363
364
                testing, it is an empty dict.
        """
        # RPN uses all feature maps that are available
        features = list(features.values())
        objectness, pred_bbox_deltas = self.head(features)
        anchors = self.anchor_generator(images, features)

        num_images = len(anchors)
365
366
        num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
        num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
limm's avatar
limm committed
367
        objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas)
368
369
370
371
372
373
374
375
376
        # apply pred_bbox_deltas to anchors to obtain the decoded proposals
        # note that we detach the deltas because Faster R-CNN do not backprop through
        # the proposals
        proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
        proposals = proposals.view(num_images, -1, 4)
        boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)

        losses = {}
        if self.training:
limm's avatar
limm committed
377
378
            if targets is None:
                raise ValueError("targets should not be None")
379
380
381
            labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
            regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
            loss_objectness, loss_rpn_box_reg = self.compute_loss(
limm's avatar
limm committed
382
383
                objectness, pred_bbox_deltas, labels, regression_targets
            )
384
385
386
387
388
            losses = {
                "loss_objectness": loss_objectness,
                "loss_rpn_box_reg": loss_rpn_box_reg,
            }
        return boxes, losses