resnet.py 13.2 KB
Newer Older
1
import torch.nn as nn
2
from .utils import load_state_dict_from_url
3
4
5


__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
6
7
           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
           'wide_resnet50_2', 'wide_resnet101_2']
8
9
10


model_urls = {
11
12
13
14
15
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
16
17
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
18
19
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
20
21
22
}


23
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
24
    """3x3 convolution with padding"""
25
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
26
                     padding=dilation, groups=groups, bias=False, dilation=dilation)
27
28


29
30
31
32
33
def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


Soumith Chintala's avatar
Soumith Chintala committed
34
class BasicBlock(nn.Module):
35
36
    expansion = 1

37
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
38
                 base_width=64, dilation=1, norm_layer=None):
39
        super(BasicBlock, self).__init__()
40
41
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
42
43
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
44
45
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
46
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
47
        self.conv1 = conv3x3(inplanes, planes, stride)
48
        self.bn1 = norm_layer(planes)
49
50
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
51
        self.bn2 = norm_layer(planes)
52
53
54
55
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
56
        identity = x
57
58
59
60
61
62
63
64
65

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
66
            identity = self.downsample(x)
67

68
        out += identity
69
70
71
72
73
        out = self.relu(out)

        return out


Soumith Chintala's avatar
Soumith Chintala committed
74
class Bottleneck(nn.Module):
75
76
    expansion = 4

77
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
78
                 base_width=64, dilation=1, norm_layer=None):
79
        super(Bottleneck, self).__init__()
80
81
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
82
        width = int(planes * (base_width / 64.)) * groups
83
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
84
85
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
86
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
87
88
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
89
        self.bn3 = norm_layer(planes * self.expansion)
90
91
92
93
94
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
95
        identity = x
96
97
98
99
100
101
102
103
104
105
106
107
108

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
109
            identity = self.downsample(x)
110

111
        out += identity
112
113
114
115
116
        out = self.relu(out)

        return out


Soumith Chintala's avatar
Soumith Chintala committed
117
class ResNet(nn.Module):
118

119
    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
120
121
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
122
        super(ResNet, self).__init__()
123
124
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
125
        self._norm_layer = norm_layer
126
127

        self.inplanes = 64
128
129
130
131
132
133
134
135
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
136
137
138
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
139
                               bias=False)
140
        self.bn1 = norm_layer(self.inplanes)
141
142
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
143
144
145
146
147
148
149
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
150
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
151
        self.fc = nn.Linear(512 * block.expansion, num_classes)
152
153
154

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
155
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
156
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
157
158
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
159

160
161
162
163
164
165
166
167
168
169
        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

170
171
    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
172
        downsample = None
173
174
175
176
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
177
178
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
179
                conv1x1(self.inplanes, planes * block.expansion, stride),
180
                norm_layer(planes * block.expansion),
181
182
183
            )

        layers = []
184
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
185
                            self.base_width, previous_dilation, norm_layer))
186
        self.inplanes = planes * block.expansion
187
        for _ in range(1, blocks):
188
            layers.append(block(self.inplanes, planes, groups=self.groups,
189
190
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
206
        x = x.reshape(x.size(0), -1)
207
208
209
210
211
        x = self.fc(x)

        return x


212
213
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
214
215
216
217
218
219
220
221
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model


def resnet18(pretrained=False, progress=True, **kwargs):
222
223
    r"""ResNet-18 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>'_
224
225
226

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
227
        progress (bool): If True, displays a progress bar of the download to stderr
228
    """
229
230
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
                   **kwargs)
231
232


233
def resnet34(pretrained=False, progress=True, **kwargs):
234
235
    r"""ResNet-34 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>'_
236
237
238

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
239
        progress (bool): If True, displays a progress bar of the download to stderr
240
    """
241
242
    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)
243
244


245
def resnet50(pretrained=False, progress=True, **kwargs):
246
247
    r"""ResNet-50 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>'_
248
249
250

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
251
        progress (bool): If True, displays a progress bar of the download to stderr
252
    """
253
254
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)
255
256


257
def resnet101(pretrained=False, progress=True, **kwargs):
258
259
    r"""ResNet-101 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>'_
260
261
262

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
263
        progress (bool): If True, displays a progress bar of the download to stderr
264
    """
265
266
    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
                   **kwargs)
267
268


269
def resnet152(pretrained=False, progress=True, **kwargs):
270
271
    r"""ResNet-152 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>'_
272
273
274

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
275
        progress (bool): If True, displays a progress bar of the download to stderr
276
    """
277
278
    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
                   **kwargs)
279
280


281
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
282
283
    r"""ResNeXt-50 32x4d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
284
285
286
287
288

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
289
290
291
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 4
    return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
292
                   pretrained, progress, **kwargs)
293
294


295
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
296
297
    r"""ResNeXt-101 32x8d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
298
299
300
301
302

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
303
304
305
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 8
    return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
306
                   pretrained, progress, **kwargs)
307
308
309


def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
310
311
    r"""Wide ResNet-50-2 model from
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
                   pretrained, progress, **kwargs)


def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
328
329
    r"""Wide ResNet-101-2 model from
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
330
331
332
333
334
335
336
337
338
339
340
341
342

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
                   pretrained, progress, **kwargs)