"docs/source/en/training/cogvideox.md" did not exist on "3e9a28a8a19686b7b66701f7b93d3358d682a5ae"
mobilenetv3.py 8.09 KB
Newer Older
1
2
from functools import partial
from typing import Any, List, Optional, Union
3

4
5
import torch
from torch import nn, Tensor
6
from torch.ao.quantization import QuantStub, DeQuantStub
7

8
from ...ops.misc import Conv2dNormActivation, SqueezeExcitation
9
from ...transforms._presets import ImageClassification
10
11
12
13
14
15
16
17
18
19
from .._api import WeightsEnum, Weights
from .._meta import _IMAGENET_CATEGORIES
from .._utils import handle_legacy_interface, _ovewrite_named_param
from ..mobilenetv3 import (
    InvertedResidual,
    InvertedResidualConfig,
    MobileNetV3,
    _mobilenet_v3_conf,
    MobileNet_V3_Large_Weights,
)
20
from .utils import _fuse_modules, _replace_relu
21
22


23
24
25
26
27
__all__ = [
    "QuantizableMobileNetV3",
    "MobileNet_V3_Large_QuantizedWeights",
    "mobilenet_v3_large",
]
28
29


30
class QuantizableSqueezeExcitation(SqueezeExcitation):
31
32
    _version = 2

33
    def __init__(self, *args: Any, **kwargs: Any) -> None:
34
        kwargs["scale_activation"] = nn.Hardsigmoid
35
36
37
38
        super().__init__(*args, **kwargs)
        self.skip_mul = nn.quantized.FloatFunctional()

    def forward(self, input: Tensor) -> Tensor:
39
        return self.skip_mul.mul(self._scale(input), input)
40

41
42
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
        _fuse_modules(self, ["fc1", "activation"], is_qat, inplace=True)
43
44
45
46
47
48
49
50
51
52
53
54
55

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)

56
        if hasattr(self, "qconfig") and (version is None or version < 2):
57
            default_state_dict = {
58
                "scale_activation.activation_post_process.scale": torch.tensor([1.0]),
59
                "scale_activation.activation_post_process.activation_post_process.scale": torch.tensor([1.0]),
60
                "scale_activation.activation_post_process.zero_point": torch.tensor([0], dtype=torch.int32),
61
62
63
                "scale_activation.activation_post_process.activation_post_process.zero_point": torch.tensor(
                    [0], dtype=torch.int32
                ),
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
                "scale_activation.activation_post_process.fake_quant_enabled": torch.tensor([1]),
                "scale_activation.activation_post_process.observer_enabled": torch.tensor([1]),
            }
            for k, v in default_state_dict.items():
                full_key = prefix + k
                if full_key not in state_dict:
                    state_dict[full_key] = v

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )
81
82
83


class QuantizableInvertedResidual(InvertedResidual):
84
85
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
86
        super().__init__(se_layer=QuantizableSqueezeExcitation, *args, **kwargs)  # type: ignore[misc]
87
88
        self.skip_add = nn.quantized.FloatFunctional()

89
    def forward(self, x: Tensor) -> Tensor:
90
91
92
93
94
95
96
        if self.use_res_connect:
            return self.skip_add.add(x, self.block(x))
        else:
            return self.block(x)


class QuantizableMobileNetV3(MobileNetV3):
97
    def __init__(self, *args: Any, **kwargs: Any) -> None:
98
99
100
101
102
103
104
105
106
107
        """
        MobileNet V3 main class

        Args:
           Inherits args from floating point MobileNetV3
        """
        super().__init__(*args, **kwargs)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

108
    def forward(self, x: Tensor) -> Tensor:
109
110
111
112
113
        x = self.quant(x)
        x = self._forward_impl(x)
        x = self.dequant(x)
        return x

114
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
115
        for m in self.modules():
116
            if type(m) is Conv2dNormActivation:
117
                modules_to_fuse = ["0", "1"]
118
                if len(m) == 3 and type(m[2]) is nn.ReLU:
119
                    modules_to_fuse.append("2")
120
                _fuse_modules(m, modules_to_fuse, is_qat, inplace=True)
121
            elif type(m) is QuantizableSqueezeExcitation:
122
                m.fuse_model(is_qat)
123
124
125
126
127


def _mobilenet_v3_model(
    inverted_residual_setting: List[InvertedResidualConfig],
    last_channel: int,
128
    weights: Optional[WeightsEnum],
129
130
    progress: bool,
    quantize: bool,
131
132
    **kwargs: Any,
) -> QuantizableMobileNetV3:
133
134
135
136
137
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
        if "backend" in weights.meta:
            _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
    backend = kwargs.pop("backend", "qnnpack")
138

139
140
141
142
    model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs)
    _replace_relu(model)

    if quantize:
143
144
145
146
        # Instead of quantizing the model and then loading the quantized weights we take a different approach.
        # We prepare the QAT model, load the QAT weights from training and then convert it.
        # This is done to avoid extremely low accuracies observed on the specific model. This is rather a workaround
        # for an unresolved bug on the eager quantization API detailed at: https://github.com/pytorch/vision/issues/5890
147
        model.fuse_model(is_qat=True)
148
149
        model.qconfig = torch.ao.quantization.get_default_qat_qconfig(backend)
        torch.ao.quantization.prepare_qat(model, inplace=True)
150

151
152
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))
153

154
    if quantize:
155
        torch.ao.quantization.convert(model, inplace=True)
156
157
158
159
160
        model.eval()

    return model


161
162
163
164
165
166
167
168
169
170
171
class MobileNet_V3_Large_QuantizedWeights(WeightsEnum):
    IMAGENET1K_QNNPACK_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            "num_params": 5483032,
            "min_size": (1, 1),
            "categories": _IMAGENET_CATEGORIES,
            "backend": "qnnpack",
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3",
            "unquantized": MobileNet_V3_Large_Weights.IMAGENET1K_V1,
172
173
174
175
            "metrics": {
                "acc@1": 73.004,
                "acc@5": 90.858,
            },
176
177
178
179
180
181
182
183
184
185
186
187
188
        },
    )
    DEFAULT = IMAGENET1K_QNNPACK_V1


@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: MobileNet_V3_Large_QuantizedWeights.IMAGENET1K_QNNPACK_V1
        if kwargs.get("quantize", False)
        else MobileNet_V3_Large_Weights.IMAGENET1K_V1,
    )
)
189
def mobilenet_v3_large(
190
191
    *,
    weights: Optional[Union[MobileNet_V3_Large_QuantizedWeights, MobileNet_V3_Large_Weights]] = None,
192
193
194
195
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableMobileNetV3:
196
197
198
199
200
201
202
203
204
    """
    Constructs a MobileNetV3 Large architecture from
    `"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.

    Note that quantize = True returns a quantized model with 8 bit
    weights. Quantized models only support inference and run on CPUs.
    GPU inference is not yet supported

    Args:
205
206
207
208
        weights (MobileNet_V3_Large_QuantizedWeights or MobileNet_V3_Large_Weights, optional): The pretrained
            weights for the model
        progress (bool): If True, displays a progress bar of the download to stderr
        quantize (bool): If True, returns a quantized model, else returns a float model
209
    """
210
211
212
213
    weights = (MobileNet_V3_Large_QuantizedWeights if quantize else MobileNet_V3_Large_Weights).verify(weights)

    inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
    return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs)
214
215
216
217
218
219
220
221
222
223
224
225


# The dictionary below is internal implementation detail and will be removed in v0.15
from .._utils import _ModelURLs
from ..mobilenetv3 import model_urls  # noqa: F401


quant_model_urls = _ModelURLs(
    {
        "mobilenet_v3_large_qnnpack": MobileNet_V3_Large_QuantizedWeights.IMAGENET1K_QNNPACK_V1.url,
    }
)