retinanet.py 35.7 KB
Newer Older
1
2
import math
import warnings
3
from collections import OrderedDict
4
5
from functools import partial
from typing import Any, Callable, Dict, List, Tuple, Optional
6
7

import torch
8
from torch import nn, Tensor
9

10
11
from ...ops import sigmoid_focal_loss
from ...ops import boxes as box_ops
12
from ...ops import misc as misc_nn_ops
13
from ...ops.feature_pyramid_network import LastLevelP6P7
14
from ...transforms._presets import ObjectDetection
15
from ...utils import _log_api_usage_once
16
17
18
19
from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..resnet import ResNet50_Weights, resnet50
20
from . import _utils as det_utils
21
from ._utils import overwrite_eps, _box_loss
22
from .anchor_utils import AnchorGenerator
23
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
24
from .transform import GeneralizedRCNNTransform
25
26


27
28
29
__all__ = [
    "RetinaNet",
    "RetinaNet_ResNet50_FPN_Weights",
30
    "RetinaNet_ResNet50_FPN_V2_Weights",
31
    "retinanet_resnet50_fpn",
32
    "retinanet_resnet50_fpn_v2",
33
]
34
35
36
37
38
39
40
41
42


def _sum(x: List[Tensor]) -> Tensor:
    res = x[0]
    for i in x[1:]:
        res = res + i
    return res


43
44
45
46
47
def _v1_to_v2_weights(state_dict, prefix):
    for i in range(4):
        for type in ["weight", "bias"]:
            old_key = f"{prefix}conv.{2*i}.{type}"
            new_key = f"{prefix}conv.{i}.0.{type}"
48
49
            if old_key in state_dict:
                state_dict[new_key] = state_dict.pop(old_key)
50
51
52
53
54
55
56
57
58


def _default_anchorgen():
    anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512])
    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
    anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
    return anchor_generator


59
60
61
62
class RetinaNetHead(nn.Module):
    """
    A regression and classification head for use in RetinaNet.

63
    Args:
64
65
66
        in_channels (int): number of channels of the input feature
        num_anchors (int): number of anchors to be predicted
        num_classes (int): number of classes to be predicted
67
        norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
68
69
    """

70
    def __init__(self, in_channels, num_anchors, num_classes, norm_layer: Optional[Callable[..., nn.Module]] = None):
71
        super().__init__()
72
73
74
75
        self.classification_head = RetinaNetClassificationHead(
            in_channels, num_anchors, num_classes, norm_layer=norm_layer
        )
        self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors, norm_layer=norm_layer)
76
77
78
79

    def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
        # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor]
        return {
80
81
            "classification": self.classification_head.compute_loss(targets, head_outputs, matched_idxs),
            "bbox_regression": self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs),
82
83
84
85
        }

    def forward(self, x):
        # type: (List[Tensor]) -> Dict[str, Tensor]
86
        return {"cls_logits": self.classification_head(x), "bbox_regression": self.regression_head(x)}
87
88
89
90
91
92


class RetinaNetClassificationHead(nn.Module):
    """
    A classification head for use in RetinaNet.

93
    Args:
94
95
96
        in_channels (int): number of channels of the input feature
        num_anchors (int): number of anchors to be predicted
        num_classes (int): number of classes to be predicted
97
        norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
98
99
    """

100
101
102
103
104
105
106
107
108
109
    _version = 2

    def __init__(
        self,
        in_channels,
        num_anchors,
        num_classes,
        prior_probability=0.01,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ):
110
111
112
113
        super().__init__()

        conv = []
        for _ in range(4):
114
            conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer))
115
116
        self.conv = nn.Sequential(*conv)

117
        for layer in self.conv.modules():
118
119
            if isinstance(layer, nn.Conv2d):
                torch.nn.init.normal_(layer.weight, std=0.01)
120
121
                if layer.bias is not None:
                    torch.nn.init.constant_(layer.bias, 0)
122
123
124
125
126
127
128
129
130
131
132
133
134

        self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
        torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
        torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))

        self.num_classes = num_classes
        self.num_anchors = num_anchors

        # This is to fix using det_utils.Matcher.BETWEEN_THRESHOLDS in TorchScript.
        # TorchScript doesn't support class attributes.
        # https://github.com/pytorch/vision/pull/1697#issuecomment-630255584
        self.BETWEEN_THRESHOLDS = det_utils.Matcher.BETWEEN_THRESHOLDS

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)

        if version is None or version < 2:
            _v1_to_v2_weights(state_dict, prefix)

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )

160
161
162
163
    def compute_loss(self, targets, head_outputs, matched_idxs):
        # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor
        losses = []

164
        cls_logits = head_outputs["cls_logits"]
165
166
167
168
169

        for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs):
            # determine only the foreground
            foreground_idxs_per_image = matched_idxs_per_image >= 0
            num_foreground = foreground_idxs_per_image.sum()
170
171
172
173
174

            # create the target classification
            gt_classes_target = torch.zeros_like(cls_logits_per_image)
            gt_classes_target[
                foreground_idxs_per_image,
175
                targets_per_image["labels"][matched_idxs_per_image[foreground_idxs_per_image]],
176
177
178
179
            ] = 1.0

            # find indices for which anchors should be ignored
            valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS
180
181

            # compute the classification loss
182
183
184
185
186
187
188
189
            losses.append(
                sigmoid_focal_loss(
                    cls_logits_per_image[valid_idxs_per_image],
                    gt_classes_target[valid_idxs_per_image],
                    reduction="sum",
                )
                / max(1, num_foreground)
            )
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215

        return _sum(losses) / len(targets)

    def forward(self, x):
        # type: (List[Tensor]) -> Tensor
        all_cls_logits = []

        for features in x:
            cls_logits = self.conv(features)
            cls_logits = self.cls_logits(cls_logits)

            # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
            N, _, H, W = cls_logits.shape
            cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
            cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
            cls_logits = cls_logits.reshape(N, -1, self.num_classes)  # Size=(N, HWA, 4)

            all_cls_logits.append(cls_logits)

        return torch.cat(all_cls_logits, dim=1)


class RetinaNetRegressionHead(nn.Module):
    """
    A regression head for use in RetinaNet.

216
    Args:
217
218
        in_channels (int): number of channels of the input feature
        num_anchors (int): number of anchors to be predicted
219
        norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
220
    """
221

222
223
    _version = 2

224
    __annotations__ = {
225
        "box_coder": det_utils.BoxCoder,
226
227
    }

228
    def __init__(self, in_channels, num_anchors, norm_layer: Optional[Callable[..., nn.Module]] = None):
229
230
231
232
        super().__init__()

        conv = []
        for _ in range(4):
233
            conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer))
234
235
236
237
238
239
        self.conv = nn.Sequential(*conv)

        self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
        torch.nn.init.normal_(self.bbox_reg.weight, std=0.01)
        torch.nn.init.zeros_(self.bbox_reg.bias)

240
        for layer in self.conv.modules():
241
242
            if isinstance(layer, nn.Conv2d):
                torch.nn.init.normal_(layer.weight, std=0.01)
243
244
                if layer.bias is not None:
                    torch.nn.init.zeros_(layer.bias)
245
246

        self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
        self._loss_type = "l1"

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)

        if version is None or version < 2:
            _v1_to_v2_weights(state_dict, prefix)

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )
273
274
275
276
277

    def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
        # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor
        losses = []

278
        bbox_regression = head_outputs["bbox_regression"]
279

280
281
282
        for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(
            targets, bbox_regression, anchors, matched_idxs
        ):
283
            # determine only the foreground indices, ignore the rest
284
285
            foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
            num_foreground = foreground_idxs_per_image.numel()
286
287

            # select only the foreground boxes
288
            matched_gt_boxes_per_image = targets_per_image["boxes"][matched_idxs_per_image[foreground_idxs_per_image]]
289
290
291
292
            bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
            anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]

            # compute the loss
293
            losses.append(
294
295
296
297
298
299
300
                _box_loss(
                    self._loss_type,
                    self.box_coder,
                    anchors_per_image,
                    matched_gt_boxes_per_image,
                    bbox_regression_per_image,
                )
301
302
                / max(1, num_foreground)
            )
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335

        return _sum(losses) / max(1, len(targets))

    def forward(self, x):
        # type: (List[Tensor]) -> Tensor
        all_bbox_regression = []

        for features in x:
            bbox_regression = self.conv(features)
            bbox_regression = self.bbox_reg(bbox_regression)

            # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
            N, _, H, W = bbox_regression.shape
            bbox_regression = bbox_regression.view(N, -1, 4, H, W)
            bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
            bbox_regression = bbox_regression.reshape(N, -1, 4)  # Size=(N, HWA, 4)

            all_bbox_regression.append(bbox_regression)

        return torch.cat(all_bbox_regression, dim=1)


class RetinaNet(nn.Module):
    """
    Implements RetinaNet.

    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
    image, and should be in 0-1 range. Different images can have different sizes.

    The behavior of the model changes depending if it is in training or evaluation mode.

    During training, the model expects both the input tensors, as well as a targets (list of dictionary),
    containing:
336
337
        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
338
339
340
341
342
343
344
345
        - labels (Int64Tensor[N]): the class label for each ground-truth box

    The model returns a Dict[Tensor] during training, containing the classification and regression
    losses.

    During inference, the model requires only the input tensors, and returns the post-processed
    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
    follows:
346
347
        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
348
349
350
        - labels (Int64Tensor[N]): the predicted labels for each image
        - scores (Tensor[N]): the scores for each prediction

351
    Args:
352
353
354
355
        backbone (nn.Module): the network used to compute the features for the model.
            It should contain an out_channels attribute, which indicates the number of output
            channels that each feature map has (and it should be the same for all feature maps).
            The backbone should return a single Tensor or an OrderedDict[Tensor].
356
        num_classes (int): number of output classes of the model (including the background).
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
        min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
        max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
        image_mean (Tuple[float, float, float]): mean values used for input normalization.
            They are generally the mean values of the dataset on which the backbone has been trained
            on
        image_std (Tuple[float, float, float]): std values used for input normalization.
            They are generally the std values of the dataset on which the backbone has been trained on
        anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
            maps.
        head (nn.Module): Module run on top of the feature pyramid.
            Defaults to a module containing a classification and regression module.
        score_thresh (float): Score threshold used for postprocessing the detections.
        nms_thresh (float): NMS threshold used for postprocessing the detections.
        detections_per_img (int): Number of best detections to keep after NMS.
        fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
            considered as positive during training.
        bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
            considered as negative during training.
375
        topk_candidates (int): Number of best detections to keep before NMS.
376
377
378
379
380
381
382
383
384

    Example:

        >>> import torch
        >>> import torchvision
        >>> from torchvision.models.detection import RetinaNet
        >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
        >>> # load a pre-trained model for classification and return
        >>> # only the features
385
        >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
386
387
388
389
390
391
392
393
394
395
396
        >>> # RetinaNet needs to know the number of
        >>> # output channels in a backbone. For mobilenet_v2, it's 1280
        >>> # so we need to add it here
        >>> backbone.out_channels = 1280
        >>>
        >>> # let's make the network generate 5 x 3 anchors per spatial
        >>> # location, with 5 different sizes and 3 different aspect
        >>> # ratios. We have a Tuple[Tuple[int]] because each feature
        >>> # map could potentially have different sizes and
        >>> # aspect ratios
        >>> anchor_generator = AnchorGenerator(
397
398
        >>>     sizes=((32, 64, 128, 256, 512),),
        >>>     aspect_ratios=((0.5, 1.0, 2.0),)
399
400
401
402
403
404
405
406
407
408
        >>> )
        >>>
        >>> # put the pieces together inside a RetinaNet model
        >>> model = RetinaNet(backbone,
        >>>                   num_classes=2,
        >>>                   anchor_generator=anchor_generator)
        >>> model.eval()
        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
        >>> predictions = model(x)
    """
409

410
    __annotations__ = {
411
412
        "box_coder": det_utils.BoxCoder,
        "proposal_matcher": det_utils.Matcher,
413
414
    }

415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
    def __init__(
        self,
        backbone,
        num_classes,
        # transform parameters
        min_size=800,
        max_size=1333,
        image_mean=None,
        image_std=None,
        # Anchor parameters
        anchor_generator=None,
        head=None,
        proposal_matcher=None,
        score_thresh=0.05,
        nms_thresh=0.5,
        detections_per_img=300,
        fg_iou_thresh=0.5,
        bg_iou_thresh=0.4,
        topk_candidates=1000,
434
        **kwargs,
435
    ):
436
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
437
        _log_api_usage_once(self)
438
439
440
441
442

        if not hasattr(backbone, "out_channels"):
            raise ValueError(
                "backbone should contain an attribute out_channels "
                "specifying the number of output channels (assumed to be the "
443
444
                "same for all the levels)"
            )
445
446
        self.backbone = backbone

447
448
449
450
        if not isinstance(anchor_generator, (AnchorGenerator, type(None))):
            raise TypeError(
                f"anchor_generator should be of type AnchorGenerator or None instead of {type(anchor_generator)}"
            )
451
452

        if anchor_generator is None:
453
            anchor_generator = _default_anchorgen()
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
        self.anchor_generator = anchor_generator

        if head is None:
            head = RetinaNetHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes)
        self.head = head

        if proposal_matcher is None:
            proposal_matcher = det_utils.Matcher(
                fg_iou_thresh,
                bg_iou_thresh,
                allow_low_quality_matches=True,
            )
        self.proposal_matcher = proposal_matcher

        self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))

        if image_mean is None:
            image_mean = [0.485, 0.456, 0.406]
        if image_std is None:
            image_std = [0.229, 0.224, 0.225]
474
        self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
475
476
477
478

        self.score_thresh = score_thresh
        self.nms_thresh = nms_thresh
        self.detections_per_img = detections_per_img
479
        self.topk_candidates = topk_candidates
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495

        # used only on torchscript mode
        self._has_warned = False

    @torch.jit.unused
    def eager_outputs(self, losses, detections):
        # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
        if self.training:
            return losses

        return detections

    def compute_loss(self, targets, head_outputs, anchors):
        # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor]
        matched_idxs = []
        for anchors_per_image, targets_per_image in zip(anchors, targets):
496
497
498
499
            if targets_per_image["boxes"].numel() == 0:
                matched_idxs.append(
                    torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
                )
500
501
                continue

502
            match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image)
503
504
505
506
507
            matched_idxs.append(self.proposal_matcher(match_quality_matrix))

        return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)

    def postprocess_detections(self, head_outputs, anchors, image_shapes):
508
        # type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
509
510
        class_logits = head_outputs["cls_logits"]
        box_regression = head_outputs["bbox_regression"]
511

512
        num_images = len(image_shapes)
513

514
        detections: List[Dict[str, Tensor]] = []
515

516
517
518
519
        for index in range(num_images):
            box_regression_per_image = [br[index] for br in box_regression]
            logits_per_image = [cl[index] for cl in class_logits]
            anchors_per_image, image_shape = anchors[index], image_shapes[index]
520
521
522
523
524

            image_boxes = []
            image_scores = []
            image_labels = []

525
526
527
            for box_regression_per_level, logits_per_level, anchors_per_level in zip(
                box_regression_per_image, logits_per_image, anchors_per_image
            ):
528
529
                num_classes = logits_per_level.shape[-1]

530
                # remove low scoring boxes
531
532
533
534
                scores_per_level = torch.sigmoid(logits_per_level).flatten()
                keep_idxs = scores_per_level > self.score_thresh
                scores_per_level = scores_per_level[keep_idxs]
                topk_idxs = torch.where(keep_idxs)[0]
535

536
                # keep only topk scoring predictions
537
                num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
538
539
                scores_per_level, idxs = scores_per_level.topk(num_topk)
                topk_idxs = topk_idxs[idxs]
540

541
                anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
542
                labels_per_level = topk_idxs % num_classes
543

544
545
546
                boxes_per_level = self.box_coder.decode_single(
                    box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
                )
547
548
549
550
551
                boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)

                image_boxes.append(boxes_per_level)
                image_scores.append(scores_per_level)
                image_labels.append(labels_per_level)
552

553
554
555
            image_boxes = torch.cat(image_boxes, dim=0)
            image_scores = torch.cat(image_scores, dim=0)
            image_labels = torch.cat(image_labels, dim=0)
556

557
558
            # non-maximum suppression
            keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
559
560
561
562
563
564
565
566
567
            keep = keep[: self.detections_per_img]

            detections.append(
                {
                    "boxes": image_boxes[keep],
                    "scores": image_scores[keep],
                    "labels": image_labels[keep],
                }
            )
568
569
570
571
572
573

        return detections

    def forward(self, images, targets=None):
        # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
        """
574
        Args:
575
576
577
578
579
580
581
582
583
584
585
            images (list[Tensor]): images to be processed
            targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)

        Returns:
            result (list[BoxList] or dict[Tensor]): the output from the model.
                During training, it returns a dict[Tensor] which contains the losses.
                During testing, it returns list[BoxList] contains additional fields
                like `scores`, `labels` and `mask` (for Mask R-CNN models).

        """
        if self.training:
586
            if targets is None:
587
588
589
590
591
592
593
594
595
                torch._assert(False, "targets should not be none when in training mode")
            else:
                for target in targets:
                    boxes = target["boxes"]
                    torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.")
                    torch._assert(
                        len(boxes.shape) == 2 and boxes.shape[-1] == 4,
                        "Expected target boxes to be a tensor of shape [N, 4].",
                    )
596
597

        # get the original image sizes
598
        original_image_sizes: List[Tuple[int, int]] = []
599
600
        for img in images:
            val = img.shape[-2:]
601
602
603
604
            torch._assert(
                len(val) == 2,
                f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
            )
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
            original_image_sizes.append((val[0], val[1]))

        # transform the input
        images, targets = self.transform(images, targets)

        # Check for degenerate boxes
        # TODO: Move this to a function
        if targets is not None:
            for target_idx, target in enumerate(targets):
                boxes = target["boxes"]
                degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
                if degenerate_boxes.any():
                    # print the first degenerate box
                    bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
                    degen_bb: List[float] = boxes[bb_idx].tolist()
620
621
                    torch._assert(
                        False,
622
                        "All bounding boxes should have positive height and width."
623
                        f" Found invalid box {degen_bb} for target at index {target_idx}.",
624
                    )
625
626
627
628

        # get the features from the backbone
        features = self.backbone(images.tensors)
        if isinstance(features, torch.Tensor):
629
            features = OrderedDict([("0", features)])
630
631
632
633
634
635
636
637
638
639
640

        # TODO: Do we want a list or a dict?
        features = list(features.values())

        # compute the retinanet heads outputs using the features
        head_outputs = self.head(features)

        # create the set of anchors
        anchors = self.anchor_generator(images, features)

        losses = {}
641
        detections: List[Dict[str, Tensor]] = []
642
        if self.training:
643
            if targets is None:
644
645
646
647
                torch._assert(False, "targets should not be none when in training mode")
            else:
                # compute the losses
                losses = self.compute_loss(targets, head_outputs, anchors)
648
        else:
649
650
651
652
653
            # recover level sizes
            num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
            HW = 0
            for v in num_anchors_per_level:
                HW += v
654
            HWA = head_outputs["cls_logits"].size(1)
655
656
657
658
659
660
661
662
663
            A = HWA // HW
            num_anchors_per_level = [hw * A for hw in num_anchors_per_level]

            # split outputs per level
            split_head_outputs: Dict[str, List[Tensor]] = {}
            for k in head_outputs:
                split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1))
            split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]

664
            # compute the detections
665
            detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
666
667
668
669
670
671
            detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)

        if torch.jit.is_scripting():
            if not self._has_warned:
                warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting")
                self._has_warned = True
672
            return losses, detections
673
674
675
        return self.eager_outputs(losses, detections)


676
677
_COMMON_META = {
    "categories": _COCO_CATEGORIES,
678
    "min_size": (1, 1),
679
680
681
}


682
683
684
685
686
class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
    COCO_V1 = Weights(
        url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
        transforms=ObjectDetection,
        meta={
687
            **_COMMON_META,
688
689
            "num_params": 34014999,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
690
691
692
            "metrics": {
                "box_map": 36.4,
            },
693
694
695
        },
    )
    DEFAULT = COCO_V1
696
697


698
class RetinaNet_ResNet50_FPN_V2_Weights(WeightsEnum):
699
700
701
702
703
704
705
    COCO_V1 = Weights(
        url="https://download.pytorch.org/models/retinanet_resnet50_fpn_v2_coco-5905b1c5.pth",
        transforms=ObjectDetection,
        meta={
            **_COMMON_META,
            "num_params": 38198935,
            "recipe": "https://github.com/pytorch/vision/pull/5756",
706
707
708
            "metrics": {
                "box_map": 41.5,
            },
709
710
711
        },
    )
    DEFAULT = COCO_V1
712
713


714
715
716
717
@handle_legacy_interface(
    weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1),
    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
718
def retinanet_resnet50_fpn(
719
720
721
722
723
724
725
726
    *,
    weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None,
    progress: bool = True,
    num_classes: Optional[int] = None,
    weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
    trainable_backbone_layers: Optional[int] = None,
    **kwargs: Any,
) -> RetinaNet:
727
728
729
    """
    Constructs a RetinaNet model with a ResNet-50-FPN backbone.

730
    Reference: `Focal Loss for Dense Object Detection <https://arxiv.org/abs/1708.02002>`_.
731

732
733
734
735
736
737
738
    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
    image, and should be in ``0-1`` range. Different images can have different sizes.

    The behavior of the model changes depending if it is in training or evaluation mode.

    During training, the model expects both the input tensors, as well as a targets (list of dictionary),
    containing:
739

740
741
        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
742
743
744
745
746
747
748
        - labels (``Int64Tensor[N]``): the class label for each ground-truth box

    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
    losses.

    During inference, the model requires only the input tensors, and returns the post-processed
    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
749
    follows, where ``N`` is the number of detections:
750

751
752
        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
753
754
755
756
        - labels (``Int64Tensor[N]``): the predicted labels for each detection
        - scores (``Tensor[N]``): the scores of each detection

    For more details on the output, you may refer to :ref:`instance_seg_output`.
757
758
759

    Example::

760
        >>> model = torchvision.models.detection.retinanet_resnet50_fpn(weights=RetinaNet_ResNet50_FPN_Weights.DEFAULT)
761
762
763
764
        >>> model.eval()
        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
        >>> predictions = model(x)

765
    Args:
766
767
768
769
770
771
        weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`
            below for more details, and possible values. By default, no
            pre-trained weights are used.
        progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
772
        num_classes (int, optional): number of output classes of the model (including the background)
773
774
        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
            the backbone.
775
        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
776
777
            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
            passed (the default) this value is set to 3.
778
779
780
781
        **kwargs: parameters passed to the ``torchvision.models.detection.RetinaNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_
            for more details about this class.
782
783
784

    .. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights
        :members:
785
    """
786
787
788
789
790
791
792
793
794
795
    weights = RetinaNet_ResNet50_FPN_Weights.verify(weights)
    weights_backbone = ResNet50_Weights.verify(weights_backbone)

    if weights is not None:
        weights_backbone = None
        num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
    elif num_classes is None:
        num_classes = 91

    is_trained = weights is not None or weights_backbone is not None
796
797
    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
798

799
    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
800
    # skip P2 because it generates too many anchors (according to their paper)
801
802
    backbone = _resnet_fpn_extractor(
        backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
803
    )
804
    model = RetinaNet(backbone, num_classes, **kwargs)
805
806
807
808
809
810

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))
        if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1:
            overwrite_eps(model, 0.0)

811
    return model
812
813
814
815
816
817
818
819
820
821
822
823
824
825


def retinanet_resnet50_fpn_v2(
    *,
    weights: Optional[RetinaNet_ResNet50_FPN_V2_Weights] = None,
    progress: bool = True,
    num_classes: Optional[int] = None,
    weights_backbone: Optional[ResNet50_Weights] = None,
    trainable_backbone_layers: Optional[int] = None,
    **kwargs: Any,
) -> RetinaNet:
    """
    Constructs an improved RetinaNet model with a ResNet-50-FPN backbone.

826
    Reference: `Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection
827
828
829
830
831
    <https://arxiv.org/abs/1912.02424>`_.

    :func:`~torchvision.models.detection.retinanet_resnet50_fpn` for more details.

    Args:
832
833
834
835
836
837
        weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`
            below for more details, and possible values. By default, no
            pre-trained weights are used.
        progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
838
        num_classes (int, optional): number of output classes of the model (including the background)
839
840
        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
            the backbone.
841
842
843
        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
            passed (the default) this value is set to 3.
844
845
846
847
        **kwargs: parameters passed to the ``torchvision.models.detection.RetinaNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_
            for more details about this class.
848
849
850

    .. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights
        :members:
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
    """
    weights = RetinaNet_ResNet50_FPN_V2_Weights.verify(weights)
    weights_backbone = ResNet50_Weights.verify(weights_backbone)

    if weights is not None:
        weights_backbone = None
        num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
    elif num_classes is None:
        num_classes = 91

    is_trained = weights is not None or weights_backbone is not None
    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)

    backbone = resnet50(weights=weights_backbone, progress=progress)
    backbone = _resnet_fpn_extractor(
        backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(2048, 256)
    )
    anchor_generator = _default_anchorgen()
    head = RetinaNetHead(
        backbone.out_channels,
        anchor_generator.num_anchors_per_location()[0],
        num_classes,
        norm_layer=partial(nn.GroupNorm, 32),
    )
    head.regression_head._loss_type = "giou"
    model = RetinaNet(backbone, num_classes, anchor_generator=anchor_generator, head=head, **kwargs)

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

    return model
882
883
884
885
886
887
888
889
890
891
892


# The dictionary below is internal implementation detail and will be removed in v0.15
from .._utils import _ModelURLs


model_urls = _ModelURLs(
    {
        "retinanet_resnet50_fpn_coco": RetinaNet_ResNet50_FPN_Weights.COCO_V1.url,
    }
)