mnasnet.py 12.7 KB
Newer Older
Dmitry Belenko's avatar
Dmitry Belenko committed
1
import warnings
2
3
from functools import partial
from typing import Any, Dict, List, Optional
4
5
6

import torch
import torch.nn as nn
7
8
from torch import Tensor

9
from ..transforms._presets import ImageClassification
10
from ..utils import _log_api_usage_once
11
12
13
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
14
15


16
17
18
19
20
21
22
23
24
25
26
27
__all__ = [
    "MNASNet",
    "MNASNet0_5_Weights",
    "MNASNet0_75_Weights",
    "MNASNet1_0_Weights",
    "MNASNet1_3_Weights",
    "mnasnet0_5",
    "mnasnet0_75",
    "mnasnet1_0",
    "mnasnet1_3",
]

28
29
30
31
32
33
34

# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
# 1.0 - tensorflow.
_BN_MOMENTUM = 1 - 0.9997


class _InvertedResidual(nn.Module):
35
    def __init__(
36
        self, in_ch: int, out_ch: int, kernel_size: int, stride: int, expansion_factor: int, bn_momentum: float = 0.1
37
    ) -> None:
38
        super().__init__()
39
40
41
42
        if stride not in [1, 2]:
            raise ValueError(f"stride should be 1 or 2 instead of {stride}")
        if kernel_size not in [3, 5]:
            raise ValueError(f"kernel_size should be 3 or 5 instead of {kernel_size}")
43
        mid_ch = in_ch * expansion_factor
44
        self.apply_residual = in_ch == out_ch and stride == 1
45
46
47
48
49
50
        self.layers = nn.Sequential(
            # Pointwise
            nn.Conv2d(in_ch, mid_ch, 1, bias=False),
            nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
            nn.ReLU(inplace=True),
            # Depthwise
51
            nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, stride=stride, groups=mid_ch, bias=False),
52
53
54
55
            nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
            nn.ReLU(inplace=True),
            # Linear pointwise. Note that there's no activation.
            nn.Conv2d(mid_ch, out_ch, 1, bias=False),
56
57
            nn.BatchNorm2d(out_ch, momentum=bn_momentum),
        )
58

59
    def forward(self, input: Tensor) -> Tensor:
60
61
62
63
64
65
        if self.apply_residual:
            return self.layers(input) + input
        else:
            return self.layers(input)


66
67
68
69
def _stack(
    in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int, bn_momentum: float
) -> nn.Sequential:
    """Creates a stack of inverted residuals."""
70
71
    if repeats < 1:
        raise ValueError(f"repeats should be >= 1, instead got {repeats}")
72
    # First one has no skip, because feature map size changes.
73
    first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, bn_momentum=bn_momentum)
74
75
    remaining = []
    for _ in range(1, repeats):
76
        remaining.append(_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, bn_momentum=bn_momentum))
77
78
79
    return nn.Sequential(first, *remaining)


80
def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int:
81
    """Asymmetric rounding to make `val` divisible by `divisor`. With default
82
    bias, will round up, unless the number is no more than 10% greater than the
83
    smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88."""
84
85
    if not 0.0 < round_up_bias < 1.0:
        raise ValueError(f"round_up_bias should be greater than 0.0 and smaller than 1.0 instead of {round_up_bias}")
86
87
88
89
    new_val = max(divisor, int(val + divisor / 2) // divisor * divisor)
    return new_val if new_val >= round_up_bias * val else new_val + divisor


90
def _get_depths(alpha: float) -> List[int]:
91
92
    """Scales tensor depths as in reference MobileNet code, prefers rouding up
    rather than down."""
Dmitry Belenko's avatar
Dmitry Belenko committed
93
    depths = [32, 16, 24, 40, 80, 96, 192, 320]
94
95
96
97
    return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]


class MNASNet(torch.nn.Module):
98
    """MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
Dmitry Belenko's avatar
Dmitry Belenko committed
99
    implements the B1 variant of the model.
100
    >>> model = MNASNet(1.0, num_classes=1000)
101
102
103
    >>> x = torch.rand(1, 3, 224, 224)
    >>> y = model(x)
    >>> y.dim()
104
    2
105
106
107
    >>> y.nelement()
    1000
    """
108

Dmitry Belenko's avatar
Dmitry Belenko committed
109
110
    # Version 2 adds depth scaling in the initial stages of the network.
    _version = 2
111

112
    def __init__(self, alpha: float, num_classes: int = 1000, dropout: float = 0.2) -> None:
113
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
114
        _log_api_usage_once(self)
115
116
        if alpha <= 0.0:
            raise ValueError(f"alpha should be greater than 0.0 instead of {alpha}")
Dmitry Belenko's avatar
Dmitry Belenko committed
117
118
119
        self.alpha = alpha
        self.num_classes = num_classes
        depths = _get_depths(alpha)
120
121
        layers = [
            # First layer: regular conv.
Dmitry Belenko's avatar
Dmitry Belenko committed
122
123
            nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
124
125
            nn.ReLU(inplace=True),
            # Depthwise separable, no skip.
126
            nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, groups=depths[0], bias=False),
Dmitry Belenko's avatar
Dmitry Belenko committed
127
            nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
128
            nn.ReLU(inplace=True),
Dmitry Belenko's avatar
Dmitry Belenko committed
129
130
            nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False),
            nn.BatchNorm2d(depths[1], momentum=_BN_MOMENTUM),
131
            # MNASNet blocks: stacks of inverted residuals.
Dmitry Belenko's avatar
Dmitry Belenko committed
132
133
134
135
136
137
            _stack(depths[1], depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
            _stack(depths[2], depths[3], 5, 2, 3, 3, _BN_MOMENTUM),
            _stack(depths[3], depths[4], 5, 2, 6, 3, _BN_MOMENTUM),
            _stack(depths[4], depths[5], 3, 1, 6, 2, _BN_MOMENTUM),
            _stack(depths[5], depths[6], 5, 2, 6, 4, _BN_MOMENTUM),
            _stack(depths[6], depths[7], 3, 1, 6, 1, _BN_MOMENTUM),
138
            # Final mapping to classifier input.
Dmitry Belenko's avatar
Dmitry Belenko committed
139
            nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False),
140
141
142
143
            nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),
            nn.ReLU(inplace=True),
        ]
        self.layers = nn.Sequential(*layers)
144
        self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), nn.Linear(1280, num_classes))
145
146
147

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
148
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
149
150
151
152
153
154
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
155
                nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="sigmoid")
156
157
                nn.init.zeros_(m.bias)

158
159
160
161
162
163
    def forward(self, x: Tensor) -> Tensor:
        x = self.layers(x)
        # Equivalent to global avgpool and removing H and W dimensions.
        x = x.mean([2, 3])
        return self.classifier(x)

164
165
166
167
168
169
170
171
172
173
    def _load_from_state_dict(
        self,
        state_dict: Dict,
        prefix: str,
        local_metadata: Dict,
        strict: bool,
        missing_keys: List[str],
        unexpected_keys: List[str],
        error_msgs: List[str],
    ) -> None:
Dmitry Belenko's avatar
Dmitry Belenko committed
174
        version = local_metadata.get("version", None)
175
176
        if version not in [1, 2]:
            raise ValueError(f"version shluld be set to 1 or 2 instead of {version}")
Dmitry Belenko's avatar
Dmitry Belenko committed
177
178
179
180
181
182
183
184
185
186
187

        if version == 1 and not self.alpha == 1.0:
            # In the initial version of the model (v1), stem was fixed-size.
            # All other layer configurations were the same. This will patch
            # the model so that it's identical to v1. Model with alpha 1.0 is
            # unaffected.
            depths = _get_depths(self.alpha)
            v1_stem = [
                nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
                nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
                nn.ReLU(inplace=True),
188
                nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False),
Dmitry Belenko's avatar
Dmitry Belenko committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
                nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
                nn.ReLU(inplace=True),
                nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
                nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
                _stack(16, depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
            ]
            for idx, layer in enumerate(v1_stem):
                self.layers[idx] = layer

            # The model is now identical to v1, and must be saved as such.
            self._version = 1
            warnings.warn(
                "A new version of MNASNet model has been implemented. "
                "Your checkpoint was saved using the previous version. "
                "This checkpoint will load and work as before, but "
                "you may want to upgrade by training a newer model or "
                "transfer learning from an updated ImageNet checkpoint.",
206
207
                UserWarning,
            )
Dmitry Belenko's avatar
Dmitry Belenko committed
208

209
        super()._load_from_state_dict(
210
211
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
        )
Dmitry Belenko's avatar
Dmitry Belenko committed
212

213

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
_COMMON_META = {
    "task": "image_classification",
    "architecture": "MNASNet",
    "size": (224, 224),
    "min_size": (1, 1),
    "categories": _IMAGENET_CATEGORIES,
    "recipe": "https://github.com/1e100/mnasnet_trainer",
}


class MNASNet0_5_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 2218512,
            "acc@1": 67.734,
            "acc@5": 87.490,
        },
    )
    DEFAULT = IMAGENET1K_V1


class MNASNet0_75_Weights(WeightsEnum):
    # If a default model is added here the corresponding changes need to be done in mnasnet0_75
    pass


class MNASNet1_0_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 4383312,
            "acc@1": 73.456,
            "acc@5": 91.510,
        },
    )
    DEFAULT = IMAGENET1K_V1


class MNASNet1_3_Weights(WeightsEnum):
    # If a default model is added here the corresponding changes need to be done in mnasnet1_3
    pass


def _mnasnet(alpha: float, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> MNASNet:
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
265

266
    model = MNASNet(alpha, **kwargs)
267

268
269
270
271
272
273
274
275
    if weights:
        model.load_state_dict(weights.get_state_dict(progress=progress))

    return model


@handle_legacy_interface(weights=("pretrained", MNASNet0_5_Weights.IMAGENET1K_V1))
def mnasnet0_5(*, weights: Optional[MNASNet0_5_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
276
    r"""MNASNet with depth multiplier of 0.5 from
ekka's avatar
ekka committed
277
278
    `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
    <https://arxiv.org/pdf/1807.11626.pdf>`_.
279

ekka's avatar
ekka committed
280
    Args:
281
        weights (MNASNet0_5_Weights, optional): The pretrained weights for the model
ekka's avatar
ekka committed
282
283
        progress (bool): If True, displays a progress bar of the download to stderr
    """
284
285
286
    weights = MNASNet0_5_Weights.verify(weights)

    return _mnasnet(0.5, weights, progress, **kwargs)
287
288


289
290
@handle_legacy_interface(weights=("pretrained", None))
def mnasnet0_75(*, weights: Optional[MNASNet0_75_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
291
    r"""MNASNet with depth multiplier of 0.75 from
ekka's avatar
ekka committed
292
293
    `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
    <https://arxiv.org/pdf/1807.11626.pdf>`_.
294

ekka's avatar
ekka committed
295
    Args:
296
        weights (MNASNet0_75_Weights, optional): The pretrained weights for the model
ekka's avatar
ekka committed
297
298
        progress (bool): If True, displays a progress bar of the download to stderr
    """
299
300
301
    weights = MNASNet0_75_Weights.verify(weights)

    return _mnasnet(0.75, weights, progress, **kwargs)
302
303


304
305
@handle_legacy_interface(weights=("pretrained", MNASNet1_0_Weights.IMAGENET1K_V1))
def mnasnet1_0(*, weights: Optional[MNASNet1_0_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
306
    r"""MNASNet with depth multiplier of 1.0 from
ekka's avatar
ekka committed
307
308
    `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
    <https://arxiv.org/pdf/1807.11626.pdf>`_.
309

ekka's avatar
ekka committed
310
    Args:
311
        weights (MNASNet1_0_Weights, optional): The pretrained weights for the model
ekka's avatar
ekka committed
312
313
        progress (bool): If True, displays a progress bar of the download to stderr
    """
314
    weights = MNASNet1_0_Weights.verify(weights)
315

316
    return _mnasnet(1.0, weights, progress, **kwargs)
317

318
319
320

@handle_legacy_interface(weights=("pretrained", None))
def mnasnet1_3(*, weights: Optional[MNASNet1_3_Weights] = None, progress: bool = True, **kwargs: Any) -> MNASNet:
321
    r"""MNASNet with depth multiplier of 1.3 from
ekka's avatar
ekka committed
322
323
    `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
    <https://arxiv.org/pdf/1807.11626.pdf>`_.
324

ekka's avatar
ekka committed
325
    Args:
326
        weights (MNASNet1_3_Weights, optional): The pretrained weights for the model
ekka's avatar
ekka committed
327
328
        progress (bool): If True, displays a progress bar of the download to stderr
    """
329
330
331
    weights = MNASNet1_3_Weights.verify(weights)

    return _mnasnet(1.3, weights, progress, **kwargs)