shufflenetv2.py 7.33 KB
Newer Older
1
2
from functools import partial
from typing import Any, List, Optional, Union
3

4
5
import torch
import torch.nn as nn
6
from torch import Tensor
7
from torchvision.models import shufflenetv2
8

9
from ...transforms._presets import ImageClassification
10
11
12
13
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights
14
from .utils import _fuse_modules, _replace_relu, quantize_model
15

16

17
__all__ = [
18
    "QuantizableShuffleNetV2",
19
20
    "ShuffleNet_V2_X0_5_QuantizedWeights",
    "ShuffleNet_V2_X1_0_QuantizedWeights",
21
22
    "shufflenet_v2_x0_5",
    "shufflenet_v2_x1_0",
23
24
25
26
]


class QuantizableInvertedResidual(shufflenetv2.InvertedResidual):
27
    def __init__(self, *args: Any, **kwargs: Any) -> None:
28
        super().__init__(*args, **kwargs)
29
30
        self.cat = nn.quantized.FloatFunctional()

31
    def forward(self, x: Tensor) -> Tensor:
32
33
        if self.stride == 1:
            x1, x2 = x.chunk(2, dim=1)
34
            out = self.cat.cat([x1, self.branch2(x2)], dim=1)
35
        else:
36
            out = self.cat.cat([self.branch1(x), self.branch2(x)], dim=1)
37
38
39
40
41
42
43

        out = shufflenetv2.channel_shuffle(out, 2)

        return out


class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2):
44
45
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
46
        super().__init__(*args, inverted_residual=QuantizableInvertedResidual, **kwargs)  # type: ignore[misc]
47
48
        self.quant = torch.ao.quantization.QuantStub()
        self.dequant = torch.ao.quantization.DeQuantStub()
49

50
    def forward(self, x: Tensor) -> Tensor:
51
        x = self.quant(x)
52
        x = self._forward_impl(x)
53
54
55
        x = self.dequant(x)
        return x

56
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
57
58
59
60
61
62
63
        r"""Fuse conv/bn/relu modules in shufflenetv2 model

        Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
        Model is modified in place.  Note that this operation does not change numerics
        and the model after modification is in floating point
        """
        for name, m in self._modules.items():
64
65
            if name in ["conv1", "conv5"] and m is not None:
                _fuse_modules(m, [["0", "1", "2"]], is_qat, inplace=True)
66
        for m in self.modules():
67
            if type(m) is QuantizableInvertedResidual:
68
                if len(m.branch1._modules.items()) > 0:
69
70
                    _fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], is_qat, inplace=True)
                _fuse_modules(
71
72
                    m.branch2,
                    [["0", "1", "2"], ["3", "4"], ["5", "6", "7"]],
73
                    is_qat,
74
75
76
77
                    inplace=True,
                )


78
def _shufflenetv2(
79
80
81
82
    stages_repeats: List[int],
    stages_out_channels: List[int],
    *,
    weights: Optional[WeightsEnum],
83
84
85
86
    progress: bool,
    quantize: bool,
    **kwargs: Any,
) -> QuantizableShuffleNetV2:
87
88
89
90
91
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
        if "backend" in weights.meta:
            _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
    backend = kwargs.pop("backend", "fbgemm")
92

93
    model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs)
94
95
96
97
    _replace_relu(model)
    if quantize:
        quantize_model(model, backend)

98
99
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))
100
101
102
103

    return model


104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
_COMMON_META = {
    "min_size": (1, 1),
    "categories": _IMAGENET_CATEGORIES,
    "backend": "fbgemm",
    "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
}


class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum):
    IMAGENET1K_FBGEMM_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 1366792,
            "unquantized": ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1,
120
121
122
123
            "metrics": {
                "acc@1": 57.972,
                "acc@5": 79.780,
            },
124
125
126
127
128
129
130
131
132
133
134
135
136
        },
    )
    DEFAULT = IMAGENET1K_FBGEMM_V1


class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
    IMAGENET1K_FBGEMM_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 2278604,
            "unquantized": ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1,
137
138
139
140
            "metrics": {
                "acc@1": 68.360,
                "acc@5": 87.582,
            },
141
142
143
144
145
146
147
148
149
150
151
152
153
        },
    )
    DEFAULT = IMAGENET1K_FBGEMM_V1


@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: ShuffleNet_V2_X0_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1
        if kwargs.get("quantize", False)
        else ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1,
    )
)
154
def shufflenet_v2_x0_5(
155
156
    *,
    weights: Optional[Union[ShuffleNet_V2_X0_5_QuantizedWeights, ShuffleNet_V2_X0_5_Weights]] = None,
157
158
159
160
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableShuffleNetV2:
161
162
163
164
165
166
    """
    Constructs a ShuffleNetV2 with 0.5x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`_.

    Args:
167
168
        weights (ShuffleNet_V2_X0_5_QuantizedWeights or ShuffleNet_V2_X0_5_Weights, optional): The pretrained
            weights for the model
169
        progress (bool): If True, displays a progress bar of the download to stderr
170
        quantize (bool): If True, return a quantized version of the model
171
    """
172
    weights = (ShuffleNet_V2_X0_5_QuantizedWeights if quantize else ShuffleNet_V2_X0_5_Weights).verify(weights)
173
    return _shufflenetv2(
174
        [4, 8, 4], [24, 48, 96, 192, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs
175
    )
176
177


178
179
180
181
182
183
184
185
@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: ShuffleNet_V2_X1_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1
        if kwargs.get("quantize", False)
        else ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1,
    )
)
186
def shufflenet_v2_x1_0(
187
188
    *,
    weights: Optional[Union[ShuffleNet_V2_X1_0_QuantizedWeights, ShuffleNet_V2_X1_0_Weights]] = None,
189
190
191
192
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableShuffleNetV2:
193
194
195
196
197
198
    """
    Constructs a ShuffleNetV2 with 1.0x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`_.

    Args:
199
200
        weights (ShuffleNet_V2_X1_0_QuantizedWeights or ShuffleNet_V2_X1_0_Weights, optional): The pretrained
            weights for the model
201
        progress (bool): If True, displays a progress bar of the download to stderr
202
        quantize (bool): If True, return a quantized version of the model
203
    """
204
    weights = (ShuffleNet_V2_X1_0_QuantizedWeights if quantize else ShuffleNet_V2_X1_0_Weights).verify(weights)
205
    return _shufflenetv2(
206
        [4, 8, 4], [24, 116, 232, 464, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs
207
    )