Commit fb54db0f authored by limm's avatar limm
Browse files

add projects code

parent 1ac2e802
Pipeline #2804 canceled with stages
# Welcome to Projects of MMPreTrain
In this folder, we welcome all contribution of vision deep-learning backbone from community.
Here, these requirements, e.g. code standards, are not that strict as in core package. Thus, developers from the community can implement their algorithms much more easily and efficiently in MMPreTrain. We appreciate all contributions from community to make MMPreTrain greater.
Here is an [example project](./example_project) about how to add your algorithms easily.
We also provide some documentation listed below:
- [New Model Guide](https://mmpretrain.readthedocs.io/en/latest/advanced_guides/modules.html)
The documentation of adding new models.
- [Contribution Guide](https://mmpretrain.readthedocs.io/en/latest/notes/contribution_guide.html)
The guides for new contributors about how to add your projects to MMPreTrain.
- [Discussions](https://github.com/open-mmlab/mmpretrain/discussions)
Welcome to start discussion!
# Implementation for DINO
**NOTE**: We only guarantee correctness of the forward pass, not responsible for full reimplementation.
First, ensure you are in the root directory of MMPretrain, then you have two choices
to play with DINO in MMPretrain:
## Slurm
If you are using a cluster managed by Slurm, you can use the following command to
start your job:
```shell
GPUS_PER_NODE=8 GPUS=8 CPUS_PER_TASK=16 bash projects/dino/tools/slurm_train.sh mm_model dino projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py --amp
```
The above command will pre-train the model on a single node with 8 GPUs.
## PyTorch
If you are using a single machine, without any cluster management software, you can use the following command
```shell
NNODES=1 bash projects/dino/tools/dist_train.sh projects/dino/config/dino_vit-base-p16_8xb64-amp-coslr-100e_in1k.py 8
--amp
```
model = dict(
type='DINO',
data_preprocessor=dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
type='mmpretrain.VisionTransformer', arch='b', patch_size=16),
neck=dict(
type='DINONeck',
in_channels=768,
out_channels=65536,
hidden_channels=2048,
bottleneck_channels=256),
head=dict(
type='DINOHead',
out_channels=65536,
num_crops=10,
student_temp=0.1,
center_momentum=0.9))
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='DINOMultiCrop',
global_crops_scale=(0.4, 1.0),
local_crops_scale=(0.05, 0.4),
local_crops_number=8),
dict(type='PackInputs')
]
train_dataloader = dict(
batch_size=32,
num_workers=16,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
collate_fn=dict(type='default_collate'),
dataset=dict(
type='mmpretrain.ImageNet',
data_root='/data/imagenet/',
ann_file='meta/train.txt',
data_prefix=dict(img_path='train/'),
pipeline=train_pipeline,
))
optimizer = dict(type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05)
optim_wrapper = dict(
type='AmpOptimWrapper',
optimizer=dict(
type='AdamW', lr=0.0024, betas=(0.9, 0.95), weight_decay=0.05),
paramwise_cfg=dict(
custom_keys=dict(
ln=dict(decay_mult=0.0),
bias=dict(decay_mult=0.0),
pos_embed=dict(decay_mult=0.0),
mask_token=dict(decay_mult=0.0),
cls_token=dict(decay_mult=0.0))),
loss_scale='dynamic')
param_scheduler = [
dict(
type='LinearLR',
start_factor=1e-09,
by_epoch=True,
begin=0,
end=10,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=90,
by_epoch=True,
begin=10,
end=100,
convert_to_iter_based=True)
]
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=100)
default_scope = 'mmpretrain'
default_hooks = dict(
runtime_info=dict(type='RuntimeInfoHook'),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=100),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=1),
sampler_seed=dict(type='DistSamplerSeedHook'))
env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'))
log_processor = dict(
window_size=10,
custom_cfg=[dict(data_src='', method='mean', window_size='global')])
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='UniversalVisualizer',
vis_backends=[dict(type='LocalVisBackend')],
name='visualizer')
log_level = 'INFO'
load_from = None
resume = True
randomness = dict(seed=2, diff_rank_seed=True)
custom_hooks = [
dict(
type='DINOTeacherTempWarmupHook',
warmup_teacher_temp=0.04,
teacher_temp=0.04,
teacher_temp_warmup_epochs=0,
max_epochs=100)
]
from .transform import * # noqa: F401,F403
from .processing import DINOMultiCrop
__all__ = ['DINOMultiCrop']
# Copyright (c) OpenMMLab. All rights reserved.
import random
from mmcv.transforms import RandomApply # noqa: E501
from mmcv.transforms import BaseTransform, Compose, RandomFlip, RandomGrayscale
from mmpretrain.datasets.transforms import (ColorJitter, GaussianBlur,
RandomResizedCrop, Solarize)
from mmpretrain.registry import TRANSFORMS
@TRANSFORMS.register_module()
class DINOMultiCrop(BaseTransform):
"""Multi-crop transform for DINO.
This module applies the multi-crop transform for DINO.
Args:
global_crops_scale (int): Scale of global crops.
local_crops_scale (int): Scale of local crops.
local_crops_number (int): Number of local crops.
"""
def __init__(self, global_crops_scale: int, local_crops_scale: int,
local_crops_number: int) -> None:
super().__init__()
self.global_crops_scale = global_crops_scale
self.local_crops_scale = local_crops_scale
flip_and_color_jitter = Compose([
RandomFlip(prob=0.5, direction='horizontal'),
RandomApply([
ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)
],
prob=0.8),
RandomGrayscale(
prob=0.2,
keep_channels=True,
channel_weights=(0.114, 0.587, 0.2989),
)
])
self.global_transform_1 = Compose([
RandomResizedCrop(
224,
crop_ratio_range=global_crops_scale,
interpolation='bicubic'),
flip_and_color_jitter,
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)),
])
self.global_transform_2 = Compose([
RandomResizedCrop(
224,
crop_ratio_range=global_crops_scale,
interpolation='bicubic'),
flip_and_color_jitter,
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)),
Solarize(thr=128, prob=0.2),
])
self.local_crops_number = local_crops_number
self.local_transform = Compose([
RandomResizedCrop(
96,
crop_ratio_range=local_crops_scale,
interpolation='bicubic'),
flip_and_color_jitter,
GaussianBlur(prob=1.0, radius=random.uniform(0.1, 2.0)),
])
def transform(self, results: dict) -> dict:
ori_img = results['img']
crops = []
results['img'] = ori_img
crops.append(self.global_transform_1(results)['img'])
results['img'] = ori_img
crops.append(self.global_transform_2(results)['img'])
for _ in range(self.local_crops_number):
results['img'] = ori_img
crops.append(self.local_transform(results)['img'])
results['img'] = crops
return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(global_crops_scale = {self.global_crops_scale}, '
repr_str += f'local_crops_scale = {self.local_crops_scale}, '
repr_str += f'local_crop_number = {self.local_crops_number})'
return repr_str
from .hooks import * # noqa
from .dino_teacher_temp_warmup_hook import DINOTeacherTempWarmupHook
__all__ = ['DINOTeacherTempWarmupHook']
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmengine.hooks import Hook
from mmpretrain.registry import HOOKS
@HOOKS.register_module()
class DINOTeacherTempWarmupHook(Hook):
"""Warmup teacher temperature for DINO.
This hook warmups the temperature for teacher to stabilize the training
process.
Args:
warmup_teacher_temp (float): Warmup temperature for teacher.
teacher_temp (float): Temperature for teacher.
teacher_temp_warmup_epochs (int): Warmup epochs for teacher
temperature.
max_epochs (int): Maximum epochs for training.
"""
def __init__(self, warmup_teacher_temp: float, teacher_temp: float,
teacher_temp_warmup_epochs: int, max_epochs: int) -> None:
super().__init__()
self.teacher_temps = np.concatenate(
(np.linspace(warmup_teacher_temp, teacher_temp,
teacher_temp_warmup_epochs),
np.ones(max_epochs - teacher_temp_warmup_epochs) * teacher_temp))
def before_train_epoch(self, runner) -> None:
runner.model.module.head.teacher_temp = self.teacher_temps[
runner.epoch]
from .algorithm import * # noqa
from .head import * # noqa
from .neck import * # noqa
from .dino import DINO
__all__ = ['DINO']
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch
from torch import nn
from mmpretrain.models import BaseSelfSupervisor, CosineEMA
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
@MODELS.register_module()
class DINO(BaseSelfSupervisor):
"""Implementation for DINO.
This module is proposed in `DINO: Emerging Properties in Self-Supervised
Vision Transformers <https://arxiv.org/abs/2104.14294>`_.
Args:
backbone (dict): Config for backbone.
neck (dict): Config for neck.
head (dict): Config for head.
pretrained (str, optional): Path for pretrained model.
Defaults to None.
base_momentum (float, optional): Base momentum for momentum update.
Defaults to 0.99.
data_preprocessor (dict, optional): Config for data preprocessor.
Defaults to None.
init_cfg (list[dict] | dict, optional): Config for initialization.
Defaults to None.
"""
def __init__(self,
backbone: dict,
neck: dict,
head: dict,
pretrained: Optional[str] = None,
base_momentum: float = 0.99,
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[Union[List[dict], dict]] = None) -> None:
super().__init__(
backbone=backbone,
neck=neck,
head=head,
pretrained=pretrained,
data_preprocessor=data_preprocessor,
init_cfg=init_cfg)
# create momentum model
self.teacher = CosineEMA(
nn.Sequential(self.backbone, self.neck), momentum=base_momentum)
# weight normalization layer
self.neck.last_layer = nn.utils.weight_norm(self.neck.last_layer)
self.neck.last_layer.weight_g.data.fill_(1)
self.neck.last_layer.weight_g.requires_grad = False
self.teacher.module[1].last_layer = nn.utils.weight_norm(
self.teacher.module[1].last_layer)
self.teacher.module[1].last_layer.weight_g.data.fill_(1)
self.teacher.module[1].last_layer.weight_g.requires_grad = False
def loss(self, inputs: torch.Tensor,
data_samples: List[DataSample]) -> dict:
global_crops = torch.cat(inputs[:2])
local_crops = torch.cat(inputs[2:])
# teacher forward
teacher_output = self.teacher(global_crops)
# student forward global
student_output_global = self.backbone(global_crops)
student_output_global = self.neck(student_output_global)
# student forward local
student_output_local = self.backbone(local_crops)
student_output_local = self.neck(student_output_local)
student_output = torch.cat(
(student_output_global, student_output_local))
# compute loss
loss = self.head(student_output, teacher_output)
return dict(loss=loss)
from .dino_head import DINOHead
__all__ = ['DINOHead']
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from mmengine.dist import all_reduce, get_world_size
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
@MODELS.register_module()
class DINOHead(BaseModule):
"""Implementation for DINO head.
This module is proposed in `DINO: Emerging Properties in Self-Supervised
Vision Transformers <https://arxiv.org/abs/2104.14294>`_.
Args:
out_channels (int): Output channels of the head.
num_crops (int): Number of crops.
student_temp (float): Temperature for student output.
center_momentum (float): Momentum for center update.
"""
def __init__(self, out_channels: int, num_crops: int, student_temp: float,
center_momentum: float) -> None:
super().__init__()
self.student_temp = student_temp
self.teacher_temp = 0
self.center_momentum = center_momentum
self.num_crops = num_crops
self.register_buffer('center', torch.zeros(1, out_channels))
def forward(self, student_output: torch.Tensor,
teacher_output: torch.Tensor) -> torch.Tensor:
current_teacher_output = teacher_output
student_output = student_output / self.student_temp
student_output = student_output.chunk(self.num_crops, dim=0)
# teacher centering and sharpening
teacher_output = F.softmax(
(teacher_output - self.center) / self.teacher_temp, dim=-1)
teacher_output = teacher_output.detach().chunk(2, dim=0)
total_loss = 0
n_loss_terms = 0
for i in range(len(teacher_output)):
for j in range(len(student_output)):
if i == j:
continue
total_loss += (-teacher_output[i] *
student_output[j].log_softmax(dim=-1)).sum(
dim=-1).mean()
n_loss_terms += 1
total_loss /= n_loss_terms
self.update_center(current_teacher_output)
return total_loss
@torch.no_grad()
def update_center(self, teacher_output: torch.Tensor) -> None:
batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
all_reduce(batch_center)
batch_center = batch_center / (len(teacher_output) * get_world_size())
# ema update batch center
self.center = self.center * self.center_momentum + batch_center * (
1 - self.center_momentum)
from .dino_neck import DINONeck
__all__ = ['DINONeck']
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.model import BaseModule
from torch import nn
from mmpretrain.registry import MODELS
@MODELS.register_module()
class DINONeck(BaseModule):
"""Implementation for DINO neck.
This module is proposed in `DINO: Emerging Properties in Self-Supervised
Vision Transformers <https://arxiv.org/abs/2104.14294>`_.
Args:
in_channels (int): Input channels.
hidden_channels (int): Hidden channels.
out_channels (int): Output channels.
bottleneck_channels (int): Bottleneck channels.
"""
def __init__(self, in_channels: int, hidden_channels: int,
out_channels: int, bottleneck_channels: int) -> None:
super().__init__()
self.mlp = nn.Sequential(*[
nn.Linear(in_channels, hidden_channels),
nn.GELU(),
nn.Linear(hidden_channels, hidden_channels),
nn.GELU(),
nn.Linear(hidden_channels, bottleneck_channels),
])
self.last_layer = nn.Linear(
bottleneck_channels, out_channels, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.mlp(x[0])
x = nn.functional.normalize(x, dim=-1, p=2)
x = self.last_layer(x)
return x
#!/usr/bin/env bash
CONFIG=$1
GPUS=$2
NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
PORT=${PORT:-29500}
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch \
--nnodes=$NNODES \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--nproc_per_node=$GPUS \
--master_port=$PORT \
$(dirname "$0")/train.py \
$CONFIG \
--launcher pytorch ${@:3}
#!/usr/bin/env bash
set -x
PARTITION=$1
JOB_NAME=$2
CONFIG=$3
GPUS=${GPUS:-8}
GPUS_PER_NODE=${GPUS_PER_NODE:-8}
CPUS_PER_TASK=${CPUS_PER_TASK:-5}
SRUN_ARGS=${SRUN_ARGS:-""}
PY_ARGS=${@:4}
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
srun -p ${PARTITION} \
--job-name=${JOB_NAME} \
--gres=gpu:${GPUS_PER_NODE} \
--ntasks=${GPUS} \
--ntasks-per-node=${GPUS_PER_NODE} \
--cpus-per-task=${CPUS_PER_TASK} \
--kill-on-bad-exit=1 \
${SRUN_ARGS} \
python -u projects/dino/tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS}
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import os.path as osp
from dataset import * # noqa: F401,F403
from engine import * # noqa: F401,F403
from mmengine.config import Config, DictAction
from mmengine.runner import Runner
from models.algorithm import * # noqa: F401,F403
from models.head import * # noqa: F401,F403
from models.neck import * # noqa: F401,F403
from mmpretrain.utils import register_all_modules
def parse_args():
parser = argparse.ArgumentParser(description='Train a model')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume',
nargs='?',
type=str,
const='auto',
help='If specify checkpint path, resume from it, while if not '
'specify, try to auto resume from the latest checkpoint '
'in the work directory.')
parser.add_argument(
'--amp',
action='store_true',
help='enable automatic-mixed-precision training')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
# register all modules in mmpretrain into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)
# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
work_type = args.config.split('/')[1]
cfg.work_dir = osp.join('./work_dirs', work_type,
osp.splitext(osp.basename(args.config))[0])
# enable automatic-mixed-precision training
if args.amp is True:
optim_wrapper = cfg.optim_wrapper.get('type', 'OptimWrapper')
assert optim_wrapper in ['OptimWrapper', 'AmpOptimWrapper'], \
'`--amp` is not supported custom optimizer wrapper type ' \
f'`{optim_wrapper}.'
cfg.optim_wrapper.type = 'AmpOptimWrapper'
cfg.optim_wrapper.setdefault('loss_scale', 'dynamic')
# resume training
if args.resume == 'auto':
cfg.resume = True
cfg.load_from = None
elif args.resume is not None:
cfg.resume = True
cfg.load_from = args.resume
# build the runner from config
runner = Runner.from_cfg(cfg)
# start training
runner.train()
if __name__ == '__main__':
main()
# Example Project
This is an example README for community `projects/`. You can write your README in your own project. Here are
some recommended parts of a README for others to understand and use your project, you can copy or modify them
according to your project.
## Usage
### Setup Environment
Please refer to [Get Started](https://mmpretrain.readthedocs.io/en/latest/get_started.html) to install
MMPreTrain.
At first, add the current folder to `PYTHONPATH`, so that Python can find your code. Run command in the current directory to add it.
> Please run it every time after you opened a new shell.
```shell
export PYTHONPATH=`pwd`:$PYTHONPATH
```
### Data Preparation
Prepare the ImageNet-2012 dataset according to the [instruction](https://mmpretrain.readthedocs.io/en/latest/user_guides/dataset_prepare.html#imagenet).
### Training commands
**To train with single GPU:**
```bash
mim train mmpretrain configs/examplenet_8xb32_in1k.py
```
**To train with multiple GPUs:**
```bash
mim train mmpretrain configs/examplenet_8xb32_in1k.py --launcher pytorch --gpus 8
```
**To train with multiple GPUs by slurm:**
```bash
mim train mmpretrain configs/examplenet_8xb32_in1k.py --launcher slurm \
--gpus 16 --gpus-per-node 8 --partition $PARTITION
```
### Testing commands
**To test with single GPU:**
```bash
mim test mmpretrain configs/examplenet_8xb32_in1k.py --checkpoint $CHECKPOINT
```
**To test with multiple GPUs:**
```bash
mim test mmpretrain configs/examplenet_8xb32_in1k.py --checkpoint $CHECKPOINT --launcher pytorch --gpus 8
```
**To test with multiple GPUs by slurm:**
```bash
mim test mmpretrain configs/examplenet_8xb32_in1k.py --checkpoint $CHECKPOINT --launcher slurm \
--gpus 16 --gpus-per-node 8 --partition $PARTITION
```
## Results
| Model | Pretrain | Top-1 (%) | Top-5 (%) | Config | Download |
| :----------------: | :----------: | :-------: | :-------: | :-------------------------------------: | :------------------------------------: |
| ExampleNet-tiny | From scratch | 82.33 | 96.15 | [config](./mvitv2-tiny_8xb256_in1k.py) | [model](MODEL-LINK) \| [log](LOG-LINK) |
| ExampleNet-small\* | From scratch | 83.63 | 96.51 | [config](./mvitv2-small_8xb256_in1k.py) | [model](MODEL-LINK) |
| ExampleNet-base\* | From scratch | 84.34 | 96.86 | [config](./mvitv2-base_8xb256_in1k.py) | [model](MODEL-LINK) |
*Models with * are converted from the [official repo](REPO-LINK). The config files of these models are only for inference. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
## Citation
<!-- Replace to the citation of the paper your project refers to. -->
```BibTeX
@misc{2023mmpretrain,
title={OpenMMLab's Pre-training Toolbox and Benchmark},
author={MMPreTrain Contributors},
howpublished = {\url{https://github.com/open-mmlab/mmpretrain}},
year={2023}
}
```
## Checklist
Here is a checklist of this project's progress. And you can ignore this part if you don't plan to contribute
to MMPreTrain projects.
- [ ] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
- [ ] Finish the code
<!-- The code's design shall follow existing interfaces and convention. For example, each model component should be registered into `mmpretrain.registry.MODELS` and configurable via a config file. -->
- [ ] Basic docstrings & proper citation
<!-- Each major class should contains a docstring, describing its functionality and arguments. If your code is copied or modified from other open-source projects, don't forget to cite the source project in docstring and make sure your behavior is not against its license. Typically, we do not accept any code snippet under GPL license. [A Short Guide to Open Source Licenses](https://medium.com/nationwide-technology/a-short-guide-to-open-source-licenses-cf5b1c329edd) -->
- [ ] Converted checkpoint and results (Only for reproduction)
<!-- If you are reproducing the result from a paper, make sure the model in the project can match that results. Also please provide checkpoint links or a checkpoint conversion script for others to get the pre-trained model. -->
- [ ] Milestone 2: Indicates a successful model implementation.
- [ ] Training results
<!-- If you are reproducing the result from a paper, train your model from scratch and verified that the final result can match the original result. Usually, ±0.1% is acceptable for the image classification task on ImageNet-1k. -->
- [ ] Milestone 3: Good to be a part of our core package!
- [ ] Unit tests
<!-- Unit tests for the major module are required. [Example](https://github.com/open-mmlab/mmpretrain/blob/main/tests/test_models/test_backbones/test_vision_transformer.py) -->
- [ ] Code style
<!-- Refactor your code according to reviewer's comment. -->
- [ ] `metafile.yml` and `README.md`
<!-- It will used for MMPreTrain to acquire your models. [Example](https://github.com/open-mmlab/mmpretrain/blob/main/configs/mvit/metafile.yml). In particular, you may have to refactor this README into a standard one. [Example](https://github.com/open-mmlab/mmpretrain/blob/main/configs/swin_transformer/README.md) -->
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment