Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmpretrain
Commits
cbc25585
Commit
cbc25585
authored
Jun 24, 2025
by
limm
Browse files
add mmpretrain/ part
parent
1baf0566
Pipeline
#2801
canceled with stages
Changes
268
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
7776 additions
and
0 deletions
+7776
-0
mmpretrain/models/backbones/repmlp.py
mmpretrain/models/backbones/repmlp.py
+578
-0
mmpretrain/models/backbones/repvgg.py
mmpretrain/models/backbones/repvgg.py
+622
-0
mmpretrain/models/backbones/res2net.py
mmpretrain/models/backbones/res2net.py
+317
-0
mmpretrain/models/backbones/resnest.py
mmpretrain/models/backbones/resnest.py
+339
-0
mmpretrain/models/backbones/resnet.py
mmpretrain/models/backbones/resnet.py
+768
-0
mmpretrain/models/backbones/resnet_cifar.py
mmpretrain/models/backbones/resnet_cifar.py
+81
-0
mmpretrain/models/backbones/resnext.py
mmpretrain/models/backbones/resnext.py
+148
-0
mmpretrain/models/backbones/revvit.py
mmpretrain/models/backbones/revvit.py
+671
-0
mmpretrain/models/backbones/riformer.py
mmpretrain/models/backbones/riformer.py
+390
-0
mmpretrain/models/backbones/seresnet.py
mmpretrain/models/backbones/seresnet.py
+125
-0
mmpretrain/models/backbones/seresnext.py
mmpretrain/models/backbones/seresnext.py
+155
-0
mmpretrain/models/backbones/shufflenet_v1.py
mmpretrain/models/backbones/shufflenet_v1.py
+321
-0
mmpretrain/models/backbones/shufflenet_v2.py
mmpretrain/models/backbones/shufflenet_v2.py
+305
-0
mmpretrain/models/backbones/sparse_convnext.py
mmpretrain/models/backbones/sparse_convnext.py
+298
-0
mmpretrain/models/backbones/sparse_resnet.py
mmpretrain/models/backbones/sparse_resnet.py
+179
-0
mmpretrain/models/backbones/swin_transformer.py
mmpretrain/models/backbones/swin_transformer.py
+585
-0
mmpretrain/models/backbones/swin_transformer_v2.py
mmpretrain/models/backbones/swin_transformer_v2.py
+567
-0
mmpretrain/models/backbones/t2t_vit.py
mmpretrain/models/backbones/t2t_vit.py
+447
-0
mmpretrain/models/backbones/timm_backbone.py
mmpretrain/models/backbones/timm_backbone.py
+111
-0
mmpretrain/models/backbones/tinyvit.py
mmpretrain/models/backbones/tinyvit.py
+769
-0
No files found.
Too many changes to show.
To preserve performance only
268 of 268+
files are displayed.
Plain diff
Email patch
mmpretrain/models/backbones/repmlp.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
# Adapted from official impl at https://github.com/DingXiaoH/RepMLP.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
(
ConvModule
,
build_activation_layer
,
build_conv_layer
,
build_norm_layer
)
from
mmcv.cnn.bricks.transformer
import
PatchEmbed
as
_PatchEmbed
from
mmengine.model
import
BaseModule
,
ModuleList
,
Sequential
from
mmpretrain.models.utils
import
SELayer
,
to_2tuple
from
mmpretrain.registry
import
MODELS
def
fuse_bn
(
conv_or_fc
,
bn
):
"""fuse conv and bn."""
std
=
(
bn
.
running_var
+
bn
.
eps
).
sqrt
()
tmp_weight
=
bn
.
weight
/
std
tmp_weight
=
tmp_weight
.
reshape
(
-
1
,
1
,
1
,
1
)
if
len
(
tmp_weight
)
==
conv_or_fc
.
weight
.
size
(
0
):
return
(
conv_or_fc
.
weight
*
tmp_weight
,
bn
.
bias
-
bn
.
running_mean
*
bn
.
weight
/
std
)
else
:
# in RepMLPBlock, dim0 of fc3 weights and fc3_bn weights
# are different.
repeat_times
=
conv_or_fc
.
weight
.
size
(
0
)
//
len
(
tmp_weight
)
repeated
=
tmp_weight
.
repeat_interleave
(
repeat_times
,
0
)
fused_weight
=
conv_or_fc
.
weight
*
repeated
bias
=
bn
.
bias
-
bn
.
running_mean
*
bn
.
weight
/
std
fused_bias
=
(
bias
).
repeat_interleave
(
repeat_times
,
0
)
return
(
fused_weight
,
fused_bias
)
class
PatchEmbed
(
_PatchEmbed
):
"""Image to Patch Embedding.
Compared with default Patch Embedding(in ViT), Patch Embedding of RepMLP
have ReLu and do not convert output tensor into shape (N, L, C).
Args:
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
conv_type (str): The type of convolution
to generate patch embedding. Default: "Conv2d".
kernel_size (int): The kernel_size of embedding conv. Default: 16.
stride (int): The slide stride of embedding conv.
Default: 16.
padding (int | tuple | string): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Default: "corner".
dilation (int): The dilation rate of embedding conv. Default: 1.
bias (bool): Bias of embed conv. Default: True.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: None.
input_size (int | tuple | None): The size of input, which will be
used to calculate the out size. Only works when `dynamic_size`
is False. Default: None.
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
Default: None.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
PatchEmbed
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
"""
Args:
x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): The output tensor.
- out_size (tuple[int]): Spatial shape of x, arrange as
(out_h, out_w).
"""
if
self
.
adaptive_padding
:
x
=
self
.
adaptive_padding
(
x
)
x
=
self
.
projection
(
x
)
if
self
.
norm
is
not
None
:
x
=
self
.
norm
(
x
)
x
=
self
.
relu
(
x
)
out_size
=
(
x
.
shape
[
2
],
x
.
shape
[
3
])
return
x
,
out_size
class
GlobalPerceptron
(
SELayer
):
"""GlobalPerceptron implemented by using ``mmpretrain.modes.SELayer``.
Args:
input_channels (int): The number of input (and output) channels
in the GlobalPerceptron.
ratio (int): Squeeze ratio in GlobalPerceptron, the intermediate
channel will be ``make_divisible(channels // ratio, divisor)``.
"""
def
__init__
(
self
,
input_channels
:
int
,
ratio
:
int
,
**
kwargs
)
->
None
:
super
(
GlobalPerceptron
,
self
).
__init__
(
channels
=
input_channels
,
ratio
=
ratio
,
return_weight
=
True
,
act_cfg
=
(
dict
(
type
=
'ReLU'
),
dict
(
type
=
'Sigmoid'
)),
**
kwargs
)
class
RepMLPBlock
(
BaseModule
):
"""Basic RepMLPNet, consists of PartitionPerceptron and GlobalPerceptron.
Args:
channels (int): The number of input and the output channels of the
block.
path_h (int): The height of patches.
path_w (int): The weidth of patches.
reparam_conv_kernels (Squeue(int) | None): The conv kernels in the
GlobalPerceptron. Default: None.
globalperceptron_ratio (int): The reducation ratio in the
GlobalPerceptron. Default: 4.
num_sharesets (int): The number of sharesets in the
PartitionPerceptron. Default 1.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True).
deploy (bool): Whether to switch the model structure to
deployment mode. Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
def
__init__
(
self
,
channels
,
path_h
,
path_w
,
reparam_conv_kernels
=
None
,
globalperceptron_ratio
=
4
,
num_sharesets
=
1
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
,
requires_grad
=
True
),
deploy
=
False
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
=
init_cfg
)
self
.
deploy
=
deploy
self
.
channels
=
channels
self
.
num_sharesets
=
num_sharesets
self
.
path_h
,
self
.
path_w
=
path_h
,
path_w
# the input channel of fc3
self
.
_path_vec_channles
=
path_h
*
path_w
*
num_sharesets
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
gp
=
GlobalPerceptron
(
input_channels
=
channels
,
ratio
=
globalperceptron_ratio
)
# using a conv layer to implement a fc layer
self
.
fc3
=
build_conv_layer
(
conv_cfg
,
in_channels
=
self
.
_path_vec_channles
,
out_channels
=
self
.
_path_vec_channles
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
deploy
,
groups
=
num_sharesets
)
if
deploy
:
self
.
fc3_bn
=
nn
.
Identity
()
else
:
norm_layer
=
build_norm_layer
(
norm_cfg
,
num_sharesets
)[
1
]
self
.
add_module
(
'fc3_bn'
,
norm_layer
)
self
.
reparam_conv_kernels
=
reparam_conv_kernels
if
not
deploy
and
reparam_conv_kernels
is
not
None
:
for
k
in
reparam_conv_kernels
:
conv_branch
=
ConvModule
(
in_channels
=
num_sharesets
,
out_channels
=
num_sharesets
,
kernel_size
=
k
,
stride
=
1
,
padding
=
k
//
2
,
norm_cfg
=
dict
(
type
=
'BN'
,
requires_grad
=
True
),
groups
=
num_sharesets
,
act_cfg
=
None
)
self
.
__setattr__
(
'repconv{}'
.
format
(
k
),
conv_branch
)
def
partition
(
self
,
x
,
h_parts
,
w_parts
):
# convert (N, C, H, W) to (N, h_parts, w_parts, C, path_h, path_w)
x
=
x
.
reshape
(
-
1
,
self
.
channels
,
h_parts
,
self
.
path_h
,
w_parts
,
self
.
path_w
)
x
=
x
.
permute
(
0
,
2
,
4
,
1
,
3
,
5
)
return
x
def
partition_affine
(
self
,
x
,
h_parts
,
w_parts
):
"""perform Partition Perceptron."""
fc_inputs
=
x
.
reshape
(
-
1
,
self
.
_path_vec_channles
,
1
,
1
)
out
=
self
.
fc3
(
fc_inputs
)
out
=
out
.
reshape
(
-
1
,
self
.
num_sharesets
,
self
.
path_h
,
self
.
path_w
)
out
=
self
.
fc3_bn
(
out
)
out
=
out
.
reshape
(
-
1
,
h_parts
,
w_parts
,
self
.
num_sharesets
,
self
.
path_h
,
self
.
path_w
)
return
out
def
forward
(
self
,
inputs
):
# Global Perceptron
global_vec
=
self
.
gp
(
inputs
)
origin_shape
=
inputs
.
size
()
h_parts
=
origin_shape
[
2
]
//
self
.
path_h
w_parts
=
origin_shape
[
3
]
//
self
.
path_w
partitions
=
self
.
partition
(
inputs
,
h_parts
,
w_parts
)
# Channel Perceptron
fc3_out
=
self
.
partition_affine
(
partitions
,
h_parts
,
w_parts
)
# perform Local Perceptron
if
self
.
reparam_conv_kernels
is
not
None
and
not
self
.
deploy
:
conv_inputs
=
partitions
.
reshape
(
-
1
,
self
.
num_sharesets
,
self
.
path_h
,
self
.
path_w
)
conv_out
=
0
for
k
in
self
.
reparam_conv_kernels
:
conv_branch
=
self
.
__getattr__
(
'repconv{}'
.
format
(
k
))
conv_out
+=
conv_branch
(
conv_inputs
)
conv_out
=
conv_out
.
reshape
(
-
1
,
h_parts
,
w_parts
,
self
.
num_sharesets
,
self
.
path_h
,
self
.
path_w
)
fc3_out
+=
conv_out
# N, h_parts, w_parts, num_sharesets, out_h, out_w
fc3_out
=
fc3_out
.
permute
(
0
,
3
,
1
,
4
,
2
,
5
)
out
=
fc3_out
.
reshape
(
*
origin_shape
)
out
=
out
*
global_vec
return
out
def
get_equivalent_fc3
(
self
):
"""get the equivalent fc3 weight and bias."""
fc_weight
,
fc_bias
=
fuse_bn
(
self
.
fc3
,
self
.
fc3_bn
)
if
self
.
reparam_conv_kernels
is
not
None
:
largest_k
=
max
(
self
.
reparam_conv_kernels
)
largest_branch
=
self
.
__getattr__
(
'repconv{}'
.
format
(
largest_k
))
total_kernel
,
total_bias
=
fuse_bn
(
largest_branch
.
conv
,
largest_branch
.
bn
)
for
k
in
self
.
reparam_conv_kernels
:
if
k
!=
largest_k
:
k_branch
=
self
.
__getattr__
(
'repconv{}'
.
format
(
k
))
kernel
,
bias
=
fuse_bn
(
k_branch
.
conv
,
k_branch
.
bn
)
total_kernel
+=
F
.
pad
(
kernel
,
[(
largest_k
-
k
)
//
2
]
*
4
)
total_bias
+=
bias
rep_weight
,
rep_bias
=
self
.
_convert_conv_to_fc
(
total_kernel
,
total_bias
)
final_fc3_weight
=
rep_weight
.
reshape_as
(
fc_weight
)
+
fc_weight
final_fc3_bias
=
rep_bias
+
fc_bias
else
:
final_fc3_weight
=
fc_weight
final_fc3_bias
=
fc_bias
return
final_fc3_weight
,
final_fc3_bias
def
local_inject
(
self
):
"""inject the Local Perceptron into Partition Perceptron."""
self
.
deploy
=
True
# Locality Injection
fc3_weight
,
fc3_bias
=
self
.
get_equivalent_fc3
()
# Remove Local Perceptron
if
self
.
reparam_conv_kernels
is
not
None
:
for
k
in
self
.
reparam_conv_kernels
:
self
.
__delattr__
(
'repconv{}'
.
format
(
k
))
self
.
__delattr__
(
'fc3'
)
self
.
__delattr__
(
'fc3_bn'
)
self
.
fc3
=
build_conv_layer
(
self
.
conv_cfg
,
self
.
_path_vec_channles
,
self
.
_path_vec_channles
,
1
,
1
,
0
,
bias
=
True
,
groups
=
self
.
num_sharesets
)
self
.
fc3_bn
=
nn
.
Identity
()
self
.
fc3
.
weight
.
data
=
fc3_weight
self
.
fc3
.
bias
.
data
=
fc3_bias
def
_convert_conv_to_fc
(
self
,
conv_kernel
,
conv_bias
):
"""convert conv_k1 to fc, which is still a conv_k2, and the k2 > k1."""
in_channels
=
torch
.
eye
(
self
.
path_h
*
self
.
path_w
).
repeat
(
1
,
self
.
num_sharesets
).
reshape
(
self
.
path_h
*
self
.
path_w
,
self
.
num_sharesets
,
self
.
path_h
,
self
.
path_w
).
to
(
conv_kernel
.
device
)
fc_k
=
F
.
conv2d
(
in_channels
,
conv_kernel
,
padding
=
(
conv_kernel
.
size
(
2
)
//
2
,
conv_kernel
.
size
(
3
)
//
2
),
groups
=
self
.
num_sharesets
)
fc_k
=
fc_k
.
reshape
(
self
.
path_w
*
self
.
path_w
,
self
.
num_sharesets
*
self
.
path_h
*
self
.
path_w
).
t
()
fc_bias
=
conv_bias
.
repeat_interleave
(
self
.
path_h
*
self
.
path_w
)
return
fc_k
,
fc_bias
class
RepMLPNetUnit
(
BaseModule
):
"""A basic unit in RepMLPNet : [REPMLPBlock + BN + ConvFFN + BN].
Args:
channels (int): The number of input and the output channels of the
unit.
path_h (int): The height of patches.
path_w (int): The weidth of patches.
reparam_conv_kernels (Squeue(int) | None): The conv kernels in the
GlobalPerceptron. Default: None.
globalperceptron_ratio (int): The reducation ratio in the
GlobalPerceptron. Default: 4.
num_sharesets (int): The number of sharesets in the
PartitionPerceptron. Default 1.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
deploy (bool): Whether to switch the model structure to
deployment mode. Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
def
__init__
(
self
,
channels
,
path_h
,
path_w
,
reparam_conv_kernels
,
globalperceptron_ratio
,
norm_cfg
=
dict
(
type
=
'BN'
,
requires_grad
=
True
),
ffn_expand
=
4
,
num_sharesets
=
1
,
deploy
=
False
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
=
init_cfg
)
self
.
repmlp_block
=
RepMLPBlock
(
channels
=
channels
,
path_h
=
path_h
,
path_w
=
path_w
,
reparam_conv_kernels
=
reparam_conv_kernels
,
globalperceptron_ratio
=
globalperceptron_ratio
,
num_sharesets
=
num_sharesets
,
deploy
=
deploy
)
self
.
ffn_block
=
ConvFFN
(
channels
,
channels
*
ffn_expand
)
norm1
=
build_norm_layer
(
norm_cfg
,
channels
)[
1
]
self
.
add_module
(
'norm1'
,
norm1
)
norm2
=
build_norm_layer
(
norm_cfg
,
channels
)[
1
]
self
.
add_module
(
'norm2'
,
norm2
)
def
forward
(
self
,
x
):
y
=
x
+
self
.
repmlp_block
(
self
.
norm1
(
x
))
out
=
y
+
self
.
ffn_block
(
self
.
norm2
(
y
))
return
out
class
ConvFFN
(
nn
.
Module
):
"""ConvFFN implemented by using point-wise convs."""
def
__init__
(
self
,
in_channels
,
hidden_channels
=
None
,
out_channels
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
,
requires_grad
=
True
),
act_cfg
=
dict
(
type
=
'GELU'
)):
super
().
__init__
()
out_features
=
out_channels
or
in_channels
hidden_features
=
hidden_channels
or
in_channels
self
.
ffn_fc1
=
ConvModule
(
in_channels
=
in_channels
,
out_channels
=
hidden_features
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
norm_cfg
=
norm_cfg
,
act_cfg
=
None
)
self
.
ffn_fc2
=
ConvModule
(
in_channels
=
hidden_features
,
out_channels
=
out_features
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
norm_cfg
=
norm_cfg
,
act_cfg
=
None
)
self
.
act
=
build_activation_layer
(
act_cfg
)
def
forward
(
self
,
x
):
x
=
self
.
ffn_fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
ffn_fc2
(
x
)
return
x
@
MODELS
.
register_module
()
class
RepMLPNet
(
BaseModule
):
"""RepMLPNet backbone.
A PyTorch impl of : `RepMLP: Re-parameterizing Convolutions into
Fully-connected Layers for Image Recognition
<https://arxiv.org/abs/2105.01883>`_
Args:
arch (str | dict): RepMLP architecture. If use string, choose
from 'base' and 'b'. If use dict, it should have below keys:
- channels (List[int]): Number of blocks in each stage.
- depths (List[int]): The number of blocks in each branch.
- sharesets_nums (List[int]): RepVGG Block that declares
the need to apply group convolution.
img_size (int | tuple): The size of input image. Defaults: 224.
in_channels (int): Number of input image channels. Default: 3.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 4.
out_indices (Sequence[int]): Output from which stages.
Default: ``(3, )``.
reparam_conv_kernels (Squeue(int) | None): The conv kernels in the
GlobalPerceptron. Default: None.
globalperceptron_ratio (int): The reducation ratio in the
GlobalPerceptron. Default: 4.
num_sharesets (int): The number of sharesets in the
PartitionPerceptron. Default 1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
Default: dict(type='BN', requires_grad=True).
patch_cfg (dict): Extra config dict for patch embedding.
Defaults to an empty dict.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
deploy (bool): Whether to switch the model structure to deployment
mode. Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
arch_zoo
=
{
**
dict
.
fromkeys
([
'b'
,
'base'
],
{
'channels'
:
[
96
,
192
,
384
,
768
],
'depths'
:
[
2
,
2
,
12
,
2
],
'sharesets_nums'
:
[
1
,
4
,
32
,
128
]}),
}
# yapf: disable
num_extra_tokens
=
0
# there is no cls-token in RepMLP
def
__init__
(
self
,
arch
,
img_size
=
224
,
in_channels
=
3
,
patch_size
=
4
,
out_indices
=
(
3
,
),
reparam_conv_kernels
=
(
3
,
),
globalperceptron_ratio
=
4
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
,
requires_grad
=
True
),
patch_cfg
=
dict
(),
final_norm
=
True
,
deploy
=
False
,
init_cfg
=
None
):
super
(
RepMLPNet
,
self
).
__init__
(
init_cfg
=
init_cfg
)
if
isinstance
(
arch
,
str
):
arch
=
arch
.
lower
()
assert
arch
in
set
(
self
.
arch_zoo
),
\
f
'Arch
{
arch
}
is not in default archs
{
set
(
self
.
arch_zoo
)
}
'
self
.
arch_settings
=
self
.
arch_zoo
[
arch
]
else
:
essential_keys
=
{
'channels'
,
'depths'
,
'sharesets_nums'
}
assert
isinstance
(
arch
,
dict
)
and
set
(
arch
)
==
essential_keys
,
\
f
'Custom arch needs a dict with keys
{
essential_keys
}
.'
self
.
arch_settings
=
arch
self
.
img_size
=
to_2tuple
(
img_size
)
self
.
patch_size
=
to_2tuple
(
patch_size
)
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
num_stage
=
len
(
self
.
arch_settings
[
'channels'
])
for
value
in
self
.
arch_settings
.
values
():
assert
isinstance
(
value
,
list
)
and
len
(
value
)
==
self
.
num_stage
,
(
'Length of setting item in arch dict must be type of list and'
' have the same length.'
)
self
.
channels
=
self
.
arch_settings
[
'channels'
]
self
.
depths
=
self
.
arch_settings
[
'depths'
]
self
.
sharesets_nums
=
self
.
arch_settings
[
'sharesets_nums'
]
_patch_cfg
=
dict
(
in_channels
=
in_channels
,
input_size
=
self
.
img_size
,
embed_dims
=
self
.
channels
[
0
],
conv_type
=
'Conv2d'
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
,
norm_cfg
=
self
.
norm_cfg
,
bias
=
False
)
_patch_cfg
.
update
(
patch_cfg
)
self
.
patch_embed
=
PatchEmbed
(
**
_patch_cfg
)
self
.
patch_resolution
=
self
.
patch_embed
.
init_out_size
self
.
patch_hs
=
[
self
.
patch_resolution
[
0
]
//
2
**
i
for
i
in
range
(
self
.
num_stage
)
]
self
.
patch_ws
=
[
self
.
patch_resolution
[
1
]
//
2
**
i
for
i
in
range
(
self
.
num_stage
)
]
self
.
stages
=
ModuleList
()
self
.
downsample_layers
=
ModuleList
()
for
stage_idx
in
range
(
self
.
num_stage
):
# make stage layers
_stage_cfg
=
dict
(
channels
=
self
.
channels
[
stage_idx
],
path_h
=
self
.
patch_hs
[
stage_idx
],
path_w
=
self
.
patch_ws
[
stage_idx
],
reparam_conv_kernels
=
reparam_conv_kernels
,
globalperceptron_ratio
=
globalperceptron_ratio
,
norm_cfg
=
self
.
norm_cfg
,
ffn_expand
=
4
,
num_sharesets
=
self
.
sharesets_nums
[
stage_idx
],
deploy
=
deploy
)
stage_blocks
=
[
RepMLPNetUnit
(
**
_stage_cfg
)
for
_
in
range
(
self
.
depths
[
stage_idx
])
]
self
.
stages
.
append
(
Sequential
(
*
stage_blocks
))
# make downsample layers
if
stage_idx
<
self
.
num_stage
-
1
:
self
.
downsample_layers
.
append
(
ConvModule
(
in_channels
=
self
.
channels
[
stage_idx
],
out_channels
=
self
.
channels
[
stage_idx
+
1
],
kernel_size
=
2
,
stride
=
2
,
padding
=
0
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
,
inplace
=
True
))
self
.
out_indice
=
out_indices
if
final_norm
:
norm_layer
=
build_norm_layer
(
norm_cfg
,
self
.
channels
[
-
1
])[
1
]
else
:
norm_layer
=
nn
.
Identity
()
self
.
add_module
(
'final_norm'
,
norm_layer
)
def
forward
(
self
,
x
):
assert
x
.
shape
[
2
:]
==
self
.
img_size
,
\
"The Rep-MLP doesn't support dynamic input shape. "
\
f
'Please input images with shape
{
self
.
img_size
}
'
outs
=
[]
x
,
_
=
self
.
patch_embed
(
x
)
for
i
,
stage
in
enumerate
(
self
.
stages
):
x
=
stage
(
x
)
# downsample after each stage except last stage
if
i
<
len
(
self
.
stages
)
-
1
:
downsample
=
self
.
downsample_layers
[
i
]
x
=
downsample
(
x
)
if
i
in
self
.
out_indice
:
if
self
.
final_norm
and
i
==
len
(
self
.
stages
)
-
1
:
out
=
self
.
final_norm
(
x
)
else
:
out
=
x
outs
.
append
(
out
)
return
tuple
(
outs
)
def
switch_to_deploy
(
self
):
for
m
in
self
.
modules
():
if
hasattr
(
m
,
'local_inject'
):
m
.
local_inject
()
mmpretrain/models/backbones/repvgg.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch.nn.functional
as
F
import
torch.utils.checkpoint
as
cp
from
mmcv.cnn
import
(
ConvModule
,
build_activation_layer
,
build_conv_layer
,
build_norm_layer
)
from
mmengine.model
import
BaseModule
,
Sequential
from
mmengine.utils.dl_utils.parrots_wrapper
import
_BatchNorm
from
torch
import
nn
from
mmpretrain.registry
import
MODELS
from
..utils.se_layer
import
SELayer
from
.base_backbone
import
BaseBackbone
class
RepVGGBlock
(
BaseModule
):
"""RepVGG block for RepVGG backbone.
Args:
in_channels (int): The input channels of the block.
out_channels (int): The output channels of the block.
stride (int): Stride of the 3x3 and 1x1 convolution layer. Default: 1.
padding (int): Padding of the 3x3 convolution layer.
dilation (int): Dilation of the 3x3 convolution layer.
groups (int): Groups of the 3x3 and 1x1 convolution layer. Default: 1.
padding_mode (str): Padding mode of the 3x3 convolution layer.
Default: 'zeros'.
se_cfg (None or dict): The configuration of the se module.
Default: None.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
deploy (bool): Whether to switch the model structure to
deployment mode. Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
=
1
,
padding
=
1
,
dilation
=
1
,
groups
=
1
,
padding_mode
=
'zeros'
,
se_cfg
=
None
,
with_cp
=
False
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
),
act_cfg
=
dict
(
type
=
'ReLU'
),
deploy
=
False
,
init_cfg
=
None
):
super
(
RepVGGBlock
,
self
).
__init__
(
init_cfg
)
assert
se_cfg
is
None
or
isinstance
(
se_cfg
,
dict
)
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
stride
=
stride
self
.
padding
=
padding
self
.
dilation
=
dilation
self
.
groups
=
groups
self
.
se_cfg
=
se_cfg
self
.
with_cp
=
with_cp
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
act_cfg
=
act_cfg
self
.
deploy
=
deploy
if
deploy
:
self
.
branch_reparam
=
build_conv_layer
(
conv_cfg
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
groups
=
groups
,
bias
=
True
,
padding_mode
=
padding_mode
)
else
:
# judge if input shape and output shape are the same.
# If true, add a normalized identity shortcut.
if
out_channels
==
in_channels
and
stride
==
1
and
\
padding
==
dilation
:
self
.
branch_norm
=
build_norm_layer
(
norm_cfg
,
in_channels
)[
1
]
else
:
self
.
branch_norm
=
None
self
.
branch_3x3
=
self
.
create_conv_bn
(
kernel_size
=
3
,
dilation
=
dilation
,
padding
=
padding
,
)
self
.
branch_1x1
=
self
.
create_conv_bn
(
kernel_size
=
1
)
if
se_cfg
is
not
None
:
self
.
se_layer
=
SELayer
(
channels
=
out_channels
,
**
se_cfg
)
else
:
self
.
se_layer
=
None
self
.
act
=
build_activation_layer
(
act_cfg
)
def
create_conv_bn
(
self
,
kernel_size
,
dilation
=
1
,
padding
=
0
):
conv_bn
=
Sequential
()
conv_bn
.
add_module
(
'conv'
,
build_conv_layer
(
self
.
conv_cfg
,
in_channels
=
self
.
in_channels
,
out_channels
=
self
.
out_channels
,
kernel_size
=
kernel_size
,
stride
=
self
.
stride
,
dilation
=
dilation
,
padding
=
padding
,
groups
=
self
.
groups
,
bias
=
False
))
conv_bn
.
add_module
(
'norm'
,
build_norm_layer
(
self
.
norm_cfg
,
num_features
=
self
.
out_channels
)[
1
])
return
conv_bn
def
forward
(
self
,
x
):
def
_inner_forward
(
inputs
):
if
self
.
deploy
:
return
self
.
branch_reparam
(
inputs
)
if
self
.
branch_norm
is
None
:
branch_norm_out
=
0
else
:
branch_norm_out
=
self
.
branch_norm
(
inputs
)
inner_out
=
self
.
branch_3x3
(
inputs
)
+
self
.
branch_1x1
(
inputs
)
+
branch_norm_out
if
self
.
se_cfg
is
not
None
:
inner_out
=
self
.
se_layer
(
inner_out
)
return
inner_out
if
self
.
with_cp
and
x
.
requires_grad
:
out
=
cp
.
checkpoint
(
_inner_forward
,
x
)
else
:
out
=
_inner_forward
(
x
)
out
=
self
.
act
(
out
)
return
out
def
switch_to_deploy
(
self
):
"""Switch the model structure from training mode to deployment mode."""
if
self
.
deploy
:
return
assert
self
.
norm_cfg
[
'type'
]
==
'BN'
,
\
"Switch is not allowed when norm_cfg['type'] != 'BN'."
reparam_weight
,
reparam_bias
=
self
.
reparameterize
()
self
.
branch_reparam
=
build_conv_layer
(
self
.
conv_cfg
,
self
.
in_channels
,
self
.
out_channels
,
kernel_size
=
3
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
dilation
=
self
.
dilation
,
groups
=
self
.
groups
,
bias
=
True
)
self
.
branch_reparam
.
weight
.
data
=
reparam_weight
self
.
branch_reparam
.
bias
.
data
=
reparam_bias
for
param
in
self
.
parameters
():
param
.
detach_
()
delattr
(
self
,
'branch_3x3'
)
delattr
(
self
,
'branch_1x1'
)
delattr
(
self
,
'branch_norm'
)
self
.
deploy
=
True
def
reparameterize
(
self
):
"""Fuse all the parameters of all branches.
Returns:
tuple[torch.Tensor, torch.Tensor]: Parameters after fusion of all
branches. the first element is the weights and the second is
the bias.
"""
weight_3x3
,
bias_3x3
=
self
.
_fuse_conv_bn
(
self
.
branch_3x3
)
weight_1x1
,
bias_1x1
=
self
.
_fuse_conv_bn
(
self
.
branch_1x1
)
# pad a conv1x1 weight to a conv3x3 weight
weight_1x1
=
F
.
pad
(
weight_1x1
,
[
1
,
1
,
1
,
1
],
value
=
0
)
weight_norm
,
bias_norm
=
0
,
0
if
self
.
branch_norm
:
tmp_conv_bn
=
self
.
_norm_to_conv3x3
(
self
.
branch_norm
)
weight_norm
,
bias_norm
=
self
.
_fuse_conv_bn
(
tmp_conv_bn
)
return
(
weight_3x3
+
weight_1x1
+
weight_norm
,
bias_3x3
+
bias_1x1
+
bias_norm
)
def
_fuse_conv_bn
(
self
,
branch
):
"""Fuse the parameters in a branch with a conv and bn.
Args:
branch (mmcv.runner.Sequential): A branch with conv and bn.
Returns:
tuple[torch.Tensor, torch.Tensor]: The parameters obtained after
fusing the parameters of conv and bn in one branch.
The first element is the weight and the second is the bias.
"""
if
branch
is
None
:
return
0
,
0
conv_weight
=
branch
.
conv
.
weight
running_mean
=
branch
.
norm
.
running_mean
running_var
=
branch
.
norm
.
running_var
gamma
=
branch
.
norm
.
weight
beta
=
branch
.
norm
.
bias
eps
=
branch
.
norm
.
eps
std
=
(
running_var
+
eps
).
sqrt
()
fused_weight
=
(
gamma
/
std
).
reshape
(
-
1
,
1
,
1
,
1
)
*
conv_weight
fused_bias
=
-
running_mean
*
gamma
/
std
+
beta
return
fused_weight
,
fused_bias
def
_norm_to_conv3x3
(
self
,
branch_nrom
):
"""Convert a norm layer to a conv3x3-bn sequence.
Args:
branch (nn.BatchNorm2d): A branch only with bn in the block.
Returns:
tmp_conv3x3 (mmcv.runner.Sequential): a sequential with conv3x3 and
bn.
"""
input_dim
=
self
.
in_channels
//
self
.
groups
conv_weight
=
torch
.
zeros
((
self
.
in_channels
,
input_dim
,
3
,
3
),
dtype
=
branch_nrom
.
weight
.
dtype
)
for
i
in
range
(
self
.
in_channels
):
conv_weight
[
i
,
i
%
input_dim
,
1
,
1
]
=
1
conv_weight
=
conv_weight
.
to
(
branch_nrom
.
weight
.
device
)
tmp_conv3x3
=
self
.
create_conv_bn
(
kernel_size
=
3
)
tmp_conv3x3
.
conv
.
weight
.
data
=
conv_weight
tmp_conv3x3
.
norm
=
branch_nrom
return
tmp_conv3x3
class
MTSPPF
(
BaseModule
):
"""MTSPPF block for YOLOX-PAI RepVGG backbone.
Args:
in_channels (int): The input channels of the block.
out_channels (int): The output channels of the block.
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
kernel_size (int): Kernel size of pooling. Default: 5.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
norm_cfg
=
dict
(
type
=
'BN'
),
act_cfg
=
dict
(
type
=
'ReLU'
),
kernel_size
=
5
):
super
().
__init__
()
hidden_features
=
in_channels
//
2
# hidden channels
self
.
conv1
=
ConvModule
(
in_channels
,
hidden_features
,
1
,
stride
=
1
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
)
self
.
conv2
=
ConvModule
(
hidden_features
*
4
,
out_channels
,
1
,
stride
=
1
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
kernel_size
,
stride
=
1
,
padding
=
kernel_size
//
2
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
y1
=
self
.
maxpool
(
x
)
y2
=
self
.
maxpool
(
y1
)
return
self
.
conv2
(
torch
.
cat
([
x
,
y1
,
y2
,
self
.
maxpool
(
y2
)],
1
))
@
MODELS
.
register_module
()
class
RepVGG
(
BaseBackbone
):
"""RepVGG backbone.
A PyTorch impl of : `RepVGG: Making VGG-style ConvNets Great Again
<https://arxiv.org/abs/2101.03697>`_
Args:
arch (str | dict): RepVGG architecture. If use string, choose from
'A0', 'A1`', 'A2', 'B0', 'B1', 'B1g2', 'B1g4', 'B2', 'B2g2',
'B2g4', 'B3', 'B3g2', 'B3g4' or 'D2se'. If use dict, it should
have below keys:
- **num_blocks** (Sequence[int]): Number of blocks in each stage.
- **width_factor** (Sequence[float]): Width deflator in each stage.
- **group_layer_map** (dict | None): RepVGG Block that declares
the need to apply group convolution.
- **se_cfg** (dict | None): SE Layer config.
- **stem_channels** (int, optional): The stem channels, the final
stem channels will be
``min(stem_channels, base_channels*width_factor[0])``.
If not set here, 64 is used by default in the code.
in_channels (int): Number of input image channels. Defaults to 3.
base_channels (int): Base channels of RepVGG backbone, work with
width_factor together. Defaults to 64.
out_indices (Sequence[int]): Output from which stages.
Defaults to ``(3, )``.
strides (Sequence[int]): Strides of the first block of each stage.
Defaults to ``(2, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Defaults to ``(1, 1, 1, 1)``.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters. Defaults to -1.
conv_cfg (dict | None): The config dict for conv layers.
Defaults to None.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='BN')``.
act_cfg (dict): Config dict for activation layer.
Defaults to ``dict(type='ReLU')``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
deploy (bool): Whether to switch the model structure to deployment
mode. Defaults to False.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
add_ppf (bool): Whether to use the MTSPPF block. Defaults to False.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
groupwise_layers
=
[
2
,
4
,
6
,
8
,
10
,
12
,
14
,
16
,
18
,
20
,
22
,
24
,
26
]
g2_layer_map
=
{
layer
:
2
for
layer
in
groupwise_layers
}
g4_layer_map
=
{
layer
:
4
for
layer
in
groupwise_layers
}
arch_settings
=
{
'A0'
:
dict
(
num_blocks
=
[
2
,
4
,
14
,
1
],
width_factor
=
[
0.75
,
0.75
,
0.75
,
2.5
],
group_layer_map
=
None
,
se_cfg
=
None
),
'A1'
:
dict
(
num_blocks
=
[
2
,
4
,
14
,
1
],
width_factor
=
[
1
,
1
,
1
,
2.5
],
group_layer_map
=
None
,
se_cfg
=
None
),
'A2'
:
dict
(
num_blocks
=
[
2
,
4
,
14
,
1
],
width_factor
=
[
1.5
,
1.5
,
1.5
,
2.75
],
group_layer_map
=
None
,
se_cfg
=
None
),
'B0'
:
dict
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_factor
=
[
1
,
1
,
1
,
2.5
],
group_layer_map
=
None
,
se_cfg
=
None
,
stem_channels
=
64
),
'B1'
:
dict
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_factor
=
[
2
,
2
,
2
,
4
],
group_layer_map
=
None
,
se_cfg
=
None
),
'B1g2'
:
dict
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_factor
=
[
2
,
2
,
2
,
4
],
group_layer_map
=
g2_layer_map
,
se_cfg
=
None
),
'B1g4'
:
dict
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_factor
=
[
2
,
2
,
2
,
4
],
group_layer_map
=
g4_layer_map
,
se_cfg
=
None
),
'B2'
:
dict
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_factor
=
[
2.5
,
2.5
,
2.5
,
5
],
group_layer_map
=
None
,
se_cfg
=
None
),
'B2g2'
:
dict
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_factor
=
[
2.5
,
2.5
,
2.5
,
5
],
group_layer_map
=
g2_layer_map
,
se_cfg
=
None
),
'B2g4'
:
dict
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_factor
=
[
2.5
,
2.5
,
2.5
,
5
],
group_layer_map
=
g4_layer_map
,
se_cfg
=
None
),
'B3'
:
dict
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_factor
=
[
3
,
3
,
3
,
5
],
group_layer_map
=
None
,
se_cfg
=
None
),
'B3g2'
:
dict
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_factor
=
[
3
,
3
,
3
,
5
],
group_layer_map
=
g2_layer_map
,
se_cfg
=
None
),
'B3g4'
:
dict
(
num_blocks
=
[
4
,
6
,
16
,
1
],
width_factor
=
[
3
,
3
,
3
,
5
],
group_layer_map
=
g4_layer_map
,
se_cfg
=
None
),
'D2se'
:
dict
(
num_blocks
=
[
8
,
14
,
24
,
1
],
width_factor
=
[
2.5
,
2.5
,
2.5
,
5
],
group_layer_map
=
None
,
se_cfg
=
dict
(
ratio
=
16
,
divisor
=
1
)),
'yolox-pai-small'
:
dict
(
num_blocks
=
[
3
,
5
,
7
,
3
],
width_factor
=
[
1
,
1
,
1
,
1
],
group_layer_map
=
None
,
se_cfg
=
None
,
stem_channels
=
32
),
}
def
__init__
(
self
,
arch
,
in_channels
=
3
,
base_channels
=
64
,
out_indices
=
(
3
,
),
strides
=
(
2
,
2
,
2
,
2
),
dilations
=
(
1
,
1
,
1
,
1
),
frozen_stages
=-
1
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
),
act_cfg
=
dict
(
type
=
'ReLU'
),
with_cp
=
False
,
deploy
=
False
,
norm_eval
=
False
,
add_ppf
=
False
,
init_cfg
=
[
dict
(
type
=
'Kaiming'
,
layer
=
[
'Conv2d'
]),
dict
(
type
=
'Constant'
,
val
=
1
,
layer
=
[
'_BatchNorm'
,
'GroupNorm'
])
]):
super
(
RepVGG
,
self
).
__init__
(
init_cfg
)
if
isinstance
(
arch
,
str
):
assert
arch
in
self
.
arch_settings
,
\
f
'"arch": "
{
arch
}
" is not one of the arch_settings'
arch
=
self
.
arch_settings
[
arch
]
elif
not
isinstance
(
arch
,
dict
):
raise
TypeError
(
'Expect "arch" to be either a string '
f
'or a dict, got
{
type
(
arch
)
}
'
)
assert
len
(
arch
[
'num_blocks'
])
==
len
(
arch
[
'width_factor'
])
==
len
(
strides
)
==
len
(
dilations
)
assert
max
(
out_indices
)
<
len
(
arch
[
'num_blocks'
])
if
arch
[
'group_layer_map'
]
is
not
None
:
assert
max
(
arch
[
'group_layer_map'
].
keys
())
<=
sum
(
arch
[
'num_blocks'
])
if
arch
[
'se_cfg'
]
is
not
None
:
assert
isinstance
(
arch
[
'se_cfg'
],
dict
)
self
.
base_channels
=
base_channels
self
.
arch
=
arch
self
.
in_channels
=
in_channels
self
.
out_indices
=
out_indices
self
.
strides
=
strides
self
.
dilations
=
dilations
self
.
deploy
=
deploy
self
.
frozen_stages
=
frozen_stages
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
act_cfg
=
act_cfg
self
.
with_cp
=
with_cp
self
.
norm_eval
=
norm_eval
# defaults to 64 to prevert BC-breaking if stem_channels
# not in arch dict;
# the stem channels should not be larger than that of stage1.
channels
=
min
(
arch
.
get
(
'stem_channels'
,
64
),
int
(
self
.
base_channels
*
self
.
arch
[
'width_factor'
][
0
]))
self
.
stem
=
RepVGGBlock
(
self
.
in_channels
,
channels
,
stride
=
2
,
se_cfg
=
arch
[
'se_cfg'
],
with_cp
=
with_cp
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
,
deploy
=
deploy
)
next_create_block_idx
=
1
self
.
stages
=
[]
for
i
in
range
(
len
(
arch
[
'num_blocks'
])):
num_blocks
=
self
.
arch
[
'num_blocks'
][
i
]
stride
=
self
.
strides
[
i
]
dilation
=
self
.
dilations
[
i
]
out_channels
=
int
(
self
.
base_channels
*
2
**
i
*
self
.
arch
[
'width_factor'
][
i
])
stage
,
next_create_block_idx
=
self
.
_make_stage
(
channels
,
out_channels
,
num_blocks
,
stride
,
dilation
,
next_create_block_idx
,
init_cfg
)
stage_name
=
f
'stage_
{
i
+
1
}
'
self
.
add_module
(
stage_name
,
stage
)
self
.
stages
.
append
(
stage_name
)
channels
=
out_channels
if
add_ppf
:
self
.
ppf
=
MTSPPF
(
out_channels
,
out_channels
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
,
kernel_size
=
5
)
else
:
self
.
ppf
=
nn
.
Identity
()
def
_make_stage
(
self
,
in_channels
,
out_channels
,
num_blocks
,
stride
,
dilation
,
next_create_block_idx
,
init_cfg
):
strides
=
[
stride
]
+
[
1
]
*
(
num_blocks
-
1
)
dilations
=
[
dilation
]
*
num_blocks
blocks
=
[]
for
i
in
range
(
num_blocks
):
groups
=
self
.
arch
[
'group_layer_map'
].
get
(
next_create_block_idx
,
1
)
if
self
.
arch
[
'group_layer_map'
]
is
not
None
else
1
blocks
.
append
(
RepVGGBlock
(
in_channels
,
out_channels
,
stride
=
strides
[
i
],
padding
=
dilations
[
i
],
dilation
=
dilations
[
i
],
groups
=
groups
,
se_cfg
=
self
.
arch
[
'se_cfg'
],
with_cp
=
self
.
with_cp
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
,
act_cfg
=
self
.
act_cfg
,
deploy
=
self
.
deploy
,
init_cfg
=
init_cfg
))
in_channels
=
out_channels
next_create_block_idx
+=
1
return
Sequential
(
*
blocks
),
next_create_block_idx
def
forward
(
self
,
x
):
x
=
self
.
stem
(
x
)
outs
=
[]
for
i
,
stage_name
in
enumerate
(
self
.
stages
):
stage
=
getattr
(
self
,
stage_name
)
x
=
stage
(
x
)
if
i
+
1
==
len
(
self
.
stages
):
x
=
self
.
ppf
(
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
x
)
return
tuple
(
outs
)
def
_freeze_stages
(
self
):
if
self
.
frozen_stages
>=
0
:
self
.
stem
.
eval
()
for
param
in
self
.
stem
.
parameters
():
param
.
requires_grad
=
False
for
i
in
range
(
self
.
frozen_stages
):
stage
=
getattr
(
self
,
f
'stage_
{
i
+
1
}
'
)
stage
.
eval
()
for
param
in
stage
.
parameters
():
param
.
requires_grad
=
False
def
train
(
self
,
mode
=
True
):
super
(
RepVGG
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
if
mode
and
self
.
norm_eval
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
_BatchNorm
):
m
.
eval
()
def
switch_to_deploy
(
self
):
for
m
in
self
.
modules
():
if
isinstance
(
m
,
RepVGGBlock
):
m
.
switch_to_deploy
()
self
.
deploy
=
True
mmpretrain/models/backbones/res2net.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
import
math
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
cp
from
mmcv.cnn
import
build_conv_layer
,
build_norm_layer
from
mmengine.model
import
ModuleList
,
Sequential
from
mmpretrain.registry
import
MODELS
from
.resnet
import
Bottleneck
as
_Bottleneck
from
.resnet
import
ResNet
class
Bottle2neck
(
_Bottleneck
):
expansion
=
4
def
__init__
(
self
,
in_channels
,
out_channels
,
scales
=
4
,
base_width
=
26
,
base_channels
=
64
,
stage_type
=
'normal'
,
**
kwargs
):
"""Bottle2neck block for Res2Net."""
super
(
Bottle2neck
,
self
).
__init__
(
in_channels
,
out_channels
,
**
kwargs
)
assert
scales
>
1
,
'Res2Net degenerates to ResNet when scales = 1.'
mid_channels
=
out_channels
//
self
.
expansion
width
=
int
(
math
.
floor
(
mid_channels
*
(
base_width
/
base_channels
)))
self
.
norm1_name
,
norm1
=
build_norm_layer
(
self
.
norm_cfg
,
width
*
scales
,
postfix
=
1
)
self
.
norm3_name
,
norm3
=
build_norm_layer
(
self
.
norm_cfg
,
self
.
out_channels
,
postfix
=
3
)
self
.
conv1
=
build_conv_layer
(
self
.
conv_cfg
,
self
.
in_channels
,
width
*
scales
,
kernel_size
=
1
,
stride
=
self
.
conv1_stride
,
bias
=
False
)
self
.
add_module
(
self
.
norm1_name
,
norm1
)
if
stage_type
==
'stage'
:
self
.
pool
=
nn
.
AvgPool2d
(
kernel_size
=
3
,
stride
=
self
.
conv2_stride
,
padding
=
1
)
self
.
convs
=
ModuleList
()
self
.
bns
=
ModuleList
()
for
i
in
range
(
scales
-
1
):
self
.
convs
.
append
(
build_conv_layer
(
self
.
conv_cfg
,
width
,
width
,
kernel_size
=
3
,
stride
=
self
.
conv2_stride
,
padding
=
self
.
dilation
,
dilation
=
self
.
dilation
,
bias
=
False
))
self
.
bns
.
append
(
build_norm_layer
(
self
.
norm_cfg
,
width
,
postfix
=
i
+
1
)[
1
])
self
.
conv3
=
build_conv_layer
(
self
.
conv_cfg
,
width
*
scales
,
self
.
out_channels
,
kernel_size
=
1
,
bias
=
False
)
self
.
add_module
(
self
.
norm3_name
,
norm3
)
self
.
stage_type
=
stage_type
self
.
scales
=
scales
self
.
width
=
width
delattr
(
self
,
'conv2'
)
delattr
(
self
,
self
.
norm2_name
)
def
forward
(
self
,
x
):
"""Forward function."""
def
_inner_forward
(
x
):
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
norm1
(
out
)
out
=
self
.
relu
(
out
)
spx
=
torch
.
split
(
out
,
self
.
width
,
1
)
sp
=
self
.
convs
[
0
](
spx
[
0
].
contiguous
())
sp
=
self
.
relu
(
self
.
bns
[
0
](
sp
))
out
=
sp
for
i
in
range
(
1
,
self
.
scales
-
1
):
if
self
.
stage_type
==
'stage'
:
sp
=
spx
[
i
]
else
:
sp
=
sp
+
spx
[
i
]
sp
=
self
.
convs
[
i
](
sp
.
contiguous
())
sp
=
self
.
relu
(
self
.
bns
[
i
](
sp
))
out
=
torch
.
cat
((
out
,
sp
),
1
)
if
self
.
stage_type
==
'normal'
and
self
.
scales
!=
1
:
out
=
torch
.
cat
((
out
,
spx
[
self
.
scales
-
1
]),
1
)
elif
self
.
stage_type
==
'stage'
and
self
.
scales
!=
1
:
out
=
torch
.
cat
((
out
,
self
.
pool
(
spx
[
self
.
scales
-
1
])),
1
)
out
=
self
.
conv3
(
out
)
out
=
self
.
norm3
(
out
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
+=
identity
return
out
if
self
.
with_cp
and
x
.
requires_grad
:
out
=
cp
.
checkpoint
(
_inner_forward
,
x
)
else
:
out
=
_inner_forward
(
x
)
out
=
self
.
relu
(
out
)
return
out
class
Res2Layer
(
Sequential
):
"""Res2Layer to build Res2Net style backbone.
Args:
block (nn.Module): block used to build ResLayer.
inplanes (int): inplanes of block.
planes (int): planes of block.
num_blocks (int): number of blocks.
stride (int): stride of the first block. Default: 1
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottle2neck. Defaults to True.
conv_cfg (dict): dictionary to construct and config conv layer.
Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
scales (int): Scales used in Res2Net. Default: 4
base_width (int): Basic width of each scale. Default: 26
drop_path_rate (float or np.ndarray): stochastic depth rate.
Default: 0.
"""
def
__init__
(
self
,
block
,
in_channels
,
out_channels
,
num_blocks
,
stride
=
1
,
avg_down
=
True
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
),
scales
=
4
,
base_width
=
26
,
drop_path_rate
=
0.0
,
**
kwargs
):
self
.
block
=
block
if
isinstance
(
drop_path_rate
,
float
):
drop_path_rate
=
[
drop_path_rate
]
*
num_blocks
assert
len
(
drop_path_rate
)
==
num_blocks
,
'Please check the length of drop_path_rate'
downsample
=
None
if
stride
!=
1
or
in_channels
!=
out_channels
:
if
avg_down
:
downsample
=
nn
.
Sequential
(
nn
.
AvgPool2d
(
kernel_size
=
stride
,
stride
=
stride
,
ceil_mode
=
True
,
count_include_pad
=
False
),
build_conv_layer
(
conv_cfg
,
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
bias
=
False
),
build_norm_layer
(
norm_cfg
,
out_channels
)[
1
],
)
else
:
downsample
=
nn
.
Sequential
(
build_conv_layer
(
conv_cfg
,
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
stride
,
bias
=
False
),
build_norm_layer
(
norm_cfg
,
out_channels
)[
1
],
)
layers
=
[]
layers
.
append
(
block
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
stride
=
stride
,
downsample
=
downsample
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
scales
=
scales
,
base_width
=
base_width
,
stage_type
=
'stage'
,
drop_path_rate
=
drop_path_rate
[
0
],
**
kwargs
))
in_channels
=
out_channels
for
i
in
range
(
1
,
num_blocks
):
layers
.
append
(
block
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
stride
=
1
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
scales
=
scales
,
base_width
=
base_width
,
drop_path_rate
=
drop_path_rate
[
i
],
**
kwargs
))
super
(
Res2Layer
,
self
).
__init__
(
*
layers
)
@
MODELS
.
register_module
()
class
Res2Net
(
ResNet
):
"""Res2Net backbone.
A PyTorch implement of : `Res2Net: A New Multi-scale Backbone
Architecture <https://arxiv.org/pdf/1904.01169.pdf>`_
Args:
depth (int): Depth of Res2Net, choose from {50, 101, 152}.
scales (int): Scales used in Res2Net. Defaults to 4.
base_width (int): Basic width of each scale. Defaults to 26.
in_channels (int): Number of input image channels. Defaults to 3.
num_stages (int): Number of Res2Net stages. Defaults to 4.
strides (Sequence[int]): Strides of the first block of each stage.
Defaults to ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Defaults to ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages.
Defaults to ``(3, )``.
style (str): "pytorch" or "caffe". If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer. Defaults to "pytorch".
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Defaults to True.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottle2neck. Defaults to True.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_cfg (dict): Dictionary to construct and config norm layer.
Defaults to ``dict(type='BN', requires_grad=True)``.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Defaults to True.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
Example:
>>> from mmpretrain.models import Res2Net
>>> import torch
>>> model = Res2Net(depth=50,
... scales=4,
... base_width=26,
... out_indices=(0, 1, 2, 3))
>>> model.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = model.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 256, 8, 8)
(1, 512, 4, 4)
(1, 1024, 2, 2)
(1, 2048, 1, 1)
"""
arch_settings
=
{
50
:
(
Bottle2neck
,
(
3
,
4
,
6
,
3
)),
101
:
(
Bottle2neck
,
(
3
,
4
,
23
,
3
)),
152
:
(
Bottle2neck
,
(
3
,
8
,
36
,
3
))
}
def
__init__
(
self
,
scales
=
4
,
base_width
=
26
,
style
=
'pytorch'
,
deep_stem
=
True
,
avg_down
=
True
,
init_cfg
=
None
,
**
kwargs
):
self
.
scales
=
scales
self
.
base_width
=
base_width
super
(
Res2Net
,
self
).
__init__
(
style
=
style
,
deep_stem
=
deep_stem
,
avg_down
=
avg_down
,
init_cfg
=
init_cfg
,
**
kwargs
)
def
make_res_layer
(
self
,
**
kwargs
):
return
Res2Layer
(
scales
=
self
.
scales
,
base_width
=
self
.
base_width
,
base_channels
=
self
.
base_channels
,
**
kwargs
)
mmpretrain/models/backbones/resnest.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.utils.checkpoint
as
cp
from
mmcv.cnn
import
build_conv_layer
,
build_norm_layer
from
mmpretrain.registry
import
MODELS
from
.resnet
import
Bottleneck
as
_Bottleneck
from
.resnet
import
ResLayer
,
ResNetV1d
class
RSoftmax
(
nn
.
Module
):
"""Radix Softmax module in ``SplitAttentionConv2d``.
Args:
radix (int): Radix of input.
groups (int): Groups of input.
"""
def
__init__
(
self
,
radix
,
groups
):
super
().
__init__
()
self
.
radix
=
radix
self
.
groups
=
groups
def
forward
(
self
,
x
):
batch
=
x
.
size
(
0
)
if
self
.
radix
>
1
:
x
=
x
.
view
(
batch
,
self
.
groups
,
self
.
radix
,
-
1
).
transpose
(
1
,
2
)
x
=
F
.
softmax
(
x
,
dim
=
1
)
x
=
x
.
reshape
(
batch
,
-
1
)
else
:
x
=
torch
.
sigmoid
(
x
)
return
x
class
SplitAttentionConv2d
(
nn
.
Module
):
"""Split-Attention Conv2d.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int | tuple[int]): Same as nn.Conv2d.
stride (int | tuple[int]): Same as nn.Conv2d.
padding (int | tuple[int]): Same as nn.Conv2d.
dilation (int | tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
radix (int): Radix of SpltAtConv2d. Default: 2
reduction_factor (int): Reduction factor of SplitAttentionConv2d.
Default: 4.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: None.
"""
def
__init__
(
self
,
in_channels
,
channels
,
kernel_size
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
radix
=
2
,
reduction_factor
=
4
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
)):
super
(
SplitAttentionConv2d
,
self
).
__init__
()
inter_channels
=
max
(
in_channels
*
radix
//
reduction_factor
,
32
)
self
.
radix
=
radix
self
.
groups
=
groups
self
.
channels
=
channels
self
.
conv
=
build_conv_layer
(
conv_cfg
,
in_channels
,
channels
*
radix
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
groups
=
groups
*
radix
,
bias
=
False
)
self
.
norm0_name
,
norm0
=
build_norm_layer
(
norm_cfg
,
channels
*
radix
,
postfix
=
0
)
self
.
add_module
(
self
.
norm0_name
,
norm0
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
fc1
=
build_conv_layer
(
None
,
channels
,
inter_channels
,
1
,
groups
=
self
.
groups
)
self
.
norm1_name
,
norm1
=
build_norm_layer
(
norm_cfg
,
inter_channels
,
postfix
=
1
)
self
.
add_module
(
self
.
norm1_name
,
norm1
)
self
.
fc2
=
build_conv_layer
(
None
,
inter_channels
,
channels
*
radix
,
1
,
groups
=
self
.
groups
)
self
.
rsoftmax
=
RSoftmax
(
radix
,
groups
)
@
property
def
norm0
(
self
):
return
getattr
(
self
,
self
.
norm0_name
)
@
property
def
norm1
(
self
):
return
getattr
(
self
,
self
.
norm1_name
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
norm0
(
x
)
x
=
self
.
relu
(
x
)
batch
,
rchannel
=
x
.
shape
[:
2
]
if
self
.
radix
>
1
:
splits
=
x
.
view
(
batch
,
self
.
radix
,
-
1
,
*
x
.
shape
[
2
:])
gap
=
splits
.
sum
(
dim
=
1
)
else
:
gap
=
x
gap
=
F
.
adaptive_avg_pool2d
(
gap
,
1
)
gap
=
self
.
fc1
(
gap
)
gap
=
self
.
norm1
(
gap
)
gap
=
self
.
relu
(
gap
)
atten
=
self
.
fc2
(
gap
)
atten
=
self
.
rsoftmax
(
atten
).
view
(
batch
,
-
1
,
1
,
1
)
if
self
.
radix
>
1
:
attens
=
atten
.
view
(
batch
,
self
.
radix
,
-
1
,
*
atten
.
shape
[
2
:])
out
=
torch
.
sum
(
attens
*
splits
,
dim
=
1
)
else
:
out
=
atten
*
x
return
out
.
contiguous
()
class
Bottleneck
(
_Bottleneck
):
"""Bottleneck block for ResNeSt.
Args:
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
groups (int): Groups of conv2.
width_per_group (int): Width per group of conv2. 64x4d indicates
``groups=64, width_per_group=4`` and 32x8d indicates
``groups=32, width_per_group=8``.
radix (int): Radix of SpltAtConv2d. Default: 2
reduction_factor (int): Reduction factor of SplitAttentionConv2d.
Default: 4.
avg_down_stride (bool): Whether to use average pool for stride in
Bottleneck. Default: True.
stride (int): stride of the block. Default: 1
dilation (int): dilation of convolution. Default: 1
downsample (nn.Module, optional): downsample operation on identity
branch. Default: None
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
conv_cfg (dict, optional): dictionary to construct and config conv
layer. Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
groups
=
1
,
width_per_group
=
4
,
base_channels
=
64
,
radix
=
2
,
reduction_factor
=
4
,
avg_down_stride
=
True
,
**
kwargs
):
super
(
Bottleneck
,
self
).
__init__
(
in_channels
,
out_channels
,
**
kwargs
)
self
.
groups
=
groups
self
.
width_per_group
=
width_per_group
# For ResNet bottleneck, middle channels are determined by expansion
# and out_channels, but for ResNeXt bottleneck, it is determined by
# groups and width_per_group and the stage it is located in.
if
groups
!=
1
:
assert
self
.
mid_channels
%
base_channels
==
0
self
.
mid_channels
=
(
groups
*
width_per_group
*
self
.
mid_channels
//
base_channels
)
self
.
avg_down_stride
=
avg_down_stride
and
self
.
conv2_stride
>
1
self
.
norm1_name
,
norm1
=
build_norm_layer
(
self
.
norm_cfg
,
self
.
mid_channels
,
postfix
=
1
)
self
.
norm3_name
,
norm3
=
build_norm_layer
(
self
.
norm_cfg
,
self
.
out_channels
,
postfix
=
3
)
self
.
conv1
=
build_conv_layer
(
self
.
conv_cfg
,
self
.
in_channels
,
self
.
mid_channels
,
kernel_size
=
1
,
stride
=
self
.
conv1_stride
,
bias
=
False
)
self
.
add_module
(
self
.
norm1_name
,
norm1
)
self
.
conv2
=
SplitAttentionConv2d
(
self
.
mid_channels
,
self
.
mid_channels
,
kernel_size
=
3
,
stride
=
1
if
self
.
avg_down_stride
else
self
.
conv2_stride
,
padding
=
self
.
dilation
,
dilation
=
self
.
dilation
,
groups
=
groups
,
radix
=
radix
,
reduction_factor
=
reduction_factor
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
)
delattr
(
self
,
self
.
norm2_name
)
if
self
.
avg_down_stride
:
self
.
avd_layer
=
nn
.
AvgPool2d
(
3
,
self
.
conv2_stride
,
padding
=
1
)
self
.
conv3
=
build_conv_layer
(
self
.
conv_cfg
,
self
.
mid_channels
,
self
.
out_channels
,
kernel_size
=
1
,
bias
=
False
)
self
.
add_module
(
self
.
norm3_name
,
norm3
)
def
forward
(
self
,
x
):
def
_inner_forward
(
x
):
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
norm1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
if
self
.
avg_down_stride
:
out
=
self
.
avd_layer
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
norm3
(
out
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
+=
identity
return
out
if
self
.
with_cp
and
x
.
requires_grad
:
out
=
cp
.
checkpoint
(
_inner_forward
,
x
)
else
:
out
=
_inner_forward
(
x
)
out
=
self
.
relu
(
out
)
return
out
@
MODELS
.
register_module
()
class
ResNeSt
(
ResNetV1d
):
"""ResNeSt backbone.
Please refer to the `paper <https://arxiv.org/pdf/2004.08955.pdf>`__ for
details.
Args:
depth (int): Network depth, from {50, 101, 152, 200}.
groups (int): Groups of conv2 in Bottleneck. Default: 32.
width_per_group (int): Width per group of conv2 in Bottleneck.
Default: 4.
radix (int): Radix of SpltAtConv2d. Default: 2
reduction_factor (int): Reduction factor of SplitAttentionConv2d.
Default: 4.
avg_down_stride (bool): Whether to use average pool for stride in
Bottleneck. Default: True.
in_channels (int): Number of input image channels. Default: 3.
stem_channels (int): Output channels of the stem layer. Default: 64.
num_stages (int): Stages of the network. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Default: False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
"""
arch_settings
=
{
50
:
(
Bottleneck
,
(
3
,
4
,
6
,
3
)),
101
:
(
Bottleneck
,
(
3
,
4
,
23
,
3
)),
152
:
(
Bottleneck
,
(
3
,
8
,
36
,
3
)),
200
:
(
Bottleneck
,
(
3
,
24
,
36
,
3
)),
269
:
(
Bottleneck
,
(
3
,
30
,
48
,
8
))
}
def
__init__
(
self
,
depth
,
groups
=
1
,
width_per_group
=
4
,
radix
=
2
,
reduction_factor
=
4
,
avg_down_stride
=
True
,
**
kwargs
):
self
.
groups
=
groups
self
.
width_per_group
=
width_per_group
self
.
radix
=
radix
self
.
reduction_factor
=
reduction_factor
self
.
avg_down_stride
=
avg_down_stride
super
(
ResNeSt
,
self
).
__init__
(
depth
=
depth
,
**
kwargs
)
def
make_res_layer
(
self
,
**
kwargs
):
return
ResLayer
(
groups
=
self
.
groups
,
width_per_group
=
self
.
width_per_group
,
base_channels
=
self
.
base_channels
,
radix
=
self
.
radix
,
reduction_factor
=
self
.
reduction_factor
,
avg_down_stride
=
self
.
avg_down_stride
,
**
kwargs
)
mmpretrain/models/backbones/resnet.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
import
math
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
cp
from
mmcv.cnn
import
(
ConvModule
,
build_activation_layer
,
build_conv_layer
,
build_norm_layer
)
from
mmcv.cnn.bricks
import
DropPath
from
mmengine.model
import
BaseModule
from
mmengine.model.weight_init
import
constant_init
from
mmengine.utils.dl_utils.parrots_wrapper
import
_BatchNorm
from
mmpretrain.registry
import
MODELS
from
.base_backbone
import
BaseBackbone
eps
=
1.0e-5
class
BasicBlock
(
BaseModule
):
"""BasicBlock for ResNet.
Args:
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
expansion (int): The ratio of ``out_channels/mid_channels`` where
``mid_channels`` is the output channels of conv1. This is a
reserved argument in BasicBlock and should always be 1. Default: 1.
stride (int): stride of the block. Default: 1
dilation (int): dilation of convolution. Default: 1
downsample (nn.Module, optional): downsample operation on identity
branch. Default: None.
style (str): `pytorch` or `caffe`. It is unused and reserved for
unified API with Bottleneck.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
conv_cfg (dict, optional): dictionary to construct and config conv
layer. Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
expansion
=
1
,
stride
=
1
,
dilation
=
1
,
downsample
=
None
,
style
=
'pytorch'
,
with_cp
=
False
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
),
drop_path_rate
=
0.0
,
act_cfg
=
dict
(
type
=
'ReLU'
,
inplace
=
True
),
init_cfg
=
None
):
super
(
BasicBlock
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
expansion
=
expansion
assert
self
.
expansion
==
1
assert
out_channels
%
expansion
==
0
self
.
mid_channels
=
out_channels
//
expansion
self
.
stride
=
stride
self
.
dilation
=
dilation
self
.
style
=
style
self
.
with_cp
=
with_cp
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
norm1_name
,
norm1
=
build_norm_layer
(
norm_cfg
,
self
.
mid_channels
,
postfix
=
1
)
self
.
norm2_name
,
norm2
=
build_norm_layer
(
norm_cfg
,
out_channels
,
postfix
=
2
)
self
.
conv1
=
build_conv_layer
(
conv_cfg
,
in_channels
,
self
.
mid_channels
,
3
,
stride
=
stride
,
padding
=
dilation
,
dilation
=
dilation
,
bias
=
False
)
self
.
add_module
(
self
.
norm1_name
,
norm1
)
self
.
conv2
=
build_conv_layer
(
conv_cfg
,
self
.
mid_channels
,
out_channels
,
3
,
padding
=
1
,
bias
=
False
)
self
.
add_module
(
self
.
norm2_name
,
norm2
)
self
.
relu
=
build_activation_layer
(
act_cfg
)
self
.
downsample
=
downsample
self
.
drop_path
=
DropPath
(
drop_prob
=
drop_path_rate
)
if
drop_path_rate
>
eps
else
nn
.
Identity
()
@
property
def
norm1
(
self
):
return
getattr
(
self
,
self
.
norm1_name
)
@
property
def
norm2
(
self
):
return
getattr
(
self
,
self
.
norm2_name
)
def
forward
(
self
,
x
):
def
_inner_forward
(
x
):
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
norm1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
norm2
(
out
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
=
self
.
drop_path
(
out
)
out
+=
identity
return
out
if
self
.
with_cp
and
x
.
requires_grad
:
out
=
cp
.
checkpoint
(
_inner_forward
,
x
)
else
:
out
=
_inner_forward
(
x
)
out
=
self
.
relu
(
out
)
return
out
class
Bottleneck
(
BaseModule
):
"""Bottleneck block for ResNet.
Args:
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
expansion (int): The ratio of ``out_channels/mid_channels`` where
``mid_channels`` is the input/output channels of conv2. Default: 4.
stride (int): stride of the block. Default: 1
dilation (int): dilation of convolution. Default: 1
downsample (nn.Module, optional): downsample operation on identity
branch. Default: None.
style (str): ``"pytorch"`` or ``"caffe"``. If set to "pytorch", the
stride-two layer is the 3x3 conv layer, otherwise the stride-two
layer is the first 1x1 conv layer. Default: "pytorch".
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
conv_cfg (dict, optional): dictionary to construct and config conv
layer. Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
expansion
=
4
,
stride
=
1
,
dilation
=
1
,
downsample
=
None
,
style
=
'pytorch'
,
with_cp
=
False
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
),
act_cfg
=
dict
(
type
=
'ReLU'
,
inplace
=
True
),
drop_path_rate
=
0.0
,
init_cfg
=
None
):
super
(
Bottleneck
,
self
).
__init__
(
init_cfg
=
init_cfg
)
assert
style
in
[
'pytorch'
,
'caffe'
]
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
expansion
=
expansion
assert
out_channels
%
expansion
==
0
self
.
mid_channels
=
out_channels
//
expansion
self
.
stride
=
stride
self
.
dilation
=
dilation
self
.
style
=
style
self
.
with_cp
=
with_cp
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
if
self
.
style
==
'pytorch'
:
self
.
conv1_stride
=
1
self
.
conv2_stride
=
stride
else
:
self
.
conv1_stride
=
stride
self
.
conv2_stride
=
1
self
.
norm1_name
,
norm1
=
build_norm_layer
(
norm_cfg
,
self
.
mid_channels
,
postfix
=
1
)
self
.
norm2_name
,
norm2
=
build_norm_layer
(
norm_cfg
,
self
.
mid_channels
,
postfix
=
2
)
self
.
norm3_name
,
norm3
=
build_norm_layer
(
norm_cfg
,
out_channels
,
postfix
=
3
)
self
.
conv1
=
build_conv_layer
(
conv_cfg
,
in_channels
,
self
.
mid_channels
,
kernel_size
=
1
,
stride
=
self
.
conv1_stride
,
bias
=
False
)
self
.
add_module
(
self
.
norm1_name
,
norm1
)
self
.
conv2
=
build_conv_layer
(
conv_cfg
,
self
.
mid_channels
,
self
.
mid_channels
,
kernel_size
=
3
,
stride
=
self
.
conv2_stride
,
padding
=
dilation
,
dilation
=
dilation
,
bias
=
False
)
self
.
add_module
(
self
.
norm2_name
,
norm2
)
self
.
conv3
=
build_conv_layer
(
conv_cfg
,
self
.
mid_channels
,
out_channels
,
kernel_size
=
1
,
bias
=
False
)
self
.
add_module
(
self
.
norm3_name
,
norm3
)
self
.
relu
=
build_activation_layer
(
act_cfg
)
self
.
downsample
=
downsample
self
.
drop_path
=
DropPath
(
drop_prob
=
drop_path_rate
)
if
drop_path_rate
>
eps
else
nn
.
Identity
()
@
property
def
norm1
(
self
):
return
getattr
(
self
,
self
.
norm1_name
)
@
property
def
norm2
(
self
):
return
getattr
(
self
,
self
.
norm2_name
)
@
property
def
norm3
(
self
):
return
getattr
(
self
,
self
.
norm3_name
)
def
forward
(
self
,
x
):
def
_inner_forward
(
x
):
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
norm1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
norm2
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
norm3
(
out
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
=
self
.
drop_path
(
out
)
out
+=
identity
return
out
if
self
.
with_cp
and
x
.
requires_grad
:
out
=
cp
.
checkpoint
(
_inner_forward
,
x
)
else
:
out
=
_inner_forward
(
x
)
out
=
self
.
relu
(
out
)
return
out
def
get_expansion
(
block
,
expansion
=
None
):
"""Get the expansion of a residual block.
The block expansion will be obtained by the following order:
1. If ``expansion`` is given, just return it.
2. If ``block`` has the attribute ``expansion``, then return
``block.expansion``.
3. Return the default value according the the block type:
1 for ``BasicBlock`` and 4 for ``Bottleneck``.
Args:
block (class): The block class.
expansion (int | None): The given expansion ratio.
Returns:
int: The expansion of the block.
"""
if
isinstance
(
expansion
,
int
):
assert
expansion
>
0
elif
expansion
is
None
:
if
hasattr
(
block
,
'expansion'
):
expansion
=
block
.
expansion
elif
issubclass
(
block
,
BasicBlock
):
expansion
=
1
elif
issubclass
(
block
,
Bottleneck
):
expansion
=
4
else
:
raise
TypeError
(
f
'expansion is not specified for
{
block
.
__name__
}
'
)
else
:
raise
TypeError
(
'expansion must be an integer or None'
)
return
expansion
class
ResLayer
(
nn
.
Sequential
):
"""ResLayer to build ResNet style backbone.
Args:
block (nn.Module): Residual block used to build ResLayer.
num_blocks (int): Number of blocks.
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
expansion (int, optional): The expansion for BasicBlock/Bottleneck.
If not specified, it will firstly be obtained via
``block.expansion``. If the block has no attribute "expansion",
the following default values will be used: 1 for BasicBlock and
4 for Bottleneck. Default: None.
stride (int): stride of the first block. Default: 1.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False
conv_cfg (dict, optional): dictionary to construct and config conv
layer. Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
drop_path_rate (float or list): stochastic depth rate.
Default: 0.
"""
def
__init__
(
self
,
block
,
num_blocks
,
in_channels
,
out_channels
,
expansion
=
None
,
stride
=
1
,
avg_down
=
False
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
),
drop_path_rate
=
0.0
,
**
kwargs
):
self
.
block
=
block
self
.
expansion
=
get_expansion
(
block
,
expansion
)
if
isinstance
(
drop_path_rate
,
float
):
drop_path_rate
=
[
drop_path_rate
]
*
num_blocks
assert
len
(
drop_path_rate
)
==
num_blocks
,
'Please check the length of drop_path_rate'
downsample
=
None
if
stride
!=
1
or
in_channels
!=
out_channels
:
downsample
=
[]
conv_stride
=
stride
if
avg_down
and
stride
!=
1
:
conv_stride
=
1
downsample
.
append
(
nn
.
AvgPool2d
(
kernel_size
=
stride
,
stride
=
stride
,
ceil_mode
=
True
,
count_include_pad
=
False
))
downsample
.
extend
([
build_conv_layer
(
conv_cfg
,
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
conv_stride
,
bias
=
False
),
build_norm_layer
(
norm_cfg
,
out_channels
)[
1
]
])
downsample
=
nn
.
Sequential
(
*
downsample
)
layers
=
[]
layers
.
append
(
block
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
expansion
=
self
.
expansion
,
stride
=
stride
,
downsample
=
downsample
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
drop_path_rate
=
drop_path_rate
[
0
],
**
kwargs
))
in_channels
=
out_channels
for
i
in
range
(
1
,
num_blocks
):
layers
.
append
(
block
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
expansion
=
self
.
expansion
,
stride
=
1
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
drop_path_rate
=
drop_path_rate
[
i
],
**
kwargs
))
super
(
ResLayer
,
self
).
__init__
(
*
layers
)
@
MODELS
.
register_module
()
class
ResNet
(
BaseBackbone
):
"""ResNet backbone.
Please refer to the `paper <https://arxiv.org/abs/1512.03385>`__ for
details.
Args:
depth (int): Network depth, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Default: 3.
stem_channels (int): Output channels of the stem layer. Default: 64.
base_channels (int): Middle channels of the first stage. Default: 64.
num_stages (int): Stages of the network. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages.
Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Default: False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
Example:
>>> from mmpretrain.models import ResNet
>>> import torch
>>> self = ResNet(depth=18)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 64, 8, 8)
(1, 128, 4, 4)
(1, 256, 2, 2)
(1, 512, 1, 1)
"""
arch_settings
=
{
18
:
(
BasicBlock
,
(
2
,
2
,
2
,
2
)),
34
:
(
BasicBlock
,
(
3
,
4
,
6
,
3
)),
50
:
(
Bottleneck
,
(
3
,
4
,
6
,
3
)),
101
:
(
Bottleneck
,
(
3
,
4
,
23
,
3
)),
152
:
(
Bottleneck
,
(
3
,
8
,
36
,
3
))
}
def
__init__
(
self
,
depth
,
in_channels
=
3
,
stem_channels
=
64
,
base_channels
=
64
,
expansion
=
None
,
num_stages
=
4
,
strides
=
(
1
,
2
,
2
,
2
),
dilations
=
(
1
,
1
,
1
,
1
),
out_indices
=
(
3
,
),
style
=
'pytorch'
,
deep_stem
=
False
,
avg_down
=
False
,
frozen_stages
=-
1
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
,
requires_grad
=
True
),
norm_eval
=
False
,
with_cp
=
False
,
zero_init_residual
=
True
,
init_cfg
=
[
dict
(
type
=
'Kaiming'
,
layer
=
[
'Conv2d'
]),
dict
(
type
=
'Constant'
,
val
=
1
,
layer
=
[
'_BatchNorm'
,
'GroupNorm'
])
],
drop_path_rate
=
0.0
):
super
(
ResNet
,
self
).
__init__
(
init_cfg
)
if
depth
not
in
self
.
arch_settings
:
raise
KeyError
(
f
'invalid depth
{
depth
}
for resnet'
)
self
.
depth
=
depth
self
.
stem_channels
=
stem_channels
self
.
base_channels
=
base_channels
self
.
num_stages
=
num_stages
assert
num_stages
>=
1
and
num_stages
<=
4
self
.
strides
=
strides
self
.
dilations
=
dilations
assert
len
(
strides
)
==
len
(
dilations
)
==
num_stages
self
.
out_indices
=
out_indices
assert
max
(
out_indices
)
<
num_stages
self
.
style
=
style
self
.
deep_stem
=
deep_stem
self
.
avg_down
=
avg_down
self
.
frozen_stages
=
frozen_stages
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
with_cp
=
with_cp
self
.
norm_eval
=
norm_eval
self
.
zero_init_residual
=
zero_init_residual
self
.
block
,
stage_blocks
=
self
.
arch_settings
[
depth
]
self
.
stage_blocks
=
stage_blocks
[:
num_stages
]
self
.
expansion
=
get_expansion
(
self
.
block
,
expansion
)
self
.
_make_stem_layer
(
in_channels
,
stem_channels
)
self
.
res_layers
=
[]
_in_channels
=
stem_channels
_out_channels
=
base_channels
*
self
.
expansion
# stochastic depth decay rule
total_depth
=
sum
(
stage_blocks
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
total_depth
)
]
for
i
,
num_blocks
in
enumerate
(
self
.
stage_blocks
):
stride
=
strides
[
i
]
dilation
=
dilations
[
i
]
res_layer
=
self
.
make_res_layer
(
block
=
self
.
block
,
num_blocks
=
num_blocks
,
in_channels
=
_in_channels
,
out_channels
=
_out_channels
,
expansion
=
self
.
expansion
,
stride
=
stride
,
dilation
=
dilation
,
style
=
self
.
style
,
avg_down
=
self
.
avg_down
,
with_cp
=
with_cp
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
drop_path_rate
=
dpr
[:
num_blocks
])
_in_channels
=
_out_channels
_out_channels
*=
2
dpr
=
dpr
[
num_blocks
:]
layer_name
=
f
'layer
{
i
+
1
}
'
self
.
add_module
(
layer_name
,
res_layer
)
self
.
res_layers
.
append
(
layer_name
)
self
.
_freeze_stages
()
self
.
feat_dim
=
res_layer
[
-
1
].
out_channels
def
make_res_layer
(
self
,
**
kwargs
):
return
ResLayer
(
**
kwargs
)
@
property
def
norm1
(
self
):
return
getattr
(
self
,
self
.
norm1_name
)
def
_make_stem_layer
(
self
,
in_channels
,
stem_channels
):
if
self
.
deep_stem
:
self
.
stem
=
nn
.
Sequential
(
ConvModule
(
in_channels
,
stem_channels
//
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
,
inplace
=
True
),
ConvModule
(
stem_channels
//
2
,
stem_channels
//
2
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
,
inplace
=
True
),
ConvModule
(
stem_channels
//
2
,
stem_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
,
inplace
=
True
))
else
:
self
.
conv1
=
build_conv_layer
(
self
.
conv_cfg
,
in_channels
,
stem_channels
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
self
.
norm1_name
,
norm1
=
build_norm_layer
(
self
.
norm_cfg
,
stem_channels
,
postfix
=
1
)
self
.
add_module
(
self
.
norm1_name
,
norm1
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
def
_freeze_stages
(
self
):
if
self
.
frozen_stages
>=
0
:
if
self
.
deep_stem
:
self
.
stem
.
eval
()
for
param
in
self
.
stem
.
parameters
():
param
.
requires_grad
=
False
else
:
self
.
norm1
.
eval
()
for
m
in
[
self
.
conv1
,
self
.
norm1
]:
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
for
i
in
range
(
1
,
self
.
frozen_stages
+
1
):
m
=
getattr
(
self
,
f
'layer
{
i
}
'
)
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
def
init_weights
(
self
):
super
(
ResNet
,
self
).
init_weights
()
if
(
isinstance
(
self
.
init_cfg
,
dict
)
and
self
.
init_cfg
[
'type'
]
==
'Pretrained'
):
# Suppress zero_init_residual if use pretrained model.
return
if
self
.
zero_init_residual
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
Bottleneck
):
constant_init
(
m
.
norm3
,
0
)
elif
isinstance
(
m
,
BasicBlock
):
constant_init
(
m
.
norm2
,
0
)
def
forward
(
self
,
x
):
if
self
.
deep_stem
:
x
=
self
.
stem
(
x
)
else
:
x
=
self
.
conv1
(
x
)
x
=
self
.
norm1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
maxpool
(
x
)
outs
=
[]
for
i
,
layer_name
in
enumerate
(
self
.
res_layers
):
res_layer
=
getattr
(
self
,
layer_name
)
x
=
res_layer
(
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
x
)
return
tuple
(
outs
)
def
train
(
self
,
mode
=
True
):
super
(
ResNet
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
if
mode
and
self
.
norm_eval
:
for
m
in
self
.
modules
():
# trick: eval have effect on BatchNorm only
if
isinstance
(
m
,
_BatchNorm
):
m
.
eval
()
def
get_layer_depth
(
self
,
param_name
:
str
,
prefix
:
str
=
''
):
"""Get the layer id to set the different learning rates for ResNet.
ResNet stages:
50 : [3, 4, 6, 3]
101 : [3, 4, 23, 3]
152 : [3, 8, 36, 3]
200 : [3, 24, 36, 3]
eca269d: [3, 30, 48, 8]
Args:
param_name (str): The name of the parameter.
prefix (str): The prefix for the parameter.
Defaults to an empty string.
Returns:
Tuple[int, int]: The layer-wise depth and the num of layers.
"""
depths
=
self
.
stage_blocks
if
depths
[
1
]
==
4
and
depths
[
2
]
==
6
:
blk2
,
blk3
=
2
,
3
elif
depths
[
1
]
==
4
and
depths
[
2
]
==
23
:
blk2
,
blk3
=
2
,
3
elif
depths
[
1
]
==
8
and
depths
[
2
]
==
36
:
blk2
,
blk3
=
4
,
4
elif
depths
[
1
]
==
24
and
depths
[
2
]
==
36
:
blk2
,
blk3
=
4
,
4
elif
depths
[
1
]
==
30
and
depths
[
2
]
==
48
:
blk2
,
blk3
=
5
,
6
else
:
raise
NotImplementedError
N2
,
N3
=
math
.
ceil
(
depths
[
1
]
/
blk2
-
1e-5
),
math
.
ceil
(
depths
[
2
]
/
blk3
-
1e-5
)
N
=
2
+
N2
+
N3
# r50: 2 + 2 + 2 = 6
max_layer_id
=
N
+
1
# r50: 2 + 2 + 2 + 1(like head) = 7
if
not
param_name
.
startswith
(
prefix
):
# For subsequent module like head
return
max_layer_id
,
max_layer_id
+
1
if
param_name
.
startswith
(
'backbone.layer'
):
stage_id
=
int
(
param_name
.
split
(
'.'
)[
1
][
5
:])
block_id
=
int
(
param_name
.
split
(
'.'
)[
2
])
if
stage_id
==
1
:
layer_id
=
1
elif
stage_id
==
2
:
layer_id
=
2
+
block_id
//
blk2
# r50: 2, 3
elif
stage_id
==
3
:
layer_id
=
2
+
N2
+
block_id
//
blk3
# r50: 4, 5
else
:
# stage_id == 4
layer_id
=
N
# r50: 6
return
layer_id
,
max_layer_id
+
1
else
:
return
0
,
max_layer_id
+
1
@
MODELS
.
register_module
()
class
ResNetV1c
(
ResNet
):
"""ResNetV1c backbone.
This variant is described in `Bag of Tricks.
<https://arxiv.org/pdf/1812.01187.pdf>`_.
Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv
in the input stem with three 3x3 convs.
"""
def
__init__
(
self
,
**
kwargs
):
super
(
ResNetV1c
,
self
).
__init__
(
deep_stem
=
True
,
avg_down
=
False
,
**
kwargs
)
@
MODELS
.
register_module
()
class
ResNetV1d
(
ResNet
):
"""ResNetV1d backbone.
This variant is described in `Bag of Tricks.
<https://arxiv.org/pdf/1812.01187.pdf>`_.
Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
the input stem with three 3x3 convs. And in the downsampling block, a 2x2
avg_pool with stride 2 is added before conv, whose stride is changed to 1.
"""
def
__init__
(
self
,
**
kwargs
):
super
(
ResNetV1d
,
self
).
__init__
(
deep_stem
=
True
,
avg_down
=
True
,
**
kwargs
)
mmpretrain/models/backbones/resnet_cifar.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
import
torch.nn
as
nn
from
mmcv.cnn
import
build_conv_layer
,
build_norm_layer
from
mmpretrain.registry
import
MODELS
from
.resnet
import
ResNet
@
MODELS
.
register_module
()
class
ResNet_CIFAR
(
ResNet
):
"""ResNet backbone for CIFAR.
Compared to standard ResNet, it uses `kernel_size=3` and `stride=1` in
conv1, and does not apply MaxPoolinng after stem. It has been proven to
be more efficient than standard ResNet in other public codebase, e.g.,
`https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py`.
Args:
depth (int): Network depth, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Default: 3.
stem_channels (int): Output channels of the stem layer. Default: 64.
base_channels (int): Middle channels of the first stage. Default: 64.
num_stages (int): Stages of the network. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): This network has specific designed stem, thus it is
asserted to be False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
"""
def
__init__
(
self
,
depth
,
deep_stem
=
False
,
**
kwargs
):
super
(
ResNet_CIFAR
,
self
).
__init__
(
depth
,
deep_stem
=
deep_stem
,
**
kwargs
)
assert
not
self
.
deep_stem
,
'ResNet_CIFAR do not support deep_stem'
def
_make_stem_layer
(
self
,
in_channels
,
base_channels
):
self
.
conv1
=
build_conv_layer
(
self
.
conv_cfg
,
in_channels
,
base_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
)
self
.
norm1_name
,
norm1
=
build_norm_layer
(
self
.
norm_cfg
,
base_channels
,
postfix
=
1
)
self
.
add_module
(
self
.
norm1_name
,
norm1
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
norm1
(
x
)
x
=
self
.
relu
(
x
)
outs
=
[]
for
i
,
layer_name
in
enumerate
(
self
.
res_layers
):
res_layer
=
getattr
(
self
,
layer_name
)
x
=
res_layer
(
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
x
)
return
tuple
(
outs
)
mmpretrain/models/backbones/resnext.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
mmcv.cnn
import
build_conv_layer
,
build_norm_layer
from
mmpretrain.registry
import
MODELS
from
.resnet
import
Bottleneck
as
_Bottleneck
from
.resnet
import
ResLayer
,
ResNet
class
Bottleneck
(
_Bottleneck
):
"""Bottleneck block for ResNeXt.
Args:
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
groups (int): Groups of conv2.
width_per_group (int): Width per group of conv2. 64x4d indicates
``groups=64, width_per_group=4`` and 32x8d indicates
``groups=32, width_per_group=8``.
stride (int): stride of the block. Default: 1
dilation (int): dilation of convolution. Default: 1
downsample (nn.Module, optional): downsample operation on identity
branch. Default: None
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
conv_cfg (dict, optional): dictionary to construct and config conv
layer. Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
base_channels
=
64
,
groups
=
32
,
width_per_group
=
4
,
**
kwargs
):
super
(
Bottleneck
,
self
).
__init__
(
in_channels
,
out_channels
,
**
kwargs
)
self
.
groups
=
groups
self
.
width_per_group
=
width_per_group
# For ResNet bottleneck, middle channels are determined by expansion
# and out_channels, but for ResNeXt bottleneck, it is determined by
# groups and width_per_group and the stage it is located in.
if
groups
!=
1
:
assert
self
.
mid_channels
%
base_channels
==
0
self
.
mid_channels
=
(
groups
*
width_per_group
*
self
.
mid_channels
//
base_channels
)
self
.
norm1_name
,
norm1
=
build_norm_layer
(
self
.
norm_cfg
,
self
.
mid_channels
,
postfix
=
1
)
self
.
norm2_name
,
norm2
=
build_norm_layer
(
self
.
norm_cfg
,
self
.
mid_channels
,
postfix
=
2
)
self
.
norm3_name
,
norm3
=
build_norm_layer
(
self
.
norm_cfg
,
self
.
out_channels
,
postfix
=
3
)
self
.
conv1
=
build_conv_layer
(
self
.
conv_cfg
,
self
.
in_channels
,
self
.
mid_channels
,
kernel_size
=
1
,
stride
=
self
.
conv1_stride
,
bias
=
False
)
self
.
add_module
(
self
.
norm1_name
,
norm1
)
self
.
conv2
=
build_conv_layer
(
self
.
conv_cfg
,
self
.
mid_channels
,
self
.
mid_channels
,
kernel_size
=
3
,
stride
=
self
.
conv2_stride
,
padding
=
self
.
dilation
,
dilation
=
self
.
dilation
,
groups
=
groups
,
bias
=
False
)
self
.
add_module
(
self
.
norm2_name
,
norm2
)
self
.
conv3
=
build_conv_layer
(
self
.
conv_cfg
,
self
.
mid_channels
,
self
.
out_channels
,
kernel_size
=
1
,
bias
=
False
)
self
.
add_module
(
self
.
norm3_name
,
norm3
)
@
MODELS
.
register_module
()
class
ResNeXt
(
ResNet
):
"""ResNeXt backbone.
Please refer to the `paper <https://arxiv.org/abs/1611.05431>`__ for
details.
Args:
depth (int): Network depth, from {50, 101, 152}.
groups (int): Groups of conv2 in Bottleneck. Default: 32.
width_per_group (int): Width per group of conv2 in Bottleneck.
Default: 4.
in_channels (int): Number of input image channels. Default: 3.
stem_channels (int): Output channels of the stem layer. Default: 64.
num_stages (int): Stages of the network. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Default: False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
"""
arch_settings
=
{
50
:
(
Bottleneck
,
(
3
,
4
,
6
,
3
)),
101
:
(
Bottleneck
,
(
3
,
4
,
23
,
3
)),
152
:
(
Bottleneck
,
(
3
,
8
,
36
,
3
))
}
def
__init__
(
self
,
depth
,
groups
=
32
,
width_per_group
=
4
,
**
kwargs
):
self
.
groups
=
groups
self
.
width_per_group
=
width_per_group
super
(
ResNeXt
,
self
).
__init__
(
depth
,
**
kwargs
)
def
make_res_layer
(
self
,
**
kwargs
):
return
ResLayer
(
groups
=
self
.
groups
,
width_per_group
=
self
.
width_per_group
,
base_channels
=
self
.
base_channels
,
**
kwargs
)
mmpretrain/models/backbones/revvit.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
import
sys
import
numpy
as
np
import
torch
from
mmcv.cnn.bricks.drop
import
build_dropout
from
mmcv.cnn.bricks.transformer
import
FFN
,
PatchEmbed
from
mmengine.model
import
BaseModule
,
ModuleList
from
mmengine.model.weight_init
import
trunc_normal_
from
torch
import
nn
from
torch.autograd
import
Function
as
Function
from
mmpretrain.models.backbones.base_backbone
import
BaseBackbone
from
mmpretrain.registry
import
MODELS
from
..utils
import
(
MultiheadAttention
,
build_norm_layer
,
resize_pos_embed
,
to_2tuple
)
class
RevBackProp
(
Function
):
"""Custom Backpropagation function to allow (A) flushing memory in forward
and (B) activation recomputation reversibly in backward for gradient
calculation.
Inspired by
https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
"""
@
staticmethod
def
forward
(
ctx
,
x
,
layers
,
buffer_layers
,
# List of layer ids for int activation to buffer
):
"""Reversible Forward pass.
Any intermediate activations from `buffer_layers` are cached in ctx for
forward pass. This is not necessary for standard usecases. Each
reversible layer implements its own forward pass logic.
"""
buffer_layers
.
sort
()
x1
,
x2
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
intermediate
=
[]
for
layer
in
layers
:
x1
,
x2
=
layer
(
x1
,
x2
)
if
layer
.
layer_id
in
buffer_layers
:
intermediate
.
extend
([
x1
.
detach
(),
x2
.
detach
()])
if
len
(
buffer_layers
)
==
0
:
all_tensors
=
[
x1
.
detach
(),
x2
.
detach
()]
else
:
intermediate
=
[
torch
.
LongTensor
(
buffer_layers
),
*
intermediate
]
all_tensors
=
[
x1
.
detach
(),
x2
.
detach
(),
*
intermediate
]
ctx
.
save_for_backward
(
*
all_tensors
)
ctx
.
layers
=
layers
return
torch
.
cat
([
x1
,
x2
],
dim
=-
1
)
@
staticmethod
def
backward
(
ctx
,
dx
):
"""Reversible Backward pass.
Any intermediate activations from `buffer_layers` are recovered from
ctx. Each layer implements its own loic for backward pass (both
activation recomputation and grad calculation).
"""
d_x1
,
d_x2
=
torch
.
chunk
(
dx
,
2
,
dim
=-
1
)
# retrieve params from ctx for backward
x1
,
x2
,
*
int_tensors
=
ctx
.
saved_tensors
# no buffering
if
len
(
int_tensors
)
!=
0
:
buffer_layers
=
int_tensors
[
0
].
tolist
()
else
:
buffer_layers
=
[]
layers
=
ctx
.
layers
for
_
,
layer
in
enumerate
(
layers
[::
-
1
]):
if
layer
.
layer_id
in
buffer_layers
:
x1
,
x2
,
d_x1
,
d_x2
=
layer
.
backward_pass
(
y1
=
int_tensors
[
buffer_layers
.
index
(
layer
.
layer_id
)
*
2
+
1
],
y2
=
int_tensors
[
buffer_layers
.
index
(
layer
.
layer_id
)
*
2
+
2
],
d_y1
=
d_x1
,
d_y2
=
d_x2
,
)
else
:
x1
,
x2
,
d_x1
,
d_x2
=
layer
.
backward_pass
(
y1
=
x1
,
y2
=
x2
,
d_y1
=
d_x1
,
d_y2
=
d_x2
,
)
dx
=
torch
.
cat
([
d_x1
,
d_x2
],
dim
=-
1
)
del
int_tensors
del
d_x1
,
d_x2
,
x1
,
x2
return
dx
,
None
,
None
class
RevTransformerEncoderLayer
(
BaseModule
):
"""Reversible Transformer Encoder Layer.
This module is a building block of Reversible Transformer Encoder,
which support backpropagation without storing activations.
The residual connection is not applied to the FFN layer.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed.
Default: 0.0
attn_drop_rate (float): The drop out rate for attention layer.
Default: 0.0
drop_path_rate (float): stochastic depth rate.
Default 0.0
num_fcs (int): The number of linear in FFN
Default: 2
qkv_bias (bool): enable bias for qkv if True.
Default: True
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU')
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
layer_id (int): The layer id of current layer. Used in RevBackProp.
Default: 0
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def
__init__
(
self
,
embed_dims
:
int
,
num_heads
:
int
,
feedforward_channels
:
int
,
drop_rate
:
float
=
0.
,
attn_drop_rate
:
float
=
0.
,
drop_path_rate
:
float
=
0.
,
num_fcs
:
int
=
2
,
qkv_bias
:
bool
=
True
,
act_cfg
:
dict
=
dict
(
type
=
'GELU'
),
norm_cfg
:
dict
=
dict
(
type
=
'LN'
),
layer_id
:
int
=
0
,
init_cfg
=
None
):
super
(
RevTransformerEncoderLayer
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
drop_path_cfg
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
)
self
.
embed_dims
=
embed_dims
self
.
ln1
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
)
self
.
attn
=
MultiheadAttention
(
embed_dims
=
embed_dims
,
num_heads
=
num_heads
,
attn_drop
=
attn_drop_rate
,
proj_drop
=
drop_rate
,
qkv_bias
=
qkv_bias
)
self
.
ln2
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
)
self
.
ffn
=
FFN
(
embed_dims
=
embed_dims
,
feedforward_channels
=
feedforward_channels
,
num_fcs
=
num_fcs
,
ffn_drop
=
drop_rate
,
act_cfg
=
act_cfg
,
add_identity
=
False
)
self
.
layer_id
=
layer_id
self
.
seeds
=
{}
def
init_weights
(
self
):
super
(
RevTransformerEncoderLayer
,
self
).
init_weights
()
for
m
in
self
.
ffn
.
modules
():
if
isinstance
(
m
,
nn
.
Linear
):
nn
.
init
.
xavier_uniform_
(
m
.
weight
)
nn
.
init
.
normal_
(
m
.
bias
,
std
=
1e-6
)
def
seed_cuda
(
self
,
key
):
"""Fix seeds to allow for stochastic elements such as dropout to be
reproduced exactly in activation recomputation in the backward pass."""
# randomize seeds
# use cuda generator if available
if
(
hasattr
(
torch
.
cuda
,
'default_generators'
)
and
len
(
torch
.
cuda
.
default_generators
)
>
0
):
# GPU
device_idx
=
torch
.
cuda
.
current_device
()
seed
=
torch
.
cuda
.
default_generators
[
device_idx
].
seed
()
else
:
# CPU
seed
=
int
(
torch
.
seed
()
%
sys
.
maxsize
)
self
.
seeds
[
key
]
=
seed
torch
.
manual_seed
(
self
.
seeds
[
key
])
def
forward
(
self
,
x1
,
x2
):
"""
Implementation of Reversible TransformerEncoderLayer
`
x = x + self.attn(self.ln1(x))
x = self.ffn(self.ln2(x), identity=x)
`
"""
self
.
seed_cuda
(
'attn'
)
# attention output
f_x2
=
self
.
attn
(
self
.
ln1
(
x2
))
# apply droppath on attention output
self
.
seed_cuda
(
'droppath'
)
f_x2_dropped
=
build_dropout
(
self
.
drop_path_cfg
)(
f_x2
)
y1
=
x1
+
f_x2_dropped
# free memory
if
self
.
training
:
del
x1
# ffn output
self
.
seed_cuda
(
'ffn'
)
g_y1
=
self
.
ffn
(
self
.
ln2
(
y1
))
# apply droppath on ffn output
torch
.
manual_seed
(
self
.
seeds
[
'droppath'
])
g_y1_dropped
=
build_dropout
(
self
.
drop_path_cfg
)(
g_y1
)
# final output
y2
=
x2
+
g_y1_dropped
# free memory
if
self
.
training
:
del
x2
return
y1
,
y2
def
backward_pass
(
self
,
y1
,
y2
,
d_y1
,
d_y2
):
"""Activation re-compute with the following equation.
x2 = y2 - g(y1), g = FFN
x1 = y1 - f(x2), f = MSHA
"""
# temporarily record intermediate activation for G
# and use them for gradient calculation of G
with
torch
.
enable_grad
():
y1
.
requires_grad
=
True
torch
.
manual_seed
(
self
.
seeds
[
'ffn'
])
g_y1
=
self
.
ffn
(
self
.
ln2
(
y1
))
torch
.
manual_seed
(
self
.
seeds
[
'droppath'
])
g_y1
=
build_dropout
(
self
.
drop_path_cfg
)(
g_y1
)
g_y1
.
backward
(
d_y2
,
retain_graph
=
True
)
# activate recomputation is by design and not part of
# the computation graph in forward pass
with
torch
.
no_grad
():
x2
=
y2
-
g_y1
del
g_y1
d_y1
=
d_y1
+
y1
.
grad
y1
.
grad
=
None
# record F activation and calculate gradients on F
with
torch
.
enable_grad
():
x2
.
requires_grad
=
True
torch
.
manual_seed
(
self
.
seeds
[
'attn'
])
f_x2
=
self
.
attn
(
self
.
ln1
(
x2
))
torch
.
manual_seed
(
self
.
seeds
[
'droppath'
])
f_x2
=
build_dropout
(
self
.
drop_path_cfg
)(
f_x2
)
f_x2
.
backward
(
d_y1
,
retain_graph
=
True
)
# propagate reverse computed activations at the
# start of the previous block
with
torch
.
no_grad
():
x1
=
y1
-
f_x2
del
f_x2
,
y1
d_y2
=
d_y2
+
x2
.
grad
x2
.
grad
=
None
x2
=
x2
.
detach
()
return
x1
,
x2
,
d_y1
,
d_y2
class
TwoStreamFusion
(
nn
.
Module
):
"""A general constructor for neural modules fusing two equal sized tensors
in forward.
Args:
mode (str): The mode of fusion. Options are 'add', 'max', 'min',
'avg', 'concat'.
"""
def
__init__
(
self
,
mode
:
str
):
super
().
__init__
()
self
.
mode
=
mode
if
mode
==
'add'
:
self
.
fuse_fn
=
lambda
x
:
torch
.
stack
(
x
).
sum
(
dim
=
0
)
elif
mode
==
'max'
:
self
.
fuse_fn
=
lambda
x
:
torch
.
stack
(
x
).
max
(
dim
=
0
).
values
elif
mode
==
'min'
:
self
.
fuse_fn
=
lambda
x
:
torch
.
stack
(
x
).
min
(
dim
=
0
).
values
elif
mode
==
'avg'
:
self
.
fuse_fn
=
lambda
x
:
torch
.
stack
(
x
).
mean
(
dim
=
0
)
elif
mode
==
'concat'
:
self
.
fuse_fn
=
lambda
x
:
torch
.
cat
(
x
,
dim
=-
1
)
else
:
raise
NotImplementedError
def
forward
(
self
,
x
):
# split the tensor into two halves in the channel dimension
x
=
torch
.
chunk
(
x
,
2
,
dim
=
2
)
return
self
.
fuse_fn
(
x
)
@
MODELS
.
register_module
()
class
RevVisionTransformer
(
BaseBackbone
):
"""Reversible Vision Transformer.
A PyTorch implementation of : `Reversible Vision Transformers
<https://openaccess.thecvf.com/content/CVPR2022/html/Mangalam_Reversible_Vision_Transformers_CVPR_2022_paper.html>`_ # noqa: E501
Args:
arch (str | dict): Vision Transformer architecture. If use string,
choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small'
and 'deit-base'. If use dict, it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
- **feedforward_channels** (int): The hidden dimensions in
feedforward modules.
Defaults to 'base'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
Defaults to ``"avg_featmap"``.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
fusion_mode (str): The fusion mode of transformer layers.
Defaults to 'concat'.
no_custom_backward (bool): Whether to use custom backward.
Defaults to False.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo
=
{
**
dict
.
fromkeys
(
[
's'
,
'small'
],
{
'embed_dims'
:
768
,
'num_layers'
:
8
,
'num_heads'
:
8
,
'feedforward_channels'
:
768
*
3
,
}),
**
dict
.
fromkeys
(
[
'b'
,
'base'
],
{
'embed_dims'
:
768
,
'num_layers'
:
12
,
'num_heads'
:
12
,
'feedforward_channels'
:
3072
}),
**
dict
.
fromkeys
(
[
'l'
,
'large'
],
{
'embed_dims'
:
1024
,
'num_layers'
:
24
,
'num_heads'
:
16
,
'feedforward_channels'
:
4096
}),
**
dict
.
fromkeys
(
[
'h'
,
'huge'
],
{
# The same as the implementation in MAE
# <https://arxiv.org/abs/2111.06377>
'embed_dims'
:
1280
,
'num_layers'
:
32
,
'num_heads'
:
16
,
'feedforward_channels'
:
5120
}),
**
dict
.
fromkeys
(
[
'deit-t'
,
'deit-tiny'
],
{
'embed_dims'
:
192
,
'num_layers'
:
12
,
'num_heads'
:
3
,
'feedforward_channels'
:
192
*
4
}),
**
dict
.
fromkeys
(
[
'deit-s'
,
'deit-small'
],
{
'embed_dims'
:
384
,
'num_layers'
:
12
,
'num_heads'
:
6
,
'feedforward_channels'
:
384
*
4
}),
**
dict
.
fromkeys
(
[
'deit-b'
,
'deit-base'
],
{
'embed_dims'
:
768
,
'num_layers'
:
12
,
'num_heads'
:
12
,
'feedforward_channels'
:
768
*
4
}),
}
num_extra_tokens
=
0
# The official RevViT doesn't have class token
OUT_TYPES
=
{
'raw'
,
'cls_token'
,
'featmap'
,
'avg_featmap'
}
def
__init__
(
self
,
arch
=
'base'
,
img_size
=
224
,
patch_size
=
16
,
in_channels
=
3
,
drop_rate
=
0.
,
drop_path_rate
=
0.
,
qkv_bias
=
True
,
norm_cfg
=
dict
(
type
=
'LN'
,
eps
=
1e-6
),
final_norm
=
True
,
out_type
=
'avg_featmap'
,
with_cls_token
=
False
,
frozen_stages
=-
1
,
interpolate_mode
=
'bicubic'
,
patch_cfg
=
dict
(),
layer_cfgs
=
dict
(),
fusion_mode
=
'concat'
,
no_custom_backward
=
False
,
init_cfg
=
None
):
super
(
RevVisionTransformer
,
self
).
__init__
(
init_cfg
)
if
isinstance
(
arch
,
str
):
arch
=
arch
.
lower
()
assert
arch
in
set
(
self
.
arch_zoo
),
\
f
'Arch
{
arch
}
is not in default archs
{
set
(
self
.
arch_zoo
)
}
'
self
.
arch_settings
=
self
.
arch_zoo
[
arch
]
else
:
essential_keys
=
{
'embed_dims'
,
'num_layers'
,
'num_heads'
,
'feedforward_channels'
}
assert
isinstance
(
arch
,
dict
)
and
essential_keys
<=
set
(
arch
),
\
f
'Custom arch needs a dict with keys
{
essential_keys
}
'
self
.
arch_settings
=
arch
self
.
embed_dims
=
self
.
arch_settings
[
'embed_dims'
]
self
.
num_layers
=
self
.
arch_settings
[
'num_layers'
]
self
.
img_size
=
to_2tuple
(
img_size
)
self
.
no_custom_backward
=
no_custom_backward
# Set patch embedding
_patch_cfg
=
dict
(
in_channels
=
in_channels
,
input_size
=
img_size
,
embed_dims
=
self
.
embed_dims
,
conv_type
=
'Conv2d'
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
)
_patch_cfg
.
update
(
patch_cfg
)
self
.
patch_embed
=
PatchEmbed
(
**
_patch_cfg
)
self
.
patch_resolution
=
self
.
patch_embed
.
init_out_size
num_patches
=
self
.
patch_resolution
[
0
]
*
self
.
patch_resolution
[
1
]
# Set out type
if
out_type
not
in
self
.
OUT_TYPES
:
raise
ValueError
(
f
'Unsupported `out_type`
{
out_type
}
, please '
f
'choose from
{
self
.
OUT_TYPES
}
'
)
self
.
out_type
=
out_type
# Set cls token
if
with_cls_token
:
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
self
.
embed_dims
))
self
.
num_extra_tokens
=
1
elif
out_type
!=
'cls_token'
:
self
.
cls_token
=
None
self
.
num_extra_tokens
=
0
else
:
raise
ValueError
(
'with_cls_token must be True when `out_type="cls_token"`.'
)
# Set position embedding
self
.
interpolate_mode
=
interpolate_mode
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
+
self
.
num_extra_tokens
,
self
.
embed_dims
))
self
.
_register_load_state_dict_pre_hook
(
self
.
_prepare_pos_embed
)
self
.
drop_after_pos
=
nn
.
Dropout
(
p
=
drop_rate
)
# stochastic depth decay rule
dpr
=
np
.
linspace
(
0
,
drop_path_rate
,
self
.
num_layers
)
self
.
layers
=
ModuleList
()
if
isinstance
(
layer_cfgs
,
dict
):
layer_cfgs
=
[
layer_cfgs
]
*
self
.
num_layers
for
i
in
range
(
self
.
num_layers
):
_layer_cfg
=
dict
(
embed_dims
=
self
.
embed_dims
,
num_heads
=
self
.
arch_settings
[
'num_heads'
],
feedforward_channels
=
self
.
arch_settings
[
'feedforward_channels'
],
drop_rate
=
drop_rate
,
drop_path_rate
=
dpr
[
i
],
qkv_bias
=
qkv_bias
,
layer_id
=
i
,
norm_cfg
=
norm_cfg
)
_layer_cfg
.
update
(
layer_cfgs
[
i
])
self
.
layers
.
append
(
RevTransformerEncoderLayer
(
**
_layer_cfg
))
# fusion operation for the final output
self
.
fusion_layer
=
TwoStreamFusion
(
mode
=
fusion_mode
)
self
.
frozen_stages
=
frozen_stages
self
.
final_norm
=
final_norm
if
final_norm
:
self
.
ln1
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
*
2
)
# freeze stages only when self.frozen_stages > 0
if
self
.
frozen_stages
>
0
:
self
.
_freeze_stages
()
def
init_weights
(
self
):
super
(
RevVisionTransformer
,
self
).
init_weights
()
if
not
(
isinstance
(
self
.
init_cfg
,
dict
)
and
self
.
init_cfg
[
'type'
]
==
'Pretrained'
):
trunc_normal_
(
self
.
pos_embed
,
std
=
0.02
)
def
_prepare_pos_embed
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
name
=
prefix
+
'pos_embed'
if
name
not
in
state_dict
.
keys
():
return
ckpt_pos_embed_shape
=
state_dict
[
name
].
shape
if
self
.
pos_embed
.
shape
!=
ckpt_pos_embed_shape
:
from
mmengine.logging
import
MMLogger
logger
=
MMLogger
.
get_current_instance
()
logger
.
info
(
f
'Resize the pos_embed shape from
{
ckpt_pos_embed_shape
}
'
f
'to
{
self
.
pos_embed
.
shape
}
.'
)
ckpt_pos_embed_shape
=
to_2tuple
(
int
(
np
.
sqrt
(
ckpt_pos_embed_shape
[
1
]
-
self
.
num_extra_tokens
)))
pos_embed_shape
=
self
.
patch_embed
.
init_out_size
state_dict
[
name
]
=
resize_pos_embed
(
state_dict
[
name
],
ckpt_pos_embed_shape
,
pos_embed_shape
,
self
.
interpolate_mode
,
self
.
num_extra_tokens
)
@
staticmethod
def
resize_pos_embed
(
*
args
,
**
kwargs
):
"""Interface for backward-compatibility."""
return
resize_pos_embed
(
*
args
,
**
kwargs
)
def
_freeze_stages
(
self
):
# freeze position embedding
self
.
pos_embed
.
requires_grad
=
False
# set dropout to eval model
self
.
drop_after_pos
.
eval
()
# freeze patch embedding
self
.
patch_embed
.
eval
()
for
param
in
self
.
patch_embed
.
parameters
():
param
.
requires_grad
=
False
# freeze cls_token
if
self
.
cls_token
is
not
None
:
self
.
cls_token
.
requires_grad
=
False
# freeze layers
for
i
in
range
(
1
,
self
.
frozen_stages
+
1
):
m
=
self
.
layers
[
i
-
1
]
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
# freeze the last layer norm
if
self
.
frozen_stages
==
len
(
self
.
layers
)
and
self
.
final_norm
:
self
.
ln1
.
eval
()
for
param
in
self
.
ln1
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
x
):
B
=
x
.
shape
[
0
]
x
,
patch_resolution
=
self
.
patch_embed
(
x
)
if
self
.
cls_token
is
not
None
:
cls_token
=
self
.
cls_token
.
expand
(
B
,
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_token
,
x
),
dim
=
1
)
x
=
x
+
resize_pos_embed
(
self
.
pos_embed
,
self
.
patch_resolution
,
patch_resolution
,
mode
=
self
.
interpolate_mode
,
num_extra_tokens
=
self
.
num_extra_tokens
)
x
=
self
.
drop_after_pos
(
x
)
x
=
torch
.
cat
([
x
,
x
],
dim
=-
1
)
# forward with different conditions
if
not
self
.
training
or
self
.
no_custom_backward
:
# in eval/inference model
executing_fn
=
RevVisionTransformer
.
_forward_vanilla_bp
else
:
# use custom backward when self.training=True.
executing_fn
=
RevBackProp
.
apply
x
=
executing_fn
(
x
,
self
.
layers
,
[])
if
self
.
final_norm
:
x
=
self
.
ln1
(
x
)
x
=
self
.
fusion_layer
(
x
)
return
(
self
.
_format_output
(
x
,
patch_resolution
),
)
@
staticmethod
def
_forward_vanilla_bp
(
hidden_state
,
layers
,
buffer
=
[]):
"""Using reversible layers without reversible backpropagation.
Debugging purpose only. Activated with self.no_custom_backward
"""
# split into ffn state(ffn_out) and attention output(attn_out)
ffn_out
,
attn_out
=
torch
.
chunk
(
hidden_state
,
2
,
dim
=-
1
)
del
hidden_state
for
_
,
layer
in
enumerate
(
layers
):
attn_out
,
ffn_out
=
layer
(
attn_out
,
ffn_out
)
return
torch
.
cat
([
attn_out
,
ffn_out
],
dim
=-
1
)
def
_format_output
(
self
,
x
,
hw
):
if
self
.
out_type
==
'raw'
:
return
x
if
self
.
out_type
==
'cls_token'
:
return
x
[:,
0
]
patch_token
=
x
[:,
self
.
num_extra_tokens
:]
if
self
.
out_type
==
'featmap'
:
B
=
x
.
size
(
0
)
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
return
patch_token
.
reshape
(
B
,
*
hw
,
-
1
).
permute
(
0
,
3
,
1
,
2
)
if
self
.
out_type
==
'avg_featmap'
:
return
patch_token
.
mean
(
dim
=
1
)
mmpretrain/models/backbones/riformer.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Sequence
import
torch
import
torch.nn
as
nn
from
mmcv.cnn.bricks
import
DropPath
,
build_norm_layer
from
mmengine.model
import
BaseModule
from
mmpretrain.registry
import
MODELS
from
.base_backbone
import
BaseBackbone
from
.poolformer
import
Mlp
,
PatchEmbed
class
Affine
(
nn
.
Module
):
"""Affine Transformation module.
Args:
in_features (int): Input dimension.
"""
def
__init__
(
self
,
in_features
):
super
().
__init__
()
self
.
affine
=
nn
.
Conv2d
(
in_features
,
in_features
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
groups
=
in_features
,
bias
=
True
)
def
forward
(
self
,
x
):
return
self
.
affine
(
x
)
-
x
class
RIFormerBlock
(
BaseModule
):
"""RIFormer Block.
Args:
dim (int): Embedding dim.
mlp_ratio (float): Mlp expansion ratio. Defaults to 4.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='GN', num_groups=1)``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
drop (float): Dropout rate. Defaults to 0.
drop_path (float): Stochastic depth rate. Defaults to 0.
layer_scale_init_value (float): Init value for Layer Scale.
Defaults to 1e-5.
deploy (bool): Whether to switch the model structure to
deployment mode. Default: False.
"""
def
__init__
(
self
,
dim
,
mlp_ratio
=
4.
,
norm_cfg
=
dict
(
type
=
'GN'
,
num_groups
=
1
),
act_cfg
=
dict
(
type
=
'GELU'
),
drop
=
0.
,
drop_path
=
0.
,
layer_scale_init_value
=
1e-5
,
deploy
=
False
):
super
().
__init__
()
if
deploy
:
self
.
norm_reparam
=
build_norm_layer
(
norm_cfg
,
dim
)[
1
]
else
:
self
.
norm1
=
build_norm_layer
(
norm_cfg
,
dim
)[
1
]
self
.
token_mixer
=
Affine
(
in_features
=
dim
)
self
.
norm2
=
build_norm_layer
(
norm_cfg
,
dim
)[
1
]
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_cfg
=
act_cfg
,
drop
=
drop
)
# The following two techniques are useful to train deep RIFormers.
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
\
else
nn
.
Identity
()
self
.
layer_scale_1
=
nn
.
Parameter
(
layer_scale_init_value
*
torch
.
ones
((
dim
)),
requires_grad
=
True
)
self
.
layer_scale_2
=
nn
.
Parameter
(
layer_scale_init_value
*
torch
.
ones
((
dim
)),
requires_grad
=
True
)
self
.
norm_cfg
=
norm_cfg
self
.
dim
=
dim
self
.
deploy
=
deploy
def
forward
(
self
,
x
):
if
hasattr
(
self
,
'norm_reparam'
):
x
=
x
+
self
.
drop_path
(
self
.
layer_scale_1
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
*
self
.
norm_reparam
(
x
))
x
=
x
+
self
.
drop_path
(
self
.
layer_scale_2
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
*
self
.
mlp
(
self
.
norm2
(
x
)))
else
:
x
=
x
+
self
.
drop_path
(
self
.
layer_scale_1
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
*
self
.
token_mixer
(
self
.
norm1
(
x
)))
x
=
x
+
self
.
drop_path
(
self
.
layer_scale_2
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
*
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
def
fuse_affine
(
self
,
norm
,
token_mixer
):
gamma_affn
=
token_mixer
.
affine
.
weight
.
reshape
(
-
1
)
gamma_affn
=
gamma_affn
-
torch
.
ones_like
(
gamma_affn
)
beta_affn
=
token_mixer
.
affine
.
bias
gamma_ln
=
norm
.
weight
beta_ln
=
norm
.
bias
return
(
gamma_ln
*
gamma_affn
),
(
beta_ln
*
gamma_affn
+
beta_affn
)
def
get_equivalent_scale_bias
(
self
):
eq_s
,
eq_b
=
self
.
fuse_affine
(
self
.
norm1
,
self
.
token_mixer
)
return
eq_s
,
eq_b
def
switch_to_deploy
(
self
):
if
self
.
deploy
:
return
eq_s
,
eq_b
=
self
.
get_equivalent_scale_bias
()
self
.
norm_reparam
=
build_norm_layer
(
self
.
norm_cfg
,
self
.
dim
)[
1
]
self
.
norm_reparam
.
weight
.
data
=
eq_s
self
.
norm_reparam
.
bias
.
data
=
eq_b
self
.
__delattr__
(
'norm1'
)
if
hasattr
(
self
,
'token_mixer'
):
self
.
__delattr__
(
'token_mixer'
)
self
.
deploy
=
True
def
basic_blocks
(
dim
,
index
,
layers
,
mlp_ratio
=
4.
,
norm_cfg
=
dict
(
type
=
'GN'
,
num_groups
=
1
),
act_cfg
=
dict
(
type
=
'GELU'
),
drop_rate
=
.
0
,
drop_path_rate
=
0.
,
layer_scale_init_value
=
1e-5
,
deploy
=
False
):
"""generate RIFormer blocks for a stage."""
blocks
=
[]
for
block_idx
in
range
(
layers
[
index
]):
block_dpr
=
drop_path_rate
*
(
block_idx
+
sum
(
layers
[:
index
]))
/
(
sum
(
layers
)
-
1
)
blocks
.
append
(
RIFormerBlock
(
dim
,
mlp_ratio
=
mlp_ratio
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
,
drop
=
drop_rate
,
drop_path
=
block_dpr
,
layer_scale_init_value
=
layer_scale_init_value
,
deploy
=
deploy
,
))
blocks
=
nn
.
Sequential
(
*
blocks
)
return
blocks
@
MODELS
.
register_module
()
class
RIFormer
(
BaseBackbone
):
"""RIFormer.
A PyTorch implementation of RIFormer introduced by:
`RIFormer: Keep Your Vision Backbone Effective But Removing Token Mixer <https://arxiv.org/abs/xxxx.xxxxx>`_
Args:
arch (str | dict): The model's architecture. If string, it should be
one of architecture in ``RIFormer.arch_settings``. And if dict, it
should include the following two keys:
- layers (list[int]): Number of blocks at each stage.
- embed_dims (list[int]): The number of channels at each stage.
- mlp_ratios (list[int]): Expansion ratio of MLPs.
- layer_scale_init_value (float): Init value for Layer Scale.
Defaults to 'S12'.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='LN2d', eps=1e-6)``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
in_patch_size (int): The patch size of/? input image patch embedding.
Defaults to 7.
in_stride (int): The stride of input image patch embedding.
Defaults to 4.
in_pad (int): The padding of input image patch embedding.
Defaults to 2.
down_patch_size (int): The patch size of downsampling patch embedding.
Defaults to 3.
down_stride (int): The stride of downsampling patch embedding.
Defaults to 2.
down_pad (int): The padding of downsampling patch embedding.
Defaults to 1.
drop_rate (float): Dropout rate. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
out_indices (Sequence | int): Output from which network position.
Index 0-6 respectively corresponds to
[stage1, downsampling, stage2, downsampling, stage3, downsampling, stage4]
Defaults to -1, means the last stage.
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to -1, which means not freezing any parameters.
deploy (bool): Whether to switch the model structure to
deployment mode. Default: False.
init_cfg (dict, optional): Initialization config dict
"""
# noqa: E501
# --layers: [x,x,x,x], numbers of layers for the four stages
# --embed_dims, --mlp_ratios:
# embedding dims and mlp ratios for the four stages
# --downsamples: flags to apply downsampling or not in four blocks
arch_settings
=
{
's12'
:
{
'layers'
:
[
2
,
2
,
6
,
2
],
'embed_dims'
:
[
64
,
128
,
320
,
512
],
'mlp_ratios'
:
[
4
,
4
,
4
,
4
],
'layer_scale_init_value'
:
1e-5
,
},
's24'
:
{
'layers'
:
[
4
,
4
,
12
,
4
],
'embed_dims'
:
[
64
,
128
,
320
,
512
],
'mlp_ratios'
:
[
4
,
4
,
4
,
4
],
'layer_scale_init_value'
:
1e-5
,
},
's36'
:
{
'layers'
:
[
6
,
6
,
18
,
6
],
'embed_dims'
:
[
64
,
128
,
320
,
512
],
'mlp_ratios'
:
[
4
,
4
,
4
,
4
],
'layer_scale_init_value'
:
1e-6
,
},
'm36'
:
{
'layers'
:
[
6
,
6
,
18
,
6
],
'embed_dims'
:
[
96
,
192
,
384
,
768
],
'mlp_ratios'
:
[
4
,
4
,
4
,
4
],
'layer_scale_init_value'
:
1e-6
,
},
'm48'
:
{
'layers'
:
[
8
,
8
,
24
,
8
],
'embed_dims'
:
[
96
,
192
,
384
,
768
],
'mlp_ratios'
:
[
4
,
4
,
4
,
4
],
'layer_scale_init_value'
:
1e-6
,
},
}
def
__init__
(
self
,
arch
=
's12'
,
in_channels
=
3
,
norm_cfg
=
dict
(
type
=
'GN'
,
num_groups
=
1
),
act_cfg
=
dict
(
type
=
'GELU'
),
in_patch_size
=
7
,
in_stride
=
4
,
in_pad
=
2
,
down_patch_size
=
3
,
down_stride
=
2
,
down_pad
=
1
,
drop_rate
=
0.
,
drop_path_rate
=
0.
,
out_indices
=-
1
,
frozen_stages
=-
1
,
init_cfg
=
None
,
deploy
=
False
):
super
().
__init__
(
init_cfg
=
init_cfg
)
if
isinstance
(
arch
,
str
):
assert
arch
in
self
.
arch_settings
,
\
f
'Unavailable arch, please choose from '
\
f
'(
{
set
(
self
.
arch_settings
)
}
) or pass a dict.'
arch
=
self
.
arch_settings
[
arch
]
elif
isinstance
(
arch
,
dict
):
assert
'layers'
in
arch
and
'embed_dims'
in
arch
,
\
f
'The arch dict must have "layers" and "embed_dims", '
\
f
'but got
{
list
(
arch
.
keys
())
}
.'
layers
=
arch
[
'layers'
]
embed_dims
=
arch
[
'embed_dims'
]
mlp_ratios
=
arch
[
'mlp_ratios'
]
\
if
'mlp_ratios'
in
arch
else
[
4
,
4
,
4
,
4
]
layer_scale_init_value
=
arch
[
'layer_scale_init_value'
]
\
if
'layer_scale_init_value'
in
arch
else
1e-5
self
.
patch_embed
=
PatchEmbed
(
patch_size
=
in_patch_size
,
stride
=
in_stride
,
padding
=
in_pad
,
in_chans
=
in_channels
,
embed_dim
=
embed_dims
[
0
])
# set the main block in network
network
=
[]
for
i
in
range
(
len
(
layers
)):
stage
=
basic_blocks
(
embed_dims
[
i
],
i
,
layers
,
mlp_ratio
=
mlp_ratios
[
i
],
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
,
drop_rate
=
drop_rate
,
drop_path_rate
=
drop_path_rate
,
layer_scale_init_value
=
layer_scale_init_value
,
deploy
=
deploy
)
network
.
append
(
stage
)
if
i
>=
len
(
layers
)
-
1
:
break
if
embed_dims
[
i
]
!=
embed_dims
[
i
+
1
]:
# downsampling between two stages
network
.
append
(
PatchEmbed
(
patch_size
=
down_patch_size
,
stride
=
down_stride
,
padding
=
down_pad
,
in_chans
=
embed_dims
[
i
],
embed_dim
=
embed_dims
[
i
+
1
]))
self
.
network
=
nn
.
ModuleList
(
network
)
if
isinstance
(
out_indices
,
int
):
out_indices
=
[
out_indices
]
assert
isinstance
(
out_indices
,
Sequence
),
\
f
'"out_indices" must by a sequence or int, '
\
f
'get
{
type
(
out_indices
)
}
instead.'
for
i
,
index
in
enumerate
(
out_indices
):
if
index
<
0
:
out_indices
[
i
]
=
7
+
index
assert
out_indices
[
i
]
>=
0
,
f
'Invalid out_indices
{
index
}
'
self
.
out_indices
=
out_indices
if
self
.
out_indices
:
for
i_layer
in
self
.
out_indices
:
layer
=
build_norm_layer
(
norm_cfg
,
embed_dims
[(
i_layer
+
1
)
//
2
])[
1
]
layer_name
=
f
'norm
{
i_layer
}
'
self
.
add_module
(
layer_name
,
layer
)
self
.
frozen_stages
=
frozen_stages
self
.
_freeze_stages
()
self
.
deploy
=
deploy
def
forward_embeddings
(
self
,
x
):
x
=
self
.
patch_embed
(
x
)
return
x
def
forward_tokens
(
self
,
x
):
outs
=
[]
for
idx
,
block
in
enumerate
(
self
.
network
):
x
=
block
(
x
)
if
idx
in
self
.
out_indices
:
norm_layer
=
getattr
(
self
,
f
'norm
{
idx
}
'
)
x_out
=
norm_layer
(
x
)
outs
.
append
(
x_out
)
return
tuple
(
outs
)
def
forward
(
self
,
x
):
# input embedding
x
=
self
.
forward_embeddings
(
x
)
# through backbone
x
=
self
.
forward_tokens
(
x
)
return
x
def
_freeze_stages
(
self
):
if
self
.
frozen_stages
>=
0
:
self
.
patch_embed
.
eval
()
for
param
in
self
.
patch_embed
.
parameters
():
param
.
requires_grad
=
False
for
i
in
range
(
0
,
self
.
frozen_stages
+
1
):
# Include both block and downsample layer.
module
=
self
.
network
[
i
]
module
.
eval
()
for
param
in
module
.
parameters
():
param
.
requires_grad
=
False
if
i
in
self
.
out_indices
:
norm_layer
=
getattr
(
self
,
f
'norm
{
i
}
'
)
norm_layer
.
eval
()
for
param
in
norm_layer
.
parameters
():
param
.
requires_grad
=
False
def
train
(
self
,
mode
=
True
):
super
(
RIFormer
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
return
self
def
switch_to_deploy
(
self
):
for
m
in
self
.
modules
():
if
isinstance
(
m
,
RIFormerBlock
):
m
.
switch_to_deploy
()
self
.
deploy
=
True
mmpretrain/models/backbones/seresnet.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
import
torch.utils.checkpoint
as
cp
from
mmpretrain.registry
import
MODELS
from
..utils.se_layer
import
SELayer
from
.resnet
import
Bottleneck
,
ResLayer
,
ResNet
class
SEBottleneck
(
Bottleneck
):
"""SEBottleneck block for SEResNet.
Args:
in_channels (int): The input channels of the SEBottleneck block.
out_channels (int): The output channel of the SEBottleneck block.
se_ratio (int): Squeeze ratio in SELayer. Default: 16
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
se_ratio
=
16
,
**
kwargs
):
super
(
SEBottleneck
,
self
).
__init__
(
in_channels
,
out_channels
,
**
kwargs
)
self
.
se_layer
=
SELayer
(
out_channels
,
ratio
=
se_ratio
)
def
forward
(
self
,
x
):
def
_inner_forward
(
x
):
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
norm1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
norm2
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
norm3
(
out
)
out
=
self
.
se_layer
(
out
)
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
+=
identity
return
out
if
self
.
with_cp
and
x
.
requires_grad
:
out
=
cp
.
checkpoint
(
_inner_forward
,
x
)
else
:
out
=
_inner_forward
(
x
)
out
=
self
.
relu
(
out
)
return
out
@
MODELS
.
register_module
()
class
SEResNet
(
ResNet
):
"""SEResNet backbone.
Please refer to the `paper <https://arxiv.org/abs/1709.01507>`__ for
details.
Args:
depth (int): Network depth, from {50, 101, 152}.
se_ratio (int): Squeeze ratio in SELayer. Default: 16.
in_channels (int): Number of input image channels. Default: 3.
stem_channels (int): Output channels of the stem layer. Default: 64.
num_stages (int): Stages of the network. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Default: False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
Example:
>>> from mmpretrain.models import SEResNet
>>> import torch
>>> self = SEResNet(depth=50)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 224, 224)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 64, 56, 56)
(1, 128, 28, 28)
(1, 256, 14, 14)
(1, 512, 7, 7)
"""
arch_settings
=
{
50
:
(
SEBottleneck
,
(
3
,
4
,
6
,
3
)),
101
:
(
SEBottleneck
,
(
3
,
4
,
23
,
3
)),
152
:
(
SEBottleneck
,
(
3
,
8
,
36
,
3
))
}
def
__init__
(
self
,
depth
,
se_ratio
=
16
,
**
kwargs
):
if
depth
not
in
self
.
arch_settings
:
raise
KeyError
(
f
'invalid depth
{
depth
}
for SEResNet'
)
self
.
se_ratio
=
se_ratio
super
(
SEResNet
,
self
).
__init__
(
depth
,
**
kwargs
)
def
make_res_layer
(
self
,
**
kwargs
):
return
ResLayer
(
se_ratio
=
self
.
se_ratio
,
**
kwargs
)
mmpretrain/models/backbones/seresnext.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
mmcv.cnn
import
build_conv_layer
,
build_norm_layer
from
mmpretrain.registry
import
MODELS
from
.resnet
import
ResLayer
from
.seresnet
import
SEBottleneck
as
_SEBottleneck
from
.seresnet
import
SEResNet
class
SEBottleneck
(
_SEBottleneck
):
"""SEBottleneck block for SEResNeXt.
Args:
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
base_channels (int): Middle channels of the first stage. Default: 64.
groups (int): Groups of conv2.
width_per_group (int): Width per group of conv2. 64x4d indicates
``groups=64, width_per_group=4`` and 32x8d indicates
``groups=32, width_per_group=8``.
stride (int): stride of the block. Default: 1
dilation (int): dilation of convolution. Default: 1
downsample (nn.Module, optional): downsample operation on identity
branch. Default: None
se_ratio (int): Squeeze ratio in SELayer. Default: 16
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
conv_cfg (dict, optional): dictionary to construct and config conv
layer. Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
base_channels
=
64
,
groups
=
32
,
width_per_group
=
4
,
se_ratio
=
16
,
**
kwargs
):
super
(
SEBottleneck
,
self
).
__init__
(
in_channels
,
out_channels
,
se_ratio
,
**
kwargs
)
self
.
groups
=
groups
self
.
width_per_group
=
width_per_group
# We follow the same rational of ResNext to compute mid_channels.
# For SEResNet bottleneck, middle channels are determined by expansion
# and out_channels, but for SEResNeXt bottleneck, it is determined by
# groups and width_per_group and the stage it is located in.
if
groups
!=
1
:
assert
self
.
mid_channels
%
base_channels
==
0
self
.
mid_channels
=
(
groups
*
width_per_group
*
self
.
mid_channels
//
base_channels
)
self
.
norm1_name
,
norm1
=
build_norm_layer
(
self
.
norm_cfg
,
self
.
mid_channels
,
postfix
=
1
)
self
.
norm2_name
,
norm2
=
build_norm_layer
(
self
.
norm_cfg
,
self
.
mid_channels
,
postfix
=
2
)
self
.
norm3_name
,
norm3
=
build_norm_layer
(
self
.
norm_cfg
,
self
.
out_channels
,
postfix
=
3
)
self
.
conv1
=
build_conv_layer
(
self
.
conv_cfg
,
self
.
in_channels
,
self
.
mid_channels
,
kernel_size
=
1
,
stride
=
self
.
conv1_stride
,
bias
=
False
)
self
.
add_module
(
self
.
norm1_name
,
norm1
)
self
.
conv2
=
build_conv_layer
(
self
.
conv_cfg
,
self
.
mid_channels
,
self
.
mid_channels
,
kernel_size
=
3
,
stride
=
self
.
conv2_stride
,
padding
=
self
.
dilation
,
dilation
=
self
.
dilation
,
groups
=
groups
,
bias
=
False
)
self
.
add_module
(
self
.
norm2_name
,
norm2
)
self
.
conv3
=
build_conv_layer
(
self
.
conv_cfg
,
self
.
mid_channels
,
self
.
out_channels
,
kernel_size
=
1
,
bias
=
False
)
self
.
add_module
(
self
.
norm3_name
,
norm3
)
@
MODELS
.
register_module
()
class
SEResNeXt
(
SEResNet
):
"""SEResNeXt backbone.
Please refer to the `paper <https://arxiv.org/abs/1709.01507>`__ for
details.
Args:
depth (int): Network depth, from {50, 101, 152}.
groups (int): Groups of conv2 in Bottleneck. Default: 32.
width_per_group (int): Width per group of conv2 in Bottleneck.
Default: 4.
se_ratio (int): Squeeze ratio in SELayer. Default: 16.
in_channels (int): Number of input image channels. Default: 3.
stem_channels (int): Output channels of the stem layer. Default: 64.
num_stages (int): Stages of the network. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Default: False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
"""
arch_settings
=
{
50
:
(
SEBottleneck
,
(
3
,
4
,
6
,
3
)),
101
:
(
SEBottleneck
,
(
3
,
4
,
23
,
3
)),
152
:
(
SEBottleneck
,
(
3
,
8
,
36
,
3
))
}
def
__init__
(
self
,
depth
,
groups
=
32
,
width_per_group
=
4
,
**
kwargs
):
self
.
groups
=
groups
self
.
width_per_group
=
width_per_group
super
(
SEResNeXt
,
self
).
__init__
(
depth
,
**
kwargs
)
def
make_res_layer
(
self
,
**
kwargs
):
return
ResLayer
(
groups
=
self
.
groups
,
width_per_group
=
self
.
width_per_group
,
base_channels
=
self
.
base_channels
,
**
kwargs
)
mmpretrain/models/backbones/shufflenet_v1.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
cp
from
mmcv.cnn
import
ConvModule
,
build_activation_layer
from
mmengine.model
import
BaseModule
from
mmengine.model.weight_init
import
constant_init
,
normal_init
from
torch.nn.modules.batchnorm
import
_BatchNorm
from
mmpretrain.models.utils
import
channel_shuffle
,
make_divisible
from
mmpretrain.registry
import
MODELS
from
.base_backbone
import
BaseBackbone
class
ShuffleUnit
(
BaseModule
):
"""ShuffleUnit block.
ShuffleNet unit with pointwise group convolution (GConv) and channel
shuffle.
Args:
in_channels (int): The input channels of the ShuffleUnit.
out_channels (int): The output channels of the ShuffleUnit.
groups (int): The number of groups to be used in grouped 1x1
convolutions in each ShuffleUnit. Default: 3
first_block (bool): Whether it is the first ShuffleUnit of a
sequential ShuffleUnits. Default: True, which means not using the
grouped 1x1 convolution.
combine (str): The ways to combine the input and output
branches. Default: 'add'.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
with_cp (bool): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Default: False.
Returns:
Tensor: The output tensor.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
groups
=
3
,
first_block
=
True
,
combine
=
'add'
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
),
act_cfg
=
dict
(
type
=
'ReLU'
),
with_cp
=
False
):
super
(
ShuffleUnit
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
first_block
=
first_block
self
.
combine
=
combine
self
.
groups
=
groups
self
.
bottleneck_channels
=
self
.
out_channels
//
4
self
.
with_cp
=
with_cp
if
self
.
combine
==
'add'
:
self
.
depthwise_stride
=
1
self
.
_combine_func
=
self
.
_add
assert
in_channels
==
out_channels
,
(
'in_channels must be equal to out_channels when combine '
'is add'
)
elif
self
.
combine
==
'concat'
:
self
.
depthwise_stride
=
2
self
.
_combine_func
=
self
.
_concat
self
.
out_channels
-=
self
.
in_channels
self
.
avgpool
=
nn
.
AvgPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
else
:
raise
ValueError
(
f
'Cannot combine tensors with
{
self
.
combine
}
. '
'Only "add" and "concat" are supported'
)
self
.
first_1x1_groups
=
1
if
first_block
else
self
.
groups
self
.
g_conv_1x1_compress
=
ConvModule
(
in_channels
=
self
.
in_channels
,
out_channels
=
self
.
bottleneck_channels
,
kernel_size
=
1
,
groups
=
self
.
first_1x1_groups
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
)
self
.
depthwise_conv3x3_bn
=
ConvModule
(
in_channels
=
self
.
bottleneck_channels
,
out_channels
=
self
.
bottleneck_channels
,
kernel_size
=
3
,
stride
=
self
.
depthwise_stride
,
padding
=
1
,
groups
=
self
.
bottleneck_channels
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
None
)
self
.
g_conv_1x1_expand
=
ConvModule
(
in_channels
=
self
.
bottleneck_channels
,
out_channels
=
self
.
out_channels
,
kernel_size
=
1
,
groups
=
self
.
groups
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
None
)
self
.
act
=
build_activation_layer
(
act_cfg
)
@
staticmethod
def
_add
(
x
,
out
):
# residual connection
return
x
+
out
@
staticmethod
def
_concat
(
x
,
out
):
# concatenate along channel axis
return
torch
.
cat
((
x
,
out
),
1
)
def
forward
(
self
,
x
):
def
_inner_forward
(
x
):
residual
=
x
out
=
self
.
g_conv_1x1_compress
(
x
)
out
=
self
.
depthwise_conv3x3_bn
(
out
)
if
self
.
groups
>
1
:
out
=
channel_shuffle
(
out
,
self
.
groups
)
out
=
self
.
g_conv_1x1_expand
(
out
)
if
self
.
combine
==
'concat'
:
residual
=
self
.
avgpool
(
residual
)
out
=
self
.
act
(
out
)
out
=
self
.
_combine_func
(
residual
,
out
)
else
:
out
=
self
.
_combine_func
(
residual
,
out
)
out
=
self
.
act
(
out
)
return
out
if
self
.
with_cp
and
x
.
requires_grad
:
out
=
cp
.
checkpoint
(
_inner_forward
,
x
)
else
:
out
=
_inner_forward
(
x
)
return
out
@
MODELS
.
register_module
()
class
ShuffleNetV1
(
BaseBackbone
):
"""ShuffleNetV1 backbone.
Args:
groups (int): The number of groups to be used in grouped 1x1
convolutions in each ShuffleUnit. Default: 3.
widen_factor (float): Width multiplier - adjusts the number
of channels in each layer by this amount. Default: 1.0.
out_indices (Sequence[int]): Output from which stages.
Default: (2, )
frozen_stages (int): Stages to be frozen (all param fixed).
Default: -1, which means not freezing any parameters.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def
__init__
(
self
,
groups
=
3
,
widen_factor
=
1.0
,
out_indices
=
(
2
,
),
frozen_stages
=-
1
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
),
act_cfg
=
dict
(
type
=
'ReLU'
),
norm_eval
=
False
,
with_cp
=
False
,
init_cfg
=
None
):
super
(
ShuffleNetV1
,
self
).
__init__
(
init_cfg
)
self
.
init_cfg
=
init_cfg
self
.
stage_blocks
=
[
4
,
8
,
4
]
self
.
groups
=
groups
for
index
in
out_indices
:
if
index
not
in
range
(
0
,
3
):
raise
ValueError
(
'the item in out_indices must in '
f
'range(0, 3). But received
{
index
}
'
)
if
frozen_stages
not
in
range
(
-
1
,
3
):
raise
ValueError
(
'frozen_stages must be in range(-1, 3). '
f
'But received
{
frozen_stages
}
'
)
self
.
out_indices
=
out_indices
self
.
frozen_stages
=
frozen_stages
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
act_cfg
=
act_cfg
self
.
norm_eval
=
norm_eval
self
.
with_cp
=
with_cp
if
groups
==
1
:
channels
=
(
144
,
288
,
576
)
elif
groups
==
2
:
channels
=
(
200
,
400
,
800
)
elif
groups
==
3
:
channels
=
(
240
,
480
,
960
)
elif
groups
==
4
:
channels
=
(
272
,
544
,
1088
)
elif
groups
==
8
:
channels
=
(
384
,
768
,
1536
)
else
:
raise
ValueError
(
f
'
{
groups
}
groups is not supported for 1x1 '
'Grouped Convolutions'
)
channels
=
[
make_divisible
(
ch
*
widen_factor
,
8
)
for
ch
in
channels
]
self
.
in_channels
=
int
(
24
*
widen_factor
)
self
.
conv1
=
ConvModule
(
in_channels
=
3
,
out_channels
=
self
.
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
layers
=
nn
.
ModuleList
()
for
i
,
num_blocks
in
enumerate
(
self
.
stage_blocks
):
first_block
=
True
if
i
==
0
else
False
layer
=
self
.
make_layer
(
channels
[
i
],
num_blocks
,
first_block
)
self
.
layers
.
append
(
layer
)
def
_freeze_stages
(
self
):
if
self
.
frozen_stages
>=
0
:
for
param
in
self
.
conv1
.
parameters
():
param
.
requires_grad
=
False
for
i
in
range
(
self
.
frozen_stages
):
layer
=
self
.
layers
[
i
]
layer
.
eval
()
for
param
in
layer
.
parameters
():
param
.
requires_grad
=
False
def
init_weights
(
self
):
super
(
ShuffleNetV1
,
self
).
init_weights
()
if
(
isinstance
(
self
.
init_cfg
,
dict
)
and
self
.
init_cfg
[
'type'
]
==
'Pretrained'
):
# Suppress default init if use pretrained model.
return
for
name
,
m
in
self
.
named_modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
if
'conv1'
in
name
:
normal_init
(
m
,
mean
=
0
,
std
=
0.01
)
else
:
normal_init
(
m
,
mean
=
0
,
std
=
1.0
/
m
.
weight
.
shape
[
1
])
elif
isinstance
(
m
,
(
_BatchNorm
,
nn
.
GroupNorm
)):
constant_init
(
m
,
val
=
1
,
bias
=
0.0001
)
if
isinstance
(
m
,
_BatchNorm
):
if
m
.
running_mean
is
not
None
:
nn
.
init
.
constant_
(
m
.
running_mean
,
0
)
def
make_layer
(
self
,
out_channels
,
num_blocks
,
first_block
=
False
):
"""Stack ShuffleUnit blocks to make a layer.
Args:
out_channels (int): out_channels of the block.
num_blocks (int): Number of blocks.
first_block (bool): Whether is the first ShuffleUnit of a
sequential ShuffleUnits. Default: False, which means using
the grouped 1x1 convolution.
"""
layers
=
[]
for
i
in
range
(
num_blocks
):
first_block
=
first_block
if
i
==
0
else
False
combine_mode
=
'concat'
if
i
==
0
else
'add'
layers
.
append
(
ShuffleUnit
(
self
.
in_channels
,
out_channels
,
groups
=
self
.
groups
,
first_block
=
first_block
,
combine
=
combine_mode
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
,
act_cfg
=
self
.
act_cfg
,
with_cp
=
self
.
with_cp
))
self
.
in_channels
=
out_channels
return
nn
.
Sequential
(
*
layers
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
maxpool
(
x
)
outs
=
[]
for
i
,
layer
in
enumerate
(
self
.
layers
):
x
=
layer
(
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
x
)
return
tuple
(
outs
)
def
train
(
self
,
mode
=
True
):
super
(
ShuffleNetV1
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
if
mode
and
self
.
norm_eval
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
_BatchNorm
):
m
.
eval
()
mmpretrain/models/backbones/shufflenet_v2.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
cp
from
mmcv.cnn
import
ConvModule
from
mmengine.model
import
BaseModule
from
mmengine.model.weight_init
import
constant_init
,
normal_init
from
torch.nn.modules.batchnorm
import
_BatchNorm
from
mmpretrain.models.utils
import
channel_shuffle
from
mmpretrain.registry
import
MODELS
from
.base_backbone
import
BaseBackbone
class
InvertedResidual
(
BaseModule
):
"""InvertedResidual block for ShuffleNetV2 backbone.
Args:
in_channels (int): The input channels of the block.
out_channels (int): The output channels of the block.
stride (int): Stride of the 3x3 convolution layer. Default: 1
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
Returns:
Tensor: The output tensor.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
=
1
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
),
act_cfg
=
dict
(
type
=
'ReLU'
),
with_cp
=
False
,
init_cfg
=
None
):
super
(
InvertedResidual
,
self
).
__init__
(
init_cfg
)
self
.
stride
=
stride
self
.
with_cp
=
with_cp
branch_features
=
out_channels
//
2
if
self
.
stride
==
1
:
assert
in_channels
==
branch_features
*
2
,
(
f
'in_channels (
{
in_channels
}
) should equal to '
f
'branch_features * 2 (
{
branch_features
*
2
}
) '
'when stride is 1'
)
if
in_channels
!=
branch_features
*
2
:
assert
self
.
stride
!=
1
,
(
f
'stride (
{
self
.
stride
}
) should not equal 1 when '
f
'in_channels != branch_features * 2'
)
if
self
.
stride
>
1
:
self
.
branch1
=
nn
.
Sequential
(
ConvModule
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
self
.
stride
,
padding
=
1
,
groups
=
in_channels
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
None
),
ConvModule
(
in_channels
,
branch_features
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
),
)
self
.
branch2
=
nn
.
Sequential
(
ConvModule
(
in_channels
if
(
self
.
stride
>
1
)
else
branch_features
,
branch_features
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
),
ConvModule
(
branch_features
,
branch_features
,
kernel_size
=
3
,
stride
=
self
.
stride
,
padding
=
1
,
groups
=
branch_features
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
None
),
ConvModule
(
branch_features
,
branch_features
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
))
def
forward
(
self
,
x
):
def
_inner_forward
(
x
):
if
self
.
stride
>
1
:
out
=
torch
.
cat
((
self
.
branch1
(
x
),
self
.
branch2
(
x
)),
dim
=
1
)
else
:
# Channel Split operation. using these lines of code to replace
# ``chunk(x, 2, dim=1)`` can make it easier to deploy a
# shufflenetv2 model by using mmdeploy.
channels
=
x
.
shape
[
1
]
c
=
channels
//
2
+
channels
%
2
x1
=
x
[:,
:
c
,
:,
:]
x2
=
x
[:,
c
:,
:,
:]
out
=
torch
.
cat
((
x1
,
self
.
branch2
(
x2
)),
dim
=
1
)
out
=
channel_shuffle
(
out
,
2
)
return
out
if
self
.
with_cp
and
x
.
requires_grad
:
out
=
cp
.
checkpoint
(
_inner_forward
,
x
)
else
:
out
=
_inner_forward
(
x
)
return
out
@
MODELS
.
register_module
()
class
ShuffleNetV2
(
BaseBackbone
):
"""ShuffleNetV2 backbone.
Args:
widen_factor (float): Width multiplier - adjusts the number of
channels in each layer by this amount. Default: 1.0.
out_indices (Sequence[int]): Output from which stages.
Default: (0, 1, 2, 3).
frozen_stages (int): Stages to be frozen (all param fixed).
Default: -1, which means not freezing any parameters.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def
__init__
(
self
,
widen_factor
=
1.0
,
out_indices
=
(
3
,
),
frozen_stages
=-
1
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
),
act_cfg
=
dict
(
type
=
'ReLU'
),
norm_eval
=
False
,
with_cp
=
False
,
init_cfg
=
None
):
super
(
ShuffleNetV2
,
self
).
__init__
(
init_cfg
)
self
.
stage_blocks
=
[
4
,
8
,
4
]
for
index
in
out_indices
:
if
index
not
in
range
(
0
,
4
):
raise
ValueError
(
'the item in out_indices must in '
f
'range(0, 4). But received
{
index
}
'
)
if
frozen_stages
not
in
range
(
-
1
,
4
):
raise
ValueError
(
'frozen_stages must be in range(-1, 4). '
f
'But received
{
frozen_stages
}
'
)
self
.
out_indices
=
out_indices
self
.
frozen_stages
=
frozen_stages
self
.
conv_cfg
=
conv_cfg
self
.
norm_cfg
=
norm_cfg
self
.
act_cfg
=
act_cfg
self
.
norm_eval
=
norm_eval
self
.
with_cp
=
with_cp
if
widen_factor
==
0.5
:
channels
=
[
48
,
96
,
192
,
1024
]
elif
widen_factor
==
1.0
:
channels
=
[
116
,
232
,
464
,
1024
]
elif
widen_factor
==
1.5
:
channels
=
[
176
,
352
,
704
,
1024
]
elif
widen_factor
==
2.0
:
channels
=
[
244
,
488
,
976
,
2048
]
else
:
raise
ValueError
(
'widen_factor must be in [0.5, 1.0, 1.5, 2.0]. '
f
'But received
{
widen_factor
}
'
)
self
.
in_channels
=
24
self
.
conv1
=
ConvModule
(
in_channels
=
3
,
out_channels
=
self
.
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
)
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
padding
=
1
)
self
.
layers
=
nn
.
ModuleList
()
for
i
,
num_blocks
in
enumerate
(
self
.
stage_blocks
):
layer
=
self
.
_make_layer
(
channels
[
i
],
num_blocks
)
self
.
layers
.
append
(
layer
)
output_channels
=
channels
[
-
1
]
self
.
layers
.
append
(
ConvModule
(
in_channels
=
self
.
in_channels
,
out_channels
=
output_channels
,
kernel_size
=
1
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
))
def
_make_layer
(
self
,
out_channels
,
num_blocks
):
"""Stack blocks to make a layer.
Args:
out_channels (int): out_channels of the block.
num_blocks (int): number of blocks.
"""
layers
=
[]
for
i
in
range
(
num_blocks
):
stride
=
2
if
i
==
0
else
1
layers
.
append
(
InvertedResidual
(
in_channels
=
self
.
in_channels
,
out_channels
=
out_channels
,
stride
=
stride
,
conv_cfg
=
self
.
conv_cfg
,
norm_cfg
=
self
.
norm_cfg
,
act_cfg
=
self
.
act_cfg
,
with_cp
=
self
.
with_cp
))
self
.
in_channels
=
out_channels
return
nn
.
Sequential
(
*
layers
)
def
_freeze_stages
(
self
):
if
self
.
frozen_stages
>=
0
:
for
param
in
self
.
conv1
.
parameters
():
param
.
requires_grad
=
False
for
i
in
range
(
self
.
frozen_stages
):
m
=
self
.
layers
[
i
]
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
def
init_weights
(
self
):
super
(
ShuffleNetV2
,
self
).
init_weights
()
if
(
isinstance
(
self
.
init_cfg
,
dict
)
and
self
.
init_cfg
[
'type'
]
==
'Pretrained'
):
# Suppress default init if use pretrained model.
return
for
name
,
m
in
self
.
named_modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
if
'conv1'
in
name
:
normal_init
(
m
,
mean
=
0
,
std
=
0.01
)
else
:
normal_init
(
m
,
mean
=
0
,
std
=
1.0
/
m
.
weight
.
shape
[
1
])
elif
isinstance
(
m
,
(
_BatchNorm
,
nn
.
GroupNorm
)):
constant_init
(
m
.
weight
,
val
=
1
,
bias
=
0.0001
)
if
isinstance
(
m
,
_BatchNorm
):
if
m
.
running_mean
is
not
None
:
nn
.
init
.
constant_
(
m
.
running_mean
,
0
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
maxpool
(
x
)
outs
=
[]
for
i
,
layer
in
enumerate
(
self
.
layers
):
x
=
layer
(
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
x
)
return
tuple
(
outs
)
def
train
(
self
,
mode
=
True
):
super
(
ShuffleNetV2
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
if
mode
and
self
.
norm_eval
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
.
eval
()
mmpretrain/models/backbones/sparse_convnext.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
List
,
Optional
,
Sequence
,
Union
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
cp
from
mmengine.model
import
ModuleList
,
Sequential
from
mmpretrain.registry
import
MODELS
from
..utils
import
(
SparseAvgPooling
,
SparseConv2d
,
SparseHelper
,
SparseMaxPooling
,
build_norm_layer
)
from
.convnext
import
ConvNeXt
,
ConvNeXtBlock
class
SparseConvNeXtBlock
(
ConvNeXtBlock
):
"""Sparse ConvNeXt Block.
Note:
There are two equivalent implementations:
1. DwConv -> SparseLayerNorm -> 1x1 Conv -> GELU -> 1x1 Conv;
all outputs are in (N, C, H, W).
2. DwConv -> SparseLayerNorm -> Permute to (N, H, W, C) -> Linear ->
GELU -> Linear; Permute back
As default, we use the second to align with the official repository.
And it may be slightly faster.
"""
def
forward
(
self
,
x
):
def
_inner_forward
(
x
):
shortcut
=
x
x
=
self
.
depthwise_conv
(
x
)
if
self
.
linear_pw_conv
:
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
# (N, C, H, W) -> (N, H, W, C)
x
=
self
.
norm
(
x
,
data_format
=
'channel_last'
)
x
=
self
.
pointwise_conv1
(
x
)
x
=
self
.
act
(
x
)
if
self
.
grn
is
not
None
:
x
=
self
.
grn
(
x
,
data_format
=
'channel_last'
)
x
=
self
.
pointwise_conv2
(
x
)
x
=
x
.
permute
(
0
,
3
,
1
,
2
)
# (N, H, W, C) -> (N, C, H, W)
else
:
x
=
self
.
norm
(
x
,
data_format
=
'channel_first'
)
x
=
self
.
pointwise_conv1
(
x
)
x
=
self
.
act
(
x
)
if
self
.
grn
is
not
None
:
x
=
self
.
grn
(
x
,
data_format
=
'channel_first'
)
x
=
self
.
pointwise_conv2
(
x
)
if
self
.
gamma
is
not
None
:
x
=
x
.
mul
(
self
.
gamma
.
view
(
1
,
-
1
,
1
,
1
))
x
*=
SparseHelper
.
_get_active_map_or_index
(
H
=
x
.
shape
[
2
],
returning_active_map
=
True
)
x
=
shortcut
+
self
.
drop_path
(
x
)
return
x
if
self
.
with_cp
and
x
.
requires_grad
:
x
=
cp
.
checkpoint
(
_inner_forward
,
x
)
else
:
x
=
_inner_forward
(
x
)
return
x
@
MODELS
.
register_module
()
class
SparseConvNeXt
(
ConvNeXt
):
"""ConvNeXt with sparse module conversion function.
Modified from
https://github.com/keyu-tian/SparK/blob/main/models/convnext.py
and
https://github.com/keyu-tian/SparK/blob/main/encoder.py
To use ConvNeXt v2, please set ``use_grn=True`` and ``layer_scale_init_value=0.``.
Args:
arch (str | dict): The model's architecture. If string, it should be
one of architecture in ``ConvNeXt.arch_settings``. And if dict, it
should include the following two keys:
- depths (list[int]): Number of blocks at each stage.
- channels (list[int]): The number of channels at each stage.
Defaults to 'tiny'.
in_channels (int): Number of input image channels. Defaults to 3.
stem_patch_size (int): The size of one patch in the stem layer.
Defaults to 4.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='SparseLN2d', eps=1e-6)``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
linear_pw_conv (bool): Whether to use linear layer to do pointwise
convolution. Defaults to True.
use_grn (bool): Whether to add Global Response Normalization in the
blocks. Defaults to False.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
layer_scale_init_value (float): Init value for Layer Scale.
Defaults to 1e-6.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters.
gap_before_output (bool): Whether to globally average the feature
map before the final norm layer. In the official repo, it's only
used in classification task. Defaults to True.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict, optional): Initialization config dict.
"""
# noqa: E501
def
__init__
(
self
,
arch
:
str
=
'small'
,
in_channels
:
int
=
3
,
stem_patch_size
:
int
=
4
,
norm_cfg
:
dict
=
dict
(
type
=
'SparseLN2d'
,
eps
=
1e-6
),
act_cfg
:
dict
=
dict
(
type
=
'GELU'
),
linear_pw_conv
:
bool
=
True
,
use_grn
:
bool
=
False
,
drop_path_rate
:
float
=
0
,
layer_scale_init_value
:
float
=
1e-6
,
out_indices
:
int
=
-
1
,
frozen_stages
:
int
=
0
,
gap_before_output
:
bool
=
True
,
with_cp
:
bool
=
False
,
init_cfg
:
Optional
[
Union
[
dict
,
List
[
dict
]]]
=
[
dict
(
type
=
'TruncNormal'
,
layer
=
[
'Conv2d'
,
'Linear'
],
std
=
.
02
,
bias
=
0.
),
dict
(
type
=
'Constant'
,
layer
=
[
'LayerNorm'
],
val
=
1.
,
bias
=
0.
),
]):
super
(
ConvNeXt
,
self
).
__init__
(
init_cfg
=
init_cfg
)
if
isinstance
(
arch
,
str
):
assert
arch
in
self
.
arch_settings
,
\
f
'Unavailable arch, please choose from '
\
f
'(
{
set
(
self
.
arch_settings
)
}
) or pass a dict.'
arch
=
self
.
arch_settings
[
arch
]
elif
isinstance
(
arch
,
dict
):
assert
'depths'
in
arch
and
'channels'
in
arch
,
\
f
'The arch dict must have "depths" and "channels", '
\
f
'but got
{
list
(
arch
.
keys
())
}
.'
self
.
depths
=
arch
[
'depths'
]
self
.
channels
=
arch
[
'channels'
]
assert
(
isinstance
(
self
.
depths
,
Sequence
)
and
isinstance
(
self
.
channels
,
Sequence
)
and
len
(
self
.
depths
)
==
len
(
self
.
channels
)),
\
f
'The "depths" (
{
self
.
depths
}
) and "channels" (
{
self
.
channels
}
) '
\
'should be both sequence with the same length.'
self
.
num_stages
=
len
(
self
.
depths
)
if
isinstance
(
out_indices
,
int
):
out_indices
=
[
out_indices
]
assert
isinstance
(
out_indices
,
Sequence
),
\
f
'"out_indices" must by a sequence or int, '
\
f
'get
{
type
(
out_indices
)
}
instead.'
for
i
,
index
in
enumerate
(
out_indices
):
if
index
<
0
:
out_indices
[
i
]
=
4
+
index
assert
out_indices
[
i
]
>=
0
,
f
'Invalid out_indices
{
index
}
'
self
.
out_indices
=
out_indices
self
.
frozen_stages
=
frozen_stages
self
.
gap_before_output
=
gap_before_output
# 4 downsample layers between stages, including the stem layer.
self
.
downsample_layers
=
ModuleList
()
stem
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
self
.
channels
[
0
],
kernel_size
=
stem_patch_size
,
stride
=
stem_patch_size
),
build_norm_layer
(
norm_cfg
,
self
.
channels
[
0
]),
)
self
.
downsample_layers
.
append
(
stem
)
# stochastic depth decay rule
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
self
.
depths
))
]
block_idx
=
0
# 4 feature resolution stages, each consisting of multiple residual
# blocks
self
.
stages
=
nn
.
ModuleList
()
for
i
in
range
(
self
.
num_stages
):
depth
=
self
.
depths
[
i
]
channels
=
self
.
channels
[
i
]
if
i
>=
1
:
downsample_layer
=
nn
.
Sequential
(
build_norm_layer
(
norm_cfg
,
self
.
channels
[
i
-
1
]),
nn
.
Conv2d
(
self
.
channels
[
i
-
1
],
channels
,
kernel_size
=
2
,
stride
=
2
),
)
self
.
downsample_layers
.
append
(
downsample_layer
)
stage
=
Sequential
(
*
[
SparseConvNeXtBlock
(
in_channels
=
channels
,
drop_path_rate
=
dpr
[
block_idx
+
j
],
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
,
linear_pw_conv
=
linear_pw_conv
,
layer_scale_init_value
=
layer_scale_init_value
,
use_grn
=
use_grn
,
with_cp
=
with_cp
)
for
j
in
range
(
depth
)
])
block_idx
+=
depth
self
.
stages
.
append
(
stage
)
self
.
dense_model_to_sparse
(
m
=
self
)
def
forward
(
self
,
x
):
outs
=
[]
for
i
,
stage
in
enumerate
(
self
.
stages
):
x
=
self
.
downsample_layers
[
i
](
x
)
x
=
stage
(
x
)
if
i
in
self
.
out_indices
:
if
self
.
gap_before_output
:
gap
=
x
.
mean
([
-
2
,
-
1
],
keepdim
=
True
)
outs
.
append
(
gap
.
flatten
(
1
))
else
:
outs
.
append
(
x
)
return
tuple
(
outs
)
def
dense_model_to_sparse
(
self
,
m
:
nn
.
Module
)
->
nn
.
Module
:
"""Convert regular dense modules to sparse modules."""
output
=
m
if
isinstance
(
m
,
nn
.
Conv2d
):
m
:
nn
.
Conv2d
bias
=
m
.
bias
is
not
None
output
=
SparseConv2d
(
m
.
in_channels
,
m
.
out_channels
,
kernel_size
=
m
.
kernel_size
,
stride
=
m
.
stride
,
padding
=
m
.
padding
,
dilation
=
m
.
dilation
,
groups
=
m
.
groups
,
bias
=
bias
,
padding_mode
=
m
.
padding_mode
,
)
output
.
weight
.
data
.
copy_
(
m
.
weight
.
data
)
if
bias
:
output
.
bias
.
data
.
copy_
(
m
.
bias
.
data
)
elif
isinstance
(
m
,
nn
.
MaxPool2d
):
m
:
nn
.
MaxPool2d
output
=
SparseMaxPooling
(
m
.
kernel_size
,
stride
=
m
.
stride
,
padding
=
m
.
padding
,
dilation
=
m
.
dilation
,
return_indices
=
m
.
return_indices
,
ceil_mode
=
m
.
ceil_mode
)
elif
isinstance
(
m
,
nn
.
AvgPool2d
):
m
:
nn
.
AvgPool2d
output
=
SparseAvgPooling
(
m
.
kernel_size
,
m
.
stride
,
m
.
padding
,
ceil_mode
=
m
.
ceil_mode
,
count_include_pad
=
m
.
count_include_pad
,
divisor_override
=
m
.
divisor_override
)
# elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
# m: nn.BatchNorm2d
# output = (SparseSyncBatchNorm2d
# if enable_sync_bn else SparseBatchNorm2d)(
# m.weight.shape[0],
# eps=m.eps,
# momentum=m.momentum,
# affine=m.affine,
# track_running_stats=m.track_running_stats)
# output.weight.data.copy_(m.weight.data)
# output.bias.data.copy_(m.bias.data)
# output.running_mean.data.copy_(m.running_mean.data)
# output.running_var.data.copy_(m.running_var.data)
# output.num_batches_tracked.data.copy_(m.num_batches_tracked.data)
for
name
,
child
in
m
.
named_children
():
output
.
add_module
(
name
,
self
.
dense_model_to_sparse
(
child
))
del
m
return
output
mmpretrain/models/backbones/sparse_resnet.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
import
re
from
typing
import
Optional
,
Tuple
import
torch.nn
as
nn
from
mmpretrain.models.utils.sparse_modules
import
(
SparseAvgPooling
,
SparseBatchNorm2d
,
SparseConv2d
,
SparseMaxPooling
,
SparseSyncBatchNorm2d
)
from
mmpretrain.registry
import
MODELS
from
.resnet
import
ResNet
@
MODELS
.
register_module
()
class
SparseResNet
(
ResNet
):
"""ResNet with sparse module conversion function.
Modified from https://github.com/keyu-tian/SparK/blob/main/encoder.py
Args:
depth (int): Network depth, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Defaults to 3.
stem_channels (int): Output channels of the stem layer. Defaults to 64.
base_channels (int): Middle channels of the first stage.
Defaults to 64.
num_stages (int): Stages of the network. Defaults to 4.
strides (Sequence[int]): Strides of the first block of each stage.
Defaults to ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Defaults to ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages.
Defaults to ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Defaults to False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
conv_cfg (dict | None): The config dict for conv layers.
Defaults to None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Defaults to True.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
"""
def
__init__
(
self
,
depth
:
int
,
in_channels
:
int
=
3
,
stem_channels
:
int
=
64
,
base_channels
:
int
=
64
,
expansion
:
Optional
[
int
]
=
None
,
num_stages
:
int
=
4
,
strides
:
Tuple
[
int
]
=
(
1
,
2
,
2
,
2
),
dilations
:
Tuple
[
int
]
=
(
1
,
1
,
1
,
1
),
out_indices
:
Tuple
[
int
]
=
(
3
,
),
style
:
str
=
'pytorch'
,
deep_stem
:
bool
=
False
,
avg_down
:
bool
=
False
,
frozen_stages
:
int
=
-
1
,
conv_cfg
:
Optional
[
dict
]
=
None
,
norm_cfg
:
dict
=
dict
(
type
=
'SparseSyncBatchNorm2d'
),
norm_eval
:
bool
=
False
,
with_cp
:
bool
=
False
,
zero_init_residual
:
bool
=
False
,
init_cfg
:
Optional
[
dict
]
=
[
dict
(
type
=
'Kaiming'
,
layer
=
[
'Conv2d'
]),
dict
(
type
=
'Constant'
,
val
=
1
,
layer
=
[
'_BatchNorm'
,
'GroupNorm'
])
],
drop_path_rate
:
float
=
0
,
**
kwargs
):
super
().
__init__
(
depth
=
depth
,
in_channels
=
in_channels
,
stem_channels
=
stem_channels
,
base_channels
=
base_channels
,
expansion
=
expansion
,
num_stages
=
num_stages
,
strides
=
strides
,
dilations
=
dilations
,
out_indices
=
out_indices
,
style
=
style
,
deep_stem
=
deep_stem
,
avg_down
=
avg_down
,
frozen_stages
=
frozen_stages
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
norm_eval
=
norm_eval
,
with_cp
=
with_cp
,
zero_init_residual
=
zero_init_residual
,
init_cfg
=
init_cfg
,
drop_path_rate
=
drop_path_rate
,
**
kwargs
)
norm_type
=
norm_cfg
[
'type'
]
enable_sync_bn
=
False
if
re
.
search
(
'Sync'
,
norm_type
)
is
not
None
:
enable_sync_bn
=
True
self
.
dense_model_to_sparse
(
m
=
self
,
enable_sync_bn
=
enable_sync_bn
)
def
dense_model_to_sparse
(
self
,
m
:
nn
.
Module
,
enable_sync_bn
:
bool
)
->
nn
.
Module
:
"""Convert regular dense modules to sparse modules."""
output
=
m
if
isinstance
(
m
,
nn
.
Conv2d
):
m
:
nn
.
Conv2d
bias
=
m
.
bias
is
not
None
output
=
SparseConv2d
(
m
.
in_channels
,
m
.
out_channels
,
kernel_size
=
m
.
kernel_size
,
stride
=
m
.
stride
,
padding
=
m
.
padding
,
dilation
=
m
.
dilation
,
groups
=
m
.
groups
,
bias
=
bias
,
padding_mode
=
m
.
padding_mode
,
)
output
.
weight
.
data
.
copy_
(
m
.
weight
.
data
)
if
bias
:
output
.
bias
.
data
.
copy_
(
m
.
bias
.
data
)
elif
isinstance
(
m
,
nn
.
MaxPool2d
):
m
:
nn
.
MaxPool2d
output
=
SparseMaxPooling
(
m
.
kernel_size
,
stride
=
m
.
stride
,
padding
=
m
.
padding
,
dilation
=
m
.
dilation
,
return_indices
=
m
.
return_indices
,
ceil_mode
=
m
.
ceil_mode
)
elif
isinstance
(
m
,
nn
.
AvgPool2d
):
m
:
nn
.
AvgPool2d
output
=
SparseAvgPooling
(
m
.
kernel_size
,
m
.
stride
,
m
.
padding
,
ceil_mode
=
m
.
ceil_mode
,
count_include_pad
=
m
.
count_include_pad
,
divisor_override
=
m
.
divisor_override
)
elif
isinstance
(
m
,
(
nn
.
BatchNorm2d
,
nn
.
SyncBatchNorm
)):
m
:
nn
.
BatchNorm2d
output
=
(
SparseSyncBatchNorm2d
if
enable_sync_bn
else
SparseBatchNorm2d
)(
m
.
weight
.
shape
[
0
],
eps
=
m
.
eps
,
momentum
=
m
.
momentum
,
affine
=
m
.
affine
,
track_running_stats
=
m
.
track_running_stats
)
output
.
weight
.
data
.
copy_
(
m
.
weight
.
data
)
output
.
bias
.
data
.
copy_
(
m
.
bias
.
data
)
output
.
running_mean
.
data
.
copy_
(
m
.
running_mean
.
data
)
output
.
running_var
.
data
.
copy_
(
m
.
running_var
.
data
)
output
.
num_batches_tracked
.
data
.
copy_
(
m
.
num_batches_tracked
.
data
)
elif
isinstance
(
m
,
(
nn
.
Conv1d
,
)):
raise
NotImplementedError
for
name
,
child
in
m
.
named_children
():
output
.
add_module
(
name
,
self
.
dense_model_to_sparse
(
child
,
enable_sync_bn
=
enable_sync_bn
))
del
m
return
output
mmpretrain/models/backbones/swin_transformer.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
from
typing
import
Sequence
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
cp
from
mmcv.cnn
import
build_norm_layer
from
mmcv.cnn.bricks.transformer
import
FFN
,
PatchEmbed
,
PatchMerging
from
mmengine.model
import
BaseModule
,
ModuleList
from
mmengine.model.weight_init
import
trunc_normal_
from
mmengine.utils.dl_utils.parrots_wrapper
import
_BatchNorm
from
mmpretrain.registry
import
MODELS
from
..utils
import
(
ShiftWindowMSA
,
resize_pos_embed
,
resize_relative_position_bias_table
,
to_2tuple
)
from
.base_backbone
import
BaseBackbone
class
SwinBlock
(
BaseModule
):
"""Swin Transformer block.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
shift (bool): Shift the attention window or not. Defaults to False.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
drop_path (float): The drop path rate after attention and ffn.
Defaults to 0.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
attn_cfgs (dict): The extra config of Shift Window-MSA.
Defaults to empty dict.
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
norm_cfg (dict): The config of norm layers.
Defaults to ``dict(type='LN')``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
window_size
=
7
,
shift
=
False
,
ffn_ratio
=
4.
,
drop_path
=
0.
,
pad_small_map
=
False
,
attn_cfgs
=
dict
(),
ffn_cfgs
=
dict
(),
norm_cfg
=
dict
(
type
=
'LN'
),
with_cp
=
False
,
init_cfg
=
None
):
super
(
SwinBlock
,
self
).
__init__
(
init_cfg
)
self
.
with_cp
=
with_cp
_attn_cfgs
=
{
'embed_dims'
:
embed_dims
,
'num_heads'
:
num_heads
,
'shift_size'
:
window_size
//
2
if
shift
else
0
,
'window_size'
:
window_size
,
'dropout_layer'
:
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path
),
'pad_small_map'
:
pad_small_map
,
**
attn_cfgs
}
self
.
norm1
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
self
.
attn
=
ShiftWindowMSA
(
**
_attn_cfgs
)
_ffn_cfgs
=
{
'embed_dims'
:
embed_dims
,
'feedforward_channels'
:
int
(
embed_dims
*
ffn_ratio
),
'num_fcs'
:
2
,
'ffn_drop'
:
0
,
'dropout_layer'
:
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path
),
'act_cfg'
:
dict
(
type
=
'GELU'
),
**
ffn_cfgs
}
self
.
norm2
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
self
.
ffn
=
FFN
(
**
_ffn_cfgs
)
def
forward
(
self
,
x
,
hw_shape
):
def
_inner_forward
(
x
):
identity
=
x
x
=
self
.
norm1
(
x
)
x
=
self
.
attn
(
x
,
hw_shape
)
x
=
x
+
identity
identity
=
x
x
=
self
.
norm2
(
x
)
x
=
self
.
ffn
(
x
,
identity
=
identity
)
return
x
if
self
.
with_cp
and
x
.
requires_grad
:
x
=
cp
.
checkpoint
(
_inner_forward
,
x
)
else
:
x
=
_inner_forward
(
x
)
return
x
class
SwinBlockSequence
(
BaseModule
):
"""Module with successive Swin Transformer blocks and downsample layer.
Args:
embed_dims (int): Number of input channels.
depth (int): Number of successive swin transformer blocks.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
downsample (bool): Downsample the output of blocks by patch merging.
Defaults to False.
downsample_cfg (dict): The extra config of the patch merging layer.
Defaults to empty dict.
drop_paths (Sequence[float] | float): The drop path rate in each block.
Defaults to 0.
block_cfgs (Sequence[dict] | dict): The extra config of each block.
Defaults to empty dicts.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
depth
,
num_heads
,
window_size
=
7
,
downsample
=
False
,
downsample_cfg
=
dict
(),
drop_paths
=
0.
,
block_cfgs
=
dict
(),
with_cp
=
False
,
pad_small_map
=
False
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
if
not
isinstance
(
drop_paths
,
Sequence
):
drop_paths
=
[
drop_paths
]
*
depth
if
not
isinstance
(
block_cfgs
,
Sequence
):
block_cfgs
=
[
deepcopy
(
block_cfgs
)
for
_
in
range
(
depth
)]
self
.
embed_dims
=
embed_dims
self
.
blocks
=
ModuleList
()
for
i
in
range
(
depth
):
_block_cfg
=
{
'embed_dims'
:
embed_dims
,
'num_heads'
:
num_heads
,
'window_size'
:
window_size
,
'shift'
:
False
if
i
%
2
==
0
else
True
,
'drop_path'
:
drop_paths
[
i
],
'with_cp'
:
with_cp
,
'pad_small_map'
:
pad_small_map
,
**
block_cfgs
[
i
]
}
block
=
SwinBlock
(
**
_block_cfg
)
self
.
blocks
.
append
(
block
)
if
downsample
:
_downsample_cfg
=
{
'in_channels'
:
embed_dims
,
'out_channels'
:
2
*
embed_dims
,
'norm_cfg'
:
dict
(
type
=
'LN'
),
**
downsample_cfg
}
self
.
downsample
=
PatchMerging
(
**
_downsample_cfg
)
else
:
self
.
downsample
=
None
def
forward
(
self
,
x
,
in_shape
,
do_downsample
=
True
):
for
block
in
self
.
blocks
:
x
=
block
(
x
,
in_shape
)
if
self
.
downsample
is
not
None
and
do_downsample
:
x
,
out_shape
=
self
.
downsample
(
x
,
in_shape
)
else
:
out_shape
=
in_shape
return
x
,
out_shape
@
property
def
out_channels
(
self
):
if
self
.
downsample
:
return
self
.
downsample
.
out_channels
else
:
return
self
.
embed_dims
@
MODELS
.
register_module
()
class
SwinTransformer
(
BaseBackbone
):
"""Swin Transformer.
A PyTorch implement of : `Swin Transformer:
Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>`_
Inspiration from
https://github.com/microsoft/Swin-Transformer
Args:
arch (str | dict): Swin Transformer architecture. If use string, choose
from 'tiny', 'small', 'base' and 'large'. If use dict, it should
have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **depths** (List[int]): The number of blocks in each stage.
- **num_heads** (List[int]): The number of heads in attention
modules of each stage.
Defaults to 'tiny'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 4.
in_channels (int): The num of input channels. Defaults to 3.
window_size (int): The height and width of the window. Defaults to 7.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
out_after_downsample (bool): Whether to output the feature map of a
stage after the following downsample layer. Defaults to False.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults to False.
interpolate_mode (str): Select the interpolate mode for absolute
position embeding vector resize. Defaults to "bicubic".
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
norm_cfg (dict): Config dict for normalization layer for all output
features. Defaults to ``dict(type='LN')``
stage_cfgs (Sequence[dict] | dict): Extra config dict for each
stage. Defaults to an empty dict.
patch_cfg (dict): Extra config dict for patch embedding.
Defaults to an empty dict.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmpretrain.models import SwinTransformer
>>> import torch
>>> extra_config = dict(
>>> arch='tiny',
>>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3,
>>> 'expansion_ratio': 3}))
>>> self = SwinTransformer(**extra_config)
>>> inputs = torch.rand(1, 3, 224, 224)
>>> output = self.forward(inputs)
>>> print(output.shape)
(1, 2592, 4)
"""
arch_zoo
=
{
**
dict
.
fromkeys
([
't'
,
'tiny'
],
{
'embed_dims'
:
96
,
'depths'
:
[
2
,
2
,
6
,
2
],
'num_heads'
:
[
3
,
6
,
12
,
24
]}),
**
dict
.
fromkeys
([
's'
,
'small'
],
{
'embed_dims'
:
96
,
'depths'
:
[
2
,
2
,
18
,
2
],
'num_heads'
:
[
3
,
6
,
12
,
24
]}),
**
dict
.
fromkeys
([
'b'
,
'base'
],
{
'embed_dims'
:
128
,
'depths'
:
[
2
,
2
,
18
,
2
],
'num_heads'
:
[
4
,
8
,
16
,
32
]}),
**
dict
.
fromkeys
([
'l'
,
'large'
],
{
'embed_dims'
:
192
,
'depths'
:
[
2
,
2
,
18
,
2
],
'num_heads'
:
[
6
,
12
,
24
,
48
]}),
}
# yapf: disable
_version
=
3
num_extra_tokens
=
0
def
__init__
(
self
,
arch
=
'tiny'
,
img_size
=
224
,
patch_size
=
4
,
in_channels
=
3
,
window_size
=
7
,
drop_rate
=
0.
,
drop_path_rate
=
0.1
,
out_indices
=
(
3
,
),
out_after_downsample
=
False
,
use_abs_pos_embed
=
False
,
interpolate_mode
=
'bicubic'
,
with_cp
=
False
,
frozen_stages
=-
1
,
norm_eval
=
False
,
pad_small_map
=
False
,
norm_cfg
=
dict
(
type
=
'LN'
),
stage_cfgs
=
dict
(),
patch_cfg
=
dict
(),
init_cfg
=
None
):
super
(
SwinTransformer
,
self
).
__init__
(
init_cfg
=
init_cfg
)
if
isinstance
(
arch
,
str
):
arch
=
arch
.
lower
()
assert
arch
in
set
(
self
.
arch_zoo
),
\
f
'Arch
{
arch
}
is not in default archs
{
set
(
self
.
arch_zoo
)
}
'
self
.
arch_settings
=
self
.
arch_zoo
[
arch
]
else
:
essential_keys
=
{
'embed_dims'
,
'depths'
,
'num_heads'
}
assert
isinstance
(
arch
,
dict
)
and
set
(
arch
)
==
essential_keys
,
\
f
'Custom arch needs a dict with keys
{
essential_keys
}
'
self
.
arch_settings
=
arch
self
.
embed_dims
=
self
.
arch_settings
[
'embed_dims'
]
self
.
depths
=
self
.
arch_settings
[
'depths'
]
self
.
num_heads
=
self
.
arch_settings
[
'num_heads'
]
self
.
num_layers
=
len
(
self
.
depths
)
self
.
out_indices
=
out_indices
self
.
out_after_downsample
=
out_after_downsample
self
.
use_abs_pos_embed
=
use_abs_pos_embed
self
.
interpolate_mode
=
interpolate_mode
self
.
frozen_stages
=
frozen_stages
_patch_cfg
=
dict
(
in_channels
=
in_channels
,
input_size
=
img_size
,
embed_dims
=
self
.
embed_dims
,
conv_type
=
'Conv2d'
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
norm_cfg
=
dict
(
type
=
'LN'
),
)
_patch_cfg
.
update
(
patch_cfg
)
self
.
patch_embed
=
PatchEmbed
(
**
_patch_cfg
)
self
.
patch_resolution
=
self
.
patch_embed
.
init_out_size
if
self
.
use_abs_pos_embed
:
num_patches
=
self
.
patch_resolution
[
0
]
*
self
.
patch_resolution
[
1
]
self
.
absolute_pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
,
self
.
embed_dims
))
self
.
_register_load_state_dict_pre_hook
(
self
.
_prepare_abs_pos_embed
)
self
.
_register_load_state_dict_pre_hook
(
self
.
_prepare_relative_position_bias_table
)
self
.
drop_after_pos
=
nn
.
Dropout
(
p
=
drop_rate
)
self
.
norm_eval
=
norm_eval
# stochastic depth
total_depth
=
sum
(
self
.
depths
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
total_depth
)
]
# stochastic depth decay rule
self
.
stages
=
ModuleList
()
embed_dims
=
[
self
.
embed_dims
]
for
i
,
(
depth
,
num_heads
)
in
enumerate
(
zip
(
self
.
depths
,
self
.
num_heads
)):
if
isinstance
(
stage_cfgs
,
Sequence
):
stage_cfg
=
stage_cfgs
[
i
]
else
:
stage_cfg
=
deepcopy
(
stage_cfgs
)
downsample
=
True
if
i
<
self
.
num_layers
-
1
else
False
_stage_cfg
=
{
'embed_dims'
:
embed_dims
[
-
1
],
'depth'
:
depth
,
'num_heads'
:
num_heads
,
'window_size'
:
window_size
,
'downsample'
:
downsample
,
'drop_paths'
:
dpr
[:
depth
],
'with_cp'
:
with_cp
,
'pad_small_map'
:
pad_small_map
,
**
stage_cfg
}
stage
=
SwinBlockSequence
(
**
_stage_cfg
)
self
.
stages
.
append
(
stage
)
dpr
=
dpr
[
depth
:]
embed_dims
.
append
(
stage
.
out_channels
)
if
self
.
out_after_downsample
:
self
.
num_features
=
embed_dims
[
1
:]
else
:
self
.
num_features
=
embed_dims
[:
-
1
]
for
i
in
out_indices
:
if
norm_cfg
is
not
None
:
norm_layer
=
build_norm_layer
(
norm_cfg
,
self
.
num_features
[
i
])[
1
]
else
:
norm_layer
=
nn
.
Identity
()
self
.
add_module
(
f
'norm
{
i
}
'
,
norm_layer
)
def
init_weights
(
self
):
super
(
SwinTransformer
,
self
).
init_weights
()
if
(
isinstance
(
self
.
init_cfg
,
dict
)
and
self
.
init_cfg
[
'type'
]
==
'Pretrained'
):
# Suppress default init if use pretrained model.
return
if
self
.
use_abs_pos_embed
:
trunc_normal_
(
self
.
absolute_pos_embed
,
std
=
0.02
)
def
forward
(
self
,
x
):
x
,
hw_shape
=
self
.
patch_embed
(
x
)
if
self
.
use_abs_pos_embed
:
x
=
x
+
resize_pos_embed
(
self
.
absolute_pos_embed
,
self
.
patch_resolution
,
hw_shape
,
self
.
interpolate_mode
,
self
.
num_extra_tokens
)
x
=
self
.
drop_after_pos
(
x
)
outs
=
[]
for
i
,
stage
in
enumerate
(
self
.
stages
):
x
,
hw_shape
=
stage
(
x
,
hw_shape
,
do_downsample
=
self
.
out_after_downsample
)
if
i
in
self
.
out_indices
:
norm_layer
=
getattr
(
self
,
f
'norm
{
i
}
'
)
out
=
norm_layer
(
x
)
out
=
out
.
view
(
-
1
,
*
hw_shape
,
self
.
num_features
[
i
]).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
out
)
if
stage
.
downsample
is
not
None
and
not
self
.
out_after_downsample
:
x
,
hw_shape
=
stage
.
downsample
(
x
,
hw_shape
)
return
tuple
(
outs
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
*
args
,
**
kwargs
):
"""load checkpoints."""
# Names of some parameters in has been changed.
version
=
local_metadata
.
get
(
'version'
,
None
)
if
(
version
is
None
or
version
<
2
)
and
self
.
__class__
is
SwinTransformer
:
final_stage_num
=
len
(
self
.
stages
)
-
1
state_dict_keys
=
list
(
state_dict
.
keys
())
for
k
in
state_dict_keys
:
if
k
.
startswith
(
'norm.'
)
or
k
.
startswith
(
'backbone.norm.'
):
convert_key
=
k
.
replace
(
'norm.'
,
f
'norm
{
final_stage_num
}
.'
)
state_dict
[
convert_key
]
=
state_dict
[
k
]
del
state_dict
[
k
]
if
(
version
is
None
or
version
<
3
)
and
self
.
__class__
is
SwinTransformer
:
state_dict_keys
=
list
(
state_dict
.
keys
())
for
k
in
state_dict_keys
:
if
'attn_mask'
in
k
:
del
state_dict
[
k
]
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
*
args
,
**
kwargs
)
def
_freeze_stages
(
self
):
if
self
.
frozen_stages
>=
0
:
self
.
patch_embed
.
eval
()
for
param
in
self
.
patch_embed
.
parameters
():
param
.
requires_grad
=
False
for
i
in
range
(
0
,
self
.
frozen_stages
+
1
):
m
=
self
.
stages
[
i
]
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
for
i
in
self
.
out_indices
:
if
i
<=
self
.
frozen_stages
:
for
param
in
getattr
(
self
,
f
'norm
{
i
}
'
).
parameters
():
param
.
requires_grad
=
False
def
train
(
self
,
mode
=
True
):
super
(
SwinTransformer
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
if
mode
and
self
.
norm_eval
:
for
m
in
self
.
modules
():
# trick: eval have effect on BatchNorm only
if
isinstance
(
m
,
_BatchNorm
):
m
.
eval
()
def
_prepare_abs_pos_embed
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
name
=
prefix
+
'absolute_pos_embed'
if
name
not
in
state_dict
.
keys
():
return
ckpt_pos_embed_shape
=
state_dict
[
name
].
shape
if
self
.
absolute_pos_embed
.
shape
!=
ckpt_pos_embed_shape
:
from
mmengine.logging
import
MMLogger
logger
=
MMLogger
.
get_current_instance
()
logger
.
info
(
'Resize the absolute_pos_embed shape from '
f
'
{
ckpt_pos_embed_shape
}
to
{
self
.
absolute_pos_embed
.
shape
}
.'
)
ckpt_pos_embed_shape
=
to_2tuple
(
int
(
np
.
sqrt
(
ckpt_pos_embed_shape
[
1
]
-
self
.
num_extra_tokens
)))
pos_embed_shape
=
self
.
patch_embed
.
init_out_size
state_dict
[
name
]
=
resize_pos_embed
(
state_dict
[
name
],
ckpt_pos_embed_shape
,
pos_embed_shape
,
self
.
interpolate_mode
,
self
.
num_extra_tokens
)
def
_prepare_relative_position_bias_table
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
state_dict_model
=
self
.
state_dict
()
all_keys
=
list
(
state_dict_model
.
keys
())
for
key
in
all_keys
:
if
'relative_position_bias_table'
in
key
:
ckpt_key
=
prefix
+
key
if
ckpt_key
not
in
state_dict
:
continue
relative_position_bias_table_pretrained
=
state_dict
[
ckpt_key
]
relative_position_bias_table_current
=
state_dict_model
[
key
]
L1
,
nH1
=
relative_position_bias_table_pretrained
.
size
()
L2
,
nH2
=
relative_position_bias_table_current
.
size
()
if
L1
!=
L2
:
src_size
=
int
(
L1
**
0.5
)
dst_size
=
int
(
L2
**
0.5
)
new_rel_pos_bias
=
resize_relative_position_bias_table
(
src_size
,
dst_size
,
relative_position_bias_table_pretrained
,
nH1
)
from
mmengine.logging
import
MMLogger
logger
=
MMLogger
.
get_current_instance
()
logger
.
info
(
'Resize the relative_position_bias_table from '
f
'
{
state_dict
[
ckpt_key
].
shape
}
to '
f
'
{
new_rel_pos_bias
.
shape
}
'
)
state_dict
[
ckpt_key
]
=
new_rel_pos_bias
# The index buffer need to be re-generated.
index_buffer
=
ckpt_key
.
replace
(
'bias_table'
,
'index'
)
del
state_dict
[
index_buffer
]
def
get_layer_depth
(
self
,
param_name
:
str
,
prefix
:
str
=
''
):
"""Get the layer-wise depth of a parameter.
Args:
param_name (str): The name of the parameter.
prefix (str): The prefix for the parameter.
Defaults to an empty string.
Returns:
Tuple[int, int]: The layer-wise depth and the num of layers.
Note:
The first depth is the stem module (``layer_depth=0``), and the
last depth is the subsequent module (``layer_depth=num_layers-1``)
"""
num_layers
=
sum
(
self
.
depths
)
+
2
if
not
param_name
.
startswith
(
prefix
):
# For subsequent module like head
return
num_layers
-
1
,
num_layers
param_name
=
param_name
[
len
(
prefix
):]
if
param_name
.
startswith
(
'patch_embed'
):
layer_depth
=
0
elif
param_name
.
startswith
(
'stages'
):
stage_id
=
int
(
param_name
.
split
(
'.'
)[
1
])
block_id
=
param_name
.
split
(
'.'
)[
3
]
if
block_id
in
(
'reduction'
,
'norm'
):
layer_depth
=
sum
(
self
.
depths
[:
stage_id
+
1
])
else
:
layer_depth
=
sum
(
self
.
depths
[:
stage_id
])
+
int
(
block_id
)
+
1
else
:
layer_depth
=
num_layers
-
1
return
layer_depth
,
num_layers
mmpretrain/models/backbones/swin_transformer_v2.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
from
typing
import
Sequence
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
cp
from
mmcv.cnn
import
build_norm_layer
from
mmcv.cnn.bricks.transformer
import
FFN
,
PatchEmbed
from
mmengine.model
import
BaseModule
,
ModuleList
from
mmengine.model.weight_init
import
trunc_normal_
from
mmengine.utils.dl_utils.parrots_wrapper
import
_BatchNorm
from
..builder
import
MODELS
from
..utils
import
(
PatchMerging
,
ShiftWindowMSA
,
WindowMSAV2
,
resize_pos_embed
,
to_2tuple
)
from
.base_backbone
import
BaseBackbone
class
SwinBlockV2
(
BaseModule
):
"""Swin Transformer V2 block. Use post normalization.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
shift (bool): Shift the attention window or not. Defaults to False.
extra_norm (bool): Whether add extra norm at the end of main branch.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
drop_path (float): The drop path rate after attention and ffn.
Defaults to 0.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
attn_cfgs (dict): The extra config of Shift Window-MSA.
Defaults to empty dict.
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
norm_cfg (dict): The config of norm layers.
Defaults to ``dict(type='LN')``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
pretrained_window_size (int): Window size in pretrained.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
window_size
=
8
,
shift
=
False
,
extra_norm
=
False
,
ffn_ratio
=
4.
,
drop_path
=
0.
,
pad_small_map
=
False
,
attn_cfgs
=
dict
(),
ffn_cfgs
=
dict
(),
norm_cfg
=
dict
(
type
=
'LN'
),
with_cp
=
False
,
pretrained_window_size
=
0
,
init_cfg
=
None
):
super
(
SwinBlockV2
,
self
).
__init__
(
init_cfg
)
self
.
with_cp
=
with_cp
self
.
extra_norm
=
extra_norm
_attn_cfgs
=
{
'embed_dims'
:
embed_dims
,
'num_heads'
:
num_heads
,
'shift_size'
:
window_size
//
2
if
shift
else
0
,
'window_size'
:
window_size
,
'dropout_layer'
:
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path
),
'pad_small_map'
:
pad_small_map
,
**
attn_cfgs
}
# use V2 attention implementation
_attn_cfgs
.
update
(
window_msa
=
WindowMSAV2
,
pretrained_window_size
=
to_2tuple
(
pretrained_window_size
))
self
.
attn
=
ShiftWindowMSA
(
**
_attn_cfgs
)
self
.
norm1
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
_ffn_cfgs
=
{
'embed_dims'
:
embed_dims
,
'feedforward_channels'
:
int
(
embed_dims
*
ffn_ratio
),
'num_fcs'
:
2
,
'ffn_drop'
:
0
,
'dropout_layer'
:
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path
),
'act_cfg'
:
dict
(
type
=
'GELU'
),
'add_identity'
:
False
,
**
ffn_cfgs
}
self
.
ffn
=
FFN
(
**
_ffn_cfgs
)
self
.
norm2
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
# add extra norm for every n blocks in huge and giant model
if
self
.
extra_norm
:
self
.
norm3
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
def
forward
(
self
,
x
,
hw_shape
):
def
_inner_forward
(
x
):
# Use post normalization
identity
=
x
x
=
self
.
attn
(
x
,
hw_shape
)
x
=
self
.
norm1
(
x
)
x
=
x
+
identity
identity
=
x
x
=
self
.
ffn
(
x
)
x
=
self
.
norm2
(
x
)
x
=
x
+
identity
if
self
.
extra_norm
:
x
=
self
.
norm3
(
x
)
return
x
if
self
.
with_cp
and
x
.
requires_grad
:
x
=
cp
.
checkpoint
(
_inner_forward
,
x
)
else
:
x
=
_inner_forward
(
x
)
return
x
class
SwinBlockV2Sequence
(
BaseModule
):
"""Module with successive Swin Transformer blocks and downsample layer.
Args:
embed_dims (int): Number of input channels.
depth (int): Number of successive swin transformer blocks.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
downsample (bool): Downsample the output of blocks by patch merging.
Defaults to False.
downsample_cfg (dict): The extra config of the patch merging layer.
Defaults to empty dict.
drop_paths (Sequence[float] | float): The drop path rate in each block.
Defaults to 0.
block_cfgs (Sequence[dict] | dict): The extra config of each block.
Defaults to empty dicts.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
extra_norm_every_n_blocks (int): Add extra norm at the end of main
branch every n blocks. Defaults to 0, which means no needs for
extra norm layer.
pretrained_window_size (int): Window size in pretrained.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
depth
,
num_heads
,
window_size
=
8
,
downsample
=
False
,
downsample_cfg
=
dict
(),
drop_paths
=
0.
,
block_cfgs
=
dict
(),
with_cp
=
False
,
pad_small_map
=
False
,
extra_norm_every_n_blocks
=
0
,
pretrained_window_size
=
0
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
if
not
isinstance
(
drop_paths
,
Sequence
):
drop_paths
=
[
drop_paths
]
*
depth
if
not
isinstance
(
block_cfgs
,
Sequence
):
block_cfgs
=
[
deepcopy
(
block_cfgs
)
for
_
in
range
(
depth
)]
if
downsample
:
self
.
out_channels
=
2
*
embed_dims
_downsample_cfg
=
{
'in_channels'
:
embed_dims
,
'out_channels'
:
self
.
out_channels
,
'norm_cfg'
:
dict
(
type
=
'LN'
),
**
downsample_cfg
}
self
.
downsample
=
PatchMerging
(
**
_downsample_cfg
)
else
:
self
.
out_channels
=
embed_dims
self
.
downsample
=
None
self
.
blocks
=
ModuleList
()
for
i
in
range
(
depth
):
extra_norm
=
True
if
extra_norm_every_n_blocks
and
\
(
i
+
1
)
%
extra_norm_every_n_blocks
==
0
else
False
_block_cfg
=
{
'embed_dims'
:
self
.
out_channels
,
'num_heads'
:
num_heads
,
'window_size'
:
window_size
,
'shift'
:
False
if
i
%
2
==
0
else
True
,
'extra_norm'
:
extra_norm
,
'drop_path'
:
drop_paths
[
i
],
'with_cp'
:
with_cp
,
'pad_small_map'
:
pad_small_map
,
'pretrained_window_size'
:
pretrained_window_size
,
**
block_cfgs
[
i
]
}
block
=
SwinBlockV2
(
**
_block_cfg
)
self
.
blocks
.
append
(
block
)
def
forward
(
self
,
x
,
in_shape
):
if
self
.
downsample
:
x
,
out_shape
=
self
.
downsample
(
x
,
in_shape
)
else
:
out_shape
=
in_shape
for
block
in
self
.
blocks
:
x
=
block
(
x
,
out_shape
)
return
x
,
out_shape
@
MODELS
.
register_module
()
class
SwinTransformerV2
(
BaseBackbone
):
"""Swin Transformer V2.
A PyTorch implement of : `Swin Transformer V2:
Scaling Up Capacity and Resolution
<https://arxiv.org/abs/2111.09883>`_
Inspiration from
https://github.com/microsoft/Swin-Transformer
Args:
arch (str | dict): Swin Transformer architecture. If use string, choose
from 'tiny', 'small', 'base' and 'large'. If use dict, it should
have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **depths** (List[int]): The number of blocks in each stage.
- **num_heads** (List[int]): The number of heads in attention
modules of each stage.
- **extra_norm_every_n_blocks** (int): Add extra norm at the end
of main branch every n blocks.
Defaults to 'tiny'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 4.
in_channels (int): The num of input channels. Defaults to 3.
window_size (int | Sequence): The height and width of the window.
Defaults to 7.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults to False.
interpolate_mode (str): Select the interpolate mode for absolute
position embeding vector resize. Defaults to "bicubic".
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
norm_cfg (dict): Config dict for normalization layer for all output
features. Defaults to ``dict(type='LN')``
stage_cfgs (Sequence[dict] | dict): Extra config dict for each
stage. Defaults to an empty dict.
patch_cfg (dict): Extra config dict for patch embedding.
Defaults to an empty dict.
pretrained_window_sizes (tuple(int)): Pretrained window sizes of
each layer.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmpretrain.models import SwinTransformerV2
>>> import torch
>>> extra_config = dict(
>>> arch='tiny',
>>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3,
>>> 'padding': 'same'}))
>>> self = SwinTransformerV2(**extra_config)
>>> inputs = torch.rand(1, 3, 224, 224)
>>> output = self.forward(inputs)
>>> print(output.shape)
(1, 2592, 4)
"""
arch_zoo
=
{
**
dict
.
fromkeys
([
't'
,
'tiny'
],
{
'embed_dims'
:
96
,
'depths'
:
[
2
,
2
,
6
,
2
],
'num_heads'
:
[
3
,
6
,
12
,
24
],
'extra_norm_every_n_blocks'
:
0
}),
**
dict
.
fromkeys
([
's'
,
'small'
],
{
'embed_dims'
:
96
,
'depths'
:
[
2
,
2
,
18
,
2
],
'num_heads'
:
[
3
,
6
,
12
,
24
],
'extra_norm_every_n_blocks'
:
0
}),
**
dict
.
fromkeys
([
'b'
,
'base'
],
{
'embed_dims'
:
128
,
'depths'
:
[
2
,
2
,
18
,
2
],
'num_heads'
:
[
4
,
8
,
16
,
32
],
'extra_norm_every_n_blocks'
:
0
}),
**
dict
.
fromkeys
([
'l'
,
'large'
],
{
'embed_dims'
:
192
,
'depths'
:
[
2
,
2
,
18
,
2
],
'num_heads'
:
[
6
,
12
,
24
,
48
],
'extra_norm_every_n_blocks'
:
0
}),
# head count not certain for huge, and is employed for another
# parallel study about self-supervised learning.
**
dict
.
fromkeys
([
'h'
,
'huge'
],
{
'embed_dims'
:
352
,
'depths'
:
[
2
,
2
,
18
,
2
],
'num_heads'
:
[
8
,
16
,
32
,
64
],
'extra_norm_every_n_blocks'
:
6
}),
**
dict
.
fromkeys
([
'g'
,
'giant'
],
{
'embed_dims'
:
512
,
'depths'
:
[
2
,
2
,
42
,
4
],
'num_heads'
:
[
16
,
32
,
64
,
128
],
'extra_norm_every_n_blocks'
:
6
}),
}
# yapf: disable
_version
=
1
num_extra_tokens
=
0
def
__init__
(
self
,
arch
=
'tiny'
,
img_size
=
256
,
patch_size
=
4
,
in_channels
=
3
,
window_size
=
8
,
drop_rate
=
0.
,
drop_path_rate
=
0.1
,
out_indices
=
(
3
,
),
use_abs_pos_embed
=
False
,
interpolate_mode
=
'bicubic'
,
with_cp
=
False
,
frozen_stages
=-
1
,
norm_eval
=
False
,
pad_small_map
=
False
,
norm_cfg
=
dict
(
type
=
'LN'
),
stage_cfgs
=
dict
(),
patch_cfg
=
dict
(),
pretrained_window_sizes
=
[
0
,
0
,
0
,
0
],
init_cfg
=
None
):
super
(
SwinTransformerV2
,
self
).
__init__
(
init_cfg
=
init_cfg
)
if
isinstance
(
arch
,
str
):
arch
=
arch
.
lower
()
assert
arch
in
set
(
self
.
arch_zoo
),
\
f
'Arch
{
arch
}
is not in default archs
{
set
(
self
.
arch_zoo
)
}
'
self
.
arch_settings
=
self
.
arch_zoo
[
arch
]
else
:
essential_keys
=
{
'embed_dims'
,
'depths'
,
'num_heads'
,
'extra_norm_every_n_blocks'
}
assert
isinstance
(
arch
,
dict
)
and
set
(
arch
)
==
essential_keys
,
\
f
'Custom arch needs a dict with keys
{
essential_keys
}
'
self
.
arch_settings
=
arch
self
.
embed_dims
=
self
.
arch_settings
[
'embed_dims'
]
self
.
depths
=
self
.
arch_settings
[
'depths'
]
self
.
num_heads
=
self
.
arch_settings
[
'num_heads'
]
self
.
extra_norm_every_n_blocks
=
self
.
arch_settings
[
'extra_norm_every_n_blocks'
]
self
.
num_layers
=
len
(
self
.
depths
)
self
.
out_indices
=
out_indices
self
.
use_abs_pos_embed
=
use_abs_pos_embed
self
.
interpolate_mode
=
interpolate_mode
self
.
frozen_stages
=
frozen_stages
if
isinstance
(
window_size
,
int
):
self
.
window_sizes
=
[
window_size
for
_
in
range
(
self
.
num_layers
)]
elif
isinstance
(
window_size
,
Sequence
):
assert
len
(
window_size
)
==
self
.
num_layers
,
\
f
'Length of window_sizes
{
len
(
window_size
)
}
is not equal to '
\
f
'length of stages
{
self
.
num_layers
}
.'
self
.
window_sizes
=
window_size
else
:
raise
TypeError
(
'window_size should be a Sequence or int.'
)
_patch_cfg
=
dict
(
in_channels
=
in_channels
,
input_size
=
img_size
,
embed_dims
=
self
.
embed_dims
,
conv_type
=
'Conv2d'
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
norm_cfg
=
dict
(
type
=
'LN'
),
)
_patch_cfg
.
update
(
patch_cfg
)
self
.
patch_embed
=
PatchEmbed
(
**
_patch_cfg
)
self
.
patch_resolution
=
self
.
patch_embed
.
init_out_size
if
self
.
use_abs_pos_embed
:
num_patches
=
self
.
patch_resolution
[
0
]
*
self
.
patch_resolution
[
1
]
self
.
absolute_pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
,
self
.
embed_dims
))
self
.
_register_load_state_dict_pre_hook
(
self
.
_prepare_abs_pos_embed
)
self
.
_register_load_state_dict_pre_hook
(
self
.
_delete_reinit_params
)
self
.
drop_after_pos
=
nn
.
Dropout
(
p
=
drop_rate
)
self
.
norm_eval
=
norm_eval
# stochastic depth
total_depth
=
sum
(
self
.
depths
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
total_depth
)
]
# stochastic depth decay rule
self
.
stages
=
ModuleList
()
embed_dims
=
[
self
.
embed_dims
]
for
i
,
(
depth
,
num_heads
)
in
enumerate
(
zip
(
self
.
depths
,
self
.
num_heads
)):
if
isinstance
(
stage_cfgs
,
Sequence
):
stage_cfg
=
stage_cfgs
[
i
]
else
:
stage_cfg
=
deepcopy
(
stage_cfgs
)
downsample
=
True
if
i
>
0
else
False
_stage_cfg
=
{
'embed_dims'
:
embed_dims
[
-
1
],
'depth'
:
depth
,
'num_heads'
:
num_heads
,
'window_size'
:
self
.
window_sizes
[
i
],
'downsample'
:
downsample
,
'drop_paths'
:
dpr
[:
depth
],
'with_cp'
:
with_cp
,
'pad_small_map'
:
pad_small_map
,
'extra_norm_every_n_blocks'
:
self
.
extra_norm_every_n_blocks
,
'pretrained_window_size'
:
pretrained_window_sizes
[
i
],
'downsample_cfg'
:
dict
(
use_post_norm
=
True
),
**
stage_cfg
}
stage
=
SwinBlockV2Sequence
(
**
_stage_cfg
)
self
.
stages
.
append
(
stage
)
dpr
=
dpr
[
depth
:]
embed_dims
.
append
(
stage
.
out_channels
)
for
i
in
out_indices
:
if
norm_cfg
is
not
None
:
norm_layer
=
build_norm_layer
(
norm_cfg
,
embed_dims
[
i
+
1
])[
1
]
else
:
norm_layer
=
nn
.
Identity
()
self
.
add_module
(
f
'norm
{
i
}
'
,
norm_layer
)
def
init_weights
(
self
):
super
(
SwinTransformerV2
,
self
).
init_weights
()
if
(
isinstance
(
self
.
init_cfg
,
dict
)
and
self
.
init_cfg
[
'type'
]
==
'Pretrained'
):
# Suppress default init if use pretrained model.
return
if
self
.
use_abs_pos_embed
:
trunc_normal_
(
self
.
absolute_pos_embed
,
std
=
0.02
)
def
forward
(
self
,
x
):
x
,
hw_shape
=
self
.
patch_embed
(
x
)
if
self
.
use_abs_pos_embed
:
x
=
x
+
resize_pos_embed
(
self
.
absolute_pos_embed
,
self
.
patch_resolution
,
hw_shape
,
self
.
interpolate_mode
,
self
.
num_extra_tokens
)
x
=
self
.
drop_after_pos
(
x
)
outs
=
[]
for
i
,
stage
in
enumerate
(
self
.
stages
):
x
,
hw_shape
=
stage
(
x
,
hw_shape
)
if
i
in
self
.
out_indices
:
norm_layer
=
getattr
(
self
,
f
'norm
{
i
}
'
)
out
=
norm_layer
(
x
)
out
=
out
.
view
(
-
1
,
*
hw_shape
,
stage
.
out_channels
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
out
)
return
tuple
(
outs
)
def
_freeze_stages
(
self
):
if
self
.
frozen_stages
>=
0
:
self
.
patch_embed
.
eval
()
for
param
in
self
.
patch_embed
.
parameters
():
param
.
requires_grad
=
False
for
i
in
range
(
0
,
self
.
frozen_stages
+
1
):
m
=
self
.
stages
[
i
]
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
for
i
in
self
.
out_indices
:
if
i
<=
self
.
frozen_stages
:
for
param
in
getattr
(
self
,
f
'norm
{
i
}
'
).
parameters
():
param
.
requires_grad
=
False
def
train
(
self
,
mode
=
True
):
super
(
SwinTransformerV2
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
if
mode
and
self
.
norm_eval
:
for
m
in
self
.
modules
():
# trick: eval have effect on BatchNorm only
if
isinstance
(
m
,
_BatchNorm
):
m
.
eval
()
def
_prepare_abs_pos_embed
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
name
=
prefix
+
'absolute_pos_embed'
if
name
not
in
state_dict
.
keys
():
return
ckpt_pos_embed_shape
=
state_dict
[
name
].
shape
if
self
.
absolute_pos_embed
.
shape
!=
ckpt_pos_embed_shape
:
from
mmengine.logging
import
MMLogger
logger
=
MMLogger
.
get_current_instance
()
logger
.
info
(
'Resize the absolute_pos_embed shape from '
f
'
{
ckpt_pos_embed_shape
}
to
{
self
.
absolute_pos_embed
.
shape
}
.'
)
ckpt_pos_embed_shape
=
to_2tuple
(
int
(
np
.
sqrt
(
ckpt_pos_embed_shape
[
1
]
-
self
.
num_extra_tokens
)))
pos_embed_shape
=
self
.
patch_embed
.
init_out_size
state_dict
[
name
]
=
resize_pos_embed
(
state_dict
[
name
],
ckpt_pos_embed_shape
,
pos_embed_shape
,
self
.
interpolate_mode
,
self
.
num_extra_tokens
)
def
_delete_reinit_params
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
# delete relative_position_index since we always re-init it
from
mmengine.logging
import
MMLogger
logger
=
MMLogger
.
get_current_instance
()
logger
.
info
(
'Delete `relative_position_index` and `relative_coords_table` '
'since we always re-init these params according to the '
'`window_size`, which might cause unwanted but unworried '
'warnings when loading checkpoint.'
)
relative_position_index_keys
=
[
k
for
k
in
state_dict
.
keys
()
if
'relative_position_index'
in
k
]
for
k
in
relative_position_index_keys
:
del
state_dict
[
k
]
# delete relative_coords_table since we always re-init it
relative_position_index_keys
=
[
k
for
k
in
state_dict
.
keys
()
if
'relative_coords_table'
in
k
]
for
k
in
relative_position_index_keys
:
del
state_dict
[
k
]
mmpretrain/models/backbones/t2t_vit.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
from
typing
import
Sequence
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.cnn.bricks.transformer
import
FFN
from
mmengine.model
import
BaseModule
,
ModuleList
from
mmengine.model.weight_init
import
trunc_normal_
from
mmpretrain.registry
import
MODELS
from
..utils
import
(
MultiheadAttention
,
build_norm_layer
,
resize_pos_embed
,
to_2tuple
)
from
.base_backbone
import
BaseBackbone
class
T2TTransformerLayer
(
BaseModule
):
"""Transformer Layer for T2T_ViT.
Comparing with :obj:`TransformerEncoderLayer` in ViT, it supports
different ``input_dims`` and ``embed_dims``.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs
input_dims (int, optional): The input token dimension.
Defaults to None.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop_rate (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
qk_scale (float, optional): Override default qk scale of
``(input_dims // num_heads) ** -0.5`` if set. Defaults to None.
act_cfg (dict): The activation config for FFNs.
Defaults to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
Notes:
In general, ``qk_scale`` should be ``head_dims ** -0.5``, i.e.
``(embed_dims // num_heads) ** -0.5``. However, in the official
code, it uses ``(input_dims // num_heads) ** -0.5``, so here we
keep the same with the official implementation.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
feedforward_channels
,
input_dims
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
num_fcs
=
2
,
qkv_bias
=
False
,
qk_scale
=
None
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
dict
(
type
=
'LN'
),
init_cfg
=
None
):
super
(
T2TTransformerLayer
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
v_shortcut
=
True
if
input_dims
is
not
None
else
False
input_dims
=
input_dims
or
embed_dims
self
.
ln1
=
build_norm_layer
(
norm_cfg
,
input_dims
)
self
.
attn
=
MultiheadAttention
(
input_dims
=
input_dims
,
embed_dims
=
embed_dims
,
num_heads
=
num_heads
,
attn_drop
=
attn_drop_rate
,
proj_drop
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
or
(
input_dims
//
num_heads
)
**-
0.5
,
v_shortcut
=
self
.
v_shortcut
)
self
.
ln2
=
build_norm_layer
(
norm_cfg
,
embed_dims
)
self
.
ffn
=
FFN
(
embed_dims
=
embed_dims
,
feedforward_channels
=
feedforward_channels
,
num_fcs
=
num_fcs
,
ffn_drop
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
act_cfg
=
act_cfg
)
def
forward
(
self
,
x
):
if
self
.
v_shortcut
:
x
=
self
.
attn
(
self
.
ln1
(
x
))
else
:
x
=
x
+
self
.
attn
(
self
.
ln1
(
x
))
x
=
self
.
ffn
(
self
.
ln2
(
x
),
identity
=
x
)
return
x
class
T2TModule
(
BaseModule
):
"""Tokens-to-Token module.
"Tokens-to-Token module" (T2T Module) can model the local structure
information of images and reduce the length of tokens progressively.
Args:
img_size (int): Input image size
in_channels (int): Number of input channels
embed_dims (int): Embedding dimension
token_dims (int): Tokens dimension in T2TModuleAttention.
use_performer (bool): If True, use Performer version self-attention to
adopt regular self-attention. Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Default: None.
Notes:
Usually, ``token_dim`` is set as a small value (32 or 64) to reduce
MACs
"""
def
__init__
(
self
,
img_size
=
224
,
in_channels
=
3
,
embed_dims
=
384
,
token_dims
=
64
,
use_performer
=
False
,
init_cfg
=
None
,
):
super
(
T2TModule
,
self
).
__init__
(
init_cfg
)
self
.
embed_dims
=
embed_dims
self
.
soft_split0
=
nn
.
Unfold
(
kernel_size
=
(
7
,
7
),
stride
=
(
4
,
4
),
padding
=
(
2
,
2
))
self
.
soft_split1
=
nn
.
Unfold
(
kernel_size
=
(
3
,
3
),
stride
=
(
2
,
2
),
padding
=
(
1
,
1
))
self
.
soft_split2
=
nn
.
Unfold
(
kernel_size
=
(
3
,
3
),
stride
=
(
2
,
2
),
padding
=
(
1
,
1
))
if
not
use_performer
:
self
.
attention1
=
T2TTransformerLayer
(
input_dims
=
in_channels
*
7
*
7
,
embed_dims
=
token_dims
,
num_heads
=
1
,
feedforward_channels
=
token_dims
)
self
.
attention2
=
T2TTransformerLayer
(
input_dims
=
token_dims
*
3
*
3
,
embed_dims
=
token_dims
,
num_heads
=
1
,
feedforward_channels
=
token_dims
)
self
.
project
=
nn
.
Linear
(
token_dims
*
3
*
3
,
embed_dims
)
else
:
raise
NotImplementedError
(
"Performer hasn't been implemented."
)
# there are 3 soft split, stride are 4,2,2 separately
out_side
=
img_size
//
(
4
*
2
*
2
)
self
.
init_out_size
=
[
out_side
,
out_side
]
self
.
num_patches
=
out_side
**
2
@
staticmethod
def
_get_unfold_size
(
unfold
:
nn
.
Unfold
,
input_size
):
h
,
w
=
input_size
kernel_size
=
to_2tuple
(
unfold
.
kernel_size
)
stride
=
to_2tuple
(
unfold
.
stride
)
padding
=
to_2tuple
(
unfold
.
padding
)
dilation
=
to_2tuple
(
unfold
.
dilation
)
h_out
=
(
h
+
2
*
padding
[
0
]
-
dilation
[
0
]
*
(
kernel_size
[
0
]
-
1
)
-
1
)
//
stride
[
0
]
+
1
w_out
=
(
w
+
2
*
padding
[
1
]
-
dilation
[
1
]
*
(
kernel_size
[
1
]
-
1
)
-
1
)
//
stride
[
1
]
+
1
return
(
h_out
,
w_out
)
def
forward
(
self
,
x
):
# step0: soft split
hw_shape
=
self
.
_get_unfold_size
(
self
.
soft_split0
,
x
.
shape
[
2
:])
x
=
self
.
soft_split0
(
x
).
transpose
(
1
,
2
)
for
step
in
[
1
,
2
]:
# re-structurization/reconstruction
attn
=
getattr
(
self
,
f
'attention
{
step
}
'
)
x
=
attn
(
x
).
transpose
(
1
,
2
)
B
,
C
,
_
=
x
.
shape
x
=
x
.
reshape
(
B
,
C
,
hw_shape
[
0
],
hw_shape
[
1
])
# soft split
soft_split
=
getattr
(
self
,
f
'soft_split
{
step
}
'
)
hw_shape
=
self
.
_get_unfold_size
(
soft_split
,
hw_shape
)
x
=
soft_split
(
x
).
transpose
(
1
,
2
)
# final tokens
x
=
self
.
project
(
x
)
return
x
,
hw_shape
def
get_sinusoid_encoding
(
n_position
,
embed_dims
):
"""Generate sinusoid encoding table.
Sinusoid encoding is a kind of relative position encoding method came from
`Attention Is All You Need<https://arxiv.org/abs/1706.03762>`_.
Args:
n_position (int): The length of the input token.
embed_dims (int): The position embedding dimension.
Returns:
:obj:`torch.FloatTensor`: The sinusoid encoding table.
"""
def
get_position_angle_vec
(
position
):
return
[
position
/
np
.
power
(
10000
,
2
*
(
i
//
2
)
/
embed_dims
)
for
i
in
range
(
embed_dims
)
]
sinusoid_table
=
np
.
array
(
[
get_position_angle_vec
(
pos
)
for
pos
in
range
(
n_position
)])
sinusoid_table
[:,
0
::
2
]
=
np
.
sin
(
sinusoid_table
[:,
0
::
2
])
# dim 2i
sinusoid_table
[:,
1
::
2
]
=
np
.
cos
(
sinusoid_table
[:,
1
::
2
])
# dim 2i+1
return
torch
.
FloatTensor
(
sinusoid_table
).
unsqueeze
(
0
)
@
MODELS
.
register_module
()
class
T2T_ViT
(
BaseBackbone
):
"""Tokens-to-Token Vision Transformer (T2T-ViT)
A PyTorch implementation of `Tokens-to-Token ViT: Training Vision
Transformers from Scratch on ImageNet <https://arxiv.org/abs/2101.11986>`_
Args:
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
in_channels (int): Number of input channels.
embed_dims (int): Embedding dimension.
num_layers (int): Num of transformer layers in encoder.
Defaults to 14.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Dropout rate after position embedding.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
norm_cfg (dict): Config dict for normalization layer. Defaults to
``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
Defaults to ``"cls_token"``.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
t2t_cfg (dict): Extra config of Tokens-to-Token module.
Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
OUT_TYPES
=
{
'raw'
,
'cls_token'
,
'featmap'
,
'avg_featmap'
}
def
__init__
(
self
,
img_size
=
224
,
in_channels
=
3
,
embed_dims
=
384
,
num_layers
=
14
,
out_indices
=-
1
,
drop_rate
=
0.
,
drop_path_rate
=
0.
,
norm_cfg
=
dict
(
type
=
'LN'
),
final_norm
=
True
,
out_type
=
'cls_token'
,
with_cls_token
=
True
,
interpolate_mode
=
'bicubic'
,
t2t_cfg
=
dict
(),
layer_cfgs
=
dict
(),
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
# Token-to-Token Module
self
.
tokens_to_token
=
T2TModule
(
img_size
=
img_size
,
in_channels
=
in_channels
,
embed_dims
=
embed_dims
,
**
t2t_cfg
)
self
.
patch_resolution
=
self
.
tokens_to_token
.
init_out_size
num_patches
=
self
.
patch_resolution
[
0
]
*
self
.
patch_resolution
[
1
]
# Set out type
if
out_type
not
in
self
.
OUT_TYPES
:
raise
ValueError
(
f
'Unsupported `out_type`
{
out_type
}
, please '
f
'choose from
{
self
.
OUT_TYPES
}
'
)
self
.
out_type
=
out_type
# Set cls token
if
with_cls_token
:
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
embed_dims
))
self
.
num_extra_tokens
=
1
elif
out_type
!=
'cls_token'
:
self
.
cls_token
=
None
self
.
num_extra_tokens
=
0
else
:
raise
ValueError
(
'with_cls_token must be True when `out_type="cls_token"`.'
)
# Set position embedding
self
.
interpolate_mode
=
interpolate_mode
sinusoid_table
=
get_sinusoid_encoding
(
num_patches
+
self
.
num_extra_tokens
,
embed_dims
)
self
.
register_buffer
(
'pos_embed'
,
sinusoid_table
)
self
.
_register_load_state_dict_pre_hook
(
self
.
_prepare_pos_embed
)
self
.
drop_after_pos
=
nn
.
Dropout
(
p
=
drop_rate
)
if
isinstance
(
out_indices
,
int
):
out_indices
=
[
out_indices
]
assert
isinstance
(
out_indices
,
Sequence
),
\
f
'"out_indices" must be a sequence or int, '
\
f
'get
{
type
(
out_indices
)
}
instead.'
for
i
,
index
in
enumerate
(
out_indices
):
if
index
<
0
:
out_indices
[
i
]
=
num_layers
+
index
assert
0
<=
out_indices
[
i
]
<=
num_layers
,
\
f
'Invalid out_indices
{
index
}
'
self
.
out_indices
=
out_indices
# stochastic depth decay rule
dpr
=
[
x
for
x
in
np
.
linspace
(
0
,
drop_path_rate
,
num_layers
)]
self
.
encoder
=
ModuleList
()
for
i
in
range
(
num_layers
):
if
isinstance
(
layer_cfgs
,
Sequence
):
layer_cfg
=
layer_cfgs
[
i
]
else
:
layer_cfg
=
deepcopy
(
layer_cfgs
)
layer_cfg
=
{
'embed_dims'
:
embed_dims
,
'num_heads'
:
6
,
'feedforward_channels'
:
3
*
embed_dims
,
'drop_path_rate'
:
dpr
[
i
],
'qkv_bias'
:
False
,
'norm_cfg'
:
norm_cfg
,
**
layer_cfg
}
layer
=
T2TTransformerLayer
(
**
layer_cfg
)
self
.
encoder
.
append
(
layer
)
self
.
final_norm
=
final_norm
if
final_norm
:
self
.
norm
=
build_norm_layer
(
norm_cfg
,
embed_dims
)
else
:
self
.
norm
=
nn
.
Identity
()
def
init_weights
(
self
):
super
().
init_weights
()
if
(
isinstance
(
self
.
init_cfg
,
dict
)
and
self
.
init_cfg
[
'type'
]
==
'Pretrained'
):
# Suppress custom init if use pretrained model.
return
trunc_normal_
(
self
.
cls_token
,
std
=
.
02
)
def
_prepare_pos_embed
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
name
=
prefix
+
'pos_embed'
if
name
not
in
state_dict
.
keys
():
return
ckpt_pos_embed_shape
=
state_dict
[
name
].
shape
if
self
.
pos_embed
.
shape
!=
ckpt_pos_embed_shape
:
from
mmengine.logging
import
MMLogger
logger
=
MMLogger
.
get_current_instance
()
logger
.
info
(
f
'Resize the pos_embed shape from
{
ckpt_pos_embed_shape
}
'
f
'to
{
self
.
pos_embed
.
shape
}
.'
)
ckpt_pos_embed_shape
=
to_2tuple
(
int
(
np
.
sqrt
(
ckpt_pos_embed_shape
[
1
]
-
self
.
num_extra_tokens
)))
pos_embed_shape
=
self
.
tokens_to_token
.
init_out_size
state_dict
[
name
]
=
resize_pos_embed
(
state_dict
[
name
],
ckpt_pos_embed_shape
,
pos_embed_shape
,
self
.
interpolate_mode
,
self
.
num_extra_tokens
)
def
forward
(
self
,
x
):
B
=
x
.
shape
[
0
]
x
,
patch_resolution
=
self
.
tokens_to_token
(
x
)
if
self
.
cls_token
is
not
None
:
# stole cls_tokens impl from Phil Wang, thanks
cls_token
=
self
.
cls_token
.
expand
(
B
,
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_token
,
x
),
dim
=
1
)
x
=
x
+
resize_pos_embed
(
self
.
pos_embed
,
self
.
patch_resolution
,
patch_resolution
,
mode
=
self
.
interpolate_mode
,
num_extra_tokens
=
self
.
num_extra_tokens
)
x
=
self
.
drop_after_pos
(
x
)
outs
=
[]
for
i
,
layer
in
enumerate
(
self
.
encoder
):
x
=
layer
(
x
)
if
i
==
len
(
self
.
encoder
)
-
1
and
self
.
final_norm
:
x
=
self
.
norm
(
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
self
.
_format_output
(
x
,
patch_resolution
))
return
tuple
(
outs
)
def
_format_output
(
self
,
x
,
hw
):
if
self
.
out_type
==
'raw'
:
return
x
if
self
.
out_type
==
'cls_token'
:
return
x
[:,
0
]
patch_token
=
x
[:,
self
.
num_extra_tokens
:]
if
self
.
out_type
==
'featmap'
:
B
=
x
.
size
(
0
)
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
return
patch_token
.
reshape
(
B
,
*
hw
,
-
1
).
permute
(
0
,
3
,
1
,
2
)
if
self
.
out_type
==
'avg_featmap'
:
return
patch_token
.
mean
(
dim
=
1
)
mmpretrain/models/backbones/timm_backbone.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
import
warnings
from
mmengine.logging
import
MMLogger
from
mmpretrain.registry
import
MODELS
from
mmpretrain.utils
import
require
from
.base_backbone
import
BaseBackbone
def
print_timm_feature_info
(
feature_info
):
"""Print feature_info of timm backbone to help development and debug.
Args:
feature_info (list[dict] | timm.models.features.FeatureInfo | None):
feature_info of timm backbone.
"""
logger
=
MMLogger
.
get_current_instance
()
if
feature_info
is
None
:
logger
.
warning
(
'This backbone does not have feature_info'
)
elif
isinstance
(
feature_info
,
list
):
for
feat_idx
,
each_info
in
enumerate
(
feature_info
):
logger
.
info
(
f
'backbone feature_info[
{
feat_idx
}
]:
{
each_info
}
'
)
else
:
try
:
logger
.
info
(
f
'backbone out_indices:
{
feature_info
.
out_indices
}
'
)
logger
.
info
(
f
'backbone out_channels:
{
feature_info
.
channels
()
}
'
)
logger
.
info
(
f
'backbone out_strides:
{
feature_info
.
reduction
()
}
'
)
except
AttributeError
:
logger
.
warning
(
'Unexpected format of backbone feature_info'
)
@
MODELS
.
register_module
()
class
TIMMBackbone
(
BaseBackbone
):
"""Wrapper to use backbones from timm library.
More details can be found in
`timm <https://github.com/rwightman/pytorch-image-models>`_.
See especially the document for `feature extraction
<https://rwightman.github.io/pytorch-image-models/feature_extraction/>`_.
Args:
model_name (str): Name of timm model to instantiate.
features_only (bool): Whether to extract feature pyramid (multi-scale
feature maps from the deepest layer at each stride). For Vision
Transformer models that do not support this argument,
set this False. Defaults to False.
pretrained (bool): Whether to load pretrained weights.
Defaults to False.
checkpoint_path (str): Path of checkpoint to load at the last of
``timm.create_model``. Defaults to empty string, which means
not loading.
in_channels (int): Number of input image channels. Defaults to 3.
init_cfg (dict or list[dict], optional): Initialization config dict of
OpenMMLab projects. Defaults to None.
**kwargs: Other timm & model specific arguments.
"""
@
require
(
'timm'
)
def
__init__
(
self
,
model_name
,
features_only
=
False
,
pretrained
=
False
,
checkpoint_path
=
''
,
in_channels
=
3
,
init_cfg
=
None
,
**
kwargs
):
import
timm
if
not
isinstance
(
pretrained
,
bool
):
raise
TypeError
(
'pretrained must be bool, not str for model path'
)
if
features_only
and
checkpoint_path
:
warnings
.
warn
(
'Using both features_only and checkpoint_path will cause error'
' in timm. See '
'https://github.com/rwightman/pytorch-image-models/issues/488'
)
super
(
TIMMBackbone
,
self
).
__init__
(
init_cfg
)
if
'norm_layer'
in
kwargs
:
norm_class
=
MODELS
.
get
(
kwargs
[
'norm_layer'
])
def
build_norm
(
*
args
,
**
kwargs
):
return
norm_class
(
*
args
,
**
kwargs
)
kwargs
[
'norm_layer'
]
=
build_norm
self
.
timm_model
=
timm
.
create_model
(
model_name
=
model_name
,
features_only
=
features_only
,
pretrained
=
pretrained
,
in_chans
=
in_channels
,
checkpoint_path
=
checkpoint_path
,
**
kwargs
)
# reset classifier
if
hasattr
(
self
.
timm_model
,
'reset_classifier'
):
self
.
timm_model
.
reset_classifier
(
0
,
''
)
# Hack to use pretrained weights from timm
if
pretrained
or
checkpoint_path
:
self
.
_is_init
=
True
feature_info
=
getattr
(
self
.
timm_model
,
'feature_info'
,
None
)
print_timm_feature_info
(
feature_info
)
def
forward
(
self
,
x
):
features
=
self
.
timm_model
(
x
)
if
isinstance
(
features
,
(
list
,
tuple
)):
features
=
tuple
(
features
)
else
:
features
=
(
features
,
)
return
features
mmpretrain/models/backbones/tinyvit.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Sequence
,
Tuple
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
checkpoint
from
mmcv.cnn.bricks
import
DropPath
,
build_activation_layer
,
build_norm_layer
from
mmengine.model
import
BaseModule
,
ModuleList
,
Sequential
from
torch.nn
import
functional
as
F
from
mmpretrain.registry
import
MODELS
from
..utils
import
LeAttention
from
.base_backbone
import
BaseBackbone
class
ConvBN2d
(
Sequential
):
"""An implementation of Conv2d + BatchNorm2d with support of fusion.
Modified from
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
kernel_size (int): The size of the convolution kernel.
Default: 1.
stride (int): The stride of the convolution.
Default: 1.
padding (int): The padding of the convolution.
Default: 0.
dilation (int): The dilation of the convolution.
Default: 1.
groups (int): The number of groups in the convolution.
Default: 1.
bn_weight_init (float): The initial value of the weight of
the nn.BatchNorm2d layer. Default: 1.0.
init_cfg (dict): The initialization config of the module.
Default: None.
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
bn_weight_init
=
1.0
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
=
init_cfg
)
self
.
add_module
(
'conv2d'
,
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
groups
=
groups
,
bias
=
False
))
bn2d
=
nn
.
BatchNorm2d
(
num_features
=
out_channels
)
# bn initialization
torch
.
nn
.
init
.
constant_
(
bn2d
.
weight
,
bn_weight_init
)
torch
.
nn
.
init
.
constant_
(
bn2d
.
bias
,
0
)
self
.
add_module
(
'bn2d'
,
bn2d
)
@
torch
.
no_grad
()
def
fuse
(
self
):
conv2d
,
bn2d
=
self
.
_modules
.
values
()
w
=
bn2d
.
weight
/
(
bn2d
.
running_var
+
bn2d
.
eps
)
**
0.5
w
=
conv2d
.
weight
*
w
[:,
None
,
None
,
None
]
b
=
bn2d
.
bias
-
bn2d
.
running_mean
*
bn2d
.
weight
/
\
(
bn2d
.
running_var
+
bn2d
.
eps
)
**
0.5
m
=
nn
.
Conv2d
(
in_channels
=
w
.
size
(
1
)
*
self
.
c
.
groups
,
out_channels
=
w
.
size
(
0
),
kernel_size
=
w
.
shape
[
2
:],
stride
=
self
.
conv2d
.
stride
,
padding
=
self
.
conv2d
.
padding
,
dilation
=
self
.
conv2d
.
dilation
,
groups
=
self
.
conv2d
.
groups
)
m
.
weight
.
data
.
copy_
(
w
)
m
.
bias
.
data
.
copy_
(
b
)
return
m
class
PatchEmbed
(
BaseModule
):
"""Patch Embedding for Vision Transformer.
Adapted from
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py
Different from `mmcv.cnn.bricks.transformer.PatchEmbed`, this module use
Conv2d and BatchNorm2d to implement PatchEmbedding, and output shape is
(N, C, H, W).
Args:
in_channels (int): The number of input channels.
embed_dim (int): The embedding dimension.
resolution (Tuple[int, int]): The resolution of the input feature.
act_cfg (dict): The activation config of the module.
Default: dict(type='GELU').
"""
def
__init__
(
self
,
in_channels
,
embed_dim
,
resolution
,
act_cfg
=
dict
(
type
=
'GELU'
)):
super
().
__init__
()
img_size
:
Tuple
[
int
,
int
]
=
resolution
self
.
patches_resolution
=
(
img_size
[
0
]
//
4
,
img_size
[
1
]
//
4
)
self
.
num_patches
=
self
.
patches_resolution
[
0
]
*
\
self
.
patches_resolution
[
1
]
self
.
in_channels
=
in_channels
self
.
embed_dim
=
embed_dim
self
.
seq
=
nn
.
Sequential
(
ConvBN2d
(
in_channels
,
embed_dim
//
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
),
build_activation_layer
(
act_cfg
),
ConvBN2d
(
embed_dim
//
2
,
embed_dim
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
),
)
def
forward
(
self
,
x
):
return
self
.
seq
(
x
)
class
PatchMerging
(
nn
.
Module
):
"""Patch Merging for TinyViT.
Adapted from
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py
Different from `mmpretrain.models.utils.PatchMerging`, this module use
Conv2d and BatchNorm2d to implement PatchMerging.
Args:
in_channels (int): The number of input channels.
resolution (Tuple[int, int]): The resolution of the input feature.
out_channels (int): The number of output channels.
act_cfg (dict): The activation config of the module.
Default: dict(type='GELU').
"""
def
__init__
(
self
,
resolution
,
in_channels
,
out_channels
,
act_cfg
=
dict
(
type
=
'GELU'
)):
super
().
__init__
()
self
.
img_size
=
resolution
self
.
act
=
build_activation_layer
(
act_cfg
)
self
.
conv1
=
ConvBN2d
(
in_channels
,
out_channels
,
kernel_size
=
1
)
self
.
conv2
=
ConvBN2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
groups
=
out_channels
)
self
.
conv3
=
ConvBN2d
(
out_channels
,
out_channels
,
kernel_size
=
1
)
self
.
out_resolution
=
(
resolution
[
0
]
//
2
,
resolution
[
1
]
//
2
)
def
forward
(
self
,
x
):
if
len
(
x
.
shape
)
==
3
:
H
,
W
=
self
.
img_size
B
=
x
.
shape
[
0
]
x
=
x
.
view
(
B
,
H
,
W
,
-
1
).
permute
(
0
,
3
,
1
,
2
)
x
=
self
.
conv1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
conv3
(
x
)
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
return
x
class
MBConvBlock
(
nn
.
Module
):
"""Mobile Inverted Residual Bottleneck Block for TinyViT. Adapted from
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
expand_ratio (int): The expand ratio of the hidden channels.
drop_rate (float): The drop rate of the block.
act_cfg (dict): The activation config of the module.
Default: dict(type='GELU').
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
expand_ratio
,
drop_path
,
act_cfg
=
dict
(
type
=
'GELU'
)):
super
().
__init__
()
self
.
in_channels
=
in_channels
hidden_channels
=
int
(
in_channels
*
expand_ratio
)
# linear
self
.
conv1
=
ConvBN2d
(
in_channels
,
hidden_channels
,
kernel_size
=
1
)
self
.
act
=
build_activation_layer
(
act_cfg
)
# depthwise conv
self
.
conv2
=
ConvBN2d
(
in_channels
=
hidden_channels
,
out_channels
=
hidden_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
groups
=
hidden_channels
)
# linear
self
.
conv3
=
ConvBN2d
(
hidden_channels
,
out_channels
,
kernel_size
=
1
,
bn_weight_init
=
0.0
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
def
forward
(
self
,
x
):
shortcut
=
x
x
=
self
.
conv1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
conv2
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
conv3
(
x
)
x
=
self
.
drop_path
(
x
)
x
+=
shortcut
x
=
self
.
act
(
x
)
return
x
class
ConvStage
(
BaseModule
):
"""Convolution Stage for TinyViT.
Adapted from
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py
Args:
in_channels (int): The number of input channels.
resolution (Tuple[int, int]): The resolution of the input feature.
depth (int): The number of blocks in the stage.
act_cfg (dict): The activation config of the module.
drop_path (float): The drop path of the block.
downsample (None | nn.Module): The downsample operation.
Default: None.
use_checkpoint (bool): Whether to use checkpointing to save memory.
out_channels (int): The number of output channels.
conv_expand_ratio (int): The expand ratio of the hidden channels.
Default: 4.
init_cfg (dict | list[dict], optional): Initialization config dict.
Default: None.
"""
def
__init__
(
self
,
in_channels
,
resolution
,
depth
,
act_cfg
,
drop_path
=
0.
,
downsample
=
None
,
use_checkpoint
=
False
,
out_channels
=
None
,
conv_expand_ratio
=
4.
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
=
init_cfg
)
self
.
use_checkpoint
=
use_checkpoint
# build blocks
self
.
blocks
=
ModuleList
([
MBConvBlock
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
expand_ratio
=
conv_expand_ratio
,
drop_path
=
drop_path
[
i
]
if
isinstance
(
drop_path
,
list
)
else
drop_path
)
for
i
in
range
(
depth
)
])
# patch merging layer
if
downsample
is
not
None
:
self
.
downsample
=
downsample
(
resolution
=
resolution
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
act_cfg
=
act_cfg
)
self
.
resolution
=
self
.
downsample
.
out_resolution
else
:
self
.
downsample
=
None
self
.
resolution
=
resolution
def
forward
(
self
,
x
):
for
block
in
self
.
blocks
:
if
self
.
use_checkpoint
:
x
=
checkpoint
.
checkpoint
(
block
,
x
)
else
:
x
=
block
(
x
)
if
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
return
x
class
MLP
(
BaseModule
):
"""MLP module for TinyViT.
Args:
in_channels (int): The number of input channels.
hidden_channels (int, optional): The number of hidden channels.
Default: None.
out_channels (int, optional): The number of output channels.
Default: None.
act_cfg (dict): The activation config of the module.
Default: dict(type='GELU').
drop (float): Probability of an element to be zeroed.
Default: 0.
init_cfg (dict | list[dict], optional): Initialization config dict.
Default: None.
"""
def
__init__
(
self
,
in_channels
,
hidden_channels
=
None
,
out_channels
=
None
,
act_cfg
=
dict
(
type
=
'GELU'
),
drop
=
0.
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
=
init_cfg
)
out_channels
=
out_channels
or
in_channels
hidden_channels
=
hidden_channels
or
in_channels
self
.
norm
=
nn
.
LayerNorm
(
in_channels
)
self
.
fc1
=
nn
.
Linear
(
in_channels
,
hidden_channels
)
self
.
fc2
=
nn
.
Linear
(
hidden_channels
,
out_channels
)
self
.
act
=
build_activation_layer
(
act_cfg
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
norm
(
x
)
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
class
TinyViTBlock
(
BaseModule
):
"""TinViT Block.
Args:
in_channels (int): The number of input channels.
resolution (Tuple[int, int]): The resolution of the input feature.
num_heads (int): The number of heads in the multi-head attention.
window_size (int): The size of the window.
Default: 7.
mlp_ratio (float): The ratio of mlp hidden dim to embedding dim.
Default: 4.
drop (float): Probability of an element to be zeroed.
Default: 0.
drop_path (float): The drop path of the block.
Default: 0.
local_conv_size (int): The size of the local convolution.
Default: 3.
act_cfg (dict): The activation config of the module.
Default: dict(type='GELU').
"""
def
__init__
(
self
,
in_channels
,
resolution
,
num_heads
,
window_size
=
7
,
mlp_ratio
=
4.
,
drop
=
0.
,
drop_path
=
0.
,
local_conv_size
=
3
,
act_cfg
=
dict
(
type
=
'GELU'
)):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
img_size
=
resolution
self
.
num_heads
=
num_heads
assert
window_size
>
0
,
'window_size must be greater than 0'
self
.
window_size
=
window_size
self
.
mlp_ratio
=
mlp_ratio
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
assert
in_channels
%
num_heads
==
0
,
\
'dim must be divisible by num_heads'
head_dim
=
in_channels
//
num_heads
window_resolution
=
(
window_size
,
window_size
)
self
.
attn
=
LeAttention
(
in_channels
,
head_dim
,
num_heads
,
attn_ratio
=
1
,
resolution
=
window_resolution
)
mlp_hidden_dim
=
int
(
in_channels
*
mlp_ratio
)
self
.
mlp
=
MLP
(
in_channels
=
in_channels
,
hidden_channels
=
mlp_hidden_dim
,
act_cfg
=
act_cfg
,
drop
=
drop
)
self
.
local_conv
=
ConvBN2d
(
in_channels
=
in_channels
,
out_channels
=
in_channels
,
kernel_size
=
local_conv_size
,
stride
=
1
,
padding
=
local_conv_size
//
2
,
groups
=
in_channels
)
def
forward
(
self
,
x
):
H
,
W
=
self
.
img_size
B
,
L
,
C
=
x
.
shape
assert
L
==
H
*
W
,
'input feature has wrong size'
res_x
=
x
if
H
==
self
.
window_size
and
W
==
self
.
window_size
:
x
=
self
.
attn
(
x
)
else
:
x
=
x
.
view
(
B
,
H
,
W
,
C
)
pad_b
=
(
self
.
window_size
-
H
%
self
.
window_size
)
%
self
.
window_size
pad_r
=
(
self
.
window_size
-
W
%
self
.
window_size
)
%
self
.
window_size
padding
=
pad_b
>
0
or
pad_r
>
0
if
padding
:
x
=
F
.
pad
(
x
,
(
0
,
0
,
0
,
pad_r
,
0
,
pad_b
))
pH
,
pW
=
H
+
pad_b
,
W
+
pad_r
nH
=
pH
//
self
.
window_size
nW
=
pW
//
self
.
window_size
# window partition
x
=
x
.
view
(
B
,
nH
,
self
.
window_size
,
nW
,
self
.
window_size
,
C
).
transpose
(
2
,
3
).
reshape
(
B
*
nH
*
nW
,
self
.
window_size
*
self
.
window_size
,
C
)
x
=
self
.
attn
(
x
)
# window reverse
x
=
x
.
view
(
B
,
nH
,
nW
,
self
.
window_size
,
self
.
window_size
,
C
).
transpose
(
2
,
3
).
reshape
(
B
,
pH
,
pW
,
C
)
if
padding
:
x
=
x
[:,
:
H
,
:
W
].
contiguous
()
x
=
x
.
view
(
B
,
L
,
C
)
x
=
res_x
+
self
.
drop_path
(
x
)
x
=
x
.
transpose
(
1
,
2
).
reshape
(
B
,
C
,
H
,
W
)
x
=
self
.
local_conv
(
x
)
x
=
x
.
view
(
B
,
C
,
L
).
transpose
(
1
,
2
)
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
x
))
return
x
class
BasicStage
(
BaseModule
):
"""Basic Stage for TinyViT.
Args:
in_channels (int): The number of input channels.
resolution (Tuple[int, int]): The resolution of the input feature.
depth (int): The number of blocks in the stage.
num_heads (int): The number of heads in the multi-head attention.
window_size (int): The size of the window.
mlp_ratio (float): The ratio of mlp hidden dim to embedding dim.
Default: 4.
drop (float): Probability of an element to be zeroed.
Default: 0.
drop_path (float): The drop path of the block.
Default: 0.
downsample (None | nn.Module): The downsample operation.
Default: None.
use_checkpoint (bool): Whether to use checkpointing to save memory.
Default: False.
act_cfg (dict): The activation config of the module.
Default: dict(type='GELU').
init_cfg (dict | list[dict], optional): Initialization config dict.
Default: None.
"""
def
__init__
(
self
,
in_channels
,
resolution
,
depth
,
num_heads
,
window_size
,
mlp_ratio
=
4.
,
drop
=
0.
,
drop_path
=
0.
,
downsample
=
None
,
use_checkpoint
=
False
,
local_conv_size
=
3
,
out_channels
=
None
,
act_cfg
=
dict
(
type
=
'GELU'
),
init_cfg
=
None
):
super
().
__init__
(
init_cfg
=
init_cfg
)
self
.
use_checkpoint
=
use_checkpoint
# build blocks
self
.
blocks
=
ModuleList
([
TinyViTBlock
(
in_channels
=
in_channels
,
resolution
=
resolution
,
num_heads
=
num_heads
,
window_size
=
window_size
,
mlp_ratio
=
mlp_ratio
,
drop
=
drop
,
local_conv_size
=
local_conv_size
,
act_cfg
=
act_cfg
,
drop_path
=
drop_path
[
i
]
if
isinstance
(
drop_path
,
list
)
else
drop_path
)
for
i
in
range
(
depth
)
])
# build patch merging layer
if
downsample
is
not
None
:
self
.
downsample
=
downsample
(
resolution
=
resolution
,
in_channels
=
in_channels
,
out_channels
=
out_channels
,
act_cfg
=
act_cfg
)
self
.
resolution
=
self
.
downsample
.
out_resolution
else
:
self
.
downsample
=
None
self
.
resolution
=
resolution
def
forward
(
self
,
x
):
for
block
in
self
.
blocks
:
if
self
.
use_checkpoint
:
x
=
checkpoint
.
checkpoint
(
block
,
x
)
else
:
x
=
block
(
x
)
if
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
return
x
@
MODELS
.
register_module
()
class
TinyViT
(
BaseBackbone
):
"""TinyViT.
A PyTorch implementation of : `TinyViT: Fast Pretraining Distillation
for Small Vision Transformers<https://arxiv.org/abs/2201.03545v1>`_
Inspiration from
https://github.com/microsoft/Cream/blob/main/TinyViT
Args:
arch (str | dict): The architecture of TinyViT.
Default: '5m'.
img_size (tuple | int): The resolution of the input image.
Default: (224, 224)
window_size (list): The size of the window.
Default: [7, 7, 14, 7]
in_channels (int): The number of input channels.
Default: 3.
depths (list[int]): The depth of each stage.
Default: [2, 2, 6, 2].
mlp_ratio (list[int]): The ratio of mlp hidden dim to embedding dim.
Default: 4.
drop_rate (float): Probability of an element to be zeroed.
Default: 0.
drop_path_rate (float): The drop path of the block.
Default: 0.1.
use_checkpoint (bool): Whether to use checkpointing to save memory.
Default: False.
mbconv_expand_ratio (int): The expand ratio of the mbconv.
Default: 4.0
local_conv_size (int): The size of the local conv.
Default: 3.
layer_lr_decay (float): The layer lr decay.
Default: 1.0
out_indices (int | list[int]): Output from which stages.
Default: -1
frozen_stages (int | list[int]): Stages to be frozen (all param fixed).
Default: -0
gap_before_final_nrom (bool): Whether to add a gap before the final
norm. Default: True.
act_cfg (dict): The activation config of the module.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
init_cfg (dict | list[dict], optional): Initialization config dict.
Default: None.
"""
arch_settings
=
{
'5m'
:
{
'channels'
:
[
64
,
128
,
160
,
320
],
'num_heads'
:
[
2
,
4
,
5
,
10
],
'depths'
:
[
2
,
2
,
6
,
2
],
},
'11m'
:
{
'channels'
:
[
64
,
128
,
256
,
448
],
'num_heads'
:
[
2
,
4
,
8
,
14
],
'depths'
:
[
2
,
2
,
6
,
2
],
},
'21m'
:
{
'channels'
:
[
96
,
192
,
384
,
576
],
'num_heads'
:
[
3
,
6
,
12
,
18
],
'depths'
:
[
2
,
2
,
6
,
2
],
},
}
def
__init__
(
self
,
arch
=
'5m'
,
img_size
=
(
224
,
224
),
window_size
=
[
7
,
7
,
14
,
7
],
in_channels
=
3
,
mlp_ratio
=
4.
,
drop_rate
=
0.
,
drop_path_rate
=
0.1
,
use_checkpoint
=
False
,
mbconv_expand_ratio
=
4.0
,
local_conv_size
=
3
,
layer_lr_decay
=
1.0
,
out_indices
=-
1
,
frozen_stages
=
0
,
gap_before_final_norm
=
True
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
dict
(
type
=
'LN'
),
init_cfg
=
None
):
super
().
__init__
(
init_cfg
=
init_cfg
)
if
isinstance
(
arch
,
str
):
assert
arch
in
self
.
arch_settings
,
\
f
'Unavaiable arch, please choose from '
\
f
'(
{
set
(
self
.
arch_settings
)
}
or pass a dict.'
arch
=
self
.
arch_settings
[
arch
]
elif
isinstance
(
arch
,
dict
):
assert
'channels'
in
arch
and
'num_heads'
in
arch
and
\
'depths'
in
arch
,
'The arch dict must have'
\
f
'"channels", "num_heads", "window_sizes" '
\
f
'keys, but got
{
arch
.
keys
()
}
'
self
.
channels
=
arch
[
'channels'
]
self
.
num_heads
=
arch
[
'num_heads'
]
self
.
widow_sizes
=
window_size
self
.
img_size
=
img_size
self
.
depths
=
arch
[
'depths'
]
self
.
num_stages
=
len
(
self
.
channels
)
if
isinstance
(
out_indices
,
int
):
out_indices
=
[
out_indices
]
assert
isinstance
(
out_indices
,
Sequence
),
\
f
'"out_indices" must by a sequence or int, '
\
f
'get
{
type
(
out_indices
)
}
instead.'
for
i
,
index
in
enumerate
(
out_indices
):
if
index
<
0
:
out_indices
[
i
]
=
4
+
index
assert
out_indices
[
i
]
>=
0
,
f
'Invalid out_indices
{
index
}
'
self
.
out_indices
=
out_indices
self
.
frozen_stages
=
frozen_stages
self
.
gap_before_final_norm
=
gap_before_final_norm
self
.
layer_lr_decay
=
layer_lr_decay
self
.
patch_embed
=
PatchEmbed
(
in_channels
=
in_channels
,
embed_dim
=
self
.
channels
[
0
],
resolution
=
self
.
img_size
,
act_cfg
=
dict
(
type
=
'GELU'
))
patches_resolution
=
self
.
patch_embed
.
patches_resolution
# stochastic depth decay rule
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
self
.
depths
))
]
# build stages
self
.
stages
=
ModuleList
()
for
i
in
range
(
self
.
num_stages
):
depth
=
self
.
depths
[
i
]
channel
=
self
.
channels
[
i
]
curr_resolution
=
(
patches_resolution
[
0
]
//
(
2
**
i
),
patches_resolution
[
1
]
//
(
2
**
i
))
drop_path
=
dpr
[
sum
(
self
.
depths
[:
i
]):
sum
(
self
.
depths
[:
i
+
1
])]
downsample
=
PatchMerging
if
(
i
<
self
.
num_stages
-
1
)
else
None
out_channels
=
self
.
channels
[
min
(
i
+
1
,
self
.
num_stages
-
1
)]
if
i
>=
1
:
stage
=
BasicStage
(
in_channels
=
channel
,
resolution
=
curr_resolution
,
depth
=
depth
,
num_heads
=
self
.
num_heads
[
i
],
window_size
=
self
.
widow_sizes
[
i
],
mlp_ratio
=
mlp_ratio
,
drop
=
drop_rate
,
drop_path
=
drop_path
,
downsample
=
downsample
,
use_checkpoint
=
use_checkpoint
,
local_conv_size
=
local_conv_size
,
out_channels
=
out_channels
,
act_cfg
=
act_cfg
)
else
:
stage
=
ConvStage
(
in_channels
=
channel
,
resolution
=
curr_resolution
,
depth
=
depth
,
act_cfg
=
act_cfg
,
drop_path
=
drop_path
,
downsample
=
downsample
,
use_checkpoint
=
use_checkpoint
,
out_channels
=
out_channels
,
conv_expand_ratio
=
mbconv_expand_ratio
)
self
.
stages
.
append
(
stage
)
# add output norm
if
i
in
self
.
out_indices
:
norm_layer
=
build_norm_layer
(
norm_cfg
,
out_channels
)[
1
]
self
.
add_module
(
f
'norm
{
i
}
'
,
norm_layer
)
def
set_layer_lr_decay
(
self
,
layer_lr_decay
):
# TODO: add layer_lr_decay
pass
def
forward
(
self
,
x
):
outs
=
[]
x
=
self
.
patch_embed
(
x
)
for
i
,
stage
in
enumerate
(
self
.
stages
):
x
=
stage
(
x
)
if
i
in
self
.
out_indices
:
norm_layer
=
getattr
(
self
,
f
'norm
{
i
}
'
)
if
self
.
gap_before_final_norm
:
gap
=
x
.
mean
(
1
)
outs
.
append
(
norm_layer
(
gap
))
else
:
out
=
norm_layer
(
x
)
# convert the (B,L,C) format into (B,C,H,W) format
# which would be better for the downstream tasks.
B
,
L
,
C
=
out
.
shape
out
=
out
.
view
(
B
,
*
stage
.
resolution
,
C
)
outs
.
append
(
out
.
permute
(
0
,
3
,
1
,
2
))
return
tuple
(
outs
)
def
_freeze_stages
(
self
):
for
i
in
range
(
self
.
frozen_stages
):
stage
=
self
.
stages
[
i
]
stage
.
eval
()
for
param
in
stage
.
parameters
():
param
.
requires_grad
=
False
def
train
(
self
,
mode
=
True
):
super
(
TinyViT
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
Prev
1
…
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment