googlenet.py 12.5 KB
Newer Older
1
2
import warnings
from collections import namedtuple
limm's avatar
limm committed
3
4
5
from functools import partial
from typing import Any, Callable, List, Optional, Tuple

6
7
8
import torch
import torch.nn as nn
import torch.nn.functional as F
9
from torch import Tensor
10

limm's avatar
limm committed
11
12
13
14
15
16
17
18
from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface


__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weights", "googlenet"]
19
20


limm's avatar
limm committed
21
22
GoogLeNetOutputs = namedtuple("GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"])
GoogLeNetOutputs.__annotations__ = {"logits": Tensor, "aux_logits2": Optional[Tensor], "aux_logits1": Optional[Tensor]}
23
24
25
26

# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _GoogLeNetOutputs set here for backwards compat
_GoogLeNetOutputs = GoogLeNetOutputs
27

28
29

class GoogLeNet(nn.Module):
limm's avatar
limm committed
30
    __constants__ = ["aux_logits", "transform_input"]
31

32
33
34
35
36
37
    def __init__(
        self,
        num_classes: int = 1000,
        aux_logits: bool = True,
        transform_input: bool = False,
        init_weights: Optional[bool] = None,
limm's avatar
limm committed
38
39
40
        blocks: Optional[List[Callable[..., nn.Module]]] = None,
        dropout: float = 0.2,
        dropout_aux: float = 0.7,
41
    ) -> None:
limm's avatar
limm committed
42
43
        super().__init__()
        _log_api_usage_once(self)
44
45
        if blocks is None:
            blocks = [BasicConv2d, Inception, InceptionAux]
46
        if init_weights is None:
limm's avatar
limm committed
47
48
49
50
51
52
            warnings.warn(
                "The default weight initialization of GoogleNet will be changed in future releases of "
                "torchvision. If you wish to keep the old behavior (which leads to long initialization times"
                " due to scipy/scipy#11299), please set init_weights=True.",
                FutureWarning,
            )
53
            init_weights = True
limm's avatar
limm committed
54
55
        if len(blocks) != 3:
            raise ValueError(f"blocks length should be 3 instead of {len(blocks)}")
56
57
58
59
        conv_block = blocks[0]
        inception_block = blocks[1]
        inception_aux_block = blocks[2]

60
61
62
        self.aux_logits = aux_logits
        self.transform_input = transform_input

63
        self.conv1 = conv_block(3, 64, kernel_size=7, stride=2, padding=3)
64
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
65
66
        self.conv2 = conv_block(64, 64, kernel_size=1)
        self.conv3 = conv_block(64, 192, kernel_size=3, padding=1)
67
68
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

69
70
        self.inception3a = inception_block(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = inception_block(256, 128, 128, 192, 32, 96, 64)
71
72
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

73
74
75
76
77
        self.inception4a = inception_block(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = inception_block(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = inception_block(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = inception_block(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = inception_block(528, 256, 160, 320, 32, 128, 128)
78
79
        self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

80
81
        self.inception5a = inception_block(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = inception_block(832, 384, 192, 384, 48, 128, 128)
82

83
        if aux_logits:
limm's avatar
limm committed
84
85
            self.aux1 = inception_aux_block(512, num_classes, dropout=dropout_aux)
            self.aux2 = inception_aux_block(528, num_classes, dropout=dropout_aux)
86
        else:
87
88
            self.aux1 = None  # type: ignore[assignment]
            self.aux2 = None  # type: ignore[assignment]
89

90
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
limm's avatar
limm committed
91
        self.dropout = nn.Dropout(p=dropout)
92
93
94
        self.fc = nn.Linear(1024, num_classes)

        if init_weights:
limm's avatar
limm committed
95
96
97
98
99
100
            for m in self.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                    torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=0.01, a=-2, b=2)
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
101

102
    def _transform_input(self, x: Tensor) -> Tensor:
103
104
105
106
107
        if self.transform_input:
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
108
        return x
109

110
    def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
111
        # N x 3 x 224 x 224
112
        x = self.conv1(x)
113
        # N x 64 x 112 x 112
114
        x = self.maxpool1(x)
115
        # N x 64 x 56 x 56
116
        x = self.conv2(x)
117
        # N x 64 x 56 x 56
118
        x = self.conv3(x)
119
        # N x 192 x 56 x 56
120
121
        x = self.maxpool2(x)

122
        # N x 192 x 28 x 28
123
        x = self.inception3a(x)
124
        # N x 256 x 28 x 28
125
        x = self.inception3b(x)
126
        # N x 480 x 28 x 28
127
        x = self.maxpool3(x)
128
        # N x 480 x 14 x 14
129
        x = self.inception4a(x)
130
        # N x 512 x 14 x 14
131
        aux1: Optional[Tensor] = None
132
133
134
        if self.aux1 is not None:
            if self.training:
                aux1 = self.aux1(x)
135
136

        x = self.inception4b(x)
137
        # N x 512 x 14 x 14
138
        x = self.inception4c(x)
139
        # N x 512 x 14 x 14
140
        x = self.inception4d(x)
141
        # N x 528 x 14 x 14
142
        aux2: Optional[Tensor] = None
143
144
145
        if self.aux2 is not None:
            if self.training:
                aux2 = self.aux2(x)
146
147

        x = self.inception4e(x)
148
        # N x 832 x 14 x 14
149
        x = self.maxpool4(x)
150
        # N x 832 x 7 x 7
151
        x = self.inception5a(x)
152
        # N x 832 x 7 x 7
153
        x = self.inception5b(x)
154
        # N x 1024 x 7 x 7
155
156

        x = self.avgpool(x)
157
        # N x 1024 x 1 x 1
158
        x = torch.flatten(x, 1)
159
        # N x 1024
160
161
        x = self.dropout(x)
        x = self.fc(x)
162
        # N x 1000 (num_classes)
163
        return x, aux2, aux1
164
165

    @torch.jit.unused
166
    def eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> GoogLeNetOutputs:
167
        if self.training and self.aux_logits:
taylanbil's avatar
taylanbil committed
168
            return _GoogLeNetOutputs(x, aux2, aux1)
169
        else:
limm's avatar
limm committed
170
            return x  # type: ignore[return-value]
171

172
    def forward(self, x: Tensor) -> GoogLeNetOutputs:
173
174
175
176
177
178
179
180
181
182
        x = self._transform_input(x)
        x, aux1, aux2 = self._forward(x)
        aux_defined = self.training and self.aux_logits
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")
            return GoogLeNetOutputs(x, aux2, aux1)
        else:
            return self.eager_outputs(x, aux2, aux1)

183
184

class Inception(nn.Module):
185
186
187
188
189
190
191
192
193
    def __init__(
        self,
        in_channels: int,
        ch1x1: int,
        ch3x3red: int,
        ch3x3: int,
        ch5x5red: int,
        ch5x5: int,
        pool_proj: int,
limm's avatar
limm committed
194
        conv_block: Optional[Callable[..., nn.Module]] = None,
195
    ) -> None:
limm's avatar
limm committed
196
        super().__init__()
197
198
199
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1)
200
201

        self.branch2 = nn.Sequential(
limm's avatar
limm committed
202
            conv_block(in_channels, ch3x3red, kernel_size=1), conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1)
203
204
205
        )

        self.branch3 = nn.Sequential(
206
            conv_block(in_channels, ch5x5red, kernel_size=1),
Philip Meier's avatar
Philip Meier committed
207
208
            # Here, kernel_size=3 instead of kernel_size=5 is a known bug.
            # Please see https://github.com/pytorch/vision/issues/906 for details.
limm's avatar
limm committed
209
            conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1),
210
211
212
213
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
limm's avatar
limm committed
214
            conv_block(in_channels, pool_proj, kernel_size=1),
215
216
        )

217
    def _forward(self, x: Tensor) -> List[Tensor]:
218
219
220
221
222
223
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
224
225
        return outputs

226
    def forward(self, x: Tensor) -> Tensor:
227
        outputs = self._forward(x)
228
229
230
231
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):
232
233
234
235
    def __init__(
        self,
        in_channels: int,
        num_classes: int,
limm's avatar
limm committed
236
237
        conv_block: Optional[Callable[..., nn.Module]] = None,
        dropout: float = 0.7,
238
    ) -> None:
limm's avatar
limm committed
239
        super().__init__()
240
241
242
        if conv_block is None:
            conv_block = BasicConv2d
        self.conv = conv_block(in_channels, 128, kernel_size=1)
243
244
245

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)
limm's avatar
limm committed
246
        self.dropout = nn.Dropout(p=dropout)
247

248
    def forward(self, x: Tensor) -> Tensor:
249
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
250
        x = F.adaptive_avg_pool2d(x, (4, 4))
251
        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
252
        x = self.conv(x)
253
        # N x 128 x 4 x 4
254
        x = torch.flatten(x, 1)
255
        # N x 2048
256
        x = F.relu(self.fc1(x), inplace=True)
Myosaki's avatar
Myosaki committed
257
        # N x 1024
limm's avatar
limm committed
258
        x = self.dropout(x)
259
        # N x 1024
Myosaki's avatar
Myosaki committed
260
261
        x = self.fc2(x)
        # N x 1000 (num_classes)
262
263
264
265
266

        return x


class BasicConv2d(nn.Module):
limm's avatar
limm committed
267
268
    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
        super().__init__()
269
270
271
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

272
    def forward(self, x: Tensor) -> Tensor:
273
274
275
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)
limm's avatar
limm committed
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345


class GoogLeNet_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/googlenet-1378be20.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            "num_params": 6624904,
            "min_size": (15, 15),
            "categories": _IMAGENET_CATEGORIES,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#googlenet",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 69.778,
                    "acc@5": 89.530,
                }
            },
            "_ops": 1.498,
            "_file_size": 49.731,
            "_docs": """These weights are ported from the original paper.""",
        },
    )
    DEFAULT = IMAGENET1K_V1


@register_model()
@handle_legacy_interface(weights=("pretrained", GoogLeNet_Weights.IMAGENET1K_V1))
def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
    """GoogLeNet (Inception v1) model architecture from
    `Going Deeper with Convolutions <http://arxiv.org/abs/1409.4842>`_.

    Args:
        weights (:class:`~torchvision.models.GoogLeNet_Weights`, optional): The
            pretrained weights for the model. See
            :class:`~torchvision.models.GoogLeNet_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.GoogLeNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/googlenet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.GoogLeNet_Weights
        :members:
    """
    weights = GoogLeNet_Weights.verify(weights)

    original_aux_logits = kwargs.get("aux_logits", False)
    if weights is not None:
        if "transform_input" not in kwargs:
            _ovewrite_named_param(kwargs, "transform_input", True)
        _ovewrite_named_param(kwargs, "aux_logits", True)
        _ovewrite_named_param(kwargs, "init_weights", False)
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

    model = GoogLeNet(**kwargs)

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
        if not original_aux_logits:
            model.aux_logits = False
            model.aux1 = None  # type: ignore[assignment]
            model.aux2 = None  # type: ignore[assignment]
        else:
            warnings.warn(
                "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
            )

    return model