inception.py 17 KB
Newer Older
1
import warnings
2
3
4
from collections import namedtuple
from typing import Callable, Any, Optional, Tuple, List

5
6
import torch
import torch.nn.functional as F
7
8
from torch import nn, Tensor

9
from .._internally_replaced_utils import load_state_dict_from_url
10
from ..utils import _log_api_usage_once
11
12


13
__all__ = ["Inception3", "inception_v3", "InceptionOutputs", "_InceptionOutputs"]
14
15
16
17


model_urls = {
    # Inception v3 ported from TensorFlow
18
    "inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
19
20
}

21
22
InceptionOutputs = namedtuple("InceptionOutputs", ["logits", "aux_logits"])
InceptionOutputs.__annotations__ = {"logits": Tensor, "aux_logits": Optional[Tensor]}
23
24
25
26

# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _InceptionOutputs set here for backwards compat
_InceptionOutputs = InceptionOutputs
27

28
29

class Inception3(nn.Module):
30
31
32
33
34
35
    def __init__(
        self,
        num_classes: int = 1000,
        aux_logits: bool = True,
        transform_input: bool = False,
        inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,
36
        init_weights: Optional[bool] = None,
37
        dropout: float = 0.5,
38
    ) -> None:
39
        super().__init__()
40
        _log_api_usage_once(self)
41
        if inception_blocks is None:
42
            inception_blocks = [BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux]
43
        if init_weights is None:
44
45
46
47
48
49
            warnings.warn(
                "The default weight initialization of inception_v3 will be changed in future releases of "
                "torchvision. If you wish to keep the old behavior (which leads to long initialization times"
                " due to scipy/scipy#11299), please set init_weights=True.",
                FutureWarning,
            )
50
            init_weights = True
51
52
53
54
55
56
57
58
59
        assert len(inception_blocks) == 7
        conv_block = inception_blocks[0]
        inception_a = inception_blocks[1]
        inception_b = inception_blocks[2]
        inception_c = inception_blocks[3]
        inception_d = inception_blocks[4]
        inception_e = inception_blocks[5]
        inception_aux = inception_blocks[6]

60
61
        self.aux_logits = aux_logits
        self.transform_input = transform_input
62
63
64
        self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
        self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
        self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
65
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
66
67
        self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
        self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
68
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
69
70
71
72
73
74
75
76
        self.Mixed_5b = inception_a(192, pool_features=32)
        self.Mixed_5c = inception_a(256, pool_features=64)
        self.Mixed_5d = inception_a(288, pool_features=64)
        self.Mixed_6a = inception_b(288)
        self.Mixed_6b = inception_c(768, channels_7x7=128)
        self.Mixed_6c = inception_c(768, channels_7x7=160)
        self.Mixed_6d = inception_c(768, channels_7x7=160)
        self.Mixed_6e = inception_c(768, channels_7x7=192)
77
        self.AuxLogits: Optional[nn.Module] = None
78
        if aux_logits:
79
80
81
82
            self.AuxLogits = inception_aux(768, num_classes)
        self.Mixed_7a = inception_d(768)
        self.Mixed_7b = inception_e(1280)
        self.Mixed_7c = inception_e(2048)
83
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
84
        self.dropout = nn.Dropout(p=dropout)
85
        self.fc = nn.Linear(2048, num_classes)
86
87
88
        if init_weights:
            for m in self.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
89
                    stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1  # type: ignore
90
                    torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=stddev, a=-2, b=2)
91
92
93
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
94

95
    def _transform_input(self, x: Tensor) -> Tensor:
96
        if self.transform_input:
97
98
99
100
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
101
102
        return x

103
    def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
104
        # N x 3 x 299 x 299
105
        x = self.Conv2d_1a_3x3(x)
106
        # N x 32 x 149 x 149
107
        x = self.Conv2d_2a_3x3(x)
108
        # N x 32 x 147 x 147
109
        x = self.Conv2d_2b_3x3(x)
110
        # N x 64 x 147 x 147
111
        x = self.maxpool1(x)
112
        # N x 64 x 73 x 73
113
        x = self.Conv2d_3b_1x1(x)
114
        # N x 80 x 73 x 73
115
        x = self.Conv2d_4a_3x3(x)
116
        # N x 192 x 71 x 71
117
        x = self.maxpool2(x)
118
        # N x 192 x 35 x 35
119
        x = self.Mixed_5b(x)
120
        # N x 256 x 35 x 35
121
        x = self.Mixed_5c(x)
surgan12's avatar
surgan12 committed
122
        # N x 288 x 35 x 35
123
        x = self.Mixed_5d(x)
124
        # N x 288 x 35 x 35
125
        x = self.Mixed_6a(x)
126
        # N x 768 x 17 x 17
127
        x = self.Mixed_6b(x)
128
        # N x 768 x 17 x 17
129
        x = self.Mixed_6c(x)
130
        # N x 768 x 17 x 17
131
        x = self.Mixed_6d(x)
132
        # N x 768 x 17 x 17
133
        x = self.Mixed_6e(x)
134
        # N x 768 x 17 x 17
135
        aux: Optional[Tensor] = None
136
137
138
        if self.AuxLogits is not None:
            if self.training:
                aux = self.AuxLogits(x)
139
        # N x 768 x 17 x 17
140
        x = self.Mixed_7a(x)
141
        # N x 1280 x 8 x 8
142
        x = self.Mixed_7b(x)
143
        # N x 2048 x 8 x 8
144
        x = self.Mixed_7c(x)
145
        # N x 2048 x 8 x 8
146
        # Adaptive average pooling
147
        x = self.avgpool(x)
148
        # N x 2048 x 1 x 1
149
        x = self.dropout(x)
150
        # N x 2048 x 1 x 1
151
        x = torch.flatten(x, 1)
152
        # N x 2048
153
        x = self.fc(x)
154
        # N x 1000 (num_classes)
155
        return x, aux
156
157

    @torch.jit.unused
158
    def eager_outputs(self, x: Tensor, aux: Optional[Tensor]) -> InceptionOutputs:
159
        if self.training and self.aux_logits:
160
            return InceptionOutputs(x, aux)
161
        else:
162
            return x  # type: ignore[return-value]
163

164
    def forward(self, x: Tensor) -> InceptionOutputs:
165
166
167
168
169
170
171
172
173
        x = self._transform_input(x)
        x, aux = self._forward(x)
        aux_defined = self.training and self.aux_logits
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
            return InceptionOutputs(x, aux)
        else:
            return self.eager_outputs(x, aux)
174
175
176


class InceptionA(nn.Module):
177
    def __init__(
178
        self, in_channels: int, pool_features: int, conv_block: Optional[Callable[..., nn.Module]] = None
179
    ) -> None:
180
        super().__init__()
181
182
183
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
184

185
186
        self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
        self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
187

188
189
190
        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
191

192
        self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
193

194
    def _forward(self, x: Tensor) -> List[Tensor]:
195
196
197
198
199
200
201
202
203
204
205
206
207
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
208
209
        return outputs

210
    def forward(self, x: Tensor) -> Tensor:
211
        outputs = self._forward(x)
212
213
214
215
        return torch.cat(outputs, 1)


class InceptionB(nn.Module):
216
    def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
217
        super().__init__()
218
219
220
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
221

222
223
224
        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
225

226
    def _forward(self, x: Tensor) -> List[Tensor]:
227
228
229
230
231
232
233
234
235
        branch3x3 = self.branch3x3(x)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)

        outputs = [branch3x3, branch3x3dbl, branch_pool]
236
237
        return outputs

238
    def forward(self, x: Tensor) -> Tensor:
239
        outputs = self._forward(x)
240
241
242
243
        return torch.cat(outputs, 1)


class InceptionC(nn.Module):
244
    def __init__(
245
        self, in_channels: int, channels_7x7: int, conv_block: Optional[Callable[..., nn.Module]] = None
246
    ) -> None:
247
        super().__init__()
248
249
250
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
251
252

        c7 = channels_7x7
253
254
255
        self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
        self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
256

257
258
259
260
261
        self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
        self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
262

263
        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
264

265
    def _forward(self, x: Tensor) -> List[Tensor]:
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
282
283
        return outputs

284
    def forward(self, x: Tensor) -> Tensor:
285
        outputs = self._forward(x)
286
287
288
289
        return torch.cat(outputs, 1)


class InceptionD(nn.Module):
290
    def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
291
        super().__init__()
292
293
294
295
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
        self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
296

297
298
299
300
        self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
        self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
301

302
    def _forward(self, x: Tensor) -> List[Tensor]:
303
304
305
306
307
308
309
310
311
312
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)

        branch7x7x3 = self.branch7x7x3_1(x)
        branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_4(branch7x7x3)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
        outputs = [branch3x3, branch7x7x3, branch_pool]
313
314
        return outputs

315
    def forward(self, x: Tensor) -> Tensor:
316
        outputs = self._forward(x)
317
318
319
320
        return torch.cat(outputs, 1)


class InceptionE(nn.Module):
321
    def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
322
        super().__init__()
323
324
325
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
326

327
328
329
        self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
        self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
330

331
332
333
334
        self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
        self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
335

336
        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
337

338
    def _forward(self, x: Tensor) -> List[Tensor]:
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
360
361
        return outputs

362
    def forward(self, x: Tensor) -> Tensor:
363
        outputs = self._forward(x)
364
365
366
367
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):
368
    def __init__(
369
        self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None
370
    ) -> None:
371
        super().__init__()
372
373
374
375
        if conv_block is None:
            conv_block = BasicConv2d
        self.conv0 = conv_block(in_channels, 128, kernel_size=1)
        self.conv1 = conv_block(128, 768, kernel_size=5)
376
        self.conv1.stddev = 0.01  # type: ignore[assignment]
377
        self.fc = nn.Linear(768, num_classes)
378
        self.fc.stddev = 0.001  # type: ignore[assignment]
379

380
    def forward(self, x: Tensor) -> Tensor:
381
        # N x 768 x 17 x 17
382
        x = F.avg_pool2d(x, kernel_size=5, stride=3)
383
        # N x 768 x 5 x 5
384
        x = self.conv0(x)
385
        # N x 128 x 5 x 5
386
        x = self.conv1(x)
387
        # N x 768 x 1 x 1
388
389
        # Adaptive average pooling
        x = F.adaptive_avg_pool2d(x, (1, 1))
390
        # N x 768 x 1 x 1
391
        x = torch.flatten(x, 1)
392
        # N x 768
393
        x = self.fc(x)
394
        # N x 1000
395
396
397
398
        return x


class BasicConv2d(nn.Module):
399
    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
400
        super().__init__()
401
402
403
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

404
    def forward(self, x: Tensor) -> Tensor:
405
406
407
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444


def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> Inception3:
    r"""Inception v3 model architecture from
    `"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
    The required minimum input size of the model is 75x75.

    .. note::
        **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
        N x 3 x 299 x 299, so ensure your images are sized accordingly.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
        aux_logits (bool): If True, add an auxiliary branch that can improve training.
            Default: *True*
        transform_input (bool): If True, preprocesses the input according to the method with which it
            was trained on ImageNet. Default: *False*
    """
    if pretrained:
        if "transform_input" not in kwargs:
            kwargs["transform_input"] = True
        if "aux_logits" in kwargs:
            original_aux_logits = kwargs["aux_logits"]
            kwargs["aux_logits"] = True
        else:
            original_aux_logits = True
        kwargs["init_weights"] = False  # we are loading weights from a pretrained model
        model = Inception3(**kwargs)
        state_dict = load_state_dict_from_url(model_urls["inception_v3_google"], progress=progress)
        model.load_state_dict(state_dict)
        if not original_aux_logits:
            model.aux_logits = False
            model.AuxLogits = None
        return model

    return Inception3(**kwargs)