googlenet.py 11.2 KB
Newer Older
1
2
import warnings
from collections import namedtuple
3
4
5
import torch
import torch.nn as nn
import torch.nn.functional as F
6
from torch import Tensor
7
from .._internally_replaced_utils import load_state_dict_from_url
8
from typing import Optional, Tuple, List, Callable, Any
9

10
__all__ = ['GoogLeNet', 'googlenet', "GoogLeNetOutputs", "_GoogLeNetOutputs"]
11
12
13
14
15
16

model_urls = {
    # GoogLeNet ported from TensorFlow
    'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth',
}

17
18
19
20
21
22
23
GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])
GoogLeNetOutputs.__annotations__ = {'logits': Tensor, 'aux_logits2': Optional[Tensor],
                                    'aux_logits1': Optional[Tensor]}

# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _GoogLeNetOutputs set here for backwards compat
_GoogLeNetOutputs = GoogLeNetOutputs
24

25

26
def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "GoogLeNet":
27
28
    r"""GoogLeNet (Inception v1) model architecture from
    `"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
29
    The required minimum input size of the model is 15x15.
ekka's avatar
ekka committed
30

31
32
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
33
        progress (bool): If True, displays a progress bar of the download to stderr
34
        aux_logits (bool): If True, adds two auxiliary branches that can improve training.
35
            Default: *False* when pretrained is True otherwise *True*
36
        transform_input (bool): If True, preprocesses the input according to the method with which it
ekka's avatar
ekka committed
37
            was trained on ImageNet. Default: *False*
38
39
40
41
    """
    if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
42
43
        if 'aux_logits' not in kwargs:
            kwargs['aux_logits'] = False
44
        if kwargs['aux_logits']:
Francisco Massa's avatar
Francisco Massa committed
45
46
            warnings.warn('auxiliary heads in the pretrained googlenet model are NOT pretrained, '
                          'so make sure to train them')
47
48
        original_aux_logits = kwargs['aux_logits']
        kwargs['aux_logits'] = True
49
50
        kwargs['init_weights'] = False
        model = GoogLeNet(**kwargs)
51
52
53
        state_dict = load_state_dict_from_url(model_urls['googlenet'],
                                              progress=progress)
        model.load_state_dict(state_dict)
54
55
        if not original_aux_logits:
            model.aux_logits = False
56
57
            model.aux1 = None  # type: ignore[assignment]
            model.aux2 = None  # type: ignore[assignment]
58
59
60
61
62
63
        return model

    return GoogLeNet(**kwargs)


class GoogLeNet(nn.Module):
64
    __constants__ = ['aux_logits', 'transform_input']
65

66
67
68
69
70
71
72
73
    def __init__(
        self,
        num_classes: int = 1000,
        aux_logits: bool = True,
        transform_input: bool = False,
        init_weights: Optional[bool] = None,
        blocks: Optional[List[Callable[..., nn.Module]]] = None
    ) -> None:
74
        super(GoogLeNet, self).__init__()
75
76
        if blocks is None:
            blocks = [BasicConv2d, Inception, InceptionAux]
77
78
79
80
81
        if init_weights is None:
            warnings.warn('The default weight initialization of GoogleNet will be changed in future releases of '
                          'torchvision. If you wish to keep the old behavior (which leads to long initialization times'
                          ' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning)
            init_weights = True
82
83
84
85
86
        assert len(blocks) == 3
        conv_block = blocks[0]
        inception_block = blocks[1]
        inception_aux_block = blocks[2]

87
88
89
        self.aux_logits = aux_logits
        self.transform_input = transform_input

90
        self.conv1 = conv_block(3, 64, kernel_size=7, stride=2, padding=3)
91
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
92
93
        self.conv2 = conv_block(64, 64, kernel_size=1)
        self.conv3 = conv_block(64, 192, kernel_size=3, padding=1)
94
95
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

96
97
        self.inception3a = inception_block(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = inception_block(256, 128, 128, 192, 32, 96, 64)
98
99
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

100
101
102
103
104
        self.inception4a = inception_block(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = inception_block(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = inception_block(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = inception_block(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = inception_block(528, 256, 160, 320, 32, 128, 128)
105
106
        self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

107
108
        self.inception5a = inception_block(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = inception_block(832, 384, 192, 384, 48, 128, 128)
109

110
        if aux_logits:
111
112
            self.aux1 = inception_aux_block(512, num_classes)
            self.aux2 = inception_aux_block(528, num_classes)
113
        else:
114
115
            self.aux1 = None  # type: ignore[assignment]
            self.aux2 = None  # type: ignore[assignment]
116

117
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
118
        self.dropout = nn.Dropout(0.2)
119
120
121
122
123
        self.fc = nn.Linear(1024, num_classes)

        if init_weights:
            self._initialize_weights()

124
    def _initialize_weights(self) -> None:
125
        for m in self.modules():
126
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
127
                torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=0.01, a=-2, b=2)
128
129
130
131
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

132
    def _transform_input(self, x: Tensor) -> Tensor:
133
134
135
136
137
        if self.transform_input:
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
138
        return x
139

140
    def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
141
        # N x 3 x 224 x 224
142
        x = self.conv1(x)
143
        # N x 64 x 112 x 112
144
        x = self.maxpool1(x)
145
        # N x 64 x 56 x 56
146
        x = self.conv2(x)
147
        # N x 64 x 56 x 56
148
        x = self.conv3(x)
149
        # N x 192 x 56 x 56
150
151
        x = self.maxpool2(x)

152
        # N x 192 x 28 x 28
153
        x = self.inception3a(x)
154
        # N x 256 x 28 x 28
155
        x = self.inception3b(x)
156
        # N x 480 x 28 x 28
157
        x = self.maxpool3(x)
158
        # N x 480 x 14 x 14
159
        x = self.inception4a(x)
160
        # N x 512 x 14 x 14
161
        aux1: Optional[Tensor] = None
162
163
164
        if self.aux1 is not None:
            if self.training:
                aux1 = self.aux1(x)
165
166

        x = self.inception4b(x)
167
        # N x 512 x 14 x 14
168
        x = self.inception4c(x)
169
        # N x 512 x 14 x 14
170
        x = self.inception4d(x)
171
        # N x 528 x 14 x 14
172
        aux2: Optional[Tensor] = None
173
174
175
        if self.aux2 is not None:
            if self.training:
                aux2 = self.aux2(x)
176
177

        x = self.inception4e(x)
178
        # N x 832 x 14 x 14
179
        x = self.maxpool4(x)
180
        # N x 832 x 7 x 7
181
        x = self.inception5a(x)
182
        # N x 832 x 7 x 7
183
        x = self.inception5b(x)
184
        # N x 1024 x 7 x 7
185
186

        x = self.avgpool(x)
187
        # N x 1024 x 1 x 1
188
        x = torch.flatten(x, 1)
189
        # N x 1024
190
191
        x = self.dropout(x)
        x = self.fc(x)
192
        # N x 1000 (num_classes)
193
        return x, aux2, aux1
194
195

    @torch.jit.unused
196
    def eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> GoogLeNetOutputs:
197
        if self.training and self.aux_logits:
taylanbil's avatar
taylanbil committed
198
            return _GoogLeNetOutputs(x, aux2, aux1)
199
        else:
200
            return x   # type: ignore[return-value]
201

202
    def forward(self, x: Tensor) -> GoogLeNetOutputs:
203
204
205
206
207
208
209
210
211
212
        x = self._transform_input(x)
        x, aux1, aux2 = self._forward(x)
        aux_defined = self.training and self.aux_logits
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")
            return GoogLeNetOutputs(x, aux2, aux1)
        else:
            return self.eager_outputs(x, aux2, aux1)

213
214
215

class Inception(nn.Module):

216
217
218
219
220
221
222
223
224
225
226
    def __init__(
        self,
        in_channels: int,
        ch1x1: int,
        ch3x3red: int,
        ch3x3: int,
        ch5x5red: int,
        ch5x5: int,
        pool_proj: int,
        conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:
227
        super(Inception, self).__init__()
228
229
230
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1)
231
232

        self.branch2 = nn.Sequential(
233
234
            conv_block(in_channels, ch3x3red, kernel_size=1),
            conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1)
235
236
237
        )

        self.branch3 = nn.Sequential(
238
            conv_block(in_channels, ch5x5red, kernel_size=1),
Philip Meier's avatar
Philip Meier committed
239
240
            # Here, kernel_size=3 instead of kernel_size=5 is a known bug.
            # Please see https://github.com/pytorch/vision/issues/906 for details.
241
            conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1)
242
243
244
245
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
246
            conv_block(in_channels, pool_proj, kernel_size=1)
247
248
        )

249
    def _forward(self, x: Tensor) -> List[Tensor]:
250
251
252
253
254
255
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
256
257
        return outputs

258
    def forward(self, x: Tensor) -> Tensor:
259
        outputs = self._forward(x)
260
261
262
263
264
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):

265
266
267
268
269
270
    def __init__(
        self,
        in_channels: int,
        num_classes: int,
        conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:
271
        super(InceptionAux, self).__init__()
272
273
274
        if conv_block is None:
            conv_block = BasicConv2d
        self.conv = conv_block(in_channels, 128, kernel_size=1)
275
276
277
278

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

279
    def forward(self, x: Tensor) -> Tensor:
280
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
281
        x = F.adaptive_avg_pool2d(x, (4, 4))
282
        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
283
        x = self.conv(x)
284
        # N x 128 x 4 x 4
285
        x = torch.flatten(x, 1)
286
        # N x 2048
287
        x = F.relu(self.fc1(x), inplace=True)
Myosaki's avatar
Myosaki committed
288
        # N x 1024
289
        x = F.dropout(x, 0.7, training=self.training)
290
        # N x 1024
Myosaki's avatar
Myosaki committed
291
292
        x = self.fc2(x)
        # N x 1000 (num_classes)
293
294
295
296
297
298

        return x


class BasicConv2d(nn.Module):

299
300
301
302
303
304
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        **kwargs: Any
    ) -> None:
305
306
307
308
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

309
    def forward(self, x: Tensor) -> Tensor:
310
311
312
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)