googlenet.py 5.83 KB
Newer Older
1
import warnings
2
3
from typing import Any

4
5
import torch
import torch.nn as nn
6
from torch import Tensor
7
8
from torch.nn import functional as F
from torchvision.models.googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, model_urls
9

10
from ..._internally_replaced_utils import load_state_dict_from_url
11
12
13
from .utils import _replace_relu, quantize_model


14
__all__ = ["QuantizableGoogLeNet", "googlenet"]
15
16
17

quant_model_urls = {
    # fp32 GoogLeNet ported from TensorFlow, with weights quantized in PyTorch
18
    "googlenet_fbgemm": "https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth",
19
20
21
22
}


class QuantizableBasicConv2d(BasicConv2d):
23
    def __init__(self, *args: Any, **kwargs: Any) -> None:
24
        super().__init__(*args, **kwargs)
25
26
        self.relu = nn.ReLU()

27
    def forward(self, x: Tensor) -> Tensor:
28
29
30
31
32
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

33
    def fuse_model(self) -> None:
34
        torch.ao.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)
35
36
37


class QuantizableInception(Inception):
38
    def __init__(self, *args: Any, **kwargs: Any) -> None:
39
        super().__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)  # type: ignore[misc]
40
41
        self.cat = nn.quantized.FloatFunctional()

42
    def forward(self, x: Tensor) -> Tensor:
43
44
45
46
47
        outputs = self._forward(x)
        return self.cat.cat(outputs, 1)


class QuantizableInceptionAux(InceptionAux):
48
49
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
50
        super().__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)  # type: ignore[misc]
51
52
        self.relu = nn.ReLU()

53
    def forward(self, x: Tensor) -> Tensor:
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
        x = F.adaptive_avg_pool2d(x, (4, 4))
        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
        x = self.conv(x)
        # N x 128 x 4 x 4
        x = torch.flatten(x, 1)
        # N x 2048
        x = self.relu(self.fc1(x))
        # N x 1024
        x = self.dropout(x)
        # N x 1024
        x = self.fc2(x)
        # N x 1000 (num_classes)

        return x


class QuantizableGoogLeNet(GoogLeNet):
72
73
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
74
        super().__init__(  # type: ignore[misc]
75
            blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux], *args, **kwargs
76
        )
77
78
        self.quant = torch.ao.quantization.QuantStub()
        self.dequant = torch.ao.quantization.DeQuantStub()
79

80
    def forward(self, x: Tensor) -> GoogLeNetOutputs:
81
82
83
84
85
86
87
88
89
90
91
92
        x = self._transform_input(x)
        x = self.quant(x)
        x, aux1, aux2 = self._forward(x)
        x = self.dequant(x)
        aux_defined = self.training and self.aux_logits
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted QuantizableGoogleNet always returns GoogleNetOutputs Tuple")
            return GoogLeNetOutputs(x, aux2, aux1)
        else:
            return self.eager_outputs(x, aux2, aux1)

93
    def fuse_model(self) -> None:
94
95
96
97
98
99
100
101
        r"""Fuse conv/bn/relu modules in googlenet model

        Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
        Model is modified in place.  Note that this operation does not change numerics
        and the model after modification is in floating point
        """

        for m in self.modules():
102
            if type(m) is QuantizableBasicConv2d:
103
                m.fuse_model()
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165


def googlenet(
    pretrained: bool = False,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> QuantizableGoogLeNet:
    r"""GoogLeNet (Inception v1) model architecture from
    `"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.

    Note that quantize = True returns a quantized model with 8 bit
    weights. Quantized models only support inference and run on CPUs.
    GPU inference is not yet supported

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
        quantize (bool): If True, return a quantized version of the model
        aux_logits (bool): If True, adds two auxiliary branches that can improve training.
            Default: *False* when pretrained is True otherwise *True*
        transform_input (bool): If True, preprocesses the input according to the method with which it
            was trained on ImageNet. Default: *False*
    """
    if pretrained:
        if "transform_input" not in kwargs:
            kwargs["transform_input"] = True
        if "aux_logits" not in kwargs:
            kwargs["aux_logits"] = False
        if kwargs["aux_logits"]:
            warnings.warn(
                "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
            )
        original_aux_logits = kwargs["aux_logits"]
        kwargs["aux_logits"] = True
        kwargs["init_weights"] = False

    model = QuantizableGoogLeNet(**kwargs)
    _replace_relu(model)

    if quantize:
        # TODO use pretrained as a string to specify the backend
        backend = "fbgemm"
        quantize_model(model, backend)
    else:
        assert pretrained in [True, False]

    if pretrained:
        if quantize:
            model_url = quant_model_urls["googlenet_" + backend]
        else:
            model_url = model_urls["googlenet"]

        state_dict = load_state_dict_from_url(model_url, progress=progress)

        model.load_state_dict(state_dict)

        if not original_aux_logits:
            model.aux_logits = False
            model.aux1 = None  # type: ignore[assignment]
            model.aux2 = None  # type: ignore[assignment]
    return model