mobilenetv3.py 7.47 KB
Newer Older
1
2
from functools import partial
from typing import Any, List, Optional, Union
3

4
5
import torch
from torch import nn, Tensor
6
from torch.ao.quantization import QuantStub, DeQuantStub
7

8
from ...ops.misc import Conv2dNormActivation, SqueezeExcitation
9
from ...transforms._presets import ImageClassification
10
11
12
13
14
15
16
17
18
19
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..mobilenetv3 import (
    InvertedResidual,
    InvertedResidualConfig,
    MobileNetV3,
    _mobilenet_v3_conf,
    MobileNet_V3_Large_Weights,
)
20
from .utils import _fuse_modules, _replace_relu
21
22


23
24
25
26
27
__all__ = [
    "QuantizableMobileNetV3",
    "MobileNet_V3_Large_QuantizedWeights",
    "mobilenet_v3_large",
]
28
29


30
class QuantizableSqueezeExcitation(SqueezeExcitation):
31
32
    _version = 2

33
    def __init__(self, *args: Any, **kwargs: Any) -> None:
34
        kwargs["scale_activation"] = nn.Hardsigmoid
35
36
37
38
        super().__init__(*args, **kwargs)
        self.skip_mul = nn.quantized.FloatFunctional()

    def forward(self, input: Tensor) -> Tensor:
39
        return self.skip_mul.mul(self._scale(input), input)
40

41
42
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
        _fuse_modules(self, ["fc1", "activation"], is_qat, inplace=True)
43
44
45
46
47
48
49
50
51
52
53
54
55

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)

56
        if hasattr(self, "qconfig") and (version is None or version < 2):
57
            default_state_dict = {
58
                "scale_activation.activation_post_process.scale": torch.tensor([1.0]),
59
                "scale_activation.activation_post_process.activation_post_process.scale": torch.tensor([1.0]),
60
                "scale_activation.activation_post_process.zero_point": torch.tensor([0], dtype=torch.int32),
61
62
63
                "scale_activation.activation_post_process.activation_post_process.zero_point": torch.tensor(
                    [0], dtype=torch.int32
                ),
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
                "scale_activation.activation_post_process.fake_quant_enabled": torch.tensor([1]),
                "scale_activation.activation_post_process.observer_enabled": torch.tensor([1]),
            }
            for k, v in default_state_dict.items():
                full_key = prefix + k
                if full_key not in state_dict:
                    state_dict[full_key] = v

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )
81
82
83


class QuantizableInvertedResidual(InvertedResidual):
84
85
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
86
        super().__init__(se_layer=QuantizableSqueezeExcitation, *args, **kwargs)  # type: ignore[misc]
87
88
        self.skip_add = nn.quantized.FloatFunctional()

89
    def forward(self, x: Tensor) -> Tensor:
90
91
92
93
94
95
96
        if self.use_res_connect:
            return self.skip_add.add(x, self.block(x))
        else:
            return self.block(x)


class QuantizableMobileNetV3(MobileNetV3):
97
    def __init__(self, *args: Any, **kwargs: Any) -> None:
98
99
100
101
102
103
104
105
106
107
        """
        MobileNet V3 main class

        Args:
           Inherits args from floating point MobileNetV3
        """
        super().__init__(*args, **kwargs)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

108
    def forward(self, x: Tensor) -> Tensor:
109
110
111
112
113
        x = self.quant(x)
        x = self._forward_impl(x)
        x = self.dequant(x)
        return x

114
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
115
        for m in self.modules():
116
            if type(m) is Conv2dNormActivation:
117
                modules_to_fuse = ["0", "1"]
118
                if len(m) == 3 and type(m[2]) is nn.ReLU:
119
                    modules_to_fuse.append("2")
120
                _fuse_modules(m, modules_to_fuse, is_qat, inplace=True)
121
            elif type(m) is QuantizableSqueezeExcitation:
122
                m.fuse_model(is_qat)
123
124
125
126
127


def _mobilenet_v3_model(
    inverted_residual_setting: List[InvertedResidualConfig],
    last_channel: int,
128
    weights: Optional[WeightsEnum],
129
130
    progress: bool,
    quantize: bool,
131
132
    **kwargs: Any,
) -> QuantizableMobileNetV3:
133
134
135
136
137
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
        if "backend" in weights.meta:
            _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
    backend = kwargs.pop("backend", "qnnpack")
138

139
140
141
142
    model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs)
    _replace_relu(model)

    if quantize:
143
        model.fuse_model(is_qat=True)
144
145
        model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend)
        torch.ao.quantization.prepare_qat(model, inplace=True)
146

147
148
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))
149

150
    if quantize:
151
        torch.ao.quantization.convert(model, inplace=True)
152
153
154
155
156
        model.eval()

    return model


157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
class MobileNet_V3_Large_QuantizedWeights(WeightsEnum):
    IMAGENET1K_QNNPACK_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            "task": "image_classification",
            "architecture": "MobileNetV3",
            "num_params": 5483032,
            "size": (224, 224),
            "min_size": (1, 1),
            "categories": _IMAGENET_CATEGORIES,
            "backend": "qnnpack",
            "quantization": "Quantization Aware Training",
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3",
            "unquantized": MobileNet_V3_Large_Weights.IMAGENET1K_V1,
            "acc@1": 73.004,
            "acc@5": 90.858,
        },
    )
    DEFAULT = IMAGENET1K_QNNPACK_V1


@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: MobileNet_V3_Large_QuantizedWeights.IMAGENET1K_QNNPACK_V1
        if kwargs.get("quantize", False)
        else MobileNet_V3_Large_Weights.IMAGENET1K_V1,
    )
)
187
def mobilenet_v3_large(
188
189
    *,
    weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None,
190
191
192
193
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableMobileNetV3:
194
195
196
197
198
199
200
201
202
    """
    Constructs a MobileNetV3 Large architecture from
    `"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.

    Note that quantize = True returns a quantized model with 8 bit
    weights. Quantized models only support inference and run on CPUs.
    GPU inference is not yet supported

    Args:
203
204
205
206
        weights (MobileNet_V3_Large_QuantizedWeights or MobileNet_V3_Large_Weights, optional): The pretrained
            weights for the model
        progress (bool): If True, displays a progress bar of the download to stderr
        quantize (bool): If True, returns a quantized model, else returns a float model
207
    """
208
209
210
211
    weights = (MobileNet_V3_Large_QuantizedWeights if quantize else MobileNet_V3_Large_Weights).verify(weights)

    inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
    return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs)