mnasnet.py 17.2 KB
Newer Older
Dmitry Belenko's avatar
Dmitry Belenko committed
1
import warnings
2
3
from functools import partial
from typing import Any, Dict, List, Optional
4
5
6

import torch
import torch.nn as nn
7
8
from torch import Tensor

9
from ..transforms._presets import ImageClassification
10
from ..utils import _log_api_usage_once
11
from ._api import register_model, Weights, WeightsEnum
12
from ._meta import _IMAGENET_CATEGORIES
13
from ._utils import _ovewrite_named_param, handle_legacy_interface
14
15


16
17
18
19
20
21
22
23
24
25
26
27
__all__ = [
    "MNASNet",
    "MNASNet0_5_Weights",
    "MNASNet0_75_Weights",
    "MNASNet1_0_Weights",
    "MNASNet1_3_Weights",
    "mnasnet0_5",
    "mnasnet0_75",
    "mnasnet1_0",
    "mnasnet1_3",
]

28
29
30
31
32
33
34

# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
# 1.0 - tensorflow.
_BN_MOMENTUM = 1 - 0.9997


class _InvertedResidual(nn.Module):
35
    def __init__(
36
        self, in_ch: int, out_ch: int, kernel_size: int, stride: int, expansion_factor: int, bn_momentum: float = 0.1
37
    ) -> None:
38
        super().__init__()
39
40
41
42
        if stride not in [1, 2]:
            raise ValueError(f"stride should be 1 or 2 instead of {stride}")
        if kernel_size not in [3, 5]:
            raise ValueError(f"kernel_size should be 3 or 5 instead of {kernel_size}")
43
        mid_ch = in_ch * expansion_factor
44
        self.apply_residual = in_ch == out_ch and stride == 1
45
46
47
48
49
50
        self.layers = nn.Sequential(
            # Pointwise
            nn.Conv2d(in_ch, mid_ch, 1, bias=False),
            nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
            nn.ReLU(inplace=True),
            # Depthwise
51
            nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, stride=stride, groups=mid_ch, bias=False),
52
53
54
55
            nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
            nn.ReLU(inplace=True),
            # Linear pointwise. Note that there's no activation.
            nn.Conv2d(mid_ch, out_ch, 1, bias=False),
56
57
            nn.BatchNorm2d(out_ch, momentum=bn_momentum),
        )
58

59
    def forward(self, input: Tensor) -> Tensor:
60
61
62
63
64
65
        if self.apply_residual:
            return self.layers(input) + input
        else:
            return self.layers(input)


66
67
68
69
def _stack(
    in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int, bn_momentum: float
) -> nn.Sequential:
    """Creates a stack of inverted residuals."""
70
71
    if repeats < 1:
        raise ValueError(f"repeats should be >= 1, instead got {repeats}")
72
    # First one has no skip, because feature map size changes.
73
    first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, bn_momentum=bn_momentum)
74
75
    remaining = []
    for _ in range(1, repeats):
76
        remaining.append(_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, bn_momentum=bn_momentum))
77
78
79
    return nn.Sequential(first, *remaining)


80
def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int:
81
    """Asymmetric rounding to make `val` divisible by `divisor`. With default
82
    bias, will round up, unless the number is no more than 10% greater than the
83
    smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88."""
84
85
    if not 0.0 < round_up_bias < 1.0:
        raise ValueError(f"round_up_bias should be greater than 0.0 and smaller than 1.0 instead of {round_up_bias}")
86
87
88
89
    new_val = max(divisor, int(val + divisor / 2) // divisor * divisor)
    return new_val if new_val >= round_up_bias * val else new_val + divisor


90
def _get_depths(alpha: float) -> List[int]:
91
92
    """Scales tensor depths as in reference MobileNet code, prefers rouding up
    rather than down."""
Dmitry Belenko's avatar
Dmitry Belenko committed
93
    depths = [32, 16, 24, 40, 80, 96, 192, 320]
94
95
96
97
    return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]


class MNASNet(torch.nn.Module):
98
    """MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
Dmitry Belenko's avatar
Dmitry Belenko committed
99
    implements the B1 variant of the model.
100
    >>> model = MNASNet(1.0, num_classes=1000)
101
102
103
    >>> x = torch.rand(1, 3, 224, 224)
    >>> y = model(x)
    >>> y.dim()
104
    2
105
106
107
    >>> y.nelement()
    1000
    """
108

Dmitry Belenko's avatar
Dmitry Belenko committed
109
110
    # Version 2 adds depth scaling in the initial stages of the network.
    _version = 2
111

112
    def __init__(self, alpha: float, num_classes: int = 1000, dropout: float = 0.2) -> None:
113
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
114
        _log_api_usage_once(self)
115
116
        if alpha <= 0.0:
            raise ValueError(f"alpha should be greater than 0.0 instead of {alpha}")
Dmitry Belenko's avatar
Dmitry Belenko committed
117
118
119
        self.alpha = alpha
        self.num_classes = num_classes
        depths = _get_depths(alpha)
120
121
        layers = [
            # First layer: regular conv.
Dmitry Belenko's avatar
Dmitry Belenko committed
122
123
            nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
124
125
            nn.ReLU(inplace=True),
            # Depthwise separable, no skip.
126
            nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, groups=depths[0], bias=False),
Dmitry Belenko's avatar
Dmitry Belenko committed
127
            nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
128
            nn.ReLU(inplace=True),
Dmitry Belenko's avatar
Dmitry Belenko committed
129
130
            nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False),
            nn.BatchNorm2d(depths[1], momentum=_BN_MOMENTUM),
131
            # MNASNet blocks: stacks of inverted residuals.
Dmitry Belenko's avatar
Dmitry Belenko committed
132
133
134
135
136
137
            _stack(depths[1], depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
            _stack(depths[2], depths[3], 5, 2, 3, 3, _BN_MOMENTUM),
            _stack(depths[3], depths[4], 5, 2, 6, 3, _BN_MOMENTUM),
            _stack(depths[4], depths[5], 3, 1, 6, 2, _BN_MOMENTUM),
            _stack(depths[5], depths[6], 5, 2, 6, 4, _BN_MOMENTUM),
            _stack(depths[6], depths[7], 3, 1, 6, 1, _BN_MOMENTUM),
138
            # Final mapping to classifier input.
Dmitry Belenko's avatar
Dmitry Belenko committed
139
            nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False),
140
141
142
143
            nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),
            nn.ReLU(inplace=True),
        ]
        self.layers = nn.Sequential(*layers)
144
        self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), nn.Linear(1280, num_classes))
145
146
147

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
148
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
149
150
151
152
153
154
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
155
                nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="sigmoid")
156
157
                nn.init.zeros_(m.bias)

158
159
160
161
162
163
    def forward(self, x: Tensor) -> Tensor:
        x = self.layers(x)
        # Equivalent to global avgpool and removing H and W dimensions.
        x = x.mean([2, 3])
        return self.classifier(x)

164
165
166
167
168
169
170
171
172
173
    def _load_from_state_dict(
        self,
        state_dict: Dict,
        prefix: str,
        local_metadata: Dict,
        strict: bool,
        missing_keys: List[str],
        unexpected_keys: List[str],
        error_msgs: List[str],
    ) -> None:
Dmitry Belenko's avatar
Dmitry Belenko committed
174
        version = local_metadata.get("version", None)
175
176
        if version not in [1, 2]:
            raise ValueError(f"version shluld be set to 1 or 2 instead of {version}")
Dmitry Belenko's avatar
Dmitry Belenko committed
177
178
179
180
181
182
183
184
185
186
187

        if version == 1 and not self.alpha == 1.0:
            # In the initial version of the model (v1), stem was fixed-size.
            # All other layer configurations were the same. This will patch
            # the model so that it's identical to v1. Model with alpha 1.0 is
            # unaffected.
            depths = _get_depths(self.alpha)
            v1_stem = [
                nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
                nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
                nn.ReLU(inplace=True),
188
                nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False),
Dmitry Belenko's avatar
Dmitry Belenko committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
                nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
                nn.ReLU(inplace=True),
                nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
                nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
                _stack(16, depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
            ]
            for idx, layer in enumerate(v1_stem):
                self.layers[idx] = layer

            # The model is now identical to v1, and must be saved as such.
            self._version = 1
            warnings.warn(
                "A new version of MNASNet model has been implemented. "
                "Your checkpoint was saved using the previous version. "
                "This checkpoint will load and work as before, but "
                "you may want to upgrade by training a newer model or "
                "transfer learning from an updated ImageNet checkpoint.",
206
207
                UserWarning,
            )
Dmitry Belenko's avatar
Dmitry Belenko committed
208

209
        super()._load_from_state_dict(
210
211
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
        )
Dmitry Belenko's avatar
Dmitry Belenko committed
212

213

214
215
216
217
218
219
220
221
222
223
224
225
226
227
_COMMON_META = {
    "min_size": (1, 1),
    "categories": _IMAGENET_CATEGORIES,
    "recipe": "https://github.com/1e100/mnasnet_trainer",
}


class MNASNet0_5_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 2218512,
228
229
230
231
232
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 67.734,
                    "acc@5": 87.490,
                }
233
            },
234
235
            "_ops": 0.104,
            "_weight_size": 8.591,
236
            "_docs": """These weights reproduce closely the results of the paper.""",
237
238
239
240
241
242
        },
    )
    DEFAULT = IMAGENET1K_V1


class MNASNet0_75_Weights(WeightsEnum):
243
244
245
246
247
248
249
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/mnasnet0_75-7090bc5f.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "recipe": "https://github.com/pytorch/vision/pull/6019",
            "num_params": 3170208,
250
251
252
253
254
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 71.180,
                    "acc@5": 90.496,
                }
255
            },
256
257
            "_ops": 0.215,
            "_weight_size": 12.303,
258
259
260
261
            "_docs": """
                These weights were trained from scratch by using TorchVision's `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
262
263
264
        },
    )
    DEFAULT = IMAGENET1K_V1
265
266
267
268
269
270
271
272
273


class MNASNet1_0_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 4383312,
274
275
276
277
278
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 73.456,
                    "acc@5": 91.510,
                }
279
            },
280
281
            "_ops": 0.314,
            "_weight_size": 16.915,
282
            "_docs": """These weights reproduce closely the results of the paper.""",
283
284
285
286
287
288
        },
    )
    DEFAULT = IMAGENET1K_V1


class MNASNet1_3_Weights(WeightsEnum):
289
290
291
292
293
294
295
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/mnasnet1_3-a4c69d6f.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "recipe": "https://github.com/pytorch/vision/pull/6019",
            "num_params": 6282256,
296
297
298
299
300
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 76.506,
                    "acc@5": 93.522,
                }
301
            },
302
303
            "_ops": 0.526,
            "_weight_size": 24.246,
304
305
306
307
            "_docs": """
                These weights were trained from scratch by using TorchVision's `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
308
309
310
        },
    )
    DEFAULT = IMAGENET1K_V1
311
312
313
314
315


def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet:
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
316

317
    model = MNASNet(alpha, **kwargs)
318

319
320
321
322
323
324
    if weights:
        model.load_state_dict(weights.get_state_dict(progress=progress))

    return model


325
@register_model()
326
327
@handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.IMAGENET1K_V1))
def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
Joao Gomes's avatar
Joao Gomes committed
328
329
330
    """MNASNet with depth multiplier of 0.5 from
    `MnasNet: Platform-Aware Neural Architecture Search for Mobile
    <https://arxiv.org/pdf/1807.11626.pdf>`_ paper.
331

ekka's avatar
ekka committed
332
    Args:
Joao Gomes's avatar
Joao Gomes committed
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        weights (:class:`~torchvision.models.MNASNet0_5_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.MNASNet0_5_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.MNASNet0_5_Weights
        :members:
ekka's avatar
ekka committed
347
    """
348
349
350
    weights = MNASNet0_5_Weights.verify(weights)

    return _mnasnet(0.5, weights, progress, **kwargs)
351
352


353
@register_model()
354
@handle_legacy_interface(weights=("pretrained", MNASNet0_75_Weights.IMAGENET1K_V1))
355
def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
Joao Gomes's avatar
Joao Gomes committed
356
357
358
    """MNASNet with depth multiplier of 0.75 from
    `MnasNet: Platform-Aware Neural Architecture Search for Mobile
    <https://arxiv.org/pdf/1807.11626.pdf>`_ paper.
359

ekka's avatar
ekka committed
360
    Args:
361
362
363
364
        weights (:class:`~torchvision.models.MNASNet0_75_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.MNASNet0_75_Weights` below for
            more details, and possible values. By default, no pre-trained
Joao Gomes's avatar
Joao Gomes committed
365
366
367
368
369
370
371
372
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
            for more details about this class.

373
374
    .. autoclass:: torchvision.models.MNASNet0_75_Weights
        :members:
ekka's avatar
ekka committed
375
    """
376
377
378
    weights = MNASNet0_75_Weights.verify(weights)

    return _mnasnet(0.75, weights, progress, **kwargs)
379
380


381
@register_model()
382
383
@handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.IMAGENET1K_V1))
def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
Joao Gomes's avatar
Joao Gomes committed
384
385
386
    """MNASNet with depth multiplier of 1.0 from
    `MnasNet: Platform-Aware Neural Architecture Search for Mobile
    <https://arxiv.org/pdf/1807.11626.pdf>`_ paper.
387

ekka's avatar
ekka committed
388
    Args:
Joao Gomes's avatar
Joao Gomes committed
389
390
391
392
393
394
395
396
397
398
399
400
401
402
        weights (:class:`~torchvision.models.MNASNet1_0_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.MNASNet1_0_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.MNASNet1_0_Weights
        :members:
ekka's avatar
ekka committed
403
    """
404
    weights = MNASNet1_0_Weights.verify(weights)
405

406
    return _mnasnet(1.0, weights, progress, **kwargs)
407

408

409
@register_model()
410
@handle_legacy_interface(weights=("pretrained", MNASNet1_3_Weights.IMAGENET1K_V1))
411
def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
Joao Gomes's avatar
Joao Gomes committed
412
413
414
    """MNASNet with depth multiplier of 1.3 from
    `MnasNet: Platform-Aware Neural Architecture Search for Mobile
    <https://arxiv.org/pdf/1807.11626.pdf>`_ paper.
415

ekka's avatar
ekka committed
416
    Args:
417
418
419
420
        weights (:class:`~torchvision.models.MNASNet1_3_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.MNASNet1_3_Weights` below for
            more details, and possible values. By default, no pre-trained
Joao Gomes's avatar
Joao Gomes committed
421
422
423
424
425
426
427
428
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
            for more details about this class.

429
430
    .. autoclass:: torchvision.models.MNASNet1_3_Weights
        :members:
ekka's avatar
ekka committed
431
    """
432
433
434
    weights = MNASNet1_3_Weights.verify(weights)

    return _mnasnet(1.3, weights, progress, **kwargs)