googlenet.py 8.13 KB
Newer Older
1
import warnings
2
3
from functools import partial
from typing import Any, Optional, Union
4

5
6
import torch
import torch.nn as nn
7
from torch import Tensor
8
from torch.nn import functional as F
9

10
from ...transforms._presets import ImageClassification
11
from .._api import Weights, WeightsEnum
12
from .._meta import _IMAGENET_CATEGORIES
13
14
from .._utils import _ovewrite_named_param, handle_legacy_interface
from ..googlenet import BasicConv2d, GoogLeNet, GoogLeNet_Weights, GoogLeNetOutputs, Inception, InceptionAux
15
from .utils import _fuse_modules, _replace_relu, quantize_model
16
17


18
19
20
21
22
__all__ = [
    "QuantizableGoogLeNet",
    "GoogLeNet_QuantizedWeights",
    "googlenet",
]
23
24
25


class QuantizableBasicConv2d(BasicConv2d):
26
    def __init__(self, *args: Any, **kwargs: Any) -> None:
27
        super().__init__(*args, **kwargs)
28
29
        self.relu = nn.ReLU()

30
    def forward(self, x: Tensor) -> Tensor:
31
32
33
34
35
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

36
37
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
        _fuse_modules(self, ["conv", "bn", "relu"], is_qat, inplace=True)
38
39
40


class QuantizableInception(Inception):
41
    def __init__(self, *args: Any, **kwargs: Any) -> None:
42
        super().__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)  # type: ignore[misc]
43
44
        self.cat = nn.quantized.FloatFunctional()

45
    def forward(self, x: Tensor) -> Tensor:
46
47
48
49
50
        outputs = self._forward(x)
        return self.cat.cat(outputs, 1)


class QuantizableInceptionAux(InceptionAux):
51
52
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
53
        super().__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)  # type: ignore[misc]
54
55
        self.relu = nn.ReLU()

56
    def forward(self, x: Tensor) -> Tensor:
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
        x = F.adaptive_avg_pool2d(x, (4, 4))
        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
        x = self.conv(x)
        # N x 128 x 4 x 4
        x = torch.flatten(x, 1)
        # N x 2048
        x = self.relu(self.fc1(x))
        # N x 1024
        x = self.dropout(x)
        # N x 1024
        x = self.fc2(x)
        # N x 1000 (num_classes)

        return x


class QuantizableGoogLeNet(GoogLeNet):
75
76
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
77
        super().__init__(  # type: ignore[misc]
78
            blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux], *args, **kwargs
79
        )
80
81
        self.quant = torch.ao.quantization.QuantStub()
        self.dequant = torch.ao.quantization.DeQuantStub()
82

83
    def forward(self, x: Tensor) -> GoogLeNetOutputs:
84
85
86
87
88
89
90
91
92
93
94
95
        x = self._transform_input(x)
        x = self.quant(x)
        x, aux1, aux2 = self._forward(x)
        x = self.dequant(x)
        aux_defined = self.training and self.aux_logits
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted QuantizableGoogleNet always returns GoogleNetOutputs Tuple")
            return GoogLeNetOutputs(x, aux2, aux1)
        else:
            return self.eager_outputs(x, aux2, aux1)

96
    def fuse_model(self, is_qat: Optional[bool] = None) -> None:
97
98
99
100
101
102
103
104
        r"""Fuse conv/bn/relu modules in googlenet model

        Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
        Model is modified in place.  Note that this operation does not change numerics
        and the model after modification is in floating point
        """

        for m in self.modules():
105
            if type(m) is QuantizableBasicConv2d:
106
                m.fuse_model(is_qat)
107
108


109
110
111
112
113
114
115
116
117
118
119
class GoogLeNet_QuantizedWeights(WeightsEnum):
    IMAGENET1K_FBGEMM_V1 = Weights(
        url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            "num_params": 6624904,
            "min_size": (15, 15),
            "categories": _IMAGENET_CATEGORIES,
            "backend": "fbgemm",
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
            "unquantized": GoogLeNet_Weights.IMAGENET1K_V1,
120
121
122
123
124
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 69.826,
                    "acc@5": 89.404,
                }
125
            },
126
127
128
129
            "_docs": """
                These weights were produced by doing Post Training Quantization (eager mode) on top of the unquantized
                weights listed below.
            """,
130
131
132
133
134
135
136
137
138
139
140
141
142
        },
    )
    DEFAULT = IMAGENET1K_FBGEMM_V1


@handle_legacy_interface(
    weights=(
        "pretrained",
        lambda kwargs: GoogLeNet_QuantizedWeights.IMAGENET1K_FBGEMM_V1
        if kwargs.get("quantize", False)
        else GoogLeNet_Weights.IMAGENET1K_V1,
    )
)
143
def googlenet(
144
145
    *,
    weights: Optional[Union[GoogLeNet_QuantizedWeights, GoogLeNet_Weights]] = None,
146
147
148
149
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableGoogLeNet:
150
    """GoogLeNet (Inception v1) model architecture from `Going Deeper with Convolutions <http://arxiv.org/abs/1409.4842>`__.
151

152
153
154
155
    .. note::
        Note that ``quantize = True`` returns a quantized model with 8 bit
        weights. Quantized models only support inference and run on CPUs.
        GPU inference is not yet supported.
156
157

    Args:
158
159
160
161
162
163
164
165
166
167
        weights (:class:`~torchvision.models.quantization.GoogLeNet_QuantizedWeights` or :class:`~torchvision.models.GoogLeNet_Weights`, optional): The
            pretrained weights for the model. See
            :class:`~torchvision.models.quantization.GoogLeNet_QuantizedWeights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        quantize (bool, optional): If True, return a quantized version of the model. Default is False.
        **kwargs: parameters passed to the ``torchvision.models.quantization.QuantizableGoogLeNet``
            base class. Please refer to the `source code
Nicolas Hug's avatar
Nicolas Hug committed
168
            <https://github.com/pytorch/vision/blob/main/torchvision/models/quantization/googlenet.py>`_
169
170
171
172
173
174
175
176
            for more details about this class.

    .. autoclass:: torchvision.models.quantization.GoogLeNet_QuantizedWeights
        :members:

    .. autoclass:: torchvision.models.GoogLeNet_Weights
        :members:
        :noindex:
177
    """
178
179
180
181
    weights = (GoogLeNet_QuantizedWeights if quantize else GoogLeNet_Weights).verify(weights)

    original_aux_logits = kwargs.get("aux_logits", False)
    if weights is not None:
182
        if "transform_input" not in kwargs:
183
184
185
186
187
188
189
            _ovewrite_named_param(kwargs, "transform_input", True)
        _ovewrite_named_param(kwargs, "aux_logits", True)
        _ovewrite_named_param(kwargs, "init_weights", False)
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
        if "backend" in weights.meta:
            _ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
    backend = kwargs.pop("backend", "fbgemm")
190
191
192
193
194
195

    model = QuantizableGoogLeNet(**kwargs)
    _replace_relu(model)
    if quantize:
        quantize_model(model, backend)

196
197
    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))
198
199
200
201
        if not original_aux_logits:
            model.aux_logits = False
            model.aux1 = None  # type: ignore[assignment]
            model.aux2 = None  # type: ignore[assignment]
202
203
204
205
206
        else:
            warnings.warn(
                "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
            )

207
    return model
208
209
210
211
212
213
214
215
216
217
218
219
220


# The dictionary below is internal implementation detail and will be removed in v0.15
from .._utils import _ModelURLs
from ..googlenet import model_urls  # noqa: F401


quant_model_urls = _ModelURLs(
    {
        # fp32 GoogLeNet ported from TensorFlow, with weights quantized in PyTorch
        "googlenet_fbgemm": GoogLeNet_QuantizedWeights.IMAGENET1K_FBGEMM_V1.url,
    }
)