retinanet.py 36.4 KB
Newer Older
1
2
import math
import warnings
3
from collections import OrderedDict
4
from functools import partial
5
from typing import Any, Callable, Dict, List, Optional, Tuple
6
7

import torch
8
from torch import nn, Tensor
9

10
from ...ops import boxes as box_ops, misc as misc_nn_ops, sigmoid_focal_loss
11
from ...ops.feature_pyramid_network import LastLevelP6P7
12
from ...transforms._presets import ObjectDetection
13
from ...utils import _log_api_usage_once
14
from .._api import register_model, Weights, WeightsEnum
15
from .._meta import _COCO_CATEGORIES
16
17
from .._utils import _ovewrite_value_param, handle_legacy_interface
from ..resnet import resnet50, ResNet50_Weights
18
from . import _utils as det_utils
19
from ._utils import _box_loss, overwrite_eps
20
from .anchor_utils import AnchorGenerator
21
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
22
from .transform import GeneralizedRCNNTransform
23
24


25
26
27
__all__ = [
    "RetinaNet",
    "RetinaNet_ResNet50_FPN_Weights",
28
    "RetinaNet_ResNet50_FPN_V2_Weights",
29
    "retinanet_resnet50_fpn",
30
    "retinanet_resnet50_fpn_v2",
31
]
32
33
34
35
36
37
38
39
40


def _sum(x: List[Tensor]) -> Tensor:
    res = x[0]
    for i in x[1:]:
        res = res + i
    return res


41
42
43
44
45
def _v1_to_v2_weights(state_dict, prefix):
    for i in range(4):
        for type in ["weight", "bias"]:
            old_key = f"{prefix}conv.{2*i}.{type}"
            new_key = f"{prefix}conv.{i}.0.{type}"
46
47
            if old_key in state_dict:
                state_dict[new_key] = state_dict.pop(old_key)
48
49
50
51
52
53
54
55
56


def _default_anchorgen():
    anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512])
    aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
    anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
    return anchor_generator


57
58
59
60
class RetinaNetHead(nn.Module):
    """
    A regression and classification head for use in RetinaNet.

61
    Args:
62
63
64
        in_channels (int): number of channels of the input feature
        num_anchors (int): number of anchors to be predicted
        num_classes (int): number of classes to be predicted
65
        norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
66
67
    """

68
    def __init__(self, in_channels, num_anchors, num_classes, norm_layer: Optional[Callable[..., nn.Module]] = None):
69
        super().__init__()
70
71
72
73
        self.classification_head = RetinaNetClassificationHead(
            in_channels, num_anchors, num_classes, norm_layer=norm_layer
        )
        self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors, norm_layer=norm_layer)
74
75
76
77

    def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
        # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor]
        return {
78
79
            "classification": self.classification_head.compute_loss(targets, head_outputs, matched_idxs),
            "bbox_regression": self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs),
80
81
82
83
        }

    def forward(self, x):
        # type: (List[Tensor]) -> Dict[str, Tensor]
84
        return {"cls_logits": self.classification_head(x), "bbox_regression": self.regression_head(x)}
85
86
87
88
89
90


class RetinaNetClassificationHead(nn.Module):
    """
    A classification head for use in RetinaNet.

91
    Args:
92
93
94
        in_channels (int): number of channels of the input feature
        num_anchors (int): number of anchors to be predicted
        num_classes (int): number of classes to be predicted
95
        norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
96
97
    """

98
99
100
101
102
103
104
105
106
107
    _version = 2

    def __init__(
        self,
        in_channels,
        num_anchors,
        num_classes,
        prior_probability=0.01,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ):
108
109
110
111
        super().__init__()

        conv = []
        for _ in range(4):
112
            conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer))
113
114
        self.conv = nn.Sequential(*conv)

115
        for layer in self.conv.modules():
116
117
            if isinstance(layer, nn.Conv2d):
                torch.nn.init.normal_(layer.weight, std=0.01)
118
119
                if layer.bias is not None:
                    torch.nn.init.constant_(layer.bias, 0)
120
121
122
123
124
125
126
127
128
129
130
131
132

        self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
        torch.nn.init.normal_(self.cls_logits.weight, std=0.01)
        torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability))

        self.num_classes = num_classes
        self.num_anchors = num_anchors

        # This is to fix using det_utils.Matcher.BETWEEN_THRESHOLDS in TorchScript.
        # TorchScript doesn't support class attributes.
        # https://github.com/pytorch/vision/pull/1697#issuecomment-630255584
        self.BETWEEN_THRESHOLDS = det_utils.Matcher.BETWEEN_THRESHOLDS

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)

        if version is None or version < 2:
            _v1_to_v2_weights(state_dict, prefix)

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )

158
159
160
161
    def compute_loss(self, targets, head_outputs, matched_idxs):
        # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor
        losses = []

162
        cls_logits = head_outputs["cls_logits"]
163
164
165
166
167

        for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs):
            # determine only the foreground
            foreground_idxs_per_image = matched_idxs_per_image >= 0
            num_foreground = foreground_idxs_per_image.sum()
168
169
170
171
172

            # create the target classification
            gt_classes_target = torch.zeros_like(cls_logits_per_image)
            gt_classes_target[
                foreground_idxs_per_image,
173
                targets_per_image["labels"][matched_idxs_per_image[foreground_idxs_per_image]],
174
175
176
177
            ] = 1.0

            # find indices for which anchors should be ignored
            valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS
178
179

            # compute the classification loss
180
181
182
183
184
185
186
187
            losses.append(
                sigmoid_focal_loss(
                    cls_logits_per_image[valid_idxs_per_image],
                    gt_classes_target[valid_idxs_per_image],
                    reduction="sum",
                )
                / max(1, num_foreground)
            )
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213

        return _sum(losses) / len(targets)

    def forward(self, x):
        # type: (List[Tensor]) -> Tensor
        all_cls_logits = []

        for features in x:
            cls_logits = self.conv(features)
            cls_logits = self.cls_logits(cls_logits)

            # Permute classification output from (N, A * K, H, W) to (N, HWA, K).
            N, _, H, W = cls_logits.shape
            cls_logits = cls_logits.view(N, -1, self.num_classes, H, W)
            cls_logits = cls_logits.permute(0, 3, 4, 1, 2)
            cls_logits = cls_logits.reshape(N, -1, self.num_classes)  # Size=(N, HWA, 4)

            all_cls_logits.append(cls_logits)

        return torch.cat(all_cls_logits, dim=1)


class RetinaNetRegressionHead(nn.Module):
    """
    A regression head for use in RetinaNet.

214
    Args:
215
216
        in_channels (int): number of channels of the input feature
        num_anchors (int): number of anchors to be predicted
217
        norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
218
    """
219

220
221
    _version = 2

222
    __annotations__ = {
223
        "box_coder": det_utils.BoxCoder,
224
225
    }

226
    def __init__(self, in_channels, num_anchors, norm_layer: Optional[Callable[..., nn.Module]] = None):
227
228
229
230
        super().__init__()

        conv = []
        for _ in range(4):
231
            conv.append(misc_nn_ops.Conv2dNormActivation(in_channels, in_channels, norm_layer=norm_layer))
232
233
234
235
236
237
        self.conv = nn.Sequential(*conv)

        self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)
        torch.nn.init.normal_(self.bbox_reg.weight, std=0.01)
        torch.nn.init.zeros_(self.bbox_reg.bias)

238
        for layer in self.conv.modules():
239
240
            if isinstance(layer, nn.Conv2d):
                torch.nn.init.normal_(layer.weight, std=0.01)
241
242
                if layer.bias is not None:
                    torch.nn.init.zeros_(layer.bias)
243
244

        self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        self._loss_type = "l1"

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)

        if version is None or version < 2:
            _v1_to_v2_weights(state_dict, prefix)

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )
271
272
273
274
275

    def compute_loss(self, targets, head_outputs, anchors, matched_idxs):
        # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor
        losses = []

276
        bbox_regression = head_outputs["bbox_regression"]
277

278
279
280
        for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in zip(
            targets, bbox_regression, anchors, matched_idxs
        ):
281
            # determine only the foreground indices, ignore the rest
282
283
            foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0]
            num_foreground = foreground_idxs_per_image.numel()
284
285

            # select only the foreground boxes
286
            matched_gt_boxes_per_image = targets_per_image["boxes"][matched_idxs_per_image[foreground_idxs_per_image]]
287
288
289
290
            bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :]
            anchors_per_image = anchors_per_image[foreground_idxs_per_image, :]

            # compute the loss
291
            losses.append(
292
293
294
295
296
297
298
                _box_loss(
                    self._loss_type,
                    self.box_coder,
                    anchors_per_image,
                    matched_gt_boxes_per_image,
                    bbox_regression_per_image,
                )
299
300
                / max(1, num_foreground)
            )
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333

        return _sum(losses) / max(1, len(targets))

    def forward(self, x):
        # type: (List[Tensor]) -> Tensor
        all_bbox_regression = []

        for features in x:
            bbox_regression = self.conv(features)
            bbox_regression = self.bbox_reg(bbox_regression)

            # Permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4).
            N, _, H, W = bbox_regression.shape
            bbox_regression = bbox_regression.view(N, -1, 4, H, W)
            bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2)
            bbox_regression = bbox_regression.reshape(N, -1, 4)  # Size=(N, HWA, 4)

            all_bbox_regression.append(bbox_regression)

        return torch.cat(all_bbox_regression, dim=1)


class RetinaNet(nn.Module):
    """
    Implements RetinaNet.

    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
    image, and should be in 0-1 range. Different images can have different sizes.

    The behavior of the model changes depending if it is in training or evaluation mode.

    During training, the model expects both the input tensors, as well as a targets (list of dictionary),
    containing:
334
335
        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
336
337
338
339
340
341
342
343
        - labels (Int64Tensor[N]): the class label for each ground-truth box

    The model returns a Dict[Tensor] during training, containing the classification and regression
    losses.

    During inference, the model requires only the input tensors, and returns the post-processed
    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
    follows:
344
345
        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
346
347
348
        - labels (Int64Tensor[N]): the predicted labels for each image
        - scores (Tensor[N]): the scores for each prediction

349
    Args:
350
351
352
353
        backbone (nn.Module): the network used to compute the features for the model.
            It should contain an out_channels attribute, which indicates the number of output
            channels that each feature map has (and it should be the same for all feature maps).
            The backbone should return a single Tensor or an OrderedDict[Tensor].
354
        num_classes (int): number of output classes of the model (including the background).
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
        min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
        max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
        image_mean (Tuple[float, float, float]): mean values used for input normalization.
            They are generally the mean values of the dataset on which the backbone has been trained
            on
        image_std (Tuple[float, float, float]): std values used for input normalization.
            They are generally the std values of the dataset on which the backbone has been trained on
        anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
            maps.
        head (nn.Module): Module run on top of the feature pyramid.
            Defaults to a module containing a classification and regression module.
        score_thresh (float): Score threshold used for postprocessing the detections.
        nms_thresh (float): NMS threshold used for postprocessing the detections.
        detections_per_img (int): Number of best detections to keep after NMS.
        fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
            considered as positive during training.
        bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
            considered as negative during training.
373
        topk_candidates (int): Number of best detections to keep before NMS.
374
375
376
377
378
379
380
381
382

    Example:

        >>> import torch
        >>> import torchvision
        >>> from torchvision.models.detection import RetinaNet
        >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
        >>> # load a pre-trained model for classification and return
        >>> # only the features
383
        >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
384
385
386
387
388
389
390
391
392
393
394
        >>> # RetinaNet needs to know the number of
        >>> # output channels in a backbone. For mobilenet_v2, it's 1280
        >>> # so we need to add it here
        >>> backbone.out_channels = 1280
        >>>
        >>> # let's make the network generate 5 x 3 anchors per spatial
        >>> # location, with 5 different sizes and 3 different aspect
        >>> # ratios. We have a Tuple[Tuple[int]] because each feature
        >>> # map could potentially have different sizes and
        >>> # aspect ratios
        >>> anchor_generator = AnchorGenerator(
395
396
        >>>     sizes=((32, 64, 128, 256, 512),),
        >>>     aspect_ratios=((0.5, 1.0, 2.0),)
397
398
399
400
401
402
403
404
405
406
        >>> )
        >>>
        >>> # put the pieces together inside a RetinaNet model
        >>> model = RetinaNet(backbone,
        >>>                   num_classes=2,
        >>>                   anchor_generator=anchor_generator)
        >>> model.eval()
        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
        >>> predictions = model(x)
    """
407

408
    __annotations__ = {
409
410
        "box_coder": det_utils.BoxCoder,
        "proposal_matcher": det_utils.Matcher,
411
412
    }

413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
    def __init__(
        self,
        backbone,
        num_classes,
        # transform parameters
        min_size=800,
        max_size=1333,
        image_mean=None,
        image_std=None,
        # Anchor parameters
        anchor_generator=None,
        head=None,
        proposal_matcher=None,
        score_thresh=0.05,
        nms_thresh=0.5,
        detections_per_img=300,
        fg_iou_thresh=0.5,
        bg_iou_thresh=0.4,
        topk_candidates=1000,
432
        **kwargs,
433
    ):
434
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
435
        _log_api_usage_once(self)
436
437
438
439
440

        if not hasattr(backbone, "out_channels"):
            raise ValueError(
                "backbone should contain an attribute out_channels "
                "specifying the number of output channels (assumed to be the "
441
442
                "same for all the levels)"
            )
443
444
        self.backbone = backbone

445
446
447
448
        if not isinstance(anchor_generator, (AnchorGenerator, type(None))):
            raise TypeError(
                f"anchor_generator should be of type AnchorGenerator or None instead of {type(anchor_generator)}"
            )
449
450

        if anchor_generator is None:
451
            anchor_generator = _default_anchorgen()
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
        self.anchor_generator = anchor_generator

        if head is None:
            head = RetinaNetHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes)
        self.head = head

        if proposal_matcher is None:
            proposal_matcher = det_utils.Matcher(
                fg_iou_thresh,
                bg_iou_thresh,
                allow_low_quality_matches=True,
            )
        self.proposal_matcher = proposal_matcher

        self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))

        if image_mean is None:
            image_mean = [0.485, 0.456, 0.406]
        if image_std is None:
            image_std = [0.229, 0.224, 0.225]
472
        self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std, **kwargs)
473
474
475
476

        self.score_thresh = score_thresh
        self.nms_thresh = nms_thresh
        self.detections_per_img = detections_per_img
477
        self.topk_candidates = topk_candidates
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493

        # used only on torchscript mode
        self._has_warned = False

    @torch.jit.unused
    def eager_outputs(self, losses, detections):
        # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
        if self.training:
            return losses

        return detections

    def compute_loss(self, targets, head_outputs, anchors):
        # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor]
        matched_idxs = []
        for anchors_per_image, targets_per_image in zip(anchors, targets):
494
495
496
497
            if targets_per_image["boxes"].numel() == 0:
                matched_idxs.append(
                    torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device)
                )
498
499
                continue

500
            match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image)
501
502
503
504
505
            matched_idxs.append(self.proposal_matcher(match_quality_matrix))

        return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)

    def postprocess_detections(self, head_outputs, anchors, image_shapes):
506
        # type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
507
508
        class_logits = head_outputs["cls_logits"]
        box_regression = head_outputs["bbox_regression"]
509

510
        num_images = len(image_shapes)
511

512
        detections: List[Dict[str, Tensor]] = []
513

514
515
516
517
        for index in range(num_images):
            box_regression_per_image = [br[index] for br in box_regression]
            logits_per_image = [cl[index] for cl in class_logits]
            anchors_per_image, image_shape = anchors[index], image_shapes[index]
518
519
520
521
522

            image_boxes = []
            image_scores = []
            image_labels = []

523
524
525
            for box_regression_per_level, logits_per_level, anchors_per_level in zip(
                box_regression_per_image, logits_per_image, anchors_per_image
            ):
526
527
                num_classes = logits_per_level.shape[-1]

528
                # remove low scoring boxes
529
530
531
532
                scores_per_level = torch.sigmoid(logits_per_level).flatten()
                keep_idxs = scores_per_level > self.score_thresh
                scores_per_level = scores_per_level[keep_idxs]
                topk_idxs = torch.where(keep_idxs)[0]
533

534
                # keep only topk scoring predictions
535
                num_topk = det_utils._topk_min(topk_idxs, self.topk_candidates, 0)
536
537
                scores_per_level, idxs = scores_per_level.topk(num_topk)
                topk_idxs = topk_idxs[idxs]
538

539
                anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor")
540
                labels_per_level = topk_idxs % num_classes
541

542
543
544
                boxes_per_level = self.box_coder.decode_single(
                    box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs]
                )
545
546
547
548
549
                boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)

                image_boxes.append(boxes_per_level)
                image_scores.append(scores_per_level)
                image_labels.append(labels_per_level)
550

551
552
553
            image_boxes = torch.cat(image_boxes, dim=0)
            image_scores = torch.cat(image_scores, dim=0)
            image_labels = torch.cat(image_labels, dim=0)
554

555
556
            # non-maximum suppression
            keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
557
558
559
560
561
562
563
564
565
            keep = keep[: self.detections_per_img]

            detections.append(
                {
                    "boxes": image_boxes[keep],
                    "scores": image_scores[keep],
                    "labels": image_labels[keep],
                }
            )
566
567
568
569
570
571

        return detections

    def forward(self, images, targets=None):
        # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]
        """
572
        Args:
573
574
575
576
577
578
579
580
581
582
583
            images (list[Tensor]): images to be processed
            targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional)

        Returns:
            result (list[BoxList] or dict[Tensor]): the output from the model.
                During training, it returns a dict[Tensor] which contains the losses.
                During testing, it returns list[BoxList] contains additional fields
                like `scores`, `labels` and `mask` (for Mask R-CNN models).

        """
        if self.training:
584
            if targets is None:
585
586
587
588
589
590
591
592
593
                torch._assert(False, "targets should not be none when in training mode")
            else:
                for target in targets:
                    boxes = target["boxes"]
                    torch._assert(isinstance(boxes, torch.Tensor), "Expected target boxes to be of type Tensor.")
                    torch._assert(
                        len(boxes.shape) == 2 and boxes.shape[-1] == 4,
                        "Expected target boxes to be a tensor of shape [N, 4].",
                    )
594
595

        # get the original image sizes
596
        original_image_sizes: List[Tuple[int, int]] = []
597
598
        for img in images:
            val = img.shape[-2:]
599
600
601
602
            torch._assert(
                len(val) == 2,
                f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
            )
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
            original_image_sizes.append((val[0], val[1]))

        # transform the input
        images, targets = self.transform(images, targets)

        # Check for degenerate boxes
        # TODO: Move this to a function
        if targets is not None:
            for target_idx, target in enumerate(targets):
                boxes = target["boxes"]
                degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
                if degenerate_boxes.any():
                    # print the first degenerate box
                    bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
                    degen_bb: List[float] = boxes[bb_idx].tolist()
618
619
                    torch._assert(
                        False,
620
                        "All bounding boxes should have positive height and width."
621
                        f" Found invalid box {degen_bb} for target at index {target_idx}.",
622
                    )
623
624
625
626

        # get the features from the backbone
        features = self.backbone(images.tensors)
        if isinstance(features, torch.Tensor):
627
            features = OrderedDict([("0", features)])
628
629
630
631
632
633
634
635
636
637
638

        # TODO: Do we want a list or a dict?
        features = list(features.values())

        # compute the retinanet heads outputs using the features
        head_outputs = self.head(features)

        # create the set of anchors
        anchors = self.anchor_generator(images, features)

        losses = {}
639
        detections: List[Dict[str, Tensor]] = []
640
        if self.training:
641
            if targets is None:
642
643
644
645
                torch._assert(False, "targets should not be none when in training mode")
            else:
                # compute the losses
                losses = self.compute_loss(targets, head_outputs, anchors)
646
        else:
647
648
649
650
651
            # recover level sizes
            num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
            HW = 0
            for v in num_anchors_per_level:
                HW += v
652
            HWA = head_outputs["cls_logits"].size(1)
653
654
655
656
657
658
659
660
661
            A = HWA // HW
            num_anchors_per_level = [hw * A for hw in num_anchors_per_level]

            # split outputs per level
            split_head_outputs: Dict[str, List[Tensor]] = {}
            for k in head_outputs:
                split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1))
            split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]

662
            # compute the detections
663
            detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
664
665
666
667
668
669
            detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)

        if torch.jit.is_scripting():
            if not self._has_warned:
                warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting")
                self._has_warned = True
670
            return losses, detections
671
672
673
        return self.eager_outputs(losses, detections)


674
675
_COMMON_META = {
    "categories": _COCO_CATEGORIES,
676
    "min_size": (1, 1),
677
678
679
}


680
681
682
683
684
class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
    COCO_V1 = Weights(
        url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
        transforms=ObjectDetection,
        meta={
685
            **_COMMON_META,
686
687
            "num_params": 34014999,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
688
689
690
691
            "_metrics": {
                "COCO-val2017": {
                    "box_map": 36.4,
                }
692
            },
693
694
            "_ops": 151.54,
            "_weight_size": 130.267,
695
            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
696
697
698
        },
    )
    DEFAULT = COCO_V1
699
700


701
class RetinaNet_ResNet50_FPN_V2_Weights(WeightsEnum):
702
703
704
705
706
707
708
    COCO_V1 = Weights(
        url="https://download.pytorch.org/models/retinanet_resnet50_fpn_v2_coco-5905b1c5.pth",
        transforms=ObjectDetection,
        meta={
            **_COMMON_META,
            "num_params": 38198935,
            "recipe": "https://github.com/pytorch/vision/pull/5756",
709
710
711
712
            "_metrics": {
                "COCO-val2017": {
                    "box_map": 41.5,
                }
713
            },
714
715
            "_ops": 152.238,
            "_weight_size": 146.037,
716
            "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
717
718
719
        },
    )
    DEFAULT = COCO_V1
720
721


722
@register_model()
723
724
725
726
@handle_legacy_interface(
    weights=("pretrained", RetinaNet_ResNet50_FPN_Weights.COCO_V1),
    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
727
def retinanet_resnet50_fpn(
728
729
730
731
732
733
734
735
    *,
    weights: Optional[RetinaNet_ResNet50_FPN_Weights] = None,
    progress: bool = True,
    num_classes: Optional[int] = None,
    weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
    trainable_backbone_layers: Optional[int] = None,
    **kwargs: Any,
) -> RetinaNet:
736
737
738
    """
    Constructs a RetinaNet model with a ResNet-50-FPN backbone.

739
740
    .. betastatus:: detection module

741
    Reference: `Focal Loss for Dense Object Detection <https://arxiv.org/abs/1708.02002>`_.
742

743
744
745
746
747
748
749
    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
    image, and should be in ``0-1`` range. Different images can have different sizes.

    The behavior of the model changes depending if it is in training or evaluation mode.

    During training, the model expects both the input tensors, as well as a targets (list of dictionary),
    containing:
750

751
752
        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
753
754
755
756
757
758
759
        - labels (``Int64Tensor[N]``): the class label for each ground-truth box

    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
    losses.

    During inference, the model requires only the input tensors, and returns the post-processed
    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
760
    follows, where ``N`` is the number of detections:
761

762
763
        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
764
765
766
767
        - labels (``Int64Tensor[N]``): the predicted labels for each detection
        - scores (``Tensor[N]``): the scores of each detection

    For more details on the output, you may refer to :ref:`instance_seg_output`.
768
769
770

    Example::

771
        >>> model = torchvision.models.detection.retinanet_resnet50_fpn(weights=RetinaNet_ResNet50_FPN_Weights.DEFAULT)
772
773
774
775
        >>> model.eval()
        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
        >>> predictions = model(x)

776
    Args:
777
778
779
780
781
782
        weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights`
            below for more details, and possible values. By default, no
            pre-trained weights are used.
        progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
783
        num_classes (int, optional): number of output classes of the model (including the background)
784
785
        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
            the backbone.
786
        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
787
788
            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
            passed (the default) this value is set to 3.
789
790
791
792
        **kwargs: parameters passed to the ``torchvision.models.detection.RetinaNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_
            for more details about this class.
793
794
795

    .. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_Weights
        :members:
796
    """
797
798
799
800
801
    weights = RetinaNet_ResNet50_FPN_Weights.verify(weights)
    weights_backbone = ResNet50_Weights.verify(weights_backbone)

    if weights is not None:
        weights_backbone = None
802
        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
803
804
805
806
    elif num_classes is None:
        num_classes = 91

    is_trained = weights is not None or weights_backbone is not None
807
808
    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
809

810
    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
811
    # skip P2 because it generates too many anchors (according to their paper)
812
813
    backbone = _resnet_fpn_extractor(
        backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
814
    )
815
    model = RetinaNet(backbone, num_classes, **kwargs)
816
817
818
819
820
821

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))
        if weights == RetinaNet_ResNet50_FPN_Weights.COCO_V1:
            overwrite_eps(model, 0.0)

822
    return model
823
824


825
@register_model()
826
827
828
829
@handle_legacy_interface(
    weights=("pretrained", RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1),
    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
830
831
832
833
834
835
836
837
838
839
840
841
def retinanet_resnet50_fpn_v2(
    *,
    weights: Optional[RetinaNet_ResNet50_FPN_V2_Weights] = None,
    progress: bool = True,
    num_classes: Optional[int] = None,
    weights_backbone: Optional[ResNet50_Weights] = None,
    trainable_backbone_layers: Optional[int] = None,
    **kwargs: Any,
) -> RetinaNet:
    """
    Constructs an improved RetinaNet model with a ResNet-50-FPN backbone.

842
843
    .. betastatus:: detection module

844
    Reference: `Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection
845
846
847
848
849
    <https://arxiv.org/abs/1912.02424>`_.

    :func:`~torchvision.models.detection.retinanet_resnet50_fpn` for more details.

    Args:
850
851
852
853
854
855
        weights (:class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights`
            below for more details, and possible values. By default, no
            pre-trained weights are used.
        progress (bool): If True, displays a progress bar of the download to stderr. Default is True.
856
        num_classes (int, optional): number of output classes of the model (including the background)
857
858
        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The pretrained weights for
            the backbone.
859
860
861
        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
            passed (the default) this value is set to 3.
862
863
864
865
        **kwargs: parameters passed to the ``torchvision.models.detection.RetinaNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/detection/retinanet.py>`_
            for more details about this class.
866
867
868

    .. autoclass:: torchvision.models.detection.RetinaNet_ResNet50_FPN_V2_Weights
        :members:
869
870
871
872
873
874
    """
    weights = RetinaNet_ResNet50_FPN_V2_Weights.verify(weights)
    weights_backbone = ResNet50_Weights.verify(weights_backbone)

    if weights is not None:
        weights_backbone = None
875
        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
    elif num_classes is None:
        num_classes = 91

    is_trained = weights is not None or weights_backbone is not None
    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)

    backbone = resnet50(weights=weights_backbone, progress=progress)
    backbone = _resnet_fpn_extractor(
        backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(2048, 256)
    )
    anchor_generator = _default_anchorgen()
    head = RetinaNetHead(
        backbone.out_channels,
        anchor_generator.num_anchors_per_location()[0],
        num_classes,
        norm_layer=partial(nn.GroupNorm, 32),
    )
    head.regression_head._loss_type = "giou"
    model = RetinaNet(backbone, num_classes, anchor_generator=anchor_generator, head=head, **kwargs)

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

    return model
900
901
902
903
904
905
906
907
908
909
910


# The dictionary below is internal implementation detail and will be removed in v0.15
from .._utils import _ModelURLs


model_urls = _ModelURLs(
    {
        "retinanet_resnet50_fpn_coco": RetinaNet_ResNet50_FPN_Weights.COCO_V1.url,
    }
)