resnet.py 13.3 KB
Newer Older
1
import torch
2
import torch.nn as nn
3
from .utils import load_state_dict_from_url
4
5
6


__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
7
8
           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
           'wide_resnet50_2', 'wide_resnet101_2']
9
10
11


model_urls = {
12
13
14
15
16
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
17
18
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
19
20
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
21
22
23
}


24
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
25
    """3x3 convolution with padding"""
26
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
27
                     padding=dilation, groups=groups, bias=False, dilation=dilation)
28
29


30
31
32
33
34
def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


Soumith Chintala's avatar
Soumith Chintala committed
35
class BasicBlock(nn.Module):
36
37
    expansion = 1

38
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
39
                 base_width=64, dilation=1, norm_layer=None):
40
        super(BasicBlock, self).__init__()
41
42
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
43
44
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
45
46
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
47
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
48
        self.conv1 = conv3x3(inplanes, planes, stride)
49
        self.bn1 = norm_layer(planes)
50
51
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
52
        self.bn2 = norm_layer(planes)
53
54
55
56
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
57
        identity = x
58
59
60
61
62
63
64
65
66

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
67
            identity = self.downsample(x)
68

69
        out += identity
70
71
72
73
74
        out = self.relu(out)

        return out


Soumith Chintala's avatar
Soumith Chintala committed
75
class Bottleneck(nn.Module):
76
77
    expansion = 4

78
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
79
                 base_width=64, dilation=1, norm_layer=None):
80
        super(Bottleneck, self).__init__()
81
82
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
83
        width = int(planes * (base_width / 64.)) * groups
84
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
85
86
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
87
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
88
89
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
90
        self.bn3 = norm_layer(planes * self.expansion)
91
92
93
94
95
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
96
        identity = x
97
98
99
100
101
102
103
104
105
106
107
108
109

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
110
            identity = self.downsample(x)
111

112
        out += identity
113
114
115
116
117
        out = self.relu(out)

        return out


Soumith Chintala's avatar
Soumith Chintala committed
118
class ResNet(nn.Module):
119

120
    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
121
122
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
123
        super(ResNet, self).__init__()
124
125
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
126
        self._norm_layer = norm_layer
127
128

        self.inplanes = 64
129
130
131
132
133
134
135
136
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
137
138
139
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
140
                               bias=False)
141
        self.bn1 = norm_layer(self.inplanes)
142
143
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
144
145
146
147
148
149
150
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
151
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
152
        self.fc = nn.Linear(512 * block.expansion, num_classes)
153
154
155

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
156
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
157
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
158
159
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
160

161
162
163
164
165
166
167
168
169
170
        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

171
172
    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
173
        downsample = None
174
175
176
177
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
178
179
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
180
                conv1x1(self.inplanes, planes * block.expansion, stride),
181
                norm_layer(planes * block.expansion),
182
183
184
            )

        layers = []
185
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
186
                            self.base_width, previous_dilation, norm_layer))
187
        self.inplanes = planes * block.expansion
188
        for _ in range(1, blocks):
189
            layers.append(block(self.inplanes, planes, groups=self.groups,
190
191
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))
192
193
194

        return nn.Sequential(*layers)

195
196
    def _forward_impl(self, x):
        # See note [TorchScript super()]
197
198
199
200
201
202
203
204
205
206
207
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
208
        x = torch.flatten(x, 1)
209
210
211
212
        x = self.fc(x)

        return x

213
214
    def forward(self, x):
        return self._forward_impl(x)
215

216

217
218
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
219
220
221
222
223
224
225
226
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model


def resnet18(pretrained=False, progress=True, **kwargs):
227
    r"""ResNet-18 model from
228
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
229
230
231

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
232
        progress (bool): If True, displays a progress bar of the download to stderr
233
    """
234
235
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
                   **kwargs)
236
237


238
def resnet34(pretrained=False, progress=True, **kwargs):
239
    r"""ResNet-34 model from
240
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
241
242
243

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
244
        progress (bool): If True, displays a progress bar of the download to stderr
245
    """
246
247
    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)
248
249


250
def resnet50(pretrained=False, progress=True, **kwargs):
251
    r"""ResNet-50 model from
252
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
253
254
255

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
256
        progress (bool): If True, displays a progress bar of the download to stderr
257
    """
258
259
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)
260
261


262
def resnet101(pretrained=False, progress=True, **kwargs):
263
    r"""ResNet-101 model from
264
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
265
266
267

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
268
        progress (bool): If True, displays a progress bar of the download to stderr
269
    """
270
271
    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
                   **kwargs)
272
273


274
def resnet152(pretrained=False, progress=True, **kwargs):
275
    r"""ResNet-152 model from
276
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
277
278
279

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
280
        progress (bool): If True, displays a progress bar of the download to stderr
281
    """
282
283
    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
                   **kwargs)
284
285


286
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
287
288
    r"""ResNeXt-50 32x4d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
289
290
291
292
293

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
294
295
296
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 4
    return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
297
                   pretrained, progress, **kwargs)
298
299


300
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
301
302
    r"""ResNeXt-101 32x8d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
303
304
305
306
307

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
308
309
310
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 8
    return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
311
                   pretrained, progress, **kwargs)
312
313
314


def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
315
316
    r"""Wide ResNet-50-2 model from
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
                   pretrained, progress, **kwargs)


def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
333
334
    r"""Wide ResNet-101-2 model from
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
335
336
337
338
339
340
341
342
343
344
345
346
347

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
                   pretrained, progress, **kwargs)