mnasnet.py 10.7 KB
Newer Older
Dmitry Belenko's avatar
Dmitry Belenko committed
1
import warnings
2
from typing import Any, Dict, List
3
4
5

import torch
import torch.nn as nn
6
7
from torch import Tensor

8
from .._internally_replaced_utils import load_state_dict_from_url
9
from ..utils import _log_api_usage_once
10

11
__all__ = ["MNASNet", "mnasnet0_5", "mnasnet0_75", "mnasnet1_0", "mnasnet1_3"]
12
13

_MODEL_URLS = {
14
    "mnasnet0_5": "https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
15
    "mnasnet0_75": None,
16
17
    "mnasnet1_0": "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
    "mnasnet1_3": None,
18
19
20
21
22
23
24
25
}

# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
# 1.0 - tensorflow.
_BN_MOMENTUM = 1 - 0.9997


class _InvertedResidual(nn.Module):
26
    def __init__(
27
        self, in_ch: int, out_ch: int, kernel_size: int, stride: int, expansion_factor: int, bn_momentum: float = 0.1
28
    ) -> None:
29
        super().__init__()
30
31
32
        assert stride in [1, 2]
        assert kernel_size in [3, 5]
        mid_ch = in_ch * expansion_factor
33
        self.apply_residual = in_ch == out_ch and stride == 1
34
35
36
37
38
39
        self.layers = nn.Sequential(
            # Pointwise
            nn.Conv2d(in_ch, mid_ch, 1, bias=False),
            nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
            nn.ReLU(inplace=True),
            # Depthwise
40
            nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, stride=stride, groups=mid_ch, bias=False),
41
42
43
44
            nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
            nn.ReLU(inplace=True),
            # Linear pointwise. Note that there's no activation.
            nn.Conv2d(mid_ch, out_ch, 1, bias=False),
45
46
            nn.BatchNorm2d(out_ch, momentum=bn_momentum),
        )
47

48
    def forward(self, input: Tensor) -> Tensor:
49
50
51
52
53
54
        if self.apply_residual:
            return self.layers(input) + input
        else:
            return self.layers(input)


55
56
57
58
def _stack(
    in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int, bn_momentum: float
) -> nn.Sequential:
    """Creates a stack of inverted residuals."""
59
60
    assert repeats >= 1
    # First one has no skip, because feature map size changes.
61
    first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, bn_momentum=bn_momentum)
62
63
    remaining = []
    for _ in range(1, repeats):
64
        remaining.append(_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, bn_momentum=bn_momentum))
65
66
67
    return nn.Sequential(first, *remaining)


68
def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int:
69
    """Asymmetric rounding to make `val` divisible by `divisor`. With default
70
    bias, will round up, unless the number is no more than 10% greater than the
71
    smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88."""
72
73
74
75
76
    assert 0.0 < round_up_bias < 1.0
    new_val = max(divisor, int(val + divisor / 2) // divisor * divisor)
    return new_val if new_val >= round_up_bias * val else new_val + divisor


77
def _get_depths(alpha: float) -> List[int]:
78
79
    """Scales tensor depths as in reference MobileNet code, prefers rouding up
    rather than down."""
Dmitry Belenko's avatar
Dmitry Belenko committed
80
    depths = [32, 16, 24, 40, 80, 96, 192, 320]
81
82
83
84
    return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]


class MNASNet(torch.nn.Module):
85
    """MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
Dmitry Belenko's avatar
Dmitry Belenko committed
86
    implements the B1 variant of the model.
87
    >>> model = MNASNet(1.0, num_classes=1000)
88
89
90
    >>> x = torch.rand(1, 3, 224, 224)
    >>> y = model(x)
    >>> y.dim()
91
    2
92
93
94
    >>> y.nelement()
    1000
    """
95

Dmitry Belenko's avatar
Dmitry Belenko committed
96
97
    # Version 2 adds depth scaling in the initial stages of the network.
    _version = 2
98

99
    def __init__(self, alpha: float, num_classes: int = 1000, dropout: float = 0.2) -> None:
100
        super().__init__()
101
        _log_api_usage_once(self)
Dmitry Belenko's avatar
Dmitry Belenko committed
102
103
104
105
        assert alpha > 0.0
        self.alpha = alpha
        self.num_classes = num_classes
        depths = _get_depths(alpha)
106
107
        layers = [
            # First layer: regular conv.
Dmitry Belenko's avatar
Dmitry Belenko committed
108
109
            nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
110
111
            nn.ReLU(inplace=True),
            # Depthwise separable, no skip.
112
            nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, groups=depths[0], bias=False),
Dmitry Belenko's avatar
Dmitry Belenko committed
113
            nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
114
            nn.ReLU(inplace=True),
Dmitry Belenko's avatar
Dmitry Belenko committed
115
116
            nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False),
            nn.BatchNorm2d(depths[1], momentum=_BN_MOMENTUM),
117
            # MNASNet blocks: stacks of inverted residuals.
Dmitry Belenko's avatar
Dmitry Belenko committed
118
119
120
121
122
123
            _stack(depths[1], depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
            _stack(depths[2], depths[3], 5, 2, 3, 3, _BN_MOMENTUM),
            _stack(depths[3], depths[4], 5, 2, 6, 3, _BN_MOMENTUM),
            _stack(depths[4], depths[5], 3, 1, 6, 2, _BN_MOMENTUM),
            _stack(depths[5], depths[6], 5, 2, 6, 4, _BN_MOMENTUM),
            _stack(depths[6], depths[7], 3, 1, 6, 1, _BN_MOMENTUM),
124
            # Final mapping to classifier input.
Dmitry Belenko's avatar
Dmitry Belenko committed
125
            nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False),
126
127
128
129
            nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),
            nn.ReLU(inplace=True),
        ]
        self.layers = nn.Sequential(*layers)
130
        self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), nn.Linear(1280, num_classes))
131
132
        self._initialize_weights()

133
    def forward(self, x: Tensor) -> Tensor:
134
135
136
137
138
        x = self.layers(x)
        # Equivalent to global avgpool and removing H and W dimensions.
        x = x.mean([2, 3])
        return self.classifier(x)

139
    def _initialize_weights(self) -> None:
140
141
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
142
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
143
144
145
146
147
148
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
149
                nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="sigmoid")
150
151
                nn.init.zeros_(m.bias)

152
153
154
155
156
157
158
159
160
161
    def _load_from_state_dict(
        self,
        state_dict: Dict,
        prefix: str,
        local_metadata: Dict,
        strict: bool,
        missing_keys: List[str],
        unexpected_keys: List[str],
        error_msgs: List[str],
    ) -> None:
Dmitry Belenko's avatar
Dmitry Belenko committed
162
163
164
165
166
167
168
169
170
171
172
173
174
        version = local_metadata.get("version", None)
        assert version in [1, 2]

        if version == 1 and not self.alpha == 1.0:
            # In the initial version of the model (v1), stem was fixed-size.
            # All other layer configurations were the same. This will patch
            # the model so that it's identical to v1. Model with alpha 1.0 is
            # unaffected.
            depths = _get_depths(self.alpha)
            v1_stem = [
                nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
                nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
                nn.ReLU(inplace=True),
175
                nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False),
Dmitry Belenko's avatar
Dmitry Belenko committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
                nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
                nn.ReLU(inplace=True),
                nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
                nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
                _stack(16, depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
            ]
            for idx, layer in enumerate(v1_stem):
                self.layers[idx] = layer

            # The model is now identical to v1, and must be saved as such.
            self._version = 1
            warnings.warn(
                "A new version of MNASNet model has been implemented. "
                "Your checkpoint was saved using the previous version. "
                "This checkpoint will load and work as before, but "
                "you may want to upgrade by training a newer model or "
                "transfer learning from an updated ImageNet checkpoint.",
193
194
                UserWarning,
            )
Dmitry Belenko's avatar
Dmitry Belenko committed
195

196
        super()._load_from_state_dict(
197
198
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
        )
Dmitry Belenko's avatar
Dmitry Belenko committed
199

200

201
def _load_pretrained(model_name: str, model: nn.Module, progress: bool) -> None:
202
    if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None:
203
        raise ValueError(f"No checkpoint is available for model type {model_name}")
204
    checkpoint_url = _MODEL_URLS[model_name]
205
    model.load_state_dict(load_state_dict_from_url(checkpoint_url, progress=progress))
206
207


208
def mnasnet0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
209
    r"""MNASNet with depth multiplier of 0.5 from
ekka's avatar
ekka committed
210
211
    `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
    <https://arxiv.org/pdf/1807.11626.pdf>`_.
212

ekka's avatar
ekka committed
213
214
215
216
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
217
218
    model = MNASNet(0.5, **kwargs)
    if pretrained:
219
        _load_pretrained("mnasnet0_5", model, progress)
220
221
222
    return model


223
def mnasnet0_75(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
224
    r"""MNASNet with depth multiplier of 0.75 from
ekka's avatar
ekka committed
225
226
    `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
    <https://arxiv.org/pdf/1807.11626.pdf>`_.
227

ekka's avatar
ekka committed
228
229
230
231
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
232
233
    model = MNASNet(0.75, **kwargs)
    if pretrained:
234
        _load_pretrained("mnasnet0_75", model, progress)
235
236
237
    return model


238
def mnasnet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
239
    r"""MNASNet with depth multiplier of 1.0 from
ekka's avatar
ekka committed
240
241
    `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
    <https://arxiv.org/pdf/1807.11626.pdf>`_.
242

ekka's avatar
ekka committed
243
244
245
246
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
247
248
    model = MNASNet(1.0, **kwargs)
    if pretrained:
249
        _load_pretrained("mnasnet1_0", model, progress)
250
251
252
    return model


253
def mnasnet1_3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
254
    r"""MNASNet with depth multiplier of 1.3 from
ekka's avatar
ekka committed
255
256
    `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
    <https://arxiv.org/pdf/1807.11626.pdf>`_.
257

ekka's avatar
ekka committed
258
259
260
261
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
262
263
    model = MNASNet(1.3, **kwargs)
    if pretrained:
264
        _load_pretrained("mnasnet1_3", model, progress)
265
    return model