resnet.py 12.5 KB
Newer Older
1
import torch.nn as nn
2
from .utils import load_state_dict_from_url
3
4
5


__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
6
7
           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
           'wide_resnet50_2', 'wide_resnet101_2']
8
9
10


model_urls = {
11
12
13
14
15
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
16
17
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
18
19
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
20
21
22
}


23
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
24
    """3x3 convolution with padding"""
25
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
26
                     padding=dilation, groups=groups, bias=False, dilation=dilation)
27
28


29
30
31
32
33
def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


Soumith Chintala's avatar
Soumith Chintala committed
34
class BasicBlock(nn.Module):
35
36
    expansion = 1

37
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
38
                 base_width=64, dilation=1, norm_layer=None):
39
        super(BasicBlock, self).__init__()
40
41
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
42
43
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
44
45
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
46
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
47
        self.conv1 = conv3x3(inplanes, planes, stride)
48
        self.bn1 = norm_layer(planes)
49
50
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
51
        self.bn2 = norm_layer(planes)
52
53
54
55
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
56
        identity = x
57
58
59
60
61
62
63
64
65

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
66
            identity = self.downsample(x)
67

68
        out += identity
69
70
71
72
73
        out = self.relu(out)

        return out


Soumith Chintala's avatar
Soumith Chintala committed
74
class Bottleneck(nn.Module):
75
76
    expansion = 4

77
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
78
                 base_width=64, dilation=1, norm_layer=None):
79
        super(Bottleneck, self).__init__()
80
81
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
82
        width = int(planes * (base_width / 64.)) * groups
83
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
84
85
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
86
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
87
88
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
89
        self.bn3 = norm_layer(planes * self.expansion)
90
91
92
93
94
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
95
        identity = x
96
97
98
99
100
101
102
103
104
105
106
107
108

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
109
            identity = self.downsample(x)
110

111
        out += identity
112
113
114
115
116
        out = self.relu(out)

        return out


Soumith Chintala's avatar
Soumith Chintala committed
117
class ResNet(nn.Module):
118

119
    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
120
121
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
122
        super(ResNet, self).__init__()
123
124
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
125
        self._norm_layer = norm_layer
126
127

        self.inplanes = 64
128
129
130
131
132
133
134
135
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
136
137
138
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
139
                               bias=False)
140
        self.bn1 = norm_layer(self.inplanes)
141
142
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
143
144
145
146
147
148
149
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
150
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
151
        self.fc = nn.Linear(512 * block.expansion, num_classes)
152
153
154

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
155
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
156
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
157
158
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
159

160
161
162
163
164
165
166
167
168
169
        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

170
171
    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
172
        downsample = None
173
174
175
176
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
177
178
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
179
                conv1x1(self.inplanes, planes * block.expansion, stride),
180
                norm_layer(planes * block.expansion),
181
182
183
            )

        layers = []
184
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
185
                            self.base_width, previous_dilation, norm_layer))
186
        self.inplanes = planes * block.expansion
187
        for _ in range(1, blocks):
188
            layers.append(block(self.inplanes, planes, groups=self.groups,
189
190
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
206
        x = x.reshape(x.size(0), -1)
207
208
209
210
211
        x = self.fc(x)

        return x


212
213
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNet(block, layers, **kwargs)
214
215
216
217
218
219
220
221
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model


def resnet18(pretrained=False, progress=True, **kwargs):
222
223
224
225
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
226
        progress (bool): If True, displays a progress bar of the download to stderr
227
    """
228
229
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
                   **kwargs)
230
231


232
def resnet34(pretrained=False, progress=True, **kwargs):
233
234
235
236
    """Constructs a ResNet-34 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
237
        progress (bool): If True, displays a progress bar of the download to stderr
238
    """
239
240
    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)
241
242


243
def resnet50(pretrained=False, progress=True, **kwargs):
244
245
246
247
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
248
        progress (bool): If True, displays a progress bar of the download to stderr
249
    """
250
251
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)
252
253


254
def resnet101(pretrained=False, progress=True, **kwargs):
255
256
257
258
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
259
        progress (bool): If True, displays a progress bar of the download to stderr
260
    """
261
262
    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
                   **kwargs)
263
264


265
def resnet152(pretrained=False, progress=True, **kwargs):
266
267
268
269
    """Constructs a ResNet-152 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
270
        progress (bool): If True, displays a progress bar of the download to stderr
271
    """
272
273
    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
                   **kwargs)
274
275


276
277
278
279
280
281
282
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
    """Constructs a ResNeXt-50 32x4d model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
283
284
285
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 4
    return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
286
                   pretrained, progress, **kwargs)
287
288


289
290
291
292
293
294
295
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
    """Constructs a ResNeXt-101 32x8d model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
296
297
298
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 8
    return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
299
                   pretrained, progress, **kwargs)
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333


def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
    """Constructs a Wide ResNet-50-2 model.

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
                   pretrained, progress, **kwargs)


def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
    """Constructs a Wide ResNet-101-2 model.

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
                   pretrained, progress, **kwargs)