googlenet.py 8.08 KB
Newer Older
1
2
import warnings
from collections import namedtuple
3
4
5
import torch
import torch.nn as nn
import torch.nn.functional as F
6
from .utils import load_state_dict_from_url
7
8
9
10
11
12
13
14

__all__ = ['GoogLeNet', 'googlenet']

model_urls = {
    # GoogLeNet ported from TensorFlow
    'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth',
}

15
16
_GoogLeNetOuputs = namedtuple('GoogLeNetOuputs', ['logits', 'aux_logits2', 'aux_logits1'])

17

18
def googlenet(pretrained=False, progress=True, **kwargs):
19
20
    r"""GoogLeNet (Inception v1) model architecture from
    `"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
ekka's avatar
ekka committed
21

22
23
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
24
        progress (bool): If True, displays a progress bar of the download to stderr
25
        aux_logits (bool): If True, adds two auxiliary branches that can improve training.
26
            Default: *False* when pretrained is True otherwise *True*
27
        transform_input (bool): If True, preprocesses the input according to the method with which it
ekka's avatar
ekka committed
28
            was trained on ImageNet. Default: *False*
29
30
31
32
    """
    if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
33
34
        if 'aux_logits' not in kwargs:
            kwargs['aux_logits'] = False
35
        if kwargs['aux_logits']:
Francisco Massa's avatar
Francisco Massa committed
36
37
            warnings.warn('auxiliary heads in the pretrained googlenet model are NOT pretrained, '
                          'so make sure to train them')
38
39
        original_aux_logits = kwargs['aux_logits']
        kwargs['aux_logits'] = True
40
41
        kwargs['init_weights'] = False
        model = GoogLeNet(**kwargs)
42
43
44
        state_dict = load_state_dict_from_url(model_urls['googlenet'],
                                              progress=progress)
        model.load_state_dict(state_dict)
45
46
47
        if not original_aux_logits:
            model.aux_logits = False
            del model.aux1, model.aux2
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
        return model

    return GoogLeNet(**kwargs)


class GoogLeNet(nn.Module):

    def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True):
        super(GoogLeNet, self).__init__()
        self.aux_logits = aux_logits
        self.transform_input = transform_input

        self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
        self.conv2 = BasicConv2d(64, 64, kernel_size=1)
        self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
79

80
81
82
        if aux_logits:
            self.aux1 = InceptionAux(512, num_classes)
            self.aux2 = InceptionAux(528, num_classes)
83

84
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
85
        self.dropout = nn.Dropout(0.2)
86
87
88
89
90
91
92
        self.fc = nn.Linear(1024, num_classes)

        if init_weights:
            self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
93
94
95
96
97
98
99
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                import scipy.stats as stats
                X = stats.truncnorm(-2, 2, scale=0.01)
                values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
                values = values.view(m.weight.size())
                with torch.no_grad():
                    m.weight.copy_(values)
100
101
102
103
104
105
106
107
108
109
110
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        if self.transform_input:
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)

111
        # N x 3 x 224 x 224
112
        x = self.conv1(x)
113
        # N x 64 x 112 x 112
114
        x = self.maxpool1(x)
115
        # N x 64 x 56 x 56
116
        x = self.conv2(x)
117
        # N x 64 x 56 x 56
118
        x = self.conv3(x)
119
        # N x 192 x 56 x 56
120
121
        x = self.maxpool2(x)

122
        # N x 192 x 28 x 28
123
        x = self.inception3a(x)
124
        # N x 256 x 28 x 28
125
        x = self.inception3b(x)
126
        # N x 480 x 28 x 28
127
        x = self.maxpool3(x)
128
        # N x 480 x 14 x 14
129
        x = self.inception4a(x)
130
        # N x 512 x 14 x 14
131
132
133
134
        if self.training and self.aux_logits:
            aux1 = self.aux1(x)

        x = self.inception4b(x)
135
        # N x 512 x 14 x 14
136
        x = self.inception4c(x)
137
        # N x 512 x 14 x 14
138
        x = self.inception4d(x)
139
        # N x 528 x 14 x 14
140
141
142
143
        if self.training and self.aux_logits:
            aux2 = self.aux2(x)

        x = self.inception4e(x)
144
        # N x 832 x 14 x 14
145
        x = self.maxpool4(x)
146
        # N x 832 x 7 x 7
147
        x = self.inception5a(x)
148
        # N x 832 x 7 x 7
149
        x = self.inception5b(x)
150
        # N x 1024 x 7 x 7
151
152

        x = self.avgpool(x)
153
        # N x 1024 x 1 x 1
154
        x = x.view(x.size(0), -1)
155
        # N x 1024
156
157
        x = self.dropout(x)
        x = self.fc(x)
158
        # N x 1000 (num_classes)
159
        if self.training and self.aux_logits:
160
            return _GoogLeNetOuputs(x, aux2, aux1)
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        return x


class Inception(nn.Module):

    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super(Inception, self).__init__()

        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)
        )

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1)
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):

    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
206
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
207
        x = F.adaptive_avg_pool2d(x, (4, 4))
208
        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
209
        x = self.conv(x)
210
        # N x 128 x 4 x 4
211
        x = x.view(x.size(0), -1)
212
        # N x 2048
213
        x = F.relu(self.fc1(x), inplace=True)
214
        # N x 2048
215
        x = F.dropout(x, 0.7, training=self.training)
216
        # N x 2048
217
        x = self.fc2(x)
218
        # N x 1024
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233

        return x


class BasicConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)