mnasnet.py 10.9 KB
Newer Older
Dmitry Belenko's avatar
Dmitry Belenko committed
1
import warnings
2
3

import torch
4
from torch import Tensor
5
6
import torch.nn as nn
from .utils import load_state_dict_from_url
7
from typing import Any, Dict, List
8
9
10
11
12

__all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3']

_MODEL_URLS = {
    "mnasnet0_5":
Dmitry Belenko's avatar
Dmitry Belenko committed
13
    "https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
14
15
    "mnasnet0_75": None,
    "mnasnet1_0":
16
    "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
17
18
19
20
21
22
23
24
25
26
    "mnasnet1_3": None
}

# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
# 1.0 - tensorflow.
_BN_MOMENTUM = 1 - 0.9997


class _InvertedResidual(nn.Module):

27
28
29
30
31
32
33
34
35
    def __init__(
        self,
        in_ch: int,
        out_ch: int,
        kernel_size: int,
        stride: int,
        expansion_factor: int,
        bn_momentum: float = 0.1
    ):
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
        super(_InvertedResidual, self).__init__()
        assert stride in [1, 2]
        assert kernel_size in [3, 5]
        mid_ch = in_ch * expansion_factor
        self.apply_residual = (in_ch == out_ch and stride == 1)
        self.layers = nn.Sequential(
            # Pointwise
            nn.Conv2d(in_ch, mid_ch, 1, bias=False),
            nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
            nn.ReLU(inplace=True),
            # Depthwise
            nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2,
                      stride=stride, groups=mid_ch, bias=False),
            nn.BatchNorm2d(mid_ch, momentum=bn_momentum),
            nn.ReLU(inplace=True),
            # Linear pointwise. Note that there's no activation.
            nn.Conv2d(mid_ch, out_ch, 1, bias=False),
            nn.BatchNorm2d(out_ch, momentum=bn_momentum))

55
    def forward(self, input: Tensor) -> Tensor:
56
57
58
59
60
61
        if self.apply_residual:
            return self.layers(input) + input
        else:
            return self.layers(input)


62
63
def _stack(in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int,
           bn_momentum: float) -> nn.Sequential:
64
65
66
67
68
69
70
71
72
73
74
75
76
    """ Creates a stack of inverted residuals. """
    assert repeats >= 1
    # First one has no skip, because feature map size changes.
    first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor,
                              bn_momentum=bn_momentum)
    remaining = []
    for _ in range(1, repeats):
        remaining.append(
            _InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor,
                              bn_momentum=bn_momentum))
    return nn.Sequential(first, *remaining)


77
def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int:
78
79
80
81
82
83
84
85
    """ Asymmetric rounding to make `val` divisible by `divisor`. With default
    bias, will round up, unless the number is no more than 10% greater than the
    smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """
    assert 0.0 < round_up_bias < 1.0
    new_val = max(divisor, int(val + divisor / 2) // divisor * divisor)
    return new_val if new_val >= round_up_bias * val else new_val + divisor


86
def _get_depths(alpha: float) -> List[int]:
87
88
    """ Scales tensor depths as in reference MobileNet code, prefers rouding up
    rather than down. """
Dmitry Belenko's avatar
Dmitry Belenko committed
89
    depths = [32, 16, 24, 40, 80, 96, 192, 320]
90
91
92
93
    return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]


class MNASNet(torch.nn.Module):
Dmitry Belenko's avatar
Dmitry Belenko committed
94
95
    """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
    implements the B1 variant of the model.
96
97
98
99
100
101
102
103
    >>> model = MNASNet(1000, 1.0)
    >>> x = torch.rand(1, 3, 224, 224)
    >>> y = model(x)
    >>> y.dim()
    1
    >>> y.nelement()
    1000
    """
Dmitry Belenko's avatar
Dmitry Belenko committed
104
105
    # Version 2 adds depth scaling in the initial stages of the network.
    _version = 2
106

107
108
109
110
111
112
    def __init__(
        self,
        alpha: float,
        num_classes: int = 1000,
        dropout: float = 0.2
    ):
113
        super(MNASNet, self).__init__()
Dmitry Belenko's avatar
Dmitry Belenko committed
114
115
116
117
        assert alpha > 0.0
        self.alpha = alpha
        self.num_classes = num_classes
        depths = _get_depths(alpha)
118
119
        layers = [
            # First layer: regular conv.
Dmitry Belenko's avatar
Dmitry Belenko committed
120
121
            nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
122
123
            nn.ReLU(inplace=True),
            # Depthwise separable, no skip.
Dmitry Belenko's avatar
Dmitry Belenko committed
124
125
126
            nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1,
                      groups=depths[0], bias=False),
            nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
127
            nn.ReLU(inplace=True),
Dmitry Belenko's avatar
Dmitry Belenko committed
128
129
            nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False),
            nn.BatchNorm2d(depths[1], momentum=_BN_MOMENTUM),
130
            # MNASNet blocks: stacks of inverted residuals.
Dmitry Belenko's avatar
Dmitry Belenko committed
131
132
133
134
135
136
            _stack(depths[1], depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
            _stack(depths[2], depths[3], 5, 2, 3, 3, _BN_MOMENTUM),
            _stack(depths[3], depths[4], 5, 2, 6, 3, _BN_MOMENTUM),
            _stack(depths[4], depths[5], 3, 1, 6, 2, _BN_MOMENTUM),
            _stack(depths[5], depths[6], 5, 2, 6, 4, _BN_MOMENTUM),
            _stack(depths[6], depths[7], 3, 1, 6, 1, _BN_MOMENTUM),
137
            # Final mapping to classifier input.
Dmitry Belenko's avatar
Dmitry Belenko committed
138
            nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False),
139
140
141
142
143
144
145
146
            nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),
            nn.ReLU(inplace=True),
        ]
        self.layers = nn.Sequential(*layers)
        self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True),
                                        nn.Linear(1280, num_classes))
        self._initialize_weights()

147
    def forward(self, x: Tensor) -> Tensor:
148
149
150
151
152
        x = self.layers(x)
        # Equivalent to global avgpool and removing H and W dimensions.
        x = x.mean([2, 3])
        return self.classifier(x)

153
    def _initialize_weights(self) -> None:
154
155
156
157
158
159
160
161
162
163
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out",
                                        nonlinearity="relu")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
Dmitry Belenko's avatar
Dmitry Belenko committed
164
165
                nn.init.kaiming_uniform_(m.weight, mode="fan_out",
                                         nonlinearity="sigmoid")
166
167
                nn.init.zeros_(m.bias)

168
169
    def _load_from_state_dict(self, state_dict: Dict, prefix: str, local_metadata: Dict, strict: bool,
                              missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str]) -> None:
Dmitry Belenko's avatar
Dmitry Belenko committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
        version = local_metadata.get("version", None)
        assert version in [1, 2]

        if version == 1 and not self.alpha == 1.0:
            # In the initial version of the model (v1), stem was fixed-size.
            # All other layer configurations were the same. This will patch
            # the model so that it's identical to v1. Model with alpha 1.0 is
            # unaffected.
            depths = _get_depths(self.alpha)
            v1_stem = [
                nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
                nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
                nn.ReLU(inplace=True),
                nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32,
                          bias=False),
                nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
                nn.ReLU(inplace=True),
                nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
                nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
                _stack(16, depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
            ]
            for idx, layer in enumerate(v1_stem):
                self.layers[idx] = layer

            # The model is now identical to v1, and must be saved as such.
            self._version = 1
            warnings.warn(
                "A new version of MNASNet model has been implemented. "
                "Your checkpoint was saved using the previous version. "
                "This checkpoint will load and work as before, but "
                "you may want to upgrade by training a newer model or "
                "transfer learning from an updated ImageNet checkpoint.",
                UserWarning)

        super(MNASNet, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys,
            unexpected_keys, error_msgs)

208

209
def _load_pretrained(model_name: str, model: nn.Module, progress: bool) -> None:
210
211
212
213
    if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None:
        raise ValueError(
            "No checkpoint is available for model type {}".format(model_name))
    checkpoint_url = _MODEL_URLS[model_name]
Dmitry Belenko's avatar
Dmitry Belenko committed
214
215
    model.load_state_dict(
        load_state_dict_from_url(checkpoint_url, progress=progress))
216
217


218
def mnasnet0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
ekka's avatar
ekka committed
219
220
221
222
223
224
225
    """MNASNet with depth multiplier of 0.5 from
    `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
    <https://arxiv.org/pdf/1807.11626.pdf>`_.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
226
227
    model = MNASNet(0.5, **kwargs)
    if pretrained:
228
        _load_pretrained("mnasnet0_5", model, progress)
229
230
231
    return model


232
def mnasnet0_75(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
ekka's avatar
ekka committed
233
234
235
236
237
238
239
    """MNASNet with depth multiplier of 0.75 from
    `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
    <https://arxiv.org/pdf/1807.11626.pdf>`_.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
240
241
    model = MNASNet(0.75, **kwargs)
    if pretrained:
242
        _load_pretrained("mnasnet0_75", model, progress)
243
244
245
    return model


246
def mnasnet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
ekka's avatar
ekka committed
247
248
249
250
251
252
253
    """MNASNet with depth multiplier of 1.0 from
    `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
    <https://arxiv.org/pdf/1807.11626.pdf>`_.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
254
255
    model = MNASNet(1.0, **kwargs)
    if pretrained:
256
        _load_pretrained("mnasnet1_0", model, progress)
257
258
259
    return model


260
def mnasnet1_3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet:
ekka's avatar
ekka committed
261
262
263
264
265
266
267
    """MNASNet with depth multiplier of 1.3 from
    `"MnasNet: Platform-Aware Neural Architecture Search for Mobile"
    <https://arxiv.org/pdf/1807.11626.pdf>`_.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
268
269
    model = MNASNet(1.3, **kwargs)
    if pretrained:
270
        _load_pretrained("mnasnet1_3", model, progress)
271
    return model