shufflenetv2.py 9.98 KB
Newer Older
1
2
from functools import partial
from typing import Callable, Any, List, Optional
3

Bar's avatar
Bar committed
4
5
import torch
import torch.nn as nn
6
7
from torch import Tensor

8
from ..transforms._presets import ImageClassification
9
from ..utils import _log_api_usage_once
10
11
12
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
Bar's avatar
Bar committed
13

14

15
16
17
18
19
20
21
22
23
24
25
__all__ = [
    "ShuffleNetV2",
    "ShuffleNet_V2_X0_5_Weights",
    "ShuffleNet_V2_X1_0_Weights",
    "ShuffleNet_V2_X1_5_Weights",
    "ShuffleNet_V2_X2_0_Weights",
    "shufflenet_v2_x0_5",
    "shufflenet_v2_x1_0",
    "shufflenet_v2_x1_5",
    "shufflenet_v2_x2_0",
]
Bar's avatar
Bar committed
26
27


28
def channel_shuffle(x: Tensor, groups: int) -> Tensor:
29
    batchsize, num_channels, height, width = x.size()
Bar's avatar
Bar committed
30
31
32
    channels_per_group = num_channels // groups

    # reshape
33
    x = x.view(batchsize, groups, channels_per_group, height, width)
Bar's avatar
Bar committed
34
35
36
37
38
39
40
41
42
43

    x = torch.transpose(x, 1, 2).contiguous()

    # flatten
    x = x.view(batchsize, -1, height, width)

    return x


class InvertedResidual(nn.Module):
44
    def __init__(self, inp: int, oup: int, stride: int) -> None:
45
        super().__init__()
Bar's avatar
Bar committed
46
47

        if not (1 <= stride <= 3):
48
            raise ValueError("illegal stride value")
Bar's avatar
Bar committed
49
50
51
        self.stride = stride

        branch_features = oup // 2
52
53
54
55
        if (self.stride == 1) and (inp != branch_features << 1):
            raise ValueError(
                f"Invalid combination of stride {stride}, inp {inp} and oup {oup} values. If stride == 1 then inp should be equal to oup // 2 << 1."
            )
Bar's avatar
Bar committed
56
57
58

        if self.stride > 1:
            self.branch1 = nn.Sequential(
59
                self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
Bar's avatar
Bar committed
60
                nn.BatchNorm2d(inp),
61
                nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
Bar's avatar
Bar committed
62
63
64
                nn.BatchNorm2d(branch_features),
                nn.ReLU(inplace=True),
            )
65
66
        else:
            self.branch1 = nn.Sequential()
Bar's avatar
Bar committed
67
68

        self.branch2 = nn.Sequential(
69
70
71
72
73
74
75
76
            nn.Conv2d(
                inp if (self.stride > 1) else branch_features,
                branch_features,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=False,
            ),
Bar's avatar
Bar committed
77
78
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
79
            self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
Bar's avatar
Bar committed
80
            nn.BatchNorm2d(branch_features),
81
            nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
Bar's avatar
Bar committed
82
83
84
85
86
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
        )

    @staticmethod
87
    def depthwise_conv(
88
        i: int, o: int, kernel_size: int, stride: int = 1, padding: int = 0, bias: bool = False
89
    ) -> nn.Conv2d:
Bar's avatar
Bar committed
90
91
        return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)

92
    def forward(self, x: Tensor) -> Tensor:
Bar's avatar
Bar committed
93
94
95
96
97
98
99
100
101
102
103
104
        if self.stride == 1:
            x1, x2 = x.chunk(2, dim=1)
            out = torch.cat((x1, self.branch2(x2)), dim=1)
        else:
            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)

        out = channel_shuffle(out, 2)

        return out


class ShuffleNetV2(nn.Module):
105
106
107
108
109
    def __init__(
        self,
        stages_repeats: List[int],
        stages_out_channels: List[int],
        num_classes: int = 1000,
110
        inverted_residual: Callable[..., nn.Module] = InvertedResidual,
111
    ) -> None:
112
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
113
        _log_api_usage_once(self)
Bar's avatar
Bar committed
114

Bar's avatar
Bar committed
115
        if len(stages_repeats) != 3:
116
            raise ValueError("expected stages_repeats as list of 3 positive ints")
Bar's avatar
Bar committed
117
        if len(stages_out_channels) != 5:
118
            raise ValueError("expected stages_out_channels as list of 5 positive ints")
Bar's avatar
Bar committed
119
        self._stage_out_channels = stages_out_channels
ekka's avatar
ekka committed
120

Bar's avatar
Bar committed
121
122
        input_channels = 3
        output_channels = self._stage_out_channels[0]
Bar's avatar
Bar committed
123
124
125
126
127
128
129
130
131
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True),
        )
        input_channels = output_channels

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

132
133
134
135
        # Static annotations for mypy
        self.stage2: nn.Sequential
        self.stage3: nn.Sequential
        self.stage4: nn.Sequential
136
        stage_names = [f"stage{i}" for i in [2, 3, 4]]
137
        for name, repeats, output_channels in zip(stage_names, stages_repeats, self._stage_out_channels[1:]):
138
            seq = [inverted_residual(input_channels, output_channels, 2)]
Bar's avatar
Bar committed
139
            for i in range(repeats - 1):
140
                seq.append(inverted_residual(output_channels, output_channels, 1))
Bar's avatar
Bar committed
141
142
143
            setattr(self, name, nn.Sequential(*seq))
            input_channels = output_channels

Bar's avatar
Bar committed
144
        output_channels = self._stage_out_channels[-1]
Bar's avatar
Bar committed
145
146
147
148
149
150
151
152
        self.conv5 = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True),
        )

        self.fc = nn.Linear(output_channels, num_classes)

153
    def _forward_impl(self, x: Tensor) -> Tensor:
154
        # See note [TorchScript super()]
Bar's avatar
Bar committed
155
156
157
158
159
160
161
162
163
164
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.conv5(x)
        x = x.mean([2, 3])  # globalpool
        x = self.fc(x)
        return x

165
    def forward(self, x: Tensor) -> Tensor:
166
        return self._forward_impl(x)
167

Bar's avatar
Bar committed
168

169
170
171
172
173
174
175
176
177
def _shufflenetv2(
    weights: Optional[WeightsEnum],
    progress: bool,
    *args: Any,
    **kwargs: Any,
) -> ShuffleNetV2:
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

Bar's avatar
Bar committed
178
    model = ShuffleNetV2(*args, **kwargs)
Bar's avatar
Bar committed
179

180
181
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))
Bar's avatar
Bar committed
182
183
184
185

    return model


186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
_COMMON_META = {
    "min_size": (1, 1),
    "categories": _IMAGENET_CATEGORIES,
    "recipe": "https://github.com/barrh/Shufflenet-v2-Pytorch/tree/v0.1.0",
}


class ShuffleNet_V2_X0_5_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 1366792,
            "acc@1": 69.362,
            "acc@5": 88.316,
        },
    )
    DEFAULT = IMAGENET1K_V1


class ShuffleNet_V2_X1_0_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 2278604,
            "acc@1": 60.552,
            "acc@5": 81.746,
        },
    )
    DEFAULT = IMAGENET1K_V1


class ShuffleNet_V2_X1_5_Weights(WeightsEnum):
    pass


class ShuffleNet_V2_X2_0_Weights(WeightsEnum):
    pass


@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1))
def shufflenet_v2_x0_5(
    *, weights: Optional[ShuffleNet_V2_X0_5_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
233
234
235
236
237
238
    """
    Constructs a ShuffleNetV2 with 0.5x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`_.

    Args:
239
        weights (ShuffleNet_V2_X0_5_Weights, optional): The pretrained weights for the model
240
241
        progress (bool): If True, displays a progress bar of the download to stderr
    """
242
    weights = ShuffleNet_V2_X0_5_Weights.verify(weights)
Bar's avatar
Bar committed
243

244
    return _shufflenetv2(weights, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
Bar's avatar
Bar committed
245

246
247
248
249
250

@handle_legacy_interface(weights=("pretrained", ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1))
def shufflenet_v2_x1_0(
    *, weights: Optional[ShuffleNet_V2_X1_0_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
251
252
253
254
255
256
    """
    Constructs a ShuffleNetV2 with 1.0x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`_.

    Args:
257
        weights (ShuffleNet_V2_X1_0_Weights, optional): The pretrained weights for the model
258
259
        progress (bool): If True, displays a progress bar of the download to stderr
    """
260
261
262
    weights = ShuffleNet_V2_X1_0_Weights.verify(weights)

    return _shufflenetv2(weights, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
Bar's avatar
Bar committed
263
264


265
266
267
268
@handle_legacy_interface(weights=("pretrained", None))
def shufflenet_v2_x1_5(
    *, weights: Optional[ShuffleNet_V2_X1_5_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
269
270
271
272
273
274
    """
    Constructs a ShuffleNetV2 with 1.5x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`_.

    Args:
275
        weights (ShuffleNet_V2_X1_5_Weights, optional): The pretrained weights for the model
276
277
        progress (bool): If True, displays a progress bar of the download to stderr
    """
278
    weights = ShuffleNet_V2_X1_5_Weights.verify(weights)
Bar's avatar
Bar committed
279

280
    return _shufflenetv2(weights, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
Bar's avatar
Bar committed
281

282
283
284
285
286

@handle_legacy_interface(weights=("pretrained", None))
def shufflenet_v2_x2_0(
    *, weights: Optional[ShuffleNet_V2_X2_0_Weights] = None, progress: bool = True, **kwargs: Any
) -> ShuffleNetV2:
287
288
289
290
291
292
    """
    Constructs a ShuffleNetV2 with 2.0x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`_.

    Args:
293
        weights (ShuffleNet_V2_X2_0_Weights, optional): The pretrained weights for the model
294
295
        progress (bool): If True, displays a progress bar of the download to stderr
    """
296
297
298
    weights = ShuffleNet_V2_X2_0_Weights.verify(weights)

    return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)