googlenet.py 6.11 KB
Newer Older
1
2
3
4
import warnings
import torch
import torch.nn as nn
from torch.nn import functional as F
5
6
from typing import Any
from torch import Tensor
7

8
from ..._internally_replaced_utils import load_state_dict_from_url
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from torchvision.models.googlenet import (
    GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, model_urls)

from .utils import _replace_relu, quantize_model


__all__ = ['QuantizableGoogLeNet', 'googlenet']

quant_model_urls = {
    # fp32 GoogLeNet ported from TensorFlow, with weights quantized in PyTorch
    'googlenet_fbgemm': 'https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth',
}


23
24
25
26
27
28
29
def googlenet(
    pretrained: bool = False,
    progress: bool = True,
    quantize: bool = False,
    **kwargs: Any,
) -> "QuantizableGoogLeNet":

30
31
32
33
34
35
36
37
38
39
    r"""GoogLeNet (Inception v1) model architecture from
    `"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.

    Note that quantize = True returns a quantized model with 8 bit
    weights. Quantized models only support inference and run on CPUs.
    GPU inference is not yet supported

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
40
        quantize (bool): If True, return a quantized version of the model
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        aux_logits (bool): If True, adds two auxiliary branches that can improve training.
            Default: *False* when pretrained is True otherwise *True*
        transform_input (bool): If True, preprocesses the input according to the method with which it
            was trained on ImageNet. Default: *False*
    """
    if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
        if 'aux_logits' not in kwargs:
            kwargs['aux_logits'] = False
        if kwargs['aux_logits']:
            warnings.warn('auxiliary heads in the pretrained googlenet model are NOT pretrained, '
                          'so make sure to train them')
        original_aux_logits = kwargs['aux_logits']
        kwargs['aux_logits'] = True
        kwargs['init_weights'] = False

    model = QuantizableGoogLeNet(**kwargs)
    _replace_relu(model)

    if quantize:
        # TODO use pretrained as a string to specify the backend
        backend = 'fbgemm'
        quantize_model(model, backend)
    else:
        assert pretrained in [True, False]

    if pretrained:
        if quantize:
            model_url = quant_model_urls['googlenet' + '_' + backend]
        else:
            model_url = model_urls['googlenet']

        state_dict = load_state_dict_from_url(model_url,
                                              progress=progress)

        model.load_state_dict(state_dict)

        if not original_aux_logits:
            model.aux_logits = False
81
82
            model.aux1 = None  # type: ignore[assignment]
            model.aux2 = None  # type: ignore[assignment]
83
84
85
86
87
    return model


class QuantizableBasicConv2d(BasicConv2d):

88
    def __init__(self, *args: Any, **kwargs: Any) -> None:
89
90
91
        super(QuantizableBasicConv2d, self).__init__(*args, **kwargs)
        self.relu = nn.ReLU()

92
    def forward(self, x: Tensor) -> Tensor:
93
94
95
96
97
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

98
    def fuse_model(self) -> None:
99
100
101
102
103
        torch.quantization.fuse_modules(self, ["conv", "bn", "relu"], inplace=True)


class QuantizableInception(Inception):

104
105
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super(QuantizableInception, self).__init__(  # type: ignore[misc]
106
107
108
            conv_block=QuantizableBasicConv2d, *args, **kwargs)
        self.cat = nn.quantized.FloatFunctional()

109
    def forward(self, x: Tensor) -> Tensor:
110
111
112
113
114
        outputs = self._forward(x)
        return self.cat.cat(outputs, 1)


class QuantizableInceptionAux(InceptionAux):
115
116
117
118
119
120
121
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super(QuantizableInceptionAux, self).__init__(  # type: ignore[misc]
            conv_block=QuantizableBasicConv2d,
            *args,
            **kwargs
        )
122
123
124
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.7)

125
    def forward(self, x: Tensor) -> Tensor:
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
        x = F.adaptive_avg_pool2d(x, (4, 4))
        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
        x = self.conv(x)
        # N x 128 x 4 x 4
        x = torch.flatten(x, 1)
        # N x 2048
        x = self.relu(self.fc1(x))
        # N x 1024
        x = self.dropout(x)
        # N x 1024
        x = self.fc2(x)
        # N x 1000 (num_classes)

        return x


class QuantizableGoogLeNet(GoogLeNet):
144
145
146
    # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super(QuantizableGoogLeNet, self).__init__(  # type: ignore[misc]
147
148
149
150
151
152
153
            blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux],
            *args,
            **kwargs
        )
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

154
    def forward(self, x: Tensor) -> GoogLeNetOutputs:
155
156
157
158
159
160
161
162
163
164
165
166
        x = self._transform_input(x)
        x = self.quant(x)
        x, aux1, aux2 = self._forward(x)
        x = self.dequant(x)
        aux_defined = self.training and self.aux_logits
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted QuantizableGoogleNet always returns GoogleNetOutputs Tuple")
            return GoogLeNetOutputs(x, aux2, aux1)
        else:
            return self.eager_outputs(x, aux2, aux1)

167
    def fuse_model(self) -> None:
168
169
170
171
172
173
174
175
176
177
        r"""Fuse conv/bn/relu modules in googlenet model

        Fuse conv+bn+relu/ conv+relu/conv+bn modules to prepare for quantization.
        Model is modified in place.  Note that this operation does not change numerics
        and the model after modification is in floating point
        """

        for m in self.modules():
            if type(m) == QuantizableBasicConv2d:
                m.fuse_model()