ssdlite.py 11.8 KB
Newer Older
1
import warnings
2
3
from collections import OrderedDict
from functools import partial
4
from typing import Any, Callable, Dict, List, Optional, Union
5

6
7
8
import torch
from torch import nn, Tensor

9
from ...ops.misc import Conv2dNormActivation
10
from ...transforms._presets import ObjectDetection
11
from ...utils import _log_api_usage_once
12
from .. import mobilenet
13
14
15
16
from .._api import WeightsEnum, Weights
from .._meta import _COCO_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_value_param
from ..mobilenetv3 import MobileNet_V3_Large_Weights, mobilenet_v3_large
17
18
19
from . import _utils as det_utils
from .anchor_utils import DefaultBoxGenerator
from .backbone_utils import _validate_trainable_layers
20
from .ssd import SSD, SSDScoringHead
21
22


23
24
25
26
__all__ = [
    "SSDLite320_MobileNet_V3_Large_Weights",
    "ssdlite320_mobilenet_v3_large",
]
27
28


29
# Building blocks of SSDlite as described in section 6.2 of MobileNetV2 paper
30
31
32
def _prediction_block(
    in_channels: int, out_channels: int, kernel_size: int, norm_layer: Callable[..., nn.Module]
) -> nn.Sequential:
33
34
    return nn.Sequential(
        # 3x3 depthwise with stride 1 and padding 1
35
        Conv2dNormActivation(
36
37
38
39
40
41
42
            in_channels,
            in_channels,
            kernel_size=kernel_size,
            groups=in_channels,
            norm_layer=norm_layer,
            activation_layer=nn.ReLU6,
        ),
43
        # 1x1 projetion to output channels
44
        nn.Conv2d(in_channels, out_channels, 1),
45
46
47
48
49
50
51
52
    )


def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., nn.Module]) -> nn.Sequential:
    activation = nn.ReLU6
    intermediate_channels = out_channels // 2
    return nn.Sequential(
        # 1x1 projection to half output channels
53
        Conv2dNormActivation(
54
55
            in_channels, intermediate_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation
        ),
56
        # 3x3 depthwise with stride 2 and padding 1
57
        Conv2dNormActivation(
58
59
60
61
62
63
64
65
            intermediate_channels,
            intermediate_channels,
            kernel_size=3,
            stride=2,
            groups=intermediate_channels,
            norm_layer=norm_layer,
            activation_layer=activation,
        ),
66
        # 1x1 projetion to output channels
67
        Conv2dNormActivation(
68
69
            intermediate_channels, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation
        ),
70
71
72
73
74
75
76
77
78
79
80
81
    )


def _normal_init(conv: nn.Module):
    for layer in conv.modules():
        if isinstance(layer, nn.Conv2d):
            torch.nn.init.normal_(layer.weight, mean=0.0, std=0.03)
            if layer.bias is not None:
                torch.nn.init.constant_(layer.bias, 0.0)


class SSDLiteHead(nn.Module):
82
83
84
    def __init__(
        self, in_channels: List[int], num_anchors: List[int], num_classes: int, norm_layer: Callable[..., nn.Module]
    ):
85
86
87
88
89
90
        super().__init__()
        self.classification_head = SSDLiteClassificationHead(in_channels, num_anchors, num_classes, norm_layer)
        self.regression_head = SSDLiteRegressionHead(in_channels, num_anchors, norm_layer)

    def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
        return {
91
92
            "bbox_regression": self.regression_head(x),
            "cls_logits": self.classification_head(x),
93
94
95
96
        }


class SSDLiteClassificationHead(SSDScoringHead):
97
98
99
    def __init__(
        self, in_channels: List[int], num_anchors: List[int], num_classes: int, norm_layer: Callable[..., nn.Module]
    ):
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        cls_logits = nn.ModuleList()
        for channels, anchors in zip(in_channels, num_anchors):
            cls_logits.append(_prediction_block(channels, num_classes * anchors, 3, norm_layer))
        _normal_init(cls_logits)
        super().__init__(cls_logits, num_classes)


class SSDLiteRegressionHead(SSDScoringHead):
    def __init__(self, in_channels: List[int], num_anchors: List[int], norm_layer: Callable[..., nn.Module]):
        bbox_reg = nn.ModuleList()
        for channels, anchors in zip(in_channels, num_anchors):
            bbox_reg.append(_prediction_block(channels, 4 * anchors, 3, norm_layer))
        _normal_init(bbox_reg)
        super().__init__(bbox_reg, 4)


class SSDLiteFeatureExtractorMobileNet(nn.Module):
117
118
119
120
121
122
123
124
    def __init__(
        self,
        backbone: nn.Module,
        c4_pos: int,
        norm_layer: Callable[..., nn.Module],
        width_mult: float = 1.0,
        min_depth: int = 16,
    ):
125
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
126
        _log_api_usage_once(self)
127

128
129
130
        if backbone[c4_pos].use_res_connect:
            raise ValueError("backbone[c4_pos].use_res_connect should be False")

131
        self.features = nn.Sequential(
132
            # As described in section 6.3 of MobileNetV3 paper
133
            nn.Sequential(*backbone[:c4_pos], backbone[c4_pos].block[0]),  # from start until C4 expansion layer
134
            nn.Sequential(backbone[c4_pos].block[1:], *backbone[c4_pos + 1 :]),  # from C4 depthwise until end
135
136
137
        )

        get_depth = lambda d: max(min_depth, int(d * width_mult))  # noqa: E731
138
139
140
141
142
143
144
145
        extra = nn.ModuleList(
            [
                _extra_block(backbone[-1].out_channels, get_depth(512), norm_layer),
                _extra_block(get_depth(512), get_depth(256), norm_layer),
                _extra_block(get_depth(256), get_depth(256), norm_layer),
                _extra_block(get_depth(256), get_depth(128), norm_layer),
            ]
        )
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
        _normal_init(extra)

        self.extra = extra

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        # Get feature maps from backbone and extra. Can't be refactored due to JIT limitations.
        output = []
        for block in self.features:
            x = block(x)
            output.append(x)

        for block in self.extra:
            x = block(x)
            output.append(x)

        return OrderedDict([(str(i), v) for i, v in enumerate(output)])


164
def _mobilenet_extractor(
165
    backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3],
166
167
168
    trainable_layers: int,
    norm_layer: Callable[..., nn.Module],
):
169
    backbone = backbone.features
170
171
172
173
174
175
    # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
    # The first and last blocks are always included because they are the C0 (conv1) and Cn.
    stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
    num_stages = len(stage_indices)

    # find the index of the layer from which we wont freeze
176
177
    if not 0 <= trainable_layers <= num_stages:
        raise ValueError("trainable_layers should be in the range [0, {num_stages}], instead got {trainable_layers}")
178
    freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
179
180
181
182
183

    for b in backbone[:freeze_before]:
        for parameter in b.parameters():
            parameter.requires_grad_(False)

184
    return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer)
185
186


187
188
189
190
191
192
193
class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
    COCO_V1 = Weights(
        url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth",
        transforms=ObjectDetection,
        meta={
            "num_params": 3440060,
            "categories": _COCO_CATEGORIES,
194
            "min_size": (1, 1),
195
            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#ssdlite320-mobilenetv3-large",
196
197
198
            "metrics": {
                "box_map": 21.3,
            },
199
200
201
202
203
204
205
206
207
        },
    )
    DEFAULT = COCO_V1


@handle_legacy_interface(
    weights=("pretrained", SSDLite320_MobileNet_V3_Large_Weights.COCO_V1),
    weights_backbone=("pretrained_backbone", MobileNet_V3_Large_Weights.IMAGENET1K_V1),
)
208
def ssdlite320_mobilenet_v3_large(
209
210
    *,
    weights: Optional[SSDLite320_MobileNet_V3_Large_Weights] = None,
211
    progress: bool = True,
212
213
    num_classes: Optional[int] = None,
    weights_backbone: Optional[MobileNet_V3_Large_Weights] = MobileNet_V3_Large_Weights.IMAGENET1K_V1,
214
215
216
    trainable_backbone_layers: Optional[int] = None,
    norm_layer: Optional[Callable[..., nn.Module]] = None,
    **kwargs: Any,
217
) -> SSD:
218
219
220
221
222
    """Constructs an SSDlite model with input size 320x320 and a MobileNetV3 Large backbone, as described at
    `"Searching for MobileNetV3"
    <https://arxiv.org/abs/1905.02244>`_ and
    `"MobileNetV2: Inverted Residuals and Linear Bottlenecks"
    <https://arxiv.org/abs/1801.04381>`_.
223
224

    See :func:`~torchvision.models.detection.ssd300_vgg16` for more details.
225
226
227

    Example:

228
        >>> model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(weights=SSDLite320_MobileNet_V3_Large_Weights.DEFAULT)
229
230
231
232
233
        >>> model.eval()
        >>> x = [torch.rand(3, 320, 320), torch.rand(3, 500, 400)]
        >>> predictions = model(x)

    Args:
234
        weights (FasterRCNN_ResNet50_FPN_Weights, optional): The pretrained weights for the model
235
        progress (bool): If True, displays a progress bar of the download to stderr
236
237
238
        num_classes (int, optional): number of output classes of the model (including the background)
        weights_backbone (ResNet50_Weights, optional): The pretrained weights for the backbone
        trainable_backbone_layers (int, optional): number of trainable (not frozen) layers starting from final block.
239
240
            Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. If ``None`` is
            passed (the default) this value is set to 6.
241
242
        norm_layer (callable, optional): Module specifying the normalization layer to use.
    """
243
244
245
    weights = SSDLite320_MobileNet_V3_Large_Weights.verify(weights)
    weights_backbone = MobileNet_V3_Large_Weights.verify(weights_backbone)

246
    if "size" in kwargs:
247
248
249
250
251
252
253
        warnings.warn("The size of the model is already fixed; ignoring the parameter.")

    if weights is not None:
        weights_backbone = None
        num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"]))
    elif num_classes is None:
        num_classes = 91
254

255
    trainable_backbone_layers = _validate_trainable_layers(
256
        weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 6
257
    )
258

259
    # Enable reduced tail if no pretrained backbone is selected. See Table 6 of MobileNetV3 paper.
260
    reduce_tail = weights_backbone is None
261
262
263
264

    if norm_layer is None:
        norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)

265
266
    backbone = mobilenet_v3_large(
        weights=weights_backbone, progress=progress, norm_layer=norm_layer, reduced_tail=reduce_tail, **kwargs
267
    )
268
    if weights_backbone is None:
269
270
        # Change the default initialization scheme if not pretrained
        _normal_init(backbone)
271
    backbone = _mobilenet_extractor(
272
        backbone,
273
274
275
        trainable_backbone_layers,
        norm_layer,
    )
276
277
278
279
280

    size = (320, 320)
    anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95)
    out_channels = det_utils.retrieve_out_channels(backbone, size)
    num_anchors = anchor_generator.num_anchors_per_location()
281
282
283
284
    if len(out_channels) != len(anchor_generator.aspect_ratios):
        raise ValueError(
            f"The length of the output channels from the backbone {len(out_channels)} do not match the length of the anchor generator aspect ratios {len(anchor_generator.aspect_ratios)}"
        )
285
286
287
288
289
290

    defaults = {
        "score_thresh": 0.001,
        "nms_thresh": 0.55,
        "detections_per_img": 300,
        "topk_candidates": 300,
291
        # Rescale the input in a way compatible to the backbone:
Xiaolin Wang's avatar
Xiaolin Wang committed
292
        # The following mean/std rescale the data from [0, 1] to [-1, 1]
293
294
        "image_mean": [0.5, 0.5, 0.5],
        "image_std": [0.5, 0.5, 0.5],
295
    }
296
    kwargs: Any = {**defaults, **kwargs}
297
298
299
300
301
302
303
304
    model = SSD(
        backbone,
        anchor_generator,
        size,
        num_classes,
        head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer),
        **kwargs,
    )
305

306
307
308
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

309
    return model