"vscode:/vscode.git/clone" did not exist on "b4077af212d62222325a7a41a80f193726680f3f"
efficientnet.py 42.1 KB
Newer Older
1
2
import copy
import math
3
from dataclasses import dataclass
4
from functools import partial
5
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
6

7
8
9
10
import torch
from torch import nn, Tensor
from torchvision.ops import StochasticDepth

11
from ..ops.misc import Conv2dNormActivation, SqueezeExcitation
12
from ..transforms._presets import ImageClassification, InterpolationMode
13
from ..utils import _log_api_usage_once
14
from ._api import register_model, Weights, WeightsEnum
15
from ._meta import _IMAGENET_CATEGORIES
16
from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
17
18


19
20
__all__ = [
    "EfficientNet",
21
22
23
24
25
26
27
28
29
30
31
    "EfficientNet_B0_Weights",
    "EfficientNet_B1_Weights",
    "EfficientNet_B2_Weights",
    "EfficientNet_B3_Weights",
    "EfficientNet_B4_Weights",
    "EfficientNet_B5_Weights",
    "EfficientNet_B6_Weights",
    "EfficientNet_B7_Weights",
    "EfficientNet_V2_S_Weights",
    "EfficientNet_V2_M_Weights",
    "EfficientNet_V2_L_Weights",
32
33
34
35
36
37
38
39
    "efficientnet_b0",
    "efficientnet_b1",
    "efficientnet_b2",
    "efficientnet_b3",
    "efficientnet_b4",
    "efficientnet_b5",
    "efficientnet_b6",
    "efficientnet_b7",
40
41
42
    "efficientnet_v2_s",
    "efficientnet_v2_m",
    "efficientnet_v2_l",
43
]
44
45


46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
@dataclass
class _MBConvConfig:
    expand_ratio: float
    kernel: int
    stride: int
    input_channels: int
    out_channels: int
    num_layers: int
    block: Callable[..., nn.Module]

    @staticmethod
    def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int:
        return _make_divisible(channels * width_mult, 8, min_value)


class MBConvConfig(_MBConvConfig):
    # Stores information listed at Table 1 of the EfficientNet paper & Table 4 of the EfficientNetV2 paper
63
64
65
66
67
68
69
70
    def __init__(
        self,
        expand_ratio: float,
        kernel: int,
        stride: int,
        input_channels: int,
        out_channels: int,
        num_layers: int,
71
72
73
        width_mult: float = 1.0,
        depth_mult: float = 1.0,
        block: Optional[Callable[..., nn.Module]] = None,
74
    ) -> None:
75
76
77
78
79
80
        input_channels = self.adjust_channels(input_channels, width_mult)
        out_channels = self.adjust_channels(out_channels, width_mult)
        num_layers = self.adjust_depth(num_layers, depth_mult)
        if block is None:
            block = MBConv
        super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
81
82
83
84
85
86

    @staticmethod
    def adjust_depth(num_layers: int, depth_mult: float):
        return int(math.ceil(num_layers * depth_mult))


87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
class FusedMBConvConfig(_MBConvConfig):
    # Stores information listed at Table 4 of the EfficientNetV2 paper
    def __init__(
        self,
        expand_ratio: float,
        kernel: int,
        stride: int,
        input_channels: int,
        out_channels: int,
        num_layers: int,
        block: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        if block is None:
            block = FusedMBConv
        super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)


104
class MBConv(nn.Module):
105
106
107
108
109
110
111
    def __init__(
        self,
        cnf: MBConvConfig,
        stochastic_depth_prob: float,
        norm_layer: Callable[..., nn.Module],
        se_layer: Callable[..., nn.Module] = SqueezeExcitation,
    ) -> None:
112
113
114
        super().__init__()

        if not (1 <= cnf.stride <= 2):
115
            raise ValueError("illegal stride value")
116
117
118
119
120
121
122
123
124

        self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels

        layers: List[nn.Module] = []
        activation_layer = nn.SiLU

        # expand
        expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
        if expanded_channels != cnf.input_channels:
125
            layers.append(
126
                Conv2dNormActivation(
127
128
129
130
131
132
133
                    cnf.input_channels,
                    expanded_channels,
                    kernel_size=1,
                    norm_layer=norm_layer,
                    activation_layer=activation_layer,
                )
            )
134
135

        # depthwise
136
        layers.append(
137
            Conv2dNormActivation(
138
139
140
141
142
143
144
145
146
                expanded_channels,
                expanded_channels,
                kernel_size=cnf.kernel,
                stride=cnf.stride,
                groups=expanded_channels,
                norm_layer=norm_layer,
                activation_layer=activation_layer,
            )
        )
147
148
149

        # squeeze and excitation
        squeeze_channels = max(1, cnf.input_channels // 4)
150
        layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True)))
151
152

        # project
153
        layers.append(
154
            Conv2dNormActivation(
155
156
157
                expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
            )
        )
158
159
160
161
162
163
164
165
166
167
168
169
170

        self.block = nn.Sequential(*layers)
        self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
        self.out_channels = cnf.out_channels

    def forward(self, input: Tensor) -> Tensor:
        result = self.block(input)
        if self.use_res_connect:
            result = self.stochastic_depth(result)
            result += input
        return result


171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
class FusedMBConv(nn.Module):
    def __init__(
        self,
        cnf: FusedMBConvConfig,
        stochastic_depth_prob: float,
        norm_layer: Callable[..., nn.Module],
    ) -> None:
        super().__init__()

        if not (1 <= cnf.stride <= 2):
            raise ValueError("illegal stride value")

        self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels

        layers: List[nn.Module] = []
        activation_layer = nn.SiLU

        expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
        if expanded_channels != cnf.input_channels:
            # fused expand
            layers.append(
                Conv2dNormActivation(
                    cnf.input_channels,
                    expanded_channels,
                    kernel_size=cnf.kernel,
                    stride=cnf.stride,
                    norm_layer=norm_layer,
                    activation_layer=activation_layer,
                )
            )

            # project
            layers.append(
                Conv2dNormActivation(
                    expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
                )
            )
        else:
            layers.append(
                Conv2dNormActivation(
                    cnf.input_channels,
                    cnf.out_channels,
                    kernel_size=cnf.kernel,
                    stride=cnf.stride,
                    norm_layer=norm_layer,
                    activation_layer=activation_layer,
                )
            )

        self.block = nn.Sequential(*layers)
        self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
        self.out_channels = cnf.out_channels

    def forward(self, input: Tensor) -> Tensor:
        result = self.block(input)
        if self.use_res_connect:
            result = self.stochastic_depth(result)
            result += input
        return result


232
233
class EfficientNet(nn.Module):
    def __init__(
234
        self,
235
        inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
236
237
238
239
        dropout: float,
        stochastic_depth_prob: float = 0.2,
        num_classes: int = 1000,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
240
        last_channel: Optional[int] = None,
241
242
    ) -> None:
        """
243
        EfficientNet V1 and V2 main class
244
245

        Args:
246
            inverted_residual_setting (Sequence[Union[MBConvConfig, FusedMBConvConfig]]): Network structure
247
248
249
250
            dropout (float): The droupout probability
            stochastic_depth_prob (float): The stochastic depth probability
            num_classes (int): Number of classes
            norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
251
            last_channel (int): The number of channels on the penultimate layer
252
253
        """
        super().__init__()
Kai Zhang's avatar
Kai Zhang committed
254
        _log_api_usage_once(self)
255
256
257

        if not inverted_residual_setting:
            raise ValueError("The inverted_residual_setting should not be empty")
258
259
        elif not (
            isinstance(inverted_residual_setting, Sequence)
260
            and all([isinstance(s, _MBConvConfig) for s in inverted_residual_setting])
261
        ):
262
263
264
265
266
267
268
269
270
            raise TypeError("The inverted_residual_setting should be List[MBConvConfig]")

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        layers: List[nn.Module] = []

        # building first layer
        firstconv_output_channels = inverted_residual_setting[0].input_channels
271
        layers.append(
272
            Conv2dNormActivation(
273
274
275
                3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU
            )
        )
276
277

        # building inverted residual blocks
278
        total_stage_blocks = sum(cnf.num_layers for cnf in inverted_residual_setting)
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
        stage_block_id = 0
        for cnf in inverted_residual_setting:
            stage: List[nn.Module] = []
            for _ in range(cnf.num_layers):
                # copy to avoid modifications. shallow copy is enough
                block_cnf = copy.copy(cnf)

                # overwrite info if not the first conv in the stage
                if stage:
                    block_cnf.input_channels = block_cnf.out_channels
                    block_cnf.stride = 1

                # adjust stochastic depth probability based on the depth of the stage block
                sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks

294
                stage.append(block_cnf.block(block_cnf, sd_prob, norm_layer))
295
296
297
298
299
300
                stage_block_id += 1

            layers.append(nn.Sequential(*stage))

        # building last several layers
        lastconv_input_channels = inverted_residual_setting[-1].out_channels
301
        lastconv_output_channels = last_channel if last_channel is not None else 4 * lastconv_input_channels
302
        layers.append(
303
            Conv2dNormActivation(
304
305
306
307
308
309
310
                lastconv_input_channels,
                lastconv_output_channels,
                kernel_size=1,
                norm_layer=norm_layer,
                activation_layer=nn.SiLU,
            )
        )
311
312
313
314
315
316
317
318
319
320

        self.features = nn.Sequential(*layers)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout, inplace=True),
            nn.Linear(lastconv_output_channels, num_classes),
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
321
                nn.init.kaiming_normal_(m.weight, mode="fan_out")
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                init_range = 1.0 / math.sqrt(m.out_features)
                nn.init.uniform_(m.weight, -init_range, init_range)
                nn.init.zeros_(m.bias)

    def _forward_impl(self, x: Tensor) -> Tensor:
        x = self.features(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)

        x = self.classifier(x)

        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)


346
def _efficientnet(
347
    inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
348
    dropout: float,
349
    last_channel: Optional[int],
350
    weights: Optional[WeightsEnum],
351
352
353
    progress: bool,
    **kwargs: Any,
) -> EfficientNet:
354
355
356
    if weights is not None:
        _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))

357
    model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs)
358
359

    if weights is not None:
360
        model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
361

362
363
364
    return model


365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
def _efficientnet_conf(
    arch: str,
    **kwargs: Any,
) -> Tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]:
    inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]]
    if arch.startswith("efficientnet_b"):
        bneck_conf = partial(MBConvConfig, width_mult=kwargs.pop("width_mult"), depth_mult=kwargs.pop("depth_mult"))
        inverted_residual_setting = [
            bneck_conf(1, 3, 1, 32, 16, 1),
            bneck_conf(6, 3, 2, 16, 24, 2),
            bneck_conf(6, 5, 2, 24, 40, 2),
            bneck_conf(6, 3, 2, 40, 80, 3),
            bneck_conf(6, 5, 1, 80, 112, 3),
            bneck_conf(6, 5, 2, 112, 192, 4),
            bneck_conf(6, 3, 1, 192, 320, 1),
        ]
        last_channel = None
    elif arch.startswith("efficientnet_v2_s"):
        inverted_residual_setting = [
            FusedMBConvConfig(1, 3, 1, 24, 24, 2),
            FusedMBConvConfig(4, 3, 2, 24, 48, 4),
            FusedMBConvConfig(4, 3, 2, 48, 64, 4),
            MBConvConfig(4, 3, 2, 64, 128, 6),
            MBConvConfig(6, 3, 1, 128, 160, 9),
            MBConvConfig(6, 3, 2, 160, 256, 15),
        ]
        last_channel = 1280
    elif arch.startswith("efficientnet_v2_m"):
        inverted_residual_setting = [
            FusedMBConvConfig(1, 3, 1, 24, 24, 3),
            FusedMBConvConfig(4, 3, 2, 24, 48, 5),
            FusedMBConvConfig(4, 3, 2, 48, 80, 5),
            MBConvConfig(4, 3, 2, 80, 160, 7),
            MBConvConfig(6, 3, 1, 160, 176, 14),
            MBConvConfig(6, 3, 2, 176, 304, 18),
            MBConvConfig(6, 3, 1, 304, 512, 5),
        ]
        last_channel = 1280
    elif arch.startswith("efficientnet_v2_l"):
        inverted_residual_setting = [
            FusedMBConvConfig(1, 3, 1, 32, 32, 4),
            FusedMBConvConfig(4, 3, 2, 32, 64, 7),
            FusedMBConvConfig(4, 3, 2, 64, 96, 7),
            MBConvConfig(4, 3, 2, 96, 192, 10),
            MBConvConfig(6, 3, 1, 192, 224, 19),
            MBConvConfig(6, 3, 2, 224, 384, 25),
            MBConvConfig(6, 3, 1, 384, 640, 7),
        ]
        last_channel = 1280
    else:
        raise ValueError(f"Unsupported model type {arch}")

    return inverted_residual_setting, last_channel


420
_COMMON_META: Dict[str, Any] = {
421
422
423
424
425
426
427
    "categories": _IMAGENET_CATEGORIES,
}


_COMMON_META_V1 = {
    **_COMMON_META,
    "min_size": (1, 1),
428
    "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet-v1",
429
430
431
432
433
434
}


_COMMON_META_V2 = {
    **_COMMON_META,
    "min_size": (33, 33),
435
    "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet-v2",
436
437
438
439
440
}


class EfficientNet_B0_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
441
        # Weights ported from https://github.com/rwightman/pytorch-image-models/
442
443
444
445
446
447
448
        url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
        transforms=partial(
            ImageClassification, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_META_V1,
            "num_params": 5288548,
449
450
451
452
453
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 77.692,
                    "acc@5": 93.532,
                }
454
            },
455
            "_ops": 0.386,
Nicolas Hug's avatar
Nicolas Hug committed
456
            "_file_size": 20.451,
457
            "_docs": """These weights are ported from the original paper.""",
458
459
460
461
462
463
464
        },
    )
    DEFAULT = IMAGENET1K_V1


class EfficientNet_B1_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
465
        # Weights ported from https://github.com/rwightman/pytorch-image-models/
Nicolas Hug's avatar
Nicolas Hug committed
466
        url="https://download.pytorch.org/models/efficientnet_b1_rwightman-bac287d4.pth",
467
468
469
470
471
472
        transforms=partial(
            ImageClassification, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_META_V1,
            "num_params": 7794184,
473
474
475
476
477
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 78.642,
                    "acc@5": 94.186,
                }
478
            },
479
            "_ops": 0.687,
Nicolas Hug's avatar
Nicolas Hug committed
480
            "_file_size": 30.134,
481
            "_docs": """These weights are ported from the original paper.""",
482
483
484
485
486
487
488
489
490
491
492
        },
    )
    IMAGENET1K_V2 = Weights(
        url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth",
        transforms=partial(
            ImageClassification, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR
        ),
        meta={
            **_COMMON_META_V1,
            "num_params": 7794184,
            "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning",
493
494
495
496
497
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 79.838,
                    "acc@5": 94.934,
                }
498
            },
499
            "_ops": 0.687,
Nicolas Hug's avatar
Nicolas Hug committed
500
            "_file_size": 30.136,
501
502
503
504
505
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
506
507
508
509
510
511
512
        },
    )
    DEFAULT = IMAGENET1K_V2


class EfficientNet_B2_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
513
        # Weights ported from https://github.com/rwightman/pytorch-image-models/
514
515
516
517
518
519
520
        url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
        transforms=partial(
            ImageClassification, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_META_V1,
            "num_params": 9109994,
521
522
523
524
525
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 80.608,
                    "acc@5": 95.310,
                }
526
            },
527
            "_ops": 1.088,
Nicolas Hug's avatar
Nicolas Hug committed
528
            "_file_size": 35.174,
529
            "_docs": """These weights are ported from the original paper.""",
530
531
532
533
534
535
536
        },
    )
    DEFAULT = IMAGENET1K_V1


class EfficientNet_B3_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
537
        # Weights ported from https://github.com/rwightman/pytorch-image-models/
538
539
540
541
542
543
544
        url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
        transforms=partial(
            ImageClassification, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_META_V1,
            "num_params": 12233232,
545
546
547
548
549
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 82.008,
                    "acc@5": 96.054,
                }
550
            },
551
            "_ops": 1.827,
Nicolas Hug's avatar
Nicolas Hug committed
552
            "_file_size": 47.184,
553
            "_docs": """These weights are ported from the original paper.""",
554
555
556
557
558
559
560
        },
    )
    DEFAULT = IMAGENET1K_V1


class EfficientNet_B4_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
561
        # Weights ported from https://github.com/rwightman/pytorch-image-models/
562
563
564
565
566
567
568
        url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
        transforms=partial(
            ImageClassification, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_META_V1,
            "num_params": 19341616,
569
570
571
572
573
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 83.384,
                    "acc@5": 96.594,
                }
574
            },
575
            "_ops": 4.394,
Nicolas Hug's avatar
Nicolas Hug committed
576
            "_file_size": 74.489,
577
            "_docs": """These weights are ported from the original paper.""",
578
579
580
581
582
583
584
        },
    )
    DEFAULT = IMAGENET1K_V1


class EfficientNet_B5_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
585
        # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
586
587
588
589
590
591
592
        url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
        transforms=partial(
            ImageClassification, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_META_V1,
            "num_params": 30389784,
593
594
595
596
597
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 83.444,
                    "acc@5": 96.628,
                }
598
            },
599
            "_ops": 10.266,
Nicolas Hug's avatar
Nicolas Hug committed
600
            "_file_size": 116.864,
601
            "_docs": """These weights are ported from the original paper.""",
602
603
604
605
606
607
608
        },
    )
    DEFAULT = IMAGENET1K_V1


class EfficientNet_B6_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
609
        # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
610
611
612
613
614
615
616
        url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
        transforms=partial(
            ImageClassification, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_META_V1,
            "num_params": 43040704,
617
618
619
620
621
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 84.008,
                    "acc@5": 96.916,
                }
622
            },
623
            "_ops": 19.068,
Nicolas Hug's avatar
Nicolas Hug committed
624
            "_file_size": 165.362,
625
            "_docs": """These weights are ported from the original paper.""",
626
627
628
629
630
631
632
        },
    )
    DEFAULT = IMAGENET1K_V1


class EfficientNet_B7_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
633
        # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
634
635
636
637
638
639
640
        url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
        transforms=partial(
            ImageClassification, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC
        ),
        meta={
            **_COMMON_META_V1,
            "num_params": 66347960,
641
642
643
644
645
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 84.122,
                    "acc@5": 96.908,
                }
646
            },
647
            "_ops": 37.746,
Nicolas Hug's avatar
Nicolas Hug committed
648
            "_file_size": 254.675,
649
            "_docs": """These weights are ported from the original paper.""",
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
        },
    )
    DEFAULT = IMAGENET1K_V1


class EfficientNet_V2_S_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth",
        transforms=partial(
            ImageClassification,
            crop_size=384,
            resize_size=384,
            interpolation=InterpolationMode.BILINEAR,
        ),
        meta={
            **_COMMON_META_V2,
            "num_params": 21458488,
667
668
669
670
671
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 84.228,
                    "acc@5": 96.878,
                }
672
            },
673
            "_ops": 8.366,
Nicolas Hug's avatar
Nicolas Hug committed
674
            "_file_size": 82.704,
675
676
677
678
679
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
        },
    )
    DEFAULT = IMAGENET1K_V1


class EfficientNet_V2_M_Weights(WeightsEnum):
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth",
        transforms=partial(
            ImageClassification,
            crop_size=480,
            resize_size=480,
            interpolation=InterpolationMode.BILINEAR,
        ),
        meta={
            **_COMMON_META_V2,
            "num_params": 54139356,
697
698
699
700
701
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 85.112,
                    "acc@5": 97.156,
                }
702
            },
703
            "_ops": 24.582,
Nicolas Hug's avatar
Nicolas Hug committed
704
            "_file_size": 208.01,
705
706
707
708
709
            "_docs": """
                These weights improve upon the results of the original paper by using a modified version of TorchVision's
                `new training recipe
                <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
            """,
710
711
712
713
714
715
        },
    )
    DEFAULT = IMAGENET1K_V1


class EfficientNet_V2_L_Weights(WeightsEnum):
716
    # Weights ported from https://github.com/google/automl/tree/master/efficientnetv2
717
718
719
720
721
722
723
724
725
726
727
728
729
    IMAGENET1K_V1 = Weights(
        url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth",
        transforms=partial(
            ImageClassification,
            crop_size=480,
            resize_size=480,
            interpolation=InterpolationMode.BICUBIC,
            mean=(0.5, 0.5, 0.5),
            std=(0.5, 0.5, 0.5),
        ),
        meta={
            **_COMMON_META_V2,
            "num_params": 118515272,
730
731
732
733
734
            "_metrics": {
                "ImageNet-1K": {
                    "acc@1": 85.808,
                    "acc@5": 97.788,
                }
735
            },
736
            "_ops": 56.08,
Nicolas Hug's avatar
Nicolas Hug committed
737
            "_file_size": 454.573,
738
            "_docs": """These weights are ported from the original paper.""",
739
740
741
742
743
        },
    )
    DEFAULT = IMAGENET1K_V1


744
@register_model()
745
746
747
748
@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1))
def efficientnet_b0(
    *, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
749
750
    """EfficientNet B0 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
    Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
751
752

    Args:
753
754
755
756
757
758
759
760
761
762
763
764
765
        weights (:class:`~torchvision.models.EfficientNet_B0_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.EfficientNet_B0_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.EfficientNet_B0_Weights
        :members:
766
    """
767
768
769
    weights = EfficientNet_B0_Weights.verify(weights)

    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0)
770
771
772
    return _efficientnet(
        inverted_residual_setting, kwargs.pop("dropout", 0.2), last_channel, weights, progress, **kwargs
    )
773
774


775
@register_model()
776
777
778
779
@handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1))
def efficientnet_b1(
    *, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
780
781
    """EfficientNet B1 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
    Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
782
783

    Args:
784
785
786
787
788
789
790
791
792
793
794
795
796
        weights (:class:`~torchvision.models.EfficientNet_B1_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.EfficientNet_B1_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.EfficientNet_B1_Weights
        :members:
797
    """
798
    weights = EfficientNet_B1_Weights.verify(weights)
799

800
    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b1", width_mult=1.0, depth_mult=1.1)
801
802
803
    return _efficientnet(
        inverted_residual_setting, kwargs.pop("dropout", 0.2), last_channel, weights, progress, **kwargs
    )
804

805

806
@register_model()
807
808
809
810
@handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1))
def efficientnet_b2(
    *, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
811
812
    """EfficientNet B2 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
    Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
813
814

    Args:
815
816
817
818
819
820
821
822
823
824
825
826
827
        weights (:class:`~torchvision.models.EfficientNet_B2_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.EfficientNet_B2_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.EfficientNet_B2_Weights
        :members:
828
    """
829
830
831
    weights = EfficientNet_B2_Weights.verify(weights)

    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b2", width_mult=1.1, depth_mult=1.2)
832
833
834
    return _efficientnet(
        inverted_residual_setting, kwargs.pop("dropout", 0.3), last_channel, weights, progress, **kwargs
    )
835
836


837
@register_model()
838
839
840
841
@handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1))
def efficientnet_b3(
    *, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
842
843
    """EfficientNet B3 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
    Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
844
845

    Args:
846
847
848
849
850
851
852
853
854
855
856
857
858
        weights (:class:`~torchvision.models.EfficientNet_B3_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.EfficientNet_B3_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.EfficientNet_B3_Weights
        :members:
859
    """
860
861
862
    weights = EfficientNet_B3_Weights.verify(weights)

    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b3", width_mult=1.2, depth_mult=1.4)
863
864
865
866
867
868
869
870
    return _efficientnet(
        inverted_residual_setting,
        kwargs.pop("dropout", 0.3),
        last_channel,
        weights,
        progress,
        **kwargs,
    )
871
872


873
@register_model()
874
875
876
877
@handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1))
def efficientnet_b4(
    *, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
878
879
    """EfficientNet B4 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
    Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
880
881

    Args:
882
883
884
885
886
887
888
889
890
891
892
893
894
        weights (:class:`~torchvision.models.EfficientNet_B4_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.EfficientNet_B4_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.EfficientNet_B4_Weights
        :members:
895
    """
896
    weights = EfficientNet_B4_Weights.verify(weights)
897

898
    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b4", width_mult=1.4, depth_mult=1.8)
899
900
901
902
903
904
905
906
    return _efficientnet(
        inverted_residual_setting,
        kwargs.pop("dropout", 0.4),
        last_channel,
        weights,
        progress,
        **kwargs,
    )
907

908

909
@register_model()
910
911
912
913
@handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1))
def efficientnet_b5(
    *, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
914
915
    """EfficientNet B5 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
    Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
916
917

    Args:
918
919
920
921
922
923
924
925
926
927
928
929
930
        weights (:class:`~torchvision.models.EfficientNet_B5_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.EfficientNet_B5_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.EfficientNet_B5_Weights
        :members:
931
    """
932
933
934
    weights = EfficientNet_B5_Weights.verify(weights)

    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b5", width_mult=1.6, depth_mult=2.2)
935
    return _efficientnet(
936
        inverted_residual_setting,
937
        kwargs.pop("dropout", 0.4),
938
        last_channel,
939
        weights,
940
941
942
943
        progress,
        norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
        **kwargs,
    )
944
945


946
@register_model()
947
948
949
950
@handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.IMAGENET1K_V1))
def efficientnet_b6(
    *, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
951
952
    """EfficientNet B6 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
    Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
953
954

    Args:
955
956
957
958
959
960
961
962
963
964
965
966
967
        weights (:class:`~torchvision.models.EfficientNet_B6_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.EfficientNet_B6_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.EfficientNet_B6_Weights
        :members:
968
    """
969
970
971
    weights = EfficientNet_B6_Weights.verify(weights)

    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b6", width_mult=1.8, depth_mult=2.6)
972
    return _efficientnet(
973
        inverted_residual_setting,
974
        kwargs.pop("dropout", 0.5),
975
        last_channel,
976
        weights,
977
978
979
980
        progress,
        norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
        **kwargs,
    )
981
982


983
@register_model()
984
985
986
987
@handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.IMAGENET1K_V1))
def efficientnet_b7(
    *, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
988
989
    """EfficientNet B7 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
    Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
990
991

    Args:
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
        weights (:class:`~torchvision.models.EfficientNet_B7_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.EfficientNet_B7_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.EfficientNet_B7_Weights
        :members:
1005
    """
1006
1007
1008
    weights = EfficientNet_B7_Weights.verify(weights)

    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b7", width_mult=2.0, depth_mult=3.1)
1009
    return _efficientnet(
1010
        inverted_residual_setting,
1011
        kwargs.pop("dropout", 0.5),
1012
        last_channel,
1013
        weights,
1014
1015
1016
1017
        progress,
        norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
        **kwargs,
    )
1018
1019


1020
@register_model()
1021
1022
1023
1024
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1))
def efficientnet_v2_s(
    *, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
1025
1026
    """
    Constructs an EfficientNetV2-S architecture from
1027
    `EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
1028
1029

    Args:
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
        weights (:class:`~torchvision.models.EfficientNet_V2_S_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.EfficientNet_V2_S_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.EfficientNet_V2_S_Weights
        :members:
1043
    """
1044
1045
1046
    weights = EfficientNet_V2_S_Weights.verify(weights)

    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_s")
1047
1048
    return _efficientnet(
        inverted_residual_setting,
1049
        kwargs.pop("dropout", 0.2),
1050
        last_channel,
1051
        weights,
1052
1053
1054
1055
1056
1057
        progress,
        norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
        **kwargs,
    )


1058
@register_model()
1059
1060
1061
1062
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1))
def efficientnet_v2_m(
    *, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
1063
1064
    """
    Constructs an EfficientNetV2-M architecture from
1065
    `EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
1066
1067

    Args:
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
        weights (:class:`~torchvision.models.EfficientNet_V2_M_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.EfficientNet_V2_M_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.EfficientNet_V2_M_Weights
        :members:
1081
    """
1082
1083
1084
    weights = EfficientNet_V2_M_Weights.verify(weights)

    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_m")
1085
1086
    return _efficientnet(
        inverted_residual_setting,
1087
        kwargs.pop("dropout", 0.3),
1088
        last_channel,
1089
        weights,
1090
1091
1092
1093
1094
1095
        progress,
        norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
        **kwargs,
    )


1096
@register_model()
1097
1098
1099
1100
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1))
def efficientnet_v2_l(
    *, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
1101
1102
    """
    Constructs an EfficientNetV2-L architecture from
1103
    `EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
1104
1105

    Args:
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
        weights (:class:`~torchvision.models.EfficientNet_V2_L_Weights`, optional): The
            pretrained weights to use. See
            :class:`~torchvision.models.EfficientNet_V2_L_Weights` below for
            more details, and possible values. By default, no pre-trained
            weights are used.
        progress (bool, optional): If True, displays a progress bar of the
            download to stderr. Default is True.
        **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
            base class. Please refer to the `source code
            <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
            for more details about this class.
    .. autoclass:: torchvision.models.EfficientNet_V2_L_Weights
        :members:
1119
    """
1120
1121
1122
    weights = EfficientNet_V2_L_Weights.verify(weights)

    inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_l")
1123
1124
    return _efficientnet(
        inverted_residual_setting,
1125
        kwargs.pop("dropout", 0.4),
1126
        last_channel,
1127
        weights,
1128
1129
1130
1131
        progress,
        norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
        **kwargs,
    )