mobilenet.py 7.37 KB
Newer Older
Francisco Massa's avatar
Francisco Massa committed
1
from torch import nn
2
from torch import Tensor
3
from .utils import load_state_dict_from_url
4
from typing import Callable, Any, Optional, List
5
6
7
8
9
10
11
12


__all__ = ['MobileNetV2', 'mobilenet_v2']


model_urls = {
    'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
}
Francisco Massa's avatar
Francisco Massa committed
13
14


15
def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    :param v:
    :param divisor:
    :param min_value:
    :return:
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


Francisco Massa's avatar
Francisco Massa committed
35
class ConvBNReLU(nn.Sequential):
36
37
38
39
40
41
42
43
44
    def __init__(
        self,
        in_planes: int,
        out_planes: int,
        kernel_size: int = 3,
        stride: int = 1,
        groups: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
Francisco Massa's avatar
Francisco Massa committed
45
        padding = (kernel_size - 1) // 2
46
47
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
Francisco Massa's avatar
Francisco Massa committed
48
49
        super(ConvBNReLU, self).__init__(
            nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
50
            norm_layer(out_planes),
Francisco Massa's avatar
Francisco Massa committed
51
52
53
54
55
            nn.ReLU6(inplace=True)
        )


class InvertedResidual(nn.Module):
56
57
58
59
60
61
62
63
    def __init__(
        self,
        inp: int,
        oup: int,
        stride: int,
        expand_ratio: int,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
Francisco Massa's avatar
Francisco Massa committed
64
65
66
67
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

68
69
70
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

Francisco Massa's avatar
Francisco Massa committed
71
72
73
        hidden_dim = int(round(inp * expand_ratio))
        self.use_res_connect = self.stride == 1 and inp == oup

74
        layers: List[nn.Module] = []
Francisco Massa's avatar
Francisco Massa committed
75
76
        if expand_ratio != 1:
            # pw
77
            layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
Francisco Massa's avatar
Francisco Massa committed
78
79
        layers.extend([
            # dw
80
            ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
Francisco Massa's avatar
Francisco Massa committed
81
82
            # pw-linear
            nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
83
            norm_layer(oup),
Francisco Massa's avatar
Francisco Massa committed
84
85
86
        ])
        self.conv = nn.Sequential(*layers)

87
    def forward(self, x: Tensor) -> Tensor:
Francisco Massa's avatar
Francisco Massa committed
88
89
90
91
92
93
94
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileNetV2(nn.Module):
95
96
97
98
99
100
101
102
103
    def __init__(
        self,
        num_classes: int = 1000,
        width_mult: float = 1.0,
        inverted_residual_setting: Optional[List[List[int]]] = None,
        round_nearest: int = 8,
        block: Optional[Callable[..., nn.Module]] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
104
105
106
107
108
109
110
111
112
        """
        MobileNet V2 main class

        Args:
            num_classes (int): Number of classes
            width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
            inverted_residual_setting: Network structure
            round_nearest (int): Round the number of channels in each layer to be a multiple of this number
            Set to 1 to turn off rounding
113
            block: Module specifying inverted residual building block for mobilenet
114
            norm_layer: Module specifying the normalization layer to use
115

116
        """
Francisco Massa's avatar
Francisco Massa committed
117
        super(MobileNetV2, self).__init__()
118
119
120

        if block is None:
            block = InvertedResidual
121
122
123
124

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

Francisco Massa's avatar
Francisco Massa committed
125
126
        input_channel = 32
        last_channel = 1280
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143

        if inverted_residual_setting is None:
            inverted_residual_setting = [
                # t, c, n, s
                [1, 16, 1, 1],
                [6, 24, 2, 2],
                [6, 32, 3, 2],
                [6, 64, 4, 2],
                [6, 96, 3, 1],
                [6, 160, 3, 2],
                [6, 320, 1, 1],
            ]

        # only check the first element, assuming user knows t,c,n,s are required
        if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
            raise ValueError("inverted_residual_setting should be non-empty "
                             "or a 4-element list, got {}".format(inverted_residual_setting))
Francisco Massa's avatar
Francisco Massa committed
144
145

        # building first layer
146
147
        input_channel = _make_divisible(input_channel * width_mult, round_nearest)
        self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
148
        features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
Francisco Massa's avatar
Francisco Massa committed
149
150
        # building inverted residual blocks
        for t, c, n, s in inverted_residual_setting:
151
            output_channel = _make_divisible(c * width_mult, round_nearest)
Francisco Massa's avatar
Francisco Massa committed
152
153
            for i in range(n):
                stride = s if i == 0 else 1
154
                features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
Francisco Massa's avatar
Francisco Massa committed
155
156
                input_channel = output_channel
        # building last several layers
157
        features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
Francisco Massa's avatar
Francisco Massa committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        # make it nn.Sequential
        self.features = nn.Sequential(*features)

        # building classifier
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.last_channel, num_classes),
        )

        # weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
173
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
Francisco Massa's avatar
Francisco Massa committed
174
175
176
177
178
179
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)

180
    def _forward_impl(self, x: Tensor) -> Tensor:
181
182
        # This exists since TorchScript doesn't support inheritance, so the superclass method
        # (this one) needs to have a name other than `forward` that can be accessed in a subclass
Francisco Massa's avatar
Francisco Massa committed
183
        x = self.features(x)
184
        # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
185
        x = nn.functional.adaptive_avg_pool2d(x, (1, 1)).reshape(x.shape[0], -1)
Francisco Massa's avatar
Francisco Massa committed
186
187
188
        x = self.classifier(x)
        return x

189
    def forward(self, x: Tensor) -> Tensor:
190
        return self._forward_impl(x)
191

Francisco Massa's avatar
Francisco Massa committed
192

193
def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV2:
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    """
    Constructs a MobileNetV2 architecture from
    `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    model = MobileNetV2(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model