resnet.py 13.8 KB
Newer Older
1
import torch
2
import torch.nn as nn
3
from .utils import load_state_dict_from_url
4
5
6


__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
7
8
           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
           'wide_resnet50_2', 'wide_resnet101_2']
9
10
11


model_urls = {
12
13
14
15
16
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
17
18
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
19
20
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
21
22
23
}


24
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
25
    """3x3 convolution with padding"""
26
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
27
                     padding=dilation, groups=groups, bias=False, dilation=dilation)
28
29


30
31
32
33
34
def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


Soumith Chintala's avatar
Soumith Chintala committed
35
class BasicBlock(nn.Module):
36
37
    expansion = 1

38
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
39
                 base_width=64, dilation=1, norm_layer=None):
40
        super(BasicBlock, self).__init__()
41
42
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
43
44
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
45
46
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
47
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
48
        self.conv1 = conv3x3(inplanes, planes, stride)
49
        self.bn1 = norm_layer(planes)
50
51
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
52
        self.bn2 = norm_layer(planes)
53
54
55
56
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
57
        identity = x
58
59
60
61
62
63
64
65
66

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
67
            identity = self.downsample(x)
68

69
        out += identity
70
71
72
73
74
        out = self.relu(out)

        return out


Soumith Chintala's avatar
Soumith Chintala committed
75
class Bottleneck(nn.Module):
76
77
78
79
80
81
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

82
83
    expansion = 4

84
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
85
                 base_width=64, dilation=1, norm_layer=None):
86
        super(Bottleneck, self).__init__()
87
88
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
89
        width = int(planes * (base_width / 64.)) * groups
90
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
91
92
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
93
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
94
95
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
96
        self.bn3 = norm_layer(planes * self.expansion)
97
98
99
100
101
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
102
        identity = x
103
104
105
106
107
108
109
110
111
112
113
114
115

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
116
            identity = self.downsample(x)
117

118
        out += identity
119
120
121
122
123
        out = self.relu(out)

        return out


Soumith Chintala's avatar
Soumith Chintala committed
124
class ResNet(nn.Module):
125

126
    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
127
128
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
129
        super(ResNet, self).__init__()
130
131
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
132
        self._norm_layer = norm_layer
133
134

        self.inplanes = 64
135
136
137
138
139
140
141
142
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
143
144
145
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
146
                               bias=False)
147
        self.bn1 = norm_layer(self.inplanes)
148
149
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
150
151
152
153
154
155
156
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
157
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
158
        self.fc = nn.Linear(512 * block.expansion, num_classes)
159
160
161

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
162
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
163
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
164
165
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
166

167
168
169
170
171
172
173
174
175
176
        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

177
178
    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
179
        downsample = None
180
181
182
183
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
184
185
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
186
                conv1x1(self.inplanes, planes * block.expansion, stride),
187
                norm_layer(planes * block.expansion),
188
189
190
            )

        layers = []
191
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
192
                            self.base_width, previous_dilation, norm_layer))
193
        self.inplanes = planes * block.expansion
194
        for _ in range(1, blocks):
195
            layers.append(block(self.inplanes, planes, groups=self.groups,
196
197
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))
198
199
200

        return nn.Sequential(*layers)

201
202
    def _forward_impl(self, x):
        # See note [TorchScript super()]
203
204
205
206
207
208
209
210
211
212
213
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
214
        x = torch.flatten(x, 1)
215
216
217
218
        x = self.fc(x)

        return x

219
220
    def forward(self, x):
        return self._forward_impl(x)
221

222

223
224
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
225
226
227
228
229
230
231
232
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model


def resnet18(pretrained=False, progress=True, **kwargs):
233
    r"""ResNet-18 model from
234
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
235
236
237

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
238
        progress (bool): If True, displays a progress bar of the download to stderr
239
    """
240
241
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
                   **kwargs)
242
243


244
def resnet34(pretrained=False, progress=True, **kwargs):
245
    r"""ResNet-34 model from
246
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
247
248
249

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
250
        progress (bool): If True, displays a progress bar of the download to stderr
251
    """
252
253
    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)
254
255


256
def resnet50(pretrained=False, progress=True, **kwargs):
257
    r"""ResNet-50 model from
258
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
259
260
261

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
262
        progress (bool): If True, displays a progress bar of the download to stderr
263
    """
264
265
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)
266
267


268
def resnet101(pretrained=False, progress=True, **kwargs):
269
    r"""ResNet-101 model from
270
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
271
272
273

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
274
        progress (bool): If True, displays a progress bar of the download to stderr
275
    """
276
277
    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
                   **kwargs)
278
279


280
def resnet152(pretrained=False, progress=True, **kwargs):
281
    r"""ResNet-152 model from
282
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
283
284
285

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
286
        progress (bool): If True, displays a progress bar of the download to stderr
287
    """
288
289
    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
                   **kwargs)
290
291


292
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
293
294
    r"""ResNeXt-50 32x4d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
295
296
297
298
299

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
300
301
302
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 4
    return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
303
                   pretrained, progress, **kwargs)
304
305


306
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
307
308
    r"""ResNeXt-101 32x8d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
309
310
311
312
313

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
314
315
316
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 8
    return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
317
                   pretrained, progress, **kwargs)
318
319
320


def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
321
322
    r"""Wide ResNet-50-2 model from
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
                   pretrained, progress, **kwargs)


def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
339
340
    r"""Wide ResNet-101-2 model from
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
341
342
343
344
345
346
347
348
349
350
351
352
353

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
                   pretrained, progress, **kwargs)