ssdlite.py 10.3 KB
Newer Older
1
import warnings
2
3
from collections import OrderedDict
from functools import partial
4
from typing import Any, Callable, Dict, List, Optional, Union
5

6
7
8
9
10
import torch
from torch import nn, Tensor

from ..._internally_replaced_utils import load_state_dict_from_url
from ...ops.misc import ConvNormActivation
11
from ...utils import _log_api_usage_once
12
from .. import mobilenet
13
14
15
from . import _utils as det_utils
from .anchor_utils import DefaultBoxGenerator
from .backbone_utils import _validate_trainable_layers
16
from .ssd import SSD, SSDScoringHead
17
18


19
__all__ = ["ssdlite320_mobilenet_v3_large"]
20
21

model_urls = {
22
    "ssdlite320_mobilenet_v3_large_coco": "https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth"
23
24
25
}


26
# Building blocks of SSDlite as described in section 6.2 of MobileNetV2 paper
27
28
29
def _prediction_block(
    in_channels: int, out_channels: int, kernel_size: int, norm_layer: Callable[..., nn.Module]
) -> nn.Sequential:
30
31
    return nn.Sequential(
        # 3x3 depthwise with stride 1 and padding 1
32
33
34
35
36
37
38
39
        ConvNormActivation(
            in_channels,
            in_channels,
            kernel_size=kernel_size,
            groups=in_channels,
            norm_layer=norm_layer,
            activation_layer=nn.ReLU6,
        ),
40
        # 1x1 projetion to output channels
41
        nn.Conv2d(in_channels, out_channels, 1),
42
43
44
45
46
47
48
49
    )


def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., nn.Module]) -> nn.Sequential:
    activation = nn.ReLU6
    intermediate_channels = out_channels // 2
    return nn.Sequential(
        # 1x1 projection to half output channels
50
51
52
        ConvNormActivation(
            in_channels, intermediate_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation
        ),
53
        # 3x3 depthwise with stride 2 and padding 1
54
55
56
57
58
59
60
61
62
        ConvNormActivation(
            intermediate_channels,
            intermediate_channels,
            kernel_size=3,
            stride=2,
            groups=intermediate_channels,
            norm_layer=norm_layer,
            activation_layer=activation,
        ),
63
        # 1x1 projetion to output channels
64
65
66
        ConvNormActivation(
            intermediate_channels, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation
        ),
67
68
69
70
71
72
73
74
75
76
77
78
    )


def _normal_init(conv: nn.Module):
    for layer in conv.modules():
        if isinstance(layer, nn.Conv2d):
            torch.nn.init.normal_(layer.weight, mean=0.0, std=0.03)
            if layer.bias is not None:
                torch.nn.init.constant_(layer.bias, 0.0)


class SSDLiteHead(nn.Module):
79
80
81
    def __init__(
        self, in_channels: List[int], num_anchors: List[int], num_classes: int, norm_layer: Callable[..., nn.Module]
    ):
82
83
84
85
86
87
        super().__init__()
        self.classification_head = SSDLiteClassificationHead(in_channels, num_anchors, num_classes, norm_layer)
        self.regression_head = SSDLiteRegressionHead(in_channels, num_anchors, norm_layer)

    def forward(self, x: List[Tensor]) -> Dict[str, Tensor]:
        return {
88
89
            "bbox_regression": self.regression_head(x),
            "cls_logits": self.classification_head(x),
90
91
92
93
        }


class SSDLiteClassificationHead(SSDScoringHead):
94
95
96
    def __init__(
        self, in_channels: List[int], num_anchors: List[int], num_classes: int, norm_layer: Callable[..., nn.Module]
    ):
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        cls_logits = nn.ModuleList()
        for channels, anchors in zip(in_channels, num_anchors):
            cls_logits.append(_prediction_block(channels, num_classes * anchors, 3, norm_layer))
        _normal_init(cls_logits)
        super().__init__(cls_logits, num_classes)


class SSDLiteRegressionHead(SSDScoringHead):
    def __init__(self, in_channels: List[int], num_anchors: List[int], norm_layer: Callable[..., nn.Module]):
        bbox_reg = nn.ModuleList()
        for channels, anchors in zip(in_channels, num_anchors):
            bbox_reg.append(_prediction_block(channels, 4 * anchors, 3, norm_layer))
        _normal_init(bbox_reg)
        super().__init__(bbox_reg, 4)


class SSDLiteFeatureExtractorMobileNet(nn.Module):
114
115
116
117
118
119
120
121
    def __init__(
        self,
        backbone: nn.Module,
        c4_pos: int,
        norm_layer: Callable[..., nn.Module],
        width_mult: float = 1.0,
        min_depth: int = 16,
    ):
122
        super().__init__()
123
        _log_api_usage_once(self)
124
125
126

        assert not backbone[c4_pos].use_res_connect
        self.features = nn.Sequential(
127
            # As described in section 6.3 of MobileNetV3 paper
128
            nn.Sequential(*backbone[:c4_pos], backbone[c4_pos].block[0]),  # from start until C4 expansion layer
129
            nn.Sequential(backbone[c4_pos].block[1:], *backbone[c4_pos + 1 :]),  # from C4 depthwise until end
130
131
132
        )

        get_depth = lambda d: max(min_depth, int(d * width_mult))  # noqa: E731
133
134
135
136
137
138
139
140
        extra = nn.ModuleList(
            [
                _extra_block(backbone[-1].out_channels, get_depth(512), norm_layer),
                _extra_block(get_depth(512), get_depth(256), norm_layer),
                _extra_block(get_depth(256), get_depth(256), norm_layer),
                _extra_block(get_depth(256), get_depth(128), norm_layer),
            ]
        )
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        _normal_init(extra)

        self.extra = extra

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        # Get feature maps from backbone and extra. Can't be refactored due to JIT limitations.
        output = []
        for block in self.features:
            x = block(x)
            output.append(x)

        for block in self.extra:
            x = block(x)
            output.append(x)

        return OrderedDict([(str(i), v) for i, v in enumerate(output)])


159
def _mobilenet_extractor(
160
    backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3],
161
162
163
    trainable_layers: int,
    norm_layer: Callable[..., nn.Module],
):
164
    backbone = backbone.features
165
166
167
168
169
170
171
    # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
    # The first and last blocks are always included because they are the C0 (conv1) and Cn.
    stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
    num_stages = len(stage_indices)

    # find the index of the layer from which we wont freeze
    assert 0 <= trainable_layers <= num_stages
172
    freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
173
174
175
176
177

    for b in backbone[:freeze_before]:
        for parameter in b.parameters():
            parameter.requires_grad_(False)

178
    return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer)
179
180


181
182
183
184
185
186
187
188
189
def ssdlite320_mobilenet_v3_large(
    pretrained: bool = False,
    progress: bool = True,
    num_classes: int = 91,
    pretrained_backbone: bool = False,
    trainable_backbone_layers: Optional[int] = None,
    norm_layer: Optional[Callable[..., nn.Module]] = None,
    **kwargs: Any,
):
190
191
192
193
194
    """Constructs an SSDlite model with input size 320x320 and a MobileNetV3 Large backbone, as described at
    `"Searching for MobileNetV3"
    <https://arxiv.org/abs/1905.02244>`_ and
    `"MobileNetV2: Inverted Residuals and Linear Bottlenecks"
    <https://arxiv.org/abs/1801.04381>`_.
195
196

    See :func:`~torchvision.models.detection.ssd300_vgg16` for more details.
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213

    Example:

        >>> model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=True)
        >>> model.eval()
        >>> x = [torch.rand(3, 320, 320), torch.rand(3, 500, 400)]
        >>> predictions = model(x)

    Args:
        pretrained (bool): If True, returns a model pre-trained on COCO train2017
        progress (bool): If True, displays a progress bar of the download to stderr
        num_classes (int): number of output classes of the model (including the background)
        pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
        trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
            Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable.
        norm_layer (callable, optional): Module specifying the normalization layer to use.
    """
214
215
216
    if "size" in kwargs:
        warnings.warn("The size of the model is already fixed; ignoring the argument.")

217
    trainable_backbone_layers = _validate_trainable_layers(
218
219
        pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6
    )
220
221
222
223

    if pretrained:
        pretrained_backbone = False

224
    # Enable reduced tail if no pretrained backbone is selected. See Table 6 of MobileNetV3 paper.
225
    reduce_tail = not pretrained_backbone
226
227
228
229

    if norm_layer is None:
        norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)

230
231
232
233
234
235
    backbone = mobilenet.mobilenet_v3_large(
        pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer, reduced_tail=reduce_tail, **kwargs
    )
    if not pretrained_backbone:
        # Change the default initialization scheme if not pretrained
        _normal_init(backbone)
236
    backbone = _mobilenet_extractor(
237
        backbone,
238
239
240
        trainable_backbone_layers,
        norm_layer,
    )
241
242
243
244
245
246
247
248
249
250
251
252

    size = (320, 320)
    anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95)
    out_channels = det_utils.retrieve_out_channels(backbone, size)
    num_anchors = anchor_generator.num_anchors_per_location()
    assert len(out_channels) == len(anchor_generator.aspect_ratios)

    defaults = {
        "score_thresh": 0.001,
        "nms_thresh": 0.55,
        "detections_per_img": 300,
        "topk_candidates": 300,
253
254
255
256
        # Rescale the input in a way compatible to the backbone:
        # The following mean/std rescale the data from [0, 1] to [-1, -1]
        "image_mean": [0.5, 0.5, 0.5],
        "image_std": [0.5, 0.5, 0.5],
257
258
    }
    kwargs = {**defaults, **kwargs}
259
260
261
262
263
264
265
266
    model = SSD(
        backbone,
        anchor_generator,
        size,
        num_classes,
        head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer),
        **kwargs,
    )
267
268

    if pretrained:
269
        weights_name = "ssdlite320_mobilenet_v3_large_coco"
270
        if model_urls.get(weights_name, None) is None:
271
            raise ValueError(f"No checkpoint is available for model {weights_name}")
272
273
274
        state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
        model.load_state_dict(state_dict)
    return model