resnet.py 15 KB
Newer Older
1
2
from typing import Type, Any, Callable, Union, List, Optional

3
import torch
4
import torch.nn as nn
5
6
from torch import Tensor

7
from .._internally_replaced_utils import load_state_dict_from_url
8
9


10
11
12
13
14
15
16
17
18
19
20
21
__all__ = [
    "ResNet",
    "resnet18",
    "resnet34",
    "resnet50",
    "resnet101",
    "resnet152",
    "resnext50_32x4d",
    "resnext101_32x8d",
    "wide_resnet50_2",
    "wide_resnet101_2",
]
22
23
24


model_urls = {
25
26
27
28
29
30
31
32
33
    "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
    "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
    "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
    "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
    "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
    "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
    "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
    "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
    "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
34
35
36
}


37
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
38
    """3x3 convolution with padding"""
39
40
41
42
43
44
45
46
47
48
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        groups=groups,
        bias=False,
        dilation=dilation,
    )
49
50


51
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
52
53
54
55
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


Soumith Chintala's avatar
Soumith Chintala committed
56
class BasicBlock(nn.Module):
57
58
59
60
61
62
63
64
65
66
67
    expansion: int = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
68
        norm_layer: Optional[Callable[..., nn.Module]] = None,
69
    ) -> None:
70
        super(BasicBlock, self).__init__()
71
72
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
73
        if groups != 1 or base_width != 64:
74
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
75
76
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
77
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
78
        self.conv1 = conv3x3(inplanes, planes, stride)
79
        self.bn1 = norm_layer(planes)
80
81
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
82
        self.bn2 = norm_layer(planes)
83
84
85
        self.downsample = downsample
        self.stride = stride

86
    def forward(self, x: Tensor) -> Tensor:
87
        identity = x
88
89
90
91
92
93
94
95
96

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
97
            identity = self.downsample(x)
98

99
        out += identity
100
101
102
103
104
        out = self.relu(out)

        return out


Soumith Chintala's avatar
Soumith Chintala committed
105
class Bottleneck(nn.Module):
106
107
108
109
110
111
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

112
113
114
115
116
117
118
119
120
121
122
    expansion: int = 4

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
123
        norm_layer: Optional[Callable[..., nn.Module]] = None,
124
    ) -> None:
125
        super(Bottleneck, self).__init__()
126
127
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
128
        width = int(planes * (base_width / 64.0)) * groups
129
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
130
131
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
132
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
133
134
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
135
        self.bn3 = norm_layer(planes * self.expansion)
136
137
138
139
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

140
    def forward(self, x: Tensor) -> Tensor:
141
        identity = x
142
143
144
145
146
147
148
149
150
151
152
153
154

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
155
            identity = self.downsample(x)
156

157
        out += identity
158
159
160
161
162
        out = self.relu(out)

        return out


Soumith Chintala's avatar
Soumith Chintala committed
163
class ResNet(nn.Module):
164
165
166
167
168
169
170
171
172
    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        layers: List[int],
        num_classes: int = 1000,
        zero_init_residual: bool = False,
        groups: int = 1,
        width_per_group: int = 64,
        replace_stride_with_dilation: Optional[List[bool]] = None,
173
        norm_layer: Optional[Callable[..., nn.Module]] = None,
174
    ) -> None:
175
        super(ResNet, self).__init__()
176
177
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
178
        self._norm_layer = norm_layer
179
180

        self.inplanes = 64
181
182
183
184
185
186
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
187
188
189
190
            raise ValueError(
                "replace_stride_with_dilation should be None "
                "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
            )
191
192
        self.groups = groups
        self.base_width = width_per_group
193
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
194
        self.bn1 = norm_layer(self.inplanes)
195
196
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
197
        self.layer1 = self._make_layer(block, 64, layers[0])
198
199
200
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
201
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
202
        self.fc = nn.Linear(512 * block.expansion, num_classes)
203
204
205

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
206
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
207
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
208
209
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
210

211
212
213
214
215
216
        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
217
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
218
                elif isinstance(m, BasicBlock):
219
                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]
220

221
222
223
224
225
226
227
228
    def _make_layer(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        planes: int,
        blocks: int,
        stride: int = 1,
        dilate: bool = False,
    ) -> nn.Sequential:
229
        norm_layer = self._norm_layer
230
        downsample = None
231
232
233
234
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
235
236
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
237
                conv1x1(self.inplanes, planes * block.expansion, stride),
238
                norm_layer(planes * block.expansion),
239
240
241
            )

        layers = []
242
243
244
245
246
        layers.append(
            block(
                self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
            )
        )
247
        self.inplanes = planes * block.expansion
248
        for _ in range(1, blocks):
249
250
251
252
253
254
255
256
257
258
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    groups=self.groups,
                    base_width=self.base_width,
                    dilation=self.dilation,
                    norm_layer=norm_layer,
                )
            )
259
260
261

        return nn.Sequential(*layers)

262
    def _forward_impl(self, x: Tensor) -> Tensor:
263
        # See note [TorchScript super()]
264
265
266
267
268
269
270
271
272
273
274
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
275
        x = torch.flatten(x, 1)
276
277
278
279
        x = self.fc(x)

        return x

280
    def forward(self, x: Tensor) -> Tensor:
281
        return self._forward_impl(x)
282

283

284
285
286
287
288
289
def _resnet(
    arch: str,
    block: Type[Union[BasicBlock, Bottleneck]],
    layers: List[int],
    pretrained: bool,
    progress: bool,
290
    **kwargs: Any,
291
) -> ResNet:
292
    model = ResNet(block, layers, **kwargs)
293
    if pretrained:
294
        state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
295
296
297
298
        model.load_state_dict(state_dict)
    return model


299
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
300
    r"""ResNet-18 model from
301
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
302
303
304

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
305
        progress (bool): If True, displays a progress bar of the download to stderr
306
    """
307
    return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
308
309


310
def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
311
    r"""ResNet-34 model from
312
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
313
314
315

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
316
        progress (bool): If True, displays a progress bar of the download to stderr
317
    """
318
    return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)
319
320


321
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
322
    r"""ResNet-50 model from
323
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
324
325
326

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
327
        progress (bool): If True, displays a progress bar of the download to stderr
328
    """
329
    return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
330
331


332
def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
333
    r"""ResNet-101 model from
334
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
335
336
337

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
338
        progress (bool): If True, displays a progress bar of the download to stderr
339
    """
340
    return _resnet("resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
341
342


343
def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
344
    r"""ResNet-152 model from
345
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
346
347
348

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
349
        progress (bool): If True, displays a progress bar of the download to stderr
350
    """
351
    return _resnet("resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs)
352
353


354
def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
355
    r"""ResNeXt-50 32x4d model from
356
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
357
358
359
360
361

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
362
363
364
    kwargs["groups"] = 32
    kwargs["width_per_group"] = 4
    return _resnet("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
365
366


367
def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
368
    r"""ResNeXt-101 32x8d model from
369
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
370
371
372
373
374

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
375
376
377
    kwargs["groups"] = 32
    kwargs["width_per_group"] = 8
    return _resnet("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
378
379


380
def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
381
    r"""Wide ResNet-50-2 model from
382
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
383
384
385
386
387
388
389
390
391
392

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
393
394
    kwargs["width_per_group"] = 64 * 2
    return _resnet("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
395
396


397
def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
398
    r"""Wide ResNet-101-2 model from
399
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
400
401
402
403
404
405
406
407
408
409

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
410
411
    kwargs["width_per_group"] = 64 * 2
    return _resnet("wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)