shufflenetv2.py 6.37 KB
Newer Older
1
2
from typing import Any

3
4
import torch
import torch.nn as nn
5
from torch import Tensor
6
from torchvision.models import shufflenetv2
7

8
from ..._internally_replaced_utils import load_state_dict_from_url
9
10
11
from .utils import _replace_relu, quantize_model

__all__ = [
12
13
14
15
16
    "QuantizableShuffleNetV2",
    "shufflenet_v2_x0_5",
    "shufflenet_v2_x1_0",
    "shufflenet_v2_x1_5",
    "shufflenet_v2_x2_0",
17
18
19
]

quant_model_urls = {
20
21
22
23
    "shufflenetv2_x0.5_fbgemm": None,
    "shufflenetv2_x1.0_fbgemm": "https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth",
    "shufflenetv2_x1.5_fbgemm": None,
    "shufflenetv2_x2.0_fbgemm": None,
24
25
26
27
}


class QuantizableInvertedResidual(shufflenetv2.InvertedResidual):
28
    def __init__(self, *args: Any, **kwargs: Any) -> None:
29
        super().__init__(*args, **kwargs)
30
31
        self.cat = nn.quantized.FloatFunctional()

32
    def forward(self, x: Tensor) -> Tensor:
33
34
        if self.stride == 1:
            x1, x2 = x.chunk(2, dim=1)
35
            out = self.cat.cat([x1, self.branch2(x2)], dim=1)
36
        else:
37
            out = self.cat.cat([self.branch1(x), self.branch2(x)], dim=1)
38
39
40
41
42
43
44

        out = shufflenetv2.channel_shuffle(out, 2)

        return out


class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2):
45
46
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
47
        super().__init__(*args, inverted_residual=QuantizableInvertedResidual, **kwargs)  # type: ignore[misc]
48
49
50
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

51
    def forward(self, x: Tensor) -> Tensor:
52
        x = self.quant(x)
53
        x = self._forward_impl(x)
54
55
56
        x = self.dequant(x)
        return x

57
    def fuse_model(self) -> None:
58
59
60
61
62
63
64
65
66
67
68
        r"""Fuse conv/bn/relu modules in shufflenetv2 model

        Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
        Model is modified in place.  Note that this operation does not change numerics
        and the model after modification is in floating point
        """

        for name, m in self._modules.items():
            if name in ["conv1", "conv5"]:
                torch.quantization.fuse_modules(m, [["0", "1", "2"]], inplace=True)
        for m in self.modules():
69
            if type(m) is QuantizableInvertedResidual:
70
                if len(m.branch1._modules.items()) > 0:
71
                    torch.quantization.fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True)
72
73
74
75
76
77
78
                torch.quantization.fuse_modules(
                    m.branch2,
                    [["0", "1", "2"], ["3", "4"], ["5", "6", "7"]],
                    inplace=True,
                )


79
80
81
82
83
84
85
86
87
def _shufflenetv2(
    arch: str,
    pretrained: bool,
    progress: bool,
    quantize: bool,
    *args: Any,
    **kwargs: Any,
) -> QuantizableShuffleNetV2:

88
89
90
91
92
    model = QuantizableShuffleNetV2(*args, **kwargs)
    _replace_relu(model)

    if quantize:
        # TODO use pretrained as a string to specify the backend
93
        backend = "fbgemm"
94
95
96
97
98
99
        quantize_model(model, backend)
    else:
        assert pretrained in [True, False]

    if pretrained:
        if quantize:
100
            model_url = quant_model_urls[arch + "_" + backend]
101
102
103
        else:
            model_url = shufflenetv2.model_urls[arch]

104
        state_dict = load_state_dict_from_url(model_url, progress=progress)
105
106
107
108
109

        model.load_state_dict(state_dict)
    return model


110
111
112
113
114
115
def shufflenet_v2_x0_5(
    pretrained: bool = False,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableShuffleNetV2:
116
117
118
119
120
121
122
123
    """
    Constructs a ShuffleNetV2 with 0.5x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
124
        quantize (bool): If True, return a quantized version of the model
125
    """
126
127
128
    return _shufflenetv2(
        "shufflenetv2_x0.5", pretrained, progress, quantize, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs
    )
129
130


131
132
133
134
135
136
def shufflenet_v2_x1_0(
    pretrained: bool = False,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableShuffleNetV2:
137
138
139
140
141
142
143
144
    """
    Constructs a ShuffleNetV2 with 1.0x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
145
        quantize (bool): If True, return a quantized version of the model
146
    """
147
148
149
    return _shufflenetv2(
        "shufflenetv2_x1.0", pretrained, progress, quantize, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs
    )
150
151


152
153
154
155
156
157
def shufflenet_v2_x1_5(
    pretrained: bool = False,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableShuffleNetV2:
158
159
160
161
162
163
164
165
    """
    Constructs a ShuffleNetV2 with 1.5x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
166
        quantize (bool): If True, return a quantized version of the model
167
    """
168
169
170
    return _shufflenetv2(
        "shufflenetv2_x1.5", pretrained, progress, quantize, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs
    )
171
172


173
174
175
176
177
178
def shufflenet_v2_x2_0(
    pretrained: bool = False,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableShuffleNetV2:
179
180
181
182
183
184
185
186
    """
    Constructs a ShuffleNetV2 with 2.0x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
187
        quantize (bool): If True, return a quantized version of the model
188
    """
189
190
191
    return _shufflenetv2(
        "shufflenetv2_x2.0", pretrained, progress, quantize, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs
    )