Unverified Commit 9d9cfab2 authored by Joao Gomes's avatar Joao Gomes Committed by GitHub
Browse files

add swin_s and swin_b variants and improved swin_t (#6048)



* add swin_s and swin_b variants

* fix swin_b params

* fix n parameters and acc numbers

* adding missing acc numbers

* apply ufmt

* Updating `_docs` to reflect training recipe

* Fix exted for swin_b
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 52a4480d
...@@ -23,3 +23,5 @@ more details about this class. ...@@ -23,3 +23,5 @@ more details about this class.
:template: function.rst :template: function.rst
swin_t swin_t
swin_s
swin_b
...@@ -228,14 +228,14 @@ and `--batch_size 64`. ...@@ -228,14 +228,14 @@ and `--batch_size 64`.
### SwinTransformer ### SwinTransformer
``` ```
torchrun --nproc_per_node=8 train.py\ torchrun --nproc_per_node=8 train.py\
--model swin_t --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0\ --model $MODEL --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0 --bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear --lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ta_wide --model-ema --ra-sampler --ra-reps 4 --val-resize-size 224
--bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear\
--lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8\
--clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ra
``` ```
Here `$MODEL` is one of `swin_t`, `swin_s` or `swin_b`.
Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value. Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value.
### ShuffleNet V2 ### ShuffleNet V2
``` ```
torchrun --nproc_per_node=8 train.py \ torchrun --nproc_per_node=8 train.py \
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
...@@ -18,7 +18,11 @@ from .vision_transformer import MLPBlock ...@@ -18,7 +18,11 @@ from .vision_transformer import MLPBlock
__all__ = [ __all__ = [
"SwinTransformer", "SwinTransformer",
"Swin_T_Weights", "Swin_T_Weights",
"Swin_S_Weights",
"Swin_B_Weights",
"swin_t", "swin_t",
"swin_s",
"swin_b",
] ]
...@@ -408,9 +412,9 @@ _COMMON_META = { ...@@ -408,9 +412,9 @@ _COMMON_META = {
class Swin_T_Weights(WeightsEnum): class Swin_T_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/swin_t-81486767.pth", url="https://download.pytorch.org/models/swin_t-4c37bd06.pth",
transforms=partial( transforms=partial(
ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC ImageClassification, crop_size=224, resize_size=232, interpolation=InterpolationMode.BICUBIC
), ),
meta={ meta={
**_COMMON_META, **_COMMON_META,
...@@ -419,11 +423,57 @@ class Swin_T_Weights(WeightsEnum): ...@@ -419,11 +423,57 @@ class Swin_T_Weights(WeightsEnum):
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer",
"_metrics": { "_metrics": {
"ImageNet-1K": { "ImageNet-1K": {
"acc@1": 81.358, "acc@1": 81.474,
"acc@5": 95.526, "acc@5": 95.776,
}
},
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
},
)
DEFAULT = IMAGENET1K_V1
class Swin_S_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/swin_s-30134662.pth",
transforms=partial(
ImageClassification, crop_size=224, resize_size=246, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META,
"num_params": 49606258,
"min_size": (224, 224),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer",
"_metrics": {
"ImageNet-1K": {
"acc@1": 83.196,
"acc@5": 96.360,
}
},
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
},
)
DEFAULT = IMAGENET1K_V1
class Swin_B_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/swin_b-1f1feb5c.pth",
transforms=partial(
ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META,
"num_params": 87768224,
"min_size": (224, 224),
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer",
"_metrics": {
"ImageNet-1K": {
"acc@1": 83.582,
"acc@5": 96.640,
} }
}, },
"_docs": """These weights reproduce closely the results of the paper using its training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
}, },
) )
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
...@@ -463,3 +513,75 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, * ...@@ -463,3 +513,75 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, *
progress=progress, progress=progress,
**kwargs, **kwargs,
) )
def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
"""
Constructs a swin_small architecture from
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/pdf/2103.14030>`_.
Args:
weights (:class:`~torchvision.models.Swin_S_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.Swin_S_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.Swin_S_Weights
:members:
"""
weights = Swin_S_Weights.verify(weights)
return _swin_transformer(
patch_size=4,
embed_dim=96,
depths=[2, 2, 18, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
stochastic_depth_prob=0.3,
weights=weights,
progress=progress,
**kwargs,
)
def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
"""
Constructs a swin_base architecture from
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/pdf/2103.14030>`_.
Args:
weights (:class:`~torchvision.models.Swin_B_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.Swin_B_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.Swin_B_Weights
:members:
"""
weights = Swin_B_Weights.verify(weights)
return _swin_transformer(
patch_size=4,
embed_dim=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=7,
stochastic_depth_prob=0.5,
weights=weights,
progress=progress,
**kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment