inception.py 17.1 KB
Newer Older
1
import warnings
2
3
4
from collections import namedtuple
from typing import Callable, Any, Optional, Tuple, List

5
6
import torch
import torch.nn.functional as F
7
8
from torch import nn, Tensor

9
from .._internally_replaced_utils import load_state_dict_from_url
10
11


12
__all__ = ["Inception3", "inception_v3", "InceptionOutputs", "_InceptionOutputs"]
13
14
15
16


model_urls = {
    # Inception v3 ported from TensorFlow
17
    "inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
18
19
}

20
21
InceptionOutputs = namedtuple("InceptionOutputs", ["logits", "aux_logits"])
InceptionOutputs.__annotations__ = {"logits": Tensor, "aux_logits": Optional[Tensor]}
22
23
24
25

# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _InceptionOutputs set here for backwards compat
_InceptionOutputs = InceptionOutputs
26

27

28
def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> "Inception3":
29
30
    r"""Inception v3 model architecture from
    `"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
31
    The required minimum input size of the model is 75x75.
32

33
34
    .. note::
        **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
35
        N x 3 x 299 x 299, so ensure your images are sized accordingly.
36

37
38
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
39
        progress (bool): If True, displays a progress bar of the download to stderr
40
41
        aux_logits (bool): If True, add an auxiliary branch that can improve training.
            Default: *True*
42
        transform_input (bool): If True, preprocesses the input according to the method with which it
43
            was trained on ImageNet. Default: *False*
44
45
    """
    if pretrained:
46
47
48
49
50
        if "transform_input" not in kwargs:
            kwargs["transform_input"] = True
        if "aux_logits" in kwargs:
            original_aux_logits = kwargs["aux_logits"]
            kwargs["aux_logits"] = True
51
52
        else:
            original_aux_logits = True
53
        kwargs["init_weights"] = False  # we are loading weights from a pretrained model
54
        model = Inception3(**kwargs)
55
        state_dict = load_state_dict_from_url(model_urls["inception_v3_google"], progress=progress)
56
        model.load_state_dict(state_dict)
57
58
        if not original_aux_logits:
            model.aux_logits = False
59
            model.AuxLogits = None
60
61
62
63
64
65
        return model

    return Inception3(**kwargs)


class Inception3(nn.Module):
66
67
68
69
70
71
    def __init__(
        self,
        num_classes: int = 1000,
        aux_logits: bool = True,
        transform_input: bool = False,
        inception_blocks: Optional[List[Callable[..., nn.Module]]] = None,
72
        init_weights: Optional[bool] = None,
73
        dropout: float = 0.5,
74
    ) -> None:
75
        super(Inception3, self).__init__()
76
        if inception_blocks is None:
77
            inception_blocks = [BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux]
78
        if init_weights is None:
79
80
81
82
83
84
            warnings.warn(
                "The default weight initialization of inception_v3 will be changed in future releases of "
                "torchvision. If you wish to keep the old behavior (which leads to long initialization times"
                " due to scipy/scipy#11299), please set init_weights=True.",
                FutureWarning,
            )
85
            init_weights = True
86
87
88
89
90
91
92
93
94
        assert len(inception_blocks) == 7
        conv_block = inception_blocks[0]
        inception_a = inception_blocks[1]
        inception_b = inception_blocks[2]
        inception_c = inception_blocks[3]
        inception_d = inception_blocks[4]
        inception_e = inception_blocks[5]
        inception_aux = inception_blocks[6]

95
96
        self.aux_logits = aux_logits
        self.transform_input = transform_input
97
98
99
        self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
        self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
        self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
100
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2)
101
102
        self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
        self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
103
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2)
104
105
106
107
108
109
110
111
        self.Mixed_5b = inception_a(192, pool_features=32)
        self.Mixed_5c = inception_a(256, pool_features=64)
        self.Mixed_5d = inception_a(288, pool_features=64)
        self.Mixed_6a = inception_b(288)
        self.Mixed_6b = inception_c(768, channels_7x7=128)
        self.Mixed_6c = inception_c(768, channels_7x7=160)
        self.Mixed_6d = inception_c(768, channels_7x7=160)
        self.Mixed_6e = inception_c(768, channels_7x7=192)
112
        self.AuxLogits: Optional[nn.Module] = None
113
        if aux_logits:
114
115
116
117
            self.AuxLogits = inception_aux(768, num_classes)
        self.Mixed_7a = inception_d(768)
        self.Mixed_7b = inception_e(1280)
        self.Mixed_7c = inception_e(2048)
118
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
119
        self.dropout = nn.Dropout(p=dropout)
120
        self.fc = nn.Linear(2048, num_classes)
121
122
123
        if init_weights:
            for m in self.modules():
                if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
124
                    stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1  # type: ignore
125
                    torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=stddev, a=-2, b=2)
126
127
128
                elif isinstance(m, nn.BatchNorm2d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
129

130
    def _transform_input(self, x: Tensor) -> Tensor:
131
        if self.transform_input:
132
133
134
135
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
136
137
        return x

138
    def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
139
        # N x 3 x 299 x 299
140
        x = self.Conv2d_1a_3x3(x)
141
        # N x 32 x 149 x 149
142
        x = self.Conv2d_2a_3x3(x)
143
        # N x 32 x 147 x 147
144
        x = self.Conv2d_2b_3x3(x)
145
        # N x 64 x 147 x 147
146
        x = self.maxpool1(x)
147
        # N x 64 x 73 x 73
148
        x = self.Conv2d_3b_1x1(x)
149
        # N x 80 x 73 x 73
150
        x = self.Conv2d_4a_3x3(x)
151
        # N x 192 x 71 x 71
152
        x = self.maxpool2(x)
153
        # N x 192 x 35 x 35
154
        x = self.Mixed_5b(x)
155
        # N x 256 x 35 x 35
156
        x = self.Mixed_5c(x)
surgan12's avatar
surgan12 committed
157
        # N x 288 x 35 x 35
158
        x = self.Mixed_5d(x)
159
        # N x 288 x 35 x 35
160
        x = self.Mixed_6a(x)
161
        # N x 768 x 17 x 17
162
        x = self.Mixed_6b(x)
163
        # N x 768 x 17 x 17
164
        x = self.Mixed_6c(x)
165
        # N x 768 x 17 x 17
166
        x = self.Mixed_6d(x)
167
        # N x 768 x 17 x 17
168
        x = self.Mixed_6e(x)
169
        # N x 768 x 17 x 17
170
        aux: Optional[Tensor] = None
171
172
173
        if self.AuxLogits is not None:
            if self.training:
                aux = self.AuxLogits(x)
174
        # N x 768 x 17 x 17
175
        x = self.Mixed_7a(x)
176
        # N x 1280 x 8 x 8
177
        x = self.Mixed_7b(x)
178
        # N x 2048 x 8 x 8
179
        x = self.Mixed_7c(x)
180
        # N x 2048 x 8 x 8
181
        # Adaptive average pooling
182
        x = self.avgpool(x)
183
        # N x 2048 x 1 x 1
184
        x = self.dropout(x)
185
        # N x 2048 x 1 x 1
186
        x = torch.flatten(x, 1)
187
        # N x 2048
188
        x = self.fc(x)
189
        # N x 1000 (num_classes)
190
        return x, aux
191
192

    @torch.jit.unused
193
    def eager_outputs(self, x: Tensor, aux: Optional[Tensor]) -> InceptionOutputs:
194
        if self.training and self.aux_logits:
195
            return InceptionOutputs(x, aux)
196
        else:
197
            return x  # type: ignore[return-value]
198

199
    def forward(self, x: Tensor) -> InceptionOutputs:
200
201
202
203
204
205
206
207
208
        x = self._transform_input(x)
        x, aux = self._forward(x)
        aux_defined = self.training and self.aux_logits
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
            return InceptionOutputs(x, aux)
        else:
            return self.eager_outputs(x, aux)
209
210
211


class InceptionA(nn.Module):
212
    def __init__(
213
        self, in_channels: int, pool_features: int, conv_block: Optional[Callable[..., nn.Module]] = None
214
    ) -> None:
215
        super(InceptionA, self).__init__()
216
217
218
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
219

220
221
        self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
        self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
222

223
224
225
        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
226

227
        self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
228

229
    def _forward(self, x: Tensor) -> List[Tensor]:
230
231
232
233
234
235
236
237
238
239
240
241
242
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
243
244
        return outputs

245
    def forward(self, x: Tensor) -> Tensor:
246
        outputs = self._forward(x)
247
248
249
250
        return torch.cat(outputs, 1)


class InceptionB(nn.Module):
251
    def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
252
        super(InceptionB, self).__init__()
253
254
255
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
256

257
258
259
        self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
260

261
    def _forward(self, x: Tensor) -> List[Tensor]:
262
263
264
265
266
267
268
269
270
        branch3x3 = self.branch3x3(x)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)

        outputs = [branch3x3, branch3x3dbl, branch_pool]
271
272
        return outputs

273
    def forward(self, x: Tensor) -> Tensor:
274
        outputs = self._forward(x)
275
276
277
278
        return torch.cat(outputs, 1)


class InceptionC(nn.Module):
279
    def __init__(
280
        self, in_channels: int, channels_7x7: int, conv_block: Optional[Callable[..., nn.Module]] = None
281
    ) -> None:
282
        super(InceptionC, self).__init__()
283
284
285
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
286
287

        c7 = channels_7x7
288
289
290
        self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
        self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
291

292
293
294
295
296
        self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
        self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
297

298
        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
299

300
    def _forward(self, x: Tensor) -> List[Tensor]:
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
317
318
        return outputs

319
    def forward(self, x: Tensor) -> Tensor:
320
        outputs = self._forward(x)
321
322
323
324
        return torch.cat(outputs, 1)


class InceptionD(nn.Module):
325
    def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
326
        super(InceptionD, self).__init__()
327
328
329
330
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
        self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
331

332
333
334
335
        self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
        self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
336

337
    def _forward(self, x: Tensor) -> List[Tensor]:
338
339
340
341
342
343
344
345
346
347
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)

        branch7x7x3 = self.branch7x7x3_1(x)
        branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_4(branch7x7x3)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
        outputs = [branch3x3, branch7x7x3, branch_pool]
348
349
        return outputs

350
    def forward(self, x: Tensor) -> Tensor:
351
        outputs = self._forward(x)
352
353
354
355
        return torch.cat(outputs, 1)


class InceptionE(nn.Module):
356
    def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None:
357
        super(InceptionE, self).__init__()
358
359
360
        if conv_block is None:
            conv_block = BasicConv2d
        self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
361

362
363
364
        self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
        self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
365

366
367
368
369
        self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
        self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
        self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
370

371
        self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
372

373
    def _forward(self, x: Tensor) -> List[Tensor]:
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
395
396
        return outputs

397
    def forward(self, x: Tensor) -> Tensor:
398
        outputs = self._forward(x)
399
400
401
402
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):
403
    def __init__(
404
        self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None
405
    ) -> None:
406
        super(InceptionAux, self).__init__()
407
408
409
410
        if conv_block is None:
            conv_block = BasicConv2d
        self.conv0 = conv_block(in_channels, 128, kernel_size=1)
        self.conv1 = conv_block(128, 768, kernel_size=5)
411
        self.conv1.stddev = 0.01  # type: ignore[assignment]
412
        self.fc = nn.Linear(768, num_classes)
413
        self.fc.stddev = 0.001  # type: ignore[assignment]
414

415
    def forward(self, x: Tensor) -> Tensor:
416
        # N x 768 x 17 x 17
417
        x = F.avg_pool2d(x, kernel_size=5, stride=3)
418
        # N x 768 x 5 x 5
419
        x = self.conv0(x)
420
        # N x 128 x 5 x 5
421
        x = self.conv1(x)
422
        # N x 768 x 1 x 1
423
424
        # Adaptive average pooling
        x = F.adaptive_avg_pool2d(x, (1, 1))
425
        # N x 768 x 1 x 1
426
        x = torch.flatten(x, 1)
427
        # N x 768
428
        x = self.fc(x)
429
        # N x 1000
430
431
432
433
        return x


class BasicConv2d(nn.Module):
434
    def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
435
436
437
438
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

439
    def forward(self, x: Tensor) -> Tensor:
440
441
442
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)