inception.py 9.1 KB
Newer Older
1
2
3
4
5
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
6
7
8
from torch import Tensor
from typing import Any, List

9
10
from torchvision.models import inception as inception_module
from torchvision.models.inception import InceptionOutputs
11
from ..._internally_replaced_utils import load_state_dict_from_url
12
13
14
15
16
17
18
19
20
21
22
23
from .utils import _replace_relu, quantize_model


__all__ = [
    "QuantizableInception3",
    "inception_v3",
]


quant_model_urls = {
    # fp32 weights ported from TensorFlow, quantized in PyTorch
    "inception_v3_google_fbgemm":
24
        "https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth"
25
26
27
}


28
29
30
31
32
33
34
def inception_v3(
    pretrained: bool = False,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> "QuantizableInception3":

35
36
37
38
39
40
41
42
43
44
45
46
47
48
    r"""Inception v3 model architecture from
    `"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.

    .. note::
        **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
        N x 3 x 299 x 299, so ensure your images are sized accordingly.

    Note that quantize = True returns a quantized model with 8 bit
    weights. Quantized models only support inference and run on CPUs.
    GPU inference is not yet supported

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
49
        quantize (bool): If True, return a quantized version of the model
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        aux_logits (bool): If True, add an auxiliary branch that can improve training.
            Default: *True*
        transform_input (bool): If True, preprocesses the input according to the method with which it
            was trained on ImageNet. Default: *False*
    """
    if pretrained:
        if "transform_input" not in kwargs:
            kwargs["transform_input"] = True
        if "aux_logits" in kwargs:
            original_aux_logits = kwargs["aux_logits"]
            kwargs["aux_logits"] = True
        else:
            original_aux_logits = False

    model = QuantizableInception3(**kwargs)
    _replace_relu(model)

    if quantize:
        # TODO use pretrained as a string to specify the backend
        backend = 'fbgemm'
        quantize_model(model, backend)
    else:
        assert pretrained in [True, False]

    if pretrained:
        if quantize:
76
77
            if not original_aux_logits:
                model.aux_logits = False
78
                model.AuxLogits = None
79
80
81
82
83
84
85
86
87
            model_url = quant_model_urls['inception_v3_google' + '_' + backend]
        else:
            model_url = inception_module.model_urls['inception_v3_google']

        state_dict = load_state_dict_from_url(model_url,
                                              progress=progress)

        model.load_state_dict(state_dict)

88
89
90
        if not quantize:
            if not original_aux_logits:
                model.aux_logits = False
91
                model.AuxLogits = None
92
93
94
95
    return model


class QuantizableBasicConv2d(inception_module.BasicConv2d):
96
    def __init__(self, *args: Any, **kwargs: Any) -> None:
97
98
99
        super(QuantizableBasicConv2d, self).__init__(*args, **kwargs)
        self.relu = nn.ReLU()

100
    def forward(self, x: Tensor) -> Tensor:
101
102
103
104
105
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

106
    def fuse_model(self) -> None:
107
108
109
110
        torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)


class QuantizableInceptionA(inception_module.InceptionA):
111
112
113
114
115
116
117
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super(QuantizableInceptionA, self).__init__(  # type: ignore[misc]
            conv_block=QuantizableBasicConv2d,
            *args,
            **kwargs
        )
118
119
        self.myop = nn.quantized.FloatFunctional()

120
    def forward(self, x: Tensor) -> Tensor:
121
122
123
124
125
        outputs = self._forward(x)
        return self.myop.cat(outputs, 1)


class QuantizableInceptionB(inception_module.InceptionB):
126
127
128
129
130
131
132
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super(QuantizableInceptionB, self).__init__(  # type: ignore[misc]
            conv_block=QuantizableBasicConv2d,
            *args,
            **kwargs
        )
133
134
        self.myop = nn.quantized.FloatFunctional()

135
    def forward(self, x: Tensor) -> Tensor:
136
137
138
139
140
        outputs = self._forward(x)
        return self.myop.cat(outputs, 1)


class QuantizableInceptionC(inception_module.InceptionC):
141
142
143
144
145
146
147
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super(QuantizableInceptionC, self).__init__(  # type: ignore[misc]
            conv_block=QuantizableBasicConv2d,
            *args,
            **kwargs
        )
148
149
        self.myop = nn.quantized.FloatFunctional()

150
    def forward(self, x: Tensor) -> Tensor:
151
152
153
154
155
        outputs = self._forward(x)
        return self.myop.cat(outputs, 1)


class QuantizableInceptionD(inception_module.InceptionD):
156
157
158
159
160
161
162
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super(QuantizableInceptionD, self).__init__(  # type: ignore[misc]
            conv_block=QuantizableBasicConv2d,
            *args,
            **kwargs
        )
163
164
        self.myop = nn.quantized.FloatFunctional()

165
    def forward(self, x: Tensor) -> Tensor:
166
167
168
169
170
        outputs = self._forward(x)
        return self.myop.cat(outputs, 1)


class QuantizableInceptionE(inception_module.InceptionE):
171
172
173
174
175
176
177
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super(QuantizableInceptionE, self).__init__(  # type: ignore[misc]
            conv_block=QuantizableBasicConv2d,
            *args,
            **kwargs
        )
hx89's avatar
hx89 committed
178
179
180
        self.myop1 = nn.quantized.FloatFunctional()
        self.myop2 = nn.quantized.FloatFunctional()
        self.myop3 = nn.quantized.FloatFunctional()
181

182
    def _forward(self, x: Tensor) -> List[Tensor]:
183
184
185
186
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)]
hx89's avatar
hx89 committed
187
        branch3x3 = self.myop1.cat(branch3x3, 1)
188
189
190
191
192
193
194

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
hx89's avatar
hx89 committed
195
        branch3x3dbl = self.myop2.cat(branch3x3dbl, 1)
196
197
198
199
200
201
202

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return outputs

203
    def forward(self, x: Tensor) -> Tensor:
204
        outputs = self._forward(x)
hx89's avatar
hx89 committed
205
        return self.myop3.cat(outputs, 1)
206
207
208


class QuantizableInceptionAux(inception_module.InceptionAux):
209
210
211
212
213
214
215
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super(QuantizableInceptionAux, self).__init__(  # type: ignore[misc]
            conv_block=QuantizableBasicConv2d,
            *args,
            **kwargs
        )
216
217
218


class QuantizableInception3(inception_module.Inception3):
219
220
221
222
223
224
    def __init__(
        self,
        num_classes: int = 1000,
        aux_logits: bool = True,
        transform_input: bool = False,
    ) -> None:
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
        super(QuantizableInception3, self).__init__(
            num_classes=num_classes,
            aux_logits=aux_logits,
            transform_input=transform_input,
            inception_blocks=[
                QuantizableBasicConv2d,
                QuantizableInceptionA,
                QuantizableInceptionB,
                QuantizableInceptionC,
                QuantizableInceptionD,
                QuantizableInceptionE,
                QuantizableInceptionAux
            ]
        )
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

242
    def forward(self, x: Tensor) -> InceptionOutputs:
243
244
245
246
247
248
249
250
251
252
253
254
        x = self._transform_input(x)
        x = self.quant(x)
        x, aux = self._forward(x)
        x = self.dequant(x)
        aux_defined = self.training and self.aux_logits
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted QuantizableInception3 always returns QuantizableInception3 Tuple")
            return InceptionOutputs(x, aux)
        else:
            return self.eager_outputs(x, aux)

255
    def fuse_model(self) -> None:
256
257
258
259
260
261
262
263
264
265
        r"""Fuse conv/bn/relu modules in inception model

        Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
        Model is modified in place.  Note that this operation does not change numerics
        and the model after modification is in floating point
        """

        for m in self.modules():
            if type(m) == QuantizableBasicConv2d:
                m.fuse_model()