mobilenetv3.py 6.07 KB
Newer Older
1
2
from typing import Any, List, Optional

3
4
import torch
from torch import nn, Tensor
5
6
from torch.quantization import QuantStub, DeQuantStub, fuse_modules

7
from ..._internally_replaced_utils import load_state_dict_from_url
8
from ...ops.misc import ConvNormActivation, SqueezeExcitation
9
from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, model_urls, _mobilenet_v3_conf
10
11
12
from .utils import _replace_relu


13
__all__ = ["QuantizableMobileNetV3", "mobilenet_v3_large"]
14
15

quant_model_urls = {
16
    "mobilenet_v3_large_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
17
18
19
}


20
class QuantizableSqueezeExcitation(SqueezeExcitation):
21
22
    _version = 2

23
    def __init__(self, *args: Any, **kwargs: Any) -> None:
24
        kwargs["scale_activation"] = nn.Hardsigmoid
25
26
27
28
        super().__init__(*args, **kwargs)
        self.skip_mul = nn.quantized.FloatFunctional()

    def forward(self, input: Tensor) -> Tensor:
29
        return self.skip_mul.mul(self._scale(input), input)
30

31
    def fuse_model(self) -> None:
32
        fuse_modules(self, ["fc1", "activation"], inplace=True)
33
34
35
36
37
38
39
40
41
42
43
44
45

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)

46
        if hasattr(self, "qconfig") and (version is None or version < 2):
47
            default_state_dict = {
48
                "scale_activation.activation_post_process.scale": torch.tensor([1.0]),
49
                "scale_activation.activation_post_process.activation_post_process.scale": torch.tensor([1.0]),
50
                "scale_activation.activation_post_process.zero_point": torch.tensor([0], dtype=torch.int32),
51
52
53
                "scale_activation.activation_post_process.activation_post_process.zero_point": torch.tensor(
                    [0], dtype=torch.int32
                ),
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
                "scale_activation.activation_post_process.fake_quant_enabled": torch.tensor([1]),
                "scale_activation.activation_post_process.observer_enabled": torch.tensor([1]),
            }
            for k, v in default_state_dict.items():
                full_key = prefix + k
                if full_key not in state_dict:
                    state_dict[full_key] = v

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )
71
72
73


class QuantizableInvertedResidual(InvertedResidual):
74
75
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
76
        super().__init__(se_layer=QuantizableSqueezeExcitation, *args, **kwargs)  # type: ignore[misc]
77
78
        self.skip_add = nn.quantized.FloatFunctional()

79
    def forward(self, x: Tensor) -> Tensor:
80
81
82
83
84
85
86
        if self.use_res_connect:
            return self.skip_add.add(x, self.block(x))
        else:
            return self.block(x)


class QuantizableMobileNetV3(MobileNetV3):
87
    def __init__(self, *args: Any, **kwargs: Any) -> None:
88
89
90
91
92
93
94
95
96
97
        """
        MobileNet V3 main class

        Args:
           Inherits args from floating point MobileNetV3
        """
        super().__init__(*args, **kwargs)
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

98
    def forward(self, x: Tensor) -> Tensor:
99
100
101
102
103
        x = self.quant(x)
        x = self._forward_impl(x)
        x = self.dequant(x)
        return x

104
    def fuse_model(self) -> None:
105
        for m in self.modules():
106
            if type(m) is ConvNormActivation:
107
                modules_to_fuse = ["0", "1"]
108
                if len(m) == 3 and type(m[2]) is nn.ReLU:
109
                    modules_to_fuse.append("2")
110
                fuse_modules(m, modules_to_fuse, inplace=True)
111
            elif type(m) is QuantizableSqueezeExcitation:
112
113
114
                m.fuse_model()


115
def _load_weights(arch: str, model: QuantizableMobileNetV3, model_url: Optional[str], progress: bool) -> None:
116
    if model_url is None:
117
        raise ValueError(f"No checkpoint is available for {arch}")
118
119
120
121
122
123
124
125
126
127
128
    state_dict = load_state_dict_from_url(model_url, progress=progress)
    model.load_state_dict(state_dict)


def _mobilenet_v3_model(
    arch: str,
    inverted_residual_setting: List[InvertedResidualConfig],
    last_channel: int,
    pretrained: bool,
    progress: bool,
    quantize: bool,
129
130
131
    **kwargs: Any,
) -> QuantizableMobileNetV3:

132
133
134
135
    model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs)
    _replace_relu(model)

    if quantize:
136
        backend = "qnnpack"
137
138
139
140
141
142

        model.fuse_model()
        model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
        torch.quantization.prepare_qat(model, inplace=True)

        if pretrained:
143
            _load_weights(arch, model, quant_model_urls.get(arch + "_" + backend, None), progress)
144
145
146
147
148
149
150
151
152
153

        torch.quantization.convert(model, inplace=True)
        model.eval()
    else:
        if pretrained:
            _load_weights(arch, model, model_urls.get(arch, None), progress)

    return model


154
155
156
157
158
159
def mobilenet_v3_large(
    pretrained: bool = False,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableMobileNetV3:
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    """
    Constructs a MobileNetV3 Large architecture from
    `"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.

    Note that quantize = True returns a quantized model with 8 bit
    weights. Quantized models only support inference and run on CPUs.
    GPU inference is not yet supported

    Args:
     pretrained (bool): If True, returns a model pre-trained on ImageNet.
     progress (bool): If True, displays a progress bar of the download to stderr
     quantize (bool): If True, returns a quantized model, else returns a float model
    """
    arch = "mobilenet_v3_large"
174
    inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs)
175
    return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, quantize, **kwargs)