mobilenetv3.py 9.12 KB
Newer Older
1
2
from functools import partial
from typing import Any, List, Optional, Union
3

4
5
import torch
from torch import nn, Tensor
6
from torch.ao.quantization import QuantStub, DeQuantStub
7

8
from ...ops.misc import Conv2dNormActivation, SqueezeExcitation
9
from ...transforms._presets import ImageClassification
10
11
12
13
14
15
16
17
18
19
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..mobilenetv3 import (
    InvertedResidual,
    InvertedResidualConfig,
    MobileNetV3,
    _mobilenet_v3_conf,
    MobileNet_V3_Large_Weights,
)
20
from .utils import _fuse_modules, _replace_relu
21
22


23
24
25
26
27
__all__ = [
    "QuantizableMobileNetV3",
    "MobileNet_V3_Large_QuantizedWeights",
    "mobilenet_v3_large",
]
28
29


30
class QuantizableSqueezeExcitation(SqueezeExcitation):
31
32
    _version = 2

33
    def __init__(self, *args: Any, **kwargs: Any) -> None:
34
        kwargs["scale_activation"] = nn.Hardsigmoid
35
36
37
38
        super().__init__(*args, **kwargs)
        self.skip_mul = nn.quantized.FloatFunctional()

    def forward(self, input: Tensor) -> Tensor:
39
        return self.skip_mul.mul(self._scale(input), input)
40

41
42
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
        _fuse_modules(self, ["fc1", "activation"], is_qat, inplace=True)
43
44
45
46
47
48
49
50
51
52
53
54
55

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)

56
        if hasattr(self, "qconfig") and (version is None or version < 2):
57
            default_state_dict = {
58
                "scale_activation.activation_post_process.scale": torch.tensor([1.0]),
59
                "scale_activation.activation_post_process.activation_post_process.scale": torch.tensor([1.0]),
60
                "scale_activation.activation_post_process.zero_point": torch.tensor([0], dtype=torch.int32),
61
62
63
                "scale_activation.activation_post_process.activation_post_process.zero_point": torch.tensor(
                    [0], dtype=torch.int32
                ),
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
                "scale_activation.activation_post_process.fake_quant_enabled": torch.tensor([1]),
                "scale_activation.activation_post_process.observer_enabled": torch.tensor([1]),
            }
            for k, v in default_state_dict.items():
                full_key = prefix + k
                if full_key not in state_dict:
                    state_dict[full_key] = v

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )
81
82
83


class QuantizableInvertedResidual(InvertedResidual):
84
85
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
86
        super().__init__(se_layer=QuantizableSqueezeExcitation, *args, **kwargs)  # type: ignore[misc]
87
88
        self.skip_add = nn.quantized.FloatFunctional()

89
    def forward(self, x: Tensor) -> Tensor:
90
91
92
93
94
95
96
        if self.use_res_connect:
            return self.skip_add.add(x, self.block(x))
        else:
            return self.block(x)


class QuantizableMobileNetV3(MobileNetV3):
97
    def __init__(self, *args: Any, **kwargs: Any) -> None:
98
99
100
101
102
103
104
105
106
107
        """
        MobileNet V3 main class

        Args:
           Inherits args from floating point MobileNetV3
        """
        super().__init__(*args, **kwargs)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

108
    def forward(self, x: Tensor) -> Tensor:
109
110
111
112
113
        x = self.quant(x)
        x = self._forward_impl(x)
        x = self.dequant(x)
        return x

114
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
115
        for m in self.modules():
116
            if type(m) is Conv2dNormActivation:
117
                modules_to_fuse = ["0", "1"]
118
                if len(m) == 3 and type(m[2]) is nn.ReLU:
119
                    modules_to_fuse.append("2")
120
                _fuse_modules(m, modules_to_fuse, is_qat, inplace=True)
121
            elif type(m) is QuantizableSqueezeExcitation:
122
                m.fuse_model(is_qat)
123
124
125
126
127


def _mobilenet_v3_model(
    inverted_residual_setting: List[InvertedResidualConfig],
    last_channel: int,
128
    weights: Optional[WeightsEnum],
129
130
    progress: bool,
    quantize: bool,
131
132
    **kwargs: Any,
) -> QuantizableMobileNetV3:
133
134
135
136
137
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
        if "backend" in weights.meta:
            _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
    backend = kwargs.pop("backend", "qnnpack")
138

139
140
141
142
    model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs)
    _replace_relu(model)

    if quantize:
143
144
145
146
        # Instead of quantizing the model and then loading the quantized weights we take a different approach.
        # We prepare the QAT model, load the QAT weights from training and then convert it.
        # This is done to avoid extremely low accuracies observed on the specific model. This is rather a workaround
        # for an unresolved bug on the eager quantization API detailed at: https://github.com/pytorch/vision/issues/5890
147
        model.fuse_model(is_qat=True)
148
149
        model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend)
        torch.ao.quantization.prepare_qat(model, inplace=True)
150

151
152
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))
153

154
    if quantize:
155
        torch.ao.quantization.convert(model, inplace=True)
156
157
158
159
160
        model.eval()

    return model


161
162
163
164
165
166
167
168
169
170
171
class MobileNet_V3_Large_QuantizedWeights(WeightsEnum):
    IMAGENET1K_QNNPACK_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            "num_params": 5483032,
            "min_size": (1, 1),
            "categories": _IMAGENET_CATEGORIES,
            "backend": "qnnpack",
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3",
            "unquantized": MobileNet_V3_Large_Weights.IMAGENET1K_V1,
172
173
174
175
            "metrics": {
                "acc@1": 73.004,
                "acc@5": 90.858,
            },
176
177
178
179
            "_docs": """
                These weights were produced by doing Quantization Aware Training (eager mode) on top of the unquantized
                weights listed below.
            """,
180
181
182
183
184
185
186
187
188
189
190
191
192
        },
    )
    DEFAULT = IMAGENET1K_QNNPACK_V1


@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: MobileNet_V3_Large_QuantizedWeights.IMAGENET1K_QNNPACK_V1
        if kwargs.get("quantize", False)
        else MobileNet_V3_Large_Weights.IMAGENET1K_V1,
    )
)
193
def mobilenet_v3_large(
194
195
    *,
    weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None,
196
197
198
199
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableMobileNetV3:
200
    """
201
202
    MobileNetV3 (Large) model from
    `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`_.
203

204
205
206
    .. note::
        Note that ``quantize = True`` returns a quantized model with 8 bit
        weights. Quantized models only support inference and run on CPUs.
207
        GPU inference is not yet supported.
208
209

    Args:
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        weights (:class:`~torchvision.models.quantization.MobileNet_V3_Large_QuantizedWeights` or :class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
            pretrained weights for the model. See
            :class:`~torchvision.models.quantization.MobileNet_V3_Large_QuantizedWeights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool): If True, displays a progress bar of the
            download to stderr. Default is True.
        quantize (bool): If True, return a quantized version of the model. Default is False.
        **kwargs: parameters passed to the ``torchvision.models.quantization.MobileNet_V3_Large_QuantizedWeights``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/mobilenetv3.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.quantization.MobileNet_V3_Large_QuantizedWeights
        :members:
    .. autoclass:: torchvision.models.MobileNet_V3_Large_Weights
        :members:
        :noindex:
228
    """
229
230
231
232
    weights = (MobileNet_V3_Large_QuantizedWeights if quantize else MobileNet_V3_Large_Weights).verify(weights)

    inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
    return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs)
233
234
235
236
237
238
239
240
241
242
243
244


# The dictionary below is internal implementation detail and will be removed in v0.15
from .._utils import _ModelURLs
from ..mobilenetv3 import model_urls  # noqa: F401


quant_model_urls = _ModelURLs(
    {
        "mobilenet_v3_large_qnnpack": MobileNet_V3_Large_QuantizedWeights.IMAGENET1K_QNNPACK_V1.url,
    }
)