inception.py 10.6 KB
Newer Older
1
import warnings
2
3
from functools import partial
from typing import Any, List, Optional, Union
4
5
6
7

import torch
import torch.nn as nn
import torch.nn.functional as F
8
from torch import Tensor
9
from torchvision.models import inception as inception_module
10
from torchvision.models.inception import Inception_V3_Weights, InceptionOutputs
11

12
from ...transforms._presets import ImageClassification
13
from .._api import register_model, Weights, WeightsEnum
14
from .._meta import _IMAGENET_CATEGORIES
15
from .._utils import _ovewrite_named_param, handle_legacy_interface
16
from .utils import _fuse_modules, _replace_relu, quantize_model
17
18
19
20


__all__ = [
    "QuantizableInception3",
21
    "Inception_V3_QuantizedWeights",
22
23
24
25
26
    "inception_v3",
]


class QuantizableBasicConv2d(inception_module.BasicConv2d):
27
    def __init__(self, *args: Any, **kwargs: Any) -> None:
28
        super().__init__(*args, **kwargs)
29
30
        self.relu = nn.ReLU()

31
    def forward(self, x: Tensor) -> Tensor:
32
33
34
35
36
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

37
38
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
        _fuse_modules(self, ["conv", "bn", "relu"], is_qat, inplace=True)
39
40
41


class QuantizableInceptionA(inception_module.InceptionA):
42
43
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
44
        super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs)  # type: ignore[misc]
45
46
        self.myop = nn.quantized.FloatFunctional()

47
    def forward(self, x: Tensor) -> Tensor:
48
49
50
51
52
        outputs = self._forward(x)
        return self.myop.cat(outputs, 1)


class QuantizableInceptionB(inception_module.InceptionB):
53
54
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
55
        super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs)  # type: ignore[misc]
56
57
        self.myop = nn.quantized.FloatFunctional()

58
    def forward(self, x: Tensor) -> Tensor:
59
60
61
62
63
        outputs = self._forward(x)
        return self.myop.cat(outputs, 1)


class QuantizableInceptionC(inception_module.InceptionC):
64
65
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
66
        super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs)  # type: ignore[misc]
67
68
        self.myop = nn.quantized.FloatFunctional()

69
    def forward(self, x: Tensor) -> Tensor:
70
71
72
73
74
        outputs = self._forward(x)
        return self.myop.cat(outputs, 1)


class QuantizableInceptionD(inception_module.InceptionD):
75
76
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
77
        super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs)  # type: ignore[misc]
78
79
        self.myop = nn.quantized.FloatFunctional()

80
    def forward(self, x: Tensor) -> Tensor:
81
82
83
84
85
        outputs = self._forward(x)
        return self.myop.cat(outputs, 1)


class QuantizableInceptionE(inception_module.InceptionE):
86
87
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
88
        super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs)  # type: ignore[misc]
hx89's avatar
hx89 committed
89
90
91
        self.myop1 = nn.quantized.FloatFunctional()
        self.myop2 = nn.quantized.FloatFunctional()
        self.myop3 = nn.quantized.FloatFunctional()
92

93
    def _forward(self, x: Tensor) -> List[Tensor]:
94
95
96
97
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)]
hx89's avatar
hx89 committed
98
        branch3x3 = self.myop1.cat(branch3x3, 1)
99
100
101
102
103
104
105

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
hx89's avatar
hx89 committed
106
        branch3x3dbl = self.myop2.cat(branch3x3dbl, 1)
107
108
109
110
111
112
113

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return outputs

114
    def forward(self, x: Tensor) -> Tensor:
115
        outputs = self._forward(x)
hx89's avatar
hx89 committed
116
        return self.myop3.cat(outputs, 1)
117
118
119


class QuantizableInceptionAux(inception_module.InceptionAux):
120
121
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
122
        super().__init__(*args, conv_block=QuantizableBasicConv2d, **kwargs)  # type: ignore[misc]
123
124
125


class QuantizableInception3(inception_module.Inception3):
126
127
128
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(  # type: ignore[misc]
            *args,
129
130
131
132
133
134
135
            inception_blocks=[
                QuantizableBasicConv2d,
                QuantizableInceptionA,
                QuantizableInceptionB,
                QuantizableInceptionC,
                QuantizableInceptionD,
                QuantizableInceptionE,
136
137
                QuantizableInceptionAux,
            ],
138
            **kwargs,
139
        )
140
141
        self.quant = torch.ao.quantization.QuantStub()
        self.dequant = torch.ao.quantization.DeQuantStub()
142

143
    def forward(self, x: Tensor) -> InceptionOutputs:
144
145
146
147
148
149
150
151
152
153
154
155
        x = self._transform_input(x)
        x = self.quant(x)
        x, aux = self._forward(x)
        x = self.dequant(x)
        aux_defined = self.training and self.aux_logits
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted QuantizableInception3 always returns QuantizableInception3 Tuple")
            return InceptionOutputs(x, aux)
        else:
            return self.eager_outputs(x, aux)

156
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
157
158
159
160
161
162
163
164
        r"""Fuse conv/bn/relu modules in inception model

        Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
        Model is modified in place.  Note that this operation does not change numerics
        and the model after modification is in floating point
        """

        for m in self.modules():
165
            if type(m) is QuantizableBasicConv2d:
166
                m.fuse_model(is_qat)
167
168


169
170
171
172
173
174
175
176
177
178
179
class Inception_V3_QuantizedWeights(WeightsEnum):
    IMAGENET1K_FBGEMM_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth",
        transforms=partial(ImageClassification, crop_size=299, resize_size=342),
        meta={
            "num_params": 27161264,
            "min_size": (75, 75),
            "categories": _IMAGENET_CATEGORIES,
            "backend": "fbgemm",
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
            "unquantized": Inception_V3_Weights.IMAGENET1K_V1,
180
181
182
183
184
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 77.176,
                    "acc@5": 93.354,
                }
185
            },
186
            "_ops": 5.713,
Nicolas Hug's avatar
Nicolas Hug committed
187
            "_file_size": 23.146,
188
189
190
191
            "_docs": """
                These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized
                weights listed below.
            """,
192
193
194
195
196
        },
    )
    DEFAULT = IMAGENET1K_FBGEMM_V1


197
@register_model(name="quantized_inception_v3")
198
199
200
201
202
203
204
205
@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: Inception_V3_QuantizedWeights.IMAGENET1K_FBGEMM_V1
        if kwargs.get("quantize", False)
        else Inception_V3_Weights.IMAGENET1K_V1,
    )
)
206
def inception_v3(
207
208
    *,
    weights: Optional[Union[Inception_V3_QuantizedWeights, Inception_V3_Weights]] = None,
209
210
211
212
213
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableInception3:
    r"""Inception v3 model architecture from
214
    `Rethinking the Inception Architecture for Computer Vision <http://arxiv.org/abs/1512.00567>`__.
215
216
217
218
219

    .. note::
        **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
        N x 3 x 299 x 299, so ensure your images are sized accordingly.

220
221
222
223
    .. note::
        Note that ``quantize = True`` returns a quantized model with 8 bit
        weights. Quantized models only support inference and run on CPUs.
        GPU inference is not yet supported.
224
225

    Args:
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        weights (:class:`~torchvision.models.quantization.Inception_V3_QuantizedWeights` or :class:`~torchvision.models.Inception_V3_Weights`, optional): The pretrained
            weights for the model. See
            :class:`~torchvision.models.quantization.Inception_V3_QuantizedWeights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr.
            Default is True.
        quantize (bool, optional): If True, return a quantized version of the model.
            Default is False.
        **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableInception3``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/inception.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.quantization.Inception_V3_QuantizedWeights
        :members:

    .. autoclass:: torchvision.models.Inception_V3_Weights
        :members:
        :noindex:
246
    """
247
248
249
250
    weights = (Inception_V3_QuantizedWeights if quantize else Inception_V3_Weights).verify(weights)

    original_aux_logits = kwargs.get("aux_logits", False)
    if weights is not None:
251
        if "transform_input" not in kwargs:
252
253
254
255
256
257
            _ovewrite_named_param(kwargs, "transform_input", True)
        _ovewrite_named_param(kwargs, "aux_logits", True)
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
        if "backend" in weights.meta:
            _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
    backend = kwargs.pop("backend", "fbgemm")
258
259
260
261
262
263

    model = QuantizableInception3(**kwargs)
    _replace_relu(model)
    if quantize:
        quantize_model(model, backend)

264
265
266
267
    if weights is not None:
        if quantize and not original_aux_logits:
            model.aux_logits = False
            model.AuxLogits = None
268
        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
269
270
271
        if not quantize and not original_aux_logits:
            model.aux_logits = False
            model.AuxLogits = None
272
273

    return model