regnet.py 52.3 KB
Newer Older
1
2
3
import math
from collections import OrderedDict
from functools import partial
4
from typing import Any, Callable, Dict, List, Optional, Tuple
5
6

import torch
7
8
from torch import nn, Tensor

9
from ..ops.misc import Conv2dNormActivation, SqueezeExcitation
10
from ..transforms._presets import ImageClassification, InterpolationMode
11
from ..utils import _log_api_usage_once
12
13
14
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible
15
16


17
18
__all__ = [
    "RegNet",
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    "RegNet_Y_400MF_Weights",
    "RegNet_Y_800MF_Weights",
    "RegNet_Y_1_6GF_Weights",
    "RegNet_Y_3_2GF_Weights",
    "RegNet_Y_8GF_Weights",
    "RegNet_Y_16GF_Weights",
    "RegNet_Y_32GF_Weights",
    "RegNet_Y_128GF_Weights",
    "RegNet_X_400MF_Weights",
    "RegNet_X_800MF_Weights",
    "RegNet_X_1_6GF_Weights",
    "RegNet_X_3_2GF_Weights",
    "RegNet_X_8GF_Weights",
    "RegNet_X_16GF_Weights",
    "RegNet_X_32GF_Weights",
34
35
36
37
38
39
40
    "regnet_y_400mf",
    "regnet_y_800mf",
    "regnet_y_1_6gf",
    "regnet_y_3_2gf",
    "regnet_y_8gf",
    "regnet_y_16gf",
    "regnet_y_32gf",
41
    "regnet_y_128gf",
42
43
44
45
46
47
48
49
    "regnet_x_400mf",
    "regnet_x_800mf",
    "regnet_x_1_6gf",
    "regnet_x_3_2gf",
    "regnet_x_8gf",
    "regnet_x_16gf",
    "regnet_x_32gf",
]
50
51


52
class SimpleStemIN(Conv2dNormActivation):
53
54
55
56
57
58
59
60
61
    """Simple stem for ImageNet: 3x3, BN, ReLU."""

    def __init__(
        self,
        width_in: int,
        width_out: int,
        norm_layer: Callable[..., nn.Module],
        activation_layer: Callable[..., nn.Module],
    ) -> None:
62
63
64
        super().__init__(
            width_in, width_out, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=activation_layer
        )
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84


class BottleneckTransform(nn.Sequential):
    """Bottleneck transformation: 1x1, 3x3 [+SE], 1x1."""

    def __init__(
        self,
        width_in: int,
        width_out: int,
        stride: int,
        norm_layer: Callable[..., nn.Module],
        activation_layer: Callable[..., nn.Module],
        group_width: int,
        bottleneck_multiplier: float,
        se_ratio: Optional[float],
    ) -> None:
        layers: OrderedDict[str, nn.Module] = OrderedDict()
        w_b = int(round(width_out * bottleneck_multiplier))
        g = w_b // group_width

85
        layers["a"] = Conv2dNormActivation(
86
87
            width_in, w_b, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=activation_layer
        )
88
        layers["b"] = Conv2dNormActivation(
89
90
            w_b, w_b, kernel_size=3, stride=stride, groups=g, norm_layer=norm_layer, activation_layer=activation_layer
        )
91
92
93
94
95
96
97
98
99
100
101

        if se_ratio:
            # The SE reduction ratio is defined with respect to the
            # beginning of the block
            width_se_out = int(round(se_ratio * width_in))
            layers["se"] = SqueezeExcitation(
                input_channels=w_b,
                squeeze_channels=width_se_out,
                activation=activation_layer,
            )

102
        layers["c"] = Conv2dNormActivation(
103
104
            w_b, width_out, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=None
        )
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        super().__init__(layers)


class ResBottleneckBlock(nn.Module):
    """Residual bottleneck block: x + F(x), F = bottleneck transform."""

    def __init__(
        self,
        width_in: int,
        width_out: int,
        stride: int,
        norm_layer: Callable[..., nn.Module],
        activation_layer: Callable[..., nn.Module],
        group_width: int = 1,
        bottleneck_multiplier: float = 1.0,
        se_ratio: Optional[float] = None,
    ) -> None:
        super().__init__()

        # Use skip connection with projection if shape changes
        self.proj = None
        should_proj = (width_in != width_out) or (stride != 1)
        if should_proj:
128
            self.proj = Conv2dNormActivation(
129
130
                width_in, width_out, kernel_size=1, stride=stride, norm_layer=norm_layer, activation_layer=None
            )
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
        self.f = BottleneckTransform(
            width_in,
            width_out,
            stride,
            norm_layer,
            activation_layer,
            group_width,
            bottleneck_multiplier,
            se_ratio,
        )
        self.activation = activation_layer(inplace=True)

    def forward(self, x: Tensor) -> Tensor:
        if self.proj is not None:
            x = self.proj(x) + self.f(x)
        else:
            x = x + self.f(x)
        return self.activation(x)


class AnyStage(nn.Sequential):
    """AnyNet stage (sequence of blocks w/ the same output shape)."""

    def __init__(
        self,
        width_in: int,
        width_out: int,
        stride: int,
        depth: int,
        block_constructor: Callable[..., nn.Module],
        norm_layer: Callable[..., nn.Module],
        activation_layer: Callable[..., nn.Module],
        group_width: int,
        bottleneck_multiplier: float,
        se_ratio: Optional[float] = None,
        stage_index: int = 0,
    ) -> None:
        super().__init__()

        for i in range(depth):
            block = block_constructor(
                width_in if i == 0 else width_out,
                width_out,
                stride if i == 0 else 1,
                norm_layer,
                activation_layer,
                group_width,
                bottleneck_multiplier,
                se_ratio,
            )

            self.add_module(f"block{stage_index}-{i}", block)


class BlockParams:
    def __init__(
        self,
        depths: List[int],
        widths: List[int],
        group_widths: List[int],
        bottleneck_multipliers: List[float],
        strides: List[int],
        se_ratio: Optional[float] = None,
    ) -> None:
        self.depths = depths
        self.widths = widths
        self.group_widths = group_widths
        self.bottleneck_multipliers = bottleneck_multipliers
        self.strides = strides
        self.se_ratio = se_ratio

    @classmethod
    def from_init_params(
        cls,
        depth: int,
        w_0: int,
        w_a: float,
        w_m: float,
        group_width: int,
        bottleneck_multiplier: float = 1.0,
        se_ratio: Optional[float] = None,
        **kwargs: Any,
    ) -> "BlockParams":
        """
        Programatically compute all the per-block settings,
        given the RegNet parameters.

        The first step is to compute the quantized linear block parameters,
        in log space. Key parameters are:
        - `w_a` is the width progression slope
        - `w_0` is the initial width
        - `w_m` is the width stepping in the log space

        In other terms
        `log(block_width) = log(w_0) + w_m * block_capacity`,
        with `bock_capacity` ramping up following the w_0 and w_a params.
        This block width is finally quantized to multiples of 8.

        The second step is to compute the parameters per stage,
        taking into account the skip connection and the final 1x1 convolutions.
        We use the fact that the output width is constant within a stage.
        """

        QUANT = 8
        STRIDE = 2

        if w_a < 0 or w_0 <= 0 or w_m <= 1 or w_0 % 8 != 0:
            raise ValueError("Invalid RegNet settings")
        # Compute the block widths. Each stage has one unique block width
        widths_cont = torch.arange(depth) * w_a + w_0
        block_capacity = torch.round(torch.log(widths_cont / w_0) / math.log(w_m))
242
        block_widths = (torch.round(torch.divide(w_0 * torch.pow(w_m, block_capacity), QUANT)) * QUANT).int().tolist()
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
        num_stages = len(set(block_widths))

        # Convert to per stage parameters
        split_helper = zip(
            block_widths + [0],
            [0] + block_widths,
            block_widths + [0],
            [0] + block_widths,
        )
        splits = [w != wp or r != rp for w, wp, r, rp in split_helper]

        stage_widths = [w for w, t in zip(block_widths, splits[:-1]) if t]
        stage_depths = torch.diff(torch.tensor([d for d, t in enumerate(splits) if t])).int().tolist()

        strides = [STRIDE] * num_stages
        bottleneck_multipliers = [bottleneck_multiplier] * num_stages
        group_widths = [group_width] * num_stages

        # Adjust the compatibility of stage widths and group widths
        stage_widths, group_widths = cls._adjust_widths_groups_compatibilty(
            stage_widths, bottleneck_multipliers, group_widths
        )

        return cls(
            depths=stage_depths,
            widths=stage_widths,
            group_widths=group_widths,
            bottleneck_multipliers=bottleneck_multipliers,
            strides=strides,
            se_ratio=se_ratio,
        )

    def _get_expanded_params(self):
276
        return zip(self.widths, self.strides, self.depths, self.group_widths, self.bottleneck_multipliers)
277
278
279

    @staticmethod
    def _adjust_widths_groups_compatibilty(
280
281
        stage_widths: List[int], bottleneck_ratios: List[float], group_widths: List[int]
    ) -> Tuple[List[int], List[int]]:
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
        """
        Adjusts the compatibility of widths and groups,
        depending on the bottleneck ratio.
        """
        # Compute all widths for the current settings
        widths = [int(w * b) for w, b in zip(stage_widths, bottleneck_ratios)]
        group_widths_min = [min(g, w_bot) for g, w_bot in zip(group_widths, widths)]

        # Compute the adjusted widths so that stage and group widths fit
        ws_bot = [_make_divisible(w_bot, g) for w_bot, g in zip(widths, group_widths_min)]
        stage_widths = [int(w_bot / b) for w_bot, b in zip(ws_bot, bottleneck_ratios)]
        return stage_widths, group_widths_min


class RegNet(nn.Module):
    def __init__(
        self,
        block_params: BlockParams,
        num_classes: int = 1000,
        stem_width: int = 32,
        stem_type: Optional[Callable[..., nn.Module]] = None,
        block_type: Optional[Callable[..., nn.Module]] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        activation: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
308
        _log_api_usage_once(self)
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375

        if stem_type is None:
            stem_type = SimpleStemIN
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if block_type is None:
            block_type = ResBottleneckBlock
        if activation is None:
            activation = nn.ReLU

        # Ad hoc stem
        self.stem = stem_type(
            3,  # width_in
            stem_width,
            norm_layer,
            activation,
        )

        current_width = stem_width

        blocks = []
        for i, (
            width_out,
            stride,
            depth,
            group_width,
            bottleneck_multiplier,
        ) in enumerate(block_params._get_expanded_params()):
            blocks.append(
                (
                    f"block{i+1}",
                    AnyStage(
                        current_width,
                        width_out,
                        stride,
                        depth,
                        block_type,
                        norm_layer,
                        activation,
                        group_width,
                        bottleneck_multiplier,
                        block_params.se_ratio,
                        stage_index=i + 1,
                    ),
                )
            )

            current_width = width_out

        self.trunk_output = nn.Sequential(OrderedDict(blocks))

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(in_features=current_width, out_features=num_classes)

        # Performs ResNet-style weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # Note that there is no bias due to BN
                fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                nn.init.normal_(m.weight, mean=0.0, std=math.sqrt(2.0 / fan_out))
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0.0, std=0.01)
                nn.init.zeros_(m.bias)

376
377
378
379
380
381
382
383
384
385
    def forward(self, x: Tensor) -> Tensor:
        x = self.stem(x)
        x = self.trunk_output(x)

        x = self.avgpool(x)
        x = x.flatten(start_dim=1)
        x = self.fc(x)

        return x

386

387
388
389
390
391
392
393
394
395
def _regnet(
    block_params: BlockParams,
    weights: Optional[WeightsEnum],
    progress: bool,
    **kwargs: Any,
) -> RegNet:
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

396
397
    norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1))
    model = RegNet(block_params, norm_layer=norm_layer, **kwargs)
398
399
400
401

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

402
403
404
    return model


405
_COMMON_META: Dict[str, Any] = {
406
407
408
409
    "min_size": (1, 1),
    "categories": _IMAGENET_CATEGORIES,
}

410
411
412
413
414
415
_COMMON_SWAG_META = {
    **_COMMON_META,
    "recipe": "https://github.com/facebookresearch/SWAG",
    "license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE",
}

416
417
418
419
420
421
422
423
424

class RegNet_Y_400MF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 4344144,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
425
426
427
428
            "metrics": {
                "acc@1": 74.046,
                "acc@5": 91.716,
            },
429
430
431
432
433
434
435
436
437
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 4344144,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
438
439
440
441
            "metrics": {
                "acc@1": 75.804,
                "acc@5": 92.742,
            },
442
443
444
445
446
447
448
449
450
451
452
453
454
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_Y_800MF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 6432512,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
455
456
457
458
            "metrics": {
                "acc@1": 76.420,
                "acc@5": 93.136,
            },
459
460
461
462
463
464
465
466
467
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 6432512,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
468
469
470
471
            "metrics": {
                "acc@1": 78.828,
                "acc@5": 94.502,
            },
472
473
474
475
476
477
478
479
480
481
482
483
484
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_Y_1_6GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 11202430,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
485
486
487
488
            "metrics": {
                "acc@1": 77.950,
                "acc@5": 93.966,
            },
489
490
491
492
493
494
495
496
497
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 11202430,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
498
499
500
501
            "metrics": {
                "acc@1": 80.876,
                "acc@5": 95.444,
            },
502
503
504
505
506
507
508
509
510
511
512
513
514
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_Y_3_2GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 19436338,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
515
516
517
518
            "metrics": {
                "acc@1": 78.948,
                "acc@5": 94.576,
            },
519
520
521
522
523
524
525
526
527
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 19436338,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
528
529
530
531
            "metrics": {
                "acc@1": 81.982,
                "acc@5": 95.972,
            },
532
533
534
535
536
537
538
539
540
541
542
543
544
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_Y_8GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 39381472,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
545
546
547
548
            "metrics": {
                "acc@1": 80.032,
                "acc@5": 95.048,
            },
549
550
551
552
553
554
555
556
557
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 39381472,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
558
559
560
561
            "metrics": {
                "acc@1": 82.828,
                "acc@5": 96.330,
            },
562
563
564
565
566
567
568
569
570
571
572
573
574
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_Y_16GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 83590140,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
575
576
577
578
            "metrics": {
                "acc@1": 80.424,
                "acc@5": 95.240,
            },
579
580
581
582
583
584
585
586
587
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 83590140,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
588
589
590
591
            "metrics": {
                "acc@1": 82.886,
                "acc@5": 96.328,
            },
592
593
        },
    )
594
    IMAGENET1K_SWAG_E2E_V1 = Weights(
595
596
597
598
599
600
601
        url="https://download.pytorch.org/models/regnet_y_16gf_swag-43afe44d.pth",
        transforms=partial(
            ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_SWAG_META,
            "num_params": 83590140,
602
603
604
605
            "metrics": {
                "acc@1": 86.012,
                "acc@5": 98.054,
            },
606
607
        },
    )
608
609
610
611
612
613
614
615
616
    IMAGENET1K_SWAG_LINEAR_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_16gf_lc_swag-f3ec0043.pth",
        transforms=partial(
            ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_SWAG_META,
            "recipe": "https://github.com/pytorch/vision/pull/5793",
            "num_params": 83590140,
617
618
619
620
            "metrics": {
                "acc@1": 83.976,
                "acc@5": 97.244,
            },
621
622
        },
    )
623
624
625
626
627
628
629
630
631
632
633
    DEFAULT = IMAGENET1K_V2


class RegNet_Y_32GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 145046770,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
634
635
636
637
            "metrics": {
                "acc@1": 80.878,
                "acc@5": 95.340,
            },
638
639
640
641
642
643
644
645
646
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 145046770,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
647
648
649
650
            "metrics": {
                "acc@1": 83.368,
                "acc@5": 96.498,
            },
651
652
        },
    )
653
    IMAGENET1K_SWAG_E2E_V1 = Weights(
654
655
656
657
658
659
660
        url="https://download.pytorch.org/models/regnet_y_32gf_swag-04fdfa75.pth",
        transforms=partial(
            ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_SWAG_META,
            "num_params": 145046770,
661
662
663
664
            "metrics": {
                "acc@1": 86.838,
                "acc@5": 98.362,
            },
665
666
        },
    )
667
668
669
670
671
672
673
674
675
    IMAGENET1K_SWAG_LINEAR_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_32gf_lc_swag-e1583746.pth",
        transforms=partial(
            ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_SWAG_META,
            "recipe": "https://github.com/pytorch/vision/pull/5793",
            "num_params": 145046770,
676
677
678
679
            "metrics": {
                "acc@1": 84.622,
                "acc@5": 97.480,
            },
680
681
        },
    )
682
683
684
685
    DEFAULT = IMAGENET1K_V2


class RegNet_Y_128GF_Weights(WeightsEnum):
686
    IMAGENET1K_SWAG_E2E_V1 = Weights(
687
688
689
690
691
692
693
        url="https://download.pytorch.org/models/regnet_y_128gf_swag-c8ce3e52.pth",
        transforms=partial(
            ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_SWAG_META,
            "num_params": 644812894,
694
695
696
697
            "metrics": {
                "acc@1": 88.228,
                "acc@5": 98.682,
            },
698
699
        },
    )
700
701
702
703
704
705
706
707
708
    IMAGENET1K_SWAG_LINEAR_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_128gf_lc_swag-cbe8ce12.pth",
        transforms=partial(
            ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_SWAG_META,
            "recipe": "https://github.com/pytorch/vision/pull/5793",
            "num_params": 644812894,
709
710
711
712
            "metrics": {
                "acc@1": 86.068,
                "acc@5": 97.844,
            },
713
714
715
        },
    )
    DEFAULT = IMAGENET1K_SWAG_E2E_V1
716
717
718
719
720
721
722
723
724
725


class RegNet_X_400MF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 5495976,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
726
727
728
729
            "metrics": {
                "acc@1": 72.834,
                "acc@5": 90.950,
            },
730
731
732
733
734
735
736
737
738
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 5495976,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
739
740
741
742
            "metrics": {
                "acc@1": 74.864,
                "acc@5": 92.322,
            },
743
744
745
746
747
748
749
750
751
752
753
754
755
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_X_800MF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 7259656,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
756
757
758
759
            "metrics": {
                "acc@1": 75.212,
                "acc@5": 92.348,
            },
760
761
762
763
764
765
766
767
768
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 7259656,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
769
770
771
772
            "metrics": {
                "acc@1": 77.522,
                "acc@5": 93.826,
            },
773
774
775
776
777
778
779
780
781
782
783
784
785
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_X_1_6GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 9190136,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
786
787
788
789
            "metrics": {
                "acc@1": 77.040,
                "acc@5": 93.440,
            },
790
791
792
793
794
795
796
797
798
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 9190136,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
799
800
801
802
            "metrics": {
                "acc@1": 79.668,
                "acc@5": 94.922,
            },
803
804
805
806
807
808
809
810
811
812
813
814
815
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_X_3_2GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 15296552,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
816
817
818
819
            "metrics": {
                "acc@1": 78.364,
                "acc@5": 93.992,
            },
820
821
822
823
824
825
826
827
828
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 15296552,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
829
830
831
832
            "metrics": {
                "acc@1": 81.196,
                "acc@5": 95.430,
            },
833
834
835
836
837
838
839
840
841
842
843
844
845
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_X_8GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 39572648,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
846
847
848
849
            "metrics": {
                "acc@1": 79.344,
                "acc@5": 94.686,
            },
850
851
852
853
854
855
856
857
858
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 39572648,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
859
860
861
862
            "metrics": {
                "acc@1": 81.682,
                "acc@5": 95.678,
            },
863
864
865
866
867
868
869
870
871
872
873
874
875
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_X_16GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 54278536,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
876
877
878
879
            "metrics": {
                "acc@1": 80.058,
                "acc@5": 94.944,
            },
880
881
882
883
884
885
886
887
888
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 54278536,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
889
890
891
892
            "metrics": {
                "acc@1": 82.716,
                "acc@5": 96.196,
            },
893
894
895
896
897
898
899
900
901
902
903
904
905
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_X_32GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 107811560,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
906
907
908
909
            "metrics": {
                "acc@1": 80.622,
                "acc@5": 95.248,
            },
910
911
912
913
914
915
916
917
918
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 107811560,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
919
920
921
922
            "metrics": {
                "acc@1": 83.014,
                "acc@5": 96.288,
            },
923
924
925
926
927
928
929
        },
    )
    DEFAULT = IMAGENET1K_V2


@handle_legacy_interface(weights=("pretrained", RegNet_Y_400MF_Weights.IMAGENET1K_V1))
def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
930
931
    """
    Constructs a RegNetY_400MF architecture from
932
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
933
934

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
935
936
        weights (:class:`~torchvision.models.RegNet_Y_400MF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_400MF_Weights` below for more details and possible values.
937
938
939
940
941
942
943
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
944
    .. autoclass:: torchvision.models.RegNet_Y_400MF_Weights
945
        :members:
946
    """
947
948
    weights = RegNet_Y_400MF_Weights.verify(weights)

949
    params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs)
950
    return _regnet(params, weights, progress, **kwargs)
951
952


953
954
@handle_legacy_interface(weights=("pretrained", RegNet_Y_800MF_Weights.IMAGENET1K_V1))
def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
955
956
    """
    Constructs a RegNetY_800MF architecture from
957
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
958
959

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
960
961
        weights (:class:`~torchvision.models.RegNet_Y_800MF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_800MF_Weights` below for more details and possible values.
962
963
964
965
966
967
968
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
969
    .. autoclass:: torchvision.models.RegNet_Y_800MF_Weights
970
        :members:
971
    """
972
973
    weights = RegNet_Y_800MF_Weights.verify(weights)

974
    params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs)
975
    return _regnet(params, weights, progress, **kwargs)
976
977


978
979
@handle_legacy_interface(weights=("pretrained", RegNet_Y_1_6GF_Weights.IMAGENET1K_V1))
def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
980
981
    """
    Constructs a RegNetY_1.6GF architecture from
982
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
983
984

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
985
986
        weights (:class:`~torchvision.models.RegNet_Y_1_6GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_1_6GF_Weights` below for more details and possible values.
987
988
989
990
991
992
993
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
994
    .. autoclass:: torchvision.models.RegNet_Y_1_6GF_Weights
995
        :members:
996
    """
997
998
    weights = RegNet_Y_1_6GF_Weights.verify(weights)

999
1000
1001
    params = BlockParams.from_init_params(
        depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs
    )
1002
    return _regnet(params, weights, progress, **kwargs)
1003
1004


1005
1006
@handle_legacy_interface(weights=("pretrained", RegNet_Y_3_2GF_Weights.IMAGENET1K_V1))
def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1007
1008
    """
    Constructs a RegNetY_3.2GF architecture from
1009
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1010
1011

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1012
1013
        weights (:class:`~torchvision.models.RegNet_Y_3_2GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_3_2GF_Weights` below for more details and possible values.
1014
1015
1016
1017
1018
1019
1020
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1021
    .. autoclass:: torchvision.models.RegNet_Y_3_2GF_Weights
1022
        :members:
1023
    """
1024
1025
    weights = RegNet_Y_3_2GF_Weights.verify(weights)

1026
1027
1028
    params = BlockParams.from_init_params(
        depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs
    )
1029
    return _regnet(params, weights, progress, **kwargs)
1030
1031


1032
1033
@handle_legacy_interface(weights=("pretrained", RegNet_Y_8GF_Weights.IMAGENET1K_V1))
def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1034
1035
    """
    Constructs a RegNetY_8GF architecture from
1036
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1037
1038

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1039
1040
        weights (:class:`~torchvision.models.RegNet_Y_8GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_8GF_Weights` below for more details and possible values.
1041
1042
1043
1044
1045
1046
1047
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1048
    .. autoclass:: torchvision.models.RegNet_Y_8GF_Weights
1049
        :members:
1050
    """
1051
1052
    weights = RegNet_Y_8GF_Weights.verify(weights)

1053
1054
1055
    params = BlockParams.from_init_params(
        depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs
    )
1056
    return _regnet(params, weights, progress, **kwargs)
1057
1058


1059
1060
@handle_legacy_interface(weights=("pretrained", RegNet_Y_16GF_Weights.IMAGENET1K_V1))
def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1061
1062
    """
    Constructs a RegNetY_16GF architecture from
1063
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1064
1065

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1066
1067
        weights (:class:`~torchvision.models.RegNet_Y_16GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_16GF_Weights` below for more details and possible values.
1068
1069
1070
1071
1072
1073
1074
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1075
    .. autoclass:: torchvision.models.RegNet_Y_16GF_Weights
1076
        :members:
1077
    """
1078
1079
    weights = RegNet_Y_16GF_Weights.verify(weights)

1080
1081
1082
    params = BlockParams.from_init_params(
        depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs
    )
1083
    return _regnet(params, weights, progress, **kwargs)
1084
1085


1086
1087
@handle_legacy_interface(weights=("pretrained", RegNet_Y_32GF_Weights.IMAGENET1K_V1))
def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1088
1089
    """
    Constructs a RegNetY_32GF architecture from
1090
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1091
1092

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1093
1094
        weights (:class:`~torchvision.models.RegNet_Y_32GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_32GF_Weights` below for more details and possible values.
1095
1096
1097
1098
1099
1100
1101
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1102
    .. autoclass:: torchvision.models.RegNet_Y_32GF_Weights
1103
        :members:
1104
    """
1105
1106
    weights = RegNet_Y_32GF_Weights.verify(weights)

1107
1108
1109
    params = BlockParams.from_init_params(
        depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs
    )
1110
    return _regnet(params, weights, progress, **kwargs)
1111
1112


1113
1114
@handle_legacy_interface(weights=("pretrained", None))
def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1115
1116
    """
    Constructs a RegNetY_128GF architecture from
1117
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1118
1119

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1120
1121
        weights (:class:`~torchvision.models.RegNet_Y_128GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_128GF_Weights` below for more details and possible values.
1122
1123
1124
1125
1126
1127
1128
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1129
    .. autoclass:: torchvision.models.RegNet_Y_128GF_Weights
1130
        :members:
1131
    """
1132
1133
    weights = RegNet_Y_128GF_Weights.verify(weights)

1134
1135
1136
    params = BlockParams.from_init_params(
        depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs
    )
1137
    return _regnet(params, weights, progress, **kwargs)
1138
1139


1140
1141
@handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.IMAGENET1K_V1))
def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1142
1143
    """
    Constructs a RegNetX_400MF architecture from
1144
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1145
1146

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1147
1148
        weights (:class:`~torchvision.models.RegNet_X_400MF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_X_400MF_Weights` below for more details and possible values.
1149
1150
1151
1152
1153
1154
1155
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1156
    .. autoclass:: torchvision.models.RegNet_X_400MF_Weights
1157
        :members:
1158
    """
1159
1160
    weights = RegNet_X_400MF_Weights.verify(weights)

1161
    params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs)
1162
    return _regnet(params, weights, progress, **kwargs)
1163
1164


1165
1166
@handle_legacy_interface(weights=("pretrained", RegNet_X_800MF_Weights.IMAGENET1K_V1))
def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1167
1168
    """
    Constructs a RegNetX_800MF architecture from
1169
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1170
1171

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1172
1173
        weights (:class:`~torchvision.models.RegNet_X_800MF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_X_800MF_Weights` below for more details and possible values.
1174
1175
1176
1177
1178
1179
1180
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1181
    .. autoclass:: torchvision.models.RegNet_X_800MF_Weights
1182
        :members:
1183
    """
1184
1185
    weights = RegNet_X_800MF_Weights.verify(weights)

1186
    params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs)
1187
    return _regnet(params, weights, progress, **kwargs)
1188
1189


1190
1191
@handle_legacy_interface(weights=("pretrained", RegNet_X_1_6GF_Weights.IMAGENET1K_V1))
def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1192
1193
    """
    Constructs a RegNetX_1.6GF architecture from
1194
1195
1196
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1197
1198
        weights (:class:`~torchvision.models.RegNet_X_1_6GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_X_1_6GF_Weights` below for more details and possible values.
1199
1200
1201
1202
1203
1204
1205
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1206
    .. autoclass:: torchvision.models.RegNet_X_1_6GF_Weights
1207
        :members:
1208
1209

    Args:
1210
        weights (RegNet_X_1_6GF_Weights, optional): The pretrained weights for the model
1211
1212
        progress (bool): If True, displays a progress bar of the download to stderr
    """
1213
1214
    weights = RegNet_X_1_6GF_Weights.verify(weights)

1215
    params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs)
1216
    return _regnet(params, weights, progress, **kwargs)
1217
1218


1219
1220
@handle_legacy_interface(weights=("pretrained", RegNet_X_3_2GF_Weights.IMAGENET1K_V1))
def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1221
1222
    """
    Constructs a RegNetX_3.2GF architecture from
1223
1224
1225
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1226
1227
        weights (:class:`~torchvision.models.RegNet_X_3_2GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_X_3_2GF_Weights` below for more details and possible values.
1228
1229
1230
1231
1232
1233
1234
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1235
    .. autoclass:: torchvision.models.RegNet_X_3_2GF_Weights
1236
        :members:
1237
1238

    Args:
1239
        weights (RegNet_X_3_2GF_Weights, optional): The pretrained weights for the model
1240
1241
        progress (bool): If True, displays a progress bar of the download to stderr
    """
1242
1243
    weights = RegNet_X_3_2GF_Weights.verify(weights)

1244
    params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs)
1245
    return _regnet(params, weights, progress, **kwargs)
1246
1247


1248
1249
@handle_legacy_interface(weights=("pretrained", RegNet_X_8GF_Weights.IMAGENET1K_V1))
def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1250
1251
    """
    Constructs a RegNetX_8GF architecture from
1252
1253
1254
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1255
1256
        weights (:class:`~torchvision.models.RegNet_X_8GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_X_8GF_Weights` below for more details and possible values.
1257
1258
1259
1260
1261
1262
1263
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1264
    .. autoclass:: torchvision.models.RegNet_X_8GF_Weights
1265
        :members:
1266
1267

    Args:
1268
        weights (RegNet_X_8GF_Weights, optional): The pretrained weights for the model
1269
1270
        progress (bool): If True, displays a progress bar of the download to stderr
    """
1271
1272
    weights = RegNet_X_8GF_Weights.verify(weights)

1273
    params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs)
1274
    return _regnet(params, weights, progress, **kwargs)
1275
1276


1277
1278
@handle_legacy_interface(weights=("pretrained", RegNet_X_16GF_Weights.IMAGENET1K_V1))
def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1279
1280
    """
    Constructs a RegNetX_16GF architecture from
1281
1282
1283
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1284
1285
        weights (:class:`~torchvision.models.RegNet_X_16GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_X_16GF_Weights` below for more details and possible values.
1286
1287
1288
1289
1290
1291
1292
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1293
    .. autoclass:: torchvision.models.RegNet_X_16GF_Weights
1294
        :members:
1295
1296

    Args:
1297
        weights (RegNet_X_16GF_Weights, optional): The pretrained weights for the model
1298
1299
        progress (bool): If True, displays a progress bar of the download to stderr
    """
1300
1301
    weights = RegNet_X_16GF_Weights.verify(weights)

1302
    params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs)
1303
    return _regnet(params, weights, progress, **kwargs)
1304
1305


1306
1307
@handle_legacy_interface(weights=("pretrained", RegNet_X_32GF_Weights.IMAGENET1K_V1))
def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1308
1309
    """
    Constructs a RegNetX_32GF architecture from
1310
1311
1312
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1313
1314
        weights (:class:`~torchvision.models.RegNet_X_32GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_X_32GF_Weights` below for more details and possible values.
1315
1316
1317
1318
1319
1320
1321
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1322
    .. autoclass:: torchvision.models.RegNet_X_32GF_Weights
1323
        :members:
1324
1325

    Args:
1326
        weights (RegNet_X_32GF_Weights, optional): The pretrained weights for the model
1327
1328
        progress (bool): If True, displays a progress bar of the download to stderr
    """
1329
    weights = RegNet_X_32GF_Weights.verify(weights)
1330

1331
1332
    params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs)
    return _regnet(params, weights, progress, **kwargs)
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356


# The dictionary below is internal implementation detail and will be removed in v0.15
from ._utils import _ModelURLs


model_urls = _ModelURLs(
    {
        "regnet_y_400mf": RegNet_Y_400MF_Weights.IMAGENET1K_V1.url,
        "regnet_y_800mf": RegNet_Y_800MF_Weights.IMAGENET1K_V1.url,
        "regnet_y_1_6gf": RegNet_Y_1_6GF_Weights.IMAGENET1K_V1.url,
        "regnet_y_3_2gf": RegNet_Y_3_2GF_Weights.IMAGENET1K_V1.url,
        "regnet_y_8gf": RegNet_Y_8GF_Weights.IMAGENET1K_V1.url,
        "regnet_y_16gf": RegNet_Y_16GF_Weights.IMAGENET1K_V1.url,
        "regnet_y_32gf": RegNet_Y_32GF_Weights.IMAGENET1K_V1.url,
        "regnet_x_400mf": RegNet_X_400MF_Weights.IMAGENET1K_V1.url,
        "regnet_x_800mf": RegNet_X_800MF_Weights.IMAGENET1K_V1.url,
        "regnet_x_1_6gf": RegNet_X_1_6GF_Weights.IMAGENET1K_V1.url,
        "regnet_x_3_2gf": RegNet_X_3_2GF_Weights.IMAGENET1K_V1.url,
        "regnet_x_8gf": RegNet_X_8GF_Weights.IMAGENET1K_V1.url,
        "regnet_x_16gf": RegNet_X_16GF_Weights.IMAGENET1K_V1.url,
        "regnet_x_32gf": RegNet_X_32GF_Weights.IMAGENET1K_V1.url,
    }
)