shufflenetv2.py 10.1 KB
Newer Older
1
2
from functools import partial
from typing import Callable, Any, List, Optional
3

Bar's avatar
Bar committed
4
5
import torch
import torch.nn as nn
6
7
from torch import Tensor

8
from ..transforms._presets import ImageClassification
9
from ..utils import _log_api_usage_once
10
11
12
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
Bar's avatar
Bar committed
13

14

15
16
17
18
19
20
21
22
23
24
25
__all__ = [
    "ShuffleNetV2",
    "ShuffleNet_V2_X0_5_Weights",
    "ShuffleNet_V2_X1_0_Weights",
    "ShuffleNet_V2_X1_5_Weights",
    "ShuffleNet_V2_X2_0_Weights",
    "shufflenet_v2_x0_5",
    "shufflenet_v2_x1_0",
    "shufflenet_v2_x1_5",
    "shufflenet_v2_x2_0",
]
Bar's avatar
Bar committed
26
27


28
def channel_shuffle(x: Tensor, groups: int) -> Tensor:
29
    batchsize, num_channels, height, width = x.size()
Bar's avatar
Bar committed
30
31
32
    channels_per_group = num_channels // groups

    # reshape
33
    x = x.view(batchsize, groups, channels_per_group, height, width)
Bar's avatar
Bar committed
34
35
36
37
38
39
40
41
42
43

    x = torch.transpose(x, 1, 2).contiguous()

    # flatten
    x = x.view(batchsize, -1, height, width)

    return x


class InvertedResidual(nn.Module):
44
    def __init__(self, inp: int, oup: int, stride: int) -> None:
45
        super().__init__()
Bar's avatar
Bar committed
46
47

        if not (1 <= stride <= 3):
48
            raise ValueError("illegal stride value")
Bar's avatar
Bar committed
49
50
51
        self.stride = stride

        branch_features = oup // 2
52
53
54
55
        if (self.stride == 1) and (inp != branch_features << 1):
            raise ValueError(
                f"Invalid combination of stride {stride}, inp {inp} and oup {oup} values. If stride == 1 then inp should be equal to oup // 2 << 1."
            )
Bar's avatar
Bar committed
56
57
58

        if self.stride > 1:
            self.branch1 = nn.Sequential(
59
                self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
Bar's avatar
Bar committed
60
                nn.BatchNorm2d(inp),
61
                nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
Bar's avatar
Bar committed
62
63
64
                nn.BatchNorm2d(branch_features),
                nn.ReLU(inplace=True),
            )
65
66
        else:
            self.branch1 = nn.Sequential()
Bar's avatar
Bar committed
67
68

        self.branch2 = nn.Sequential(
69
70
71
72
73
74
75
76
            nn.Conv2d(
                inp if (self.stride > 1) else branch_features,
                branch_features,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=False,
            ),
Bar's avatar
Bar committed
77
78
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
79
            self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
Bar's avatar
Bar committed
80
            nn.BatchNorm2d(branch_features),
81
            nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
Bar's avatar
Bar committed
82
83
84
85
86
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
        )

    @staticmethod
87
    def depthwise_conv(
88
        i: int, o: int, kernel_size: int, stride: int = 1, padding: int = 0, bias: bool = False
89
    ) -> nn.Conv2d:
Bar's avatar
Bar committed
90
91
        return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)

92
    def forward(self, x: Tensor) -> Tensor:
Bar's avatar
Bar committed
93
94
95
96
97
98
99
100
101
102
103
104
        if self.stride == 1:
            x1, x2 = x.chunk(2, dim=1)
            out = torch.cat((x1, self.branch2(x2)), dim=1)
        else:
            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)

        out = channel_shuffle(out, 2)

        return out


class ShuffleNetV2(nn.Module):
105
106
107
108
109
    def __init__(
        self,
        stages_repeats: List[int],
        stages_out_channels: List[int],
        num_classes: int = 1000,
110
        inverted_residual: Callable[..., nn.Module] = InvertedResidual,
111
    ) -> None:
112
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
113
        _log_api_usage_once(self)
Bar's avatar
Bar committed
114

Bar's avatar
Bar committed
115
        if len(stages_repeats) != 3:
116
            raise ValueError("expected stages_repeats as list of 3 positive ints")
Bar's avatar
Bar committed
117
        if len(stages_out_channels) != 5:
118
            raise ValueError("expected stages_out_channels as list of 5 positive ints")
Bar's avatar
Bar committed
119
        self._stage_out_channels = stages_out_channels
ekka's avatar
ekka committed
120

Bar's avatar
Bar committed
121
122
        input_channels = 3
        output_channels = self._stage_out_channels[0]
Bar's avatar
Bar committed
123
124
125
126
127
128
129
130
131
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True),
        )
        input_channels = output_channels

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

132
133
134
135
        # Static annotations for mypy
        self.stage2: nn.Sequential
        self.stage3: nn.Sequential
        self.stage4: nn.Sequential
136
        stage_names = [f"stage{i}" for i in [2, 3, 4]]
137
        for name, repeats, output_channels in zip(stage_names, stages_repeats, self._stage_out_channels[1:]):
138
            seq = [inverted_residual(input_channels, output_channels, 2)]
Bar's avatar
Bar committed
139
            for i in range(repeats - 1):
140
                seq.append(inverted_residual(output_channels, output_channels, 1))
Bar's avatar
Bar committed
141
142
143
            setattr(self, name, nn.Sequential(*seq))
            input_channels = output_channels

Bar's avatar
Bar committed
144
        output_channels = self._stage_out_channels[-1]
Bar's avatar
Bar committed
145
146
147
148
149
150
151
152
        self.conv5 = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True),
        )

        self.fc = nn.Linear(output_channels, num_classes)

153
    def _forward_impl(self, x: Tensor) -> Tensor:
154
        # See note [TorchScript super()]
Bar's avatar
Bar committed
155
156
157
158
159
160
161
162
163
164
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.conv5(x)
        x = x.mean([2, 3])  # globalpool
        x = self.fc(x)
        return x

165
    def forward(self, x: Tensor) -> Tensor:
166
        return self._forward_impl(x)
167

Bar's avatar
Bar committed
168

169
170
171
172
173
174
175
176
177
def _shufflenetv2(
    weights: Optional[WeightsEnum],
    progress: bool,
    *args: Any,
    **kwargs: Any,
) -> ShuffleNetV2:
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

Bar's avatar
Bar committed
178
    model = ShuffleNetV2(*args, **kwargs)
Bar's avatar
Bar committed
179

180
181
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))
Bar's avatar
Bar committed
182
183
184
185

    return model


186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
_COMMON_META = {
    "task": "image_classification",
    "architecture": "ShuffleNetV2",
    "size": (224, 224),
    "min_size": (1, 1),
    "categories": _IMAGENET_CATEGORIES,
    "recipe": "https://github.com/barrh/Shufflenet-v2-Pytorch/tree/v0.1.0",
}


class ShuffleNet_V2_X0_5_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 1366792,
            "acc@1": 69.362,
            "acc@5": 88.316,
        },
    )
    DEFAULT = IMAGENET1K_V1


class ShuffleNet_V2_X1_0_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 2278604,
            "acc@1": 60.552,
            "acc@5": 81.746,
        },
    )
    DEFAULT = IMAGENET1K_V1


class ShuffleNet_V2_X1_5_Weights(WeightsEnum):
    pass


class ShuffleNet_V2_X2_0_Weights(WeightsEnum):
    pass


@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1))
def shufflenet_v2_x0_5(
    *, weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
236
237
238
239
240
241
    """
    Constructs a ShuffleNetV2 with 0.5x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`_.

    Args:
242
        weights (ShuffleNet_V2_X0_5_Weights, optional): The pretrained weights for the model
243
244
        progress (bool): If True, displays a progress bar of the download to stderr
    """
245
    weights = ShuffleNet_V2_X0_5_Weights.verify(weights)
Bar's avatar
Bar committed
246

247
    return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
Bar's avatar
Bar committed
248

249
250
251
252
253

@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1))
def shufflenet_v2_x1_0(
    *, weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
254
255
256
257
258
259
    """
    Constructs a ShuffleNetV2 with 1.0x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`_.

    Args:
260
        weights (ShuffleNet_V2_X1_0_Weights, optional): The pretrained weights for the model
261
262
        progress (bool): If True, displays a progress bar of the download to stderr
    """
263
264
265
    weights = ShuffleNet_V2_X1_0_Weights.verify(weights)

    return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
Bar's avatar
Bar committed
266
267


268
269
270
271
@handle_legacy_interface(weights=("pretrained", None))
def shufflenet_v2_x1_5(
    *, weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
272
273
274
275
276
277
    """
    Constructs a ShuffleNetV2 with 1.5x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`_.

    Args:
278
        weights (ShuffleNet_V2_X1_5_Weights, optional): The pretrained weights for the model
279
280
        progress (bool): If True, displays a progress bar of the download to stderr
    """
281
    weights = ShuffleNet_V2_X1_5_Weights.verify(weights)
Bar's avatar
Bar committed
282

283
    return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
Bar's avatar
Bar committed
284

285
286
287
288
289

@handle_legacy_interface(weights=("pretrained", None))
def shufflenet_v2_x2_0(
    *, weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
290
291
292
293
294
295
    """
    Constructs a ShuffleNetV2 with 2.0x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`_.

    Args:
296
        weights (ShuffleNet_V2_X2_0_Weights, optional): The pretrained weights for the model
297
298
        progress (bool): If True, displays a progress bar of the download to stderr
    """
299
300
301
    weights = ShuffleNet_V2_X2_0_Weights.verify(weights)

    return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)