inception.py 17.3 KB
Newer Older
1
from collections import namedtuple
2
import warnings
3
import torch
4
from torch import nn, Tensor
5
import torch.nn.functional as F
6
from .._internally_replaced_utils import load_state_dict_from_url
7
from typing import Callable, Any, Optional, Tuple, List
8
9


10
__all__ = ['Inception3', 'inception_v3', 'InceptionOutputs', '_InceptionOutputs']
11
12
13
14


model_urls = {
    # Inception v3 ported from TensorFlow
15
    'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth',
16
17
}

18
InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
19
InceptionOutputs.__annotations__ = {'logits': Tensor, 'aux_logits': Optional[Tensor]}
20
21
22
23

# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _InceptionOutputs set here for backwards compat
_InceptionOutputs = InceptionOutputs
24

25

26
def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "Inception3":
27
28
    r"""Inception v3 model architecture from
    `"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
29
    The required minimum input size of the model is 75x75.
30

31
32
    .. note::
        **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
33
        N x 3 x 299 x 299, so ensure your images are sized accordingly.
34

35
36
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
37
        progress (bool): If True, displays a progress bar of the download to stderr
38
39
        aux_logits (bool): If True, add an auxiliary branch that can improve training.
            Default: *True*
40
        transform_input (bool): If True, preprocesses the input according to the method with which it
41
            was trained on ImageNet. Default: *False*
42
43
44
45
    """
    if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
46
47
48
49
50
        if 'aux_logits' in kwargs:
            original_aux_logits = kwargs['aux_logits']
            kwargs['aux_logits'] = True
        else:
            original_aux_logits = True
51
        kwargs['init_weights'] = False  # we are loading weights from a pretrained model
52
        model = Inception3(**kwargs)
53
54
55
        state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
                                              progress=progress)
        model.load_state_dict(state_dict)
56
57
        if not original_aux_logits:
            model.aux_logits = False
58
            model.AuxLogits = None
59
60
61
62
63
64
        return model

    return Inception3(**kwargs)


class Inception3(nn.Module):
soumith's avatar
soumith committed
65

66
67
68
69
70
71
72
73
    def __init__(
        self,
        num_classes: int = 1000,
        aux_logits: bool = True,
        transform_input: bool = False,
        inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,
        init_weights: Optional[bool] = None
    ) -> None:
74
        super(Inception3, self).__init__()
75
76
77
78
79
        if inception_blocks is None:
            inception_blocks = [
                BasicConv2d, InceptionA, InceptionB, InceptionC,
                InceptionD, InceptionE, InceptionAux
            ]
80
81
82
83
84
        if init_weights is None:
            warnings.warn('The default weight initialization of inception_v3 will be changed in future releases of '
                          'torchvision. If you wish to keep the old behavior (which leads to long initialization times'
                          ' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning)
            init_weights = True
85
86
87
88
89
90
91
92
93
        assert len(inception_blocks) == 7
        conv_block = inception_blocks[0]
        inception_a = inception_blocks[1]
        inception_b = inception_blocks[2]
        inception_c = inception_blocks[3]
        inception_d = inception_blocks[4]
        inception_e = inception_blocks[5]
        inception_aux = inception_blocks[6]

94
95
        self.aux_logits = aux_logits
        self.transform_input = transform_input
96
97
98
        self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
        self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
        self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
99
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
100
101
        self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
        self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
102
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
103
104
105
106
107
108
109
110
        self.Mixed_5b = inception_a(192, pool_features=32)
        self.Mixed_5c = inception_a(256, pool_features=64)
        self.Mixed_5d = inception_a(288, pool_features=64)
        self.Mixed_6a = inception_b(288)
        self.Mixed_6b = inception_c(768, channels_7x7=128)
        self.Mixed_6c = inception_c(768, channels_7x7=160)
        self.Mixed_6d = inception_c(768, channels_7x7=160)
        self.Mixed_6e = inception_c(768, channels_7x7=192)
111
        self.AuxLogits: Optional[nn.Module] = None
112
        if aux_logits:
113
114
115
116
            self.AuxLogits = inception_aux(768, num_classes)
        self.Mixed_7a = inception_d(768)
        self.Mixed_7b = inception_e(1280)
        self.Mixed_7c = inception_e(2048)
117
118
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout()
119
        self.fc = nn.Linear(2048, num_classes)
120
121
122
        if init_weights:
            for m in self.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
123
124
                    stddev = float(m.stddev) if hasattr(m, 'stddev') else 0.1  # type: ignore
                    torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=stddev, a=-2, b=2)
125
126
127
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
128

129
    def _transform_input(self, x: Tensor) -> Tensor:
130
        if self.transform_input:
131
132
133
134
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
135
136
        return x

137
    def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
138
        # N x 3 x 299 x 299
139
        x = self.Conv2d_1a_3x3(x)
140
        # N x 32 x 149 x 149
141
        x = self.Conv2d_2a_3x3(x)
142
        # N x 32 x 147 x 147
143
        x = self.Conv2d_2b_3x3(x)
144
        # N x 64 x 147 x 147
145
        x = self.maxpool1(x)
146
        # N x 64 x 73 x 73
147
        x = self.Conv2d_3b_1x1(x)
148
        # N x 80 x 73 x 73
149
        x = self.Conv2d_4a_3x3(x)
150
        # N x 192 x 71 x 71
151
        x = self.maxpool2(x)
152
        # N x 192 x 35 x 35
153
        x = self.Mixed_5b(x)
154
        # N x 256 x 35 x 35
155
        x = self.Mixed_5c(x)
surgan12's avatar
surgan12 committed
156
        # N x 288 x 35 x 35
157
        x = self.Mixed_5d(x)
158
        # N x 288 x 35 x 35
159
        x = self.Mixed_6a(x)
160
        # N x 768 x 17 x 17
161
        x = self.Mixed_6b(x)
162
        # N x 768 x 17 x 17
163
        x = self.Mixed_6c(x)
164
        # N x 768 x 17 x 17
165
        x = self.Mixed_6d(x)
166
        # N x 768 x 17 x 17
167
        x = self.Mixed_6e(x)
168
        # N x 768 x 17 x 17
169
        aux: Optional[Tensor] = None
170
171
172
        if self.AuxLogits is not None:
            if self.training:
                aux = self.AuxLogits(x)
173
        # N x 768 x 17 x 17
174
        x = self.Mixed_7a(x)
175
        # N x 1280 x 8 x 8
176
        x = self.Mixed_7b(x)
177
        # N x 2048 x 8 x 8
178
        x = self.Mixed_7c(x)
179
        # N x 2048 x 8 x 8
180
        # Adaptive average pooling
181
        x = self.avgpool(x)
182
        # N x 2048 x 1 x 1
183
        x = self.dropout(x)
184
        # N x 2048 x 1 x 1
185
        x = torch.flatten(x, 1)
186
        # N x 2048
187
        x = self.fc(x)
188
        # N x 1000 (num_classes)
189
        return x, aux
190
191

    @torch.jit.unused
192
    def eager_outputs(self, x: Tensor, aux: Optional[Tensor]) -> InceptionOutputs:
193
        if self.training and self.aux_logits:
194
            return InceptionOutputs(x, aux)
195
        else:
196
            return x  # type: ignore[return-value]
197

198
    def forward(self, x: Tensor) -> InceptionOutputs:
199
200
201
202
203
204
205
206
207
        x = self._transform_input(x)
        x, aux = self._forward(x)
        aux_defined = self.training and self.aux_logits
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
            return InceptionOutputs(x, aux)
        else:
            return self.eager_outputs(x, aux)
208
209
210


class InceptionA(nn.Module):
soumith's avatar
soumith committed
211

212
213
214
215
216
217
    def __init__(
        self,
        in_channels: int,
        pool_features: int,
        conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:
218
        super(InceptionA, self).__init__()
219
220
221
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
222

223
224
        self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
        self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
225

226
227
228
        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
229

230
        self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
231

232
    def _forward(self, x: Tensor) -> List[Tensor]:
233
234
235
236
237
238
239
240
241
242
243
244
245
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
246
247
        return outputs

248
    def forward(self, x: Tensor) -> Tensor:
249
        outputs = self._forward(x)
250
251
252
253
        return torch.cat(outputs, 1)


class InceptionB(nn.Module):
soumith's avatar
soumith committed
254

255
256
257
258
259
    def __init__(
        self,
        in_channels: int,
        conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:
260
        super(InceptionB, self).__init__()
261
262
263
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
264

265
266
267
        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
268

269
    def _forward(self, x: Tensor) -> List[Tensor]:
270
271
272
273
274
275
276
277
278
        branch3x3 = self.branch3x3(x)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)

        outputs = [branch3x3, branch3x3dbl, branch_pool]
279
280
        return outputs

281
    def forward(self, x: Tensor) -> Tensor:
282
        outputs = self._forward(x)
283
284
285
286
        return torch.cat(outputs, 1)


class InceptionC(nn.Module):
soumith's avatar
soumith committed
287

288
289
290
291
292
293
    def __init__(
        self,
        in_channels: int,
        channels_7x7: int,
        conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:
294
        super(InceptionC, self).__init__()
295
296
297
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
298
299

        c7 = channels_7x7
300
301
302
        self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
        self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
303

304
305
306
307
308
        self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
        self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
309

310
        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
311

312
    def _forward(self, x: Tensor) -> List[Tensor]:
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
329
330
        return outputs

331
    def forward(self, x: Tensor) -> Tensor:
332
        outputs = self._forward(x)
333
334
335
336
        return torch.cat(outputs, 1)


class InceptionD(nn.Module):
soumith's avatar
soumith committed
337

338
339
340
341
342
    def __init__(
        self,
        in_channels: int,
        conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:
343
        super(InceptionD, self).__init__()
344
345
346
347
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
        self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
348

349
350
351
352
        self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
        self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
353

354
    def _forward(self, x: Tensor) -> List[Tensor]:
355
356
357
358
359
360
361
362
363
364
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)

        branch7x7x3 = self.branch7x7x3_1(x)
        branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_4(branch7x7x3)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
        outputs = [branch3x3, branch7x7x3, branch_pool]
365
366
        return outputs

367
    def forward(self, x: Tensor) -> Tensor:
368
        outputs = self._forward(x)
369
370
371
372
        return torch.cat(outputs, 1)


class InceptionE(nn.Module):
soumith's avatar
soumith committed
373

374
375
376
377
378
    def __init__(
        self,
        in_channels: int,
        conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:
379
        super(InceptionE, self).__init__()
380
381
382
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
383

384
385
386
        self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
        self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
387

388
389
390
391
        self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
        self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
392

393
        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
394

395
    def _forward(self, x: Tensor) -> List[Tensor]:
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
417
418
        return outputs

419
    def forward(self, x: Tensor) -> Tensor:
420
        outputs = self._forward(x)
421
422
423
424
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):
soumith's avatar
soumith committed
425

426
427
428
429
430
431
    def __init__(
        self,
        in_channels: int,
        num_classes: int,
        conv_block: Optional[Callable[..., nn.Module]] = None
    ) -> None:
432
        super(InceptionAux, self).__init__()
433
434
435
436
        if conv_block is None:
            conv_block = BasicConv2d
        self.conv0 = conv_block(in_channels, 128, kernel_size=1)
        self.conv1 = conv_block(128, 768, kernel_size=5)
437
        self.conv1.stddev = 0.01  # type: ignore[assignment]
438
        self.fc = nn.Linear(768, num_classes)
439
        self.fc.stddev = 0.001  # type: ignore[assignment]
440

441
    def forward(self, x: Tensor) -> Tensor:
442
        # N x 768 x 17 x 17
443
        x = F.avg_pool2d(x, kernel_size=5, stride=3)
444
        # N x 768 x 5 x 5
445
        x = self.conv0(x)
446
        # N x 128 x 5 x 5
447
        x = self.conv1(x)
448
        # N x 768 x 1 x 1
449
450
        # Adaptive average pooling
        x = F.adaptive_avg_pool2d(x, (1, 1))
451
        # N x 768 x 1 x 1
452
        x = torch.flatten(x, 1)
453
        # N x 768
454
        x = self.fc(x)
455
        # N x 1000
456
457
458
459
        return x


class BasicConv2d(nn.Module):
soumith's avatar
soumith committed
460

461
462
463
464
465
466
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        **kwargs: Any
    ) -> None:
467
468
469
470
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

471
    def forward(self, x: Tensor) -> Tensor:
472
473
474
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)