Commit c218d1c5 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #1192 canceled with stages
[isort]
line_length=100
multi_line_output=3
include_trailing_comma=True
known_standard_library=numpy,setuptools
skip_glob=*/__init__.py
known_myself=repvit_sam
known_third_party=matplotlib,cv2,torch,torchvision,pycocotools,onnx,black,isort
no_lines_before=STDLIB,THIRDPARTY
sections=FUTURE,STDLIB,THIRDPARTY,MYSELF,FIRSTPARTY,LOCALFOLDER
default_section=FIRSTPARTY
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from setuptools import find_packages, setup
setup(
name="repvit_sam",
version="1.0",
install_requires=[],
packages=find_packages(exclude="notebooks"),
extras_require={
"all": ["matplotlib", "pycocotools", "opencv-python", "onnx", "onnxruntime"],
"dev": ["flake8", "isort", "black", "mypy"],
},
)
pretrain
work_dirs
data
seg_pretrain
\ No newline at end of file
# Semantic Segmentation
Segmentation on ADE20K is implemented based on [MMSegmentation](https://github.com/open-mmlab/mmsegmentation).
## Models
| Model | mIoU | Latency | Ckpt | Log |
|:---------------|:----:|:---:|:--:|:--:|
| RepViT-M1.1 | 40.6 | 4.9ms | [M1.1](https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_1_ade20k.pth) | [M1.1](./logs/repvit_m1_1_ade20k.json) |
| RepViT-M1.5 | 43.6 | 6.4ms | [M1.5](https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_5_ade20k.pth) | [M1.5](./logs/repvit_m1_5_ade20k.json) |
| RepViT-M2.3 | 46.1 | 9.9ms | [M2.3](https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_3_ade20k.pth) | [M2.3](./logs/repvit_m2_3_ade20k.json) |
The backbone latency is measured with image crops of 512x512 on iPhone 12 by Core ML Tools.
## Requirements
Install [mmcv-full](https://github.com/open-mmlab/mmcv) and [MMSegmentation v0.30.0](https://github.com/open-mmlab/mmsegmentation/tree/v0.30.0).
Later versions should work as well.
The easiest way is to install via [MIM](https://github.com/open-mmlab/mim)
```
pip install -U openmim
mim install mmcv-full==1.7.1
mim install mmseg==0.30.0
```
## Data preparation
We benchmark RepViT on the challenging ADE20K dataset, which can be downloaded and prepared following [insructions in MMSeg](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/dataset_prepare.md#prepare-datasets).
The data should appear as:
```
├── segmentation
│ ├── data
│ │ ├── ade
│ │ │ ├── ADEChallengeData2016
│ │ │ │ ├── annotations
│ │ │ │ │ ├── training
│ │ │ │ │ ├── validation
│ │ │ │ ├── images
│ │ │ │ │ ├── training
│ │ │ │ │ ├── validation
```
## Testing
We provide a multi-GPU testing script, specify config file, checkpoint, and number of GPUs to use:
```
./tools/dist_test.sh config_file path/to/checkpoint #GPUs --eval mIoU
```
For example, to test RepViT-M1.1 on ADE20K on an 8-GPU machine,
```
./tools/dist_test.sh configs/sem_fpn/fpn_repvit_m1_1_ade20k_40k.py path/to/repvit_m1_1_ade20k.pth 8 --eval mIoU
```
## Training
Download ImageNet-1K pretrained weights into `./pretrain`
We provide PyTorch distributed data parallel (DDP) training script `dist_train.sh`, for example, to train RepViT-M1.1 on an 8-GPU machine:
```
./tools/dist_train.sh configs/sem_fpn/fpn_repvit_m1_1_ade20k_40k.py 8
```
Tips: specify configs and #GPUs!
import mmcv
import numpy as np
from mmcv.utils import deprecated_api_warning, is_tuple_of
from numpy import random
from mmseg.datasets.builder import PIPELINES
@PIPELINES.register_module()
class AlignResize(object):
"""Resize images & seg. Align
"""
def __init__(self,
img_scale=None,
multiscale_mode='range',
ratio_range=None,
keep_ratio=True,
size_divisor=32):
if img_scale is None:
self.img_scale = None
else:
if isinstance(img_scale, list):
self.img_scale = img_scale
else:
self.img_scale = [img_scale]
assert mmcv.is_list_of(self.img_scale, tuple)
if ratio_range is not None:
# mode 1: given img_scale=None and a range of image ratio
# mode 2: given a scale and a range of image ratio
assert self.img_scale is None or len(self.img_scale) == 1
else:
# mode 3 and 4: given multiple scales or a range of scales
assert multiscale_mode in ['value', 'range']
self.multiscale_mode = multiscale_mode
self.ratio_range = ratio_range
self.keep_ratio = keep_ratio
self.size_divisor = size_divisor
@staticmethod
def random_select(img_scales):
"""Randomly select an img_scale from given candidates.
Args:
img_scales (list[tuple]): Images scales for selection.
Returns:
(tuple, int): Returns a tuple ``(img_scale, scale_dix)``,
where ``img_scale`` is the selected image scale and
``scale_idx`` is the selected index in the given candidates.
"""
assert mmcv.is_list_of(img_scales, tuple)
scale_idx = np.random.randint(len(img_scales))
img_scale = img_scales[scale_idx]
return img_scale, scale_idx
@staticmethod
def random_sample(img_scales):
"""Randomly sample an img_scale when ``multiscale_mode=='range'``.
Args:
img_scales (list[tuple]): Images scale range for sampling.
There must be two tuples in img_scales, which specify the lower
and uper bound of image scales.
Returns:
(tuple, None): Returns a tuple ``(img_scale, None)``, where
``img_scale`` is sampled scale and None is just a placeholder
to be consistent with :func:`random_select`.
"""
assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
img_scale_long = [max(s) for s in img_scales]
img_scale_short = [min(s) for s in img_scales]
long_edge = np.random.randint(
min(img_scale_long),
max(img_scale_long) + 1)
short_edge = np.random.randint(
min(img_scale_short),
max(img_scale_short) + 1)
img_scale = (long_edge, short_edge)
return img_scale, None
@staticmethod
def random_sample_ratio(img_scale, ratio_range):
"""Randomly sample an img_scale when ``ratio_range`` is specified.
A ratio will be randomly sampled from the range specified by
``ratio_range``. Then it would be multiplied with ``img_scale`` to
generate sampled scale.
Args:
img_scale (tuple): Images scale base to multiply with ratio.
ratio_range (tuple[float]): The minimum and maximum ratio to scale
the ``img_scale``.
Returns:
(tuple, None): Returns a tuple ``(scale, None)``, where
``scale`` is sampled ratio multiplied with ``img_scale`` and
None is just a placeholder to be consistent with
:func:`random_select`.
"""
assert isinstance(img_scale, tuple) and len(img_scale) == 2
min_ratio, max_ratio = ratio_range
assert min_ratio <= max_ratio
ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
return scale, None
def _random_scale(self, results):
"""Randomly sample an img_scale according to ``ratio_range`` and
``multiscale_mode``.
If ``ratio_range`` is specified, a ratio will be sampled and be
multiplied with ``img_scale``.
If multiple scales are specified by ``img_scale``, a scale will be
sampled according to ``multiscale_mode``.
Otherwise, single scale will be used.
Args:
results (dict): Result dict from :obj:`dataset`.
Returns:
dict: Two new keys 'scale` and 'scale_idx` are added into
``results``, which would be used by subsequent pipelines.
"""
if self.ratio_range is not None:
if self.img_scale is None:
h, w = results['img'].shape[:2]
scale, scale_idx = self.random_sample_ratio((w, h),
self.ratio_range)
else:
scale, scale_idx = self.random_sample_ratio(
self.img_scale[0], self.ratio_range)
elif len(self.img_scale) == 1:
scale, scale_idx = self.img_scale[0], 0
elif self.multiscale_mode == 'range':
scale, scale_idx = self.random_sample(self.img_scale)
elif self.multiscale_mode == 'value':
scale, scale_idx = self.random_select(self.img_scale)
else:
raise NotImplementedError
results['scale'] = scale
results['scale_idx'] = scale_idx
def _align(self, img, size_divisor, interpolation=None):
align_h = int(np.ceil(img.shape[0] / size_divisor)) * size_divisor
align_w = int(np.ceil(img.shape[1] / size_divisor)) * size_divisor
if interpolation == None:
img = mmcv.imresize(img, (align_w, align_h))
else:
img = mmcv.imresize(img, (align_w, align_h), interpolation=interpolation)
return img
def _resize_img(self, results):
"""Resize images with ``results['scale']``."""
if self.keep_ratio:
img, scale_factor = mmcv.imrescale(
results['img'], results['scale'], return_scale=True)
#### align ####
img = self._align(img, self.size_divisor)
# the w_scale and h_scale has minor difference
# a real fix should be done in the mmcv.imrescale in the future
new_h, new_w = img.shape[:2]
h, w = results['img'].shape[:2]
w_scale = new_w / w
h_scale = new_h / h
else:
img, w_scale, h_scale = mmcv.imresize(
results['img'], results['scale'], return_scale=True)
h, w = img.shape[:2]
assert int(np.ceil(h / self.size_divisor)) * self.size_divisor == h and \
int(np.ceil(w / self.size_divisor)) * self.size_divisor == w, \
"img size not align. h:{} w:{}".format(h, w)
scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
dtype=np.float32)
results['img'] = img
results['img_shape'] = img.shape
results['pad_shape'] = img.shape # in case that there is no padding
results['scale_factor'] = scale_factor
results['keep_ratio'] = self.keep_ratio
def _resize_seg(self, results):
"""Resize semantic segmentation map with ``results['scale']``."""
for key in results.get('seg_fields', []):
if self.keep_ratio:
gt_seg = mmcv.imrescale(
results[key], results['scale'], interpolation='nearest')
gt_seg = self._align(gt_seg, self.size_divisor, interpolation='nearest')
else:
gt_seg = mmcv.imresize(
results[key], results['scale'], interpolation='nearest')
h, w = gt_seg.shape[:2]
assert int(np.ceil(h / self.size_divisor)) * self.size_divisor == h and \
int(np.ceil(w / self.size_divisor)) * self.size_divisor == w, \
"gt_seg size not align. h:{} w:{}".format(h, w)
results[key] = gt_seg
def __call__(self, results):
"""Call function to resize images, bounding boxes, masks, semantic
segmentation map.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor',
'keep_ratio' keys are added into result dict.
"""
if 'scale' not in results:
self._random_scale(results)
self._resize_img(results)
self._resize_seg(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (f'(img_scale={self.img_scale}, '
f'multiscale_mode={self.multiscale_mode}, '
f'ratio_range={self.ratio_range}, '
f'keep_ratio={self.keep_ratio})')
return repr_str
# dataset settings
dataset_type = 'ADE20KDataset'
data_root = 'data/ade/ADEChallengeData2016'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='AlignResize', keep_ratio=True, size_divisor=32),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=50,
dataset=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/training',
ann_dir='annotations/training',
pipeline=train_pipeline)),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline))
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook', by_epoch=False),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
cudnn_benchmark = True
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 1, 1),
strides=(1, 2, 2, 2),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=4),
decode_head=dict(
type='FPNHead',
in_channels=[256, 256, 256, 256],
in_index=[0, 1, 2, 3],
feature_strides=[4, 8, 16, 32],
channels=128,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=160000)
checkpoint_config = dict(by_epoch=False, interval=16000)
evaluation = dict(interval=16000, metric='mIoU')
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=20000)
checkpoint_config = dict(by_epoch=False, interval=2000)
evaluation = dict(interval=2000, metric='mIoU')
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=40000)
checkpoint_config = dict(by_epoch=False, interval=4000)
evaluation = dict(interval=4000, metric='mIoU')
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=80000)
checkpoint_config = dict(by_epoch=False, interval=8000)
evaluation = dict(interval=8000, metric='mIoU')
_base_ = [
'../_base_/models/fpn_r50.py',
'../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py'
]
# model settings
model = dict(
type='EncoderDecoder',
backbone=dict(
type='repvit_m1_1',
style='pytorch',
init_cfg=dict(
type='Pretrained',
checkpoint='pretrain/repvit_m1_1_distill_300e.pth',
),
out_indices = [3,7,21,24]
),
neck=dict(in_channels=[64, 128, 256, 512]),
decode_head=dict(num_classes=150))
gpu_multiples = 2 # we use 8 gpu instead of 4 in mmsegmentation, so lr*2 and max_iters/2
# optimizer
optimizer = dict(type='AdamW', lr=0.0001 * gpu_multiples, weight_decay=0.0001)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-6, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=80000 // gpu_multiples)
checkpoint_config = dict(by_epoch=False, interval=8000 // gpu_multiples)
evaluation = dict(interval=8000 // gpu_multiples, metric='mIoU')
_base_ = [
'../_base_/models/fpn_r50.py',
'../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py'
]
# model settings
model = dict(
type='EncoderDecoder',
backbone=dict(
type='repvit_m1_5',
style='pytorch',
init_cfg=dict(
type='Pretrained',
checkpoint='pretrain/repvit_m1_5_distill_300e.pth',
),
out_indices=[5, 11, 37, 42]
),
neck=dict(in_channels=[64, 128, 256, 512]),
decode_head=dict(num_classes=150))
gpu_multiples = 2 # we use 8 gpu instead of 4 in mmsegmentation, so lr*2 and max_iters/2
# optimizer
optimizer = dict(type='AdamW', lr=0.0001 * gpu_multiples, weight_decay=0.0001)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-6, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=80000 // gpu_multiples)
checkpoint_config = dict(by_epoch=False, interval=8000 // gpu_multiples)
evaluation = dict(interval=8000 // gpu_multiples, metric='mIoU')
_base_ = [
'../_base_/models/fpn_r50.py',
'../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py'
]
# model settings
model = dict(
type='EncoderDecoder',
backbone=dict(
type='repvit_m2_3',
style='pytorch',
init_cfg=dict(
type='Pretrained',
checkpoint='pretrain/repvit_m2_3_distill_450e.pth',
),
out_indices=[7, 15, 51, 54]
),
neck=dict(in_channels=[80, 160, 320, 640]),
decode_head=dict(num_classes=150))
gpu_multiples = 2 # we use 8 gpu instead of 4 in mmsegmentation, so lr*2 and max_iters/2
# optimizer
optimizer = dict(type='AdamW', lr=0.0001 * gpu_multiples, weight_decay=0.0001)
optimizer_config = dict()
# learning policy
lr_config = dict(policy='poly', power=0.9, min_lr=1e-6, by_epoch=False)
# runtime settings
runner = dict(type='IterBasedRunner', max_iters=80000 // gpu_multiples)
checkpoint_config = dict(by_epoch=False, interval=8000 // gpu_multiples)
evaluation = dict(interval=8000 // gpu_multiples, metric='mIoU')
PORT=12345 ./tools/dist_test.sh configs/sem_fpn/fpn_repvit_m1_1_ade20k_40k.py seg_pretrain/repvit_m1_1_ade20k.pth 8 --eval mIoU
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
import torch.nn as nn
import numpy as np
import itertools
from mmseg.models.builder import BACKBONES
from mmseg.utils import get_root_logger
from mmcv.runner import _load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
def _make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
from timm.models.layers import SqueezeExcite
import torch
class Conv2d_BN(torch.nn.Sequential):
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
groups=1, bn_weight_init=1, resolution=-10000):
super().__init__()
self.add_module('c', torch.nn.Conv2d(
a, b, ks, stride, pad, dilation, groups, bias=False))
self.add_module('bn', torch.nn.BatchNorm2d(b))
torch.nn.init.constant_(self.bn.weight, bn_weight_init)
torch.nn.init.constant_(self.bn.bias, 0)
@torch.no_grad()
def fuse(self):
c, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps)**0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / \
(bn.running_var + bn.eps)**0.5
m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups,
device=c.weight.device)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class Residual(torch.nn.Module):
def __init__(self, m, drop=0.):
super().__init__()
self.m = m
self.drop = drop
def forward(self, x):
if self.training and self.drop > 0:
return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
device=x.device).ge_(self.drop).div(1 - self.drop).detach()
else:
return x + self.m(x)
@torch.no_grad()
def fuse(self):
if isinstance(self.m, Conv2d_BN):
m = self.m.fuse()
assert(m.groups == m.in_channels)
identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
identity = torch.nn.functional.pad(identity, [1,1,1,1])
m.weight += identity.to(m.weight.device)
return m
elif isinstance(self.m, torch.nn.Conv2d):
m = self.m
assert(m.groups != m.in_channels)
identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
identity = torch.nn.functional.pad(identity, [1,1,1,1])
m.weight += identity.to(m.weight.device)
return m
else:
return self
class RepVGGDW(torch.nn.Module):
def __init__(self, ed) -> None:
super().__init__()
self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed)
self.conv1 = torch.nn.Conv2d(ed, ed, 1, 1, 0, groups=ed)
self.dim = ed
self.bn = torch.nn.BatchNorm2d(ed)
def forward(self, x):
return self.bn((self.conv(x) + self.conv1(x)) + x)
@torch.no_grad()
def fuse(self):
conv = self.conv.fuse()
conv1 = self.conv1
conv_w = conv.weight
conv_b = conv.bias
conv1_w = conv1.weight
conv1_b = conv1.bias
conv1_w = torch.nn.functional.pad(conv1_w, [1,1,1,1])
identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1,1,1,1])
final_conv_w = conv_w + conv1_w + identity
final_conv_b = conv_b + conv1_b
conv.weight.data.copy_(final_conv_w)
conv.bias.data.copy_(final_conv_b)
bn = self.bn
w = bn.weight / (bn.running_var + bn.eps)**0.5
w = conv.weight * w[:, None, None, None]
b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / \
(bn.running_var + bn.eps)**0.5
conv.weight.data.copy_(w)
conv.bias.data.copy_(b)
return conv
class RepViTBlock(nn.Module):
def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
super(RepViTBlock, self).__init__()
assert stride in [1, 2]
self.identity = stride == 1 and inp == oup
assert(hidden_dim == 2 * inp)
if stride == 2:
self.token_mixer = nn.Sequential(
Conv2d_BN(inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp),
SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
Conv2d_BN(inp, oup, ks=1, stride=1, pad=0)
)
self.channel_mixer = Residual(nn.Sequential(
# pw
Conv2d_BN(oup, 2 * oup, 1, 1, 0),
nn.GELU() if use_hs else nn.GELU(),
# pw-linear
Conv2d_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
))
else:
assert(self.identity)
self.token_mixer = nn.Sequential(
RepVGGDW(inp),
SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
)
self.channel_mixer = Residual(nn.Sequential(
# pw
Conv2d_BN(inp, hidden_dim, 1, 1, 0),
nn.GELU() if use_hs else nn.GELU(),
# pw-linear
Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
))
def forward(self, x):
return self.channel_mixer(self.token_mixer(x))
from timm.models.vision_transformer import trunc_normal_
class BN_Linear(torch.nn.Sequential):
def __init__(self, a, b, bias=True, std=0.02):
super().__init__()
self.add_module('bn', torch.nn.BatchNorm1d(a))
self.add_module('l', torch.nn.Linear(a, b, bias=bias))
trunc_normal_(self.l.weight, std=std)
if bias:
torch.nn.init.constant_(self.l.bias, 0)
@torch.no_grad()
def fuse(self):
bn, l = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps)**0.5
b = bn.bias - self.bn.running_mean * \
self.bn.weight / (bn.running_var + bn.eps)**0.5
w = l.weight * w[None, :]
if l.bias is None:
b = b @ self.l.weight.T
else:
b = (l.weight @ b[:, None]).view(-1) + self.l.bias
m = torch.nn.Linear(w.size(1), w.size(0), device=l.weight.device)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class RepViT(nn.Module):
def __init__(self, cfgs, distillation=False, pretrained=None, init_cfg=None, out_indices=[]):
super(RepViT, self).__init__()
# setting of inverted residual blocks
self.cfgs = cfgs
# building first layer
input_channel = self.cfgs[0][2]
patch_embed = torch.nn.Sequential(Conv2d_BN(3, input_channel // 2, 3, 2, 1), torch.nn.GELU(),
Conv2d_BN(input_channel // 2, input_channel, 3, 2, 1))
layers = [patch_embed]
# building inverted residual blocks
block = RepViTBlock
for k, t, c, use_se, use_hs, s in self.cfgs:
output_channel = _make_divisible(c, 8)
exp_size = _make_divisible(input_channel * t, 8)
layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs))
input_channel = output_channel
self.features = nn.ModuleList(layers)
self.init_cfg = init_cfg
assert(self.init_cfg is not None)
self.out_indices = out_indices
self.init_weights()
self = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self)
self.train()
def init_weights(self, pretrained=None):
logger = get_root_logger()
if self.init_cfg is None and pretrained is None:
logger.warn(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
pass
else:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
if self.init_cfg is not None:
ckpt_path = self.init_cfg['checkpoint']
elif pretrained is not None:
ckpt_path = pretrained
ckpt = _load_checkpoint(
ckpt_path, logger=logger, map_location='cpu')
if 'state_dict' in ckpt:
_state_dict = ckpt['state_dict']
elif 'model' in ckpt:
_state_dict = ckpt['model']
else:
_state_dict = ckpt
state_dict = _state_dict
missing_keys, unexpected_keys = \
self.load_state_dict(state_dict, False)
logger.info(f"Miss {missing_keys}")
logger.info(f"Unexpected {unexpected_keys}")
def train(self, mode=True):
"""Convert the model into training mode while keep layers freezed."""
super(RepViT, self).train(mode)
if mode:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
def forward(self, x):
outs = []
for i, f in enumerate(self.features):
x = f(x)
if i in self.out_indices:
outs.append(x)
assert(len(outs) == 4)
return outs
from timm.models import register_model
@BACKBONES.register_module()
def repvit_m1_1(pretrained=False, num_classes = 1000, distillation=False, init_cfg=None, out_indices=[], **kwargs):
"""
Constructs a MobileNetV3-Large model
"""
cfgs = [
# k, t, c, SE, HS, s
[3, 2, 64, 1, 0, 1],
[3, 2, 64, 0, 0, 1],
[3, 2, 64, 0, 0, 1],
[3, 2, 128, 0, 0, 2],
[3, 2, 128, 1, 0, 1],
[3, 2, 128, 0, 0, 1],
[3, 2, 128, 0, 0, 1],
[3, 2, 256, 0, 1, 2],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 512, 0, 1, 2],
[3, 2, 512, 1, 1, 1],
[3, 2, 512, 0, 1, 1]
]
return RepViT(cfgs, init_cfg=init_cfg, pretrained=pretrained, distillation=distillation, out_indices=out_indices)
@BACKBONES.register_module()
def repvit_m1_5(pretrained=False, num_classes = 1000, distillation=False, init_cfg=None, out_indices=[], **kwargs):
"""
Constructs a MobileNetV3-Large model
"""
cfgs = [
# k, t, c, SE, HS, s
[3, 2, 64, 1, 0, 1],
[3, 2, 64, 0, 0, 1],
[3, 2, 64, 1, 0, 1],
[3, 2, 64, 0, 0, 1],
[3, 2, 64, 0, 0, 1],
[3, 2, 128, 0, 0, 2],
[3, 2, 128, 1, 0, 1],
[3, 2, 128, 0, 0, 1],
[3, 2, 128, 1, 0, 1],
[3, 2, 128, 0, 0, 1],
[3, 2, 128, 0, 0, 1],
[3, 2, 256, 0, 1, 2],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 512, 0, 1, 2],
[3, 2, 512, 1, 1, 1],
[3, 2, 512, 0, 1, 1],
[3, 2, 512, 1, 1, 1],
[3, 2, 512, 0, 1, 1]
]
return RepViT(cfgs, init_cfg=init_cfg, pretrained=pretrained, distillation=distillation, out_indices=out_indices)
@BACKBONES.register_module()
def repvit_m2_3(pretrained=False, num_classes = 1000, distillation=False, init_cfg=None, out_indices=[], **kwargs):
"""
Constructs a MobileNetV3-Large model
"""
cfgs = [
# k, t, c, SE, HS, s
[3, 2, 80, 1, 0, 1],
[3, 2, 80, 0, 0, 1],
[3, 2, 80, 1, 0, 1],
[3, 2, 80, 0, 0, 1],
[3, 2, 80, 1, 0, 1],
[3, 2, 80, 0, 0, 1],
[3, 2, 80, 0, 0, 1],
[3, 2, 160, 0, 0, 2],
[3, 2, 160, 1, 0, 1],
[3, 2, 160, 0, 0, 1],
[3, 2, 160, 1, 0, 1],
[3, 2, 160, 0, 0, 1],
[3, 2, 160, 1, 0, 1],
[3, 2, 160, 0, 0, 1],
[3, 2, 160, 0, 0, 1],
[3, 2, 320, 0, 1, 2],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
# [3, 2, 320, 1, 1, 1],
# [3, 2, 320, 0, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 640, 0, 1, 2],
[3, 2, 640, 1, 1, 1],
[3, 2, 640, 0, 1, 1],
# [3, 2, 640, 1, 1, 1],
# [3, 2, 640, 0, 1, 1]
]
return RepViT(cfgs, init_cfg=init_cfg, pretrained=pretrained, distillation=distillation, out_indices=out_indices)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment