googlenet.py 9.2 KB
Newer Older
1
2
from __future__ import division

3
4
import warnings
from collections import namedtuple
5
6
7
import torch
import torch.nn as nn
import torch.nn.functional as F
8
9
from torch.jit.annotations import Optional
from torch import Tensor
10
from .utils import load_state_dict_from_url
11

12
__all__ = ['GoogLeNet', 'googlenet', "GoogLeNetOutputs", "_GoogLeNetOutputs"]
13
14
15
16
17
18

model_urls = {
    # GoogLeNet ported from TensorFlow
    'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth',
}

19
20
21
22
23
24
25
GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1'])
GoogLeNetOutputs.__annotations__ = {'logits': Tensor, 'aux_logits2': Optional[Tensor],
                                    'aux_logits1': Optional[Tensor]}

# Script annotations failed with _GoogleNetOutputs = namedtuple ...
# _GoogLeNetOutputs set here for backwards compat
_GoogLeNetOutputs = GoogLeNetOutputs
26

27

28
def googlenet(pretrained=False, progress=True, **kwargs):
29
30
    r"""GoogLeNet (Inception v1) model architecture from
    `"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
ekka's avatar
ekka committed
31

32
33
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
34
        progress (bool): If True, displays a progress bar of the download to stderr
35
        aux_logits (bool): If True, adds two auxiliary branches that can improve training.
36
            Default: *False* when pretrained is True otherwise *True*
37
        transform_input (bool): If True, preprocesses the input according to the method with which it
ekka's avatar
ekka committed
38
            was trained on ImageNet. Default: *False*
39
40
41
42
    """
    if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
43
44
        if 'aux_logits' not in kwargs:
            kwargs['aux_logits'] = False
45
        if kwargs['aux_logits']:
Francisco Massa's avatar
Francisco Massa committed
46
47
            warnings.warn('auxiliary heads in the pretrained googlenet model are NOT pretrained, '
                          'so make sure to train them')
48
49
        original_aux_logits = kwargs['aux_logits']
        kwargs['aux_logits'] = True
50
51
        kwargs['init_weights'] = False
        model = GoogLeNet(**kwargs)
52
53
54
        state_dict = load_state_dict_from_url(model_urls['googlenet'],
                                              progress=progress)
        model.load_state_dict(state_dict)
55
56
57
        if not original_aux_logits:
            model.aux_logits = False
            del model.aux1, model.aux2
58
59
60
61
62
63
        return model

    return GoogLeNet(**kwargs)


class GoogLeNet(nn.Module):
64
    __constants__ = ['aux_logits', 'transform_input']
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

    def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True):
        super(GoogLeNet, self).__init__()
        self.aux_logits = aux_logits
        self.transform_input = transform_input

        self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
        self.conv2 = BasicConv2d(64, 64, kernel_size=1)
        self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
90

91
92
93
        if aux_logits:
            self.aux1 = InceptionAux(512, num_classes)
            self.aux2 = InceptionAux(528, num_classes)
94

95
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
96
        self.dropout = nn.Dropout(0.2)
97
98
99
100
101
102
103
        self.fc = nn.Linear(1024, num_classes)

        if init_weights:
            self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
104
105
106
107
108
109
110
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                import scipy.stats as stats
                X = stats.truncnorm(-2, 2, scale=0.01)
                values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
                values = values.view(m.weight.size())
                with torch.no_grad():
                    m.weight.copy_(values)
111
112
113
114
115
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
116
        # type: (Tensor) -> GoogLeNetOutputs
117
118
119
120
121
122
        if self.transform_input:
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)

123
        # N x 3 x 224 x 224
124
        x = self.conv1(x)
125
        # N x 64 x 112 x 112
126
        x = self.maxpool1(x)
127
        # N x 64 x 56 x 56
128
        x = self.conv2(x)
129
        # N x 64 x 56 x 56
130
        x = self.conv3(x)
131
        # N x 192 x 56 x 56
132
133
        x = self.maxpool2(x)

134
        # N x 192 x 28 x 28
135
        x = self.inception3a(x)
136
        # N x 256 x 28 x 28
137
        x = self.inception3b(x)
138
        # N x 480 x 28 x 28
139
        x = self.maxpool3(x)
140
        # N x 480 x 14 x 14
141
        x = self.inception4a(x)
142
        # N x 512 x 14 x 14
143
144
        aux_defined = self.training and self.aux_logits
        if aux_defined:
145
            aux1 = self.aux1(x)
146
147
        else:
            aux1 = None
148
149

        x = self.inception4b(x)
150
        # N x 512 x 14 x 14
151
        x = self.inception4c(x)
152
        # N x 512 x 14 x 14
153
        x = self.inception4d(x)
154
        # N x 528 x 14 x 14
155
        if aux_defined:
156
            aux2 = self.aux2(x)
157
158
        else:
            aux2 = None
159
160

        x = self.inception4e(x)
161
        # N x 832 x 14 x 14
162
        x = self.maxpool4(x)
163
        # N x 832 x 7 x 7
164
        x = self.inception5a(x)
165
        # N x 832 x 7 x 7
166
        x = self.inception5b(x)
167
        # N x 1024 x 7 x 7
168
169

        x = self.avgpool(x)
170
        # N x 1024 x 1 x 1
171
        x = torch.flatten(x, 1)
172
        # N x 1024
173
174
        x = self.dropout(x)
        x = self.fc(x)
175
        # N x 1000 (num_classes)
176
177
178
179
180
181
182
183
184
185
        if torch.jit.is_scripting():
            if not aux_defined:
                warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")
            return GoogLeNetOutputs(x, aux2, aux1)
        else:
            return self.eager_outputs(x, aux2, aux1)

    @torch.jit.unused
    def eager_outputs(self, x, aux2, aux1):
        # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> GoogLeNetOutputs
186
        if self.training and self.aux_logits:
taylanbil's avatar
taylanbil committed
187
            return _GoogLeNetOutputs(x, aux2, aux1)
188
189
        else:
            return x
190
191
192


class Inception(nn.Module):
193
    __constants__ = ['branch2', 'branch3', 'branch4']
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super(Inception, self).__init__()

        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)
        )

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1)
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):

    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
235
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
236
        x = F.adaptive_avg_pool2d(x, (4, 4))
237
        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
238
        x = self.conv(x)
239
        # N x 128 x 4 x 4
240
        x = torch.flatten(x, 1)
241
        # N x 2048
242
        x = F.relu(self.fc1(x), inplace=True)
Myosaki's avatar
Myosaki committed
243
        # N x 1024
244
        x = F.dropout(x, 0.7, training=self.training)
245
        # N x 1024
Myosaki's avatar
Myosaki committed
246
247
        x = self.fc2(x)
        # N x 1000 (num_classes)
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262

        return x


class BasicConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)