"tests/vscode:/vscode.git/clone" did not exist on "f5383a7e5a79bd9e912af0fb0199c557b9987877"
inception.py 13.1 KB
Newer Older
1
from collections import namedtuple
2
3
4
import torch
import torch.nn as nn
import torch.nn.functional as F
5
from .utils import load_state_dict_from_url
6
7
8
9
10
11
12
13
14
15


__all__ = ['Inception3', 'inception_v3']


model_urls = {
    # Inception v3 ported from TensorFlow
    'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
}

16
17
_InceptionOuputs = namedtuple('InceptionOuputs', ['logits', 'aux_logits'])

18

19
def inception_v3(pretrained=False, progress=True, **kwargs):
20
21
22
    r"""Inception v3 model architecture from
    `"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.

23
24
    .. note::
        **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
25
        N x 3 x 299 x 299, so ensure your images are sized accordingly.
26

27
28
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
29
        progress (bool): If True, displays a progress bar of the download to stderr
30
31
        aux_logits (bool): If True, add an auxiliary branch that can improve training.
            Default: *True*
32
        transform_input (bool): If True, preprocesses the input according to the method with which it
33
            was trained on ImageNet. Default: *False*
34
35
36
37
    """
    if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
38
39
40
41
42
        if 'aux_logits' in kwargs:
            original_aux_logits = kwargs['aux_logits']
            kwargs['aux_logits'] = True
        else:
            original_aux_logits = True
43
        model = Inception3(**kwargs)
44
45
46
        state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
                                              progress=progress)
        model.load_state_dict(state_dict)
47
48
49
        if not original_aux_logits:
            model.aux_logits = False
            del model.AuxLogits
50
51
52
53
54
55
        return model

    return Inception3(**kwargs)


class Inception3(nn.Module):
soumith's avatar
soumith committed
56

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    def __init__(self, num_classes=1000, aux_logits=True, transform_input=False):
        super(Inception3, self).__init__()
        self.aux_logits = aux_logits
        self.transform_input = transform_input
        self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
        self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
        self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
        self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
        self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
        self.Mixed_5b = InceptionA(192, pool_features=32)
        self.Mixed_5c = InceptionA(256, pool_features=64)
        self.Mixed_5d = InceptionA(288, pool_features=64)
        self.Mixed_6a = InceptionB(288)
        self.Mixed_6b = InceptionC(768, channels_7x7=128)
        self.Mixed_6c = InceptionC(768, channels_7x7=160)
        self.Mixed_6d = InceptionC(768, channels_7x7=160)
        self.Mixed_6e = InceptionC(768, channels_7x7=192)
        if aux_logits:
            self.AuxLogits = InceptionAux(768, num_classes)
        self.Mixed_7a = InceptionD(768)
        self.Mixed_7b = InceptionE(1280)
        self.Mixed_7c = InceptionE(2048)
        self.fc = nn.Linear(2048, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                import scipy.stats as stats
                stddev = m.stddev if hasattr(m, 'stddev') else 0.1
                X = stats.truncnorm(-2, 2, scale=stddev)
Michael Kösel's avatar
Michael Kösel committed
86
                values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
87
                values = values.view(m.weight.size())
Michael Kösel's avatar
Michael Kösel committed
88
89
                with torch.no_grad():
                    m.weight.copy_(values)
90
            elif isinstance(m, nn.BatchNorm2d):
91
92
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
93
94
95

    def forward(self, x):
        if self.transform_input:
96
97
98
99
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
100
        # N x 3 x 299 x 299
101
        x = self.Conv2d_1a_3x3(x)
102
        # N x 32 x 149 x 149
103
        x = self.Conv2d_2a_3x3(x)
104
        # N x 32 x 147 x 147
105
        x = self.Conv2d_2b_3x3(x)
106
        # N x 64 x 147 x 147
107
        x = F.max_pool2d(x, kernel_size=3, stride=2)
108
        # N x 64 x 73 x 73
109
        x = self.Conv2d_3b_1x1(x)
110
        # N x 80 x 73 x 73
111
        x = self.Conv2d_4a_3x3(x)
112
        # N x 192 x 71 x 71
113
        x = F.max_pool2d(x, kernel_size=3, stride=2)
114
        # N x 192 x 35 x 35
115
        x = self.Mixed_5b(x)
116
        # N x 256 x 35 x 35
117
        x = self.Mixed_5c(x)
surgan12's avatar
surgan12 committed
118
        # N x 288 x 35 x 35
119
        x = self.Mixed_5d(x)
120
        # N x 288 x 35 x 35
121
        x = self.Mixed_6a(x)
122
        # N x 768 x 17 x 17
123
        x = self.Mixed_6b(x)
124
        # N x 768 x 17 x 17
125
        x = self.Mixed_6c(x)
126
        # N x 768 x 17 x 17
127
        x = self.Mixed_6d(x)
128
        # N x 768 x 17 x 17
129
        x = self.Mixed_6e(x)
130
        # N x 768 x 17 x 17
131
132
        if self.training and self.aux_logits:
            aux = self.AuxLogits(x)
133
        # N x 768 x 17 x 17
134
        x = self.Mixed_7a(x)
135
        # N x 1280 x 8 x 8
136
        x = self.Mixed_7b(x)
137
        # N x 2048 x 8 x 8
138
        x = self.Mixed_7c(x)
139
        # N x 2048 x 8 x 8
140
141
        # Adaptive average pooling
        x = F.adaptive_avg_pool2d(x, (1, 1))
142
        # N x 2048 x 1 x 1
143
        x = F.dropout(x, training=self.training)
144
        # N x 2048 x 1 x 1
145
        x = x.view(x.size(0), -1)
146
        # N x 2048
147
        x = self.fc(x)
148
        # N x 1000 (num_classes)
149
        if self.training and self.aux_logits:
150
            return _InceptionOuputs(x, aux)
151
152
153
154
        return x


class InceptionA(nn.Module):
soumith's avatar
soumith committed
155

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    def __init__(self, in_channels, pool_features):
        super(InceptionA, self).__init__()
        self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1)

        self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1)
        self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2)

        self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1)

        self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)


class InceptionB(nn.Module):
soumith's avatar
soumith committed
187

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
    def __init__(self, in_channels):
        super(InceptionB, self).__init__()
        self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2)

        self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2)

    def forward(self, x):
        branch3x3 = self.branch3x3(x)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)

        outputs = [branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)


class InceptionC(nn.Module):
soumith's avatar
soumith committed
210

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    def __init__(self, in_channels, channels_7x7):
        super(InceptionC, self).__init__()
        self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1)

        c7 = channels_7x7
        self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1)
        self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0))

        self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1)
        self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))

        self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
        return torch.cat(outputs, 1)


class InceptionD(nn.Module):
soumith's avatar
soumith committed
249

250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
    def __init__(self, in_channels):
        super(InceptionD, self).__init__()
        self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
        self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2)

        self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
        self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2)

    def forward(self, x):
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)

        branch7x7x3 = self.branch7x7x3_1(x)
        branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_4(branch7x7x3)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
        outputs = [branch3x3, branch7x7x3, branch_pool]
        return torch.cat(outputs, 1)


class InceptionE(nn.Module):
soumith's avatar
soumith committed
275

276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
    def __init__(self, in_channels):
        super(InceptionE, self).__init__()
        self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1)

        self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1)
        self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))

        self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1)
        self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1)
        self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))

        self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):
soumith's avatar
soumith committed
317

318
319
320
321
322
323
324
325
326
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1)
        self.conv1 = BasicConv2d(128, 768, kernel_size=5)
        self.conv1.stddev = 0.01
        self.fc = nn.Linear(768, num_classes)
        self.fc.stddev = 0.001

    def forward(self, x):
327
        # N x 768 x 17 x 17
328
        x = F.avg_pool2d(x, kernel_size=5, stride=3)
329
        # N x 768 x 5 x 5
330
        x = self.conv0(x)
331
        # N x 128 x 5 x 5
332
        x = self.conv1(x)
333
        # N x 768 x 1 x 1
334
335
        # Adaptive average pooling
        x = F.adaptive_avg_pool2d(x, (1, 1))
336
        # N x 768 x 1 x 1
337
        x = x.view(x.size(0), -1)
338
        # N x 768
339
        x = self.fc(x)
340
        # N x 1000
341
342
343
344
        return x


class BasicConv2d(nn.Module):
soumith's avatar
soumith committed
345

346
347
348
349
350
351
352
353
354
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)