inception.py 18.6 KB
Newer Older
1
import warnings
2
from collections import namedtuple
3
from functools import partial
4
from typing import Any, Callable, List, Optional, Tuple
5

6
7
import torch
import torch.nn.functional as F
8
9
from torch import nn, Tensor

10
from ..transforms._presets import ImageClassification
11
from ..utils import _log_api_usage_once
12
from ._api import register_model, Weights, WeightsEnum
13
from ._meta import _IMAGENET_CATEGORIES
14
from ._utils import _ovewrite_named_param, handle_legacy_interface
15
16


17
__all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_Weights", "inception_v3"]
18
19


20
21
InceptionOutputs = namedtuple("InceptionOutputs", ["logits", "aux_logits"])
InceptionOutputs.__annotations__ = {"logits": Tensor, "aux_logits": Optional[Tensor]}
22
23
24
25

# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _InceptionOutputs set here for backwards compat
_InceptionOutputs = InceptionOutputs
26

27
28

class Inception3(nn.Module):
29
30
31
32
33
34
    def __init__(
        self,
        num_classes: int = 1000,
        aux_logits: bool = True,
        transform_input: bool = False,
        inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,
35
        init_weights: Optional[bool] = None,
36
        dropout: float = 0.5,
37
    ) -> None:
38
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
39
        _log_api_usage_once(self)
40
        if inception_blocks is None:
41
            inception_blocks = [BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux]
42
        if init_weights is None:
43
44
45
46
47
48
            warnings.warn(
                "The default weight initialization of inception_v3 will be changed in future releases of "
                "torchvision. If you wish to keep the old behavior (which leads to long initialization times"
                " due to scipy/scipy#11299), please set init_weights=True.",
                FutureWarning,
            )
49
            init_weights = True
50
51
        if len(inception_blocks) != 7:
            raise ValueError(f"lenght of inception_blocks should be 7 instead of {len(inception_blocks)}")
52
53
54
55
56
57
58
59
        conv_block = inception_blocks[0]
        inception_a = inception_blocks[1]
        inception_b = inception_blocks[2]
        inception_c = inception_blocks[3]
        inception_d = inception_blocks[4]
        inception_e = inception_blocks[5]
        inception_aux = inception_blocks[6]

60
61
        self.aux_logits = aux_logits
        self.transform_input = transform_input
62
63
64
        self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
        self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
        self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
65
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
66
67
        self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
        self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
68
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
69
70
71
72
73
74
75
76
        self.Mixed_5b = inception_a(192, pool_features=32)
        self.Mixed_5c = inception_a(256, pool_features=64)
        self.Mixed_5d = inception_a(288, pool_features=64)
        self.Mixed_6a = inception_b(288)
        self.Mixed_6b = inception_c(768, channels_7x7=128)
        self.Mixed_6c = inception_c(768, channels_7x7=160)
        self.Mixed_6d = inception_c(768, channels_7x7=160)
        self.Mixed_6e = inception_c(768, channels_7x7=192)
77
        self.AuxLogits: Optional[nn.Module] = None
78
        if aux_logits:
79
80
81
82
            self.AuxLogits = inception_aux(768, num_classes)
        self.Mixed_7a = inception_d(768)
        self.Mixed_7b = inception_e(1280)
        self.Mixed_7c = inception_e(2048)
83
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
84
        self.dropout = nn.Dropout(p=dropout)
85
        self.fc = nn.Linear(2048, num_classes)
86
87
88
        if init_weights:
            for m in self.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
89
                    stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1  # type: ignore
90
                    torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=stddev, a=-2, b=2)
91
92
93
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
94

95
    def _transform_input(self, x: Tensor) -> Tensor:
96
        if self.transform_input:
97
98
99
100
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
101
102
        return x

103
    def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
104
        # N x 3 x 299 x 299
105
        x = self.Conv2d_1a_3x3(x)
106
        # N x 32 x 149 x 149
107
        x = self.Conv2d_2a_3x3(x)
108
        # N x 32 x 147 x 147
109
        x = self.Conv2d_2b_3x3(x)
110
        # N x 64 x 147 x 147
111
        x = self.maxpool1(x)
112
        # N x 64 x 73 x 73
113
        x = self.Conv2d_3b_1x1(x)
114
        # N x 80 x 73 x 73
115
        x = self.Conv2d_4a_3x3(x)
116
        # N x 192 x 71 x 71
117
        x = self.maxpool2(x)
118
        # N x 192 x 35 x 35
119
        x = self.Mixed_5b(x)
120
        # N x 256 x 35 x 35
121
        x = self.Mixed_5c(x)
surgan12's avatar
surgan12 committed
122
        # N x 288 x 35 x 35
123
        x = self.Mixed_5d(x)
124
        # N x 288 x 35 x 35
125
        x = self.Mixed_6a(x)
126
        # N x 768 x 17 x 17
127
        x = self.Mixed_6b(x)
128
        # N x 768 x 17 x 17
129
        x = self.Mixed_6c(x)
130
        # N x 768 x 17 x 17
131
        x = self.Mixed_6d(x)
132
        # N x 768 x 17 x 17
133
        x = self.Mixed_6e(x)
134
        # N x 768 x 17 x 17
135
        aux: Optional[Tensor] = None
136
137
138
        if self.AuxLogits is not None:
            if self.training:
                aux = self.AuxLogits(x)
139
        # N x 768 x 17 x 17
140
        x = self.Mixed_7a(x)
141
        # N x 1280 x 8 x 8
142
        x = self.Mixed_7b(x)
143
        # N x 2048 x 8 x 8
144
        x = self.Mixed_7c(x)
145
        # N x 2048 x 8 x 8
146
        # Adaptive average pooling
147
        x = self.avgpool(x)
148
        # N x 2048 x 1 x 1
149
        x = self.dropout(x)
150
        # N x 2048 x 1 x 1
151
        x = torch.flatten(x, 1)
152
        # N x 2048
153
        x = self.fc(x)
154
        # N x 1000 (num_classes)
155
        return x, aux
156
157

    @torch.jit.unused
158
    def eager_outputs(self, x: Tensor, aux: Optional[Tensor]) -> InceptionOutputs:
159
        if self.training and self.aux_logits:
160
            return InceptionOutputs(x, aux)
161
        else:
162
            return x  # type: ignore[return-value]
163

164
    def forward(self, x: Tensor) -> InceptionOutputs:
165
166
167
168
169
170
171
172
173
        x = self._transform_input(x)
        x, aux = self._forward(x)
        aux_defined = self.training and self.aux_logits
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
            return InceptionOutputs(x, aux)
        else:
            return self.eager_outputs(x, aux)
174
175
176


class InceptionA(nn.Module):
177
    def __init__(
178
        self, in_channels: int, pool_features: int, conv_block: Optional[Callable[..., nn.Module]] = None
179
    ) -> None:
180
        super().__init__()
181
182
183
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
184

185
186
        self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
        self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
187

188
189
190
        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
191

192
        self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
193

194
    def _forward(self, x: Tensor) -> List[Tensor]:
195
196
197
198
199
200
201
202
203
204
205
206
207
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
208
209
        return outputs

210
    def forward(self, x: Tensor) -> Tensor:
211
        outputs = self._forward(x)
212
213
214
215
        return torch.cat(outputs, 1)


class InceptionB(nn.Module):
216
    def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
217
        super().__init__()
218
219
220
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
221

222
223
224
        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
225

226
    def _forward(self, x: Tensor) -> List[Tensor]:
227
228
229
230
231
232
233
234
235
        branch3x3 = self.branch3x3(x)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)

        outputs = [branch3x3, branch3x3dbl, branch_pool]
236
237
        return outputs

238
    def forward(self, x: Tensor) -> Tensor:
239
        outputs = self._forward(x)
240
241
242
243
        return torch.cat(outputs, 1)


class InceptionC(nn.Module):
244
    def __init__(
245
        self, in_channels: int, channels_7x7: int, conv_block: Optional[Callable[..., nn.Module]] = None
246
    ) -> None:
247
        super().__init__()
248
249
250
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
251
252

        c7 = channels_7x7
253
254
255
        self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
        self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
256

257
258
259
260
261
        self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
        self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
262

263
        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
264

265
    def _forward(self, x: Tensor) -> List[Tensor]:
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
282
283
        return outputs

284
    def forward(self, x: Tensor) -> Tensor:
285
        outputs = self._forward(x)
286
287
288
289
        return torch.cat(outputs, 1)


class InceptionD(nn.Module):
290
    def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
291
        super().__init__()
292
293
294
295
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
        self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
296

297
298
299
300
        self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
        self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
301

302
    def _forward(self, x: Tensor) -> List[Tensor]:
303
304
305
306
307
308
309
310
311
312
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)

        branch7x7x3 = self.branch7x7x3_1(x)
        branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_4(branch7x7x3)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
        outputs = [branch3x3, branch7x7x3, branch_pool]
313
314
        return outputs

315
    def forward(self, x: Tensor) -> Tensor:
316
        outputs = self._forward(x)
317
318
319
320
        return torch.cat(outputs, 1)


class InceptionE(nn.Module):
321
    def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
322
        super().__init__()
323
324
325
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
326

327
328
329
        self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
        self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
330

331
332
333
334
        self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
        self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
335

336
        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
337

338
    def _forward(self, x: Tensor) -> List[Tensor]:
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
360
361
        return outputs

362
    def forward(self, x: Tensor) -> Tensor:
363
        outputs = self._forward(x)
364
365
366
367
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):
368
    def __init__(
369
        self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None
370
    ) -> None:
371
        super().__init__()
372
373
374
375
        if conv_block is None:
            conv_block = BasicConv2d
        self.conv0 = conv_block(in_channels, 128, kernel_size=1)
        self.conv1 = conv_block(128, 768, kernel_size=5)
376
        self.conv1.stddev = 0.01  # type: ignore[assignment]
377
        self.fc = nn.Linear(768, num_classes)
378
        self.fc.stddev = 0.001  # type: ignore[assignment]
379

380
    def forward(self, x: Tensor) -> Tensor:
381
        # N x 768 x 17 x 17
382
        x = F.avg_pool2d(x, kernel_size=5, stride=3)
383
        # N x 768 x 5 x 5
384
        x = self.conv0(x)
385
        # N x 128 x 5 x 5
386
        x = self.conv1(x)
387
        # N x 768 x 1 x 1
388
389
        # Adaptive average pooling
        x = F.adaptive_avg_pool2d(x, (1, 1))
390
        # N x 768 x 1 x 1
391
        x = torch.flatten(x, 1)
392
        # N x 768
393
        x = self.fc(x)
394
        # N x 1000
395
396
397
398
        return x


class BasicConv2d(nn.Module):
399
    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
400
        super().__init__()
401
402
403
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

404
    def forward(self, x: Tensor) -> Tensor:
405
406
407
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)
408
409


410
411
412
413
414
415
416
417
418
class Inception_V3_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
        transforms=partial(ImageClassification, crop_size=299, resize_size=342),
        meta={
            "num_params": 27161264,
            "min_size": (75, 75),
            "categories": _IMAGENET_CATEGORIES,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#inception-v3",
419
420
421
422
423
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 77.294,
                    "acc@5": 93.450,
                }
424
            },
425
            "_docs": """These weights are ported from the original paper.""",
426
427
428
429
430
        },
    )
    DEFAULT = IMAGENET1K_V1


431
@register_model()
432
433
@handle_legacy_interface(weights=("pretrained", Inception_V3_Weights.IMAGENET1K_V1))
def inception_v3(*, weights: Optional[Inception_V3_Weights] = None, progress: bool = True, **kwargs: Any) -> Inception3:
Aditya Oke's avatar
Aditya Oke committed
434
435
436
    """
    Inception v3 model architecture from
    `Rethinking the Inception Architecture for Computer Vision <http://arxiv.org/abs/1512.00567>`_.
437
438
439
440
441
442

    .. note::
        **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
        N x 3 x 299 x 299, so ensure your images are sized accordingly.

    Args:
Aditya Oke's avatar
Aditya Oke committed
443
444
445
446
447
448
449
450
451
452
453
454
455
456
        weights (:class:`~torchvision.models.Inception_V3_Weights`, optional): The
            pretrained weights for the model. See
            :class:`~torchvision.models.Inception_V3_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.Inception3``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/inception.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.Inception_V3_Weights
        :members:
457
    """
458
459
460
461
    weights = Inception_V3_Weights.verify(weights)

    original_aux_logits = kwargs.get("aux_logits", True)
    if weights is not None:
462
        if "transform_input" not in kwargs:
463
464
465
466
467
468
469
470
471
            _ovewrite_named_param(kwargs, "transform_input", True)
        _ovewrite_named_param(kwargs, "aux_logits", True)
        _ovewrite_named_param(kwargs, "init_weights", False)
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

    model = Inception3(**kwargs)

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))
472
473
474
475
        if not original_aux_logits:
            model.aux_logits = False
            model.AuxLogits = None

476
    return model
477
478
479
480
481
482
483
484
485
486
487
488


# The dictionary below is internal implementation detail and will be removed in v0.15
from ._utils import _ModelURLs


model_urls = _ModelURLs(
    {
        # Inception v3 ported from TensorFlow
        "inception_v3_google": Inception_V3_Weights.IMAGENET1K_V1.url,
    }
)