mobilenetv2.py 3.21 KB
Newer Older
1
2
from typing import Any

3
4
from torch import Tensor
from torch import nn
5
from torch.quantization import QuantStub, DeQuantStub, fuse_modules
6
7
8
from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls

from ..._internally_replaced_utils import load_state_dict_from_url
9
from ...ops.misc import ConvNormActivation
10
from .utils import _replace_relu, quantize_model
11
12


13
__all__ = ["QuantizableMobileNetV2", "mobilenet_v2"]
14
15

quant_model_urls = {
16
    "mobilenet_v2_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth"
17
18
19
20
}


class QuantizableInvertedResidual(InvertedResidual):
21
    def __init__(self, *args: Any, **kwargs: Any) -> None:
22
        super().__init__(*args, **kwargs)
23
24
        self.skip_add = nn.quantized.FloatFunctional()

25
    def forward(self, x: Tensor) -> Tensor:
26
27
28
29
30
        if self.use_res_connect:
            return self.skip_add.add(x, self.conv(x))
        else:
            return self.conv(x)

31
    def fuse_model(self) -> None:
32
        for idx in range(len(self.conv)):
33
            if type(self.conv[idx]) is nn.Conv2d:
34
35
36
37
                fuse_modules(self.conv, [str(idx), str(idx + 1)], inplace=True)


class QuantizableMobileNetV2(MobileNetV2):
38
    def __init__(self, *args: Any, **kwargs: Any) -> None:
39
40
41
42
43
44
        """
        MobileNet V2 main class

        Args:
           Inherits args from floating point MobileNetV2
        """
45
        super().__init__(*args, **kwargs)
46
47
48
        self.quant = QuantStub()
        self.dequant = DeQuantStub()

49
    def forward(self, x: Tensor) -> Tensor:
50
51
52
53
54
        x = self.quant(x)
        x = self._forward_impl(x)
        x = self.dequant(x)
        return x

55
    def fuse_model(self) -> None:
56
        for m in self.modules():
57
            if type(m) is ConvNormActivation:
58
                fuse_modules(m, ["0", "1", "2"], inplace=True)
59
            if type(m) is QuantizableInvertedResidual:
60
61
62
                m.fuse_model()


63
64
65
66
67
68
def mobilenet_v2(
    pretrained: bool = False,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableMobileNetV2:
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    """
    Constructs a MobileNetV2 architecture from
    `"MobileNetV2: Inverted Residuals and Linear Bottlenecks"
    <https://arxiv.org/abs/1801.04381>`_.

    Note that quantize = True returns a quantized model with 8 bit
    weights. Quantized models only support inference and run on CPUs.
    GPU inference is not yet supported

    Args:
     pretrained (bool): If True, returns a model pre-trained on ImageNet.
     progress (bool): If True, displays a progress bar of the download to stderr
     quantize(bool): If True, returns a quantized model, else returns a float model
    """
    model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs)
    _replace_relu(model)

    if quantize:
        # TODO use pretrained as a string to specify the backend
88
        backend = "qnnpack"
89
90
91
92
93
94
        quantize_model(model, backend)
    else:
        assert pretrained in [True, False]

    if pretrained:
        if quantize:
95
            model_url = quant_model_urls["mobilenet_v2_" + backend]
96
        else:
97
            model_url = model_urls["mobilenet_v2"]
98

99
        state_dict = load_state_dict_from_url(model_url, progress=progress)
100
101
102

        model.load_state_dict(state_dict)
    return model