inception.py 17.2 KB
Newer Older
1
import warnings
2
3
4
from collections import namedtuple
from typing import Callable, Any, Optional, Tuple, List

5
6
import torch
import torch.nn.functional as F
7
8
from torch import nn, Tensor

9
from .._internally_replaced_utils import load_state_dict_from_url
10
from ..utils import _log_api_usage_once
11
12


13
__all__ = ["Inception3", "inception_v3", "InceptionOutputs", "_InceptionOutputs"]
14
15
16
17


model_urls = {
    # Inception v3 ported from TensorFlow
18
    "inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
19
20
}

21
22
InceptionOutputs = namedtuple("InceptionOutputs", ["logits", "aux_logits"])
InceptionOutputs.__annotations__ = {"logits": Tensor, "aux_logits": Optional[Tensor]}
23
24
25
26

# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _InceptionOutputs set here for backwards compat
_InceptionOutputs = InceptionOutputs
27

28
29

class Inception3(nn.Module):
30
31
32
33
34
35
    def __init__(
        self,
        num_classes: int = 1000,
        aux_logits: bool = True,
        transform_input: bool = False,
        inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,
36
        init_weights: Optional[bool] = None,
37
        dropout: float = 0.5,
38
    ) -> None:
39
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
40
        _log_api_usage_once(self)
41
        if inception_blocks is None:
42
            inception_blocks = [BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux]
43
        if init_weights is None:
44
45
46
47
48
49
            warnings.warn(
                "The default weight initialization of inception_v3 will be changed in future releases of "
                "torchvision. If you wish to keep the old behavior (which leads to long initialization times"
                " due to scipy/scipy#11299), please set init_weights=True.",
                FutureWarning,
            )
50
            init_weights = True
51
52
        if len(inception_blocks) != 7:
            raise ValueError(f"lenght of inception_blocks should be 7 instead of {len(inception_blocks)}")
53
54
55
56
57
58
59
60
        conv_block = inception_blocks[0]
        inception_a = inception_blocks[1]
        inception_b = inception_blocks[2]
        inception_c = inception_blocks[3]
        inception_d = inception_blocks[4]
        inception_e = inception_blocks[5]
        inception_aux = inception_blocks[6]

61
62
        self.aux_logits = aux_logits
        self.transform_input = transform_input
63
64
65
        self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
        self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
        self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
66
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
67
68
        self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
        self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
69
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
70
71
72
73
74
75
76
77
        self.Mixed_5b = inception_a(192, pool_features=32)
        self.Mixed_5c = inception_a(256, pool_features=64)
        self.Mixed_5d = inception_a(288, pool_features=64)
        self.Mixed_6a = inception_b(288)
        self.Mixed_6b = inception_c(768, channels_7x7=128)
        self.Mixed_6c = inception_c(768, channels_7x7=160)
        self.Mixed_6d = inception_c(768, channels_7x7=160)
        self.Mixed_6e = inception_c(768, channels_7x7=192)
78
        self.AuxLogits: Optional[nn.Module] = None
79
        if aux_logits:
80
81
82
83
            self.AuxLogits = inception_aux(768, num_classes)
        self.Mixed_7a = inception_d(768)
        self.Mixed_7b = inception_e(1280)
        self.Mixed_7c = inception_e(2048)
84
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
85
        self.dropout = nn.Dropout(p=dropout)
86
        self.fc = nn.Linear(2048, num_classes)
87
88
89
        if init_weights:
            for m in self.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
90
                    stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1  # type: ignore
91
                    torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=stddev, a=-2, b=2)
92
93
94
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
95

96
    def _transform_input(self, x: Tensor) -> Tensor:
97
        if self.transform_input:
98
99
100
101
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
102
103
        return x

104
    def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
105
        # N x 3 x 299 x 299
106
        x = self.Conv2d_1a_3x3(x)
107
        # N x 32 x 149 x 149
108
        x = self.Conv2d_2a_3x3(x)
109
        # N x 32 x 147 x 147
110
        x = self.Conv2d_2b_3x3(x)
111
        # N x 64 x 147 x 147
112
        x = self.maxpool1(x)
113
        # N x 64 x 73 x 73
114
        x = self.Conv2d_3b_1x1(x)
115
        # N x 80 x 73 x 73
116
        x = self.Conv2d_4a_3x3(x)
117
        # N x 192 x 71 x 71
118
        x = self.maxpool2(x)
119
        # N x 192 x 35 x 35
120
        x = self.Mixed_5b(x)
121
        # N x 256 x 35 x 35
122
        x = self.Mixed_5c(x)
surgan12's avatar
surgan12 committed
123
        # N x 288 x 35 x 35
124
        x = self.Mixed_5d(x)
125
        # N x 288 x 35 x 35
126
        x = self.Mixed_6a(x)
127
        # N x 768 x 17 x 17
128
        x = self.Mixed_6b(x)
129
        # N x 768 x 17 x 17
130
        x = self.Mixed_6c(x)
131
        # N x 768 x 17 x 17
132
        x = self.Mixed_6d(x)
133
        # N x 768 x 17 x 17
134
        x = self.Mixed_6e(x)
135
        # N x 768 x 17 x 17
136
        aux: Optional[Tensor] = None
137
138
139
        if self.AuxLogits is not None:
            if self.training:
                aux = self.AuxLogits(x)
140
        # N x 768 x 17 x 17
141
        x = self.Mixed_7a(x)
142
        # N x 1280 x 8 x 8
143
        x = self.Mixed_7b(x)
144
        # N x 2048 x 8 x 8
145
        x = self.Mixed_7c(x)
146
        # N x 2048 x 8 x 8
147
        # Adaptive average pooling
148
        x = self.avgpool(x)
149
        # N x 2048 x 1 x 1
150
        x = self.dropout(x)
151
        # N x 2048 x 1 x 1
152
        x = torch.flatten(x, 1)
153
        # N x 2048
154
        x = self.fc(x)
155
        # N x 1000 (num_classes)
156
        return x, aux
157
158

    @torch.jit.unused
159
    def eager_outputs(self, x: Tensor, aux: Optional[Tensor]) -> InceptionOutputs:
160
        if self.training and self.aux_logits:
161
            return InceptionOutputs(x, aux)
162
        else:
163
            return x  # type: ignore[return-value]
164

165
    def forward(self, x: Tensor) -> InceptionOutputs:
166
167
168
169
170
171
172
173
174
        x = self._transform_input(x)
        x, aux = self._forward(x)
        aux_defined = self.training and self.aux_logits
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
            return InceptionOutputs(x, aux)
        else:
            return self.eager_outputs(x, aux)
175
176
177


class InceptionA(nn.Module):
178
    def __init__(
179
        self, in_channels: int, pool_features: int, conv_block: Optional[Callable[..., nn.Module]] = None
180
    ) -> None:
181
        super().__init__()
182
183
184
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
185

186
187
        self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
        self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
188

189
190
191
        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
192

193
        self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
194

195
    def _forward(self, x: Tensor) -> List[Tensor]:
196
197
198
199
200
201
202
203
204
205
206
207
208
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
209
210
        return outputs

211
    def forward(self, x: Tensor) -> Tensor:
212
        outputs = self._forward(x)
213
214
215
216
        return torch.cat(outputs, 1)


class InceptionB(nn.Module):
217
    def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
218
        super().__init__()
219
220
221
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
222

223
224
225
        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
226

227
    def _forward(self, x: Tensor) -> List[Tensor]:
228
229
230
231
232
233
234
235
236
        branch3x3 = self.branch3x3(x)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)

        outputs = [branch3x3, branch3x3dbl, branch_pool]
237
238
        return outputs

239
    def forward(self, x: Tensor) -> Tensor:
240
        outputs = self._forward(x)
241
242
243
244
        return torch.cat(outputs, 1)


class InceptionC(nn.Module):
245
    def __init__(
246
        self, in_channels: int, channels_7x7: int, conv_block: Optional[Callable[..., nn.Module]] = None
247
    ) -> None:
248
        super().__init__()
249
250
251
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
252
253

        c7 = channels_7x7
254
255
256
        self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
        self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
257

258
259
260
261
262
        self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
        self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
263

264
        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
265

266
    def _forward(self, x: Tensor) -> List[Tensor]:
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
283
284
        return outputs

285
    def forward(self, x: Tensor) -> Tensor:
286
        outputs = self._forward(x)
287
288
289
290
        return torch.cat(outputs, 1)


class InceptionD(nn.Module):
291
    def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
292
        super().__init__()
293
294
295
296
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
        self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
297

298
299
300
301
        self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
        self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
302

303
    def _forward(self, x: Tensor) -> List[Tensor]:
304
305
306
307
308
309
310
311
312
313
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)

        branch7x7x3 = self.branch7x7x3_1(x)
        branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_4(branch7x7x3)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
        outputs = [branch3x3, branch7x7x3, branch_pool]
314
315
        return outputs

316
    def forward(self, x: Tensor) -> Tensor:
317
        outputs = self._forward(x)
318
319
320
321
        return torch.cat(outputs, 1)


class InceptionE(nn.Module):
322
    def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
323
        super().__init__()
324
325
326
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
327

328
329
330
        self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
        self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
331

332
333
334
335
        self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
        self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
336

337
        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
338

339
    def _forward(self, x: Tensor) -> List[Tensor]:
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
361
362
        return outputs

363
    def forward(self, x: Tensor) -> Tensor:
364
        outputs = self._forward(x)
365
366
367
368
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):
369
    def __init__(
370
        self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None
371
    ) -> None:
372
        super().__init__()
373
374
375
376
        if conv_block is None:
            conv_block = BasicConv2d
        self.conv0 = conv_block(in_channels, 128, kernel_size=1)
        self.conv1 = conv_block(128, 768, kernel_size=5)
377
        self.conv1.stddev = 0.01  # type: ignore[assignment]
378
        self.fc = nn.Linear(768, num_classes)
379
        self.fc.stddev = 0.001  # type: ignore[assignment]
380

381
    def forward(self, x: Tensor) -> Tensor:
382
        # N x 768 x 17 x 17
383
        x = F.avg_pool2d(x, kernel_size=5, stride=3)
384
        # N x 768 x 5 x 5
385
        x = self.conv0(x)
386
        # N x 128 x 5 x 5
387
        x = self.conv1(x)
388
        # N x 768 x 1 x 1
389
390
        # Adaptive average pooling
        x = F.adaptive_avg_pool2d(x, (1, 1))
391
        # N x 768 x 1 x 1
392
        x = torch.flatten(x, 1)
393
        # N x 768
394
        x = self.fc(x)
395
        # N x 1000
396
397
398
399
        return x


class BasicConv2d(nn.Module):
400
    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
401
        super().__init__()
402
403
404
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

405
    def forward(self, x: Tensor) -> Tensor:
406
407
408
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425


def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> Inception3:
    r"""Inception v3 model architecture from
    `"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
    The required minimum input size of the model is 75x75.

    .. note::
        **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
        N x 3 x 299 x 299, so ensure your images are sized accordingly.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
        aux_logits (bool): If True, add an auxiliary branch that can improve training.
            Default: *True*
        transform_input (bool): If True, preprocesses the input according to the method with which it
426
            was trained on ImageNet. Default: True if ``pretrained=True``, else False.
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
    """
    if pretrained:
        if "transform_input" not in kwargs:
            kwargs["transform_input"] = True
        if "aux_logits" in kwargs:
            original_aux_logits = kwargs["aux_logits"]
            kwargs["aux_logits"] = True
        else:
            original_aux_logits = True
        kwargs["init_weights"] = False  # we are loading weights from a pretrained model
        model = Inception3(**kwargs)
        state_dict = load_state_dict_from_url(model_urls["inception_v3_google"], progress=progress)
        model.load_state_dict(state_dict)
        if not original_aux_logits:
            model.aux_logits = False
            model.AuxLogits = None
        return model

    return Inception3(**kwargs)