googlenet.py 11.2 KB
Newer Older
1
2
import warnings
from collections import namedtuple
3
4
from typing import Optional, Tuple, List, Callable, Any

5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F
8
from torch import Tensor
9

10
from .._internally_replaced_utils import load_state_dict_from_url
11
from ..utils import _log_api_usage_once
12

13
__all__ = ["GoogLeNet", "googlenet", "GoogLeNetOutputs", "_GoogLeNetOutputs"]
14
15
16

model_urls = {
    # GoogLeNet ported from TensorFlow
17
    "googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth",
18
19
}

20
21
GoogLeNetOutputs = namedtuple("GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"])
GoogLeNetOutputs.__annotations__ = {"logits": Tensor, "aux_logits2": Optional[Tensor], "aux_logits1": Optional[Tensor]}
22
23
24
25

# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _GoogLeNetOutputs set here for backwards compat
_GoogLeNetOutputs = GoogLeNetOutputs
26

27
28

class GoogLeNet(nn.Module):
29
    __constants__ = ["aux_logits", "transform_input"]
30

31
32
33
34
35
36
    def __init__(
        self,
        num_classes: int = 1000,
        aux_logits: bool = True,
        transform_input: bool = False,
        init_weights: Optional[bool] = None,
37
        blocks: Optional[List[Callable[..., nn.Module]]] = None,
38
39
        dropout: float = 0.2,
        dropout_aux: float = 0.7,
40
    ) -> None:
41
        super().__init__()
42
        _log_api_usage_once(self)
43
44
        if blocks is None:
            blocks = [BasicConv2d, Inception, InceptionAux]
45
        if init_weights is None:
46
47
48
49
50
51
            warnings.warn(
                "The default weight initialization of GoogleNet will be changed in future releases of "
                "torchvision. If you wish to keep the old behavior (which leads to long initialization times"
                " due to scipy/scipy#11299), please set init_weights=True.",
                FutureWarning,
            )
52
            init_weights = True
53
54
55
56
57
        assert len(blocks) == 3
        conv_block = blocks[0]
        inception_block = blocks[1]
        inception_aux_block = blocks[2]

58
59
60
        self.aux_logits = aux_logits
        self.transform_input = transform_input

61
        self.conv1 = conv_block(3, 64, kernel_size=7, stride=2, padding=3)
62
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
63
64
        self.conv2 = conv_block(64, 64, kernel_size=1)
        self.conv3 = conv_block(64, 192, kernel_size=3, padding=1)
65
66
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

67
68
        self.inception3a = inception_block(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = inception_block(256, 128, 128, 192, 32, 96, 64)
69
70
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

71
72
73
74
75
        self.inception4a = inception_block(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = inception_block(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = inception_block(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = inception_block(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = inception_block(528, 256, 160, 320, 32, 128, 128)
76
77
        self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

78
79
        self.inception5a = inception_block(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = inception_block(832, 384, 192, 384, 48, 128, 128)
80

81
        if aux_logits:
82
83
            self.aux1 = inception_aux_block(512, num_classes, dropout=dropout_aux)
            self.aux2 = inception_aux_block(528, num_classes, dropout=dropout_aux)
84
        else:
85
86
            self.aux1 = None  # type: ignore[assignment]
            self.aux2 = None  # type: ignore[assignment]
87

88
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
89
        self.dropout = nn.Dropout(p=dropout)
90
91
92
93
94
        self.fc = nn.Linear(1024, num_classes)

        if init_weights:
            self._initialize_weights()

95
    def _initialize_weights(self) -> None:
96
        for m in self.modules():
97
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
98
                torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=0.01, a=-2, b=2)
99
100
101
102
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

103
    def _transform_input(self, x: Tensor) -> Tensor:
104
105
106
107
108
        if self.transform_input:
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
109
        return x
110

111
    def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
112
        # N x 3 x 224 x 224
113
        x = self.conv1(x)
114
        # N x 64 x 112 x 112
115
        x = self.maxpool1(x)
116
        # N x 64 x 56 x 56
117
        x = self.conv2(x)
118
        # N x 64 x 56 x 56
119
        x = self.conv3(x)
120
        # N x 192 x 56 x 56
121
122
        x = self.maxpool2(x)

123
        # N x 192 x 28 x 28
124
        x = self.inception3a(x)
125
        # N x 256 x 28 x 28
126
        x = self.inception3b(x)
127
        # N x 480 x 28 x 28
128
        x = self.maxpool3(x)
129
        # N x 480 x 14 x 14
130
        x = self.inception4a(x)
131
        # N x 512 x 14 x 14
132
        aux1: Optional[Tensor] = None
133
134
135
        if self.aux1 is not None:
            if self.training:
                aux1 = self.aux1(x)
136
137

        x = self.inception4b(x)
138
        # N x 512 x 14 x 14
139
        x = self.inception4c(x)
140
        # N x 512 x 14 x 14
141
        x = self.inception4d(x)
142
        # N x 528 x 14 x 14
143
        aux2: Optional[Tensor] = None
144
145
146
        if self.aux2 is not None:
            if self.training:
                aux2 = self.aux2(x)
147
148

        x = self.inception4e(x)
149
        # N x 832 x 14 x 14
150
        x = self.maxpool4(x)
151
        # N x 832 x 7 x 7
152
        x = self.inception5a(x)
153
        # N x 832 x 7 x 7
154
        x = self.inception5b(x)
155
        # N x 1024 x 7 x 7
156
157

        x = self.avgpool(x)
158
        # N x 1024 x 1 x 1
159
        x = torch.flatten(x, 1)
160
        # N x 1024
161
162
        x = self.dropout(x)
        x = self.fc(x)
163
        # N x 1000 (num_classes)
164
        return x, aux2, aux1
165
166

    @torch.jit.unused
167
    def eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> GoogLeNetOutputs:
168
        if self.training and self.aux_logits:
taylanbil's avatar
taylanbil committed
169
            return _GoogLeNetOutputs(x, aux2, aux1)
170
        else:
171
            return x  # type: ignore[return-value]
172

173
    def forward(self, x: Tensor) -> GoogLeNetOutputs:
174
175
176
177
178
179
180
181
182
183
        x = self._transform_input(x)
        x, aux1, aux2 = self._forward(x)
        aux_defined = self.training and self.aux_logits
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")
            return GoogLeNetOutputs(x, aux2, aux1)
        else:
            return self.eager_outputs(x, aux2, aux1)

184
185

class Inception(nn.Module):
186
187
188
189
190
191
192
193
194
    def __init__(
        self,
        in_channels: int,
        ch1x1: int,
        ch3x3red: int,
        ch3x3: int,
        ch5x5red: int,
        ch5x5: int,
        pool_proj: int,
195
        conv_block: Optional[Callable[..., nn.Module]] = None,
196
    ) -> None:
197
        super().__init__()
198
199
200
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1)
201
202

        self.branch2 = nn.Sequential(
203
            conv_block(in_channels, ch3x3red, kernel_size=1), conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1)
204
205
206
        )

        self.branch3 = nn.Sequential(
207
            conv_block(in_channels, ch5x5red, kernel_size=1),
Philip Meier's avatar
Philip Meier committed
208
209
            # Here, kernel_size=3 instead of kernel_size=5 is a known bug.
            # Please see https://github.com/pytorch/vision/issues/906 for details.
210
            conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1),
211
212
213
214
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
215
            conv_block(in_channels, pool_proj, kernel_size=1),
216
217
        )

218
    def _forward(self, x: Tensor) -> List[Tensor]:
219
220
221
222
223
224
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
225
226
        return outputs

227
    def forward(self, x: Tensor) -> Tensor:
228
        outputs = self._forward(x)
229
230
231
232
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):
233
    def __init__(
234
235
236
237
238
        self,
        in_channels: int,
        num_classes: int,
        conv_block: Optional[Callable[..., nn.Module]] = None,
        dropout: float = 0.7,
239
    ) -> None:
240
        super().__init__()
241
242
243
        if conv_block is None:
            conv_block = BasicConv2d
        self.conv = conv_block(in_channels, 128, kernel_size=1)
244
245
246

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)
247
        self.dropout = nn.Dropout(p=dropout)
248

249
    def forward(self, x: Tensor) -> Tensor:
250
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
251
        x = F.adaptive_avg_pool2d(x, (4, 4))
252
        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
253
        x = self.conv(x)
254
        # N x 128 x 4 x 4
255
        x = torch.flatten(x, 1)
256
        # N x 2048
257
        x = F.relu(self.fc1(x), inplace=True)
Myosaki's avatar
Myosaki committed
258
        # N x 1024
259
        x = self.dropout(x)
260
        # N x 1024
Myosaki's avatar
Myosaki committed
261
262
        x = self.fc2(x)
        # N x 1000 (num_classes)
263
264
265
266
267

        return x


class BasicConv2d(nn.Module):
268
    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
269
        super().__init__()
270
271
272
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

273
    def forward(self, x: Tensor) -> Tensor:
274
275
276
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313


def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> GoogLeNet:
    r"""GoogLeNet (Inception v1) model architecture from
    `"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
    The required minimum input size of the model is 15x15.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
        aux_logits (bool): If True, adds two auxiliary branches that can improve training.
            Default: *False* when pretrained is True otherwise *True*
        transform_input (bool): If True, preprocesses the input according to the method with which it
            was trained on ImageNet. Default: *False*
    """
    if pretrained:
        if "transform_input" not in kwargs:
            kwargs["transform_input"] = True
        if "aux_logits" not in kwargs:
            kwargs["aux_logits"] = False
        if kwargs["aux_logits"]:
            warnings.warn(
                "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
            )
        original_aux_logits = kwargs["aux_logits"]
        kwargs["aux_logits"] = True
        kwargs["init_weights"] = False
        model = GoogLeNet(**kwargs)
        state_dict = load_state_dict_from_url(model_urls["googlenet"], progress=progress)
        model.load_state_dict(state_dict)
        if not original_aux_logits:
            model.aux_logits = False
            model.aux1 = None  # type: ignore[assignment]
            model.aux2 = None  # type: ignore[assignment]
        return model

    return GoogLeNet(**kwargs)