Unverified Commit 5521e9d0 authored by Local State's avatar Local State Committed by GitHub
Browse files

Add SwinV2 (#6246)



* init submit

* fix typo

* support ufmt and mypy

* fix 2 unittest errors

* fix ufmt issue

* Apply suggestions from code review
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* unify codes

* fix meshgrid indexing

* fix a bug

* fix type check

* add type_annotation

* add slow model

* fix device issue

* fix ufmt issue

* add expect pickle file

* fix jit script issue

* fix type check

* keep consistent argument order

* add support for pretrained_window_size

* avoid code duplication

* a better code reuse

* update window_size argument

* make permute and flatten operations modular

* add PatchMergingV2

* modify expect.pkl

* use None as default argument value

* fix type check

* fix indent

* fix window_size (temporarily)

* remove "v2_" related prefix and add v2 builder

* remove v2 builder

* keep default value consistent with official repo

* deprecate dropout

* deprecate pretrained_window_size

* fix dynamic padding edge case

* remove unused imports

* remove doc modification

* Revert "deprecate dropout"

This reverts commit 8a13f932815ae25655c07430d52929f86b1ca479.

* Revert "fix dynamic padding edge case"

This reverts commit 1c7579cb1bd7bf2f0f94907f39bee6ed707a97a8.

* remove unused kwargs

* add downsample docs

* revert block default value

* revert argument order change

* explicitly specify start_dim

* add small and base variants

* add expect files and slow_models

* Add model weights and documentation for swin v2

* fix lint

* fix end of files line
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: default avatarJoao Gomes <jdsgomes@fb.com>
parent 7e8186e0
......@@ -3,16 +3,18 @@ SwinTransformer
.. currentmodule:: torchvision.models
The SwinTransformer model is based on the `Swin Transformer: Hierarchical Vision
The SwinTransformer models are based on the `Swin Transformer: Hierarchical Vision
Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>`__
paper.
SwinTransformer V2 models are based on the `Swin Transformer V2: Scaling Up Capacity
and Resolution <https://openaccess.thecvf.com/content/CVPR2022/papers/Liu_Swin_Transformer_V2_Scaling_Up_Capacity_and_Resolution_CVPR_2022_paper.pdf>`__
paper.
Model builders
--------------
The following model builders can be used to instantiate an SwinTransformer model.
`swin_t` can be instantiated with pre-trained weights and all others without.
The following model builders can be used to instantiate an SwinTransformer model (original and V2) with and without pre-trained weights.
All the model builders internally rely on the ``torchvision.models.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_ for
......@@ -25,3 +27,6 @@ more details about this class.
swin_t
swin_s
swin_b
swin_v2_t
swin_v2_s
swin_v2_b
......@@ -236,6 +236,17 @@ Note that `--val-resize-size` was optimized in a post-training step, see their `
### SwinTransformer V2
```
torchrun --nproc_per_node=8 train.py\
--model $MODEL --epochs 300 --batch-size 128 --opt adamw --lr 0.001 --weight-decay 0.05 --norm-weight-decay 0.0 --bias-weight-decay 0.0 --transformer-embedding-decay 0.0 --lr-scheduler cosineannealinglr --lr-min 0.00001 --lr-warmup-method linear --lr-warmup-epochs 20 --lr-warmup-decay 0.01 --amp --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 5.0 --cutmix-alpha 1.0 --random-erase 0.25 --interpolation bicubic --auto-augment ta_wide --model-ema --ra-sampler --ra-reps 4 --val-resize-size 256 --val-crop-size 256 --train-crop-size 256
```
Here `$MODEL` is one of `swin_v2_t`, `swin_v2_s` or `swin_v2_b`.
Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value.
### ShuffleNet V2
```
torchrun --nproc_per_node=8 train.py \
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
......@@ -332,6 +332,9 @@ slow_models = [
"swin_t",
"swin_s",
"swin_b",
"swin_v2_t",
"swin_v2_s",
"swin_v2_b",
]
for m in slow_models:
_model_params[m] = {"input_shape": (1, 3, 64, 64)}
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment