Unverified Commit e6d82f7d authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding EfficientNetV2 architecture (#5450)

* Extend the EfficientNet class to support v1 and v2.

* Refactor config/builder methods and add prototype builders

* Refactoring weight info.

* Update dropouts based on TF config ref

* Update BN eps on TF base_config

* Use Conv2dNormActivation.

* Adding pre-trained weights for EfficientNetV2-s

* Add Medium and Large weights

* Update stats with single batch run.

* Add accuracies in the docs.
parent a2b70758
...@@ -38,7 +38,7 @@ architectures for image classification: ...@@ -38,7 +38,7 @@ architectures for image classification:
- `ResNeXt`_ - `ResNeXt`_
- `Wide ResNet`_ - `Wide ResNet`_
- `MNASNet`_ - `MNASNet`_
- `EfficientNet`_ - `EfficientNet`_ v1 & v2
- `RegNet`_ - `RegNet`_
- `VisionTransformer`_ - `VisionTransformer`_
- `ConvNeXt`_ - `ConvNeXt`_
...@@ -70,6 +70,9 @@ You can construct a model with random weights by calling its constructor: ...@@ -70,6 +70,9 @@ You can construct a model with random weights by calling its constructor:
efficientnet_b5 = models.efficientnet_b5() efficientnet_b5 = models.efficientnet_b5()
efficientnet_b6 = models.efficientnet_b6() efficientnet_b6 = models.efficientnet_b6()
efficientnet_b7 = models.efficientnet_b7() efficientnet_b7 = models.efficientnet_b7()
efficientnet_v2_s = models.efficientnet_v2_s()
efficientnet_v2_m = models.efficientnet_v2_m()
efficientnet_v2_l = models.efficientnet_v2_l()
regnet_y_400mf = models.regnet_y_400mf() regnet_y_400mf = models.regnet_y_400mf()
regnet_y_800mf = models.regnet_y_800mf() regnet_y_800mf = models.regnet_y_800mf()
regnet_y_1_6gf = models.regnet_y_1_6gf() regnet_y_1_6gf = models.regnet_y_1_6gf()
...@@ -122,6 +125,9 @@ These can be constructed by passing ``pretrained=True``: ...@@ -122,6 +125,9 @@ These can be constructed by passing ``pretrained=True``:
efficientnet_b5 = models.efficientnet_b5(pretrained=True) efficientnet_b5 = models.efficientnet_b5(pretrained=True)
efficientnet_b6 = models.efficientnet_b6(pretrained=True) efficientnet_b6 = models.efficientnet_b6(pretrained=True)
efficientnet_b7 = models.efficientnet_b7(pretrained=True) efficientnet_b7 = models.efficientnet_b7(pretrained=True)
efficientnet_v2_s = models.efficientnet_v2_s(pretrained=True)
efficientnet_v2_m = models.efficientnet_v2_m(pretrained=True)
efficientnet_v2_l = models.efficientnet_v2_l(pretrained=True)
regnet_y_400mf = models.regnet_y_400mf(pretrained=True) regnet_y_400mf = models.regnet_y_400mf(pretrained=True)
regnet_y_800mf = models.regnet_y_800mf(pretrained=True) regnet_y_800mf = models.regnet_y_800mf(pretrained=True)
regnet_y_1_6gf = models.regnet_y_1_6gf(pretrained=True) regnet_y_1_6gf = models.regnet_y_1_6gf(pretrained=True)
...@@ -238,6 +244,9 @@ EfficientNet-B4 83.384 96.594 ...@@ -238,6 +244,9 @@ EfficientNet-B4 83.384 96.594
EfficientNet-B5 83.444 96.628 EfficientNet-B5 83.444 96.628
EfficientNet-B6 84.008 96.916 EfficientNet-B6 84.008 96.916
EfficientNet-B7 84.122 96.908 EfficientNet-B7 84.122 96.908
EfficientNetV2-s 84.228 96.878
EfficientNetV2-m 85.112 97.156
EfficientNetV2-l 85.810 97.792
regnet_x_400mf 72.834 90.950 regnet_x_400mf 72.834 90.950
regnet_x_800mf 75.212 92.348 regnet_x_800mf 75.212 92.348
regnet_x_1_6gf 77.040 93.440 regnet_x_1_6gf 77.040 93.440
...@@ -439,6 +448,9 @@ EfficientNet ...@@ -439,6 +448,9 @@ EfficientNet
efficientnet_b5 efficientnet_b5
efficientnet_b6 efficientnet_b6
efficientnet_b7 efficientnet_b7
efficientnet_v2_s
efficientnet_v2_m
efficientnet_v2_l
RegNet RegNet
------------ ------------
......
...@@ -13,6 +13,9 @@ from torchvision.models.efficientnet import ( ...@@ -13,6 +13,9 @@ from torchvision.models.efficientnet import (
efficientnet_b5, efficientnet_b5,
efficientnet_b6, efficientnet_b6,
efficientnet_b7, efficientnet_b7,
efficientnet_v2_s,
efficientnet_v2_m,
efficientnet_v2_l,
) )
from torchvision.models.googlenet import googlenet from torchvision.models.googlenet import googlenet
from torchvision.models.inception import inception_v3 from torchvision.models.inception import inception_v3
......
...@@ -88,7 +88,7 @@ Then we averaged the parameters of the last 3 checkpoints that improved the Acc@ ...@@ -88,7 +88,7 @@ Then we averaged the parameters of the last 3 checkpoints that improved the Acc@
and [#3354](https://github.com/pytorch/vision/pull/3354) for details. and [#3354](https://github.com/pytorch/vision/pull/3354) for details.
### EfficientNet ### EfficientNet-V1
The weights of the B0-B4 variants are ported from Ross Wightman's [timm repo](https://github.com/rwightman/pytorch-image-models/blob/01cb46a9a50e3ba4be167965b5764e9702f09b30/timm/models/efficientnet.py#L95-L108). The weights of the B0-B4 variants are ported from Ross Wightman's [timm repo](https://github.com/rwightman/pytorch-image-models/blob/01cb46a9a50e3ba4be167965b5764e9702f09b30/timm/models/efficientnet.py#L95-L108).
...@@ -114,6 +114,26 @@ torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --interpolation bic ...@@ -114,6 +114,26 @@ torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --interpolation bic
--val-resize-size 600 --val-crop-size 600 --train-crop-size 600 --test-only --pretrained --val-resize-size 600 --val-crop-size 600 --train-crop-size 600 --test-only --pretrained
``` ```
### EfficientNet-V2
```
torchrun --nproc_per_node=8 train.py \
--model $MODEL --batch-size 128 --lr 0.5 --lr-scheduler cosineannealinglr \
--lr-warmup-epochs 5 --lr-warmup-method linear --auto-augment ta_wide --epochs 600 --random-erase 0.1 \
--label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --weight-decay 0.00002 --norm-weight-decay 0.0 \
--train-crop-size $TRAIN_SIZE --model-ema --val-crop-size $EVAL_SIZE --val-resize-size $EVAL_SIZE \
--ra-sampler --ra-reps 4
```
Here `$MODEL` is one of `efficientnet_v2_s` and `efficientnet_v2_m`.
Note that the Small variant had a `$TRAIN_SIZE` of `300` and a `$EVAL_SIZE` of `384`, while the Medium `384` and `480` respectively.
Note that the above command corresponds to training on a single node with 8 GPUs.
For generatring the pre-trained weights, we trained with 4 nodes, each with 8 GPUs (for a total of 32 GPUs),
and `--batch_size 32`.
The weights of the Large variant are ported from the original paper rather than trained from scratch. See the `EfficientNet_V2_L_Weights` entry for their exact preprocessing transforms.
### RegNet ### RegNet
#### Small models #### Small models
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
import copy import copy
import math import math
import warnings
from dataclasses import dataclass
from functools import partial from functools import partial
from typing import Any, Callable, Optional, List, Sequence from typing import Any, Callable, Optional, List, Sequence, Tuple, Union
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
...@@ -23,6 +25,9 @@ __all__ = [ ...@@ -23,6 +25,9 @@ __all__ = [
"efficientnet_b5", "efficientnet_b5",
"efficientnet_b6", "efficientnet_b6",
"efficientnet_b7", "efficientnet_b7",
"efficientnet_v2_s",
"efficientnet_v2_m",
"efficientnet_v2_l",
] ]
...@@ -37,11 +42,31 @@ model_urls = { ...@@ -37,11 +42,31 @@ model_urls = {
"efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", "efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
"efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", "efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
"efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", "efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
# Weights trained with TorchVision
"efficientnet_v2_s": "https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth",
"efficientnet_v2_m": "https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth",
# Weights ported from TF
"efficientnet_v2_l": "https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth",
} }
class MBConvConfig: @dataclass
# Stores information listed at Table 1 of the EfficientNet paper class _MBConvConfig:
expand_ratio: float
kernel: int
stride: int
input_channels: int
out_channels: int
num_layers: int
block: Callable[..., nn.Module]
@staticmethod
def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int:
return _make_divisible(channels * width_mult, 8, min_value)
class MBConvConfig(_MBConvConfig):
# Stores information listed at Table 1 of the EfficientNet paper & Table 4 of the EfficientNetV2 paper
def __init__( def __init__(
self, self,
expand_ratio: float, expand_ratio: float,
...@@ -50,38 +75,39 @@ class MBConvConfig: ...@@ -50,38 +75,39 @@ class MBConvConfig:
input_channels: int, input_channels: int,
out_channels: int, out_channels: int,
num_layers: int, num_layers: int,
width_mult: float, width_mult: float = 1.0,
depth_mult: float, depth_mult: float = 1.0,
block: Optional[Callable[..., nn.Module]] = None,
) -> None: ) -> None:
self.expand_ratio = expand_ratio input_channels = self.adjust_channels(input_channels, width_mult)
self.kernel = kernel out_channels = self.adjust_channels(out_channels, width_mult)
self.stride = stride num_layers = self.adjust_depth(num_layers, depth_mult)
self.input_channels = self.adjust_channels(input_channels, width_mult) if block is None:
self.out_channels = self.adjust_channels(out_channels, width_mult) block = MBConv
self.num_layers = self.adjust_depth(num_layers, depth_mult) super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"expand_ratio={self.expand_ratio}"
f", kernel={self.kernel}"
f", stride={self.stride}"
f", input_channels={self.input_channels}"
f", out_channels={self.out_channels}"
f", num_layers={self.num_layers}"
f")"
)
return s
@staticmethod
def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int:
return _make_divisible(channels * width_mult, 8, min_value)
@staticmethod @staticmethod
def adjust_depth(num_layers: int, depth_mult: float): def adjust_depth(num_layers: int, depth_mult: float):
return int(math.ceil(num_layers * depth_mult)) return int(math.ceil(num_layers * depth_mult))
class FusedMBConvConfig(_MBConvConfig):
# Stores information listed at Table 4 of the EfficientNetV2 paper
def __init__(
self,
expand_ratio: float,
kernel: int,
stride: int,
input_channels: int,
out_channels: int,
num_layers: int,
block: Optional[Callable[..., nn.Module]] = None,
) -> None:
if block is None:
block = FusedMBConv
super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
class MBConv(nn.Module): class MBConv(nn.Module):
def __init__( def __init__(
self, self,
...@@ -149,27 +175,88 @@ class MBConv(nn.Module): ...@@ -149,27 +175,88 @@ class MBConv(nn.Module):
return result return result
class FusedMBConv(nn.Module):
def __init__(
self,
cnf: FusedMBConvConfig,
stochastic_depth_prob: float,
norm_layer: Callable[..., nn.Module],
) -> None:
super().__init__()
if not (1 <= cnf.stride <= 2):
raise ValueError("illegal stride value")
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
layers: List[nn.Module] = []
activation_layer = nn.SiLU
expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
if expanded_channels != cnf.input_channels:
# fused expand
layers.append(
Conv2dNormActivation(
cnf.input_channels,
expanded_channels,
kernel_size=cnf.kernel,
stride=cnf.stride,
norm_layer=norm_layer,
activation_layer=activation_layer,
)
)
# project
layers.append(
Conv2dNormActivation(
expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
)
)
else:
layers.append(
Conv2dNormActivation(
cnf.input_channels,
cnf.out_channels,
kernel_size=cnf.kernel,
stride=cnf.stride,
norm_layer=norm_layer,
activation_layer=activation_layer,
)
)
self.block = nn.Sequential(*layers)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
self.out_channels = cnf.out_channels
def forward(self, input: Tensor) -> Tensor:
result = self.block(input)
if self.use_res_connect:
result = self.stochastic_depth(result)
result += input
return result
class EfficientNet(nn.Module): class EfficientNet(nn.Module):
def __init__( def __init__(
self, self,
inverted_residual_setting: List[MBConvConfig], inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
dropout: float, dropout: float,
stochastic_depth_prob: float = 0.2, stochastic_depth_prob: float = 0.2,
num_classes: int = 1000, num_classes: int = 1000,
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None,
last_channel: Optional[int] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
""" """
EfficientNet main class EfficientNet V1 and V2 main class
Args: Args:
inverted_residual_setting (List[MBConvConfig]): Network structure inverted_residual_setting (Sequence[Union[MBConvConfig, FusedMBConvConfig]]): Network structure
dropout (float): The droupout probability dropout (float): The droupout probability
stochastic_depth_prob (float): The stochastic depth probability stochastic_depth_prob (float): The stochastic depth probability
num_classes (int): Number of classes num_classes (int): Number of classes
block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet
norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
last_channel (int): The number of channels on the penultimate layer
""" """
super().__init__() super().__init__()
_log_api_usage_once(self) _log_api_usage_once(self)
...@@ -178,12 +265,19 @@ class EfficientNet(nn.Module): ...@@ -178,12 +265,19 @@ class EfficientNet(nn.Module):
raise ValueError("The inverted_residual_setting should not be empty") raise ValueError("The inverted_residual_setting should not be empty")
elif not ( elif not (
isinstance(inverted_residual_setting, Sequence) isinstance(inverted_residual_setting, Sequence)
and all([isinstance(s, MBConvConfig) for s in inverted_residual_setting]) and all([isinstance(s, _MBConvConfig) for s in inverted_residual_setting])
): ):
raise TypeError("The inverted_residual_setting should be List[MBConvConfig]") raise TypeError("The inverted_residual_setting should be List[MBConvConfig]")
if block is None: if "block" in kwargs:
block = MBConv warnings.warn(
"The parameter 'block' is deprecated since 0.13 and will be removed 0.15. "
"Please pass this information on 'MBConvConfig.block' instead."
)
if kwargs["block"] is not None:
for s in inverted_residual_setting:
if isinstance(s, MBConvConfig):
s.block = kwargs["block"]
if norm_layer is None: if norm_layer is None:
norm_layer = nn.BatchNorm2d norm_layer = nn.BatchNorm2d
...@@ -215,14 +309,14 @@ class EfficientNet(nn.Module): ...@@ -215,14 +309,14 @@ class EfficientNet(nn.Module):
# adjust stochastic depth probability based on the depth of the stage block # adjust stochastic depth probability based on the depth of the stage block
sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks
stage.append(block(block_cnf, sd_prob, norm_layer)) stage.append(block_cnf.block(block_cnf, sd_prob, norm_layer))
stage_block_id += 1 stage_block_id += 1
layers.append(nn.Sequential(*stage)) layers.append(nn.Sequential(*stage))
# building last several layers # building last several layers
lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_input_channels = inverted_residual_setting[-1].out_channels
lastconv_output_channels = 4 * lastconv_input_channels lastconv_output_channels = last_channel if last_channel is not None else 4 * lastconv_input_channels
layers.append( layers.append(
Conv2dNormActivation( Conv2dNormActivation(
lastconv_input_channels, lastconv_input_channels,
...@@ -269,24 +363,14 @@ class EfficientNet(nn.Module): ...@@ -269,24 +363,14 @@ class EfficientNet(nn.Module):
def _efficientnet( def _efficientnet(
arch: str, arch: str,
width_mult: float, inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
depth_mult: float,
dropout: float, dropout: float,
last_channel: Optional[int],
pretrained: bool, pretrained: bool,
progress: bool, progress: bool,
**kwargs: Any, **kwargs: Any,
) -> EfficientNet: ) -> EfficientNet:
bneck_conf = partial(MBConvConfig, width_mult=width_mult, depth_mult=depth_mult) model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs)
inverted_residual_setting = [
bneck_conf(1, 3, 1, 32, 16, 1),
bneck_conf(6, 3, 2, 16, 24, 2),
bneck_conf(6, 5, 2, 24, 40, 2),
bneck_conf(6, 3, 2, 40, 80, 3),
bneck_conf(6, 5, 1, 80, 112, 3),
bneck_conf(6, 5, 2, 112, 192, 4),
bneck_conf(6, 3, 1, 192, 320, 1),
]
model = EfficientNet(inverted_residual_setting, dropout, **kwargs)
if pretrained: if pretrained:
if model_urls.get(arch, None) is None: if model_urls.get(arch, None) is None:
raise ValueError(f"No checkpoint is available for model type {arch}") raise ValueError(f"No checkpoint is available for model type {arch}")
...@@ -295,6 +379,61 @@ def _efficientnet( ...@@ -295,6 +379,61 @@ def _efficientnet(
return model return model
def _efficientnet_conf(
arch: str,
**kwargs: Any,
) -> Tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]:
inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]]
if arch.startswith("efficientnet_b"):
bneck_conf = partial(MBConvConfig, width_mult=kwargs.pop("width_mult"), depth_mult=kwargs.pop("depth_mult"))
inverted_residual_setting = [
bneck_conf(1, 3, 1, 32, 16, 1),
bneck_conf(6, 3, 2, 16, 24, 2),
bneck_conf(6, 5, 2, 24, 40, 2),
bneck_conf(6, 3, 2, 40, 80, 3),
bneck_conf(6, 5, 1, 80, 112, 3),
bneck_conf(6, 5, 2, 112, 192, 4),
bneck_conf(6, 3, 1, 192, 320, 1),
]
last_channel = None
elif arch.startswith("efficientnet_v2_s"):
inverted_residual_setting = [
FusedMBConvConfig(1, 3, 1, 24, 24, 2),
FusedMBConvConfig(4, 3, 2, 24, 48, 4),
FusedMBConvConfig(4, 3, 2, 48, 64, 4),
MBConvConfig(4, 3, 2, 64, 128, 6),
MBConvConfig(6, 3, 1, 128, 160, 9),
MBConvConfig(6, 3, 2, 160, 256, 15),
]
last_channel = 1280
elif arch.startswith("efficientnet_v2_m"):
inverted_residual_setting = [
FusedMBConvConfig(1, 3, 1, 24, 24, 3),
FusedMBConvConfig(4, 3, 2, 24, 48, 5),
FusedMBConvConfig(4, 3, 2, 48, 80, 5),
MBConvConfig(4, 3, 2, 80, 160, 7),
MBConvConfig(6, 3, 1, 160, 176, 14),
MBConvConfig(6, 3, 2, 176, 304, 18),
MBConvConfig(6, 3, 1, 304, 512, 5),
]
last_channel = 1280
elif arch.startswith("efficientnet_v2_l"):
inverted_residual_setting = [
FusedMBConvConfig(1, 3, 1, 32, 32, 4),
FusedMBConvConfig(4, 3, 2, 32, 64, 7),
FusedMBConvConfig(4, 3, 2, 64, 96, 7),
MBConvConfig(4, 3, 2, 96, 192, 10),
MBConvConfig(6, 3, 1, 192, 224, 19),
MBConvConfig(6, 3, 2, 224, 384, 25),
MBConvConfig(6, 3, 1, 384, 640, 7),
]
last_channel = 1280
else:
raise ValueError(f"Unsupported model type {arch}")
return inverted_residual_setting, last_channel
def efficientnet_b0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: def efficientnet_b0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
""" """
Constructs a EfficientNet B0 architecture from Constructs a EfficientNet B0 architecture from
...@@ -304,7 +443,9 @@ def efficientnet_b0(pretrained: bool = False, progress: bool = True, **kwargs: A ...@@ -304,7 +443,9 @@ def efficientnet_b0(pretrained: bool = False, progress: bool = True, **kwargs: A
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _efficientnet("efficientnet_b0", 1.0, 1.0, 0.2, pretrained, progress, **kwargs) arch = "efficientnet_b0"
inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.0, depth_mult=1.0)
return _efficientnet(arch, inverted_residual_setting, 0.2, last_channel, pretrained, progress, **kwargs)
def efficientnet_b1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: def efficientnet_b1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
...@@ -316,7 +457,9 @@ def efficientnet_b1(pretrained: bool = False, progress: bool = True, **kwargs: A ...@@ -316,7 +457,9 @@ def efficientnet_b1(pretrained: bool = False, progress: bool = True, **kwargs: A
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _efficientnet("efficientnet_b1", 1.0, 1.1, 0.2, pretrained, progress, **kwargs) arch = "efficientnet_b1"
inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.0, depth_mult=1.1)
return _efficientnet(arch, inverted_residual_setting, 0.2, last_channel, pretrained, progress, **kwargs)
def efficientnet_b2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: def efficientnet_b2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
...@@ -328,7 +471,9 @@ def efficientnet_b2(pretrained: bool = False, progress: bool = True, **kwargs: A ...@@ -328,7 +471,9 @@ def efficientnet_b2(pretrained: bool = False, progress: bool = True, **kwargs: A
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _efficientnet("efficientnet_b2", 1.1, 1.2, 0.3, pretrained, progress, **kwargs) arch = "efficientnet_b2"
inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.1, depth_mult=1.2)
return _efficientnet(arch, inverted_residual_setting, 0.3, last_channel, pretrained, progress, **kwargs)
def efficientnet_b3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: def efficientnet_b3(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
...@@ -340,7 +485,9 @@ def efficientnet_b3(pretrained: bool = False, progress: bool = True, **kwargs: A ...@@ -340,7 +485,9 @@ def efficientnet_b3(pretrained: bool = False, progress: bool = True, **kwargs: A
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _efficientnet("efficientnet_b3", 1.2, 1.4, 0.3, pretrained, progress, **kwargs) arch = "efficientnet_b3"
inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.2, depth_mult=1.4)
return _efficientnet(arch, inverted_residual_setting, 0.3, last_channel, pretrained, progress, **kwargs)
def efficientnet_b4(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: def efficientnet_b4(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
...@@ -352,7 +499,9 @@ def efficientnet_b4(pretrained: bool = False, progress: bool = True, **kwargs: A ...@@ -352,7 +499,9 @@ def efficientnet_b4(pretrained: bool = False, progress: bool = True, **kwargs: A
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
return _efficientnet("efficientnet_b4", 1.4, 1.8, 0.4, pretrained, progress, **kwargs) arch = "efficientnet_b4"
inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.4, depth_mult=1.8)
return _efficientnet(arch, inverted_residual_setting, 0.4, last_channel, pretrained, progress, **kwargs)
def efficientnet_b5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: def efficientnet_b5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
...@@ -364,11 +513,13 @@ def efficientnet_b5(pretrained: bool = False, progress: bool = True, **kwargs: A ...@@ -364,11 +513,13 @@ def efficientnet_b5(pretrained: bool = False, progress: bool = True, **kwargs: A
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
arch = "efficientnet_b5"
inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.6, depth_mult=2.2)
return _efficientnet( return _efficientnet(
"efficientnet_b5", arch,
1.6, inverted_residual_setting,
2.2,
0.4, 0.4,
last_channel,
pretrained, pretrained,
progress, progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
...@@ -385,11 +536,13 @@ def efficientnet_b6(pretrained: bool = False, progress: bool = True, **kwargs: A ...@@ -385,11 +536,13 @@ def efficientnet_b6(pretrained: bool = False, progress: bool = True, **kwargs: A
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
arch = "efficientnet_b6"
inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=1.8, depth_mult=2.6)
return _efficientnet( return _efficientnet(
"efficientnet_b6", arch,
1.8, inverted_residual_setting,
2.6,
0.5, 0.5,
last_channel,
pretrained, pretrained,
progress, progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
...@@ -406,13 +559,84 @@ def efficientnet_b7(pretrained: bool = False, progress: bool = True, **kwargs: A ...@@ -406,13 +559,84 @@ def efficientnet_b7(pretrained: bool = False, progress: bool = True, **kwargs: A
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr progress (bool): If True, displays a progress bar of the download to stderr
""" """
arch = "efficientnet_b7"
inverted_residual_setting, last_channel = _efficientnet_conf(arch, width_mult=2.0, depth_mult=3.1)
return _efficientnet( return _efficientnet(
"efficientnet_b7", arch,
2.0, inverted_residual_setting,
3.1,
0.5, 0.5,
last_channel,
pretrained, pretrained,
progress, progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
**kwargs, **kwargs,
) )
def efficientnet_v2_s(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
"""
Constructs an EfficientNetV2-S architecture from
`"EfficientNetV2: Smaller Models and Faster Training" <https://arxiv.org/abs/2104.00298>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
arch = "efficientnet_v2_s"
inverted_residual_setting, last_channel = _efficientnet_conf(arch)
return _efficientnet(
arch,
inverted_residual_setting,
0.2,
last_channel,
pretrained,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
**kwargs,
)
def efficientnet_v2_m(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
"""
Constructs an EfficientNetV2-M architecture from
`"EfficientNetV2: Smaller Models and Faster Training" <https://arxiv.org/abs/2104.00298>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
arch = "efficientnet_v2_m"
inverted_residual_setting, last_channel = _efficientnet_conf(arch)
return _efficientnet(
arch,
inverted_residual_setting,
0.3,
last_channel,
pretrained,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
**kwargs,
)
def efficientnet_v2_l(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet:
"""
Constructs an EfficientNetV2-L architecture from
`"EfficientNetV2: Smaller Models and Faster Training" <https://arxiv.org/abs/2104.00298>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
arch = "efficientnet_v2_l"
inverted_residual_setting, last_channel = _efficientnet_conf(arch)
return _efficientnet(
arch,
inverted_residual_setting,
0.4,
last_channel,
pretrained,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
**kwargs,
)
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional, Sequence, Union
from torch import nn from torch import nn
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.efficientnet import EfficientNet, MBConvConfig from ...models.efficientnet import EfficientNet, MBConvConfig, FusedMBConvConfig, _efficientnet_conf
from ._api import WeightsEnum, Weights from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param from ._utils import handle_legacy_interface, _ovewrite_named_param
...@@ -21,6 +21,9 @@ __all__ = [ ...@@ -21,6 +21,9 @@ __all__ = [
"EfficientNet_B5_Weights", "EfficientNet_B5_Weights",
"EfficientNet_B6_Weights", "EfficientNet_B6_Weights",
"EfficientNet_B7_Weights", "EfficientNet_B7_Weights",
"EfficientNet_V2_S_Weights",
"EfficientNet_V2_M_Weights",
"EfficientNet_V2_L_Weights",
"efficientnet_b0", "efficientnet_b0",
"efficientnet_b1", "efficientnet_b1",
"efficientnet_b2", "efficientnet_b2",
...@@ -29,13 +32,16 @@ __all__ = [ ...@@ -29,13 +32,16 @@ __all__ = [
"efficientnet_b5", "efficientnet_b5",
"efficientnet_b6", "efficientnet_b6",
"efficientnet_b7", "efficientnet_b7",
"efficientnet_v2_s",
"efficientnet_v2_m",
"efficientnet_v2_l",
] ]
def _efficientnet( def _efficientnet(
width_mult: float, inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
depth_mult: float,
dropout: float, dropout: float,
last_channel: Optional[int],
weights: Optional[WeightsEnum], weights: Optional[WeightsEnum],
progress: bool, progress: bool,
**kwargs: Any, **kwargs: Any,
...@@ -43,18 +49,7 @@ def _efficientnet( ...@@ -43,18 +49,7 @@ def _efficientnet(
if weights is not None: if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
bneck_conf = partial(MBConvConfig, width_mult=width_mult, depth_mult=depth_mult) model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs)
inverted_residual_setting = [
bneck_conf(1, 3, 1, 32, 16, 1),
bneck_conf(6, 3, 2, 16, 24, 2),
bneck_conf(6, 5, 2, 24, 40, 2),
bneck_conf(6, 3, 2, 40, 80, 3),
bneck_conf(6, 5, 1, 80, 112, 3),
bneck_conf(6, 5, 2, 112, 192, 4),
bneck_conf(6, 3, 1, 192, 320, 1),
]
model = EfficientNet(inverted_residual_setting, dropout, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress))
...@@ -64,12 +59,26 @@ def _efficientnet( ...@@ -64,12 +59,26 @@ def _efficientnet(
_COMMON_META = { _COMMON_META = {
"task": "image_classification", "task": "image_classification",
"categories": _IMAGENET_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet",
}
_COMMON_META_V1 = {
**_COMMON_META,
"architecture": "EfficientNet", "architecture": "EfficientNet",
"publication_year": 2019, "publication_year": 2019,
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BICUBIC, "interpolation": InterpolationMode.BICUBIC,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet", "min_size": (1, 1),
}
_COMMON_META_V2 = {
**_COMMON_META,
"architecture": "EfficientNetV2",
"publication_year": 2021,
"interpolation": InterpolationMode.BILINEAR,
"min_size": (33, 33),
} }
...@@ -78,7 +87,7 @@ class EfficientNet_B0_Weights(WeightsEnum): ...@@ -78,7 +87,7 @@ class EfficientNet_B0_Weights(WeightsEnum):
url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_COMMON_META, **_COMMON_META_V1,
"num_params": 5288548, "num_params": 5288548,
"size": (224, 224), "size": (224, 224),
"acc@1": 77.692, "acc@1": 77.692,
...@@ -93,7 +102,7 @@ class EfficientNet_B1_Weights(WeightsEnum): ...@@ -93,7 +102,7 @@ class EfficientNet_B1_Weights(WeightsEnum):
url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
transforms=partial(ImageNetEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_COMMON_META, **_COMMON_META_V1,
"num_params": 7794184, "num_params": 7794184,
"size": (240, 240), "size": (240, 240),
"acc@1": 78.642, "acc@1": 78.642,
...@@ -104,7 +113,7 @@ class EfficientNet_B1_Weights(WeightsEnum): ...@@ -104,7 +113,7 @@ class EfficientNet_B1_Weights(WeightsEnum):
url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth", url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth",
transforms=partial(ImageNetEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR), transforms=partial(ImageNetEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR),
meta={ meta={
**_COMMON_META, **_COMMON_META_V1,
"num_params": 7794184, "num_params": 7794184,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning", "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning",
"interpolation": InterpolationMode.BILINEAR, "interpolation": InterpolationMode.BILINEAR,
...@@ -121,7 +130,7 @@ class EfficientNet_B2_Weights(WeightsEnum): ...@@ -121,7 +130,7 @@ class EfficientNet_B2_Weights(WeightsEnum):
url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
transforms=partial(ImageNetEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_COMMON_META, **_COMMON_META_V1,
"num_params": 9109994, "num_params": 9109994,
"size": (288, 288), "size": (288, 288),
"acc@1": 80.608, "acc@1": 80.608,
...@@ -136,7 +145,7 @@ class EfficientNet_B3_Weights(WeightsEnum): ...@@ -136,7 +145,7 @@ class EfficientNet_B3_Weights(WeightsEnum):
url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
transforms=partial(ImageNetEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_COMMON_META, **_COMMON_META_V1,
"num_params": 12233232, "num_params": 12233232,
"size": (300, 300), "size": (300, 300),
"acc@1": 82.008, "acc@1": 82.008,
...@@ -151,7 +160,7 @@ class EfficientNet_B4_Weights(WeightsEnum): ...@@ -151,7 +160,7 @@ class EfficientNet_B4_Weights(WeightsEnum):
url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
transforms=partial(ImageNetEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_COMMON_META, **_COMMON_META_V1,
"num_params": 19341616, "num_params": 19341616,
"size": (380, 380), "size": (380, 380),
"acc@1": 83.384, "acc@1": 83.384,
...@@ -166,7 +175,7 @@ class EfficientNet_B5_Weights(WeightsEnum): ...@@ -166,7 +175,7 @@ class EfficientNet_B5_Weights(WeightsEnum):
url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
transforms=partial(ImageNetEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_COMMON_META, **_COMMON_META_V1,
"num_params": 30389784, "num_params": 30389784,
"size": (456, 456), "size": (456, 456),
"acc@1": 83.444, "acc@1": 83.444,
...@@ -181,7 +190,7 @@ class EfficientNet_B6_Weights(WeightsEnum): ...@@ -181,7 +190,7 @@ class EfficientNet_B6_Weights(WeightsEnum):
url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
transforms=partial(ImageNetEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_COMMON_META, **_COMMON_META_V1,
"num_params": 43040704, "num_params": 43040704,
"size": (528, 528), "size": (528, 528),
"acc@1": 84.008, "acc@1": 84.008,
...@@ -196,7 +205,7 @@ class EfficientNet_B7_Weights(WeightsEnum): ...@@ -196,7 +205,7 @@ class EfficientNet_B7_Weights(WeightsEnum):
url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
transforms=partial(ImageNetEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC), transforms=partial(ImageNetEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC),
meta={ meta={
**_COMMON_META, **_COMMON_META_V1,
"num_params": 66347960, "num_params": 66347960,
"size": (600, 600), "size": (600, 600),
"acc@1": 84.122, "acc@1": 84.122,
...@@ -206,13 +215,76 @@ class EfficientNet_B7_Weights(WeightsEnum): ...@@ -206,13 +215,76 @@ class EfficientNet_B7_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
class EfficientNet_V2_S_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth",
transforms=partial(
ImageNetEval,
crop_size=384,
resize_size=384,
interpolation=InterpolationMode.BILINEAR,
),
meta={
**_COMMON_META_V2,
"num_params": 21458488,
"size": (384, 384),
"acc@1": 84.228,
"acc@5": 96.878,
},
)
DEFAULT = IMAGENET1K_V1
class EfficientNet_V2_M_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth",
transforms=partial(
ImageNetEval,
crop_size=480,
resize_size=480,
interpolation=InterpolationMode.BILINEAR,
),
meta={
**_COMMON_META_V2,
"num_params": 54139356,
"size": (480, 480),
"acc@1": 85.112,
"acc@5": 97.156,
},
)
DEFAULT = IMAGENET1K_V1
class EfficientNet_V2_L_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth",
transforms=partial(
ImageNetEval,
crop_size=480,
resize_size=480,
interpolation=InterpolationMode.BICUBIC,
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5),
),
meta={
**_COMMON_META_V2,
"num_params": 118515272,
"size": (480, 480),
"acc@1": 85.808,
"acc@5": 97.788,
},
)
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1))
def efficientnet_b0( def efficientnet_b0(
*, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any *, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet: ) -> EfficientNet:
weights = EfficientNet_B0_Weights.verify(weights) weights = EfficientNet_B0_Weights.verify(weights)
return _efficientnet(width_mult=1.0, depth_mult=1.0, dropout=0.2, weights=weights, progress=progress, **kwargs) inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0)
return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1))
...@@ -221,7 +293,8 @@ def efficientnet_b1( ...@@ -221,7 +293,8 @@ def efficientnet_b1(
) -> EfficientNet: ) -> EfficientNet:
weights = EfficientNet_B1_Weights.verify(weights) weights = EfficientNet_B1_Weights.verify(weights)
return _efficientnet(width_mult=1.0, depth_mult=1.1, dropout=0.2, weights=weights, progress=progress, **kwargs) inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b1", width_mult=1.0, depth_mult=1.1)
return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1))
...@@ -230,7 +303,8 @@ def efficientnet_b2( ...@@ -230,7 +303,8 @@ def efficientnet_b2(
) -> EfficientNet: ) -> EfficientNet:
weights = EfficientNet_B2_Weights.verify(weights) weights = EfficientNet_B2_Weights.verify(weights)
return _efficientnet(width_mult=1.1, depth_mult=1.2, dropout=0.3, weights=weights, progress=progress, **kwargs) inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b2", width_mult=1.1, depth_mult=1.2)
return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1))
...@@ -239,7 +313,8 @@ def efficientnet_b3( ...@@ -239,7 +313,8 @@ def efficientnet_b3(
) -> EfficientNet: ) -> EfficientNet:
weights = EfficientNet_B3_Weights.verify(weights) weights = EfficientNet_B3_Weights.verify(weights)
return _efficientnet(width_mult=1.2, depth_mult=1.4, dropout=0.3, weights=weights, progress=progress, **kwargs) inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b3", width_mult=1.2, depth_mult=1.4)
return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1))
...@@ -248,7 +323,8 @@ def efficientnet_b4( ...@@ -248,7 +323,8 @@ def efficientnet_b4(
) -> EfficientNet: ) -> EfficientNet:
weights = EfficientNet_B4_Weights.verify(weights) weights = EfficientNet_B4_Weights.verify(weights)
return _efficientnet(width_mult=1.4, depth_mult=1.8, dropout=0.4, weights=weights, progress=progress, **kwargs) inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b4", width_mult=1.4, depth_mult=1.8)
return _efficientnet(inverted_residual_setting, 0.4, last_channel, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1))
...@@ -257,12 +333,13 @@ def efficientnet_b5( ...@@ -257,12 +333,13 @@ def efficientnet_b5(
) -> EfficientNet: ) -> EfficientNet:
weights = EfficientNet_B5_Weights.verify(weights) weights = EfficientNet_B5_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b5", width_mult=1.6, depth_mult=2.2)
return _efficientnet( return _efficientnet(
width_mult=1.6, inverted_residual_setting,
depth_mult=2.2, 0.4,
dropout=0.4, last_channel,
weights=weights, weights,
progress=progress, progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
**kwargs, **kwargs,
) )
...@@ -274,12 +351,13 @@ def efficientnet_b6( ...@@ -274,12 +351,13 @@ def efficientnet_b6(
) -> EfficientNet: ) -> EfficientNet:
weights = EfficientNet_B6_Weights.verify(weights) weights = EfficientNet_B6_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b6", width_mult=1.8, depth_mult=2.6)
return _efficientnet( return _efficientnet(
width_mult=1.8, inverted_residual_setting,
depth_mult=2.6, 0.5,
dropout=0.5, last_channel,
weights=weights, weights,
progress=progress, progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
**kwargs, **kwargs,
) )
...@@ -291,12 +369,67 @@ def efficientnet_b7( ...@@ -291,12 +369,67 @@ def efficientnet_b7(
) -> EfficientNet: ) -> EfficientNet:
weights = EfficientNet_B7_Weights.verify(weights) weights = EfficientNet_B7_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b7", width_mult=2.0, depth_mult=3.1)
return _efficientnet( return _efficientnet(
width_mult=2.0, inverted_residual_setting,
depth_mult=3.1, 0.5,
dropout=0.5, last_channel,
weights=weights, weights,
progress=progress, progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
**kwargs, **kwargs,
) )
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1))
def efficientnet_v2_s(
*, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
weights = EfficientNet_V2_S_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_s")
return _efficientnet(
inverted_residual_setting,
0.2,
last_channel,
weights,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
**kwargs,
)
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1))
def efficientnet_v2_m(
*, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
weights = EfficientNet_V2_M_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_m")
return _efficientnet(
inverted_residual_setting,
0.3,
last_channel,
weights,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
**kwargs,
)
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1))
def efficientnet_v2_l(
*, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
weights = EfficientNet_V2_L_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_l")
return _efficientnet(
inverted_residual_setting,
0.4,
last_channel,
weights,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
**kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment