regnet.py 62 KB
Newer Older
1
2
3
import math
from collections import OrderedDict
from functools import partial
4
from typing import Any, Callable, Dict, List, Optional, Tuple
5
6

import torch
7
8
from torch import nn, Tensor

9
from ..ops.misc import Conv2dNormActivation, SqueezeExcitation
10
from ..transforms._presets import ImageClassification, InterpolationMode
11
from ..utils import _log_api_usage_once
12
from ._api import register_model, Weights, WeightsEnum
13
from ._meta import _IMAGENET_CATEGORIES
14
from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
15
16


17
18
__all__ = [
    "RegNet",
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
    "RegNet_Y_400MF_Weights",
    "RegNet_Y_800MF_Weights",
    "RegNet_Y_1_6GF_Weights",
    "RegNet_Y_3_2GF_Weights",
    "RegNet_Y_8GF_Weights",
    "RegNet_Y_16GF_Weights",
    "RegNet_Y_32GF_Weights",
    "RegNet_Y_128GF_Weights",
    "RegNet_X_400MF_Weights",
    "RegNet_X_800MF_Weights",
    "RegNet_X_1_6GF_Weights",
    "RegNet_X_3_2GF_Weights",
    "RegNet_X_8GF_Weights",
    "RegNet_X_16GF_Weights",
    "RegNet_X_32GF_Weights",
34
35
36
37
38
39
40
    "regnet_y_400mf",
    "regnet_y_800mf",
    "regnet_y_1_6gf",
    "regnet_y_3_2gf",
    "regnet_y_8gf",
    "regnet_y_16gf",
    "regnet_y_32gf",
41
    "regnet_y_128gf",
42
43
44
45
46
47
48
49
    "regnet_x_400mf",
    "regnet_x_800mf",
    "regnet_x_1_6gf",
    "regnet_x_3_2gf",
    "regnet_x_8gf",
    "regnet_x_16gf",
    "regnet_x_32gf",
]
50
51


52
class SimpleStemIN(Conv2dNormActivation):
53
54
55
56
57
58
59
60
61
    """Simple stem for ImageNet: 3x3, BN, ReLU."""

    def __init__(
        self,
        width_in: int,
        width_out: int,
        norm_layer: Callable[..., nn.Module],
        activation_layer: Callable[..., nn.Module],
    ) -> None:
62
63
64
        super().__init__(
            width_in, width_out, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=activation_layer
        )
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84


class BottleneckTransform(nn.Sequential):
    """Bottleneck transformation: 1x1, 3x3 [+SE], 1x1."""

    def __init__(
        self,
        width_in: int,
        width_out: int,
        stride: int,
        norm_layer: Callable[..., nn.Module],
        activation_layer: Callable[..., nn.Module],
        group_width: int,
        bottleneck_multiplier: float,
        se_ratio: Optional[float],
    ) -> None:
        layers: OrderedDict[str, nn.Module] = OrderedDict()
        w_b = int(round(width_out * bottleneck_multiplier))
        g = w_b // group_width

85
        layers["a"] = Conv2dNormActivation(
86
87
            width_in, w_b, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=activation_layer
        )
88
        layers["b"] = Conv2dNormActivation(
89
90
            w_b, w_b, kernel_size=3, stride=stride, groups=g, norm_layer=norm_layer, activation_layer=activation_layer
        )
91
92
93
94
95
96
97
98
99
100
101

        if se_ratio:
            # The SE reduction ratio is defined with respect to the
            # beginning of the block
            width_se_out = int(round(se_ratio * width_in))
            layers["se"] = SqueezeExcitation(
                input_channels=w_b,
                squeeze_channels=width_se_out,
                activation=activation_layer,
            )

102
        layers["c"] = Conv2dNormActivation(
103
104
            w_b, width_out, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=None
        )
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        super().__init__(layers)


class ResBottleneckBlock(nn.Module):
    """Residual bottleneck block: x + F(x), F = bottleneck transform."""

    def __init__(
        self,
        width_in: int,
        width_out: int,
        stride: int,
        norm_layer: Callable[..., nn.Module],
        activation_layer: Callable[..., nn.Module],
        group_width: int = 1,
        bottleneck_multiplier: float = 1.0,
        se_ratio: Optional[float] = None,
    ) -> None:
        super().__init__()

        # Use skip connection with projection if shape changes
        self.proj = None
        should_proj = (width_in != width_out) or (stride != 1)
        if should_proj:
128
            self.proj = Conv2dNormActivation(
129
130
                width_in, width_out, kernel_size=1, stride=stride, norm_layer=norm_layer, activation_layer=None
            )
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
        self.f = BottleneckTransform(
            width_in,
            width_out,
            stride,
            norm_layer,
            activation_layer,
            group_width,
            bottleneck_multiplier,
            se_ratio,
        )
        self.activation = activation_layer(inplace=True)

    def forward(self, x: Tensor) -> Tensor:
        if self.proj is not None:
            x = self.proj(x) + self.f(x)
        else:
            x = x + self.f(x)
        return self.activation(x)


class AnyStage(nn.Sequential):
    """AnyNet stage (sequence of blocks w/ the same output shape)."""

    def __init__(
        self,
        width_in: int,
        width_out: int,
        stride: int,
        depth: int,
        block_constructor: Callable[..., nn.Module],
        norm_layer: Callable[..., nn.Module],
        activation_layer: Callable[..., nn.Module],
        group_width: int,
        bottleneck_multiplier: float,
        se_ratio: Optional[float] = None,
        stage_index: int = 0,
    ) -> None:
        super().__init__()

        for i in range(depth):
            block = block_constructor(
                width_in if i == 0 else width_out,
                width_out,
                stride if i == 0 else 1,
                norm_layer,
                activation_layer,
                group_width,
                bottleneck_multiplier,
                se_ratio,
            )

            self.add_module(f"block{stage_index}-{i}", block)


class BlockParams:
    def __init__(
        self,
        depths: List[int],
        widths: List[int],
        group_widths: List[int],
        bottleneck_multipliers: List[float],
        strides: List[int],
        se_ratio: Optional[float] = None,
    ) -> None:
        self.depths = depths
        self.widths = widths
        self.group_widths = group_widths
        self.bottleneck_multipliers = bottleneck_multipliers
        self.strides = strides
        self.se_ratio = se_ratio

    @classmethod
    def from_init_params(
        cls,
        depth: int,
        w_0: int,
        w_a: float,
        w_m: float,
        group_width: int,
        bottleneck_multiplier: float = 1.0,
        se_ratio: Optional[float] = None,
        **kwargs: Any,
    ) -> "BlockParams":
        """
        Programatically compute all the per-block settings,
        given the RegNet parameters.

        The first step is to compute the quantized linear block parameters,
        in log space. Key parameters are:
        - `w_a` is the width progression slope
        - `w_0` is the initial width
        - `w_m` is the width stepping in the log space

        In other terms
        `log(block_width) = log(w_0) + w_m * block_capacity`,
        with `bock_capacity` ramping up following the w_0 and w_a params.
        This block width is finally quantized to multiples of 8.

        The second step is to compute the parameters per stage,
        taking into account the skip connection and the final 1x1 convolutions.
        We use the fact that the output width is constant within a stage.
        """

        QUANT = 8
        STRIDE = 2

        if w_a < 0 or w_0 <= 0 or w_m <= 1 or w_0 % 8 != 0:
            raise ValueError("Invalid RegNet settings")
        # Compute the block widths. Each stage has one unique block width
        widths_cont = torch.arange(depth) * w_a + w_0
        block_capacity = torch.round(torch.log(widths_cont / w_0) / math.log(w_m))
242
        block_widths = (torch.round(torch.divide(w_0 * torch.pow(w_m, block_capacity), QUANT)) * QUANT).int().tolist()
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
        num_stages = len(set(block_widths))

        # Convert to per stage parameters
        split_helper = zip(
            block_widths + [0],
            [0] + block_widths,
            block_widths + [0],
            [0] + block_widths,
        )
        splits = [w != wp or r != rp for w, wp, r, rp in split_helper]

        stage_widths = [w for w, t in zip(block_widths, splits[:-1]) if t]
        stage_depths = torch.diff(torch.tensor([d for d, t in enumerate(splits) if t])).int().tolist()

        strides = [STRIDE] * num_stages
        bottleneck_multipliers = [bottleneck_multiplier] * num_stages
        group_widths = [group_width] * num_stages

        # Adjust the compatibility of stage widths and group widths
        stage_widths, group_widths = cls._adjust_widths_groups_compatibilty(
            stage_widths, bottleneck_multipliers, group_widths
        )

        return cls(
            depths=stage_depths,
            widths=stage_widths,
            group_widths=group_widths,
            bottleneck_multipliers=bottleneck_multipliers,
            strides=strides,
            se_ratio=se_ratio,
        )

    def _get_expanded_params(self):
276
        return zip(self.widths, self.strides, self.depths, self.group_widths, self.bottleneck_multipliers)
277
278
279

    @staticmethod
    def _adjust_widths_groups_compatibilty(
280
281
        stage_widths: List[int], bottleneck_ratios: List[float], group_widths: List[int]
    ) -> Tuple[List[int], List[int]]:
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
        """
        Adjusts the compatibility of widths and groups,
        depending on the bottleneck ratio.
        """
        # Compute all widths for the current settings
        widths = [int(w * b) for w, b in zip(stage_widths, bottleneck_ratios)]
        group_widths_min = [min(g, w_bot) for g, w_bot in zip(group_widths, widths)]

        # Compute the adjusted widths so that stage and group widths fit
        ws_bot = [_make_divisible(w_bot, g) for w_bot, g in zip(widths, group_widths_min)]
        stage_widths = [int(w_bot / b) for w_bot, b in zip(ws_bot, bottleneck_ratios)]
        return stage_widths, group_widths_min


class RegNet(nn.Module):
    def __init__(
        self,
        block_params: BlockParams,
        num_classes: int = 1000,
        stem_width: int = 32,
        stem_type: Optional[Callable[..., nn.Module]] = None,
        block_type: Optional[Callable[..., nn.Module]] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        activation: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
308
        _log_api_usage_once(self)
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375

        if stem_type is None:
            stem_type = SimpleStemIN
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if block_type is None:
            block_type = ResBottleneckBlock
        if activation is None:
            activation = nn.ReLU

        # Ad hoc stem
        self.stem = stem_type(
            3,  # width_in
            stem_width,
            norm_layer,
            activation,
        )

        current_width = stem_width

        blocks = []
        for i, (
            width_out,
            stride,
            depth,
            group_width,
            bottleneck_multiplier,
        ) in enumerate(block_params._get_expanded_params()):
            blocks.append(
                (
                    f"block{i+1}",
                    AnyStage(
                        current_width,
                        width_out,
                        stride,
                        depth,
                        block_type,
                        norm_layer,
                        activation,
                        group_width,
                        bottleneck_multiplier,
                        block_params.se_ratio,
                        stage_index=i + 1,
                    ),
                )
            )

            current_width = width_out

        self.trunk_output = nn.Sequential(OrderedDict(blocks))

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(in_features=current_width, out_features=num_classes)

        # Performs ResNet-style weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # Note that there is no bias due to BN
                fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                nn.init.normal_(m.weight, mean=0.0, std=math.sqrt(2.0 / fan_out))
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean=0.0, std=0.01)
                nn.init.zeros_(m.bias)

376
377
378
379
380
381
382
383
384
385
    def forward(self, x: Tensor) -> Tensor:
        x = self.stem(x)
        x = self.trunk_output(x)

        x = self.avgpool(x)
        x = x.flatten(start_dim=1)
        x = self.fc(x)

        return x

386

387
388
389
390
391
392
393
394
395
def _regnet(
    block_params: BlockParams,
    weights: Optional[WeightsEnum],
    progress: bool,
    **kwargs: Any,
) -> RegNet:
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

396
397
    norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1))
    model = RegNet(block_params, norm_layer=norm_layer, **kwargs)
398
399
400
401

    if weights is not None:
        model.load_state_dict(weights.get_state_dict(progress=progress))

402
403
404
    return model


405
_COMMON_META: Dict[str, Any] = {
406
407
408
409
    "min_size": (1, 1),
    "categories": _IMAGENET_CATEGORIES,
}

410
411
412
413
414
415
_COMMON_SWAG_META = {
    **_COMMON_META,
    "recipe": "https://github.com/facebookresearch/SWAG",
    "license": "https://github.com/facebookresearch/SWAG/blob/main/LICENSE",
}

416
417
418
419
420
421
422
423
424

class RegNet_Y_400MF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 4344144,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
425
426
427
428
429
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 74.046,
                    "acc@5": 91.716,
                }
430
            },
431
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
432
433
434
435
436
437
438
439
440
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 4344144,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
441
442
443
444
445
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 75.804,
                    "acc@5": 92.742,
                }
446
            },
447
448
449
450
451
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
452
453
454
455
456
457
458
459
460
461
462
463
464
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_Y_800MF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 6432512,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
465
466
467
468
469
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 76.420,
                    "acc@5": 93.136,
                }
470
            },
471
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
472
473
474
475
476
477
478
479
480
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 6432512,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
481
482
483
484
485
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 78.828,
                    "acc@5": 94.502,
                }
486
            },
487
488
489
490
491
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
492
493
494
495
496
497
498
499
500
501
502
503
504
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_Y_1_6GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 11202430,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
505
506
507
508
509
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 77.950,
                    "acc@5": 93.966,
                }
510
            },
511
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
512
513
514
515
516
517
518
519
520
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 11202430,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
521
522
523
524
525
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 80.876,
                    "acc@5": 95.444,
                }
526
            },
527
528
529
530
531
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
532
533
534
535
536
537
538
539
540
541
542
543
544
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_Y_3_2GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 19436338,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
545
546
547
548
549
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 78.948,
                    "acc@5": 94.576,
                }
550
            },
551
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
552
553
554
555
556
557
558
559
560
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 19436338,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
561
562
563
564
565
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 81.982,
                    "acc@5": 95.972,
                }
566
            },
567
568
569
570
571
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
572
573
574
575
576
577
578
579
580
581
582
583
584
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_Y_8GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 39381472,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
585
586
587
588
589
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 80.032,
                    "acc@5": 95.048,
                }
590
            },
591
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
592
593
594
595
596
597
598
599
600
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 39381472,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
601
602
603
604
605
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 82.828,
                    "acc@5": 96.330,
                }
606
            },
607
608
609
610
611
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
612
613
614
615
616
617
618
619
620
621
622
623
624
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_Y_16GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 83590140,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
625
626
627
628
629
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 80.424,
                    "acc@5": 95.240,
                }
630
            },
631
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
632
633
634
635
636
637
638
639
640
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 83590140,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
641
642
643
644
645
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 82.886,
                    "acc@5": 96.328,
                }
646
            },
647
648
649
650
651
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
652
653
        },
    )
654
    IMAGENET1K_SWAG_E2E_V1 = Weights(
655
656
657
658
659
660
661
        url="https://download.pytorch.org/models/regnet_y_16gf_swag-43afe44d.pth",
        transforms=partial(
            ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_SWAG_META,
            "num_params": 83590140,
662
663
664
665
666
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 86.012,
                    "acc@5": 98.054,
                }
667
            },
668
669
670
671
            "_docs": """
                These weights are learnt via transfer learning by end-to-end fine-tuning the original
                `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
            """,
672
673
        },
    )
674
675
676
677
678
679
680
681
682
    IMAGENET1K_SWAG_LINEAR_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_16gf_lc_swag-f3ec0043.pth",
        transforms=partial(
            ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_SWAG_META,
            "recipe": "https://github.com/pytorch/vision/pull/5793",
            "num_params": 83590140,
683
684
685
686
687
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 83.976,
                    "acc@5": 97.244,
                }
688
            },
689
690
691
692
            "_docs": """
                These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
                weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
            """,
693
694
        },
    )
695
696
697
698
699
700
701
702
703
704
705
    DEFAULT = IMAGENET1K_V2


class RegNet_Y_32GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 145046770,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
706
707
708
709
710
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 80.878,
                    "acc@5": 95.340,
                }
711
            },
712
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
713
714
715
716
717
718
719
720
721
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 145046770,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
722
723
724
725
726
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 83.368,
                    "acc@5": 96.498,
                }
727
            },
728
729
730
731
732
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
733
734
        },
    )
735
    IMAGENET1K_SWAG_E2E_V1 = Weights(
736
737
738
739
740
741
742
        url="https://download.pytorch.org/models/regnet_y_32gf_swag-04fdfa75.pth",
        transforms=partial(
            ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_SWAG_META,
            "num_params": 145046770,
743
744
745
746
747
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 86.838,
                    "acc@5": 98.362,
                }
748
            },
749
750
751
752
            "_docs": """
                These weights are learnt via transfer learning by end-to-end fine-tuning the original
                `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
            """,
753
754
        },
    )
755
756
757
758
759
760
761
762
763
    IMAGENET1K_SWAG_LINEAR_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_32gf_lc_swag-e1583746.pth",
        transforms=partial(
            ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_SWAG_META,
            "recipe": "https://github.com/pytorch/vision/pull/5793",
            "num_params": 145046770,
764
765
766
767
768
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 84.622,
                    "acc@5": 97.480,
                }
769
            },
770
771
772
773
            "_docs": """
                These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
                weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
            """,
774
775
        },
    )
776
777
778
779
    DEFAULT = IMAGENET1K_V2


class RegNet_Y_128GF_Weights(WeightsEnum):
780
    IMAGENET1K_SWAG_E2E_V1 = Weights(
781
782
783
784
785
786
787
        url="https://download.pytorch.org/models/regnet_y_128gf_swag-c8ce3e52.pth",
        transforms=partial(
            ImageClassification, crop_size=384, resize_size=384, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_SWAG_META,
            "num_params": 644812894,
788
789
790
791
792
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 88.228,
                    "acc@5": 98.682,
                }
793
            },
794
795
796
797
            "_docs": """
                These weights are learnt via transfer learning by end-to-end fine-tuning the original
                `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
            """,
798
799
        },
    )
800
801
802
803
804
805
806
807
808
    IMAGENET1K_SWAG_LINEAR_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_y_128gf_lc_swag-cbe8ce12.pth",
        transforms=partial(
            ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_SWAG_META,
            "recipe": "https://github.com/pytorch/vision/pull/5793",
            "num_params": 644812894,
809
810
811
812
813
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 86.068,
                    "acc@5": 97.844,
                }
814
            },
815
816
817
818
            "_docs": """
                These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
                weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
            """,
819
820
821
        },
    )
    DEFAULT = IMAGENET1K_SWAG_E2E_V1
822
823
824
825
826
827
828
829
830
831


class RegNet_X_400MF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 5495976,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
832
833
834
835
836
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 72.834,
                    "acc@5": 90.950,
                }
837
            },
838
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
839
840
841
842
843
844
845
846
847
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 5495976,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
848
849
850
851
852
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 74.864,
                    "acc@5": 92.322,
                }
853
            },
854
855
856
857
858
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
859
860
861
862
863
864
865
866
867
868
869
870
871
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_X_800MF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 7259656,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
872
873
874
875
876
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 75.212,
                    "acc@5": 92.348,
                }
877
            },
878
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
879
880
881
882
883
884
885
886
887
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 7259656,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
888
889
890
891
892
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 77.522,
                    "acc@5": 93.826,
                }
893
            },
894
895
896
897
898
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
899
900
901
902
903
904
905
906
907
908
909
910
911
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_X_1_6GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 9190136,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#small-models",
912
913
914
915
916
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 77.040,
                    "acc@5": 93.440,
                }
917
            },
918
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
919
920
921
922
923
924
925
926
927
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 9190136,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres",
928
929
930
931
932
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 79.668,
                    "acc@5": 94.922,
                }
933
            },
934
935
936
937
938
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
939
940
941
942
943
944
945
946
947
948
949
950
951
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_X_3_2GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 15296552,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
952
953
954
955
956
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 78.364,
                    "acc@5": 93.992,
                }
957
            },
958
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
959
960
961
962
963
964
965
966
967
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 15296552,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
968
969
970
971
972
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 81.196,
                    "acc@5": 95.430,
                }
973
            },
974
975
976
977
978
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
979
980
981
982
983
984
985
986
987
988
989
990
991
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_X_8GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 39572648,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
992
993
994
995
996
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 79.344,
                    "acc@5": 94.686,
                }
997
            },
998
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
999
1000
1001
1002
1003
1004
1005
1006
1007
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 39572648,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
1008
1009
1010
1011
1012
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 81.682,
                    "acc@5": 95.678,
                }
1013
            },
1014
1015
1016
1017
1018
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_X_16GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 54278536,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#medium-models",
1032
1033
1034
1035
1036
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 80.058,
                    "acc@5": 94.944,
                }
1037
            },
1038
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
1039
1040
1041
1042
1043
1044
1045
1046
1047
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 54278536,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
1048
1049
1050
1051
1052
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 82.716,
                    "acc@5": 96.196,
                }
1053
            },
1054
1055
1056
1057
1058
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
        },
    )
    DEFAULT = IMAGENET1K_V2


class RegNet_X_32GF_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
        transforms=partial(ImageClassification, crop_size=224),
        meta={
            **_COMMON_META,
            "num_params": 107811560,
            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#large-models",
1072
1073
1074
1075
1076
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 80.622,
                    "acc@5": 95.248,
                }
1077
            },
1078
            "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
1079
1080
1081
1082
1083
1084
1085
1086
1087
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth",
        transforms=partial(ImageClassification, crop_size=224, resize_size=232),
        meta={
            **_COMMON_META,
            "num_params": 107811560,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe",
1088
1089
1090
1091
1092
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 83.014,
                    "acc@5": 96.288,
                }
1093
            },
1094
1095
1096
1097
1098
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
1099
1100
1101
1102
1103
        },
    )
    DEFAULT = IMAGENET1K_V2


1104
@register_model()
1105
1106
@handle_legacy_interface(weights=("pretrained", RegNet_Y_400MF_Weights.IMAGENET1K_V1))
def regnet_y_400mf(*, weights: Optional[RegNet_Y_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1107
1108
    """
    Constructs a RegNetY_400MF architecture from
1109
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1110
1111

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1112
1113
        weights (:class:`~torchvision.models.RegNet_Y_400MF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_400MF_Weights` below for more details and possible values.
1114
1115
1116
1117
1118
1119
1120
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1121
    .. autoclass:: torchvision.models.RegNet_Y_400MF_Weights
1122
        :members:
1123
    """
1124
1125
    weights = RegNet_Y_400MF_Weights.verify(weights)

1126
    params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs)
1127
    return _regnet(params, weights, progress, **kwargs)
1128
1129


1130
@register_model()
1131
1132
@handle_legacy_interface(weights=("pretrained", RegNet_Y_800MF_Weights.IMAGENET1K_V1))
def regnet_y_800mf(*, weights: Optional[RegNet_Y_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1133
1134
    """
    Constructs a RegNetY_800MF architecture from
1135
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1136
1137

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1138
1139
        weights (:class:`~torchvision.models.RegNet_Y_800MF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_800MF_Weights` below for more details and possible values.
1140
1141
1142
1143
1144
1145
1146
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1147
    .. autoclass:: torchvision.models.RegNet_Y_800MF_Weights
1148
        :members:
1149
    """
1150
1151
    weights = RegNet_Y_800MF_Weights.verify(weights)

1152
    params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs)
1153
    return _regnet(params, weights, progress, **kwargs)
1154
1155


1156
@register_model()
1157
1158
@handle_legacy_interface(weights=("pretrained", RegNet_Y_1_6GF_Weights.IMAGENET1K_V1))
def regnet_y_1_6gf(*, weights: Optional[RegNet_Y_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1159
1160
    """
    Constructs a RegNetY_1.6GF architecture from
1161
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1162
1163

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1164
1165
        weights (:class:`~torchvision.models.RegNet_Y_1_6GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_1_6GF_Weights` below for more details and possible values.
1166
1167
1168
1169
1170
1171
1172
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1173
    .. autoclass:: torchvision.models.RegNet_Y_1_6GF_Weights
1174
        :members:
1175
    """
1176
1177
    weights = RegNet_Y_1_6GF_Weights.verify(weights)

1178
1179
1180
    params = BlockParams.from_init_params(
        depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs
    )
1181
    return _regnet(params, weights, progress, **kwargs)
1182
1183


1184
@register_model()
1185
1186
@handle_legacy_interface(weights=("pretrained", RegNet_Y_3_2GF_Weights.IMAGENET1K_V1))
def regnet_y_3_2gf(*, weights: Optional[RegNet_Y_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1187
1188
    """
    Constructs a RegNetY_3.2GF architecture from
1189
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1190
1191

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1192
1193
        weights (:class:`~torchvision.models.RegNet_Y_3_2GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_3_2GF_Weights` below for more details and possible values.
1194
1195
1196
1197
1198
1199
1200
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1201
    .. autoclass:: torchvision.models.RegNet_Y_3_2GF_Weights
1202
        :members:
1203
    """
1204
1205
    weights = RegNet_Y_3_2GF_Weights.verify(weights)

1206
1207
1208
    params = BlockParams.from_init_params(
        depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs
    )
1209
    return _regnet(params, weights, progress, **kwargs)
1210
1211


1212
@register_model()
1213
1214
@handle_legacy_interface(weights=("pretrained", RegNet_Y_8GF_Weights.IMAGENET1K_V1))
def regnet_y_8gf(*, weights: Optional[RegNet_Y_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1215
1216
    """
    Constructs a RegNetY_8GF architecture from
1217
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1218
1219

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1220
1221
        weights (:class:`~torchvision.models.RegNet_Y_8GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_8GF_Weights` below for more details and possible values.
1222
1223
1224
1225
1226
1227
1228
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1229
    .. autoclass:: torchvision.models.RegNet_Y_8GF_Weights
1230
        :members:
1231
    """
1232
1233
    weights = RegNet_Y_8GF_Weights.verify(weights)

1234
1235
1236
    params = BlockParams.from_init_params(
        depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs
    )
1237
    return _regnet(params, weights, progress, **kwargs)
1238
1239


1240
@register_model()
1241
1242
@handle_legacy_interface(weights=("pretrained", RegNet_Y_16GF_Weights.IMAGENET1K_V1))
def regnet_y_16gf(*, weights: Optional[RegNet_Y_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1243
1244
    """
    Constructs a RegNetY_16GF architecture from
1245
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1246
1247

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1248
1249
        weights (:class:`~torchvision.models.RegNet_Y_16GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_16GF_Weights` below for more details and possible values.
1250
1251
1252
1253
1254
1255
1256
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1257
    .. autoclass:: torchvision.models.RegNet_Y_16GF_Weights
1258
        :members:
1259
    """
1260
1261
    weights = RegNet_Y_16GF_Weights.verify(weights)

1262
1263
1264
    params = BlockParams.from_init_params(
        depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs
    )
1265
    return _regnet(params, weights, progress, **kwargs)
1266
1267


1268
@register_model()
1269
1270
@handle_legacy_interface(weights=("pretrained", RegNet_Y_32GF_Weights.IMAGENET1K_V1))
def regnet_y_32gf(*, weights: Optional[RegNet_Y_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1271
1272
    """
    Constructs a RegNetY_32GF architecture from
1273
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1274
1275

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1276
1277
        weights (:class:`~torchvision.models.RegNet_Y_32GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_32GF_Weights` below for more details and possible values.
1278
1279
1280
1281
1282
1283
1284
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1285
    .. autoclass:: torchvision.models.RegNet_Y_32GF_Weights
1286
        :members:
1287
    """
1288
1289
    weights = RegNet_Y_32GF_Weights.verify(weights)

1290
1291
1292
    params = BlockParams.from_init_params(
        depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs
    )
1293
    return _regnet(params, weights, progress, **kwargs)
1294
1295


1296
@register_model()
1297
1298
@handle_legacy_interface(weights=("pretrained", None))
def regnet_y_128gf(*, weights: Optional[RegNet_Y_128GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1299
1300
    """
    Constructs a RegNetY_128GF architecture from
1301
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1302
1303

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1304
1305
        weights (:class:`~torchvision.models.RegNet_Y_128GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_Y_128GF_Weights` below for more details and possible values.
1306
1307
1308
1309
1310
1311
1312
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1313
    .. autoclass:: torchvision.models.RegNet_Y_128GF_Weights
1314
        :members:
1315
    """
1316
1317
    weights = RegNet_Y_128GF_Weights.verify(weights)

1318
1319
1320
    params = BlockParams.from_init_params(
        depth=27, w_0=456, w_a=160.83, w_m=2.52, group_width=264, se_ratio=0.25, **kwargs
    )
1321
    return _regnet(params, weights, progress, **kwargs)
1322
1323


1324
@register_model()
1325
1326
@handle_legacy_interface(weights=("pretrained", RegNet_X_400MF_Weights.IMAGENET1K_V1))
def regnet_x_400mf(*, weights: Optional[RegNet_X_400MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1327
1328
    """
    Constructs a RegNetX_400MF architecture from
1329
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1330
1331

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1332
1333
        weights (:class:`~torchvision.models.RegNet_X_400MF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_X_400MF_Weights` below for more details and possible values.
1334
1335
1336
1337
1338
1339
1340
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1341
    .. autoclass:: torchvision.models.RegNet_X_400MF_Weights
1342
        :members:
1343
    """
1344
1345
    weights = RegNet_X_400MF_Weights.verify(weights)

1346
    params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs)
1347
    return _regnet(params, weights, progress, **kwargs)
1348
1349


1350
@register_model()
1351
1352
@handle_legacy_interface(weights=("pretrained", RegNet_X_800MF_Weights.IMAGENET1K_V1))
def regnet_x_800mf(*, weights: Optional[RegNet_X_800MF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1353
1354
    """
    Constructs a RegNetX_800MF architecture from
1355
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.
1356
1357

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1358
1359
        weights (:class:`~torchvision.models.RegNet_X_800MF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_X_800MF_Weights` below for more details and possible values.
1360
1361
1362
1363
1364
1365
1366
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1367
    .. autoclass:: torchvision.models.RegNet_X_800MF_Weights
1368
        :members:
1369
    """
1370
1371
    weights = RegNet_X_800MF_Weights.verify(weights)

1372
    params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs)
1373
    return _regnet(params, weights, progress, **kwargs)
1374
1375


1376
@register_model()
1377
1378
@handle_legacy_interface(weights=("pretrained", RegNet_X_1_6GF_Weights.IMAGENET1K_V1))
def regnet_x_1_6gf(*, weights: Optional[RegNet_X_1_6GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1379
1380
    """
    Constructs a RegNetX_1.6GF architecture from
1381
1382
1383
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1384
1385
        weights (:class:`~torchvision.models.RegNet_X_1_6GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_X_1_6GF_Weights` below for more details and possible values.
1386
1387
1388
1389
1390
1391
1392
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1393
    .. autoclass:: torchvision.models.RegNet_X_1_6GF_Weights
1394
        :members:
1395
1396

    Args:
1397
        weights (RegNet_X_1_6GF_Weights, optional): The pretrained weights for the model
1398
1399
        progress (bool): If True, displays a progress bar of the download to stderr
    """
1400
1401
    weights = RegNet_X_1_6GF_Weights.verify(weights)

1402
    params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs)
1403
    return _regnet(params, weights, progress, **kwargs)
1404
1405


1406
@register_model()
1407
1408
@handle_legacy_interface(weights=("pretrained", RegNet_X_3_2GF_Weights.IMAGENET1K_V1))
def regnet_x_3_2gf(*, weights: Optional[RegNet_X_3_2GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1409
1410
    """
    Constructs a RegNetX_3.2GF architecture from
1411
1412
1413
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1414
1415
        weights (:class:`~torchvision.models.RegNet_X_3_2GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_X_3_2GF_Weights` below for more details and possible values.
1416
1417
1418
1419
1420
1421
1422
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1423
    .. autoclass:: torchvision.models.RegNet_X_3_2GF_Weights
1424
        :members:
1425
1426

    Args:
1427
        weights (RegNet_X_3_2GF_Weights, optional): The pretrained weights for the model
1428
1429
        progress (bool): If True, displays a progress bar of the download to stderr
    """
1430
1431
    weights = RegNet_X_3_2GF_Weights.verify(weights)

1432
    params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs)
1433
    return _regnet(params, weights, progress, **kwargs)
1434
1435


1436
@register_model()
1437
1438
@handle_legacy_interface(weights=("pretrained", RegNet_X_8GF_Weights.IMAGENET1K_V1))
def regnet_x_8gf(*, weights: Optional[RegNet_X_8GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1439
1440
    """
    Constructs a RegNetX_8GF architecture from
1441
1442
1443
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1444
1445
        weights (:class:`~torchvision.models.RegNet_X_8GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_X_8GF_Weights` below for more details and possible values.
1446
1447
1448
1449
1450
1451
1452
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1453
    .. autoclass:: torchvision.models.RegNet_X_8GF_Weights
1454
        :members:
1455
1456

    Args:
1457
        weights (RegNet_X_8GF_Weights, optional): The pretrained weights for the model
1458
1459
        progress (bool): If True, displays a progress bar of the download to stderr
    """
1460
1461
    weights = RegNet_X_8GF_Weights.verify(weights)

1462
    params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs)
1463
    return _regnet(params, weights, progress, **kwargs)
1464
1465


1466
@register_model()
1467
1468
@handle_legacy_interface(weights=("pretrained", RegNet_X_16GF_Weights.IMAGENET1K_V1))
def regnet_x_16gf(*, weights: Optional[RegNet_X_16GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1469
1470
    """
    Constructs a RegNetX_16GF architecture from
1471
1472
1473
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1474
1475
        weights (:class:`~torchvision.models.RegNet_X_16GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_X_16GF_Weights` below for more details and possible values.
1476
1477
1478
1479
1480
1481
1482
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1483
    .. autoclass:: torchvision.models.RegNet_X_16GF_Weights
1484
        :members:
1485
1486

    Args:
1487
        weights (RegNet_X_16GF_Weights, optional): The pretrained weights for the model
1488
1489
        progress (bool): If True, displays a progress bar of the download to stderr
    """
1490
1491
    weights = RegNet_X_16GF_Weights.verify(weights)

1492
    params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs)
1493
    return _regnet(params, weights, progress, **kwargs)
1494
1495


1496
@register_model()
1497
1498
@handle_legacy_interface(weights=("pretrained", RegNet_X_32GF_Weights.IMAGENET1K_V1))
def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress: bool = True, **kwargs: Any) -> RegNet:
1499
1500
    """
    Constructs a RegNetX_32GF architecture from
1501
1502
1503
    `Designing Network Design Spaces <https://arxiv.org/abs/2003.13678>`_.

    Args:
Nicolas Hug's avatar
Nicolas Hug committed
1504
1505
        weights (:class:`~torchvision.models.RegNet_X_32GF_Weights`, optional): The pretrained weights to use.
            See :class:`~torchvision.models.RegNet_X_32GF_Weights` below for more details and possible values.
1506
1507
1508
1509
1510
1511
1512
            By default, no pretrained weights are used.
        progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
        **kwargs: parameters passed to either ``torchvision.models.regnet.RegNet`` or
            ``torchvision.models.regnet.BlockParams`` class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/regnet.py>`_
            for more detail about the classes.

Nicolas Hug's avatar
Nicolas Hug committed
1513
    .. autoclass:: torchvision.models.RegNet_X_32GF_Weights
1514
        :members:
1515
1516

    Args:
1517
        weights (RegNet_X_32GF_Weights, optional): The pretrained weights for the model
1518
1519
        progress (bool): If True, displays a progress bar of the download to stderr
    """
1520
    weights = RegNet_X_32GF_Weights.verify(weights)
1521

1522
1523
    params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs)
    return _regnet(params, weights, progress, **kwargs)
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547


# The dictionary below is internal implementation detail and will be removed in v0.15
from ._utils import _ModelURLs


model_urls = _ModelURLs(
    {
        "regnet_y_400mf": RegNet_Y_400MF_Weights.IMAGENET1K_V1.url,
        "regnet_y_800mf": RegNet_Y_800MF_Weights.IMAGENET1K_V1.url,
        "regnet_y_1_6gf": RegNet_Y_1_6GF_Weights.IMAGENET1K_V1.url,
        "regnet_y_3_2gf": RegNet_Y_3_2GF_Weights.IMAGENET1K_V1.url,
        "regnet_y_8gf": RegNet_Y_8GF_Weights.IMAGENET1K_V1.url,
        "regnet_y_16gf": RegNet_Y_16GF_Weights.IMAGENET1K_V1.url,
        "regnet_y_32gf": RegNet_Y_32GF_Weights.IMAGENET1K_V1.url,
        "regnet_x_400mf": RegNet_X_400MF_Weights.IMAGENET1K_V1.url,
        "regnet_x_800mf": RegNet_X_800MF_Weights.IMAGENET1K_V1.url,
        "regnet_x_1_6gf": RegNet_X_1_6GF_Weights.IMAGENET1K_V1.url,
        "regnet_x_3_2gf": RegNet_X_3_2GF_Weights.IMAGENET1K_V1.url,
        "regnet_x_8gf": RegNet_X_8GF_Weights.IMAGENET1K_V1.url,
        "regnet_x_16gf": RegNet_X_16GF_Weights.IMAGENET1K_V1.url,
        "regnet_x_32gf": RegNet_X_32GF_Weights.IMAGENET1K_V1.url,
    }
)