Unverified Commit e6d82f7d authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding EfficientNetV2 architecture (#5450)

* Extend the EfficientNet class to support v1 and v2.

* Refactor config/builder methods and add prototype builders

* Refactoring weight info.

* Update dropouts based on TF config ref

* Update BN eps on TF base_config

* Use Conv2dNormActivation.

* Adding pre-trained weights for EfficientNetV2-s

* Add Medium and Large weights

* Update stats with single batch run.

* Add accuracies in the docs.
parent a2b70758
......@@ -38,7 +38,7 @@ architectures for image classification:
- `ResNeXt`_
- `Wide ResNet`_
- `MNASNet`_
- `EfficientNet`_
- `EfficientNet`_ v1 & v2
- `RegNet`_
- `VisionTransformer`_
- `ConvNeXt`_
......@@ -70,6 +70,9 @@ You can construct a model with random weights by calling its constructor:
efficientnet_b5 = models.efficientnet_b5()
efficientnet_b6 = models.efficientnet_b6()
efficientnet_b7 = models.efficientnet_b7()
efficientnet_v2_s = models.efficientnet_v2_s()
efficientnet_v2_m = models.efficientnet_v2_m()
efficientnet_v2_l = models.efficientnet_v2_l()
regnet_y_400mf = models.regnet_y_400mf()
regnet_y_800mf = models.regnet_y_800mf()
regnet_y_1_6gf = models.regnet_y_1_6gf()
......@@ -122,6 +125,9 @@ These can be constructed by passing ``pretrained=True``:
efficientnet_b5 = models.efficientnet_b5(pretrained=True)
efficientnet_b6 = models.efficientnet_b6(pretrained=True)
efficientnet_b7 = models.efficientnet_b7(pretrained=True)
efficientnet_v2_s = models.efficientnet_v2_s(pretrained=True)
efficientnet_v2_m = models.efficientnet_v2_m(pretrained=True)
efficientnet_v2_l = models.efficientnet_v2_l(pretrained=True)
regnet_y_400mf = models.regnet_y_400mf(pretrained=True)
regnet_y_800mf = models.regnet_y_800mf(pretrained=True)
regnet_y_1_6gf = models.regnet_y_1_6gf(pretrained=True)
......@@ -238,6 +244,9 @@ EfficientNet-B4 83.384 96.594
EfficientNet-B5 83.444 96.628
EfficientNet-B6 84.008 96.916
EfficientNet-B7 84.122 96.908
EfficientNetV2-s 84.228 96.878
EfficientNetV2-m 85.112 97.156
EfficientNetV2-l 85.810 97.792
regnet_x_400mf 72.834 90.950
regnet_x_800mf 75.212 92.348
regnet_x_1_6gf 77.040 93.440
......@@ -439,6 +448,9 @@ EfficientNet
efficientnet_b5
efficientnet_b6
efficientnet_b7
efficientnet_v2_s
efficientnet_v2_m
efficientnet_v2_l
RegNet
------------
......
......@@ -13,6 +13,9 @@ from torchvision.models.efficientnet import (
efficientnet_b5,
efficientnet_b6,
efficientnet_b7,
efficientnet_v2_s,
efficientnet_v2_m,
efficientnet_v2_l,
)
from torchvision.models.googlenet import googlenet
from torchvision.models.inception import inception_v3
......
......@@ -88,7 +88,7 @@ Then we averaged the parameters of the last 3 checkpoints that improved the Acc@
and [#3354](https://github.com/pytorch/vision/pull/3354) for details.
### EfficientNet
### EfficientNet-V1
The weights of the B0-B4 variants are ported from Ross Wightman's [timm repo](https://github.com/rwightman/pytorch-image-models/blob/01cb46a9a50e3ba4be167965b5764e9702f09b30/timm/models/efficientnet.py#L95-L108).
......@@ -114,6 +114,26 @@ torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --interpolation bic
--val-resize-size 600 --val-crop-size 600 --train-crop-size 600 --test-only --pretrained
```
### EfficientNet-V2
```
torchrun --nproc_per_node=8 train.py \
--model $MODEL --batch-size 128 --lr 0.5 --lr-scheduler cosineannealinglr \
--lr-warmup-epochs 5 --lr-warmup-method linear --auto-augment ta_wide --epochs 600 --random-erase 0.1 \
--label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --weight-decay 0.00002 --norm-weight-decay 0.0 \
--train-crop-size $TRAIN_SIZE --model-ema --val-crop-size $EVAL_SIZE --val-resize-size $EVAL_SIZE \
--ra-sampler --ra-reps 4
```
Here `$MODEL` is one of `efficientnet_v2_s` and `efficientnet_v2_m`.
Note that the Small variant had a `$TRAIN_SIZE` of `300` and a `$EVAL_SIZE` of `384`, while the Medium `384` and `480` respectively.
Note that the above command corresponds to training on a single node with 8 GPUs.
For generatring the pre-trained weights, we trained with 4 nodes, each with 8 GPUs (for a total of 32 GPUs),
and `--batch_size 32`.
The weights of the Large variant are ported from the original paper rather than trained from scratch. See the `EfficientNet_V2_L_Weights` entry for their exact preprocessing transforms.
### RegNet
#### Small models
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
This diff is collapsed.
from functools import partial
from typing import Any, Optional
from typing import Any, Optional, Sequence, Union
from torch import nn
from torchvision.prototype.transforms import ImageNetEval
from torchvision.transforms.functional import InterpolationMode
from ...models.efficientnet import EfficientNet, MBConvConfig
from ...models.efficientnet import EfficientNet, MBConvConfig, FusedMBConvConfig, _efficientnet_conf
from ._api import WeightsEnum, Weights
from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param
......@@ -21,6 +21,9 @@ __all__ = [
"EfficientNet_B5_Weights",
"EfficientNet_B6_Weights",
"EfficientNet_B7_Weights",
"EfficientNet_V2_S_Weights",
"EfficientNet_V2_M_Weights",
"EfficientNet_V2_L_Weights",
"efficientnet_b0",
"efficientnet_b1",
"efficientnet_b2",
......@@ -29,13 +32,16 @@ __all__ = [
"efficientnet_b5",
"efficientnet_b6",
"efficientnet_b7",
"efficientnet_v2_s",
"efficientnet_v2_m",
"efficientnet_v2_l",
]
def _efficientnet(
width_mult: float,
depth_mult: float,
inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
dropout: float,
last_channel: Optional[int],
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
......@@ -43,18 +49,7 @@ def _efficientnet(
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
bneck_conf = partial(MBConvConfig, width_mult=width_mult, depth_mult=depth_mult)
inverted_residual_setting = [
bneck_conf(1, 3, 1, 32, 16, 1),
bneck_conf(6, 3, 2, 16, 24, 2),
bneck_conf(6, 5, 2, 24, 40, 2),
bneck_conf(6, 3, 2, 40, 80, 3),
bneck_conf(6, 5, 1, 80, 112, 3),
bneck_conf(6, 5, 2, 112, 192, 4),
bneck_conf(6, 3, 1, 192, 320, 1),
]
model = EfficientNet(inverted_residual_setting, dropout, **kwargs)
model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
......@@ -64,12 +59,26 @@ def _efficientnet(
_COMMON_META = {
"task": "image_classification",
"categories": _IMAGENET_CATEGORIES,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet",
}
_COMMON_META_V1 = {
**_COMMON_META,
"architecture": "EfficientNet",
"publication_year": 2019,
"min_size": (1, 1),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BICUBIC,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet",
"min_size": (1, 1),
}
_COMMON_META_V2 = {
**_COMMON_META,
"architecture": "EfficientNetV2",
"publication_year": 2021,
"interpolation": InterpolationMode.BILINEAR,
"min_size": (33, 33),
}
......@@ -78,7 +87,7 @@ class EfficientNet_B0_Weights(WeightsEnum):
url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC),
meta={
**_COMMON_META,
**_COMMON_META_V1,
"num_params": 5288548,
"size": (224, 224),
"acc@1": 77.692,
......@@ -93,7 +102,7 @@ class EfficientNet_B1_Weights(WeightsEnum):
url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
transforms=partial(ImageNetEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC),
meta={
**_COMMON_META,
**_COMMON_META_V1,
"num_params": 7794184,
"size": (240, 240),
"acc@1": 78.642,
......@@ -104,7 +113,7 @@ class EfficientNet_B1_Weights(WeightsEnum):
url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth",
transforms=partial(ImageNetEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR),
meta={
**_COMMON_META,
**_COMMON_META_V1,
"num_params": 7794184,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning",
"interpolation": InterpolationMode.BILINEAR,
......@@ -121,7 +130,7 @@ class EfficientNet_B2_Weights(WeightsEnum):
url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
transforms=partial(ImageNetEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC),
meta={
**_COMMON_META,
**_COMMON_META_V1,
"num_params": 9109994,
"size": (288, 288),
"acc@1": 80.608,
......@@ -136,7 +145,7 @@ class EfficientNet_B3_Weights(WeightsEnum):
url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
transforms=partial(ImageNetEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC),
meta={
**_COMMON_META,
**_COMMON_META_V1,
"num_params": 12233232,
"size": (300, 300),
"acc@1": 82.008,
......@@ -151,7 +160,7 @@ class EfficientNet_B4_Weights(WeightsEnum):
url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
transforms=partial(ImageNetEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC),
meta={
**_COMMON_META,
**_COMMON_META_V1,
"num_params": 19341616,
"size": (380, 380),
"acc@1": 83.384,
......@@ -166,7 +175,7 @@ class EfficientNet_B5_Weights(WeightsEnum):
url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
transforms=partial(ImageNetEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC),
meta={
**_COMMON_META,
**_COMMON_META_V1,
"num_params": 30389784,
"size": (456, 456),
"acc@1": 83.444,
......@@ -181,7 +190,7 @@ class EfficientNet_B6_Weights(WeightsEnum):
url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
transforms=partial(ImageNetEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC),
meta={
**_COMMON_META,
**_COMMON_META_V1,
"num_params": 43040704,
"size": (528, 528),
"acc@1": 84.008,
......@@ -196,7 +205,7 @@ class EfficientNet_B7_Weights(WeightsEnum):
url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
transforms=partial(ImageNetEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC),
meta={
**_COMMON_META,
**_COMMON_META_V1,
"num_params": 66347960,
"size": (600, 600),
"acc@1": 84.122,
......@@ -206,13 +215,76 @@ class EfficientNet_B7_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1
class EfficientNet_V2_S_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth",
transforms=partial(
ImageNetEval,
crop_size=384,
resize_size=384,
interpolation=InterpolationMode.BILINEAR,
),
meta={
**_COMMON_META_V2,
"num_params": 21458488,
"size": (384, 384),
"acc@1": 84.228,
"acc@5": 96.878,
},
)
DEFAULT = IMAGENET1K_V1
class EfficientNet_V2_M_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth",
transforms=partial(
ImageNetEval,
crop_size=480,
resize_size=480,
interpolation=InterpolationMode.BILINEAR,
),
meta={
**_COMMON_META_V2,
"num_params": 54139356,
"size": (480, 480),
"acc@1": 85.112,
"acc@5": 97.156,
},
)
DEFAULT = IMAGENET1K_V1
class EfficientNet_V2_L_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth",
transforms=partial(
ImageNetEval,
crop_size=480,
resize_size=480,
interpolation=InterpolationMode.BICUBIC,
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5),
),
meta={
**_COMMON_META_V2,
"num_params": 118515272,
"size": (480, 480),
"acc@1": 85.808,
"acc@5": 97.788,
},
)
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1))
def efficientnet_b0(
*, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
weights = EfficientNet_B0_Weights.verify(weights)
return _efficientnet(width_mult=1.0, depth_mult=1.0, dropout=0.2, weights=weights, progress=progress, **kwargs)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0)
return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1))
......@@ -221,7 +293,8 @@ def efficientnet_b1(
) -> EfficientNet:
weights = EfficientNet_B1_Weights.verify(weights)
return _efficientnet(width_mult=1.0, depth_mult=1.1, dropout=0.2, weights=weights, progress=progress, **kwargs)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b1", width_mult=1.0, depth_mult=1.1)
return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1))
......@@ -230,7 +303,8 @@ def efficientnet_b2(
) -> EfficientNet:
weights = EfficientNet_B2_Weights.verify(weights)
return _efficientnet(width_mult=1.1, depth_mult=1.2, dropout=0.3, weights=weights, progress=progress, **kwargs)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b2", width_mult=1.1, depth_mult=1.2)
return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1))
......@@ -239,7 +313,8 @@ def efficientnet_b3(
) -> EfficientNet:
weights = EfficientNet_B3_Weights.verify(weights)
return _efficientnet(width_mult=1.2, depth_mult=1.4, dropout=0.3, weights=weights, progress=progress, **kwargs)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b3", width_mult=1.2, depth_mult=1.4)
return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1))
......@@ -248,7 +323,8 @@ def efficientnet_b4(
) -> EfficientNet:
weights = EfficientNet_B4_Weights.verify(weights)
return _efficientnet(width_mult=1.4, depth_mult=1.8, dropout=0.4, weights=weights, progress=progress, **kwargs)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b4", width_mult=1.4, depth_mult=1.8)
return _efficientnet(inverted_residual_setting, 0.4, last_channel, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1))
......@@ -257,12 +333,13 @@ def efficientnet_b5(
) -> EfficientNet:
weights = EfficientNet_B5_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b5", width_mult=1.6, depth_mult=2.2)
return _efficientnet(
width_mult=1.6,
depth_mult=2.2,
dropout=0.4,
weights=weights,
progress=progress,
inverted_residual_setting,
0.4,
last_channel,
weights,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
**kwargs,
)
......@@ -274,12 +351,13 @@ def efficientnet_b6(
) -> EfficientNet:
weights = EfficientNet_B6_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b6", width_mult=1.8, depth_mult=2.6)
return _efficientnet(
width_mult=1.8,
depth_mult=2.6,
dropout=0.5,
weights=weights,
progress=progress,
inverted_residual_setting,
0.5,
last_channel,
weights,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
**kwargs,
)
......@@ -291,12 +369,67 @@ def efficientnet_b7(
) -> EfficientNet:
weights = EfficientNet_B7_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b7", width_mult=2.0, depth_mult=3.1)
return _efficientnet(
width_mult=2.0,
depth_mult=3.1,
dropout=0.5,
weights=weights,
progress=progress,
inverted_residual_setting,
0.5,
last_channel,
weights,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
**kwargs,
)
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1))
def efficientnet_v2_s(
*, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
weights = EfficientNet_V2_S_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_s")
return _efficientnet(
inverted_residual_setting,
0.2,
last_channel,
weights,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
**kwargs,
)
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1))
def efficientnet_v2_m(
*, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
weights = EfficientNet_V2_M_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_m")
return _efficientnet(
inverted_residual_setting,
0.3,
last_channel,
weights,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
**kwargs,
)
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1))
def efficientnet_v2_l(
*, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any
) -> EfficientNet:
weights = EfficientNet_V2_L_Weights.verify(weights)
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_l")
return _efficientnet(
inverted_residual_setting,
0.4,
last_channel,
weights,
progress,
norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
**kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment