inception.py 8.56 KB
Newer Older
1
import warnings
2
from typing import Any, List
3
4
5
6

import torch
import torch.nn as nn
import torch.nn.functional as F
7
from torch import Tensor
8
9
from torchvision.models import inception as inception_module
from torchvision.models.inception import InceptionOutputs
10

11
from ..._internally_replaced_utils import load_state_dict_from_url
12
13
14
15
16
17
18
19
20
21
22
from .utils import _replace_relu, quantize_model


__all__ = [
    "QuantizableInception3",
    "inception_v3",
]


quant_model_urls = {
    # fp32 weights ported from TensorFlow, quantized in PyTorch
23
    "inception_v3_google_fbgemm": "https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth"
24
25
26
}


27
28
29
30
31
32
33
def inception_v3(
    pretrained: bool = False,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> "QuantizableInception3":

34
35
36
37
38
39
40
41
42
43
44
45
46
47
    r"""Inception v3 model architecture from
    `"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.

    .. note::
        **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
        N x 3 x 299 x 299, so ensure your images are sized accordingly.

    Note that quantize = True returns a quantized model with 8 bit
    weights. Quantized models only support inference and run on CPUs.
    GPU inference is not yet supported

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
48
        quantize (bool): If True, return a quantized version of the model
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
        aux_logits (bool): If True, add an auxiliary branch that can improve training.
            Default: *True*
        transform_input (bool): If True, preprocesses the input according to the method with which it
            was trained on ImageNet. Default: *False*
    """
    if pretrained:
        if "transform_input" not in kwargs:
            kwargs["transform_input"] = True
        if "aux_logits" in kwargs:
            original_aux_logits = kwargs["aux_logits"]
            kwargs["aux_logits"] = True
        else:
            original_aux_logits = False

    model = QuantizableInception3(**kwargs)
    _replace_relu(model)

    if quantize:
        # TODO use pretrained as a string to specify the backend
68
        backend = "fbgemm"
69
70
71
72
73
74
        quantize_model(model, backend)
    else:
        assert pretrained in [True, False]

    if pretrained:
        if quantize:
75
76
            if not original_aux_logits:
                model.aux_logits = False
77
                model.AuxLogits = None
78
            model_url = quant_model_urls["inception_v3_google_" + backend]
79
        else:
80
            model_url = inception_module.model_urls["inception_v3_google"]
81

82
        state_dict = load_state_dict_from_url(model_url, progress=progress)
83
84
85

        model.load_state_dict(state_dict)

86
87
88
        if not quantize:
            if not original_aux_logits:
                model.aux_logits = False
89
                model.AuxLogits = None
90
91
92
93
    return model


class QuantizableBasicConv2d(inception_module.BasicConv2d):
94
    def __init__(self, *args: Any, **kwargs: Any) -> None:
95
        super().__init__(*args, **kwargs)
96
97
        self.relu = nn.ReLU()

98
    def forward(self, x: Tensor) -> Tensor:
99
100
101
102
103
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

104
    def fuse_model(self) -> None:
105
106
107
108
        torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)


class QuantizableInceptionA(inception_module.InceptionA):
109
110
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
111
        super().__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)  # type: ignore[misc]
112
113
        self.myop = nn.quantized.FloatFunctional()

114
    def forward(self, x: Tensor) -> Tensor:
115
116
117
118
119
        outputs = self._forward(x)
        return self.myop.cat(outputs, 1)


class QuantizableInceptionB(inception_module.InceptionB):
120
121
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
122
        super().__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)  # type: ignore[misc]
123
124
        self.myop = nn.quantized.FloatFunctional()

125
    def forward(self, x: Tensor) -> Tensor:
126
127
128
129
130
        outputs = self._forward(x)
        return self.myop.cat(outputs, 1)


class QuantizableInceptionC(inception_module.InceptionC):
131
132
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
133
        super().__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)  # type: ignore[misc]
134
135
        self.myop = nn.quantized.FloatFunctional()

136
    def forward(self, x: Tensor) -> Tensor:
137
138
139
140
141
        outputs = self._forward(x)
        return self.myop.cat(outputs, 1)


class QuantizableInceptionD(inception_module.InceptionD):
142
143
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
144
        super().__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)  # type: ignore[misc]
145
146
        self.myop = nn.quantized.FloatFunctional()

147
    def forward(self, x: Tensor) -> Tensor:
148
149
150
151
152
        outputs = self._forward(x)
        return self.myop.cat(outputs, 1)


class QuantizableInceptionE(inception_module.InceptionE):
153
154
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
155
        super().__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)  # type: ignore[misc]
hx89's avatar
hx89 committed
156
157
158
        self.myop1 = nn.quantized.FloatFunctional()
        self.myop2 = nn.quantized.FloatFunctional()
        self.myop3 = nn.quantized.FloatFunctional()
159

160
    def _forward(self, x: Tensor) -> List[Tensor]:
161
162
163
164
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)]
hx89's avatar
hx89 committed
165
        branch3x3 = self.myop1.cat(branch3x3, 1)
166
167
168
169
170
171
172

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
hx89's avatar
hx89 committed
173
        branch3x3dbl = self.myop2.cat(branch3x3dbl, 1)
174
175
176
177
178
179
180

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return outputs

181
    def forward(self, x: Tensor) -> Tensor:
182
        outputs = self._forward(x)
hx89's avatar
hx89 committed
183
        return self.myop3.cat(outputs, 1)
184
185
186


class QuantizableInceptionAux(inception_module.InceptionAux):
187
188
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
189
        super().__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)  # type: ignore[misc]
190
191
192


class QuantizableInception3(inception_module.Inception3):
193
194
195
196
197
198
    def __init__(
        self,
        num_classes: int = 1000,
        aux_logits: bool = True,
        transform_input: bool = False,
    ) -> None:
199
        super().__init__(
200
201
202
203
204
205
206
207
208
209
            num_classes=num_classes,
            aux_logits=aux_logits,
            transform_input=transform_input,
            inception_blocks=[
                QuantizableBasicConv2d,
                QuantizableInceptionA,
                QuantizableInceptionB,
                QuantizableInceptionC,
                QuantizableInceptionD,
                QuantizableInceptionE,
210
211
                QuantizableInceptionAux,
            ],
212
213
214
215
        )
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

216
    def forward(self, x: Tensor) -> InceptionOutputs:
217
218
219
220
221
222
223
224
225
226
227
228
        x = self._transform_input(x)
        x = self.quant(x)
        x, aux = self._forward(x)
        x = self.dequant(x)
        aux_defined = self.training and self.aux_logits
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted QuantizableInception3 always returns QuantizableInception3 Tuple")
            return InceptionOutputs(x, aux)
        else:
            return self.eager_outputs(x, aux)

229
    def fuse_model(self) -> None:
230
231
232
233
234
235
236
237
        r"""Fuse conv/bn/relu modules in inception model

        Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
        Model is modified in place.  Note that this operation does not change numerics
        and the model after modification is in floating point
        """

        for m in self.modules():
238
            if type(m) is QuantizableBasicConv2d:
239
                m.fuse_model()