googlenet.py 7.84 KB
Newer Older
1
2
import warnings
from collections import namedtuple
3
4
5
6
7
8
9
10
11
12
13
14
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import model_zoo

__all__ = ['GoogLeNet', 'googlenet']

model_urls = {
    # GoogLeNet ported from TensorFlow
    'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth',
}

15
16
_GoogLeNetOuputs = namedtuple('GoogLeNetOuputs', ['logits', 'aux_logits2', 'aux_logits1'])

17
18
19
20

def googlenet(pretrained=False, **kwargs):
    r"""GoogLeNet (Inception v1) model architecture from
    `"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
ekka's avatar
ekka committed
21

22
23
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
24
        aux_logits (bool): If True, adds two auxiliary branches that can improve training.
25
            Default: *False* when pretrained is True otherwise *True*
26
        transform_input (bool): If True, preprocesses the input according to the method with which it
ekka's avatar
ekka committed
27
            was trained on ImageNet. Default: *False*
28
29
30
31
    """
    if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
32
33
        if 'aux_logits' not in kwargs:
            kwargs['aux_logits'] = False
34
35
36
37
        if kwargs['aux_logits']:
            warnings.warn('auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them')
        original_aux_logits = kwargs['aux_logits']
        kwargs['aux_logits'] = True
38
39
40
        kwargs['init_weights'] = False
        model = GoogLeNet(**kwargs)
        model.load_state_dict(model_zoo.load_url(model_urls['googlenet']))
41
42
43
        if not original_aux_logits:
            model.aux_logits = False
            del model.aux1, model.aux2
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        return model

    return GoogLeNet(**kwargs)


class GoogLeNet(nn.Module):

    def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True):
        super(GoogLeNet, self).__init__()
        self.aux_logits = aux_logits
        self.transform_input = transform_input

        self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
        self.conv2 = BasicConv2d(64, 64, kernel_size=1)
        self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
75

76
77
78
        if aux_logits:
            self.aux1 = InceptionAux(512, num_classes)
            self.aux2 = InceptionAux(528, num_classes)
79

80
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
81
        self.dropout = nn.Dropout(0.2)
82
83
84
85
86
87
88
        self.fc = nn.Linear(1024, num_classes)

        if init_weights:
            self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
89
90
91
92
93
94
95
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                import scipy.stats as stats
                X = stats.truncnorm(-2, 2, scale=0.01)
                values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
                values = values.view(m.weight.size())
                with torch.no_grad():
                    m.weight.copy_(values)
96
97
98
99
100
101
102
103
104
105
106
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        if self.transform_input:
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)

107
        # N x 3 x 224 x 224
108
        x = self.conv1(x)
109
        # N x 64 x 112 x 112
110
        x = self.maxpool1(x)
111
        # N x 64 x 56 x 56
112
        x = self.conv2(x)
113
        # N x 64 x 56 x 56
114
        x = self.conv3(x)
115
        # N x 192 x 56 x 56
116
117
        x = self.maxpool2(x)

118
        # N x 192 x 28 x 28
119
        x = self.inception3a(x)
120
        # N x 256 x 28 x 28
121
        x = self.inception3b(x)
122
        # N x 480 x 28 x 28
123
        x = self.maxpool3(x)
124
        # N x 480 x 14 x 14
125
        x = self.inception4a(x)
126
        # N x 512 x 14 x 14
127
128
129
130
        if self.training and self.aux_logits:
            aux1 = self.aux1(x)

        x = self.inception4b(x)
131
        # N x 512 x 14 x 14
132
        x = self.inception4c(x)
133
        # N x 512 x 14 x 14
134
        x = self.inception4d(x)
135
        # N x 528 x 14 x 14
136
137
138
139
        if self.training and self.aux_logits:
            aux2 = self.aux2(x)

        x = self.inception4e(x)
140
        # N x 832 x 14 x 14
141
        x = self.maxpool4(x)
142
        # N x 832 x 7 x 7
143
        x = self.inception5a(x)
144
        # N x 832 x 7 x 7
145
        x = self.inception5b(x)
146
        # N x 1024 x 7 x 7
147
148

        x = self.avgpool(x)
149
        # N x 1024 x 1 x 1
150
        x = x.view(x.size(0), -1)
151
        # N x 1024
152
153
        x = self.dropout(x)
        x = self.fc(x)
154
        # N x 1000 (num_classes)
155
        if self.training and self.aux_logits:
156
            return _GoogLeNetOuputs(x, aux2, aux1)
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
        return x


class Inception(nn.Module):

    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super(Inception, self).__init__()

        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)
        )

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1)
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):

    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
202
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
203
        x = F.adaptive_avg_pool2d(x, (4, 4))
204
        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
205
        x = self.conv(x)
206
        # N x 128 x 4 x 4
207
        x = x.view(x.size(0), -1)
208
        # N x 2048
209
        x = F.relu(self.fc1(x), inplace=True)
210
        # N x 2048
211
        x = F.dropout(x, 0.7, training=self.training)
212
        # N x 2048
213
        x = self.fc2(x)
214
        # N x 1024
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229

        return x


class BasicConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)