inception.py 17 KB
Newer Older
1
import warnings
2
3
4
from collections import namedtuple
from typing import Callable, Any, Optional, Tuple, List

5
6
import torch
import torch.nn.functional as F
7
8
from torch import nn, Tensor

9
from .._internally_replaced_utils import load_state_dict_from_url
10
from ..utils import _log_api_usage_once
11
12


13
__all__ = ["Inception3", "inception_v3", "InceptionOutputs", "_InceptionOutputs"]
14
15
16
17


model_urls = {
    # Inception v3 ported from TensorFlow
18
    "inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
19
20
}

21
22
InceptionOutputs = namedtuple("InceptionOutputs", ["logits", "aux_logits"])
InceptionOutputs.__annotations__ = {"logits": Tensor, "aux_logits": Optional[Tensor]}
23
24
25
26

# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _InceptionOutputs set here for backwards compat
_InceptionOutputs = InceptionOutputs
27

28

29
def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "Inception3":
30
31
    r"""Inception v3 model architecture from
    `"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
32
    The required minimum input size of the model is 75x75.
33

34
35
    .. note::
        **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
36
        N x 3 x 299 x 299, so ensure your images are sized accordingly.
37

38
39
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
40
        progress (bool): If True, displays a progress bar of the download to stderr
41
42
        aux_logits (bool): If True, add an auxiliary branch that can improve training.
            Default: *True*
43
        transform_input (bool): If True, preprocesses the input according to the method with which it
44
            was trained on ImageNet. Default: *False*
45
46
    """
    if pretrained:
47
48
49
50
51
        if "transform_input" not in kwargs:
            kwargs["transform_input"] = True
        if "aux_logits" in kwargs:
            original_aux_logits = kwargs["aux_logits"]
            kwargs["aux_logits"] = True
52
53
        else:
            original_aux_logits = True
54
        kwargs["init_weights"] = False  # we are loading weights from a pretrained model
55
        model = Inception3(**kwargs)
56
        state_dict = load_state_dict_from_url(model_urls["inception_v3_google"], progress=progress)
57
        model.load_state_dict(state_dict)
58
59
        if not original_aux_logits:
            model.aux_logits = False
60
            model.AuxLogits = None
61
62
63
64
65
66
        return model

    return Inception3(**kwargs)


class Inception3(nn.Module):
67
68
69
70
71
72
    def __init__(
        self,
        num_classes: int = 1000,
        aux_logits: bool = True,
        transform_input: bool = False,
        inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,
73
        init_weights: Optional[bool] = None,
74
        dropout: float = 0.5,
75
    ) -> None:
76
        super().__init__()
77
        _log_api_usage_once(self)
78
        if inception_blocks is None:
79
            inception_blocks = [BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux]
80
        if init_weights is None:
81
82
83
84
85
86
            warnings.warn(
                "The default weight initialization of inception_v3 will be changed in future releases of "
                "torchvision. If you wish to keep the old behavior (which leads to long initialization times"
                " due to scipy/scipy#11299), please set init_weights=True.",
                FutureWarning,
            )
87
            init_weights = True
88
89
90
91
92
93
94
95
96
        assert len(inception_blocks) == 7
        conv_block = inception_blocks[0]
        inception_a = inception_blocks[1]
        inception_b = inception_blocks[2]
        inception_c = inception_blocks[3]
        inception_d = inception_blocks[4]
        inception_e = inception_blocks[5]
        inception_aux = inception_blocks[6]

97
98
        self.aux_logits = aux_logits
        self.transform_input = transform_input
99
100
101
        self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
        self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
        self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
102
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
103
104
        self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
        self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
105
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
106
107
108
109
110
111
112
113
        self.Mixed_5b = inception_a(192, pool_features=32)
        self.Mixed_5c = inception_a(256, pool_features=64)
        self.Mixed_5d = inception_a(288, pool_features=64)
        self.Mixed_6a = inception_b(288)
        self.Mixed_6b = inception_c(768, channels_7x7=128)
        self.Mixed_6c = inception_c(768, channels_7x7=160)
        self.Mixed_6d = inception_c(768, channels_7x7=160)
        self.Mixed_6e = inception_c(768, channels_7x7=192)
114
        self.AuxLogits: Optional[nn.Module] = None
115
        if aux_logits:
116
117
118
119
            self.AuxLogits = inception_aux(768, num_classes)
        self.Mixed_7a = inception_d(768)
        self.Mixed_7b = inception_e(1280)
        self.Mixed_7c = inception_e(2048)
120
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
121
        self.dropout = nn.Dropout(p=dropout)
122
        self.fc = nn.Linear(2048, num_classes)
123
124
125
        if init_weights:
            for m in self.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
126
                    stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1  # type: ignore
127
                    torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=stddev, a=-2, b=2)
128
129
130
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
131

132
    def _transform_input(self, x: Tensor) -> Tensor:
133
        if self.transform_input:
134
135
136
137
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
138
139
        return x

140
    def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
141
        # N x 3 x 299 x 299
142
        x = self.Conv2d_1a_3x3(x)
143
        # N x 32 x 149 x 149
144
        x = self.Conv2d_2a_3x3(x)
145
        # N x 32 x 147 x 147
146
        x = self.Conv2d_2b_3x3(x)
147
        # N x 64 x 147 x 147
148
        x = self.maxpool1(x)
149
        # N x 64 x 73 x 73
150
        x = self.Conv2d_3b_1x1(x)
151
        # N x 80 x 73 x 73
152
        x = self.Conv2d_4a_3x3(x)
153
        # N x 192 x 71 x 71
154
        x = self.maxpool2(x)
155
        # N x 192 x 35 x 35
156
        x = self.Mixed_5b(x)
157
        # N x 256 x 35 x 35
158
        x = self.Mixed_5c(x)
surgan12's avatar
surgan12 committed
159
        # N x 288 x 35 x 35
160
        x = self.Mixed_5d(x)
161
        # N x 288 x 35 x 35
162
        x = self.Mixed_6a(x)
163
        # N x 768 x 17 x 17
164
        x = self.Mixed_6b(x)
165
        # N x 768 x 17 x 17
166
        x = self.Mixed_6c(x)
167
        # N x 768 x 17 x 17
168
        x = self.Mixed_6d(x)
169
        # N x 768 x 17 x 17
170
        x = self.Mixed_6e(x)
171
        # N x 768 x 17 x 17
172
        aux: Optional[Tensor] = None
173
174
175
        if self.AuxLogits is not None:
            if self.training:
                aux = self.AuxLogits(x)
176
        # N x 768 x 17 x 17
177
        x = self.Mixed_7a(x)
178
        # N x 1280 x 8 x 8
179
        x = self.Mixed_7b(x)
180
        # N x 2048 x 8 x 8
181
        x = self.Mixed_7c(x)
182
        # N x 2048 x 8 x 8
183
        # Adaptive average pooling
184
        x = self.avgpool(x)
185
        # N x 2048 x 1 x 1
186
        x = self.dropout(x)
187
        # N x 2048 x 1 x 1
188
        x = torch.flatten(x, 1)
189
        # N x 2048
190
        x = self.fc(x)
191
        # N x 1000 (num_classes)
192
        return x, aux
193
194

    @torch.jit.unused
195
    def eager_outputs(self, x: Tensor, aux: Optional[Tensor]) -> InceptionOutputs:
196
        if self.training and self.aux_logits:
197
            return InceptionOutputs(x, aux)
198
        else:
199
            return x  # type: ignore[return-value]
200

201
    def forward(self, x: Tensor) -> InceptionOutputs:
202
203
204
205
206
207
208
209
210
        x = self._transform_input(x)
        x, aux = self._forward(x)
        aux_defined = self.training and self.aux_logits
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
            return InceptionOutputs(x, aux)
        else:
            return self.eager_outputs(x, aux)
211
212
213


class InceptionA(nn.Module):
214
    def __init__(
215
        self, in_channels: int, pool_features: int, conv_block: Optional[Callable[..., nn.Module]] = None
216
    ) -> None:
217
        super().__init__()
218
219
220
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
221

222
223
        self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
        self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
224

225
226
227
        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
228

229
        self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
230

231
    def _forward(self, x: Tensor) -> List[Tensor]:
232
233
234
235
236
237
238
239
240
241
242
243
244
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
245
246
        return outputs

247
    def forward(self, x: Tensor) -> Tensor:
248
        outputs = self._forward(x)
249
250
251
252
        return torch.cat(outputs, 1)


class InceptionB(nn.Module):
253
    def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
254
        super().__init__()
255
256
257
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
258

259
260
261
        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
262

263
    def _forward(self, x: Tensor) -> List[Tensor]:
264
265
266
267
268
269
270
271
272
        branch3x3 = self.branch3x3(x)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)

        outputs = [branch3x3, branch3x3dbl, branch_pool]
273
274
        return outputs

275
    def forward(self, x: Tensor) -> Tensor:
276
        outputs = self._forward(x)
277
278
279
280
        return torch.cat(outputs, 1)


class InceptionC(nn.Module):
281
    def __init__(
282
        self, in_channels: int, channels_7x7: int, conv_block: Optional[Callable[..., nn.Module]] = None
283
    ) -> None:
284
        super().__init__()
285
286
287
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
288
289

        c7 = channels_7x7
290
291
292
        self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
        self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
293

294
295
296
297
298
        self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
        self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
299

300
        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
301

302
    def _forward(self, x: Tensor) -> List[Tensor]:
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
319
320
        return outputs

321
    def forward(self, x: Tensor) -> Tensor:
322
        outputs = self._forward(x)
323
324
325
326
        return torch.cat(outputs, 1)


class InceptionD(nn.Module):
327
    def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
328
        super().__init__()
329
330
331
332
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
        self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
333

334
335
336
337
        self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
        self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
338

339
    def _forward(self, x: Tensor) -> List[Tensor]:
340
341
342
343
344
345
346
347
348
349
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)

        branch7x7x3 = self.branch7x7x3_1(x)
        branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_4(branch7x7x3)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
        outputs = [branch3x3, branch7x7x3, branch_pool]
350
351
        return outputs

352
    def forward(self, x: Tensor) -> Tensor:
353
        outputs = self._forward(x)
354
355
356
357
        return torch.cat(outputs, 1)


class InceptionE(nn.Module):
358
    def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
359
        super().__init__()
360
361
362
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
363

364
365
366
        self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
        self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
367

368
369
370
371
        self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
        self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
372

373
        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
374

375
    def _forward(self, x: Tensor) -> List[Tensor]:
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
397
398
        return outputs

399
    def forward(self, x: Tensor) -> Tensor:
400
        outputs = self._forward(x)
401
402
403
404
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):
405
    def __init__(
406
        self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None
407
    ) -> None:
408
        super().__init__()
409
410
411
412
        if conv_block is None:
            conv_block = BasicConv2d
        self.conv0 = conv_block(in_channels, 128, kernel_size=1)
        self.conv1 = conv_block(128, 768, kernel_size=5)
413
        self.conv1.stddev = 0.01  # type: ignore[assignment]
414
        self.fc = nn.Linear(768, num_classes)
415
        self.fc.stddev = 0.001  # type: ignore[assignment]
416

417
    def forward(self, x: Tensor) -> Tensor:
418
        # N x 768 x 17 x 17
419
        x = F.avg_pool2d(x, kernel_size=5, stride=3)
420
        # N x 768 x 5 x 5
421
        x = self.conv0(x)
422
        # N x 128 x 5 x 5
423
        x = self.conv1(x)
424
        # N x 768 x 1 x 1
425
426
        # Adaptive average pooling
        x = F.adaptive_avg_pool2d(x, (1, 1))
427
        # N x 768 x 1 x 1
428
        x = torch.flatten(x, 1)
429
        # N x 768
430
        x = self.fc(x)
431
        # N x 1000
432
433
434
435
        return x


class BasicConv2d(nn.Module):
436
    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
437
        super().__init__()
438
439
440
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

441
    def forward(self, x: Tensor) -> Tensor:
442
443
444
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)