Unverified Commit 6b1646ca authored by Ponku's avatar Ponku Committed by GitHub
Browse files

MaxVit model (#6342)



* Added maxvit architecture and tests

* rebased + addresed comments

* Revert "rebased + addresed comments"

This reverts commit c5b28398cd48d2f3403c7c8eeefbaba9df05fcfe.

* Re-added model changes after revert

* aligned with partial original implementation

* removed submitit script fixed lint

* mypy fix for too many arguments

* updated old tests

* removed per batch lr scheduler and seed setting

* removed ontap

* added docs, validated weights

* fixed test expect, moved shape assertions in the begging for torch.fx compatibility

* mypy fix

* lint fix

* added legacy interface

* added weight link

* updated docs

* Update references/classification/train.py
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* Update torchvision/models/maxvit.py
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>

* adressed comments

* update ra_maginuted and augmix_severity default values

* adressed some comments

* remove input_channels parameter
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent d65e286f
...@@ -207,6 +207,7 @@ weights: ...@@ -207,6 +207,7 @@ weights:
models/efficientnetv2 models/efficientnetv2
models/googlenet models/googlenet
models/inception models/inception
models/maxvit
models/mnasnet models/mnasnet
models/mobilenetv2 models/mobilenetv2
models/mobilenetv3 models/mobilenetv3
......
MaxVit
===============
.. currentmodule:: torchvision.models
The MaxVit transformer models are based on the `MaxViT: Multi-Axis Vision Transformer <https://arxiv.org/abs/2204.01697>`__
paper.
Model builders
--------------
The following model builders can be used to instantiate an MaxVit model with and without pre-trained weights.
All the model builders internally rely on the ``torchvision.models.maxvit.MaxVit``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/maxvit.py>`_ for
more details about this class.
.. autosummary::
:toctree: generated/
:template: function.rst
maxvit_t
...@@ -245,6 +245,14 @@ Here `$MODEL` is one of `swin_v2_t`, `swin_v2_s` or `swin_v2_b`. ...@@ -245,6 +245,14 @@ Here `$MODEL` is one of `swin_v2_t`, `swin_v2_s` or `swin_v2_b`.
Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value. Note that `--val-resize-size` was optimized in a post-training step, see their `Weights` entry for the exact value.
### MaxViT
```
torchrun --nproc_per_node=8 --n_nodes=4 train.py\
--model $MODEL --epochs 400 --batch-size 128 --opt adamw --lr 3e-3 --weight-decay 0.05 --lr-scheduler cosineannealinglr --lr-min 1e-5 --lr-warmup-method linear --lr-warmup-epochs 32 --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 1.0 --interpolation bicubic --auto-augment ta_wide --policy-magnitude 15 --train-center-crop --model-ema --val-resize-size 224
--val-crop-size 224 --train-crop-size 224 --amp --model-ema-steps 32 --transformer-embedding-decay 0 --sync-bn
```
Here `$MODEL` is `maxvit_t`.
Note that `--val-resize-size` was not optimized in a post-training step.
### ShuffleNet V2 ### ShuffleNet V2
......
...@@ -13,18 +13,25 @@ class ClassificationPresetTrain: ...@@ -13,18 +13,25 @@ class ClassificationPresetTrain:
interpolation=InterpolationMode.BILINEAR, interpolation=InterpolationMode.BILINEAR,
hflip_prob=0.5, hflip_prob=0.5,
auto_augment_policy=None, auto_augment_policy=None,
ra_magnitude=9,
augmix_severity=3,
random_erase_prob=0.0, random_erase_prob=0.0,
center_crop=False,
): ):
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] trans = (
[transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
if center_crop
else [transforms.CenterCrop(crop_size)]
)
if hflip_prob > 0: if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob)) trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None: if auto_augment_policy is not None:
if auto_augment_policy == "ra": if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment(interpolation=interpolation)) trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
elif auto_augment_policy == "ta_wide": elif auto_augment_policy == "ta_wide":
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation)) trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
elif auto_augment_policy == "augmix": elif auto_augment_policy == "augmix":
trans.append(autoaugment.AugMix(interpolation=interpolation)) trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity))
else: else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
......
...@@ -113,7 +113,12 @@ def _get_cache_path(filepath): ...@@ -113,7 +113,12 @@ def _get_cache_path(filepath):
def load_data(traindir, valdir, args): def load_data(traindir, valdir, args):
# Data loading code # Data loading code
print("Loading data") print("Loading data")
val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size val_resize_size, val_crop_size, train_crop_size, center_crop = (
args.val_resize_size,
args.val_crop_size,
args.train_crop_size,
args.train_center_crop,
)
interpolation = InterpolationMode(args.interpolation) interpolation = InterpolationMode(args.interpolation)
print("Loading training data") print("Loading training data")
...@@ -126,13 +131,18 @@ def load_data(traindir, valdir, args): ...@@ -126,13 +131,18 @@ def load_data(traindir, valdir, args):
else: else:
auto_augment_policy = getattr(args, "auto_augment", None) auto_augment_policy = getattr(args, "auto_augment", None)
random_erase_prob = getattr(args, "random_erase", 0.0) random_erase_prob = getattr(args, "random_erase", 0.0)
ra_magnitude = args.ra_magnitude
augmix_severity = args.augmix_severity
dataset = torchvision.datasets.ImageFolder( dataset = torchvision.datasets.ImageFolder(
traindir, traindir,
presets.ClassificationPresetTrain( presets.ClassificationPresetTrain(
center_crop=center_crop,
crop_size=train_crop_size, crop_size=train_crop_size,
interpolation=interpolation, interpolation=interpolation,
auto_augment_policy=auto_augment_policy, auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob, random_erase_prob=random_erase_prob,
ra_magnitude=ra_magnitude,
augmix_severity=augmix_severity,
), ),
) )
if args.cache_dataset: if args.cache_dataset:
...@@ -207,7 +217,10 @@ def main(args): ...@@ -207,7 +217,10 @@ def main(args):
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha)) mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
if mixup_transforms: if mixup_transforms:
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731
def collate_fn(batch):
return mixupcutmix(*default_collate(batch))
data_loader = torch.utils.data.DataLoader( data_loader = torch.utils.data.DataLoader(
dataset, dataset,
batch_size=args.batch_size, batch_size=args.batch_size,
...@@ -448,6 +461,8 @@ def get_args_parser(add_help=True): ...@@ -448,6 +461,8 @@ def get_args_parser(add_help=True):
action="store_true", action="store_true",
) )
parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)") parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
parser.add_argument("--ra-magnitude", default=9, type=int, help="magnitude of auto augment policy")
parser.add_argument("--augmix-severity", default=3, type=int, help="severity of augmix policy")
parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)") parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
# Mixed precision training parameters # Mixed precision training parameters
...@@ -486,13 +501,17 @@ def get_args_parser(add_help=True): ...@@ -486,13 +501,17 @@ def get_args_parser(add_help=True):
parser.add_argument( parser.add_argument(
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
) )
parser.add_argument(
"--train-center-crop",
action="store_true",
help="use center crop instead of random crop for training (default: False)",
)
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training") parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training")
parser.add_argument( parser.add_argument(
"--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)" "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
) )
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
return parser return parser
......
File suppressed by a .gitattributes entry or the file's encoding is unsupported.
import unittest
import pytest
import torch
from torchvision.models.maxvit import SwapAxes, WindowDepartition, WindowPartition
class MaxvitTester(unittest.TestCase):
def test_maxvit_window_partition(self):
input_shape = (1, 3, 224, 224)
partition_size = 7
n_partitions = input_shape[3] // partition_size
x = torch.randn(input_shape)
partition = WindowPartition()
departition = WindowDepartition()
x_hat = partition(x, partition_size)
x_hat = departition(x_hat, partition_size, n_partitions, n_partitions)
assert torch.allclose(x, x_hat)
def test_maxvit_grid_partition(self):
input_shape = (1, 3, 224, 224)
partition_size = 7
n_partitions = input_shape[3] // partition_size
x = torch.randn(input_shape)
pre_swap = SwapAxes(-2, -3)
post_swap = SwapAxes(-2, -3)
partition = WindowPartition()
departition = WindowDepartition()
x_hat = partition(x, n_partitions)
x_hat = pre_swap(x_hat)
x_hat = post_swap(x_hat)
x_hat = departition(x_hat, n_partitions, partition_size, partition_size)
assert torch.allclose(x, x_hat)
if __name__ == "__main__":
pytest.main([__file__])
...@@ -13,5 +13,6 @@ from .squeezenet import * ...@@ -13,5 +13,6 @@ from .squeezenet import *
from .vgg import * from .vgg import *
from .vision_transformer import * from .vision_transformer import *
from .swin_transformer import * from .swin_transformer import *
from .maxvit import *
from . import detection, optical_flow, quantization, segmentation, video from . import detection, optical_flow, quantization, segmentation, video
from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment