inception.py 13.9 KB
Newer Older
1
2
from __future__ import division

3
from collections import namedtuple
4
import warnings
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F
8
from torch.jit.annotations import Optional
9
from .utils import load_state_dict_from_url
10
11


12
__all__ = ['Inception3', 'inception_v3', 'InceptionOutputs', '_InceptionOutputs']
13
14
15
16
17
18
19


model_urls = {
    # Inception v3 ported from TensorFlow
    'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
}

20
21
22
23
24
25
InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
InceptionOutputs.__annotations__ = {'logits': torch.Tensor, 'aux_logits': Optional[torch.Tensor]}

# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _InceptionOutputs set here for backwards compat
_InceptionOutputs = InceptionOutputs
26

27

28
def inception_v3(pretrained=False, progress=True, **kwargs):
29
30
31
    r"""Inception v3 model architecture from
    `"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.

32
33
    .. note::
        **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
34
        N x 3 x 299 x 299, so ensure your images are sized accordingly.
35

36
37
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
38
        progress (bool): If True, displays a progress bar of the download to stderr
39
40
        aux_logits (bool): If True, add an auxiliary branch that can improve training.
            Default: *True*
41
        transform_input (bool): If True, preprocesses the input according to the method with which it
42
            was trained on ImageNet. Default: *False*
43
44
45
46
    """
    if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
47
48
49
50
51
        if 'aux_logits' in kwargs:
            original_aux_logits = kwargs['aux_logits']
            kwargs['aux_logits'] = True
        else:
            original_aux_logits = True
52
        model = Inception3(**kwargs)
53
54
55
        state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
                                              progress=progress)
        model.load_state_dict(state_dict)
56
57
58
        if not original_aux_logits:
            model.aux_logits = False
            del model.AuxLogits
59
60
61
62
63
64
        return model

    return Inception3(**kwargs)


class Inception3(nn.Module):
soumith's avatar
soumith committed
65

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    def __init__(self, num_classes=1000, aux_logits=True, transform_input=False):
        super(Inception3, self).__init__()
        self.aux_logits = aux_logits
        self.transform_input = transform_input
        self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
        self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
        self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
        self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
        self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
        self.Mixed_5b = InceptionA(192, pool_features=32)
        self.Mixed_5c = InceptionA(256, pool_features=64)
        self.Mixed_5d = InceptionA(288, pool_features=64)
        self.Mixed_6a = InceptionB(288)
        self.Mixed_6b = InceptionC(768, channels_7x7=128)
        self.Mixed_6c = InceptionC(768, channels_7x7=160)
        self.Mixed_6d = InceptionC(768, channels_7x7=160)
        self.Mixed_6e = InceptionC(768, channels_7x7=192)
        if aux_logits:
            self.AuxLogits = InceptionAux(768, num_classes)
        self.Mixed_7a = InceptionD(768)
        self.Mixed_7b = InceptionE(1280)
        self.Mixed_7c = InceptionE(2048)
        self.fc = nn.Linear(2048, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                import scipy.stats as stats
                stddev = m.stddev if hasattr(m, 'stddev') else 0.1
                X = stats.truncnorm(-2, 2, scale=stddev)
Michael Kösel's avatar
Michael Kösel committed
95
                values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
96
                values = values.view(m.weight.size())
Michael Kösel's avatar
Michael Kösel committed
97
98
                with torch.no_grad():
                    m.weight.copy_(values)
99
            elif isinstance(m, nn.BatchNorm2d):
100
101
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
102
103
104

    def forward(self, x):
        if self.transform_input:
105
106
107
108
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
109
        # N x 3 x 299 x 299
110
        x = self.Conv2d_1a_3x3(x)
111
        # N x 32 x 149 x 149
112
        x = self.Conv2d_2a_3x3(x)
113
        # N x 32 x 147 x 147
114
        x = self.Conv2d_2b_3x3(x)
115
        # N x 64 x 147 x 147
116
        x = F.max_pool2d(x, kernel_size=3, stride=2)
117
        # N x 64 x 73 x 73
118
        x = self.Conv2d_3b_1x1(x)
119
        # N x 80 x 73 x 73
120
        x = self.Conv2d_4a_3x3(x)
121
        # N x 192 x 71 x 71
122
        x = F.max_pool2d(x, kernel_size=3, stride=2)
123
        # N x 192 x 35 x 35
124
        x = self.Mixed_5b(x)
125
        # N x 256 x 35 x 35
126
        x = self.Mixed_5c(x)
surgan12's avatar
surgan12 committed
127
        # N x 288 x 35 x 35
128
        x = self.Mixed_5d(x)
129
        # N x 288 x 35 x 35
130
        x = self.Mixed_6a(x)
131
        # N x 768 x 17 x 17
132
        x = self.Mixed_6b(x)
133
        # N x 768 x 17 x 17
134
        x = self.Mixed_6c(x)
135
        # N x 768 x 17 x 17
136
        x = self.Mixed_6d(x)
137
        # N x 768 x 17 x 17
138
        x = self.Mixed_6e(x)
139
        # N x 768 x 17 x 17
140
141
        aux_defined = self.training and self.aux_logits
        if aux_defined:
142
            aux = self.AuxLogits(x)
143
144
        else:
            aux = None
145
        # N x 768 x 17 x 17
146
        x = self.Mixed_7a(x)
147
        # N x 1280 x 8 x 8
148
        x = self.Mixed_7b(x)
149
        # N x 2048 x 8 x 8
150
        x = self.Mixed_7c(x)
151
        # N x 2048 x 8 x 8
152
153
        # Adaptive average pooling
        x = F.adaptive_avg_pool2d(x, (1, 1))
154
        # N x 2048 x 1 x 1
155
        x = F.dropout(x, training=self.training)
156
        # N x 2048 x 1 x 1
157
        x = torch.flatten(x, 1)
158
        # N x 2048
159
        x = self.fc(x)
160
        # N x 1000 (num_classes)
161
162
163
164
165
166
167
168
169
170
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted InceptionNet always returns InceptionOutputs Tuple")
            return InceptionOutputs(x, aux)
        else:
            return self.eager_outputs(x, aux)

    @torch.jit.unused
    def eager_outputs(self, x, aux):
        # type: (torch.Tensor, Optional[torch.Tensor]) -> InceptionOutputs
171
        if self.training and self.aux_logits:
172
            return InceptionOutputs(x, aux)
173
174
175
176
        return x


class InceptionA(nn.Module):
soumith's avatar
soumith committed
177

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    def __init__(self, in_channels, pool_features):
        super(InceptionA, self).__init__()
        self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1)

        self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1)
        self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2)

        self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1)

        self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)


class InceptionB(nn.Module):
soumith's avatar
soumith committed
209

210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
    def __init__(self, in_channels):
        super(InceptionB, self).__init__()
        self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2)

        self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2)

    def forward(self, x):
        branch3x3 = self.branch3x3(x)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)

        outputs = [branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)


class InceptionC(nn.Module):
soumith's avatar
soumith committed
232

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    def __init__(self, in_channels, channels_7x7):
        super(InceptionC, self).__init__()
        self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1)

        c7 = channels_7x7
        self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1)
        self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0))

        self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1)
        self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))

        self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
        return torch.cat(outputs, 1)


class InceptionD(nn.Module):
soumith's avatar
soumith committed
271

272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    def __init__(self, in_channels):
        super(InceptionD, self).__init__()
        self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
        self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2)

        self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
        self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2)

    def forward(self, x):
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)

        branch7x7x3 = self.branch7x7x3_1(x)
        branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_4(branch7x7x3)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
        outputs = [branch3x3, branch7x7x3, branch_pool]
        return torch.cat(outputs, 1)


class InceptionE(nn.Module):
soumith's avatar
soumith committed
297

298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    def __init__(self, in_channels):
        super(InceptionE, self).__init__()
        self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1)

        self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1)
        self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))

        self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1)
        self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1)
        self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))

        self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):
soumith's avatar
soumith committed
339

340
341
342
343
344
345
346
347
348
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1)
        self.conv1 = BasicConv2d(128, 768, kernel_size=5)
        self.conv1.stddev = 0.01
        self.fc = nn.Linear(768, num_classes)
        self.fc.stddev = 0.001

    def forward(self, x):
349
        # N x 768 x 17 x 17
350
        x = F.avg_pool2d(x, kernel_size=5, stride=3)
351
        # N x 768 x 5 x 5
352
        x = self.conv0(x)
353
        # N x 128 x 5 x 5
354
        x = self.conv1(x)
355
        # N x 768 x 1 x 1
356
357
        # Adaptive average pooling
        x = F.adaptive_avg_pool2d(x, (1, 1))
358
        # N x 768 x 1 x 1
359
        x = torch.flatten(x, 1)
360
        # N x 768
361
        x = self.fc(x)
362
        # N x 1000
363
364
365
366
        return x


class BasicConv2d(nn.Module):
soumith's avatar
soumith committed
367

368
369
370
371
372
373
374
375
376
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)