Commit dff2c686 authored by renzhc's avatar renzhc
Browse files

first commit

parent 8f9dd0ed
Pipeline #1665 canceled with stages
_base_ = '../levit-256_8xb256_in1k.py'
model = dict(backbone=dict(deploy=True), head=dict(deploy=True))
_base_ = '../levit-384_8xb256_in1k.py'
model = dict(backbone=dict(deploy=True), head=dict(deploy=True))
_base_ = [
'../_base_/models/levit-256-p16.py',
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs2048_adamw_levit.py',
'../_base_/default_runtime.py',
]
# model settings
model = dict(backbone=dict(arch='128'), head=dict(in_channels=384))
# dataset settings
train_dataloader = dict(batch_size=256)
_base_ = [
'../_base_/models/levit-256-p16.py',
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs2048_adamw_levit.py',
'../_base_/default_runtime.py',
]
# model settings
model = dict(backbone=dict(arch='128s'), head=dict(in_channels=384))
# dataset settings
train_dataloader = dict(batch_size=256)
_base_ = [
'../_base_/models/levit-256-p16.py',
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs2048_adamw_levit.py',
'../_base_/default_runtime.py',
]
# model settings
model = dict(backbone=dict(arch='192'), head=dict(in_channels=384))
# dataset settings
train_dataloader = dict(batch_size=256)
_base_ = [
'../_base_/models/levit-256-p16.py',
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs2048_adamw_levit.py',
'../_base_/default_runtime.py',
]
# dataset settings
train_dataloader = dict(batch_size=256)
_base_ = [
'../_base_/models/levit-256-p16.py',
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs2048_adamw_levit.py',
'../_base_/default_runtime.py',
]
# model settings
model = dict(
backbone=dict(arch='384', drop_path_rate=0.1),
head=dict(in_channels=768),
)
# dataset settings
train_dataloader = dict(batch_size=256)
Collections:
- Name: LeViT
Metadata:
Training Data: ImageNet-1k
Architecture:
- 1x1 Convolution
- LeViT Attention Block
Paper:
Title: "LeViT: a Vision Transformer in ConvNet\u2019s Clothing for Faster Inference"
URL: https://arxiv.org/abs/2104.01136
README: configs/levit/README.md
Code:
URL: https://github.com/open-mmlab/mmpretrain/blob/main/mmpretrain/models/backbones/levit.py
Version: v1.0.0rc5
Models:
- Name: levit-128s_3rdparty_in1k
Metadata:
FLOPs: 310342496
Parameters: 7391290
Training Data: ImageNet-1k
In Collection: LeViT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 76.51
Top 5 Accuracy: 92.90
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/levit/levit-128s_3rdparty_in1k_20230117-e9fbd209.pth
Config: configs/levit/levit-128s_8xb256_in1k.py
Converted From:
Weights: https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth
Code: https://github.com/facebookresearch/LeViT
- Name: levit-128_3rdparty_in1k
Metadata:
FLOPs: 413060992
Parameters: 8828168
Training Data: ImageNet-1k
In Collection: LeViT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 78.58
Top 5 Accuracy: 93.95
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/levit/levit-128_3rdparty_in1k_20230117-3be02a02.pth
Config: configs/levit/levit-128_8xb256_in1k.py
Converted From:
Weights: https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth
Code: https://github.com/facebookresearch/LeViT
- Name: levit-192_3rdparty_in1k
Metadata:
FLOPs: 667860704
Parameters: 10561301
Training Data: ImageNet-1k
In Collection: LeViT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 79.86
Top 5 Accuracy: 94.75
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/levit/levit-192_3rdparty_in1k_20230117-8217a0f9.pth
Config: configs/levit/levit-192_8xb256_in1k.py
Converted From:
Weights: https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth
Code: https://github.com/facebookresearch/LeViT
- Name: levit-256_3rdparty_in1k
Metadata:
FLOPs: 1141625216
Parameters: 18379852
Training Data: ImageNet-1k
In Collection: LeViT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 81.59
Top 5 Accuracy: 95.46
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/levit/levit-256_3rdparty_in1k_20230117-5ae2ce7d.pth
Config: configs/levit/levit-256_8xb256_in1k.py
Converted From:
Weights: https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth
Code: https://github.com/facebookresearch/LeViT
- Name: levit-384_3rdparty_in1k
Metadata:
FLOPs: 2372941568
Parameters: 38358300
Training Data: ImageNet-1k
In Collection: LeViT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.59
Top 5 Accuracy: 95.95
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/levit/levit-384_3rdparty_in1k_20230117-f3539cce.pth
Config: configs/levit/levit-384_8xb256_in1k.py
Converted From:
Weights: https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth
Code: https://github.com/facebookresearch/LeViT
# LLaVA
> [Visual Instruction Tuning](https://arxiv.org/abs/2304.08485)
<!-- [ALGORITHM] -->
## Abstract
Instruction tuning large language models (LLMs) using machine-generated instruction-following data has improved zero-shot capabilities on new tasks, but the idea is less explored in the multimodal field. In this paper, we present the first attempt to use language-only GPT-4 to generate multimodal language-image instruction-following data. By instruction tuning on such generated data, we introduce LLaVA: Large Language and Vision Assistant, an end-to-end trained large multimodal model that connects a vision encoder and LLM for general-purpose visual and language understanding.Our early experiments show that LLaVA demonstrates impressive multimodel chat abilities, sometimes exhibiting the behaviors of multimodal GPT-4 on unseen images/instructions, and yields a 85.1% relative score compared with GPT-4 on a synthetic multimodal instruction-following dataset. When fine-tuned on Science QA, the synergy of LLaVA and GPT-4 achieves a new state-of-the-art accuracy of 92.53%. We make GPT-4 generated visual instruction tuning data, our model and code base publicly available.
<div align=center>
<img src="https://github-production-user-asset-6210df.s3.amazonaws.com/26739999/246466979-c2f41b71-1de3-4da8-b20a-eaebe722c339.png" width="80%"/>
</div>
## How to use it?
<!-- [TABS-BEGIN] -->
**Use the model**
```python
import torch
from mmpretrain import get_model, inference_model
out = inference_model('llava-7b-v1_caption', 'demo/cat-dog.png', device='cuda')
print(out)
# {'pred_caption': 'In the image, there are two cats sitting on a blanket.'}
```
<!-- [TABS-END] -->
## Models and results
### Image Caption on COCO
| Model | Params (M) | Config | Download |
| :---------------------- | :--------: | :--------------------------------: | :-------------------------------------------------------------------------------------------------------------: |
| `llava-7b-v1_caption` | 7045.82 | [config](llava-7b-v1_caption.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1_liuhaotian_20231025-c9e119b6.pth) |
| `llava-7b-v1.5_caption` | 7062.90 | [config](llava-7b-v1.5_caption.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth) |
| `llava-7b-v1.5_vqa` | 7062.90 | [config](llava-7b-v1.5_vqa.py) | [ckpt](https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth) |
## Citation
```bibtex
@misc{liu2023llava,
title={Visual Instruction Tuning},
author={Liu, Haotian and Li, Chunyuan and Wu, Qingyang and Lee, Yong Jae},
publisher={arXiv:2304.08485},
year={2023},
}
```
_base_ = '../_base_/default_runtime.py'
meta_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." # noqa: E501
image_size = 336
prompt_tmpl = f'''{meta_prompt} User: <image>
Describe the image in detail. ASSISTANT:'''
# model settings
model = dict(
type='Llava',
tokenizer=dict(
type='AutoTokenizer', name_or_path='liuhaotian/llava-v1.5-7b'),
vision_encoder=dict(
type='VisionTransformer',
arch='l',
patch_size=14,
img_size=image_size,
pre_norm=True,
norm_cfg=dict(type='LN', eps=1e-5),
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
final_norm=False,
out_type='raw',
pretrained='https://download.openmmlab.com/mmclassification/v0/clip/'
'vit-large-p14_clip-openai-pre_336px_20231025-fb1315ed.pth',
),
mm_hidden_size=1024,
use_im_patch=False,
use_im_start_end=False,
mm_proj_depth=2,
lang_encoder=dict(
type='AutoModelForCausalLM',
name_or_path='huggyllama/llama-7b',
),
task='caption',
prompt_tmpl=prompt_tmpl,
generation_cfg=dict(num_beams=3, max_new_tokens=50, length_penalty=-1.0),
)
# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(image_size, image_size),
interpolation='bicubic',
backend='pillow'),
dict(type='PackInputs', meta_keys=['image_id']),
]
test_dataloader = dict(
batch_size=8,
num_workers=5,
dataset=dict(
type='COCOCaption',
data_root='data/coco',
ann_file='annotations/coco_karpathy_val.json',
pipeline=test_pipeline,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
test_evaluator = dict(
type='COCOCaption',
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
)
# schedule settings
test_cfg = dict()
_base_ = '../_base_/default_runtime.py'
meta_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." # noqa: E501
image_size = 336
prompt_tmpl = f'''{meta_prompt} User: <image>
{{question}} ASSISTANT:'''
# model settings
model = dict(
type='Llava',
tokenizer=dict(
type='AutoTokenizer', name_or_path='liuhaotian/llava-v1.5-7b'),
vision_encoder=dict(
type='VisionTransformer',
arch='l',
patch_size=14,
img_size=image_size,
pre_norm=True,
norm_cfg=dict(type='LN', eps=1e-5),
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
final_norm=False,
out_type='raw',
pretrained='https://download.openmmlab.com/mmclassification/v0/clip/'
'vit-large-p14_clip-openai-pre_336px_20231025-fb1315ed.pth',
),
mm_hidden_size=1024,
use_im_patch=False,
use_im_start_end=False,
mm_proj_depth=2,
lang_encoder=dict(
type='AutoModelForCausalLM',
name_or_path='huggyllama/llama-7b',
),
task='vqa',
prompt_tmpl=prompt_tmpl,
generation_cfg=dict(max_new_tokens=100),
)
# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(image_size, image_size),
interpolation='bicubic',
backend='pillow'),
dict(type='PackInputs', meta_keys=['image_id', 'question']),
]
test_dataloader = dict(
batch_size=8,
num_workers=5,
dataset=dict(
type='COCOCaption',
data_root='data/coco',
ann_file='annotations/coco_karpathy_val.json',
pipeline=test_pipeline,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
test_evaluator = dict(
type='COCOCaption',
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
)
# schedule settings
test_cfg = dict()
_base_ = '../_base_/default_runtime.py'
meta_prompt = 'You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.Follow the instructions carefully and explain your answers in detail.' # noqa: E501
image_size = 224
prompt_tmpl = f'''{meta_prompt} User: <im_start><image><im_end>
Describe the image in detail. ASSISTANT:'''
# model settings
model = dict(
type='Llava',
tokenizer=dict(
type='AutoTokenizer',
name_or_path='liuhaotian/LLaVA-Lightning-7B-delta-v1-1'),
vision_encoder=dict(
type='VisionTransformer',
arch='l',
patch_size=14,
img_size=image_size,
pre_norm=True,
norm_cfg=dict(type='LN', eps=1e-5),
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')),
final_norm=False,
out_type='raw',
pretrained=(
'https://download.openmmlab.com/mmclassification/v0/clip/'
'vit-large-p14_clip-openai-pre_3rdparty_20230517-95e2af0b.pth'),
),
mm_hidden_size=1024,
use_im_patch=False,
use_im_start_end=True,
mm_proj_depth=1,
lang_encoder=dict(
type='AutoModelForCausalLM',
name_or_path='huggyllama/llama-7b',
),
task='caption',
prompt_tmpl=prompt_tmpl,
generation_cfg=dict(max_new_tokens=50),
)
# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(image_size, image_size),
interpolation='bicubic',
backend='pillow'),
dict(type='PackInputs', meta_keys=['image_id']),
]
test_dataloader = dict(
batch_size=8,
num_workers=5,
dataset=dict(
type='COCOCaption',
data_root='data/coco',
ann_file='annotations/coco_karpathy_val.json',
pipeline=test_pipeline,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
test_evaluator = dict(
type='COCOCaption',
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
)
# schedule settings
test_cfg = dict()
Collections:
- Name: LLaVA
Metadata:
Architecture:
- LLaMA
- CLIP
Paper:
Title: Visual Instruction Tuning
URL: https://arxiv.org/abs/2304.08485
README: configs/llava/README.md
Models:
- Name: llava-7b-v1_caption
Metadata:
FLOPs: null
Parameters: 7045816320
In Collection: LLaVA
Results:
- Task: Image Caption
Dataset: COCO
Metrics:
BLEU-4: null
CIDER: null
Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1_liuhaotian_20231025-c9e119b6.pth
Config: configs/llava/llava-7b-v1_caption.py
- Name: llava-7b-v1.5_caption
Metadata:
FLOPs: null
Parameters: 7062900736
In Collection: LLaVA
Results:
- Task: Image Caption
Dataset: COCO
Metrics:
BLEU-4: null
CIDER: null
Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth
Config: configs/llava/llava-7b-v1.5_caption.py
- Name: llava-7b-v1.5_vqa
Metadata:
FLOPs: null
Parameters: 7062900736
In Collection: LLaVA
Results:
- Task: Visual Question Answering
Dataset: COCO
Metrics:
BLEU-4: null
CIDER: null
Weights: https://download.openmmlab.com/mmclassification/v1/llava/llava-7b-v1.5_liuhaotian_20231025-5828aa5a.pth
Config: configs/llava/llava-7b-v1.5_vqa.py
# MAE
> [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/abs/2111.06377)
<!-- [ALGORITHM] -->
## Abstract
This paper shows that masked autoencoders (MAE) are
scalable self-supervised learners for computer vision. Our
MAE approach is simple: we mask random patches of the
input image and reconstruct the missing pixels. It is based
on two core designs. First, we develop an asymmetric
encoder-decoder architecture, with an encoder that operates only on the
visible subset of patches (without mask tokens), along with a lightweight
decoder that reconstructs the original image from the latent representation
and mask tokens. Second, we find that masking a high proportion
of the input image, e.g., 75%, yields a nontrivial and
meaningful self-supervisory task. Coupling these two designs enables us to
train large models efficiently and effectively: we accelerate
training (by 3× or more) and improve accuracy. Our scalable approach allows
for learning high-capacity models that generalize well: e.g., a vanilla
ViT-Huge model achieves the best accuracy (87.8%) among
methods that use only ImageNet-1K data. Transfer performance in downstream tasks outperforms supervised pretraining and shows promising scaling behavior.
<div align=center>
<img src="https://user-images.githubusercontent.com/30762564/150733959-2959852a-c7bd-4d3f-911f-3e8d8839fe67.png" width="80%"/>
</div>
## How to use it?
<!-- [TABS-BEGIN] -->
**Predict image**
```python
from mmpretrain import inference_model
predict = inference_model('vit-base-p16_mae-300e-pre_8xb128-coslr-100e_in1k', 'demo/bird.JPEG')
print(predict['pred_class'])
print(predict['pred_score'])
```
**Use the model**
```python
import torch
from mmpretrain import get_model
model = get_model('mae_vit-base-p16_8xb512-amp-coslr-300e_in1k', pretrained=True)
inputs = torch.rand(1, 3, 224, 224)
out = model(inputs)
print(type(out))
# To extract features.
feats = model.extract_feat(inputs)
print(type(feats))
```
**Train/Test Command**
Prepare your dataset according to the [docs](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#prepare-dataset).
Train:
```shell
python tools/train.py configs/mae/mae_vit-base-p16_8xb512-amp-coslr-300e_in1k.py
```
Test:
```shell
python tools/test.py configs/mae/benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py None
```
<!-- [TABS-END] -->
## Models and results
### Pretrained models
| Model | Params (M) | Flops (G) | Config | Download |
| :---------------------------------------------- | :--------: | :-------: | :--------------------------------------------------------: | :--------------------------------------------------------------------------: |
| `mae_vit-base-p16_8xb512-amp-coslr-300e_in1k` | 111.91 | 17.58 | [config](mae_vit-base-p16_8xb512-amp-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-300e_in1k/mae_vit-base-p16_8xb512-coslr-300e-fp16_in1k_20220829-c2cf66ba.pth) \| [log](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-300e_in1k/mae_vit-base-p16_8xb512-coslr-300e-fp16_in1k_20220829-c2cf66ba.json) |
| `mae_vit-base-p16_8xb512-amp-coslr-400e_in1k` | 111.91 | 17.58 | [config](mae_vit-base-p16_8xb512-amp-coslr-400e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-400e_in1k/mae_vit-base-p16_8xb512-coslr-400e-fp16_in1k_20220825-bc79e40b.pth) \| [log](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-400e_in1k/mae_vit-base-p16_8xb512-coslr-400e-fp16_in1k_20220825-bc79e40b.json) |
| `mae_vit-base-p16_8xb512-amp-coslr-800e_in1k` | 111.91 | 17.58 | [config](mae_vit-base-p16_8xb512-amp-coslr-800e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-800e_in1k/mae_vit-base-p16_8xb512-coslr-800e-fp16_in1k_20220825-5d81fbc4.pth) \| [log](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-800e_in1k/mae_vit-base-p16_8xb512-coslr-800e-fp16_in1k_20220825-5d81fbc4.json) |
| `mae_vit-base-p16_8xb512-amp-coslr-1600e_in1k` | 111.91 | 17.58 | [config](mae_vit-base-p16_8xb512-amp-coslr-1600e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k_20220825-f7569ca2.pth) \| [log](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k_20220825-f7569ca2.json) |
| `mae_vit-large-p16_8xb512-amp-coslr-400e_in1k` | 329.54 | 61.60 | [config](mae_vit-large-p16_8xb512-amp-coslr-400e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-400e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-400e_in1k_20220825-b11d0425.pth) \| [log](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-400e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-400e_in1k_20220825-b11d0425.json) |
| `mae_vit-large-p16_8xb512-amp-coslr-800e_in1k` | 329.54 | 61.60 | [config](mae_vit-large-p16_8xb512-amp-coslr-800e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-800e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-800e_in1k_20220825-df72726a.pth) \| [log](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-800e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-800e_in1k_20220825-df72726a.json) |
| `mae_vit-large-p16_8xb512-amp-coslr-1600e_in1k` | 329.54 | 61.60 | [config](mae_vit-large-p16_8xb512-amp-coslr-1600e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth) \| [log](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.json) |
| `mae_vit-huge-p16_8xb512-amp-coslr-1600e_in1k` | 657.07 | 167.40 | [config](mae_vit-huge-p14_8xb512-amp-coslr-1600e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-huge-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-huge-p16_8xb512-fp16-coslr-1600e_in1k_20220916-ff848775.pth) \| [log](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-huge-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-huge-p16_8xb512-fp16-coslr-1600e_in1k_20220916-ff848775.json) |
### Image Classification on ImageNet-1k
| Model | Pretrain | Params (M) | Flops (G) | Top-1 (%) | Config | Download |
| :---------------------------------------- | :------------------------------------------: | :--------: | :-------: | :-------: | :----------------------------------------: | :-------------------------------------------: |
| `vit-base-p16_mae-300e-pre_8xb128-coslr-100e_in1k` | [MAE 300-Epochs](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-300e_in1k/mae_vit-base-p16_8xb512-coslr-300e-fp16_in1k_20220829-c2cf66ba.pth) | 86.57 | 17.58 | 83.10 | [config](benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py) | N/A |
| `vit-base-p16_mae-400e-pre_8xb128-coslr-100e_in1k` | [MAE 400-Epochs](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-400e_in1k/mae_vit-base-p16_8xb512-coslr-400e-fp16_in1k_20220825-bc79e40b.pth) | 86.57 | 17.58 | 83.30 | [config](benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py) | N/A |
| `vit-base-p16_mae-800e-pre_8xb128-coslr-100e_in1k` | [MAE 800-Epochs](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-800e_in1k/mae_vit-base-p16_8xb512-coslr-800e-fp16_in1k_20220825-5d81fbc4.pth) | 86.57 | 17.58 | 83.30 | [config](benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py) | N/A |
| `vit-base-p16_mae-1600e-pre_8xb128-coslr-100e_in1k` | [MAE 1600-Epochs](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k_20220825-f7569ca2.pth) | 86.57 | 17.58 | 83.50 | [config](benchmarks/vit-base-p16_8xb128-coslr-100e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20220825-cf70aa21.pth) \| [log](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k/vit-base-p16_ft-8xb128-coslr-100e_in1k_20220825-cf70aa21.json) |
| `vit-base-p16_mae-300e-pre_8xb2048-linear-coslr-90e_in1k` | [MAE 300-Epochs](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-300e_in1k/mae_vit-base-p16_8xb512-coslr-300e-fp16_in1k_20220829-c2cf66ba.pth) | 86.57 | 17.58 | 60.80 | [config](benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py) | N/A |
| `vit-base-p16_mae-400e-pre_8xb2048-linear-coslr-90e_in1k` | [MAE 400-Epochs](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-400e_in1k/mae_vit-base-p16_8xb512-coslr-400e-fp16_in1k_20220825-bc79e40b.pth) | 86.57 | 17.58 | 62.50 | [config](benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py) | N/A |
| `vit-base-p16_mae-800e-pre_8xb2048-linear-coslr-90e_in1k` | [MAE 800-Epochs](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-800e_in1k/mae_vit-base-p16_8xb512-coslr-800e-fp16_in1k_20220825-5d81fbc4.pth) | 86.57 | 17.58 | 65.10 | [config](benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py) | N/A |
| `vit-base-p16_mae-1600e-pre_8xb2048-linear-coslr-90e_in1k` | [MAE 1600-Epochs](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-base-p16_8xb512-fp16-coslr-1600e_in1k_20220825-f7569ca2.pth) | 86.57 | 17.58 | 67.10 | [config](benchmarks/vit-base-p16_8xb2048-linear-coslr-90e_in1k.py) | N/A |
| `vit-large-p16_mae-400e-pre_8xb128-coslr-50e_in1k` | [MAE 400-Epochs](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-400e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-400e_in1k_20220825-b11d0425.pth) | 304.32 | 61.60 | 85.20 | [config](benchmarks/vit-large-p16_8xb128-coslr-50e_in1k.py) | N/A |
| `vit-large-p16_mae-800e-pre_8xb128-coslr-50e_in1k` | [MAE 800-Epochs](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-800e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-800e_in1k_20220825-df72726a.pth) | 304.32 | 61.60 | 85.40 | [config](benchmarks/vit-large-p16_8xb128-coslr-50e_in1k.py) | N/A |
| `vit-large-p16_mae-1600e-pre_8xb128-coslr-50e_in1k` | [MAE 1600-Epochs](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth) | 304.32 | 61.60 | 85.70 | [config](benchmarks/vit-large-p16_8xb128-coslr-50e_in1k.py) | N/A |
| `vit-large-p16_mae-400e-pre_8xb2048-linear-coslr-90e_in1k` | [MAE 400-Epochs](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-400e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-400e_in1k_20220825-b11d0425.pth) | 304.33 | 61.60 | 70.70 | [config](benchmarks/vit-large-p16_8xb2048-linear-coslr-90e_in1k.py) | N/A |
| `vit-large-p16_mae-800e-pre_8xb2048-linear-coslr-90e_in1k` | [MAE 800-Epochs](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-800e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-800e_in1k_20220825-df72726a.pth) | 304.33 | 61.60 | 73.70 | [config](benchmarks/vit-large-p16_8xb2048-linear-coslr-90e_in1k.py) | N/A |
| `vit-large-p16_mae-1600e-pre_8xb2048-linear-coslr-90e_in1k` | [MAE 1600-Epochs](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-large-p16_8xb512-fp16-coslr-1600e_in1k_20220825-cc7e98c9.pth) | 304.33 | 61.60 | 75.50 | [config](benchmarks/vit-large-p16_8xb2048-linear-coslr-90e_in1k.py) | N/A |
| `vit-huge-p14_mae-1600e-pre_8xb128-coslr-50e_in1k` | [MAE 1600-Epochs](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-huge-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-huge-p16_8xb512-fp16-coslr-1600e_in1k_20220916-ff848775.pth) | 632.04 | 167.40 | 86.90 | [config](benchmarks/vit-huge-p14_8xb128-coslr-50e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-huge-p16_8xb512-fp16-coslr-1600e_in1k/vit-huge-p16_ft-8xb128-coslr-50e_in1k/vit-huge-p16_ft-8xb128-coslr-50e_in1k_20220916-0bfc9bfd.pth) \| [log](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-huge-p16_8xb512-fp16-coslr-1600e_in1k/vit-huge-p16_ft-8xb128-coslr-50e_in1k/vit-huge-p16_ft-8xb128-coslr-50e_in1k_20220916-0bfc9bfd.json) |
| `vit-huge-p14_mae-1600e-pre_32xb8-coslr-50e_in1k-448px` | [MAE 1600-Epochs](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-huge-p16_8xb512-fp16-coslr-1600e_in1k/mae_vit-huge-p16_8xb512-fp16-coslr-1600e_in1k_20220916-ff848775.pth) | 633.03 | 732.13 | 87.30 | [config](benchmarks/vit-huge-p14_32xb8-coslr-50e_in1k-448px.py) | [model](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-huge-p16_8xb512-fp16-coslr-1600e_in1k/vit-huge-p16_ft-32xb8-coslr-50e_in1k-448/vit-huge-p16_ft-32xb8-coslr-50e_in1k-448_20220916-95b6a0ce.pth) \| [log](https://download.openmmlab.com/mmselfsup/1.x/mae/mae_vit-huge-p16_8xb512-fp16-coslr-1600e_in1k/vit-huge-p16_ft-32xb8-coslr-50e_in1k-448/vit-huge-p16_ft-32xb8-coslr-50e_in1k-448_20220916-95b6a0ce.json) |
## Citation
```bibtex
@article{He2021MaskedAA,
title={Masked Autoencoders Are Scalable Vision Learners},
author={Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and
Piotr Doll'ar and Ross B. Girshick},
journal={arXiv},
year={2021}
}
```
_base_ = [
'../../_base_/datasets/imagenet_bs64_swin_224.py',
'../../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../../_base_/default_runtime.py'
]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=0.3333333333333333,
fill_color=[103.53, 116.28, 123.675],
fill_std=[57.375, 57.12, 58.395]),
dict(type='PackInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=256,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackInputs')
]
train_dataloader = dict(batch_size=128, dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(batch_size=128, dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='VisionTransformer',
arch='base',
img_size=224,
patch_size=16,
drop_path_rate=0.1,
out_type='avg_featmap',
final_norm=False,
init_cfg=dict(type='Pretrained', checkpoint='', prefix='backbone.')),
neck=None,
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
init_cfg=[dict(type='TruncNormal', layer='Linear', std=2e-5)]),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))
# optimizer wrapper
optim_wrapper = dict(
optimizer=dict(
type='AdamW', lr=2e-3, weight_decay=0.05, betas=(0.9, 0.999)),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.65,
custom_keys={
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0)
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=5,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=95,
by_epoch=True,
begin=5,
end=100,
eta_min=1e-6,
convert_to_iter_based=True)
]
# runtime settings
default_hooks = dict(
# save checkpoint per epoch.
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
train_cfg = dict(by_epoch=True, max_epochs=100)
randomness = dict(seed=0, diff_rank_seed=True)
_base_ = [
'../../_base_/datasets/imagenet_bs32_pil_resize.py',
'../../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../../_base_/default_runtime.py'
]
# dataset settings
train_dataloader = dict(batch_size=2048, drop_last=True)
val_dataloader = dict(drop_last=False)
test_dataloader = dict(drop_last=False)
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='VisionTransformer',
arch='base',
img_size=224,
patch_size=16,
frozen_stages=12,
out_type='cls_token',
final_norm=True,
init_cfg=dict(type='Pretrained', checkpoint='', prefix='backbone.')),
neck=dict(type='ClsBatchNormNeck', input_features=768),
head=dict(
type='VisionTransformerClsHead',
num_classes=1000,
in_channels=768,
loss=dict(type='CrossEntropyLoss'),
init_cfg=[dict(type='TruncNormal', layer='Linear', std=0.01)]))
# optimizer
optim_wrapper = dict(
_delete_=True,
type='AmpOptimWrapper',
optimizer=dict(type='LARS', lr=6.4, weight_decay=0.0, momentum=0.9))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=10,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=80,
by_epoch=True,
begin=10,
end=90,
eta_min=0.0,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(by_epoch=True, max_epochs=90)
default_hooks = dict(
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3),
logger=dict(type='LoggerHook', interval=10))
randomness = dict(seed=0, diff_rank_seed=True)
_base_ = [
'../../_base_/datasets/imagenet_bs64_swin_224.py',
'../../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../../_base_/default_runtime.py'
]
# dataset settings
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=448,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=0.3333333333333333,
fill_color=[103.53, 116.28, 123.675],
fill_std=[57.375, 57.12, 58.395]),
dict(type='PackInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=512,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=448),
dict(type='PackInputs')
]
train_dataloader = dict(batch_size=128, dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(batch_size=128, dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='VisionTransformer',
arch='huge',
img_size=448,
patch_size=14,
drop_path_rate=0.3, # set to 0.3
out_type='avg_featmap',
final_norm=False,
init_cfg=dict(type='Pretrained', checkpoint='', prefix='backbone.')),
neck=None,
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1280,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
init_cfg=[dict(type='TruncNormal', layer='Linear', std=2e-5)]),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))
# optimizer wrapper
# learning rate and layer decay rate are set to 0.004 and 0.75 respectively
optim_wrapper = dict(
optimizer=dict(
type='AdamW', lr=4e-3, weight_decay=0.05, betas=(0.9, 0.999)),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.75,
custom_keys={
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0)
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=5,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=45,
by_epoch=True,
begin=5,
end=50,
eta_min=1e-6,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(by_epoch=True, max_epochs=50)
default_hooks = dict(
# save checkpoint per epoch.
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
randomness = dict(seed=0, diff_rank_seed=True)
_base_ = [
'../../_base_/datasets/imagenet_bs64_swin_224.py',
'../../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../../_base_/default_runtime.py'
]
# dataset settings
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(pad_val=[104, 116, 124], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=0.3333333333333333,
fill_color=[103.53, 116.28, 123.675],
fill_std=[57.375, 57.12, 58.395]),
dict(type='PackInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=256,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackInputs')
]
train_dataloader = dict(batch_size=128, dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(batch_size=128, dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='VisionTransformer',
arch='huge',
img_size=224,
patch_size=14,
drop_path_rate=0.3, # set to 0.3
out_type='avg_featmap',
final_norm=False,
init_cfg=dict(type='Pretrained', checkpoint='', prefix='backbone.')),
neck=None,
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1280,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
init_cfg=[dict(type='TruncNormal', layer='Linear', std=2e-5)]),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))
# optimizer wrapper
# learning rate and layer decay rate are set to 0.004 and 0.75 respectively
optim_wrapper = dict(
optimizer=dict(
type='AdamW', lr=4e-3, weight_decay=0.05, betas=(0.9, 0.999)),
constructor='LearningRateDecayOptimWrapperConstructor',
paramwise_cfg=dict(
layer_decay_rate=0.75,
custom_keys={
'.ln': dict(decay_mult=0.0),
'.bias': dict(decay_mult=0.0),
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0)
}))
# learning rate scheduler
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-4,
by_epoch=True,
begin=0,
end=5,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=45,
by_epoch=True,
begin=5,
end=50,
eta_min=1e-6,
convert_to_iter_based=True)
]
# runtime settings
train_cfg = dict(by_epoch=True, max_epochs=50)
default_hooks = dict(
# save checkpoint per epoch.
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=3))
randomness = dict(seed=0, diff_rank_seed=True)
_base_ = ['./vit-huge-p14_8xb128-coslr-50e_in1k.py']
# optimizer wrapper
optim_wrapper = dict(type='DeepSpeedOptimWrapper')
# training strategy
strategy = dict(
type='DeepSpeedStrategy',
fp16=dict(
enabled=True,
fp16_master_weights_and_grads=False,
loss_scale=0,
loss_scale_window=500,
hysteresis=2,
min_loss_scale=1,
initial_scale_power=15,
),
inputs_to_half=['inputs'],
zero_optimization=dict(
stage=1,
allgather_partitions=True,
reduce_scatter=True,
allgather_bucket_size=50000000,
reduce_bucket_size=50000000,
overlap_comm=True,
contiguous_gradients=True,
cpu_offload=False,
))
# runner which supports strategies
runner_type = 'FlexibleRunner'
_base_ = ['./vit-huge-p14_8xb128-coslr-50e_in1k.py']
strategy = dict(
type='FSDPStrategy',
model_wrapper=dict(
auto_wrap_policy=dict(
type='torch.distributed.fsdp.wrap.size_based_auto_wrap_policy',
min_num_params=1e7)))
optim_wrapper = dict(type='AmpOptimWrapper')
# runner which supports strategies
runner_type = 'FlexibleRunner'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment