googlenet.py 11.2 KB
Newer Older
1
2
import warnings
from collections import namedtuple
3
4
from typing import Optional, Tuple, List, Callable, Any

5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F
8
from torch import Tensor
9

10
from .._internally_replaced_utils import load_state_dict_from_url
11
from ..utils import _log_api_usage_once
12

13
__all__ = ["GoogLeNet", "googlenet", "GoogLeNetOutputs", "_GoogLeNetOutputs"]
14
15
16

model_urls = {
    # GoogLeNet ported from TensorFlow
17
    "googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth",
18
19
}

20
21
GoogLeNetOutputs = namedtuple("GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"])
GoogLeNetOutputs.__annotations__ = {"logits": Tensor, "aux_logits2": Optional[Tensor], "aux_logits1": Optional[Tensor]}
22
23
24
25

# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _GoogLeNetOutputs set here for backwards compat
_GoogLeNetOutputs = GoogLeNetOutputs
26

27

28
def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "GoogLeNet":
29
30
    r"""GoogLeNet (Inception v1) model architecture from
    `"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
31
    The required minimum input size of the model is 15x15.
ekka's avatar
ekka committed
32

33
34
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
35
        progress (bool): If True, displays a progress bar of the download to stderr
36
        aux_logits (bool): If True, adds two auxiliary branches that can improve training.
37
            Default: *False* when pretrained is True otherwise *True*
38
        transform_input (bool): If True, preprocesses the input according to the method with which it
ekka's avatar
ekka committed
39
            was trained on ImageNet. Default: *False*
40
41
    """
    if pretrained:
42
43
44
45
46
47
        if "transform_input" not in kwargs:
            kwargs["transform_input"] = True
        if "aux_logits" not in kwargs:
            kwargs["aux_logits"] = False
        if kwargs["aux_logits"]:
            warnings.warn(
48
                "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
49
50
51
52
            )
        original_aux_logits = kwargs["aux_logits"]
        kwargs["aux_logits"] = True
        kwargs["init_weights"] = False
53
        model = GoogLeNet(**kwargs)
54
        state_dict = load_state_dict_from_url(model_urls["googlenet"], progress=progress)
55
        model.load_state_dict(state_dict)
56
57
        if not original_aux_logits:
            model.aux_logits = False
58
59
            model.aux1 = None  # type: ignore[assignment]
            model.aux2 = None  # type: ignore[assignment]
60
61
62
63
64
65
        return model

    return GoogLeNet(**kwargs)


class GoogLeNet(nn.Module):
66
    __constants__ = ["aux_logits", "transform_input"]
67

68
69
70
71
72
73
    def __init__(
        self,
        num_classes: int = 1000,
        aux_logits: bool = True,
        transform_input: bool = False,
        init_weights: Optional[bool] = None,
74
        blocks: Optional[List[Callable[..., nn.Module]]] = None,
75
76
        dropout: float = 0.2,
        dropout_aux: float = 0.7,
77
    ) -> None:
78
        super().__init__()
79
        _log_api_usage_once(self)
80
81
        if blocks is None:
            blocks = [BasicConv2d, Inception, InceptionAux]
82
        if init_weights is None:
83
84
85
86
87
88
            warnings.warn(
                "The default weight initialization of GoogleNet will be changed in future releases of "
                "torchvision. If you wish to keep the old behavior (which leads to long initialization times"
                " due to scipy/scipy#11299), please set init_weights=True.",
                FutureWarning,
            )
89
            init_weights = True
90
91
92
93
94
        assert len(blocks) == 3
        conv_block = blocks[0]
        inception_block = blocks[1]
        inception_aux_block = blocks[2]

95
96
97
        self.aux_logits = aux_logits
        self.transform_input = transform_input

98
        self.conv1 = conv_block(3, 64, kernel_size=7, stride=2, padding=3)
99
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
100
101
        self.conv2 = conv_block(64, 64, kernel_size=1)
        self.conv3 = conv_block(64, 192, kernel_size=3, padding=1)
102
103
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

104
105
        self.inception3a = inception_block(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = inception_block(256, 128, 128, 192, 32, 96, 64)
106
107
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

108
109
110
111
112
        self.inception4a = inception_block(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = inception_block(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = inception_block(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = inception_block(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = inception_block(528, 256, 160, 320, 32, 128, 128)
113
114
        self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

115
116
        self.inception5a = inception_block(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = inception_block(832, 384, 192, 384, 48, 128, 128)
117

118
        if aux_logits:
119
120
            self.aux1 = inception_aux_block(512, num_classes, dropout=dropout_aux)
            self.aux2 = inception_aux_block(528, num_classes, dropout=dropout_aux)
121
        else:
122
123
            self.aux1 = None  # type: ignore[assignment]
            self.aux2 = None  # type: ignore[assignment]
124

125
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
126
        self.dropout = nn.Dropout(p=dropout)
127
128
129
130
131
        self.fc = nn.Linear(1024, num_classes)

        if init_weights:
            self._initialize_weights()

132
    def _initialize_weights(self) -> None:
133
        for m in self.modules():
134
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
135
                torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=0.01, a=-2, b=2)
136
137
138
139
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

140
    def _transform_input(self, x: Tensor) -> Tensor:
141
142
143
144
145
        if self.transform_input:
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
146
        return x
147

148
    def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
149
        # N x 3 x 224 x 224
150
        x = self.conv1(x)
151
        # N x 64 x 112 x 112
152
        x = self.maxpool1(x)
153
        # N x 64 x 56 x 56
154
        x = self.conv2(x)
155
        # N x 64 x 56 x 56
156
        x = self.conv3(x)
157
        # N x 192 x 56 x 56
158
159
        x = self.maxpool2(x)

160
        # N x 192 x 28 x 28
161
        x = self.inception3a(x)
162
        # N x 256 x 28 x 28
163
        x = self.inception3b(x)
164
        # N x 480 x 28 x 28
165
        x = self.maxpool3(x)
166
        # N x 480 x 14 x 14
167
        x = self.inception4a(x)
168
        # N x 512 x 14 x 14
169
        aux1: Optional[Tensor] = None
170
171
172
        if self.aux1 is not None:
            if self.training:
                aux1 = self.aux1(x)
173
174

        x = self.inception4b(x)
175
        # N x 512 x 14 x 14
176
        x = self.inception4c(x)
177
        # N x 512 x 14 x 14
178
        x = self.inception4d(x)
179
        # N x 528 x 14 x 14
180
        aux2: Optional[Tensor] = None
181
182
183
        if self.aux2 is not None:
            if self.training:
                aux2 = self.aux2(x)
184
185

        x = self.inception4e(x)
186
        # N x 832 x 14 x 14
187
        x = self.maxpool4(x)
188
        # N x 832 x 7 x 7
189
        x = self.inception5a(x)
190
        # N x 832 x 7 x 7
191
        x = self.inception5b(x)
192
        # N x 1024 x 7 x 7
193
194

        x = self.avgpool(x)
195
        # N x 1024 x 1 x 1
196
        x = torch.flatten(x, 1)
197
        # N x 1024
198
199
        x = self.dropout(x)
        x = self.fc(x)
200
        # N x 1000 (num_classes)
201
        return x, aux2, aux1
202
203

    @torch.jit.unused
204
    def eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> GoogLeNetOutputs:
205
        if self.training and self.aux_logits:
taylanbil's avatar
taylanbil committed
206
            return _GoogLeNetOutputs(x, aux2, aux1)
207
        else:
208
            return x  # type: ignore[return-value]
209

210
    def forward(self, x: Tensor) -> GoogLeNetOutputs:
211
212
213
214
215
216
217
218
219
220
        x = self._transform_input(x)
        x, aux1, aux2 = self._forward(x)
        aux_defined = self.training and self.aux_logits
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")
            return GoogLeNetOutputs(x, aux2, aux1)
        else:
            return self.eager_outputs(x, aux2, aux1)

221
222

class Inception(nn.Module):
223
224
225
226
227
228
229
230
231
    def __init__(
        self,
        in_channels: int,
        ch1x1: int,
        ch3x3red: int,
        ch3x3: int,
        ch5x5red: int,
        ch5x5: int,
        pool_proj: int,
232
        conv_block: Optional[Callable[..., nn.Module]] = None,
233
    ) -> None:
234
        super().__init__()
235
236
237
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1)
238
239

        self.branch2 = nn.Sequential(
240
            conv_block(in_channels, ch3x3red, kernel_size=1), conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1)
241
242
243
        )

        self.branch3 = nn.Sequential(
244
            conv_block(in_channels, ch5x5red, kernel_size=1),
Philip Meier's avatar
Philip Meier committed
245
246
            # Here, kernel_size=3 instead of kernel_size=5 is a known bug.
            # Please see https://github.com/pytorch/vision/issues/906 for details.
247
            conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1),
248
249
250
251
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
252
            conv_block(in_channels, pool_proj, kernel_size=1),
253
254
        )

255
    def _forward(self, x: Tensor) -> List[Tensor]:
256
257
258
259
260
261
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
262
263
        return outputs

264
    def forward(self, x: Tensor) -> Tensor:
265
        outputs = self._forward(x)
266
267
268
269
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):
270
    def __init__(
271
272
273
274
275
        self,
        in_channels: int,
        num_classes: int,
        conv_block: Optional[Callable[..., nn.Module]] = None,
        dropout: float = 0.7,
276
    ) -> None:
277
        super().__init__()
278
279
280
        if conv_block is None:
            conv_block = BasicConv2d
        self.conv = conv_block(in_channels, 128, kernel_size=1)
281
282
283

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)
284
        self.dropout = nn.Dropout(p=dropout)
285

286
    def forward(self, x: Tensor) -> Tensor:
287
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
288
        x = F.adaptive_avg_pool2d(x, (4, 4))
289
        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
290
        x = self.conv(x)
291
        # N x 128 x 4 x 4
292
        x = torch.flatten(x, 1)
293
        # N x 2048
294
        x = F.relu(self.fc1(x), inplace=True)
Myosaki's avatar
Myosaki committed
295
        # N x 1024
296
        x = self.dropout(x)
297
        # N x 1024
Myosaki's avatar
Myosaki committed
298
299
        x = self.fc2(x)
        # N x 1000 (num_classes)
300
301
302
303
304

        return x


class BasicConv2d(nn.Module):
305
    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
306
        super().__init__()
307
308
309
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

310
    def forward(self, x: Tensor) -> Tensor:
311
312
313
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)