resnet.py 31.4 KB
Newer Older
1
from functools import partial
2
3
from typing import Type, Any, Callable, Union, List, Optional

4
import torch
5
import torch.nn as nn
6
7
from torch import Tensor

8
from ..transforms._presets import ImageClassification
9
from ..utils import _log_api_usage_once
10
11
12
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
13
14


15
16
__all__ = [
    "ResNet",
17
18
19
20
21
22
23
    "ResNet18_Weights",
    "ResNet34_Weights",
    "ResNet50_Weights",
    "ResNet101_Weights",
    "ResNet152_Weights",
    "ResNeXt50_32X4D_Weights",
    "ResNeXt101_32X8D_Weights",
24
    "ResNeXt101_64X4D_Weights",
25
26
    "Wide_ResNet50_2_Weights",
    "Wide_ResNet101_2_Weights",
27
28
29
30
31
32
33
    "resnet18",
    "resnet34",
    "resnet50",
    "resnet101",
    "resnet152",
    "resnext50_32x4d",
    "resnext101_32x8d",
34
    "resnext101_64x4d",
35
36
37
    "wide_resnet50_2",
    "wide_resnet101_2",
]
38
39


40
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
41
    """3x3 convolution with padding"""
42
43
44
45
46
47
48
49
50
51
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        groups=groups,
        bias=False,
        dilation=dilation,
    )
52
53


54
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
55
56
57
58
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


Soumith Chintala's avatar
Soumith Chintala committed
59
class BasicBlock(nn.Module):
60
61
62
63
64
65
66
67
68
69
70
    expansion: int = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
71
        norm_layer: Optional[Callable[..., nn.Module]] = None,
72
    ) -> None:
73
        super().__init__()
74
75
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
76
        if groups != 1 or base_width != 64:
77
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
78
79
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
80
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
81
        self.conv1 = conv3x3(inplanes, planes, stride)
82
        self.bn1 = norm_layer(planes)
83
84
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
85
        self.bn2 = norm_layer(planes)
86
87
88
        self.downsample = downsample
        self.stride = stride

89
    def forward(self, x: Tensor) -> Tensor:
90
        identity = x
91
92
93
94
95
96
97
98
99

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
100
            identity = self.downsample(x)
101

102
        out += identity
103
104
105
106
107
        out = self.relu(out)

        return out


Soumith Chintala's avatar
Soumith Chintala committed
108
class Bottleneck(nn.Module):
109
110
111
112
113
114
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

115
116
117
118
119
120
121
122
123
124
125
    expansion: int = 4

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
126
        norm_layer: Optional[Callable[..., nn.Module]] = None,
127
    ) -> None:
128
        super().__init__()
129
130
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
131
        width = int(planes * (base_width / 64.0)) * groups
132
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
133
134
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
135
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
136
137
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
138
        self.bn3 = norm_layer(planes * self.expansion)
139
140
141
142
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

143
    def forward(self, x: Tensor) -> Tensor:
144
        identity = x
145
146
147
148
149
150
151
152
153
154
155
156
157

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
158
            identity = self.downsample(x)
159

160
        out += identity
161
162
163
164
165
        out = self.relu(out)

        return out


Soumith Chintala's avatar
Soumith Chintala committed
166
class ResNet(nn.Module):
167
168
169
170
171
172
173
174
175
    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        layers: List[int],
        num_classes: int = 1000,
        zero_init_residual: bool = False,
        groups: int = 1,
        width_per_group: int = 64,
        replace_stride_with_dilation: Optional[List[bool]] = None,
176
        norm_layer: Optional[Callable[..., nn.Module]] = None,
177
    ) -> None:
178
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
179
        _log_api_usage_once(self)
180
181
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
182
        self._norm_layer = norm_layer
183
184

        self.inplanes = 64
185
186
187
188
189
190
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
191
192
            raise ValueError(
                "replace_stride_with_dilation should be None "
193
                f"or a 3-element tuple, got {replace_stride_with_dilation}"
194
            )
195
196
        self.groups = groups
        self.base_width = width_per_group
197
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
198
        self.bn1 = norm_layer(self.inplanes)
199
200
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
201
        self.layer1 = self._make_layer(block, 64, layers[0])
202
203
204
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
205
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
206
        self.fc = nn.Linear(512 * block.expansion, num_classes)
207
208
209

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
210
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
211
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
212
213
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
214

215
216
217
218
219
220
        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
221
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
222
                elif isinstance(m, BasicBlock):
223
                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]
224

225
226
227
228
229
230
231
232
    def _make_layer(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        planes: int,
        blocks: int,
        stride: int = 1,
        dilate: bool = False,
    ) -> nn.Sequential:
233
        norm_layer = self._norm_layer
234
        downsample = None
235
236
237
238
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
239
240
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
241
                conv1x1(self.inplanes, planes * block.expansion, stride),
242
                norm_layer(planes * block.expansion),
243
244
245
            )

        layers = []
246
247
248
249
250
        layers.append(
            block(
                self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
            )
        )
251
        self.inplanes = planes * block.expansion
252
        for _ in range(1, blocks):
253
254
255
256
257
258
259
260
261
262
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    groups=self.groups,
                    base_width=self.base_width,
                    dilation=self.dilation,
                    norm_layer=norm_layer,
                )
            )
263
264
265

        return nn.Sequential(*layers)

266
    def _forward_impl(self, x: Tensor) -> Tensor:
267
        # See note [TorchScript super()]
268
269
270
271
272
273
274
275
276
277
278
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
279
        x = torch.flatten(x, 1)
280
281
282
283
        x = self.fc(x)

        return x

284
    def forward(self, x: Tensor) -> Tensor:
285
        return self._forward_impl(x)
286

287

288
289
290
def _resnet(
    block: Type[Union[BasicBlock, Bottleneck]],
    layers: List[int],
291
    weights: Optional[WeightsEnum],
292
    progress: bool,
293
    **kwargs: Any,
294
) -> ResNet:
295
296
297
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

298
    model = ResNet(block, layers, **kwargs)
299
300
301
302

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

303
304
305
    return model


306
307
308
309
310
311
312
313
314
315
316
317
318
319
_COMMON_META = {
    "min_size": (1, 1),
    "categories": _IMAGENET_CATEGORIES,
}


class ResNet18_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 11689512,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
320
321
322
323
            "metrics": {
                "acc@1": 69.758,
                "acc@5": 89.078,
            },
324
325
326
327
328
329
330
331
332
333
334
335
336
        },
    )
    DEFAULT = IMAGENET1K_V1


class ResNet34_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/resnet34-b627a593.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 21797672,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
337
338
339
340
            "metrics": {
                "acc@1": 73.314,
                "acc@5": 91.420,
            },
341
342
343
344
345
346
347
348
349
350
351
352
353
        },
    )
    DEFAULT = IMAGENET1K_V1


class ResNet50_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/resnet50-0676ba61.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 25557032,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
354
355
356
357
            "metrics": {
                "acc@1": 76.130,
                "acc@5": 92.862,
            },
358
359
360
361
362
363
364
365
366
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 25557032,
            "recipe": "https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621",
367
368
369
370
            "metrics": {
                "acc@1": 80.858,
                "acc@5": 95.434,
            },
371
372
373
374
375
376
377
378
379
380
381
382
383
        },
    )
    DEFAULT = IMAGENET1K_V2


class ResNet101_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/resnet101-63fe2227.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 44549160,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
384
385
386
387
            "metrics": {
                "acc@1": 77.374,
                "acc@5": 93.546,
            },
388
389
390
391
392
393
394
395
396
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/resnet101-cd907fc2.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 44549160,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
397
398
399
400
            "metrics": {
                "acc@1": 81.886,
                "acc@5": 95.780,
            },
401
402
403
404
405
406
407
408
409
410
411
412
413
        },
    )
    DEFAULT = IMAGENET1K_V2


class ResNet152_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/resnet152-394f9c45.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 60192808,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet",
414
415
416
417
            "metrics": {
                "acc@1": 78.312,
                "acc@5": 94.046,
            },
418
419
420
421
422
423
424
425
426
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/resnet152-f82ba261.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 60192808,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
427
428
429
430
            "metrics": {
                "acc@1": 82.284,
                "acc@5": 96.002,
            },
431
432
433
434
435
436
437
438
439
440
441
442
443
        },
    )
    DEFAULT = IMAGENET1K_V2


class ResNeXt50_32X4D_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 25028904,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
444
445
446
447
            "metrics": {
                "acc@1": 77.618,
                "acc@5": 93.698,
            },
448
449
450
451
452
453
454
455
456
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 25028904,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
457
458
459
460
            "metrics": {
                "acc@1": 81.198,
                "acc@5": 95.340,
            },
461
462
463
464
465
466
467
468
469
470
471
472
473
        },
    )
    DEFAULT = IMAGENET1K_V2


class ResNeXt101_32X8D_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 88791336,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext",
474
475
476
477
            "metrics": {
                "acc@1": 79.312,
                "acc@5": 94.526,
            },
478
479
480
481
482
483
484
485
486
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 88791336,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
487
488
489
490
            "metrics": {
                "acc@1": 82.834,
                "acc@5": 96.228,
            },
491
492
493
494
495
        },
    )
    DEFAULT = IMAGENET1K_V2


496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
class ResNeXt101_64X4D_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/resnext101_64x4d-173b62eb.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 83455272,
            "recipe": "https://github.com/pytorch/vision/pull/5935",
            "metrics": {
                # Mock
                "acc@1": 83.246,
                "acc@5": 96.454,
            },
        },
    )
    DEFAULT = IMAGENET1K_V1


514
515
516
517
518
519
520
521
class Wide_ResNet50_2_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 68883240,
            "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
522
523
524
525
            "metrics": {
                "acc@1": 78.468,
                "acc@5": 94.086,
            },
526
527
528
529
530
531
532
533
534
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 68883240,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
535
536
537
538
            "metrics": {
                "acc@1": 81.602,
                "acc@5": 95.758,
            },
539
540
541
542
543
544
545
546
547
548
549
550
551
        },
    )
    DEFAULT = IMAGENET1K_V2


class Wide_ResNet101_2_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 126886696,
            "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439",
552
553
554
555
            "metrics": {
                "acc@1": 78.848,
                "acc@5": 94.284,
            },
556
557
558
559
560
561
562
563
564
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 126886696,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
565
566
567
568
            "metrics": {
                "acc@1": 82.510,
                "acc@5": 96.020,
            },
569
570
571
572
573
574
575
        },
    )
    DEFAULT = IMAGENET1K_V2


@handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1))
def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
576
    """ResNet-18 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
577
578

    Args:
579
580
581
582
583
584
585
586
587
588
589
590
591
592
        weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.ResNet18_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.ResNet18_Weights
        :members:
593
    """
594
595
596
    weights = ResNet18_Weights.verify(weights)

    return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs)
597
598


599
600
@handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1))
def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
601
    """ResNet-34 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
602
603

    Args:
604
605
606
607
608
609
610
611
612
613
614
615
616
617
        weights (:class:`~torchvision.models.ResNet34_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.ResNet34_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.ResNet34_Weights
        :members:
618
    """
619
    weights = ResNet34_Weights.verify(weights)
620

621
    return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs)
622

623
624
625

@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1))
def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
626
    """ResNet-50 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
627
628

    Args:
629
630
631
632
633
634
635
636
637
638
639
640
641
642
        weights (:class:`~torchvision.models.ResNet50_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.ResNet50_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.ResNet50_Weights
        :members:
643
    """
644
645
646
    weights = ResNet50_Weights.verify(weights)

    return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
647
648


649
650
@handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1))
def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
651
    """ResNet-101 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
652
653

    Args:
654
655
656
657
658
659
660
661
662
663
664
665
666
667
        weights (:class:`~torchvision.models.ResNet101_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.ResNet101_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.ResNet101_Weights
        :members:
668
    """
669
    weights = ResNet101_Weights.verify(weights)
670

671
    return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
672

673
674
675

@handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1))
def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
676
    """ResNet-152 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__.
677
678

    Args:
679
680
681
682
683
684
685
686
687
688
689
690
691
692
        weights (:class:`~torchvision.models.ResNet152_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.ResNet152_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
            for more details about this class.

    .. autoclass:: torchvision.models.ResNet152_Weights
        :members:
693
    """
694
695
696
    weights = ResNet152_Weights.verify(weights)

    return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs)
697
698


699
700
701
702
@handle_legacy_interface(weights=("pretrained", ResNeXt50_32X4D_Weights.IMAGENET1K_V1))
def resnext50_32x4d(
    *, weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
703
704
    """ResNeXt-50 32x4d model from
    `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_.
705
706

    Args:
707
708
709
710
711
712
713
714
715
716
717
718
719
        weights (:class:`~torchvision.models.ResNeXt50_32X4D_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.ResNext50_32X4D_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.ResNeXt50_32X4D_Weights
        :members:
720
    """
721
    weights = ResNeXt50_32X4D_Weights.verify(weights)
722

723
724
725
    _ovewrite_named_param(kwargs, "groups", 32)
    _ovewrite_named_param(kwargs, "width_per_group", 4)
    return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
726

727
728
729
730
731

@handle_legacy_interface(weights=("pretrained", ResNeXt101_32X8D_Weights.IMAGENET1K_V1))
def resnext101_32x8d(
    *, weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
732
733
    """ResNeXt-101 32x8d model from
    `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_.
734
735

    Args:
736
737
738
739
740
741
742
743
744
745
746
747
748
        weights (:class:`~torchvision.models.ResNeXt101_32X8D_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.ResNeXt101_32X8D_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.ResNeXt101_32X8D_Weights
        :members:
749
    """
750
    weights = ResNeXt101_32X8D_Weights.verify(weights)
751

752
753
754
    _ovewrite_named_param(kwargs, "groups", 32)
    _ovewrite_named_param(kwargs, "width_per_group", 8)
    return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
755

756

757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
def resnext101_64x4d(
    *, weights: Optional[ResNeXt101_64X4D_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
    """ResNeXt-101 64x4d model from
    `Aggregated Residual Transformation for Deep Neural Networks <https://arxiv.org/abs/1611.05431>`_.

    Args:
        weights (:class:`~torchvision.models.ResNeXt101_64X4D_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.ResNeXt101_64X4D_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.ResNeXt101_64X4D_Weights
        :members:
    """
    weights = ResNeXt101_64X4D_Weights.verify(weights)

    _ovewrite_named_param(kwargs, "groups", 64)
    _ovewrite_named_param(kwargs, "width_per_group", 4)
    return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)


785
786
787
788
@handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.IMAGENET1K_V1))
def wide_resnet50_2(
    *, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
789
790
    """Wide ResNet-50-2 model from
    `Wide Residual Networks <https://arxiv.org/abs/1605.07146>`_.
791
792
793
794
795
796
797

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
798
799
800
801
802
803
804
805
806
807
808
809
810
        weights (:class:`~torchvision.models.Wide_ResNet50_2_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.Wide_ResNet50_2_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.Wide_ResNet50_2_Weights
        :members:
811
    """
812
813
814
815
    weights = Wide_ResNet50_2_Weights.verify(weights)

    _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
    return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs)
816
817


818
819
820
821
@handle_legacy_interface(weights=("pretrained", Wide_ResNet101_2_Weights.IMAGENET1K_V1))
def wide_resnet101_2(
    *, weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any
) -> ResNet:
822
823
    """Wide ResNet-101-2 model from
    `Wide Residual Networks <https://arxiv.org/abs/1605.07146>`_.
824
825
826
827
828
829
830

    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.

    Args:
831
832
833
834
835
836
837
838
839
840
841
842
843
        weights (:class:`~torchvision.models.Wide_ResNet101_2_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.Wide_ResNet101_2_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.Wide_ResNet101_2_Weights
        :members:
844
    """
845
846
847
848
    weights = Wide_ResNet101_2_Weights.verify(weights)

    _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
    return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)