inception.py 18.4 KB
Newer Older
1
import warnings
limm's avatar
limm committed
2
3
4
5
from collections import namedtuple
from functools import partial
from typing import Any, Callable, List, Optional, Tuple

6
7
import torch
import torch.nn.functional as F
limm's avatar
limm committed
8
from torch import nn, Tensor
9

limm's avatar
limm committed
10
11
12
13
14
from ..transforms._presets import ImageClassification
from ..utils import _log_api_usage_once
from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface
15
16


limm's avatar
limm committed
17
__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"]
18
19


limm's avatar
limm committed
20
21
InceptionOutputs = namedtuple("InceptionOutputs", ["logits", "aux_logits"])
InceptionOutputs.__annotations__ = {"logits": Tensor, "aux_logits": Optional[Tensor]}
22
23
24
25

# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _InceptionOutputs set here for backwards compat
_InceptionOutputs = InceptionOutputs
26

27
28

class Inception3(nn.Module):
29
30
31
32
33
34
    def __init__(
        self,
        num_classes: int = 1000,
        aux_logits: bool = True,
        transform_input: bool = False,
        inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,
limm's avatar
limm committed
35
36
        init_weights: Optional[bool] = None,
        dropout: float = 0.5,
37
    ) -> None:
limm's avatar
limm committed
38
39
        super().__init__()
        _log_api_usage_once(self)
40
        if inception_blocks is None:
limm's avatar
limm committed
41
            inception_blocks = [BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux]
42
        if init_weights is None:
limm's avatar
limm committed
43
44
45
46
47
48
            warnings.warn(
                "The default weight initialization of inception_v3 will be changed in future releases of "
                "torchvision. If you wish to keep the old behavior (which leads to long initialization times"
                " due to scipy/scipy#11299), please set init_weights=True.",
                FutureWarning,
            )
49
            init_weights = True
limm's avatar
limm committed
50
51
        if len(inception_blocks) != 7:
            raise ValueError(f"length of inception_blocks should be 7 instead of {len(inception_blocks)}")
52
53
54
55
56
57
58
59
        conv_block = inception_blocks[0]
        inception_a = inception_blocks[1]
        inception_b = inception_blocks[2]
        inception_c = inception_blocks[3]
        inception_d = inception_blocks[4]
        inception_e = inception_blocks[5]
        inception_aux = inception_blocks[6]

60
61
        self.aux_logits = aux_logits
        self.transform_input = transform_input
62
63
64
        self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
        self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
        self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
65
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
66
67
        self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
        self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
68
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
69
70
71
72
73
74
75
76
        self.Mixed_5b = inception_a(192, pool_features=32)
        self.Mixed_5c = inception_a(256, pool_features=64)
        self.Mixed_5d = inception_a(288, pool_features=64)
        self.Mixed_6a = inception_b(288)
        self.Mixed_6b = inception_c(768, channels_7x7=128)
        self.Mixed_6c = inception_c(768, channels_7x7=160)
        self.Mixed_6d = inception_c(768, channels_7x7=160)
        self.Mixed_6e = inception_c(768, channels_7x7=192)
77
        self.AuxLogits: Optional[nn.Module] = None
78
        if aux_logits:
79
80
81
82
            self.AuxLogits = inception_aux(768, num_classes)
        self.Mixed_7a = inception_d(768)
        self.Mixed_7b = inception_e(1280)
        self.Mixed_7c = inception_e(2048)
83
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
limm's avatar
limm committed
84
        self.dropout = nn.Dropout(p=dropout)
85
        self.fc = nn.Linear(2048, num_classes)
86
87
88
        if init_weights:
            for m in self.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
limm's avatar
limm committed
89
90
                    stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1  # type: ignore
                    torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=stddev, a=-2, b=2)
91
92
93
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
94

95
    def _transform_input(self, x: Tensor) -> Tensor:
96
        if self.transform_input:
97
98
99
100
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
101
102
        return x

103
    def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
104
        # N x 3 x 299 x 299
105
        x = self.Conv2d_1a_3x3(x)
106
        # N x 32 x 149 x 149
107
        x = self.Conv2d_2a_3x3(x)
108
        # N x 32 x 147 x 147
109
        x = self.Conv2d_2b_3x3(x)
110
        # N x 64 x 147 x 147
111
        x = self.maxpool1(x)
112
        # N x 64 x 73 x 73
113
        x = self.Conv2d_3b_1x1(x)
114
        # N x 80 x 73 x 73
115
        x = self.Conv2d_4a_3x3(x)
116
        # N x 192 x 71 x 71
117
        x = self.maxpool2(x)
118
        # N x 192 x 35 x 35
119
        x = self.Mixed_5b(x)
120
        # N x 256 x 35 x 35
121
        x = self.Mixed_5c(x)
surgan12's avatar
surgan12 committed
122
        # N x 288 x 35 x 35
123
        x = self.Mixed_5d(x)
124
        # N x 288 x 35 x 35
125
        x = self.Mixed_6a(x)
126
        # N x 768 x 17 x 17
127
        x = self.Mixed_6b(x)
128
        # N x 768 x 17 x 17
129
        x = self.Mixed_6c(x)
130
        # N x 768 x 17 x 17
131
        x = self.Mixed_6d(x)
132
        # N x 768 x 17 x 17
133
        x = self.Mixed_6e(x)
134
        # N x 768 x 17 x 17
135
        aux: Optional[Tensor] = None
136
137
138
        if self.AuxLogits is not None:
            if self.training:
                aux = self.AuxLogits(x)
139
        # N x 768 x 17 x 17
140
        x = self.Mixed_7a(x)
141
        # N x 1280 x 8 x 8
142
        x = self.Mixed_7b(x)
143
        # N x 2048 x 8 x 8
144
        x = self.Mixed_7c(x)
145
        # N x 2048 x 8 x 8
146
        # Adaptive average pooling
147
        x = self.avgpool(x)
148
        # N x 2048 x 1 x 1
149
        x = self.dropout(x)
150
        # N x 2048 x 1 x 1
151
        x = torch.flatten(x, 1)
152
        # N x 2048
153
        x = self.fc(x)
154
        # N x 1000 (num_classes)
155
        return x, aux
156
157

    @torch.jit.unused
158
    def eager_outputs(self, x: Tensor, aux: Optional[Tensor]) -> InceptionOutputs:
159
        if self.training and self.aux_logits:
160
            return InceptionOutputs(x, aux)
161
        else:
162
            return x  # type: ignore[return-value]
163

164
    def forward(self, x: Tensor) -> InceptionOutputs:
165
166
167
168
169
170
171
172
173
        x = self._transform_input(x)
        x, aux = self._forward(x)
        aux_defined = self.training and self.aux_logits
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
            return InceptionOutputs(x, aux)
        else:
            return self.eager_outputs(x, aux)
174
175
176


class InceptionA(nn.Module):
177
    def __init__(
limm's avatar
limm committed
178
        self, in_channels: int, pool_features: int, conv_block: Optional[Callable[..., nn.Module]] = None
179
    ) -> None:
limm's avatar
limm committed
180
        super().__init__()
181
182
183
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
184

185
186
        self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
        self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
187

188
189
190
        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
191

192
        self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
193

194
    def _forward(self, x: Tensor) -> List[Tensor]:
195
196
197
198
199
200
201
202
203
204
205
206
207
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
208
209
        return outputs

210
    def forward(self, x: Tensor) -> Tensor:
211
        outputs = self._forward(x)
212
213
214
215
        return torch.cat(outputs, 1)


class InceptionB(nn.Module):
limm's avatar
limm committed
216
217
    def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
        super().__init__()
218
219
220
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
221

222
223
224
        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
225

226
    def _forward(self, x: Tensor) -> List[Tensor]:
227
228
229
230
231
232
233
234
235
        branch3x3 = self.branch3x3(x)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)

        outputs = [branch3x3, branch3x3dbl, branch_pool]
236
237
        return outputs

238
    def forward(self, x: Tensor) -> Tensor:
239
        outputs = self._forward(x)
240
241
242
243
        return torch.cat(outputs, 1)


class InceptionC(nn.Module):
244
    def __init__(
limm's avatar
limm committed
245
        self, in_channels: int, channels_7x7: int, conv_block: Optional[Callable[..., nn.Module]] = None
246
    ) -> None:
limm's avatar
limm committed
247
        super().__init__()
248
249
250
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
251
252

        c7 = channels_7x7
253
254
255
        self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
        self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
256

257
258
259
260
261
        self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
        self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
262

263
        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
264

265
    def _forward(self, x: Tensor) -> List[Tensor]:
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
282
283
        return outputs

284
    def forward(self, x: Tensor) -> Tensor:
285
        outputs = self._forward(x)
286
287
288
289
        return torch.cat(outputs, 1)


class InceptionD(nn.Module):
limm's avatar
limm committed
290
291
    def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
        super().__init__()
292
293
294
295
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
        self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
296

297
298
299
300
        self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
        self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
301

302
    def _forward(self, x: Tensor) -> List[Tensor]:
303
304
305
306
307
308
309
310
311
312
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)

        branch7x7x3 = self.branch7x7x3_1(x)
        branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_4(branch7x7x3)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
        outputs = [branch3x3, branch7x7x3, branch_pool]
313
314
        return outputs

315
    def forward(self, x: Tensor) -> Tensor:
316
        outputs = self._forward(x)
317
318
319
320
        return torch.cat(outputs, 1)


class InceptionE(nn.Module):
limm's avatar
limm committed
321
322
    def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
        super().__init__()
323
324
325
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
326

327
328
329
        self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
        self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
330

331
332
333
334
        self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
        self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
335

336
        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
337

338
    def _forward(self, x: Tensor) -> List[Tensor]:
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
360
361
        return outputs

362
    def forward(self, x: Tensor) -> Tensor:
363
        outputs = self._forward(x)
364
365
366
367
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):
368
    def __init__(
limm's avatar
limm committed
369
        self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None
370
    ) -> None:
limm's avatar
limm committed
371
        super().__init__()
372
373
374
375
        if conv_block is None:
            conv_block = BasicConv2d
        self.conv0 = conv_block(in_channels, 128, kernel_size=1)
        self.conv1 = conv_block(128, 768, kernel_size=5)
376
        self.conv1.stddev = 0.01  # type: ignore[assignment]
377
        self.fc = nn.Linear(768, num_classes)
378
        self.fc.stddev = 0.001  # type: ignore[assignment]
379

380
    def forward(self, x: Tensor) -> Tensor:
381
        # N x 768 x 17 x 17
382
        x = F.avg_pool2d(x, kernel_size=5, stride=3)
383
        # N x 768 x 5 x 5
384
        x = self.conv0(x)
385
        # N x 128 x 5 x 5
386
        x = self.conv1(x)
387
        # N x 768 x 1 x 1
388
389
        # Adaptive average pooling
        x = F.adaptive_avg_pool2d(x, (1, 1))
390
        # N x 768 x 1 x 1
391
        x = torch.flatten(x, 1)
392
        # N x 768
393
        x = self.fc(x)
394
        # N x 1000
395
396
397
398
        return x


class BasicConv2d(nn.Module):
limm's avatar
limm committed
399
400
    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
        super().__init__()
401
402
403
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

404
    def forward(self, x: Tensor) -> Tensor:
405
406
407
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)
limm's avatar
limm committed
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478


class Inception_V3_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
        transforms=partial(ImageClassification, crop_size=299, resize_size=342),
        meta={
            "num_params": 27161264,
            "min_size": (75, 75),
            "categories": _IMAGENET_CATEGORIES,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#inception-v3",
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 77.294,
                    "acc@5": 93.450,
                }
            },
            "_ops": 5.713,
            "_file_size": 103.903,
            "_docs": """These weights are ported from the original paper.""",
        },
    )
    DEFAULT = IMAGENET1K_V1


@register_model()
@handle_legacy_interface(weights=("pretrained", Inception_V3_Weights.IMAGENET1K_V1))
def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
    """
    Inception v3 model architecture from
    `Rethinking the Inception Architecture for Computer Vision <http://arxiv.org/abs/1512.00567>`_.

    .. note::
        **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
        N x 3 x 299 x 299, so ensure your images are sized accordingly.

    Args:
        weights (:class:`~torchvision.models.Inception_V3_Weights`, optional): The
            pretrained weights for the model. See
            :class:`~torchvision.models.Inception_V3_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.Inception3``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/inception.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.Inception_V3_Weights
        :members:
    """
    weights = Inception_V3_Weights.verify(weights)

    original_aux_logits = kwargs.get("aux_logits", True)
    if weights is not None:
        if "transform_input" not in kwargs:
            _ovewrite_named_param(kwargs, "transform_input", True)
        _ovewrite_named_param(kwargs, "aux_logits", True)
        _ovewrite_named_param(kwargs, "init_weights", False)
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

    model = Inception3(**kwargs)

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
        if not original_aux_logits:
            model.aux_logits = False
            model.AuxLogits = None

    return model