mobilenetv2.py 5.75 KB
Newer Older
limm's avatar
limm committed
1
2
from functools import partial
from typing import Any, Optional, Union
3

limm's avatar
limm committed
4
5
6
from torch import nn, Tensor
from torch.ao.quantization import DeQuantStub, QuantStub
from torchvision.models.mobilenetv2 import InvertedResidual, MobileNet_V2_Weights, MobileNetV2
7

limm's avatar
limm committed
8
9
10
11
12
13
from ...ops.misc import Conv2dNormActivation
from ...transforms._presets import ImageClassification
from .._api import register_model, Weights, WeightsEnum
from .._meta import _IMAGENET_CATEGORIES
from .._utils import _ovewrite_named_param, handle_legacy_interface
from .utils import _fuse_modules, _replace_relu, quantize_model
14

limm's avatar
limm committed
15
16
17
18
19
20

__all__ = [
    "QuantizableMobileNetV2",
    "MobileNet_V2_QuantizedWeights",
    "mobilenet_v2",
]
21
22
23


class QuantizableInvertedResidual(InvertedResidual):
limm's avatar
limm committed
24
25
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
26
27
        self.skip_add = nn.quantized.FloatFunctional()

limm's avatar
limm committed
28
    def forward(self, x: Tensor) -> Tensor:
29
30
31
32
33
        if self.use_res_connect:
            return self.skip_add.add(x, self.conv(x))
        else:
            return self.conv(x)

limm's avatar
limm committed
34
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
35
        for idx in range(len(self.conv)):
limm's avatar
limm committed
36
37
            if type(self.conv[idx]) is nn.Conv2d:
                _fuse_modules(self.conv, [str(idx), str(idx + 1)], is_qat, inplace=True)
38
39
40


class QuantizableMobileNetV2(MobileNetV2):
limm's avatar
limm committed
41
    def __init__(self, *args: Any, **kwargs: Any) -> None:
42
43
44
45
46
47
        """
        MobileNet V2 main class

        Args:
           Inherits args from floating point MobileNetV2
        """
limm's avatar
limm committed
48
        super().__init__(*args, **kwargs)
49
50
51
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

limm's avatar
limm committed
52
    def forward(self, x: Tensor) -> Tensor:
53
54
55
56
57
        x = self.quant(x)
        x = self._forward_impl(x)
        x = self.dequant(x)
        return x

limm's avatar
limm committed
58
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
59
        for m in self.modules():
limm's avatar
limm committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
            if type(m) is Conv2dNormActivation:
                _fuse_modules(m, ["0", "1", "2"], is_qat, inplace=True)
            if type(m) is QuantizableInvertedResidual:
                m.fuse_model(is_qat)


class MobileNet_V2_QuantizedWeights(WeightsEnum):
    IMAGENET1K_QNNPACK_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            "num_params": 3504872,
            "min_size": (1, 1),
            "categories": _IMAGENET_CATEGORIES,
            "backend": "qnnpack",
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2",
            "unquantized": MobileNet_V2_Weights.IMAGENET1K_V1,
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 71.658,
                    "acc@5": 90.150,
                }
            },
            "_ops": 0.301,
            "_file_size": 3.423,
            "_docs": """
                These weights were produced by doing Quantization Aware Training (eager mode) on top of the unquantized
                weights listed below.
            """,
        },
    )
    DEFAULT = IMAGENET1K_QNNPACK_V1


@register_model(name="quantized_mobilenet_v2")
@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1
        if kwargs.get("quantize", False)
        else MobileNet_V2_Weights.IMAGENET1K_V1,
    )
)
def mobilenet_v2(
    *,
    weights: Optional[Union[MobileNet_V2_QuantizedWeights, MobileNet_V2_Weights]] = None,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableMobileNetV2:
110
111
    """
    Constructs a MobileNetV2 architecture from
limm's avatar
limm committed
112
    `MobileNetV2: Inverted Residuals and Linear Bottlenecks
113
114
    <https://arxiv.org/abs/1801.04381>`_.

limm's avatar
limm committed
115
116
117
118
    .. note::
        Note that ``quantize = True`` returns a quantized model with 8 bit
        weights. Quantized models only support inference and run on CPUs.
        GPU inference is not yet supported.
119
120

    Args:
limm's avatar
limm committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        weights (:class:`~torchvision.models.quantization.MobileNet_V2_QuantizedWeights` or :class:`~torchvision.models.MobileNet_V2_Weights`, optional): The
            pretrained weights for the model. See
            :class:`~torchvision.models.quantization.MobileNet_V2_QuantizedWeights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        quantize (bool, optional): If True, returns a quantized version of the model. Default is False.
        **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableMobileNetV2``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/mobilenetv2.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.quantization.MobileNet_V2_QuantizedWeights
        :members:
    .. autoclass:: torchvision.models.MobileNet_V2_Weights
        :members:
        :noindex:
137
    """
limm's avatar
limm committed
138
139
140
141
142
143
144
145
    weights = (MobileNet_V2_QuantizedWeights if quantize else MobileNet_V2_Weights).verify(weights)

    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
        if "backend" in weights.meta:
            _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
    backend = kwargs.pop("backend", "qnnpack")

146
147
148
149
150
    model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs)
    _replace_relu(model)
    if quantize:
        quantize_model(model, backend)

limm's avatar
limm committed
151
152
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
153
154

    return model