mnasnet.py 17.2 KB
Newer Older
Dmitry Belenko's avatar
Dmitry Belenko committed
1
import warnings
limm's avatar
limm committed
2
3
from functools import partial
from typing import Any, Dict, List, Optional
4
5
6

import torch
import torch.nn as nn
limm's avatar
limm committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from torch import Tensor

from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface


__all__ = [
    "MNASNet",
    "MNASNet0_5_Weights",
    "MNASNet0_75_Weights",
    "MNASNet1_0_Weights",
    "MNASNet1_3_Weights",
    "mnasnet0_5",
    "mnasnet0_75",
    "mnasnet1_0",
    "mnasnet1_3",
]

28
29
30
31
32
33
34

# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
# 1.0 - tensorflow.
_BN_MOMENTUM = 1 - 0.9997


class _InvertedResidual(nn.Module):
35
    def __init__(
limm's avatar
limm committed
36
        self, in_ch: int, out_ch: int, kernel_size: int, stride: int, expansion_factor: int, bn_momentum: float = 0.1
37
    ) -> None:
limm's avatar
limm committed
38
39
40
41
42
        super().__init__()
        if stride not in [1, 2]:
            raise ValueError(f"stride should be 1 or 2 instead of {stride}")
        if kernel_size not in [3, 5]:
            raise ValueError(f"kernel_size should be 3 or 5 instead of {kernel_size}")
43
        mid_ch = in_ch * expansion_factor
limm's avatar
limm committed
44
        self.apply_residual = in_ch == out_ch and stride == 1
45
46
47
48
49
50
        self.layers = nn.Sequential(
            # Pointwise
            nn.Conv2d(in_ch, mid_ch, 1, bias=False),
            nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
            nn.ReLU(inplace=True),
            # Depthwise
limm's avatar
limm committed
51
            nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, stride=stride, groups=mid_ch, bias=False),
52
53
54
55
            nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
            nn.ReLU(inplace=True),
            # Linear pointwise. Note that there's no activation.
            nn.Conv2d(mid_ch, out_ch, 1, bias=False),
limm's avatar
limm committed
56
57
            nn.BatchNorm2d(out_ch, momentum=bn_momentum),
        )
58

59
    def forward(self, input: Tensor) -> Tensor:
60
61
62
63
64
65
        if self.apply_residual:
            return self.layers(input) + input
        else:
            return self.layers(input)


limm's avatar
limm committed
66
67
68
69
70
71
def _stack(
    in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int, bn_momentum: float
) -> nn.Sequential:
    """Creates a stack of inverted residuals."""
    if repeats < 1:
        raise ValueError(f"repeats should be >= 1, instead got {repeats}")
72
    # First one has no skip, because feature map size changes.
limm's avatar
limm committed
73
    first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, bn_momentum=bn_momentum)
74
75
    remaining = []
    for _ in range(1, repeats):
limm's avatar
limm committed
76
        remaining.append(_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, bn_momentum=bn_momentum))
77
78
79
    return nn.Sequential(first, *remaining)


80
def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int:
limm's avatar
limm committed
81
    """Asymmetric rounding to make `val` divisible by `divisor`. With default
82
    bias, will round up, unless the number is no more than 10% greater than the
limm's avatar
limm committed
83
84
85
    smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88."""
    if not 0.0 < round_up_bias < 1.0:
        raise ValueError(f"round_up_bias should be greater than 0.0 and smaller than 1.0 instead of {round_up_bias}")
86
87
88
89
    new_val = max(divisor, int(val + divisor / 2) // divisor * divisor)
    return new_val if new_val >= round_up_bias * val else new_val + divisor


90
def _get_depths(alpha: float) -> List[int]:
limm's avatar
limm committed
91
92
    """Scales tensor depths as in reference MobileNet code, prefers rounding up
    rather than down."""
Dmitry Belenko's avatar
Dmitry Belenko committed
93
    depths = [32, 16, 24, 40, 80, 96, 192, 320]
94
95
96
97
    return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]


class MNASNet(torch.nn.Module):
limm's avatar
limm committed
98
    """MNASNet, as described in https://arxiv.org/abs/1807.11626. This
Dmitry Belenko's avatar
Dmitry Belenko committed
99
    implements the B1 variant of the model.
100
    >>> model = MNASNet(1.0, num_classes=1000)
101
102
103
    >>> x = torch.rand(1, 3, 224, 224)
    >>> y = model(x)
    >>> y.dim()
104
    2
105
106
107
    >>> y.nelement()
    1000
    """
limm's avatar
limm committed
108

Dmitry Belenko's avatar
Dmitry Belenko committed
109
110
    # Version 2 adds depth scaling in the initial stages of the network.
    _version = 2
111

limm's avatar
limm committed
112
113
114
115
116
    def __init__(self, alpha: float, num_classes: int = 1000, dropout: float = 0.2) -> None:
        super().__init__()
        _log_api_usage_once(self)
        if alpha <= 0.0:
            raise ValueError(f"alpha should be greater than 0.0 instead of {alpha}")
Dmitry Belenko's avatar
Dmitry Belenko committed
117
118
119
        self.alpha = alpha
        self.num_classes = num_classes
        depths = _get_depths(alpha)
120
121
        layers = [
            # First layer: regular conv.
Dmitry Belenko's avatar
Dmitry Belenko committed
122
123
            nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
124
125
            nn.ReLU(inplace=True),
            # Depthwise separable, no skip.
limm's avatar
limm committed
126
            nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, groups=depths[0], bias=False),
Dmitry Belenko's avatar
Dmitry Belenko committed
127
            nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
128
            nn.ReLU(inplace=True),
Dmitry Belenko's avatar
Dmitry Belenko committed
129
130
            nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False),
            nn.BatchNorm2d(depths[1], momentum=_BN_MOMENTUM),
131
            # MNASNet blocks: stacks of inverted residuals.
Dmitry Belenko's avatar
Dmitry Belenko committed
132
133
134
135
136
137
            _stack(depths[1], depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
            _stack(depths[2], depths[3], 5, 2, 3, 3, _BN_MOMENTUM),
            _stack(depths[3], depths[4], 5, 2, 6, 3, _BN_MOMENTUM),
            _stack(depths[4], depths[5], 3, 1, 6, 2, _BN_MOMENTUM),
            _stack(depths[5], depths[6], 5, 2, 6, 4, _BN_MOMENTUM),
            _stack(depths[6], depths[7], 3, 1, 6, 1, _BN_MOMENTUM),
138
            # Final mapping to classifier input.
Dmitry Belenko's avatar
Dmitry Belenko committed
139
            nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False),
140
141
142
143
            nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),
            nn.ReLU(inplace=True),
        ]
        self.layers = nn.Sequential(*layers)
limm's avatar
limm committed
144
        self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), nn.Linear(1280, num_classes))
145
146
147

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
limm's avatar
limm committed
148
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
149
150
151
152
153
154
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
limm's avatar
limm committed
155
                nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="sigmoid")
156
157
                nn.init.zeros_(m.bias)

limm's avatar
limm committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    def forward(self, x: Tensor) -> Tensor:
        x = self.layers(x)
        # Equivalent to global avgpool and removing H and W dimensions.
        x = x.mean([2, 3])
        return self.classifier(x)

    def _load_from_state_dict(
        self,
        state_dict: Dict,
        prefix: str,
        local_metadata: Dict,
        strict: bool,
        missing_keys: List[str],
        unexpected_keys: List[str],
        error_msgs: List[str],
    ) -> None:
Dmitry Belenko's avatar
Dmitry Belenko committed
174
        version = local_metadata.get("version", None)
limm's avatar
limm committed
175
176
        if version not in [1, 2]:
            raise ValueError(f"version shluld be set to 1 or 2 instead of {version}")
Dmitry Belenko's avatar
Dmitry Belenko committed
177
178
179
180
181
182
183
184
185
186
187

        if version == 1 and not self.alpha == 1.0:
            # In the initial version of the model (v1), stem was fixed-size.
            # All other layer configurations were the same. This will patch
            # the model so that it's identical to v1. Model with alpha 1.0 is
            # unaffected.
            depths = _get_depths(self.alpha)
            v1_stem = [
                nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
                nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
                nn.ReLU(inplace=True),
limm's avatar
limm committed
188
                nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False),
Dmitry Belenko's avatar
Dmitry Belenko committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
                nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
                nn.ReLU(inplace=True),
                nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
                nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
                _stack(16, depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
            ]
            for idx, layer in enumerate(v1_stem):
                self.layers[idx] = layer

            # The model is now identical to v1, and must be saved as such.
            self._version = 1
            warnings.warn(
                "A new version of MNASNet model has been implemented. "
                "Your checkpoint was saved using the previous version. "
                "This checkpoint will load and work as before, but "
                "you may want to upgrade by training a newer model or "
                "transfer learning from an updated ImageNet checkpoint.",
limm's avatar
limm committed
206
207
208
209
210
211
212
213
214
215
216
217
218
                UserWarning,
            )

        super()._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
        )


_COMMON_META = {
    "min_size": (1, 1),
    "categories": _IMAGENET_CATEGORIES,
    "recipe": "https://github.com/1e100/mnasnet_trainer",
}
Dmitry Belenko's avatar
Dmitry Belenko committed
219
220


limm's avatar
limm committed
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
class MNASNet0_5_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 2218512,
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 67.734,
                    "acc@5": 87.490,
                }
            },
            "_ops": 0.104,
            "_file_size": 8.591,
            "_docs": """These weights reproduce closely the results of the paper.""",
        },
    )
    DEFAULT = IMAGENET1K_V1


class MNASNet0_75_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/mnasnet0_75-7090bc5f.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "recipe": "https://github.com/pytorch/vision/pull/6019",
            "num_params": 3170208,
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 71.180,
                    "acc@5": 90.496,
                }
            },
            "_ops": 0.215,
            "_file_size": 12.303,
            "_docs": """
                These weights were trained from scratch by using TorchVision's `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
        },
    )
    DEFAULT = IMAGENET1K_V1


class MNASNet1_0_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 4383312,
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 73.456,
                    "acc@5": 91.510,
                }
            },
            "_ops": 0.314,
            "_file_size": 16.915,
            "_docs": """These weights reproduce closely the results of the paper.""",
        },
    )
    DEFAULT = IMAGENET1K_V1


class MNASNet1_3_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/mnasnet1_3-a4c69d6f.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "recipe": "https://github.com/pytorch/vision/pull/6019",
            "num_params": 6282256,
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 76.506,
                    "acc@5": 93.522,
                }
            },
            "_ops": 0.526,
            "_file_size": 24.246,
            "_docs": """
                These weights were trained from scratch by using TorchVision's `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
        },
    )
    DEFAULT = IMAGENET1K_V1


def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet:
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

    model = MNASNet(alpha, **kwargs)

    if weights:
        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
321

limm's avatar
limm committed
322
    return model
323
324


limm's avatar
limm committed
325
326
327
328
329
330
@register_model()
@handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.IMAGENET1K_V1))
def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
    """MNASNet with depth multiplier of 0.5 from
    `MnasNet: Platform-Aware Neural Architecture Search for Mobile
    <https://arxiv.org/abs/1807.11626>`_ paper.
331

ekka's avatar
ekka committed
332
    Args:
limm's avatar
limm committed
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        weights (:class:`~torchvision.models.MNASNet0_5_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.MNASNet0_5_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.MNASNet0_5_Weights
        :members:
ekka's avatar
ekka committed
347
    """
limm's avatar
limm committed
348
    weights = MNASNet0_5_Weights.verify(weights)
349

limm's avatar
limm committed
350
    return _mnasnet(0.5, weights, progress, **kwargs)
351

limm's avatar
limm committed
352
353
354
355
356
357
358

@register_model()
@handle_legacy_interface(weights=("pretrained", MNASNet0_75_Weights.IMAGENET1K_V1))
def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
    """MNASNet with depth multiplier of 0.75 from
    `MnasNet: Platform-Aware Neural Architecture Search for Mobile
    <https://arxiv.org/abs/1807.11626>`_ paper.
359

ekka's avatar
ekka committed
360
    Args:
limm's avatar
limm committed
361
362
363
364
365
366
367
368
369
370
371
372
373
374
        weights (:class:`~torchvision.models.MNASNet0_75_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.MNASNet0_75_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.MNASNet0_75_Weights
        :members:
ekka's avatar
ekka committed
375
    """
limm's avatar
limm committed
376
    weights = MNASNet0_75_Weights.verify(weights)
377

limm's avatar
limm committed
378
    return _mnasnet(0.75, weights, progress, **kwargs)
379

limm's avatar
limm committed
380
381
382
383
384
385
386

@register_model()
@handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.IMAGENET1K_V1))
def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
    """MNASNet with depth multiplier of 1.0 from
    `MnasNet: Platform-Aware Neural Architecture Search for Mobile
    <https://arxiv.org/abs/1807.11626>`_ paper.
387

ekka's avatar
ekka committed
388
    Args:
limm's avatar
limm committed
389
390
391
392
393
394
395
396
397
398
399
400
401
402
        weights (:class:`~torchvision.models.MNASNet1_0_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.MNASNet1_0_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.MNASNet1_0_Weights
        :members:
ekka's avatar
ekka committed
403
    """
limm's avatar
limm committed
404
405
406
    weights = MNASNet1_0_Weights.verify(weights)

    return _mnasnet(1.0, weights, progress, **kwargs)
407
408


limm's avatar
limm committed
409
410
411
412
413
414
@register_model()
@handle_legacy_interface(weights=("pretrained", MNASNet1_3_Weights.IMAGENET1K_V1))
def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
    """MNASNet with depth multiplier of 1.3 from
    `MnasNet: Platform-Aware Neural Architecture Search for Mobile
    <https://arxiv.org/abs/1807.11626>`_ paper.
415

ekka's avatar
ekka committed
416
    Args:
limm's avatar
limm committed
417
418
419
420
421
422
423
424
425
426
427
428
429
430
        weights (:class:`~torchvision.models.MNASNet1_3_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.MNASNet1_3_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.mnasnet.MNASNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/mnasnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.MNASNet1_3_Weights
        :members:
ekka's avatar
ekka committed
431
    """
limm's avatar
limm committed
432
433
434
    weights = MNASNet1_3_Weights.verify(weights)

    return _mnasnet(1.3, weights, progress, **kwargs)