inception.py 8.57 KB
Newer Older
1
import warnings
2
from typing import Any, List
3
4
5
6

import torch
import torch.nn as nn
import torch.nn.functional as F
7
from torch import Tensor
8
9
from torchvision.models import inception as inception_module
from torchvision.models.inception import InceptionOutputs
10

11
from ..._internally_replaced_utils import load_state_dict_from_url
12
13
14
15
16
17
18
19
20
21
22
from .utils import _replace_relu, quantize_model


__all__ = [
    "QuantizableInception3",
    "inception_v3",
]


quant_model_urls = {
    # fp32 weights ported from TensorFlow, quantized in PyTorch
23
    "inception_v3_google_fbgemm": "https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth"
24
25
26
27
}


class QuantizableBasicConv2d(inception_module.BasicConv2d):
28
    def __init__(self, *args: Any, **kwargs: Any) -> None:
29
        super().__init__(*args, **kwargs)
30
31
        self.relu = nn.ReLU()

32
    def forward(self, x: Tensor) -> Tensor:
33
34
35
36
37
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

38
    def fuse_model(self) -> None:
39
        torch.ao.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
40
41
42


class QuantizableInceptionA(inception_module.InceptionA):
43
44
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
45
        super().__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)  # type: ignore[misc]
46
47
        self.myop = nn.quantized.FloatFunctional()

48
    def forward(self, x: Tensor) -> Tensor:
49
50
51
52
53
        outputs = self._forward(x)
        return self.myop.cat(outputs, 1)


class QuantizableInceptionB(inception_module.InceptionB):
54
55
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
56
        super().__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)  # type: ignore[misc]
57
58
        self.myop = nn.quantized.FloatFunctional()

59
    def forward(self, x: Tensor) -> Tensor:
60
61
62
63
64
        outputs = self._forward(x)
        return self.myop.cat(outputs, 1)


class QuantizableInceptionC(inception_module.InceptionC):
65
66
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
67
        super().__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)  # type: ignore[misc]
68
69
        self.myop = nn.quantized.FloatFunctional()

70
    def forward(self, x: Tensor) -> Tensor:
71
72
73
74
75
        outputs = self._forward(x)
        return self.myop.cat(outputs, 1)


class QuantizableInceptionD(inception_module.InceptionD):
76
77
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
78
        super().__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)  # type: ignore[misc]
79
80
        self.myop = nn.quantized.FloatFunctional()

81
    def forward(self, x: Tensor) -> Tensor:
82
83
84
85
86
        outputs = self._forward(x)
        return self.myop.cat(outputs, 1)


class QuantizableInceptionE(inception_module.InceptionE):
87
88
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
89
        super().__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)  # type: ignore[misc]
hx89's avatar
hx89 committed
90
91
92
        self.myop1 = nn.quantized.FloatFunctional()
        self.myop2 = nn.quantized.FloatFunctional()
        self.myop3 = nn.quantized.FloatFunctional()
93

94
    def _forward(self, x: Tensor) -> List[Tensor]:
95
96
97
98
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)]
hx89's avatar
hx89 committed
99
        branch3x3 = self.myop1.cat(branch3x3, 1)
100
101
102
103
104
105
106

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
hx89's avatar
hx89 committed
107
        branch3x3dbl = self.myop2.cat(branch3x3dbl, 1)
108
109
110
111
112
113
114

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return outputs

115
    def forward(self, x: Tensor) -> Tensor:
116
        outputs = self._forward(x)
hx89's avatar
hx89 committed
117
        return self.myop3.cat(outputs, 1)
118
119
120


class QuantizableInceptionAux(inception_module.InceptionAux):
121
122
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
123
        super().__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)  # type: ignore[misc]
124
125
126


class QuantizableInception3(inception_module.Inception3):
127
128
129
130
131
132
    def __init__(
        self,
        num_classes: int = 1000,
        aux_logits: bool = True,
        transform_input: bool = False,
    ) -> None:
133
        super().__init__(
134
135
136
137
138
139
140
141
142
143
            num_classes=num_classes,
            aux_logits=aux_logits,
            transform_input=transform_input,
            inception_blocks=[
                QuantizableBasicConv2d,
                QuantizableInceptionA,
                QuantizableInceptionB,
                QuantizableInceptionC,
                QuantizableInceptionD,
                QuantizableInceptionE,
144
145
                QuantizableInceptionAux,
            ],
146
        )
147
148
        self.quant = torch.ao.quantization.QuantStub()
        self.dequant = torch.ao.quantization.DeQuantStub()
149

150
    def forward(self, x: Tensor) -> InceptionOutputs:
151
152
153
154
155
156
157
158
159
160
161
162
        x = self._transform_input(x)
        x = self.quant(x)
        x, aux = self._forward(x)
        x = self.dequant(x)
        aux_defined = self.training and self.aux_logits
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted QuantizableInception3 always returns QuantizableInception3 Tuple")
            return InceptionOutputs(x, aux)
        else:
            return self.eager_outputs(x, aux)

163
    def fuse_model(self) -> None:
164
165
166
167
168
169
170
171
        r"""Fuse conv/bn/relu modules in inception model

        Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
        Model is modified in place.  Note that this operation does not change numerics
        and the model after modification is in floating point
        """

        for m in self.modules():
172
            if type(m) is QuantizableBasicConv2d:
173
                m.fuse_model()
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238


def inception_v3(
    pretrained: bool = False,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableInception3:
    r"""Inception v3 model architecture from
    `"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.

    .. note::
        **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
        N x 3 x 299 x 299, so ensure your images are sized accordingly.

    Note that quantize = True returns a quantized model with 8 bit
    weights. Quantized models only support inference and run on CPUs.
    GPU inference is not yet supported

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
        quantize (bool): If True, return a quantized version of the model
        aux_logits (bool): If True, add an auxiliary branch that can improve training.
            Default: *True*
        transform_input (bool): If True, preprocesses the input according to the method with which it
            was trained on ImageNet. Default: *False*
    """
    if pretrained:
        if "transform_input" not in kwargs:
            kwargs["transform_input"] = True
        if "aux_logits" in kwargs:
            original_aux_logits = kwargs["aux_logits"]
            kwargs["aux_logits"] = True
        else:
            original_aux_logits = False

    model = QuantizableInception3(**kwargs)
    _replace_relu(model)

    if quantize:
        # TODO use pretrained as a string to specify the backend
        backend = "fbgemm"
        quantize_model(model, backend)
    else:
        assert pretrained in [True, False]

    if pretrained:
        if quantize:
            if not original_aux_logits:
                model.aux_logits = False
                model.AuxLogits = None
            model_url = quant_model_urls["inception_v3_google_" + backend]
        else:
            model_url = inception_module.model_urls["inception_v3_google"]

        state_dict = load_state_dict_from_url(model_url, progress=progress)

        model.load_state_dict(state_dict)

        if not quantize:
            if not original_aux_logits:
                model.aux_logits = False
                model.AuxLogits = None
    return model