regnet.py 61.7 KB
Newer Older
1
2
3
import math
from collections import OrderedDict
from functools import partial
4
from typing import Any, Callable, Dict, List, Optional, Tuple
5
6

import torch
7
8
from torch import nn, Tensor

9
from ..ops.misc import Conv2dNormActivation, SqueezeExcitation
10
from ..transforms._presets import ImageClassification, InterpolationMode
11
from ..utils import _log_api_usage_once
12
from ._api import Weights, WeightsEnum
13
from ._meta import _IMAGENET_CATEGORIES
14
from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
15
16


17
18
__all__ = [
    "RegNet",
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    "RegNet_Y_400MF_Weights",
    "RegNet_Y_800MF_Weights",
    "RegNet_Y_1_6GF_Weights",
    "RegNet_Y_3_2GF_Weights",
    "RegNet_Y_8GF_Weights",
    "RegNet_Y_16GF_Weights",
    "RegNet_Y_32GF_Weights",
    "RegNet_Y_128GF_Weights",
    "RegNet_X_400MF_Weights",
    "RegNet_X_800MF_Weights",
    "RegNet_X_1_6GF_Weights",
    "RegNet_X_3_2GF_Weights",
    "RegNet_X_8GF_Weights",
    "RegNet_X_16GF_Weights",
    "RegNet_X_32GF_Weights",
34
35
36
37
38
39
40
    "regnet_y_400mf",
    "regnet_y_800mf",
    "regnet_y_1_6gf",
    "regnet_y_3_2gf",
    "regnet_y_8gf",
    "regnet_y_16gf",
    "regnet_y_32gf",
41
    "regnet_y_128gf",
42
43
44
45
46
47
48
49
    "regnet_x_400mf",
    "regnet_x_800mf",
    "regnet_x_1_6gf",
    "regnet_x_3_2gf",
    "regnet_x_8gf",
    "regnet_x_16gf",
    "regnet_x_32gf",
]
50
51


52
class SimpleStemIN(Conv2dNormActivation):
53
54
55
56
57
58
59
60
61
    """Simple stem for ImageNet: 3x3, BN, ReLU."""

    def __init__(
        self,
        width_in: int,
        width_out: int,
        norm_layer: Callable[..., nn.Module],
        activation_layer: Callable[..., nn.Module],
    ) -> None:
62
63
64
        super().__init__(
            width_in, width_out, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=activation_layer
        )
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84


class BottleneckTransform(nn.Sequential):
    """Bottleneck transformation: 1x1, 3x3 [+SE], 1x1."""

    def __init__(
        self,
        width_in: int,
        width_out: int,
        stride: int,
        norm_layer: Callable[..., nn.Module],
        activation_layer: Callable[..., nn.Module],
        group_width: int,
        bottleneck_multiplier: float,
        se_ratio: Optional[float],
    ) -> None:
        layers: OrderedDict[str, nn.Module] = OrderedDict()
        w_b = int(round(width_out * bottleneck_multiplier))
        g = w_b // group_width

85
        layers["a"] = Conv2dNormActivation(
86
87
            width_in, w_b, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=activation_layer
        )
88
        layers["b"] = Conv2dNormActivation(
89
90
            w_b, w_b, kernel_size=3, stride=stride, groups=g, norm_layer=norm_layer, activation_layer=activation_layer
        )
91
92
93
94
95
96
97
98
99
100
101

        if se_ratio:
            # The SE reduction ratio is defined with respect to the
            # beginning of the block
            width_se_out = int(round(se_ratio * width_in))
            layers["se"] = SqueezeExcitation(
                input_channels=w_b,
                squeeze_channels=width_se_out,
                activation=activation_layer,
            )

102
        layers["c"] = Conv2dNormActivation(
103
104
            w_b, width_out, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=None
        )
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        super().__init__(layers)


class ResBottleneckBlock(nn.Module):
    """Residual bottleneck block: x + F(x), F = bottleneck transform."""

    def __init__(
        self,
        width_in: int,
        width_out: int,
        stride: int,
        norm_layer: Callable[..., nn.Module],
        activation_layer: Callable[..., nn.Module],
        group_width: int = 1,
        bottleneck_multiplier: float = 1.0,
        se_ratio: Optional[float] = None,
    ) -> None:
        super().__init__()

        # Use skip connection with projection if shape changes
        self.proj = None
        should_proj = (width_in != width_out) or (stride != 1)
        if should_proj:
128
            self.proj = Conv2dNormActivation(
129
130
                width_in, width_out, kernel_size=1, stride=stride, norm_layer=norm_layer, activation_layer=None
            )
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
        self.f = BottleneckTransform(
            width_in,
            width_out,
            stride,
            norm_layer,
            activation_layer,
            group_width,
            bottleneck_multiplier,
            se_ratio,
        )
        self.activation = activation_layer(inplace=True)

    def forward(self, x: Tensor) -> Tensor:
        if self.proj is not None:
            x = self.proj(x) + self.f(x)
        else:
            x = x + self.f(x)
        return self.activation(x)


class AnyStage(nn.Sequential):
    """AnyNet stage (sequence of blocks w/ the same output shape)."""

    def __init__(
        self,
        width_in: int,
        width_out: int,
        stride: int,
        depth: int,
        block_constructor: Callable[..., nn.Module],
        norm_layer: Callable[..., nn.Module],
        activation_layer: Callable[..., nn.Module],
        group_width: int,
        bottleneck_multiplier: float,
        se_ratio: Optional[float] = None,
        stage_index: int = 0,
    ) -> None:
        super().__init__()

        for i in range(depth):
            block = block_constructor(
                width_in if i == 0 else width_out,
                width_out,
                stride if i == 0 else 1,
                norm_layer,
                activation_layer,
                group_width,
                bottleneck_multiplier,
                se_ratio,
            )

            self.add_module(f"block{stage_index}-{i}", block)


class BlockParams:
    def __init__(
        self,
        depths: List[int],
        widths: List[int],
        group_widths: List[int],
        bottleneck_multipliers: List[float],
        strides: List[int],
        se_ratio: Optional[float] = None,
    ) -> None:
        self.depths = depths
        self.widths = widths
        self.group_widths = group_widths
        self.bottleneck_multipliers = bottleneck_multipliers
        self.strides = strides
        self.se_ratio = se_ratio

    @classmethod
    def from_init_params(
        cls,
        depth: int,
        w_0: int,
        w_a: float,
        w_m: float,
        group_width: int,
        bottleneck_multiplier: float = 1.0,
        se_ratio: Optional[float] = None,
        **kwargs: Any,
    ) -> "BlockParams":
        """
        Programatically compute all the per-block settings,
        given the RegNet parameters.

        The first step is to compute the quantized linear block parameters,
        in log space. Key parameters are:
        - `w_a` is the width progression slope
        - `w_0` is the initial width
        - `w_m` is the width stepping in the log space

        In other terms
        `log(block_width) = log(w_0) + w_m * block_capacity`,
        with `bock_capacity` ramping up following the w_0 and w_a params.
        This block width is finally quantized to multiples of 8.

        The second step is to compute the parameters per stage,
        taking into account the skip connection and the final 1x1 convolutions.
        We use the fact that the output width is constant within a stage.
        """

        QUANT = 8
        STRIDE = 2

        if w_a < 0 or w_0 <= 0 or w_m <= 1 or w_0 % 8 != 0:
            raise ValueError("Invalid RegNet settings")
        # Compute the block widths. Each stage has one unique block width
        widths_cont = torch.arange(depth) * w_a + w_0
        block_capacity = torch.round(torch.log(widths_cont / w_0) / math.log(w_m))
242
        block_widths = (torch.round(torch.divide(w_0 * torch.pow(w_m, block_capacity), QUANT)) * QUANT).int().tolist()
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
        num_stages = len(set(block_widths))

        # Convert to per stage parameters
        split_helper = zip(
            block_widths + [0],
            [0] + block_widths,
            block_widths + [0],
            [0] + block_widths,
        )
        splits = [w != wp or r != rp for w, wp, r, rp in split_helper]

        stage_widths = [w for w, t in zip(block_widths, splits[:-1]) if t]
        stage_depths = torch.diff(torch.tensor([d for d, t in enumerate(splits) if t])).int().tolist()

        strides = [STRIDE] * num_stages
        bottleneck_multipliers = [bottleneck_multiplier] * num_stages
        group_widths = [group_width] * num_stages

        # Adjust the compatibility of stage widths and group widths
        stage_widths, group_widths = cls._adjust_widths_groups_compatibilty(
            stage_widths, bottleneck_multipliers, group_widths
        )

        return cls(
            depths=stage_depths,
            widths=stage_widths,
            group_widths=group_widths,
            bottleneck_multipliers=bottleneck_multipliers,
            strides=strides,
            se_ratio=se_ratio,
        )

    def _get_expanded_params(self):
276
        return zip(self.widths, self.strides, self.depths, self.group_widths, self.bottleneck_multipliers)
277
278
279

    @staticmethod
    def _adjust_widths_groups_compatibilty(
280
281
        stage_widths: List[int], bottleneck_ratios: List[float], group_widths: List[int]
    ) -> Tuple[List[int], List[int]]:
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
        """
        Adjusts the compatibility of widths and groups,
        depending on the bottleneck ratio.
        """
        # Compute all widths for the current settings
        widths = [int(w * b) for w, b in zip(stage_widths, bottleneck_ratios)]
        group_widths_min = [min(g, w_bot) for g, w_bot in zip(group_widths, widths)]

        # Compute the adjusted widths so that stage and group widths fit
        ws_bot = [_make_divisible(w_bot, g) for w_bot, g in zip(widths, group_widths_min)]
        stage_widths = [int(w_bot / b) for w_bot, b in zip(ws_bot, bottleneck_ratios)]
        return stage_widths, group_widths_min


class RegNet(nn.Module):
    def __init__(
        self,
        block_params: BlockParams,
        num_classes: int = 1000,
        stem_width: int = 32,
        stem_type: Optional[Callable[..., nn.Module]] = None,
        block_type: Optional[Callable[..., nn.Module]] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        activation: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
308
        _log_api_usage_once(self)
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375

        if stem_type is None:
            stem_type = SimpleStemIN
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if block_type is None:
            block_type = ResBottleneckBlock
        if activation is None:
            activation = nn.ReLU

        # Ad hoc stem
        self.stem = stem_type(
            3,  # width_in
            stem_width,
            norm_layer,
            activation,
        )

        current_width = stem_width

        blocks = []
        for i, (
            width_out,
            stride,
            depth,
            group_width,
            bottleneck_multiplier,
        ) in enumerate(block_params._get_expanded_params()):
            blocks.append(
                (
                    f"block{i+1}",
                    AnyStage(
                        current_width,
                        width_out,
                        stride,
                        depth,
                        block_type,
                        norm_layer,
                        activation,
                        group_width,
                        bottleneck_multiplier,
                        block_params.se_ratio,
                        stage_index=i + 1,
                    ),
                )
            )

            current_width = width_out

        self.trunk_output = nn.Sequential(OrderedDict(blocks))

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(in_features=current_width, out_features=num_classes)

        # Performs ResNet-style weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # Note that there is no bias due to BN
                fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                nn.init.normal_(m.weight, mean=0.0, std=math.sqrt(2.0 / fan_out))
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0.0, std=0.01)
                nn.init.zeros_(m.bias)

376
377
378
379
380
381
382
383
384
385
    def forward(self, x: Tensor) -> Tensor:
        x = self.stem(x)
        x = self.trunk_output(x)

        x = self.avgpool(x)
        x = x.flatten(start_dim=1)
        x = self.fc(x)

        return x

386

387
388
389
390
391
392
393
394
395
def _regnet(
    block_params: BlockParams,
    weights: Optional[WeightsEnum],
    progress: bool,
    **kwargs: Any,
) -> RegNet:
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

396
397
    norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1))
    model = RegNet(block_params, norm_layer=norm_layer, **kwargs)
398
399
400
401

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

402
403
404
    return model


405
_COMMON_META: Dict[str, Any] = {
406
407
408
409
    "min_size": (1, 1),
    "categories": _IMAGENET_CATEGORIES,
}

410
411
412
413
414
415
_COMMON_SWAG_META = {
    **_COMMON_META,
    "recipe": "https://github.com/facebookresearch/SWAG",
    "license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE",
}

416
417
418
419
420
421
422
423
424

class RegNet_Y_400MF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 4344144,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
425
426
427
428
429
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 74.046,
                    "acc@5": 91.716,
                }
430
            },
431
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
432
433
434
435
436
437
438
439
440
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 4344144,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
441
442
443
444
445
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 75.804,
                    "acc@5": 92.742,
                }
446
            },
447
448
449
450
451
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
452
453
454
455
456
457
458
459
460
461
462
463
464
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_Y_800MF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 6432512,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
465
466
467
468
469
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 76.420,
                    "acc@5": 93.136,
                }
470
            },
471
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
472
473
474
475
476
477
478
479
480
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 6432512,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
481
482
483
484
485
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 78.828,
                    "acc@5": 94.502,
                }
486
            },
487
488
489
490
491
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
492
493
494
495
496
497
498
499
500
501
502
503
504
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_Y_1_6GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 11202430,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
505
506
507
508
509
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 77.950,
                    "acc@5": 93.966,
                }
510
            },
511
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
512
513
514
515
516
517
518
519
520
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 11202430,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
521
522
523
524
525
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 80.876,
                    "acc@5": 95.444,
                }
526
            },
527
528
529
530
531
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
532
533
534
535
536
537
538
539
540
541
542
543
544
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_Y_3_2GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 19436338,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
545
546
547
548
549
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 78.948,
                    "acc@5": 94.576,
                }
550
            },
551
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
552
553
554
555
556
557
558
559
560
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 19436338,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
561
562
563
564
565
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 81.982,
                    "acc@5": 95.972,
                }
566
            },
567
568
569
570
571
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
572
573
574
575
576
577
578
579
580
581
582
583
584
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_Y_8GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 39381472,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
585
586
587
588
589
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 80.032,
                    "acc@5": 95.048,
                }
590
            },
591
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
592
593
594
595
596
597
598
599
600
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 39381472,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
601
602
603
604
605
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 82.828,
                    "acc@5": 96.330,
                }
606
            },
607
608
609
610
611
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
612
613
614
615
616
617
618
619
620
621
622
623
624
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_Y_16GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 83590140,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
625
626
627
628
629
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 80.424,
                    "acc@5": 95.240,
                }
630
            },
631
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
632
633
634
635
636
637
638
639
640
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 83590140,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
641
642
643
644
645
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 82.886,
                    "acc@5": 96.328,
                }
646
            },
647
648
649
650
651
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
652
653
        },
    )
654
    IMAGENET1K_SWAG_E2E_V1 = Weights(
655
656
657
658
659
660
661
        url="https://download.pytorch.org/models/regnet_y_16gf_swag-43afe44d.pth",
        transforms=partial(
            ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_SWAG_META,
            "num_params": 83590140,
662
663
664
665
666
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 86.012,
                    "acc@5": 98.054,
                }
667
            },
668
669
670
671
            "_docs": """
                These weights are learnt via transfer learning by end-to-end fine-tuning the original
                `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
            """,
672
673
        },
    )
674
675
676
677
678
679
680
681
682
    IMAGENET1K_SWAG_LINEAR_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_16gf_lc_swag-f3ec0043.pth",
        transforms=partial(
            ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_SWAG_META,
            "recipe": "https://github.com/pytorch/vision/pull/5793",
            "num_params": 83590140,
683
684
685
686
687
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 83.976,
                    "acc@5": 97.244,
                }
688
            },
689
690
691
692
            "_docs": """
                These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
                weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
            """,
693
694
        },
    )
695
696
697
698
699
700
701
702
703
704
705
    DEFAULT = IMAGENET1K_V2


class RegNet_Y_32GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 145046770,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
706
707
708
709
710
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 80.878,
                    "acc@5": 95.340,
                }
711
            },
712
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
713
714
715
716
717
718
719
720
721
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 145046770,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
722
723
724
725
726
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 83.368,
                    "acc@5": 96.498,
                }
727
            },
728
729
730
731
732
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
733
734
        },
    )
735
    IMAGENET1K_SWAG_E2E_V1 = Weights(
736
737
738
739
740
741
742
        url="https://download.pytorch.org/models/regnet_y_32gf_swag-04fdfa75.pth",
        transforms=partial(
            ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_SWAG_META,
            "num_params": 145046770,
743
744
745
746
747
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 86.838,
                    "acc@5": 98.362,
                }
748
            },
749
750
751
752
            "_docs": """
                These weights are learnt via transfer learning by end-to-end fine-tuning the original
                `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
            """,
753
754
        },
    )
755
756
757
758
759
760
761
762
763
    IMAGENET1K_SWAG_LINEAR_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_32gf_lc_swag-e1583746.pth",
        transforms=partial(
            ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_SWAG_META,
            "recipe": "https://github.com/pytorch/vision/pull/5793",
            "num_params": 145046770,
764
765
766
767
768
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 84.622,
                    "acc@5": 97.480,
                }
769
            },
770
771
772
773
            "_docs": """
                These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
                weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
            """,
774
775
        },
    )
776
777
778
779
    DEFAULT = IMAGENET1K_V2


class RegNet_Y_128GF_Weights(WeightsEnum):
780
    IMAGENET1K_SWAG_E2E_V1 = Weights(
781
782
783
784
785
786
787
        url="https://download.pytorch.org/models/regnet_y_128gf_swag-c8ce3e52.pth",
        transforms=partial(
            ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_SWAG_META,
            "num_params": 644812894,
788
789
790
791
792
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 88.228,
                    "acc@5": 98.682,
                }
793
            },
794
795
796
797
            "_docs": """
                These weights are learnt via transfer learning by end-to-end fine-tuning the original
                `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
            """,
798
799
        },
    )
800
801
802
803
804
805
806
807
808
    IMAGENET1K_SWAG_LINEAR_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_128gf_lc_swag-cbe8ce12.pth",
        transforms=partial(
            ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_SWAG_META,
            "recipe": "https://github.com/pytorch/vision/pull/5793",
            "num_params": 644812894,
809
810
811
812
813
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 86.068,
                    "acc@5": 97.844,
                }
814
            },
815
816
817
818
            "_docs": """
                These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
                weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
            """,
819
820
821
        },
    )
    DEFAULT = IMAGENET1K_SWAG_E2E_V1
822
823
824
825
826
827
828
829
830
831


class RegNet_X_400MF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 5495976,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
832
833
834
835
836
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 72.834,
                    "acc@5": 90.950,
                }
837
            },
838
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
839
840
841
842
843
844
845
846
847
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 5495976,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
848
849
850
851
852
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 74.864,
                    "acc@5": 92.322,
                }
853
            },
854
855
856
857
858
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
859
860
861
862
863
864
865
866
867
868
869
870
871
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_X_800MF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 7259656,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
872
873
874
875
876
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 75.212,
                    "acc@5": 92.348,
                }
877
            },
878
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
879
880
881
882
883
884
885
886
887
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 7259656,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
888
889
890
891
892
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 77.522,
                    "acc@5": 93.826,
                }
893
            },
894
895
896
897
898
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
899
900
901
902
903
904
905
906
907
908
909
910
911
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_X_1_6GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 9190136,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
912
913
914
915
916
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 77.040,
                    "acc@5": 93.440,
                }
917
            },
918
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
919
920
921
922
923
924
925
926
927
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 9190136,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
928
929
930
931
932
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 79.668,
                    "acc@5": 94.922,
                }
933
            },
934
935
936
937
938
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
939
940
941
942
943
944
945
946
947
948
949
950
951
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_X_3_2GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 15296552,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
952
953
954
955
956
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 78.364,
                    "acc@5": 93.992,
                }
957
            },
958
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
959
960
961
962
963
964
965
966
967
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 15296552,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
968
969
970
971
972
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 81.196,
                    "acc@5": 95.430,
                }
973
            },
974
975
976
977
978
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
979
980
981
982
983
984
985
986
987
988
989
990
991
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_X_8GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 39572648,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
992
993
994
995
996
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 79.344,
                    "acc@5": 94.686,
                }
997
            },
998
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
999
1000
1001
1002
1003
1004
1005
1006
1007
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 39572648,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
1008
1009
1010
1011
1012
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 81.682,
                    "acc@5": 95.678,
                }
1013
            },
1014
1015
1016
1017
1018
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_X_16GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 54278536,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
1032
1033
1034
1035
1036
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 80.058,
                    "acc@5": 94.944,
                }
1037
            },
1038
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
1039
1040
1041
1042
1043
1044
1045
1046
1047
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 54278536,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
1048
1049
1050
1051
1052
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 82.716,
                    "acc@5": 96.196,
                }
1053
            },
1054
1055
1056
1057
1058
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_X_32GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 107811560,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
1072
1073
1074
1075
1076
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 80.622,
                    "acc@5": 95.248,
                }
1077
            },
1078
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
1079
1080
1081
1082
1083
1084
1085
1086
1087
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 107811560,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
1088
1089
1090
1091
1092
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 83.014,
                    "acc@5": 96.288,
                }
1093
            },
1094
1095
1096
1097
1098
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
1099
1100
1101
1102
1103
1104
1105
        },
    )
    DEFAULT = IMAGENET1K_V2


@handle_legacy_interface(weights=("pretrained", RegNet_Y_400MF_Weights.IMAGENET1K_V1))
def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1106
1107
    """
    Constructs a RegNetY_400MF architecture from
1108
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1109
1110

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1111
1112
        weights (:class:`~torchvision.models.RegNet_Y_400MF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_400MF_Weights` below for more details and possible values.
1113
1114
1115
1116
1117
1118
1119
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1120
    .. autoclass:: torchvision.models.RegNet_Y_400MF_Weights
1121
        :members:
1122
    """
1123
1124
    weights = RegNet_Y_400MF_Weights.verify(weights)

1125
    params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs)
1126
    return _regnet(params, weights, progress, **kwargs)
1127
1128


1129
1130
@handle_legacy_interface(weights=("pretrained", RegNet_Y_800MF_Weights.IMAGENET1K_V1))
def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1131
1132
    """
    Constructs a RegNetY_800MF architecture from
1133
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1134
1135

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1136
1137
        weights (:class:`~torchvision.models.RegNet_Y_800MF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_800MF_Weights` below for more details and possible values.
1138
1139
1140
1141
1142
1143
1144
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1145
    .. autoclass:: torchvision.models.RegNet_Y_800MF_Weights
1146
        :members:
1147
    """
1148
1149
    weights = RegNet_Y_800MF_Weights.verify(weights)

1150
    params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs)
1151
    return _regnet(params, weights, progress, **kwargs)
1152
1153


1154
1155
@handle_legacy_interface(weights=("pretrained", RegNet_Y_1_6GF_Weights.IMAGENET1K_V1))
def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1156
1157
    """
    Constructs a RegNetY_1.6GF architecture from
1158
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1159
1160

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1161
1162
        weights (:class:`~torchvision.models.RegNet_Y_1_6GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_1_6GF_Weights` below for more details and possible values.
1163
1164
1165
1166
1167
1168
1169
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1170
    .. autoclass:: torchvision.models.RegNet_Y_1_6GF_Weights
1171
        :members:
1172
    """
1173
1174
    weights = RegNet_Y_1_6GF_Weights.verify(weights)

1175
1176
1177
    params = BlockParams.from_init_params(
        depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs
    )
1178
    return _regnet(params, weights, progress, **kwargs)
1179
1180


1181
1182
@handle_legacy_interface(weights=("pretrained", RegNet_Y_3_2GF_Weights.IMAGENET1K_V1))
def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1183
1184
    """
    Constructs a RegNetY_3.2GF architecture from
1185
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1186
1187

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1188
1189
        weights (:class:`~torchvision.models.RegNet_Y_3_2GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_3_2GF_Weights` below for more details and possible values.
1190
1191
1192
1193
1194
1195
1196
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1197
    .. autoclass:: torchvision.models.RegNet_Y_3_2GF_Weights
1198
        :members:
1199
    """
1200
1201
    weights = RegNet_Y_3_2GF_Weights.verify(weights)

1202
1203
1204
    params = BlockParams.from_init_params(
        depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs
    )
1205
    return _regnet(params, weights, progress, **kwargs)
1206
1207


1208
1209
@handle_legacy_interface(weights=("pretrained", RegNet_Y_8GF_Weights.IMAGENET1K_V1))
def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1210
1211
    """
    Constructs a RegNetY_8GF architecture from
1212
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1213
1214

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1215
1216
        weights (:class:`~torchvision.models.RegNet_Y_8GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_8GF_Weights` below for more details and possible values.
1217
1218
1219
1220
1221
1222
1223
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1224
    .. autoclass:: torchvision.models.RegNet_Y_8GF_Weights
1225
        :members:
1226
    """
1227
1228
    weights = RegNet_Y_8GF_Weights.verify(weights)

1229
1230
1231
    params = BlockParams.from_init_params(
        depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs
    )
1232
    return _regnet(params, weights, progress, **kwargs)
1233
1234


1235
1236
@handle_legacy_interface(weights=("pretrained", RegNet_Y_16GF_Weights.IMAGENET1K_V1))
def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1237
1238
    """
    Constructs a RegNetY_16GF architecture from
1239
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1240
1241

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1242
1243
        weights (:class:`~torchvision.models.RegNet_Y_16GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_16GF_Weights` below for more details and possible values.
1244
1245
1246
1247
1248
1249
1250
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1251
    .. autoclass:: torchvision.models.RegNet_Y_16GF_Weights
1252
        :members:
1253
    """
1254
1255
    weights = RegNet_Y_16GF_Weights.verify(weights)

1256
1257
1258
    params = BlockParams.from_init_params(
        depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs
    )
1259
    return _regnet(params, weights, progress, **kwargs)
1260
1261


1262
1263
@handle_legacy_interface(weights=("pretrained", RegNet_Y_32GF_Weights.IMAGENET1K_V1))
def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1264
1265
    """
    Constructs a RegNetY_32GF architecture from
1266
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1267
1268

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1269
1270
        weights (:class:`~torchvision.models.RegNet_Y_32GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_32GF_Weights` below for more details and possible values.
1271
1272
1273
1274
1275
1276
1277
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1278
    .. autoclass:: torchvision.models.RegNet_Y_32GF_Weights
1279
        :members:
1280
    """
1281
1282
    weights = RegNet_Y_32GF_Weights.verify(weights)

1283
1284
1285
    params = BlockParams.from_init_params(
        depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs
    )
1286
    return _regnet(params, weights, progress, **kwargs)
1287
1288


1289
1290
@handle_legacy_interface(weights=("pretrained", None))
def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1291
1292
    """
    Constructs a RegNetY_128GF architecture from
1293
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1294
1295

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1296
1297
        weights (:class:`~torchvision.models.RegNet_Y_128GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_128GF_Weights` below for more details and possible values.
1298
1299
1300
1301
1302
1303
1304
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1305
    .. autoclass:: torchvision.models.RegNet_Y_128GF_Weights
1306
        :members:
1307
    """
1308
1309
    weights = RegNet_Y_128GF_Weights.verify(weights)

1310
1311
1312
    params = BlockParams.from_init_params(
        depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs
    )
1313
    return _regnet(params, weights, progress, **kwargs)
1314
1315


1316
1317
@handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.IMAGENET1K_V1))
def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1318
1319
    """
    Constructs a RegNetX_400MF architecture from
1320
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1321
1322

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1323
1324
        weights (:class:`~torchvision.models.RegNet_X_400MF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_X_400MF_Weights` below for more details and possible values.
1325
1326
1327
1328
1329
1330
1331
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1332
    .. autoclass:: torchvision.models.RegNet_X_400MF_Weights
1333
        :members:
1334
    """
1335
1336
    weights = RegNet_X_400MF_Weights.verify(weights)

1337
    params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs)
1338
    return _regnet(params, weights, progress, **kwargs)
1339
1340


1341
1342
@handle_legacy_interface(weights=("pretrained", RegNet_X_800MF_Weights.IMAGENET1K_V1))
def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1343
1344
    """
    Constructs a RegNetX_800MF architecture from
1345
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1346
1347

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1348
1349
        weights (:class:`~torchvision.models.RegNet_X_800MF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_X_800MF_Weights` below for more details and possible values.
1350
1351
1352
1353
1354
1355
1356
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1357
    .. autoclass:: torchvision.models.RegNet_X_800MF_Weights
1358
        :members:
1359
    """
1360
1361
    weights = RegNet_X_800MF_Weights.verify(weights)

1362
    params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs)
1363
    return _regnet(params, weights, progress, **kwargs)
1364
1365


1366
1367
@handle_legacy_interface(weights=("pretrained", RegNet_X_1_6GF_Weights.IMAGENET1K_V1))
def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1368
1369
    """
    Constructs a RegNetX_1.6GF architecture from
1370
1371
1372
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1373
1374
        weights (:class:`~torchvision.models.RegNet_X_1_6GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_X_1_6GF_Weights` below for more details and possible values.
1375
1376
1377
1378
1379
1380
1381
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1382
    .. autoclass:: torchvision.models.RegNet_X_1_6GF_Weights
1383
        :members:
1384
1385

    Args:
1386
        weights (RegNet_X_1_6GF_Weights, optional): The pretrained weights for the model
1387
1388
        progress (bool): If True, displays a progress bar of the download to stderr
    """
1389
1390
    weights = RegNet_X_1_6GF_Weights.verify(weights)

1391
    params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs)
1392
    return _regnet(params, weights, progress, **kwargs)
1393
1394


1395
1396
@handle_legacy_interface(weights=("pretrained", RegNet_X_3_2GF_Weights.IMAGENET1K_V1))
def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1397
1398
    """
    Constructs a RegNetX_3.2GF architecture from
1399
1400
1401
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1402
1403
        weights (:class:`~torchvision.models.RegNet_X_3_2GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_X_3_2GF_Weights` below for more details and possible values.
1404
1405
1406
1407
1408
1409
1410
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1411
    .. autoclass:: torchvision.models.RegNet_X_3_2GF_Weights
1412
        :members:
1413
1414

    Args:
1415
        weights (RegNet_X_3_2GF_Weights, optional): The pretrained weights for the model
1416
1417
        progress (bool): If True, displays a progress bar of the download to stderr
    """
1418
1419
    weights = RegNet_X_3_2GF_Weights.verify(weights)

1420
    params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs)
1421
    return _regnet(params, weights, progress, **kwargs)
1422
1423


1424
1425
@handle_legacy_interface(weights=("pretrained", RegNet_X_8GF_Weights.IMAGENET1K_V1))
def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1426
1427
    """
    Constructs a RegNetX_8GF architecture from
1428
1429
1430
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1431
1432
        weights (:class:`~torchvision.models.RegNet_X_8GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_X_8GF_Weights` below for more details and possible values.
1433
1434
1435
1436
1437
1438
1439
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1440
    .. autoclass:: torchvision.models.RegNet_X_8GF_Weights
1441
        :members:
1442
1443

    Args:
1444
        weights (RegNet_X_8GF_Weights, optional): The pretrained weights for the model
1445
1446
        progress (bool): If True, displays a progress bar of the download to stderr
    """
1447
1448
    weights = RegNet_X_8GF_Weights.verify(weights)

1449
    params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs)
1450
    return _regnet(params, weights, progress, **kwargs)
1451
1452


1453
1454
@handle_legacy_interface(weights=("pretrained", RegNet_X_16GF_Weights.IMAGENET1K_V1))
def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1455
1456
    """
    Constructs a RegNetX_16GF architecture from
1457
1458
1459
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1460
1461
        weights (:class:`~torchvision.models.RegNet_X_16GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_X_16GF_Weights` below for more details and possible values.
1462
1463
1464
1465
1466
1467
1468
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1469
    .. autoclass:: torchvision.models.RegNet_X_16GF_Weights
1470
        :members:
1471
1472

    Args:
1473
        weights (RegNet_X_16GF_Weights, optional): The pretrained weights for the model
1474
1475
        progress (bool): If True, displays a progress bar of the download to stderr
    """
1476
1477
    weights = RegNet_X_16GF_Weights.verify(weights)

1478
    params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs)
1479
    return _regnet(params, weights, progress, **kwargs)
1480
1481


1482
1483
@handle_legacy_interface(weights=("pretrained", RegNet_X_32GF_Weights.IMAGENET1K_V1))
def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1484
1485
    """
    Constructs a RegNetX_32GF architecture from
1486
1487
1488
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1489
1490
        weights (:class:`~torchvision.models.RegNet_X_32GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_X_32GF_Weights` below for more details and possible values.
1491
1492
1493
1494
1495
1496
1497
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1498
    .. autoclass:: torchvision.models.RegNet_X_32GF_Weights
1499
        :members:
1500
1501

    Args:
1502
        weights (RegNet_X_32GF_Weights, optional): The pretrained weights for the model
1503
1504
        progress (bool): If True, displays a progress bar of the download to stderr
    """
1505
    weights = RegNet_X_32GF_Weights.verify(weights)
1506

1507
1508
    params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs)
    return _regnet(params, weights, progress, **kwargs)
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532


# The dictionary below is internal implementation detail and will be removed in v0.15
from ._utils import _ModelURLs


model_urls = _ModelURLs(
    {
        "regnet_y_400mf": RegNet_Y_400MF_Weights.IMAGENET1K_V1.url,
        "regnet_y_800mf": RegNet_Y_800MF_Weights.IMAGENET1K_V1.url,
        "regnet_y_1_6gf": RegNet_Y_1_6GF_Weights.IMAGENET1K_V1.url,
        "regnet_y_3_2gf": RegNet_Y_3_2GF_Weights.IMAGENET1K_V1.url,
        "regnet_y_8gf": RegNet_Y_8GF_Weights.IMAGENET1K_V1.url,
        "regnet_y_16gf": RegNet_Y_16GF_Weights.IMAGENET1K_V1.url,
        "regnet_y_32gf": RegNet_Y_32GF_Weights.IMAGENET1K_V1.url,
        "regnet_x_400mf": RegNet_X_400MF_Weights.IMAGENET1K_V1.url,
        "regnet_x_800mf": RegNet_X_800MF_Weights.IMAGENET1K_V1.url,
        "regnet_x_1_6gf": RegNet_X_1_6GF_Weights.IMAGENET1K_V1.url,
        "regnet_x_3_2gf": RegNet_X_3_2GF_Weights.IMAGENET1K_V1.url,
        "regnet_x_8gf": RegNet_X_8GF_Weights.IMAGENET1K_V1.url,
        "regnet_x_16gf": RegNet_X_16GF_Weights.IMAGENET1K_V1.url,
        "regnet_x_32gf": RegNet_X_32GF_Weights.IMAGENET1K_V1.url,
    }
)