shufflenetv2.py 10.2 KB
Newer Older
1
2
from functools import partial
from typing import Callable, Any, List, Optional
3

Bar's avatar
Bar committed
4
5
import torch
import torch.nn as nn
6
7
from torch import Tensor

8
from ..transforms._presets import ImageClassification, InterpolationMode
9
from ..utils import _log_api_usage_once
10
11
12
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
Bar's avatar
Bar committed
13

14

15
16
17
18
19
20
21
22
23
24
25
__all__ = [
    "ShuffleNetV2",
    "ShuffleNet_V2_X0_5_Weights",
    "ShuffleNet_V2_X1_0_Weights",
    "ShuffleNet_V2_X1_5_Weights",
    "ShuffleNet_V2_X2_0_Weights",
    "shufflenet_v2_x0_5",
    "shufflenet_v2_x1_0",
    "shufflenet_v2_x1_5",
    "shufflenet_v2_x2_0",
]
Bar's avatar
Bar committed
26
27


28
def channel_shuffle(x: Tensor, groups: int) -> Tensor:
29
    batchsize, num_channels, height, width = x.size()
Bar's avatar
Bar committed
30
31
32
    channels_per_group = num_channels // groups

    # reshape
33
    x = x.view(batchsize, groups, channels_per_group, height, width)
Bar's avatar
Bar committed
34
35
36
37
38
39
40
41
42
43

    x = torch.transpose(x, 1, 2).contiguous()

    # flatten
    x = x.view(batchsize, -1, height, width)

    return x


class InvertedResidual(nn.Module):
44
    def __init__(self, inp: int, oup: int, stride: int) -> None:
45
        super().__init__()
Bar's avatar
Bar committed
46
47

        if not (1 <= stride <= 3):
48
            raise ValueError("illegal stride value")
Bar's avatar
Bar committed
49
50
51
        self.stride = stride

        branch_features = oup // 2
52
53
54
55
        if (self.stride == 1) and (inp != branch_features << 1):
            raise ValueError(
                f"Invalid combination of stride {stride}, inp {inp} and oup {oup} values. If stride == 1 then inp should be equal to oup // 2 << 1."
            )
Bar's avatar
Bar committed
56
57
58

        if self.stride > 1:
            self.branch1 = nn.Sequential(
59
                self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
Bar's avatar
Bar committed
60
                nn.BatchNorm2d(inp),
61
                nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
Bar's avatar
Bar committed
62
63
64
                nn.BatchNorm2d(branch_features),
                nn.ReLU(inplace=True),
            )
65
66
        else:
            self.branch1 = nn.Sequential()
Bar's avatar
Bar committed
67
68

        self.branch2 = nn.Sequential(
69
70
71
72
73
74
75
76
            nn.Conv2d(
                inp if (self.stride > 1) else branch_features,
                branch_features,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=False,
            ),
Bar's avatar
Bar committed
77
78
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
79
            self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
Bar's avatar
Bar committed
80
            nn.BatchNorm2d(branch_features),
81
            nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
Bar's avatar
Bar committed
82
83
84
85
86
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
        )

    @staticmethod
87
    def depthwise_conv(
88
        i: int, o: int, kernel_size: int, stride: int = 1, padding: int = 0, bias: bool = False
89
    ) -> nn.Conv2d:
Bar's avatar
Bar committed
90
91
        return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)

92
    def forward(self, x: Tensor) -> Tensor:
Bar's avatar
Bar committed
93
94
95
96
97
98
99
100
101
102
103
104
        if self.stride == 1:
            x1, x2 = x.chunk(2, dim=1)
            out = torch.cat((x1, self.branch2(x2)), dim=1)
        else:
            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)

        out = channel_shuffle(out, 2)

        return out


class ShuffleNetV2(nn.Module):
105
106
107
108
109
    def __init__(
        self,
        stages_repeats: List[int],
        stages_out_channels: List[int],
        num_classes: int = 1000,
110
        inverted_residual: Callable[..., nn.Module] = InvertedResidual,
111
    ) -> None:
112
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
113
        _log_api_usage_once(self)
Bar's avatar
Bar committed
114

Bar's avatar
Bar committed
115
        if len(stages_repeats) != 3:
116
            raise ValueError("expected stages_repeats as list of 3 positive ints")
Bar's avatar
Bar committed
117
        if len(stages_out_channels) != 5:
118
            raise ValueError("expected stages_out_channels as list of 5 positive ints")
Bar's avatar
Bar committed
119
        self._stage_out_channels = stages_out_channels
ekka's avatar
ekka committed
120

Bar's avatar
Bar committed
121
122
        input_channels = 3
        output_channels = self._stage_out_channels[0]
Bar's avatar
Bar committed
123
124
125
126
127
128
129
130
131
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True),
        )
        input_channels = output_channels

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

132
133
134
135
        # Static annotations for mypy
        self.stage2: nn.Sequential
        self.stage3: nn.Sequential
        self.stage4: nn.Sequential
136
        stage_names = [f"stage{i}" for i in [2, 3, 4]]
137
        for name, repeats, output_channels in zip(stage_names, stages_repeats, self._stage_out_channels[1:]):
138
            seq = [inverted_residual(input_channels, output_channels, 2)]
Bar's avatar
Bar committed
139
            for i in range(repeats - 1):
140
                seq.append(inverted_residual(output_channels, output_channels, 1))
Bar's avatar
Bar committed
141
142
143
            setattr(self, name, nn.Sequential(*seq))
            input_channels = output_channels

Bar's avatar
Bar committed
144
        output_channels = self._stage_out_channels[-1]
Bar's avatar
Bar committed
145
146
147
148
149
150
151
152
        self.conv5 = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True),
        )

        self.fc = nn.Linear(output_channels, num_classes)

153
    def _forward_impl(self, x: Tensor) -> Tensor:
154
        # See note [TorchScript super()]
Bar's avatar
Bar committed
155
156
157
158
159
160
161
162
163
164
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.conv5(x)
        x = x.mean([2, 3])  # globalpool
        x = self.fc(x)
        return x

165
    def forward(self, x: Tensor) -> Tensor:
166
        return self._forward_impl(x)
167

Bar's avatar
Bar committed
168

169
170
171
172
173
174
175
176
177
def _shufflenetv2(
    weights: Optional[WeightsEnum],
    progress: bool,
    *args: Any,
    **kwargs: Any,
) -> ShuffleNetV2:
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

Bar's avatar
Bar committed
178
    model = ShuffleNetV2(*args, **kwargs)
Bar's avatar
Bar committed
179

180
181
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))
Bar's avatar
Bar committed
182
183
184
185

    return model


186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
_COMMON_META = {
    "task": "image_classification",
    "architecture": "ShuffleNetV2",
    "publication_year": 2018,
    "size": (224, 224),
    "min_size": (1, 1),
    "categories": _IMAGENET_CATEGORIES,
    "interpolation": InterpolationMode.BILINEAR,
    "recipe": "https://github.com/barrh/Shufflenet-v2-Pytorch/tree/v0.1.0",
}


class ShuffleNet_V2_X0_5_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 1366792,
            "acc@1": 69.362,
            "acc@5": 88.316,
        },
    )
    DEFAULT = IMAGENET1K_V1


class ShuffleNet_V2_X1_0_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 2278604,
            "acc@1": 60.552,
            "acc@5": 81.746,
        },
    )
    DEFAULT = IMAGENET1K_V1


class ShuffleNet_V2_X1_5_Weights(WeightsEnum):
    pass


class ShuffleNet_V2_X2_0_Weights(WeightsEnum):
    pass


@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1))
def shufflenet_v2_x0_5(
    *, weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
238
239
240
241
242
243
    """
    Constructs a ShuffleNetV2 with 0.5x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`_.

    Args:
244
        weights (ShuffleNet_V2_X0_5_Weights, optional): The pretrained weights for the model
245
246
        progress (bool): If True, displays a progress bar of the download to stderr
    """
247
    weights = ShuffleNet_V2_X0_5_Weights.verify(weights)
Bar's avatar
Bar committed
248

249
    return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
Bar's avatar
Bar committed
250

251
252
253
254
255

@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1))
def shufflenet_v2_x1_0(
    *, weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
256
257
258
259
260
261
    """
    Constructs a ShuffleNetV2 with 1.0x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`_.

    Args:
262
        weights (ShuffleNet_V2_X1_0_Weights, optional): The pretrained weights for the model
263
264
        progress (bool): If True, displays a progress bar of the download to stderr
    """
265
266
267
    weights = ShuffleNet_V2_X1_0_Weights.verify(weights)

    return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
Bar's avatar
Bar committed
268
269


270
271
272
273
@handle_legacy_interface(weights=("pretrained", None))
def shufflenet_v2_x1_5(
    *, weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
274
275
276
277
278
279
    """
    Constructs a ShuffleNetV2 with 1.5x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`_.

    Args:
280
        weights (ShuffleNet_V2_X1_5_Weights, optional): The pretrained weights for the model
281
282
        progress (bool): If True, displays a progress bar of the download to stderr
    """
283
    weights = ShuffleNet_V2_X1_5_Weights.verify(weights)
Bar's avatar
Bar committed
284

285
    return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
Bar's avatar
Bar committed
286

287
288
289
290
291

@handle_legacy_interface(weights=("pretrained", None))
def shufflenet_v2_x2_0(
    *, weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
292
293
294
295
296
297
    """
    Constructs a ShuffleNetV2 with 2.0x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`_.

    Args:
298
        weights (ShuffleNet_V2_X2_0_Weights, optional): The pretrained weights for the model
299
300
        progress (bool): If True, displays a progress bar of the download to stderr
    """
301
302
303
    weights = ShuffleNet_V2_X2_0_Weights.verify(weights)

    return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)