Unverified Commit 82929ae1 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding more ConvNeXt variants + Speed optimizations (#5253)

* Refactor model builder

* Add 3 more convnext variants.

* Adding weights for convnext_small.

* Fix minor bug.

* Fix number of parameters for small model.

* Adding weights for the base variant.

* Adding weights for the large variant.

* Simplify LayerNorm2d implementation.

* Optimize speed of CNBlock.

* Repackage weights.
parent 8aad0e0a
...@@ -249,6 +249,9 @@ vit_b_32 75.912 92.466 ...@@ -249,6 +249,9 @@ vit_b_32 75.912 92.466
vit_l_16 79.662 94.638 vit_l_16 79.662 94.638
vit_l_32 76.972 93.070 vit_l_32 76.972 93.070
convnext_tiny (prototype) 82.520 96.146 convnext_tiny (prototype) 82.520 96.146
convnext_small (prototype) 83.616 96.650
convnext_base (prototype) 84.062 96.870
convnext_large (prototype) 84.414 96.976
================================ ============= ============= ================================ ============= =============
......
...@@ -201,11 +201,12 @@ and `--batch_size 64`. ...@@ -201,11 +201,12 @@ and `--batch_size 64`.
### ConvNeXt ### ConvNeXt
``` ```
torchrun --nproc_per_node=8 train.py\ torchrun --nproc_per_node=8 train.py\
--model convnext_tiny --batch-size 128 --opt adamw --lr 1e-3 --lr-scheduler cosineannealinglr \ --model $MODEL --batch-size 128 --opt adamw --lr 1e-3 --lr-scheduler cosineannealinglr \
--lr-warmup-epochs 5 --lr-warmup-method linear --auto-augment ta_wide --epochs 600 --random-erase 0.1 \ --lr-warmup-epochs 5 --lr-warmup-method linear --auto-augment ta_wide --epochs 600 --random-erase 0.1 \
--label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --weight-decay 0.05 --norm-weight-decay 0.0 \ --label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --weight-decay 0.05 --norm-weight-decay 0.0 \
--train-crop-size 176 --model-ema --val-resize-size 236 --ra-sampler --ra-reps 4 --train-crop-size 176 --model-ema --val-resize-size 232 --ra-sampler --ra-reps 4
``` ```
Here `$MODEL` is one of `convnext_tiny`, `convnext_small`, `convnext_base` and `convnext_large`. Note that each variant had its `--val-resize-size` optimized in a post-training step, see their `Weights` entry for their exact value.
Note that the above command corresponds to training on a single node with 8 GPUs. Note that the above command corresponds to training on a single node with 8 GPUs.
For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 16 GPUs), For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 16 GPUs),
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
...@@ -15,47 +15,56 @@ from ._meta import _IMAGENET_CATEGORIES ...@@ -15,47 +15,56 @@ from ._meta import _IMAGENET_CATEGORIES
from ._utils import handle_legacy_interface, _ovewrite_named_param from ._utils import handle_legacy_interface, _ovewrite_named_param
__all__ = ["ConvNeXt", "ConvNeXt_Tiny_Weights", "convnext_tiny"] __all__ = [
"ConvNeXt",
"ConvNeXt_Tiny_Weights",
"ConvNeXt_Small_Weights",
"ConvNeXt_Base_Weights",
"ConvNeXt_Large_Weights",
"convnext_tiny",
"convnext_small",
"convnext_base",
"convnext_large",
]
class LayerNorm2d(nn.LayerNorm): class LayerNorm2d(nn.LayerNorm):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.channels_last = kwargs.pop("channels_last", False)
super().__init__(*args, **kwargs)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# TODO: Benchmark this against the approach described at https://github.com/pytorch/vision/pull/5197#discussion_r786251298 x = x.permute(0, 2, 3, 1)
if not self.channels_last:
x = x.permute(0, 2, 3, 1)
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
if not self.channels_last: x = x.permute(0, 3, 1, 2)
x = x.permute(0, 3, 1, 2)
return x return x
class Permute(nn.Module):
def __init__(self, dims: List[int]):
super().__init__()
self.dims = dims
def forward(self, x):
return torch.permute(x, self.dims)
class CNBlock(nn.Module): class CNBlock(nn.Module):
def __init__( def __init__(
self, dim, layer_scale: float, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module] self,
dim,
layer_scale: float,
stochastic_depth_prob: float,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.block = nn.Sequential( self.block = nn.Sequential(
ConvNormActivation( nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
dim, Permute([0, 2, 3, 1]),
dim, norm_layer(dim),
kernel_size=7, nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
groups=dim, nn.GELU(),
norm_layer=norm_layer, nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
activation_layer=None, Permute([0, 3, 1, 2]),
bias=True,
),
ConvNormActivation(dim, 4 * dim, kernel_size=1, norm_layer=None, activation_layer=nn.GELU, inplace=None),
ConvNormActivation(
4 * dim,
dim,
kernel_size=1,
norm_layer=None,
activation_layer=None,
),
) )
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale) self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
...@@ -138,7 +147,7 @@ class ConvNeXt(nn.Module): ...@@ -138,7 +147,7 @@ class ConvNeXt(nn.Module):
for _ in range(cnf.num_layers): for _ in range(cnf.num_layers):
# adjust stochastic depth probability based on the depth of the stage block # adjust stochastic depth probability based on the depth of the stage block
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
stage.append(block(cnf.input_channels, layer_scale, sd_prob, norm_layer)) stage.append(block(cnf.input_channels, layer_scale, sd_prob))
stage_block_id += 1 stage_block_id += 1
layers.append(nn.Sequential(*stage)) layers.append(nn.Sequential(*stage))
if cnf.out_channels is not None: if cnf.out_channels is not None:
...@@ -177,20 +186,43 @@ class ConvNeXt(nn.Module): ...@@ -177,20 +186,43 @@ class ConvNeXt(nn.Module):
return self._forward_impl(x) return self._forward_impl(x)
def _convnext(
block_setting: List[CNBlockConfig],
stochastic_depth_prob: float,
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> ConvNeXt:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model
_COMMON_META = {
"task": "image_classification",
"architecture": "ConvNeXt",
"publication_year": 2022,
"size": (224, 224),
"min_size": (32, 32),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext",
}
class ConvNeXt_Tiny_Weights(WeightsEnum): class ConvNeXt_Tiny_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth", url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=236), transforms=partial(ImageNetEval, crop_size=224, resize_size=236),
meta={ meta={
"task": "image_classification", **_COMMON_META,
"architecture": "ConvNeXt",
"publication_year": 2022,
"num_params": 28589128, "num_params": 28589128,
"size": (224, 224),
"min_size": (32, 32),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext",
"acc@1": 82.520, "acc@1": 82.520,
"acc@5": 96.146, "acc@5": 96.146,
}, },
...@@ -198,9 +230,51 @@ class ConvNeXt_Tiny_Weights(WeightsEnum): ...@@ -198,9 +230,51 @@ class ConvNeXt_Tiny_Weights(WeightsEnum):
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
class ConvNeXt_Small_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_small-0c510722.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=230),
meta={
**_COMMON_META,
"num_params": 50223688,
"acc@1": 83.616,
"acc@5": 96.650,
},
)
DEFAULT = IMAGENET1K_V1
class ConvNeXt_Base_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_base-6075fbad.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 88591464,
"acc@1": 84.062,
"acc@5": 96.870,
},
)
DEFAULT = IMAGENET1K_V1
class ConvNeXt_Large_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_large-ea097f82.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 197767336,
"acc@1": 84.414,
"acc@5": 96.976,
},
)
DEFAULT = IMAGENET1K_V1
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
r"""ConvNeXt model architecture from the r"""ConvNeXt Tiny model architecture from the
`"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper. `"A ConvNet for the 2020s" <https://arxiv.org/abs/2201.03545>`_ paper.
Args: Args:
...@@ -209,9 +283,6 @@ def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: ...@@ -209,9 +283,6 @@ def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress:
""" """
weights = ConvNeXt_Tiny_Weights.verify(weights) weights = ConvNeXt_Tiny_Weights.verify(weights)
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
block_setting = [ block_setting = [
CNBlockConfig(96, 192, 3), CNBlockConfig(96, 192, 3),
CNBlockConfig(192, 384, 3), CNBlockConfig(192, 384, 3),
...@@ -219,9 +290,50 @@ def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: ...@@ -219,9 +290,50 @@ def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress:
CNBlockConfig(768, None, 3), CNBlockConfig(768, None, 3),
] ]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1) stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1)
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress))
return model @handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1))
def convnext_small(
*, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any
) -> ConvNeXt:
weights = ConvNeXt_Small_Weights.verify(weights)
block_setting = [
CNBlockConfig(96, 192, 3),
CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 27),
CNBlockConfig(768, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4)
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1))
def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
weights = ConvNeXt_Base_Weights.verify(weights)
block_setting = [
CNBlockConfig(128, 256, 3),
CNBlockConfig(256, 512, 3),
CNBlockConfig(512, 1024, 27),
CNBlockConfig(1024, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1))
def convnext_large(
*, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any
) -> ConvNeXt:
weights = ConvNeXt_Large_Weights.verify(weights)
block_setting = [
CNBlockConfig(192, 384, 3),
CNBlockConfig(384, 768, 3),
CNBlockConfig(768, 1536, 27),
CNBlockConfig(1536, None, 3),
]
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment