mobilenetv2.py 9.66 KB
Newer Older
1
import warnings
2
from functools import partial
3
4
5
from typing import Callable, Any, Optional, List

import torch
6
from torch import Tensor
7
8
from torch import nn

9
from ..ops.misc import Conv2dNormActivation
10
from ..transforms._presets import ImageClassification
11
from ..utils import _log_api_usage_once
12
13
14
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible
15
16


17
__all__ = ["MobileNetV2", "MobileNet_V2_Weights", "mobilenet_v2"]
18
19


20
# necessary for backwards compatibility
21
class _DeprecatedConvBNAct(Conv2dNormActivation):
22
23
    def __init__(self, *args, **kwargs):
        warnings.warn(
24
            "The ConvBNReLU/ConvBNActivation classes are deprecated since 0.12 and will be removed in 0.14. "
25
            "Use torchvision.ops.misc.Conv2dNormActivation instead.",
26
27
            FutureWarning,
        )
28
29
30
31
32
        if kwargs.get("norm_layer", None) is None:
            kwargs["norm_layer"] = nn.BatchNorm2d
        if kwargs.get("activation_layer", None) is None:
            kwargs["activation_layer"] = nn.ReLU6
        super().__init__(*args, **kwargs)
33
34


35
36
ConvBNReLU = _DeprecatedConvBNAct
ConvBNActivation = _DeprecatedConvBNAct
37
38
39
40


class InvertedResidual(nn.Module):
    def __init__(
41
        self, inp: int, oup: int, stride: int, expand_ratio: int, norm_layer: Optional[Callable[..., nn.Module]] = None
42
    ) -> None:
43
        super().__init__()
44
        self.stride = stride
45
46
        if stride not in [1, 2]:
            raise ValueError(f"stride should be 1 or 2 insted of {stride}")
47
48
49
50
51
52
53
54
55
56

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        hidden_dim = int(round(inp * expand_ratio))
        self.use_res_connect = self.stride == 1 and inp == oup

        layers: List[nn.Module] = []
        if expand_ratio != 1:
            # pw
57
            layers.append(
58
                Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6)
59
60
61
62
            )
        layers.extend(
            [
                # dw
63
                Conv2dNormActivation(
64
65
66
67
68
69
70
71
72
73
74
75
                    hidden_dim,
                    hidden_dim,
                    stride=stride,
                    groups=hidden_dim,
                    norm_layer=norm_layer,
                    activation_layer=nn.ReLU6,
                ),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                norm_layer(oup),
            ]
        )
76
        self.conv = nn.Sequential(*layers)
77
        self.out_channels = oup
78
        self._is_cn = stride > 1
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

    def forward(self, x: Tensor) -> Tensor:
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileNetV2(nn.Module):
    def __init__(
        self,
        num_classes: int = 1000,
        width_mult: float = 1.0,
        inverted_residual_setting: Optional[List[List[int]]] = None,
        round_nearest: int = 8,
        block: Optional[Callable[..., nn.Module]] = None,
95
        norm_layer: Optional[Callable[..., nn.Module]] = None,
96
        dropout: float = 0.2,
97
98
99
100
101
102
103
104
105
106
107
108
    ) -> None:
        """
        MobileNet V2 main class

        Args:
            num_classes (int): Number of classes
            width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
            inverted_residual_setting: Network structure
            round_nearest (int): Round the number of channels in each layer to be a multiple of this number
            Set to 1 to turn off rounding
            block: Module specifying inverted residual building block for mobilenet
            norm_layer: Module specifying the normalization layer to use
109
            dropout (float): The droupout probability
110
111

        """
112
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
113
        _log_api_usage_once(self)
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

        if block is None:
            block = InvertedResidual

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        input_channel = 32
        last_channel = 1280

        if inverted_residual_setting is None:
            inverted_residual_setting = [
                # t, c, n, s
                [1, 16, 1, 1],
                [6, 24, 2, 2],
                [6, 32, 3, 2],
                [6, 64, 4, 2],
                [6, 96, 3, 1],
                [6, 160, 3, 2],
                [6, 320, 1, 1],
            ]

        # only check the first element, assuming user knows t,c,n,s are required
        if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
138
            raise ValueError(
139
                f"inverted_residual_setting should be non-empty or a 4-element list, got {inverted_residual_setting}"
140
            )
141
142
143
144

        # building first layer
        input_channel = _make_divisible(input_channel * width_mult, round_nearest)
        self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
145
        features: List[nn.Module] = [
146
            Conv2dNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6)
147
        ]
148
149
150
151
152
153
154
155
        # building inverted residual blocks
        for t, c, n, s in inverted_residual_setting:
            output_channel = _make_divisible(c * width_mult, round_nearest)
            for i in range(n):
                stride = s if i == 0 else 1
                features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
                input_channel = output_channel
        # building last several layers
156
        features.append(
157
            Conv2dNormActivation(
158
159
160
                input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6
            )
        )
161
162
163
164
165
        # make it nn.Sequential
        self.features = nn.Sequential(*features)

        # building classifier
        self.classifier = nn.Sequential(
166
            nn.Dropout(p=dropout),
167
168
169
170
171
172
            nn.Linear(self.last_channel, num_classes),
        )

        # weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
173
                nn.init.kaiming_normal_(m.weight, mode="fan_out")
174
175
176
177
178
179
180
181
182
183
184
185
186
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)

    def _forward_impl(self, x: Tensor) -> Tensor:
        # This exists since TorchScript doesn't support inheritance, so the superclass method
        # (this one) needs to have a name other than `forward` that can be accessed in a subclass
        x = self.features(x)
187
188
189
        # Cannot use "squeeze" as batch-size can be 1
        x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1)
190
191
192
193
194
195
196
        x = self.classifier(x)
        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)


197
198
199
200
201
202
203
204
205
206
207
208
209
210
_COMMON_META = {
    "num_params": 3504872,
    "min_size": (1, 1),
    "categories": _IMAGENET_CATEGORIES,
}


class MobileNet_V2_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2",
211
212
213
214
            "metrics": {
                "acc@1": 71.878,
                "acc@5": 90.286,
            },
215
216
217
218
219
220
221
222
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning",
223
224
225
226
            "metrics": {
                "acc@1": 72.154,
                "acc@5": 90.822,
            },
227
228
229
230
231
232
233
234
235
        },
    )
    DEFAULT = IMAGENET1K_V2


@handle_legacy_interface(weights=("pretrained", MobileNet_V2_Weights.IMAGENET1K_V1))
def mobilenet_v2(
    *, weights: Optional[MobileNet_V2_Weights] = None, progress: bool = True, **kwargs: Any
) -> MobileNetV2:
236
237
    """MobileNetV2 architecture from the `MobileNetV2: Inverted Residuals and Linear
    Bottlenecks <https://arxiv.org/abs/1801.04381>`_ paper.
238
239

    Args:
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        weights (:class:`~torchvision.models.MobileNet_V2_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.MobileNet_V2_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.mobilenetv2.MobileNetV2``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv2.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.MobileNet_V2_Weights
        :members:
254
    """
255
256
257
258
259
    weights = MobileNet_V2_Weights.verify(weights)

    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

260
    model = MobileNetV2(**kwargs)
261
262
263
264

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

265
    return model
266
267
268
269
270
271
272
273
274
275
276


# The dictionary below is internal implementation detail and will be removed in v0.15
from ._utils import _ModelURLs


model_urls = _ModelURLs(
    {
        "mobilenet_v2": MobileNet_V2_Weights.IMAGENET1K_V1.url,
    }
)