resnet.py 15.1 KB
Newer Older
1
import torch
2
from torch import Tensor
3
import torch.nn as nn
4
from .utils import load_state_dict_from_url
5
from typing import Type, Any, Callable, Union, List, Optional
6
7
8


__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
9
10
           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
           'wide_resnet50_2', 'wide_resnet101_2']
11
12
13


model_urls = {
14
15
16
17
18
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
19
20
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
21
22
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
23
24
25
}


26
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
27
    """3x3 convolution with padding"""
28
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
29
                     padding=dilation, groups=groups, bias=False, dilation=dilation)
30
31


32
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
33
34
35
36
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


Soumith Chintala's avatar
Soumith Chintala committed
37
class BasicBlock(nn.Module):
38
39
40
41
42
43
44
45
46
47
48
49
50
    expansion: int = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
51
        super(BasicBlock, self).__init__()
52
53
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
54
55
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
56
57
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
58
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
59
        self.conv1 = conv3x3(inplanes, planes, stride)
60
        self.bn1 = norm_layer(planes)
61
62
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
63
        self.bn2 = norm_layer(planes)
64
65
66
        self.downsample = downsample
        self.stride = stride

67
    def forward(self, x: Tensor) -> Tensor:
68
        identity = x
69
70
71
72
73
74
75
76
77

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
78
            identity = self.downsample(x)
79

80
        out += identity
81
82
83
84
85
        out = self.relu(out)

        return out


Soumith Chintala's avatar
Soumith Chintala committed
86
class Bottleneck(nn.Module):
87
88
89
90
91
92
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

93
94
95
96
97
98
99
100
101
102
103
104
105
    expansion: int = 4

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
106
        super(Bottleneck, self).__init__()
107
108
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
109
        width = int(planes * (base_width / 64.)) * groups
110
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
111
112
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
113
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
114
115
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
116
        self.bn3 = norm_layer(planes * self.expansion)
117
118
119
120
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

121
    def forward(self, x: Tensor) -> Tensor:
122
        identity = x
123
124
125
126
127
128
129
130
131
132
133
134
135

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
136
            identity = self.downsample(x)
137

138
        out += identity
139
140
141
142
143
        out = self.relu(out)

        return out


Soumith Chintala's avatar
Soumith Chintala committed
144
class ResNet(nn.Module):
145

146
147
148
149
150
151
152
153
154
155
156
    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        layers: List[int],
        num_classes: int = 1000,
        zero_init_residual: bool = False,
        groups: int = 1,
        width_per_group: int = 64,
        replace_stride_with_dilation: Optional[List[bool]] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ) -> None:
157
        super(ResNet, self).__init__()
158
159
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
160
        self._norm_layer = norm_layer
161
162

        self.inplanes = 64
163
164
165
166
167
168
169
170
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
171
172
173
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
174
                               bias=False)
175
        self.bn1 = norm_layer(self.inplanes)
176
177
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
178
179
180
181
182
183
184
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
185
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
186
        self.fc = nn.Linear(512 * block.expansion, num_classes)
187
188
189

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
190
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
191
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
192
193
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
194

195
196
197
198
199
200
        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
201
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
202
                elif isinstance(m, BasicBlock):
203
                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]
204

205
206
    def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
                    stride: int = 1, dilate: bool = False) -> nn.Sequential:
207
        norm_layer = self._norm_layer
208
        downsample = None
209
210
211
212
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
213
214
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
215
                conv1x1(self.inplanes, planes * block.expansion, stride),
216
                norm_layer(planes * block.expansion),
217
218
219
            )

        layers = []
220
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
221
                            self.base_width, previous_dilation, norm_layer))
222
        self.inplanes = planes * block.expansion
223
        for _ in range(1, blocks):
224
            layers.append(block(self.inplanes, planes, groups=self.groups,
225
226
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))
227
228
229

        return nn.Sequential(*layers)

230
    def _forward_impl(self, x: Tensor) -> Tensor:
231
        # See note [TorchScript super()]
232
233
234
235
236
237
238
239
240
241
242
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
243
        x = torch.flatten(x, 1)
244
245
246
247
        x = self.fc(x)

        return x

248
    def forward(self, x: Tensor) -> Tensor:
249
        return self._forward_impl(x)
250

251

252
253
254
255
256
257
258
259
def _resnet(
    arch: str,
    block: Type[Union[BasicBlock, Bottleneck]],
    layers: List[int],
    pretrained: bool,
    progress: bool,
    **kwargs: Any
) -> ResNet:
260
    model = ResNet(block, layers, **kwargs)
261
262
263
264
265
266
267
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model


268
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
269
    r"""ResNet-18 model from
270
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
271
272
273

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
274
        progress (bool): If True, displays a progress bar of the download to stderr
275
    """
276
277
    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
                   **kwargs)
278
279


280
def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
281
    r"""ResNet-34 model from
282
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
283
284
285

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
286
        progress (bool): If True, displays a progress bar of the download to stderr
287
    """
288
289
    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)
290
291


292
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
293
    r"""ResNet-50 model from
294
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
295
296
297

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
298
        progress (bool): If True, displays a progress bar of the download to stderr
299
    """
300
301
    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)
302
303


304
def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
305
    r"""ResNet-101 model from
306
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
307
308
309

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
310
        progress (bool): If True, displays a progress bar of the download to stderr
311
    """
312
313
    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
                   **kwargs)
314
315


316
def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
317
    r"""ResNet-152 model from
318
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
319
320
321

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
322
        progress (bool): If True, displays a progress bar of the download to stderr
323
    """
324
325
    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
                   **kwargs)
326
327


328
def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
329
    r"""ResNeXt-50 32x4d model from
330
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
331
332
333
334
335

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
336
337
338
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 4
    return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
339
                   pretrained, progress, **kwargs)
340
341


342
def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
343
    r"""ResNeXt-101 32x8d model from
344
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
345
346
347
348
349

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
350
351
352
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 8
    return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
353
                   pretrained, progress, **kwargs)
354
355


356
def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
357
    r"""Wide ResNet-50-2 model from
358
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
                   pretrained, progress, **kwargs)


374
def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
375
    r"""Wide ResNet-101-2 model from
376
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
377
378
379
380
381
382
383
384
385
386
387
388
389

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
                   pretrained, progress, **kwargs)