googlenet.py 7.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import model_zoo

__all__ = ['GoogLeNet', 'googlenet']

model_urls = {
    # GoogLeNet ported from TensorFlow
    'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth',
}


def googlenet(pretrained=False, **kwargs):
    r"""GoogLeNet (Inception v1) model architecture from
    `"Going Deeper with Convolutions" <http://arxiv.org/abs/1409.4842>`_.
ekka's avatar
ekka committed
17

18
19
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
20
21
        aux_logits (bool): If True, adds two auxiliary branches that can improve training.
            Automatically set to False if 'pretrained' is True. Default: *True*
22
        transform_input (bool): If True, preprocesses the input according to the method with which it
ekka's avatar
ekka committed
23
            was trained on ImageNet. Default: *False*
24
25
26
27
    """
    if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
28
29
        if 'aux_logits' not in kwargs:
            kwargs['aux_logits'] = False
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
        kwargs['init_weights'] = False
        model = GoogLeNet(**kwargs)
        model.load_state_dict(model_zoo.load_url(model_urls['googlenet']))
        return model

    return GoogLeNet(**kwargs)


class GoogLeNet(nn.Module):

    def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True):
        super(GoogLeNet, self).__init__()
        self.aux_logits = aux_logits
        self.transform_input = transform_input

        self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
        self.conv2 = BasicConv2d(64, 64, kernel_size=1)
        self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)
64
65
66
67

        self.aux1 = InceptionAux(512, num_classes)
        self.aux2 = InceptionAux(528, num_classes)

68
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
69
        self.dropout = nn.Dropout(0.2)
70
71
72
73
74
75
76
        self.fc = nn.Linear(1024, num_classes)

        if init_weights:
            self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
77
78
79
80
81
82
83
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                import scipy.stats as stats
                X = stats.truncnorm(-2, 2, scale=0.01)
                values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
                values = values.view(m.weight.size())
                with torch.no_grad():
                    m.weight.copy_(values)
84
85
86
87
88
89
90
91
92
93
94
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        if self.transform_input:
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)

95
        # N x 3 x 224 x 224
96
        x = self.conv1(x)
97
        # N x 64 x 112 x 112
98
        x = self.maxpool1(x)
99
        # N x 64 x 56 x 56
100
        x = self.conv2(x)
101
        # N x 64 x 56 x 56
102
        x = self.conv3(x)
103
        # N x 192 x 56 x 56
104
105
        x = self.maxpool2(x)

106
        # N x 192 x 28 x 28
107
        x = self.inception3a(x)
108
        # N x 256 x 28 x 28
109
        x = self.inception3b(x)
110
        # N x 480 x 28 x 28
111
        x = self.maxpool3(x)
112
        # N x 480 x 14 x 14
113
        x = self.inception4a(x)
114
        # N x 512 x 14 x 14
115
116
117
118
        if self.training and self.aux_logits:
            aux1 = self.aux1(x)

        x = self.inception4b(x)
119
        # N x 512 x 14 x 14
120
        x = self.inception4c(x)
121
        # N x 512 x 14 x 14
122
        x = self.inception4d(x)
123
        # N x 528 x 14 x 14
124
125
126
127
        if self.training and self.aux_logits:
            aux2 = self.aux2(x)

        x = self.inception4e(x)
128
        # N x 832 x 14 x 14
129
        x = self.maxpool4(x)
130
        # N x 832 x 7 x 7
131
        x = self.inception5a(x)
132
        # N x 832 x 7 x 7
133
        x = self.inception5b(x)
134
        # N x 1024 x 7 x 7
135
136

        x = self.avgpool(x)
137
        # N x 1024 x 1 x 1
138
        x = x.view(x.size(0), -1)
139
        # N x 1024
140
141
        x = self.dropout(x)
        x = self.fc(x)
142
        # N x 1000 (num_classes)
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        if self.training and self.aux_logits:
            return aux1, aux2, x
        return x


class Inception(nn.Module):

    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super(Inception, self).__init__()

        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)
        )

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1)
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):

    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
190
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
191
        x = F.adaptive_avg_pool2d(x, (4, 4))
192
        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
193
        x = self.conv(x)
194
        # N x 128 x 4 x 4
195
        x = x.view(x.size(0), -1)
196
        # N x 2048
197
        x = F.relu(self.fc1(x), inplace=True)
198
        # N x 2048
199
        x = F.dropout(x, 0.7, training=self.training)
200
        # N x 2048
201
        x = self.fc2(x)
202
        # N x 1024
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217

        return x


class BasicConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)