keypoint_rcnn.py 21.2 KB
Newer Older
1
2
from typing import Any, Optional

3
4
5
6
import torch
from torch import nn
from torchvision.ops import MultiScaleRoIAlign

7
from ...ops import misc as misc_nn_ops
8
from ...transforms._presets import ObjectDetection
9
from .._api import register_model, Weights, WeightsEnum
10
from .._meta import _COCO_PERSON_CATEGORIES, _COCO_PERSON_KEYPOINT_NAMES
11
12
from .._utils import _ovewrite_value_param, handle_legacy_interface
from ..resnet import resnet50, ResNet50_Weights
13
from ._utils import overwrite_eps
14
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
15
from .faster_rcnn import FasterRCNN
16
17


18
19
20
21
22
__all__ = [
    "KeypointRCNN",
    "KeypointRCNN_ResNet50_FPN_Weights",
    "keypointrcnn_resnet50_fpn",
]
23
24
25


class KeypointRCNN(FasterRCNN):
26
27
28
29
30
31
    """
    Implements Keypoint R-CNN.

    The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
    image, and should be in 0-1 range. Different images can have different sizes.

32
    The behavior of the model changes depending on if it is in training or evaluation mode.
33

34
    During training, the model expects both the input tensors and targets (list of dictionary),
35
    containing:
36

37
38
        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
            ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
39
40
        - labels (Int64Tensor[N]): the class label for each ground-truth box
        - keypoints (FloatTensor[N, K, 3]): the K keypoints location for each of the N instances, in the
41
42
          format [x, y, visibility], where visibility=0 means that the keypoint is not visible.

43
44
45
46
47
48
    The model returns a Dict[Tensor] during training, containing the classification and regression
    losses for both the RPN and the R-CNN, and the keypoint loss.

    During inference, the model requires only the input tensors, and returns the post-processed
    predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as
    follows:
49

50
51
        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
            ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
52
        - labels (Int64Tensor[N]): the predicted labels for each image
53
        - scores (Tensor[N]): the scores or each prediction
54
        - keypoints (FloatTensor[N, K, 3]): the locations of the predicted keypoints, in [x, y, v] format.
55

56
    Args:
57
        backbone (nn.Module): the network used to compute the features for the model.
58
            It should contain an out_channels attribute, which indicates the number of output
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
            channels that each feature map has (and it should be the same for all feature maps).
            The backbone should return a single Tensor or and OrderedDict[Tensor].
        num_classes (int): number of output classes of the model (including the background).
            If box_predictor is specified, num_classes should be None.
        min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
        max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
        image_mean (Tuple[float, float, float]): mean values used for input normalization.
            They are generally the mean values of the dataset on which the backbone has been trained
            on
        image_std (Tuple[float, float, float]): std values used for input normalization.
            They are generally the std values of the dataset on which the backbone has been trained on
        rpn_anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature
            maps.
        rpn_head (nn.Module): module that computes the objectness and regression deltas from the RPN
        rpn_pre_nms_top_n_train (int): number of proposals to keep before applying NMS during training
        rpn_pre_nms_top_n_test (int): number of proposals to keep before applying NMS during testing
        rpn_post_nms_top_n_train (int): number of proposals to keep after applying NMS during training
        rpn_post_nms_top_n_test (int): number of proposals to keep after applying NMS during testing
        rpn_nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
        rpn_fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be
            considered as positive during training of the RPN.
        rpn_bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
            considered as negative during training of the RPN.
        rpn_batch_size_per_image (int): number of anchors that are sampled during training of the RPN
            for computing the loss
        rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
            of the RPN
86
        rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
            the locations indicated by the bounding boxes
        box_head (nn.Module): module that takes the cropped feature maps as input
        box_predictor (nn.Module): module that takes the output of box_head and returns the
            classification logits and box regression deltas.
        box_score_thresh (float): during inference, only return proposals with a classification score
            greater than box_score_thresh
        box_nms_thresh (float): NMS threshold for the prediction head. Used during inference
        box_detections_per_img (int): maximum number of detections per image, for all classes.
        box_fg_iou_thresh (float): minimum IoU between the proposals and the GT box so that they can be
            considered as positive during training of the classification head
        box_bg_iou_thresh (float): maximum IoU between the proposals and the GT box so that they can be
            considered as negative during training of the classification head
        box_batch_size_per_image (int): number of proposals that are sampled during training of the
            classification head
        box_positive_fraction (float): proportion of positive proposals in a mini-batch during training
            of the classification head
        bbox_reg_weights (Tuple[float, float, float, float]): weights for the encoding/decoding of the
            bounding boxes
        keypoint_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
             the locations indicated by the bounding boxes, which will be used for the keypoint head.
        keypoint_head (nn.Module): module that takes the cropped feature maps as input
        keypoint_predictor (nn.Module): module that takes the output of the keypoint_head and returns the
            heatmap logits

    Example::

114
        >>> import torch
115
116
        >>> import torchvision
        >>> from torchvision.models.detection import KeypointRCNN
117
        >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
118
119
120
        >>>
        >>> # load a pre-trained model for classification and return
        >>> # only the features
121
        >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
122
        >>> # KeypointRCNN needs to know the number of
123
        >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        >>> # so we need to add it here
        >>> backbone.out_channels = 1280
        >>>
        >>> # let's make the RPN generate 5 x 3 anchors per spatial
        >>> # location, with 5 different sizes and 3 different aspect
        >>> # ratios. We have a Tuple[Tuple[int]] because each feature
        >>> # map could potentially have different sizes and
        >>> # aspect ratios
        >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
        >>>                                    aspect_ratios=((0.5, 1.0, 2.0),))
        >>>
        >>> # let's define what are the feature maps that we will
        >>> # use to perform the region of interest cropping, as well as
        >>> # the size of the crop after rescaling.
        >>> # if your backbone returns a Tensor, featmap_names is expected to
139
        >>> # be ['0']. More generally, the backbone should return an
140
141
        >>> # OrderedDict[Tensor], and in featmap_names you can choose which
        >>> # feature maps to use.
142
        >>> roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
143
144
145
        >>>                                                 output_size=7,
        >>>                                                 sampling_ratio=2)
        >>>
146
        >>> keypoint_roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
147
148
        >>>                                                          output_size=14,
        >>>                                                          sampling_ratio=2)
149
        >>> # put the pieces together inside a KeypointRCNN model
150
151
152
153
154
155
        >>> model = KeypointRCNN(backbone,
        >>>                      num_classes=2,
        >>>                      rpn_anchor_generator=anchor_generator,
        >>>                      box_roi_pool=roi_pooler,
        >>>                      keypoint_roi_pool=keypoint_roi_pooler)
        >>> model.eval()
156
157
158
159
        >>> model.eval()
        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
        >>> predictions = model(x)
    """
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198

    def __init__(
        self,
        backbone,
        num_classes=None,
        # transform parameters
        min_size=None,
        max_size=1333,
        image_mean=None,
        image_std=None,
        # RPN parameters
        rpn_anchor_generator=None,
        rpn_head=None,
        rpn_pre_nms_top_n_train=2000,
        rpn_pre_nms_top_n_test=1000,
        rpn_post_nms_top_n_train=2000,
        rpn_post_nms_top_n_test=1000,
        rpn_nms_thresh=0.7,
        rpn_fg_iou_thresh=0.7,
        rpn_bg_iou_thresh=0.3,
        rpn_batch_size_per_image=256,
        rpn_positive_fraction=0.5,
        rpn_score_thresh=0.0,
        # Box parameters
        box_roi_pool=None,
        box_head=None,
        box_predictor=None,
        box_score_thresh=0.05,
        box_nms_thresh=0.5,
        box_detections_per_img=100,
        box_fg_iou_thresh=0.5,
        box_bg_iou_thresh=0.5,
        box_batch_size_per_image=512,
        box_positive_fraction=0.25,
        bbox_reg_weights=None,
        # keypoint parameters
        keypoint_roi_pool=None,
        keypoint_head=None,
        keypoint_predictor=None,
199
        num_keypoints=None,
200
        **kwargs,
201
    ):
202

203
204
205
206
        if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))):
            raise TypeError(
                "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}"
            )
207
208
        if min_size is None:
            min_size = (640, 672, 704, 736, 768, 800)
209

210
        if num_keypoints is not None:
211
            if keypoint_predictor is not None:
212
213
214
                raise ValueError("num_keypoints should be None when keypoint_predictor is specified")
        else:
            num_keypoints = 17
215
216
217
218

        out_channels = backbone.out_channels

        if keypoint_roi_pool is None:
219
            keypoint_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2)
220
221
222
223
224
225
226
227
228

        if keypoint_head is None:
            keypoint_layers = tuple(512 for _ in range(8))
            keypoint_head = KeypointRCNNHeads(out_channels, keypoint_layers)

        if keypoint_predictor is None:
            keypoint_dim_reduced = 512  # == keypoint_layers[-1]
            keypoint_predictor = KeypointRCNNPredictor(keypoint_dim_reduced, num_keypoints)

229
        super().__init__(
230
231
            backbone,
            num_classes,
232
            # transform parameters
233
234
235
236
            min_size,
            max_size,
            image_mean,
            image_std,
237
            # RPN-specific parameters
238
239
240
241
242
243
            rpn_anchor_generator,
            rpn_head,
            rpn_pre_nms_top_n_train,
            rpn_pre_nms_top_n_test,
            rpn_post_nms_top_n_train,
            rpn_post_nms_top_n_test,
244
            rpn_nms_thresh,
245
246
247
248
            rpn_fg_iou_thresh,
            rpn_bg_iou_thresh,
            rpn_batch_size_per_image,
            rpn_positive_fraction,
249
            rpn_score_thresh,
250
            # Box parameters
251
252
253
254
255
256
257
258
259
260
261
            box_roi_pool,
            box_head,
            box_predictor,
            box_score_thresh,
            box_nms_thresh,
            box_detections_per_img,
            box_fg_iou_thresh,
            box_bg_iou_thresh,
            box_batch_size_per_image,
            box_positive_fraction,
            bbox_reg_weights,
262
            **kwargs,
263
        )
264
265
266
267
268
269
270
271
272
273

        self.roi_heads.keypoint_roi_pool = keypoint_roi_pool
        self.roi_heads.keypoint_head = keypoint_head
        self.roi_heads.keypoint_predictor = keypoint_predictor


class KeypointRCNNHeads(nn.Sequential):
    def __init__(self, in_channels, layers):
        d = []
        next_feature = in_channels
Francisco Massa's avatar
Francisco Massa committed
274
        for out_channels in layers:
275
            d.append(nn.Conv2d(next_feature, out_channels, 3, stride=1, padding=1))
276
            d.append(nn.ReLU(inplace=True))
Francisco Massa's avatar
Francisco Massa committed
277
            next_feature = out_channels
278
        super().__init__(*d)
279
        for m in self.children():
280
            if isinstance(m, nn.Conv2d):
281
282
283
284
285
286
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                nn.init.constant_(m.bias, 0)


class KeypointRCNNPredictor(nn.Module):
    def __init__(self, in_channels, num_keypoints):
287
        super().__init__()
288
289
        input_features = in_channels
        deconv_kernel = 4
290
        self.kps_score_lowres = nn.ConvTranspose2d(
291
292
293
294
295
296
            input_features,
            num_keypoints,
            deconv_kernel,
            stride=2,
            padding=deconv_kernel // 2 - 1,
        )
297
        nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu")
298
299
300
301
302
303
        nn.init.constant_(self.kps_score_lowres.bias, 0)
        self.up_scale = 2
        self.out_channels = num_keypoints

    def forward(self, x):
        x = self.kps_score_lowres(x)
304
305
        return torch.nn.functional.interpolate(
            x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
306
307
308
        )


309
310
311
_COMMON_META = {
    "categories": _COCO_PERSON_CATEGORIES,
    "keypoint_names": _COCO_PERSON_KEYPOINT_NAMES,
312
    "min_size": (1, 1),
313
314
315
}


316
317
318
319
320
321
322
323
class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
    COCO_LEGACY = Weights(
        url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
        transforms=ObjectDetection,
        meta={
            **_COMMON_META,
            "num_params": 59137258,
            "recipe": "https://github.com/pytorch/vision/issues/1606",
324
325
326
327
328
            "_metrics": {
                "COCO-val2017": {
                    "box_map": 50.6,
                    "kp_map": 61.1,
                }
329
            },
330
            "_ops": 133.924,
Nicolas Hug's avatar
Nicolas Hug committed
331
            "_file_size": 226.054,
332
333
334
335
            "_docs": """
                These weights were produced by following a similar training recipe as on the paper but use a checkpoint
                from an early epoch.
            """,
336
337
338
339
340
341
342
343
344
        },
    )
    COCO_V1 = Weights(
        url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
        transforms=ObjectDetection,
        meta={
            **_COMMON_META,
            "num_params": 59137258,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#keypoint-r-cnn",
345
346
347
348
349
            "_metrics": {
                "COCO-val2017": {
                    "box_map": 54.6,
                    "kp_map": 65.0,
                }
350
            },
351
            "_ops": 137.42,
Nicolas Hug's avatar
Nicolas Hug committed
352
            "_file_size": 226.054,
353
            "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
354
355
356
357
358
        },
    )
    DEFAULT = COCO_V1


359
@register_model()
360
361
362
363
364
365
366
367
368
@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY
        if kwargs["pretrained"] == "legacy"
        else KeypointRCNN_ResNet50_FPN_Weights.COCO_V1,
    ),
    weights_backbone=("pretrained_backbone", ResNet50_Weights.IMAGENET1K_V1),
)
369
def keypointrcnn_resnet50_fpn(
370
371
372
373
374
375
376
377
378
    *,
    weights: Optional[KeypointRCNN_ResNet50_FPN_Weights] = None,
    progress: bool = True,
    num_classes: Optional[int] = None,
    num_keypoints: Optional[int] = None,
    weights_backbone: Optional[ResNet50_Weights] = ResNet50_Weights.IMAGENET1K_V1,
    trainable_backbone_layers: Optional[int] = None,
    **kwargs: Any,
) -> KeypointRCNN:
379
380
381
    """
    Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone.

382
383
    .. betastatus:: detection module

384
    Reference: `Mask R-CNN <https://arxiv.org/abs/1703.06870>`__.
385

386
387
388
    The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
    image, and should be in ``0-1`` range. Different images can have different sizes.

389
    The behavior of the model changes depending on if it is in training or evaluation mode.
390

391
    During training, the model expects both the input tensors and targets (list of dictionary),
392
    containing:
393

394
395
        - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
396
397
        - labels (``Int64Tensor[N]``): the class label for each ground-truth box
        - keypoints (``FloatTensor[N, K, 3]``): the ``K`` keypoints location for each of the ``N`` instances, in the
398
399
400
401
402
403
404
          format ``[x, y, visibility]``, where ``visibility=0`` means that the keypoint is not visible.

    The model returns a ``Dict[Tensor]`` during training, containing the classification and regression
    losses for both the RPN and the R-CNN, and the keypoint loss.

    During inference, the model requires only the input tensors, and returns the post-processed
    predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as
405
    follows, where ``N`` is the number of detected instances:
406

407
408
        - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with
          ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
409
410
        - labels (``Int64Tensor[N]``): the predicted labels for each instance
        - scores (``Tensor[N]``): the scores or each instance
411
        - keypoints (``FloatTensor[N, K, 3]``): the locations of the predicted keypoints, in ``[x, y, v]`` format.
412

413
414
    For more details on the output, you may refer to :ref:`instance_seg_output`.

415
416
    Keypoint R-CNN is exportable to ONNX for a fixed batch size with inputs images of fixed size.

417
418
    Example::

419
        >>> model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT)
420
421
422
        >>> model.eval()
        >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
        >>> predictions = model(x)
423
424
425
        >>>
        >>> # optionally, if you want to export the model to ONNX:
        >>> torch.onnx.export(model, x, "keypoint_rcnn.onnx", opset_version = 11)
426

427
    Args:
428
429
430
431
432
        weights (:class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights`
            below for more details, and possible values. By default, no
            pre-trained weights are used.
433
        progress (bool): If True, displays a progress bar of the download to stderr
434
435
        num_classes (int, optional): number of output classes of the model (including the background)
        num_keypoints (int, optional): number of keypoints
436
437
        weights_backbone (:class:`~torchvision.models.ResNet50_Weights`, optional): The
            pretrained weights for the backbone.
438
        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
439
440
            Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. If ``None`` is
            passed (the default) this value is set to 3.
441
442
443

    .. autoclass:: torchvision.models.detection.KeypointRCNN_ResNet50_FPN_Weights
        :members:
444
    """
445
446
447
448
449
    weights = KeypointRCNN_ResNet50_FPN_Weights.verify(weights)
    weights_backbone = ResNet50_Weights.verify(weights_backbone)

    if weights is not None:
        weights_backbone = None
450
451
        num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"]))
        num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"]))
452
453
454
455
456
457
458
    else:
        if num_classes is None:
            num_classes = 2
        if num_keypoints is None:
            num_keypoints = 17

    is_trained = weights is not None or weights_backbone is not None
459
460
    trainable_backbone_layers = _validate_trainable_layers(is_trained, trainable_backbone_layers, 5, 3)
    norm_layer = misc_nn_ops.FrozenBatchNorm2d if is_trained else nn.BatchNorm2d
461

462
    backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=norm_layer)
463
    backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
464
    model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
465
466

    if weights is not None:
467
        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
468
469
470
        if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
            overwrite_eps(model, 0.0)

471
    return model