Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmpretrain
Commits
cbc25585
Commit
cbc25585
authored
Jun 24, 2025
by
limm
Browse files
add mmpretrain/ part
parent
1baf0566
Pipeline
#2801
canceled with stages
Changes
268
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
4142 additions
and
0 deletions
+4142
-0
mmpretrain/models/backbones/tnt.py
mmpretrain/models/backbones/tnt.py
+368
-0
mmpretrain/models/backbones/twins.py
mmpretrain/models/backbones/twins.py
+721
-0
mmpretrain/models/backbones/van.py
mmpretrain/models/backbones/van.py
+434
-0
mmpretrain/models/backbones/vgg.py
mmpretrain/models/backbones/vgg.py
+183
-0
mmpretrain/models/backbones/vig.py
mmpretrain/models/backbones/vig.py
+852
-0
mmpretrain/models/backbones/vision_transformer.py
mmpretrain/models/backbones/vision_transformer.py
+537
-0
mmpretrain/models/backbones/vit_eva02.py
mmpretrain/models/backbones/vit_eva02.py
+350
-0
mmpretrain/models/backbones/vit_sam.py
mmpretrain/models/backbones/vit_sam.py
+697
-0
No files found.
Too many changes to show.
To preserve performance only
268 of 268+
files are displayed.
Plain diff
Email patch
mmpretrain/models/backbones/tnt.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
import
math
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
build_norm_layer
from
mmcv.cnn.bricks.transformer
import
FFN
,
MultiheadAttention
from
mmengine.model
import
BaseModule
,
ModuleList
from
mmengine.model.weight_init
import
trunc_normal_
from
mmpretrain.registry
import
MODELS
from
..utils
import
to_2tuple
from
.base_backbone
import
BaseBackbone
class
TransformerBlock
(
BaseModule
):
"""Implement a transformer block in TnTLayer.
Args:
embed_dims (int): The feature dimension
num_heads (int): Parallel attention heads
ffn_ratio (int): A ratio to calculate the hidden_dims in ffn layer.
Default: 4
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Default 0.
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.
drop_path_rate (float): stochastic depth rate. Default 0.
num_fcs (int): The number of fully-connected layers for FFNs. Default 2
qkv_bias (bool): Enable bias for qkv if True. Default False
act_cfg (dict): The activation config for FFNs. Defaults to GELU.
norm_cfg (dict): Config dict for normalization layer. Default
layer normalization
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim) or (n, batch, embed_dim).
(batch, n, embed_dim) is common case in CV. Defaults to False
init_cfg (dict, optional): Initialization config dict. Defaults to None
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
ffn_ratio
=
4
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
num_fcs
=
2
,
qkv_bias
=
False
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
dict
(
type
=
'LN'
),
batch_first
=
True
,
init_cfg
=
None
):
super
(
TransformerBlock
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
norm_attn
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
self
.
attn
=
MultiheadAttention
(
embed_dims
=
embed_dims
,
num_heads
=
num_heads
,
attn_drop
=
attn_drop_rate
,
proj_drop
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
batch_first
=
batch_first
)
self
.
norm_ffn
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
self
.
ffn
=
FFN
(
embed_dims
=
embed_dims
,
feedforward_channels
=
embed_dims
*
ffn_ratio
,
num_fcs
=
num_fcs
,
ffn_drop
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
act_cfg
=
act_cfg
)
if
not
qkv_bias
:
self
.
attn
.
attn
.
in_proj_bias
=
None
def
forward
(
self
,
x
):
x
=
self
.
attn
(
self
.
norm_attn
(
x
),
identity
=
x
)
x
=
self
.
ffn
(
self
.
norm_ffn
(
x
),
identity
=
x
)
return
x
class
TnTLayer
(
BaseModule
):
"""Implement one encoder layer in Transformer in Transformer.
Args:
num_pixel (int): The pixel number in target patch transformed with
a linear projection in inner transformer
embed_dims_inner (int): Feature dimension in inner transformer block
embed_dims_outer (int): Feature dimension in outer transformer block
num_heads_inner (int): Parallel attention heads in inner transformer.
num_heads_outer (int): Parallel attention heads in outer transformer.
inner_block_cfg (dict): Extra config of inner transformer block.
Defaults to empty dict.
outer_block_cfg (dict): Extra config of outer transformer block.
Defaults to empty dict.
norm_cfg (dict): Config dict for normalization layer. Default
layer normalization
init_cfg (dict, optional): Initialization config dict. Defaults to None
"""
def
__init__
(
self
,
num_pixel
,
embed_dims_inner
,
embed_dims_outer
,
num_heads_inner
,
num_heads_outer
,
inner_block_cfg
=
dict
(),
outer_block_cfg
=
dict
(),
norm_cfg
=
dict
(
type
=
'LN'
),
init_cfg
=
None
):
super
(
TnTLayer
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
inner_block
=
TransformerBlock
(
embed_dims
=
embed_dims_inner
,
num_heads
=
num_heads_inner
,
**
inner_block_cfg
)
self
.
norm_proj
=
build_norm_layer
(
norm_cfg
,
embed_dims_inner
)[
1
]
self
.
projection
=
nn
.
Linear
(
embed_dims_inner
*
num_pixel
,
embed_dims_outer
,
bias
=
True
)
self
.
outer_block
=
TransformerBlock
(
embed_dims
=
embed_dims_outer
,
num_heads
=
num_heads_outer
,
**
outer_block_cfg
)
def
forward
(
self
,
pixel_embed
,
patch_embed
):
pixel_embed
=
self
.
inner_block
(
pixel_embed
)
B
,
N
,
C
=
patch_embed
.
size
()
patch_embed
[:,
1
:]
=
patch_embed
[:,
1
:]
+
self
.
projection
(
self
.
norm_proj
(
pixel_embed
).
reshape
(
B
,
N
-
1
,
-
1
))
patch_embed
=
self
.
outer_block
(
patch_embed
)
return
pixel_embed
,
patch_embed
class
PixelEmbed
(
BaseModule
):
"""Image to Pixel Embedding.
Args:
img_size (int | tuple): The size of input image
patch_size (int): The size of one patch
in_channels (int): The num of input channels
embed_dims_inner (int): The num of channels of the target patch
transformed with a linear projection in inner transformer
stride (int): The stride of the conv2d layer. We use a conv2d layer
and a unfold layer to implement image to pixel embedding.
init_cfg (dict, optional): Initialization config dict
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_channels
=
3
,
embed_dims_inner
=
48
,
stride
=
4
,
init_cfg
=
None
):
super
(
PixelEmbed
,
self
).
__init__
(
init_cfg
=
init_cfg
)
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
# patches_resolution property necessary for resizing
# positional embedding
patches_resolution
=
[
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
]
]
num_patches
=
patches_resolution
[
0
]
*
patches_resolution
[
1
]
self
.
img_size
=
img_size
self
.
num_patches
=
num_patches
self
.
embed_dims_inner
=
embed_dims_inner
new_patch_size
=
[
math
.
ceil
(
ps
/
stride
)
for
ps
in
patch_size
]
self
.
new_patch_size
=
new_patch_size
self
.
proj
=
nn
.
Conv2d
(
in_channels
,
self
.
embed_dims_inner
,
kernel_size
=
7
,
padding
=
3
,
stride
=
stride
)
self
.
unfold
=
nn
.
Unfold
(
kernel_size
=
new_patch_size
,
stride
=
new_patch_size
)
def
forward
(
self
,
x
,
pixel_pos
):
B
,
C
,
H
,
W
=
x
.
shape
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
\
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model "
\
f
'(
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
).'
x
=
self
.
proj
(
x
)
x
=
self
.
unfold
(
x
)
x
=
x
.
transpose
(
1
,
2
).
reshape
(
B
*
self
.
num_patches
,
self
.
embed_dims_inner
,
self
.
new_patch_size
[
0
],
self
.
new_patch_size
[
1
])
x
=
x
+
pixel_pos
x
=
x
.
reshape
(
B
*
self
.
num_patches
,
self
.
embed_dims_inner
,
-
1
).
transpose
(
1
,
2
)
return
x
@
MODELS
.
register_module
()
class
TNT
(
BaseBackbone
):
"""Transformer in Transformer.
A PyTorch implement of: `Transformer in Transformer
<https://arxiv.org/abs/2103.00112>`_
Inspiration from
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/tnt.py
Args:
arch (str | dict): Vision Transformer architecture
Default: 'b'
img_size (int | tuple): Input image size. Defaults to 224
patch_size (int | tuple): The patch size. Deault to 16
in_channels (int): Number of input channels. Defaults to 3
ffn_ratio (int): A ratio to calculate the hidden_dims in ffn layer.
Default: 4
qkv_bias (bool): Enable bias for qkv if True. Default False
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Default 0.
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.
drop_path_rate (float): stochastic depth rate. Default 0.
act_cfg (dict): The activation config for FFNs. Defaults to GELU.
norm_cfg (dict): Config dict for normalization layer. Default
layer normalization
first_stride (int): The stride of the conv2d layer. We use a conv2d
layer and a unfold layer to implement image to pixel embedding.
num_fcs (int): The number of fully-connected layers for FFNs. Default 2
init_cfg (dict, optional): Initialization config dict
"""
arch_zoo
=
{
**
dict
.
fromkeys
(
[
's'
,
'small'
],
{
'embed_dims_outer'
:
384
,
'embed_dims_inner'
:
24
,
'num_layers'
:
12
,
'num_heads_outer'
:
6
,
'num_heads_inner'
:
4
}),
**
dict
.
fromkeys
(
[
'b'
,
'base'
],
{
'embed_dims_outer'
:
640
,
'embed_dims_inner'
:
40
,
'num_layers'
:
12
,
'num_heads_outer'
:
10
,
'num_heads_inner'
:
4
})
}
def
__init__
(
self
,
arch
=
'b'
,
img_size
=
224
,
patch_size
=
16
,
in_channels
=
3
,
ffn_ratio
=
4
,
qkv_bias
=
False
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
dict
(
type
=
'LN'
),
first_stride
=
4
,
num_fcs
=
2
,
init_cfg
=
[
dict
(
type
=
'TruncNormal'
,
layer
=
'Linear'
,
std
=
.
02
),
dict
(
type
=
'Constant'
,
layer
=
'LayerNorm'
,
val
=
1.
,
bias
=
0.
)
]):
super
(
TNT
,
self
).
__init__
(
init_cfg
=
init_cfg
)
if
isinstance
(
arch
,
str
):
arch
=
arch
.
lower
()
assert
arch
in
set
(
self
.
arch_zoo
),
\
f
'Arch
{
arch
}
is not in default archs
{
set
(
self
.
arch_zoo
)
}
'
self
.
arch_settings
=
self
.
arch_zoo
[
arch
]
else
:
essential_keys
=
{
'embed_dims_outer'
,
'embed_dims_inner'
,
'num_layers'
,
'num_heads_inner'
,
'num_heads_outer'
}
assert
isinstance
(
arch
,
dict
)
and
set
(
arch
)
==
essential_keys
,
\
f
'Custom arch needs a dict with keys
{
essential_keys
}
'
self
.
arch_settings
=
arch
self
.
embed_dims_inner
=
self
.
arch_settings
[
'embed_dims_inner'
]
self
.
embed_dims_outer
=
self
.
arch_settings
[
'embed_dims_outer'
]
# embed_dims for consistency with other models
self
.
embed_dims
=
self
.
embed_dims_outer
self
.
num_layers
=
self
.
arch_settings
[
'num_layers'
]
self
.
num_heads_inner
=
self
.
arch_settings
[
'num_heads_inner'
]
self
.
num_heads_outer
=
self
.
arch_settings
[
'num_heads_outer'
]
self
.
pixel_embed
=
PixelEmbed
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_channels
=
in_channels
,
embed_dims_inner
=
self
.
embed_dims_inner
,
stride
=
first_stride
)
num_patches
=
self
.
pixel_embed
.
num_patches
self
.
num_patches
=
num_patches
new_patch_size
=
self
.
pixel_embed
.
new_patch_size
num_pixel
=
new_patch_size
[
0
]
*
new_patch_size
[
1
]
self
.
norm1_proj
=
build_norm_layer
(
norm_cfg
,
num_pixel
*
self
.
embed_dims_inner
)[
1
]
self
.
projection
=
nn
.
Linear
(
num_pixel
*
self
.
embed_dims_inner
,
self
.
embed_dims_outer
)
self
.
norm2_proj
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims_outer
)[
1
]
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
self
.
embed_dims_outer
))
self
.
patch_pos
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
+
1
,
self
.
embed_dims_outer
))
self
.
pixel_pos
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
self
.
embed_dims_inner
,
new_patch_size
[
0
],
new_patch_size
[
1
]))
self
.
drop_after_pos
=
nn
.
Dropout
(
p
=
drop_rate
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
self
.
num_layers
)
]
# stochastic depth decay rule
self
.
layers
=
ModuleList
()
for
i
in
range
(
self
.
num_layers
):
block_cfg
=
dict
(
ffn_ratio
=
ffn_ratio
,
drop_rate
=
drop_rate
,
attn_drop_rate
=
attn_drop_rate
,
drop_path_rate
=
dpr
[
i
],
num_fcs
=
num_fcs
,
qkv_bias
=
qkv_bias
,
norm_cfg
=
norm_cfg
,
batch_first
=
True
)
self
.
layers
.
append
(
TnTLayer
(
num_pixel
=
num_pixel
,
embed_dims_inner
=
self
.
embed_dims_inner
,
embed_dims_outer
=
self
.
embed_dims_outer
,
num_heads_inner
=
self
.
num_heads_inner
,
num_heads_outer
=
self
.
num_heads_outer
,
inner_block_cfg
=
block_cfg
,
outer_block_cfg
=
block_cfg
,
norm_cfg
=
norm_cfg
))
self
.
norm
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims_outer
)[
1
]
trunc_normal_
(
self
.
cls_token
,
std
=
.
02
)
trunc_normal_
(
self
.
patch_pos
,
std
=
.
02
)
trunc_normal_
(
self
.
pixel_pos
,
std
=
.
02
)
def
forward
(
self
,
x
):
B
=
x
.
shape
[
0
]
pixel_embed
=
self
.
pixel_embed
(
x
,
self
.
pixel_pos
)
patch_embed
=
self
.
norm2_proj
(
self
.
projection
(
self
.
norm1_proj
(
pixel_embed
.
reshape
(
B
,
self
.
num_patches
,
-
1
))))
patch_embed
=
torch
.
cat
(
(
self
.
cls_token
.
expand
(
B
,
-
1
,
-
1
),
patch_embed
),
dim
=
1
)
patch_embed
=
patch_embed
+
self
.
patch_pos
patch_embed
=
self
.
drop_after_pos
(
patch_embed
)
for
layer
in
self
.
layers
:
pixel_embed
,
patch_embed
=
layer
(
pixel_embed
,
patch_embed
)
patch_embed
=
self
.
norm
(
patch_embed
)
return
(
patch_embed
[:,
0
],
)
mmpretrain/models/backbones/twins.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
Conv2d
,
build_norm_layer
from
mmcv.cnn.bricks.drop
import
build_dropout
from
mmcv.cnn.bricks.transformer
import
FFN
,
PatchEmbed
from
mmengine.model
import
BaseModule
,
ModuleList
from
mmengine.model.weight_init
import
(
constant_init
,
normal_init
,
trunc_normal_init
)
from
torch.nn.modules.batchnorm
import
_BatchNorm
from
mmpretrain.registry
import
MODELS
from
..utils
import
ConditionalPositionEncoding
,
MultiheadAttention
class
GlobalSubsampledAttention
(
MultiheadAttention
):
"""Global Sub-sampled Attention (GSA) module.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
input_dims (int, optional): The input dimension, and if None,
use ``embed_dims``. Defaults to None.
attn_drop (float): Dropout rate of the dropout layer after the
attention calculation of query and key. Defaults to 0.
proj_drop (float): Dropout rate of the dropout layer after the
output projection. Defaults to 0.
dropout_layer (dict): The dropout config before adding the shortcut.
Defaults to ``dict(type='Dropout', drop_prob=0.)``.
qkv_bias (bool): If True, add a learnable bias to q, k, v.
Defaults to True.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
proj_bias (bool) If True, add a learnable bias to output projection.
Defaults to True.
v_shortcut (bool): Add a shortcut from value to output. It's usually
used if ``input_dims`` is different from ``embed_dims``.
Defaults to False.
sr_ratio (float): The ratio of spatial reduction in attention modules.
Defaults to 1.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
norm_cfg
=
dict
(
type
=
'LN'
),
qkv_bias
=
True
,
sr_ratio
=
1
,
**
kwargs
):
super
(
GlobalSubsampledAttention
,
self
).
__init__
(
embed_dims
,
num_heads
,
**
kwargs
)
self
.
qkv_bias
=
qkv_bias
self
.
q
=
nn
.
Linear
(
self
.
input_dims
,
embed_dims
,
bias
=
qkv_bias
)
self
.
kv
=
nn
.
Linear
(
self
.
input_dims
,
embed_dims
*
2
,
bias
=
qkv_bias
)
# remove self.qkv, here split into self.q, self.kv
delattr
(
self
,
'qkv'
)
self
.
sr_ratio
=
sr_ratio
if
sr_ratio
>
1
:
# use a conv as the spatial-reduction operation, the kernel_size
# and stride in conv are equal to the sr_ratio.
self
.
sr
=
Conv2d
(
in_channels
=
embed_dims
,
out_channels
=
embed_dims
,
kernel_size
=
sr_ratio
,
stride
=
sr_ratio
)
# The ret[0] of build_norm_layer is norm name.
self
.
norm
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
def
forward
(
self
,
x
,
hw_shape
):
B
,
N
,
C
=
x
.
shape
H
,
W
=
hw_shape
assert
H
*
W
==
N
,
'The product of h and w of hw_shape must be N, '
\
'which is the 2nd dim number of the input Tensor x.'
q
=
self
.
q
(
x
).
reshape
(
B
,
N
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
0
,
2
,
1
,
3
)
if
self
.
sr_ratio
>
1
:
x
=
x
.
permute
(
0
,
2
,
1
).
reshape
(
B
,
C
,
*
hw_shape
)
# BNC_2_BCHW
x
=
self
.
sr
(
x
)
x
=
x
.
reshape
(
B
,
C
,
-
1
).
permute
(
0
,
2
,
1
)
# BCHW_2_BNC
x
=
self
.
norm
(
x
)
kv
=
self
.
kv
(
x
).
reshape
(
B
,
-
1
,
2
,
self
.
num_heads
,
self
.
head_dims
).
permute
(
2
,
0
,
3
,
1
,
4
)
k
,
v
=
kv
[
0
],
kv
[
1
]
attn_drop
=
self
.
attn_drop
if
self
.
training
else
0.
x
=
self
.
scaled_dot_product_attention
(
q
,
k
,
v
,
dropout_p
=
attn_drop
)
x
=
x
.
transpose
(
1
,
2
).
reshape
(
B
,
N
,
self
.
embed_dims
)
x
=
self
.
proj
(
x
)
x
=
self
.
out_drop
(
self
.
proj_drop
(
x
))
if
self
.
v_shortcut
:
x
=
v
.
squeeze
(
1
)
+
x
return
x
class
GSAEncoderLayer
(
BaseModule
):
"""Implements one encoder layer with GlobalSubsampledAttention(GSA).
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Default: 0.0.
attn_drop_rate (float): The drop out rate for attention layer.
Default: 0.0.
drop_path_rate (float): Stochastic depth rate. Default 0.0.
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
qkv_bias (bool): Enable bias for qkv if True. Default: True
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
sr_ratio (float): The ratio of spatial reduction in attention modules.
Defaults to 1.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
feedforward_channels
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
num_fcs
=
2
,
qkv_bias
=
True
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
dict
(
type
=
'LN'
),
sr_ratio
=
1.
,
init_cfg
=
None
):
super
(
GSAEncoderLayer
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
norm1
=
build_norm_layer
(
norm_cfg
,
embed_dims
,
postfix
=
1
)[
1
]
self
.
attn
=
GlobalSubsampledAttention
(
embed_dims
=
embed_dims
,
num_heads
=
num_heads
,
attn_drop
=
attn_drop_rate
,
proj_drop
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
qkv_bias
=
qkv_bias
,
norm_cfg
=
norm_cfg
,
sr_ratio
=
sr_ratio
)
self
.
norm2
=
build_norm_layer
(
norm_cfg
,
embed_dims
,
postfix
=
2
)[
1
]
self
.
ffn
=
FFN
(
embed_dims
=
embed_dims
,
feedforward_channels
=
feedforward_channels
,
num_fcs
=
num_fcs
,
ffn_drop
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
act_cfg
=
act_cfg
,
add_identity
=
False
)
self
.
drop_path
=
build_dropout
(
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
)
)
if
drop_path_rate
>
0.
else
nn
.
Identity
()
def
forward
(
self
,
x
,
hw_shape
):
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
),
hw_shape
))
x
=
x
+
self
.
drop_path
(
self
.
ffn
(
self
.
norm2
(
x
)))
return
x
class
LocallyGroupedSelfAttention
(
BaseModule
):
"""Locally-grouped Self Attention (LSA) module.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads. Default: 8
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: False.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Default: 0.0
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
window_size(int): Window size of LSA. Default: 1.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
=
8
,
qkv_bias
=
False
,
qk_scale
=
None
,
attn_drop_rate
=
0.
,
proj_drop_rate
=
0.
,
window_size
=
1
,
init_cfg
=
None
):
super
(
LocallyGroupedSelfAttention
,
self
).
__init__
(
init_cfg
=
init_cfg
)
assert
embed_dims
%
num_heads
==
0
,
\
f
'dim
{
embed_dims
}
should be divided by num_heads
{
num_heads
}
'
self
.
embed_dims
=
embed_dims
self
.
num_heads
=
num_heads
head_dim
=
embed_dims
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**-
0.5
self
.
qkv
=
nn
.
Linear
(
embed_dims
,
embed_dims
*
3
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop_rate
)
self
.
proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop_rate
)
self
.
window_size
=
window_size
def
forward
(
self
,
x
,
hw_shape
):
B
,
N
,
C
=
x
.
shape
H
,
W
=
hw_shape
x
=
x
.
view
(
B
,
H
,
W
,
C
)
# pad feature maps to multiples of Local-groups
pad_l
=
pad_t
=
0
pad_r
=
(
self
.
window_size
-
W
%
self
.
window_size
)
%
self
.
window_size
pad_b
=
(
self
.
window_size
-
H
%
self
.
window_size
)
%
self
.
window_size
x
=
F
.
pad
(
x
,
(
0
,
0
,
pad_l
,
pad_r
,
pad_t
,
pad_b
))
# calculate attention mask for LSA
Hp
,
Wp
=
x
.
shape
[
1
:
-
1
]
_h
,
_w
=
Hp
//
self
.
window_size
,
Wp
//
self
.
window_size
mask
=
torch
.
zeros
((
1
,
Hp
,
Wp
),
device
=
x
.
device
)
mask
[:,
-
pad_b
:,
:].
fill_
(
1
)
mask
[:,
:,
-
pad_r
:].
fill_
(
1
)
# [B, _h, _w, window_size, window_size, C]
x
=
x
.
reshape
(
B
,
_h
,
self
.
window_size
,
_w
,
self
.
window_size
,
C
).
transpose
(
2
,
3
)
mask
=
mask
.
reshape
(
1
,
_h
,
self
.
window_size
,
_w
,
self
.
window_size
).
transpose
(
2
,
3
).
reshape
(
1
,
_h
*
_w
,
self
.
window_size
*
self
.
window_size
)
# [1, _h*_w, window_size*window_size, window_size*window_size]
attn_mask
=
mask
.
unsqueeze
(
2
)
-
mask
.
unsqueeze
(
3
)
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
!=
0
,
float
(
-
1000.0
)).
masked_fill
(
attn_mask
==
0
,
float
(
0.0
))
# [3, B, _w*_h, nhead, window_size*window_size, dim]
qkv
=
self
.
qkv
(
x
).
reshape
(
B
,
_h
*
_w
,
self
.
window_size
*
self
.
window_size
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
3
,
0
,
1
,
4
,
2
,
5
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
# [B, _h*_w, n_head, window_size*window_size, window_size*window_size]
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
*
self
.
scale
attn
=
attn
+
attn_mask
.
unsqueeze
(
2
)
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
attn
=
(
attn
@
v
).
transpose
(
2
,
3
).
reshape
(
B
,
_h
,
_w
,
self
.
window_size
,
self
.
window_size
,
C
)
x
=
attn
.
transpose
(
2
,
3
).
reshape
(
B
,
_h
*
self
.
window_size
,
_w
*
self
.
window_size
,
C
)
if
pad_r
>
0
or
pad_b
>
0
:
x
=
x
[:,
:
H
,
:
W
,
:].
contiguous
()
x
=
x
.
reshape
(
B
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
LSAEncoderLayer
(
BaseModule
):
"""Implements one encoder layer with LocallyGroupedSelfAttention(LSA).
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Default: 0.0.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Default: 0.0
drop_path_rate (float): Stochastic depth rate. Default 0.0.
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
qkv_bias (bool): Enable bias for qkv if True. Default: True
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
window_size (int): Window size of LSA. Default: 1.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
feedforward_channels
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
num_fcs
=
2
,
qkv_bias
=
True
,
qk_scale
=
None
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
dict
(
type
=
'LN'
),
window_size
=
1
,
init_cfg
=
None
):
super
(
LSAEncoderLayer
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
norm1
=
build_norm_layer
(
norm_cfg
,
embed_dims
,
postfix
=
1
)[
1
]
self
.
attn
=
LocallyGroupedSelfAttention
(
embed_dims
,
num_heads
,
qkv_bias
,
qk_scale
,
attn_drop_rate
,
drop_rate
,
window_size
)
self
.
norm2
=
build_norm_layer
(
norm_cfg
,
embed_dims
,
postfix
=
2
)[
1
]
self
.
ffn
=
FFN
(
embed_dims
=
embed_dims
,
feedforward_channels
=
feedforward_channels
,
num_fcs
=
num_fcs
,
ffn_drop
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
act_cfg
=
act_cfg
,
add_identity
=
False
)
self
.
drop_path
=
build_dropout
(
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
)
)
if
drop_path_rate
>
0.
else
nn
.
Identity
()
def
forward
(
self
,
x
,
hw_shape
):
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
),
hw_shape
))
x
=
x
+
self
.
drop_path
(
self
.
ffn
(
self
.
norm2
(
x
)))
return
x
@
MODELS
.
register_module
()
class
PCPVT
(
BaseModule
):
"""The backbone of Twins-PCPVT.
This backbone is the implementation of `Twins: Revisiting the Design
of Spatial Attention in Vision Transformers
<https://arxiv.org/abs/1512.03385>`_.
Args:
arch (dict, str): PCPVT architecture, a str value in arch zoo or a
detailed configuration dict with 7 keys, and the length of all the
values in dict should be the same:
- depths (List[int]): The number of encoder layers in each stage.
- embed_dims (List[int]): Embedding dimension in each stage.
- patch_sizes (List[int]): The patch sizes in each stage.
- num_heads (List[int]): Numbers of attention head in each stage.
- strides (List[int]): The strides in each stage.
- mlp_ratios (List[int]): The ratios of mlp in each stage.
- sr_ratios (List[int]): The ratios of GSA-encoder layers in each
stage.
in_channels (int): Number of input channels. Defaults to 3.
out_indices (tuple[int]): Output from which stages.
Defaults to ``(3, )``.
qkv_bias (bool): Enable bias for qkv if True. Defaults to False.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
attn_drop_rate (float): The drop out rate for attention layer.
Defaults to 0.0
drop_path_rate (float): Stochastic depth rate. Defaults to 0.0.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
norm_after_stage(bool, List[bool]): Add extra norm after each stage.
Defaults to False.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmpretrain.models import PCPVT
>>> import torch
>>> pcpvt_cfg = {'arch': "small",
>>> 'norm_after_stage': [False, False, False, True]}
>>> model = PCPVT(**pcpvt_cfg)
>>> x = torch.rand(1, 3, 224, 224)
>>> outputs = model(x)
>>> print(outputs[-1].shape)
torch.Size([1, 512, 7, 7])
>>> pcpvt_cfg['norm_after_stage'] = [True, True, True, True]
>>> pcpvt_cfg['out_indices'] = (0, 1, 2, 3)
>>> model = PCPVT(**pcpvt_cfg)
>>> outputs = model(x)
>>> for feat in outputs:
>>> print(feat.shape)
torch.Size([1, 64, 56, 56])
torch.Size([1, 128, 28, 28])
torch.Size([1, 320, 14, 14])
torch.Size([1, 512, 7, 7])
"""
arch_zoo
=
{
**
dict
.
fromkeys
([
's'
,
'small'
],
{
'embed_dims'
:
[
64
,
128
,
320
,
512
],
'depths'
:
[
3
,
4
,
6
,
3
],
'num_heads'
:
[
1
,
2
,
5
,
8
],
'patch_sizes'
:
[
4
,
2
,
2
,
2
],
'strides'
:
[
4
,
2
,
2
,
2
],
'mlp_ratios'
:
[
8
,
8
,
4
,
4
],
'sr_ratios'
:
[
8
,
4
,
2
,
1
]}),
**
dict
.
fromkeys
([
'b'
,
'base'
],
{
'embed_dims'
:
[
64
,
128
,
320
,
512
],
'depths'
:
[
3
,
4
,
18
,
3
],
'num_heads'
:
[
1
,
2
,
5
,
8
],
'patch_sizes'
:
[
4
,
2
,
2
,
2
],
'strides'
:
[
4
,
2
,
2
,
2
],
'mlp_ratios'
:
[
8
,
8
,
4
,
4
],
'sr_ratios'
:
[
8
,
4
,
2
,
1
]}),
**
dict
.
fromkeys
([
'l'
,
'large'
],
{
'embed_dims'
:
[
64
,
128
,
320
,
512
],
'depths'
:
[
3
,
8
,
27
,
3
],
'num_heads'
:
[
1
,
2
,
5
,
8
],
'patch_sizes'
:
[
4
,
2
,
2
,
2
],
'strides'
:
[
4
,
2
,
2
,
2
],
'mlp_ratios'
:
[
8
,
8
,
4
,
4
],
'sr_ratios'
:
[
8
,
4
,
2
,
1
]}),
}
# yapf: disable
essential_keys
=
{
'embed_dims'
,
'depths'
,
'num_heads'
,
'patch_sizes'
,
'strides'
,
'mlp_ratios'
,
'sr_ratios'
}
def
__init__
(
self
,
arch
,
in_channels
=
3
,
out_indices
=
(
3
,
),
qkv_bias
=
False
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
norm_cfg
=
dict
(
type
=
'LN'
),
norm_after_stage
=
False
,
init_cfg
=
None
):
super
(
PCPVT
,
self
).
__init__
(
init_cfg
=
init_cfg
)
if
isinstance
(
arch
,
str
):
arch
=
arch
.
lower
()
assert
arch
in
set
(
self
.
arch_zoo
),
\
f
'Arch
{
arch
}
is not in default archs
{
set
(
self
.
arch_zoo
)
}
'
self
.
arch_settings
=
self
.
arch_zoo
[
arch
]
else
:
assert
isinstance
(
arch
,
dict
)
and
(
set
(
arch
)
==
self
.
essential_keys
),
f
'Custom arch needs a dict with keys
{
self
.
essential_keys
}
.'
self
.
arch_settings
=
arch
self
.
depths
=
self
.
arch_settings
[
'depths'
]
self
.
embed_dims
=
self
.
arch_settings
[
'embed_dims'
]
self
.
patch_sizes
=
self
.
arch_settings
[
'patch_sizes'
]
self
.
strides
=
self
.
arch_settings
[
'strides'
]
self
.
mlp_ratios
=
self
.
arch_settings
[
'mlp_ratios'
]
self
.
num_heads
=
self
.
arch_settings
[
'num_heads'
]
self
.
sr_ratios
=
self
.
arch_settings
[
'sr_ratios'
]
self
.
num_extra_tokens
=
0
# there is no cls-token in Twins
self
.
num_stage
=
len
(
self
.
depths
)
for
key
,
value
in
self
.
arch_settings
.
items
():
assert
isinstance
(
value
,
list
)
and
len
(
value
)
==
self
.
num_stage
,
(
'Length of setting item in arch dict must be type of list and'
' have the same length.'
)
# patch_embeds
self
.
patch_embeds
=
ModuleList
()
self
.
position_encoding_drops
=
ModuleList
()
self
.
stages
=
ModuleList
()
for
i
in
range
(
self
.
num_stage
):
# use in_channels of the model in the first stage
if
i
==
0
:
stage_in_channels
=
in_channels
else
:
stage_in_channels
=
self
.
embed_dims
[
i
-
1
]
self
.
patch_embeds
.
append
(
PatchEmbed
(
in_channels
=
stage_in_channels
,
embed_dims
=
self
.
embed_dims
[
i
],
conv_type
=
'Conv2d'
,
kernel_size
=
self
.
patch_sizes
[
i
],
stride
=
self
.
strides
[
i
],
padding
=
'corner'
,
norm_cfg
=
dict
(
type
=
'LN'
)))
self
.
position_encoding_drops
.
append
(
nn
.
Dropout
(
p
=
drop_rate
))
# PEGs
self
.
position_encodings
=
ModuleList
([
ConditionalPositionEncoding
(
embed_dim
,
embed_dim
)
for
embed_dim
in
self
.
embed_dims
])
# stochastic depth
total_depth
=
sum
(
self
.
depths
)
self
.
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
total_depth
)
]
# stochastic depth decay rule
cur
=
0
for
k
in
range
(
len
(
self
.
depths
)):
_block
=
ModuleList
([
GSAEncoderLayer
(
embed_dims
=
self
.
embed_dims
[
k
],
num_heads
=
self
.
num_heads
[
k
],
feedforward_channels
=
self
.
mlp_ratios
[
k
]
*
self
.
embed_dims
[
k
],
attn_drop_rate
=
attn_drop_rate
,
drop_rate
=
drop_rate
,
drop_path_rate
=
self
.
dpr
[
cur
+
i
],
num_fcs
=
2
,
qkv_bias
=
qkv_bias
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
norm_cfg
,
sr_ratio
=
self
.
sr_ratios
[
k
])
for
i
in
range
(
self
.
depths
[
k
])
])
self
.
stages
.
append
(
_block
)
cur
+=
self
.
depths
[
k
]
self
.
out_indices
=
out_indices
assert
isinstance
(
norm_after_stage
,
(
bool
,
list
))
if
isinstance
(
norm_after_stage
,
bool
):
self
.
norm_after_stage
=
[
norm_after_stage
]
*
self
.
num_stage
else
:
self
.
norm_after_stage
=
norm_after_stage
assert
len
(
self
.
norm_after_stage
)
==
self
.
num_stage
,
\
(
f
'Number of norm_after_stage(
{
len
(
self
.
norm_after_stage
)
}
) should'
f
' be equal to the number of stages(
{
self
.
num_stage
}
).'
)
for
i
,
has_norm
in
enumerate
(
self
.
norm_after_stage
):
assert
isinstance
(
has_norm
,
bool
),
'norm_after_stage should be '
\
'bool or List[bool].'
if
has_norm
and
norm_cfg
is
not
None
:
norm_layer
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
[
i
])[
1
]
else
:
norm_layer
=
nn
.
Identity
()
self
.
add_module
(
f
'norm_after_stage
{
i
}
'
,
norm_layer
)
def
init_weights
(
self
):
if
self
.
init_cfg
is
not
None
:
super
(
PCPVT
,
self
).
init_weights
()
else
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_init
(
m
,
std
=
.
02
,
bias
=
0.
)
elif
isinstance
(
m
,
(
_BatchNorm
,
nn
.
GroupNorm
,
nn
.
LayerNorm
)):
constant_init
(
m
,
val
=
1.0
,
bias
=
0.
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
normal_init
(
m
,
mean
=
0
,
std
=
math
.
sqrt
(
2.0
/
fan_out
),
bias
=
0
)
def
forward
(
self
,
x
):
outputs
=
list
()
b
=
x
.
shape
[
0
]
for
i
in
range
(
self
.
num_stage
):
x
,
hw_shape
=
self
.
patch_embeds
[
i
](
x
)
h
,
w
=
hw_shape
x
=
self
.
position_encoding_drops
[
i
](
x
)
for
j
,
blk
in
enumerate
(
self
.
stages
[
i
]):
x
=
blk
(
x
,
hw_shape
)
if
j
==
0
:
x
=
self
.
position_encodings
[
i
](
x
,
hw_shape
)
norm_layer
=
getattr
(
self
,
f
'norm_after_stage
{
i
}
'
)
x
=
norm_layer
(
x
)
x
=
x
.
reshape
(
b
,
h
,
w
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
if
i
in
self
.
out_indices
:
outputs
.
append
(
x
)
return
tuple
(
outputs
)
@
MODELS
.
register_module
()
class
SVT
(
PCPVT
):
"""The backbone of Twins-SVT.
This backbone is the implementation of `Twins: Revisiting the Design
of Spatial Attention in Vision Transformers
<https://arxiv.org/abs/1512.03385>`_.
Args:
arch (dict, str): SVT architecture, a str value in arch zoo or a
detailed configuration dict with 8 keys, and the length of all the
values in dict should be the same:
- depths (List[int]): The number of encoder layers in each stage.
- embed_dims (List[int]): Embedding dimension in each stage.
- patch_sizes (List[int]): The patch sizes in each stage.
- num_heads (List[int]): Numbers of attention head in each stage.
- strides (List[int]): The strides in each stage.
- mlp_ratios (List[int]): The ratios of mlp in each stage.
- sr_ratios (List[int]): The ratios of GSA-encoder layers in each
stage.
- windiow_sizes (List[int]): The window sizes in LSA-encoder layers
in each stage.
in_channels (int): Number of input channels. Defaults to 3.
out_indices (tuple[int]): Output from which stages.
Defaults to (3, ).
qkv_bias (bool): Enable bias for qkv if True. Defaults to False.
drop_rate (float): Dropout rate. Defaults to 0.
attn_drop_rate (float): Dropout ratio of attention weight.
Defaults to 0.0
drop_path_rate (float): Stochastic depth rate. Defaults to 0.2.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
norm_after_stage(bool, List[bool]): Add extra norm after each stage.
Defaults to False.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmpretrain.models import SVT
>>> import torch
>>> svt_cfg = {'arch': "small",
>>> 'norm_after_stage': [False, False, False, True]}
>>> model = SVT(**svt_cfg)
>>> x = torch.rand(1, 3, 224, 224)
>>> outputs = model(x)
>>> print(outputs[-1].shape)
torch.Size([1, 512, 7, 7])
>>> svt_cfg["out_indices"] = (0, 1, 2, 3)
>>> svt_cfg["norm_after_stage"] = [True, True, True, True]
>>> model = SVT(**svt_cfg)
>>> output = model(x)
>>> for feat in output:
>>> print(feat.shape)
torch.Size([1, 64, 56, 56])
torch.Size([1, 128, 28, 28])
torch.Size([1, 320, 14, 14])
torch.Size([1, 512, 7, 7])
"""
arch_zoo
=
{
**
dict
.
fromkeys
([
's'
,
'small'
],
{
'embed_dims'
:
[
64
,
128
,
256
,
512
],
'depths'
:
[
2
,
2
,
10
,
4
],
'num_heads'
:
[
2
,
4
,
8
,
16
],
'patch_sizes'
:
[
4
,
2
,
2
,
2
],
'strides'
:
[
4
,
2
,
2
,
2
],
'mlp_ratios'
:
[
4
,
4
,
4
,
4
],
'sr_ratios'
:
[
8
,
4
,
2
,
1
],
'window_sizes'
:
[
7
,
7
,
7
,
7
]}),
**
dict
.
fromkeys
([
'b'
,
'base'
],
{
'embed_dims'
:
[
96
,
192
,
384
,
768
],
'depths'
:
[
2
,
2
,
18
,
2
],
'num_heads'
:
[
3
,
6
,
12
,
24
],
'patch_sizes'
:
[
4
,
2
,
2
,
2
],
'strides'
:
[
4
,
2
,
2
,
2
],
'mlp_ratios'
:
[
4
,
4
,
4
,
4
],
'sr_ratios'
:
[
8
,
4
,
2
,
1
],
'window_sizes'
:
[
7
,
7
,
7
,
7
]}),
**
dict
.
fromkeys
([
'l'
,
'large'
],
{
'embed_dims'
:
[
128
,
256
,
512
,
1024
],
'depths'
:
[
2
,
2
,
18
,
2
],
'num_heads'
:
[
4
,
8
,
16
,
32
],
'patch_sizes'
:
[
4
,
2
,
2
,
2
],
'strides'
:
[
4
,
2
,
2
,
2
],
'mlp_ratios'
:
[
4
,
4
,
4
,
4
],
'sr_ratios'
:
[
8
,
4
,
2
,
1
],
'window_sizes'
:
[
7
,
7
,
7
,
7
]}),
}
# yapf: disable
essential_keys
=
{
'embed_dims'
,
'depths'
,
'num_heads'
,
'patch_sizes'
,
'strides'
,
'mlp_ratios'
,
'sr_ratios'
,
'window_sizes'
}
def
__init__
(
self
,
arch
,
in_channels
=
3
,
out_indices
=
(
3
,
),
qkv_bias
=
False
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.0
,
norm_cfg
=
dict
(
type
=
'LN'
),
norm_after_stage
=
False
,
init_cfg
=
None
):
super
(
SVT
,
self
).
__init__
(
arch
,
in_channels
,
out_indices
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
drop_path_rate
,
norm_cfg
,
norm_after_stage
,
init_cfg
)
self
.
window_sizes
=
self
.
arch_settings
[
'window_sizes'
]
for
k
in
range
(
self
.
num_stage
):
for
i
in
range
(
self
.
depths
[
k
]):
# in even-numbered layers of each stage, replace GSA with LSA
if
i
%
2
==
0
:
ffn_channels
=
self
.
mlp_ratios
[
k
]
*
self
.
embed_dims
[
k
]
self
.
stages
[
k
][
i
]
=
\
LSAEncoderLayer
(
embed_dims
=
self
.
embed_dims
[
k
],
num_heads
=
self
.
num_heads
[
k
],
feedforward_channels
=
ffn_channels
,
drop_rate
=
drop_rate
,
norm_cfg
=
norm_cfg
,
attn_drop_rate
=
attn_drop_rate
,
drop_path_rate
=
self
.
dpr
[
sum
(
self
.
depths
[:
k
])
+
i
],
qkv_bias
=
qkv_bias
,
window_size
=
self
.
window_sizes
[
k
])
mmpretrain/models/backbones/van.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
Conv2d
,
build_activation_layer
,
build_norm_layer
from
mmcv.cnn.bricks
import
DropPath
from
mmcv.cnn.bricks.transformer
import
PatchEmbed
from
mmengine.model
import
BaseModule
,
ModuleList
from
mmengine.utils.dl_utils.parrots_wrapper
import
_BatchNorm
from
mmpretrain.registry
import
MODELS
from
.base_backbone
import
BaseBackbone
class
MixFFN
(
BaseModule
):
"""An implementation of MixFFN of VAN. Refer to
mmdetection/mmdet/models/backbones/pvt.py.
The differences between MixFFN & FFN:
1. Use 1X1 Conv to replace Linear layer.
2. Introduce 3X3 Depth-wise Conv to encode positional information.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`.
feedforward_channels (int): The hidden dimension of FFNs.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='GELU').
ffn_drop (float, optional): Probability of an element to be
zeroed in FFN. Default 0.0.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def
__init__
(
self
,
embed_dims
,
feedforward_channels
,
act_cfg
=
dict
(
type
=
'GELU'
),
ffn_drop
=
0.
,
init_cfg
=
None
):
super
(
MixFFN
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
embed_dims
=
embed_dims
self
.
feedforward_channels
=
feedforward_channels
self
.
act_cfg
=
act_cfg
self
.
fc1
=
Conv2d
(
in_channels
=
embed_dims
,
out_channels
=
feedforward_channels
,
kernel_size
=
1
)
self
.
dwconv
=
Conv2d
(
in_channels
=
feedforward_channels
,
out_channels
=
feedforward_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
,
groups
=
feedforward_channels
)
self
.
act
=
build_activation_layer
(
act_cfg
)
self
.
fc2
=
Conv2d
(
in_channels
=
feedforward_channels
,
out_channels
=
embed_dims
,
kernel_size
=
1
)
self
.
drop
=
nn
.
Dropout
(
ffn_drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
dwconv
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
class
LKA
(
BaseModule
):
"""Large Kernel Attention(LKA) of VAN.
.. code:: text
DW_conv (depth-wise convolution)
|
|
DW_D_conv (depth-wise dilation convolution)
|
|
Transition Convolution (1×1 convolution)
Args:
embed_dims (int): Number of input channels.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def
__init__
(
self
,
embed_dims
,
init_cfg
=
None
):
super
(
LKA
,
self
).
__init__
(
init_cfg
=
init_cfg
)
# a spatial local convolution (depth-wise convolution)
self
.
DW_conv
=
Conv2d
(
in_channels
=
embed_dims
,
out_channels
=
embed_dims
,
kernel_size
=
5
,
padding
=
2
,
groups
=
embed_dims
)
# a spatial long-range convolution (depth-wise dilation convolution)
self
.
DW_D_conv
=
Conv2d
(
in_channels
=
embed_dims
,
out_channels
=
embed_dims
,
kernel_size
=
7
,
stride
=
1
,
padding
=
9
,
groups
=
embed_dims
,
dilation
=
3
)
self
.
conv1
=
Conv2d
(
in_channels
=
embed_dims
,
out_channels
=
embed_dims
,
kernel_size
=
1
)
def
forward
(
self
,
x
):
u
=
x
.
clone
()
attn
=
self
.
DW_conv
(
x
)
attn
=
self
.
DW_D_conv
(
attn
)
attn
=
self
.
conv1
(
attn
)
return
u
*
attn
class
SpatialAttention
(
BaseModule
):
"""Basic attention module in VANBloack.
Args:
embed_dims (int): Number of input channels.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='GELU').
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def
__init__
(
self
,
embed_dims
,
act_cfg
=
dict
(
type
=
'GELU'
),
init_cfg
=
None
):
super
(
SpatialAttention
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
proj_1
=
Conv2d
(
in_channels
=
embed_dims
,
out_channels
=
embed_dims
,
kernel_size
=
1
)
self
.
activation
=
build_activation_layer
(
act_cfg
)
self
.
spatial_gating_unit
=
LKA
(
embed_dims
)
self
.
proj_2
=
Conv2d
(
in_channels
=
embed_dims
,
out_channels
=
embed_dims
,
kernel_size
=
1
)
def
forward
(
self
,
x
):
shorcut
=
x
.
clone
()
x
=
self
.
proj_1
(
x
)
x
=
self
.
activation
(
x
)
x
=
self
.
spatial_gating_unit
(
x
)
x
=
self
.
proj_2
(
x
)
x
=
x
+
shorcut
return
x
class
VANBlock
(
BaseModule
):
"""A block of VAN.
Args:
embed_dims (int): Number of input channels.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='GELU').
layer_scale_init_value (float): Init value for Layer Scale.
Defaults to 1e-2.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def
__init__
(
self
,
embed_dims
,
ffn_ratio
=
4.
,
drop_rate
=
0.
,
drop_path_rate
=
0.
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
dict
(
type
=
'BN'
,
eps
=
1e-5
),
layer_scale_init_value
=
1e-2
,
init_cfg
=
None
):
super
(
VANBlock
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
out_channels
=
embed_dims
self
.
norm1
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
self
.
attn
=
SpatialAttention
(
embed_dims
,
act_cfg
=
act_cfg
)
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.
else
nn
.
Identity
()
self
.
norm2
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
mlp_hidden_dim
=
int
(
embed_dims
*
ffn_ratio
)
self
.
mlp
=
MixFFN
(
embed_dims
=
embed_dims
,
feedforward_channels
=
mlp_hidden_dim
,
act_cfg
=
act_cfg
,
ffn_drop
=
drop_rate
)
self
.
layer_scale_1
=
nn
.
Parameter
(
layer_scale_init_value
*
torch
.
ones
((
embed_dims
)),
requires_grad
=
True
)
if
layer_scale_init_value
>
0
else
None
self
.
layer_scale_2
=
nn
.
Parameter
(
layer_scale_init_value
*
torch
.
ones
((
embed_dims
)),
requires_grad
=
True
)
if
layer_scale_init_value
>
0
else
None
def
forward
(
self
,
x
):
identity
=
x
x
=
self
.
norm1
(
x
)
x
=
self
.
attn
(
x
)
if
self
.
layer_scale_1
is
not
None
:
x
=
self
.
layer_scale_1
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
*
x
x
=
identity
+
self
.
drop_path
(
x
)
identity
=
x
x
=
self
.
norm2
(
x
)
x
=
self
.
mlp
(
x
)
if
self
.
layer_scale_2
is
not
None
:
x
=
self
.
layer_scale_2
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
*
x
x
=
identity
+
self
.
drop_path
(
x
)
return
x
class
VANPatchEmbed
(
PatchEmbed
):
"""Image to Patch Embedding of VAN.
The differences between VANPatchEmbed & PatchEmbed:
1. Use BN.
2. Do not use 'flatten' and 'transpose'.
"""
def
__init__
(
self
,
*
args
,
norm_cfg
=
dict
(
type
=
'BN'
),
**
kwargs
):
super
(
VANPatchEmbed
,
self
).
__init__
(
*
args
,
norm_cfg
=
norm_cfg
,
**
kwargs
)
def
forward
(
self
,
x
):
"""
Args:
x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, out_h * out_w, embed_dims)
- out_size (tuple[int]): Spatial shape of x, arrange as
(out_h, out_w).
"""
if
self
.
adaptive_padding
:
x
=
self
.
adaptive_padding
(
x
)
x
=
self
.
projection
(
x
)
out_size
=
(
x
.
shape
[
2
],
x
.
shape
[
3
])
if
self
.
norm
is
not
None
:
x
=
self
.
norm
(
x
)
return
x
,
out_size
@
MODELS
.
register_module
()
class
VAN
(
BaseBackbone
):
"""Visual Attention Network.
A PyTorch implement of : `Visual Attention Network
<https://arxiv.org/pdf/2202.09741v2.pdf>`_
Inspiration from
https://github.com/Visual-Attention-Network/VAN-Classification
Args:
arch (str | dict): Visual Attention Network architecture.
If use string, choose from 'tiny', 'small', 'base' and 'large'.
If use dict, it should have below keys:
- **embed_dims** (List[int]): The dimensions of embedding.
- **depths** (List[int]): The number of blocks in each stage.
- **ffn_ratios** (List[int]): The number of expansion ratio of
feedforward network hidden layer channels.
Defaults to 'tiny'.
patch_sizes (List[int | tuple]): The patch size in patch embeddings.
Defaults to [7, 3, 3, 3].
in_channels (int): The num of input channels. Defaults to 3.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
out_indices (Sequence[int]): Output from which stages.
Default: ``(3, )``.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
norm_cfg (dict): Config dict for normalization layer for all output
features. Defaults to ``dict(type='LN')``
block_cfgs (Sequence[dict] | dict): The extra config of each block.
Defaults to empty dicts.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmpretrain.models import VAN
>>> import torch
>>> cfg = dict(arch='tiny')
>>> model = VAN(**cfg)
>>> inputs = torch.rand(1, 3, 224, 224)
>>> outputs = model(inputs)
>>> for out in outputs:
>>> print(out.size())
(1, 256, 7, 7)
"""
arch_zoo
=
{
**
dict
.
fromkeys
([
't'
,
'tiny'
],
{
'embed_dims'
:
[
32
,
64
,
160
,
256
],
'depths'
:
[
3
,
3
,
5
,
2
],
'ffn_ratios'
:
[
8
,
8
,
4
,
4
]}),
**
dict
.
fromkeys
([
's'
,
'small'
],
{
'embed_dims'
:
[
64
,
128
,
320
,
512
],
'depths'
:
[
2
,
2
,
4
,
2
],
'ffn_ratios'
:
[
8
,
8
,
4
,
4
]}),
**
dict
.
fromkeys
([
'b'
,
'base'
],
{
'embed_dims'
:
[
64
,
128
,
320
,
512
],
'depths'
:
[
3
,
3
,
12
,
3
],
'ffn_ratios'
:
[
8
,
8
,
4
,
4
]}),
**
dict
.
fromkeys
([
'l'
,
'large'
],
{
'embed_dims'
:
[
64
,
128
,
320
,
512
],
'depths'
:
[
3
,
5
,
27
,
3
],
'ffn_ratios'
:
[
8
,
8
,
4
,
4
]}),
}
# yapf: disable
def
__init__
(
self
,
arch
=
'tiny'
,
patch_sizes
=
[
7
,
3
,
3
,
3
],
in_channels
=
3
,
drop_rate
=
0.
,
drop_path_rate
=
0.
,
out_indices
=
(
3
,
),
frozen_stages
=-
1
,
norm_eval
=
False
,
norm_cfg
=
dict
(
type
=
'LN'
),
block_cfgs
=
dict
(),
init_cfg
=
None
):
super
(
VAN
,
self
).
__init__
(
init_cfg
=
init_cfg
)
if
isinstance
(
arch
,
str
):
arch
=
arch
.
lower
()
assert
arch
in
set
(
self
.
arch_zoo
),
\
f
'Arch
{
arch
}
is not in default archs
{
set
(
self
.
arch_zoo
)
}
'
self
.
arch_settings
=
self
.
arch_zoo
[
arch
]
else
:
essential_keys
=
{
'embed_dims'
,
'depths'
,
'ffn_ratios'
}
assert
isinstance
(
arch
,
dict
)
and
set
(
arch
)
==
essential_keys
,
\
f
'Custom arch needs a dict with keys
{
essential_keys
}
'
self
.
arch_settings
=
arch
self
.
embed_dims
=
self
.
arch_settings
[
'embed_dims'
]
self
.
depths
=
self
.
arch_settings
[
'depths'
]
self
.
ffn_ratios
=
self
.
arch_settings
[
'ffn_ratios'
]
self
.
num_stages
=
len
(
self
.
depths
)
self
.
out_indices
=
out_indices
self
.
frozen_stages
=
frozen_stages
self
.
norm_eval
=
norm_eval
total_depth
=
sum
(
self
.
depths
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
total_depth
)
]
# stochastic depth decay rule
cur_block_idx
=
0
for
i
,
depth
in
enumerate
(
self
.
depths
):
patch_embed
=
VANPatchEmbed
(
in_channels
=
in_channels
if
i
==
0
else
self
.
embed_dims
[
i
-
1
],
input_size
=
None
,
embed_dims
=
self
.
embed_dims
[
i
],
kernel_size
=
patch_sizes
[
i
],
stride
=
patch_sizes
[
i
]
//
2
+
1
,
padding
=
(
patch_sizes
[
i
]
//
2
,
patch_sizes
[
i
]
//
2
),
norm_cfg
=
dict
(
type
=
'BN'
))
blocks
=
ModuleList
([
VANBlock
(
embed_dims
=
self
.
embed_dims
[
i
],
ffn_ratio
=
self
.
ffn_ratios
[
i
],
drop_rate
=
drop_rate
,
drop_path_rate
=
dpr
[
cur_block_idx
+
j
],
**
block_cfgs
)
for
j
in
range
(
depth
)
])
cur_block_idx
+=
depth
norm
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
[
i
])[
1
]
self
.
add_module
(
f
'patch_embed
{
i
+
1
}
'
,
patch_embed
)
self
.
add_module
(
f
'blocks
{
i
+
1
}
'
,
blocks
)
self
.
add_module
(
f
'norm
{
i
+
1
}
'
,
norm
)
def
train
(
self
,
mode
=
True
):
super
(
VAN
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
if
mode
and
self
.
norm_eval
:
for
m
in
self
.
modules
():
# trick: eval have effect on BatchNorm only
if
isinstance
(
m
,
_BatchNorm
):
m
.
eval
()
def
_freeze_stages
(
self
):
for
i
in
range
(
0
,
self
.
frozen_stages
+
1
):
# freeze patch embed
m
=
getattr
(
self
,
f
'patch_embed
{
i
+
1
}
'
)
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
# freeze blocks
m
=
getattr
(
self
,
f
'blocks
{
i
+
1
}
'
)
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
# freeze norm
m
=
getattr
(
self
,
f
'norm
{
i
+
1
}
'
)
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
x
):
outs
=
[]
for
i
in
range
(
self
.
num_stages
):
patch_embed
=
getattr
(
self
,
f
'patch_embed
{
i
+
1
}
'
)
blocks
=
getattr
(
self
,
f
'blocks
{
i
+
1
}
'
)
norm
=
getattr
(
self
,
f
'norm
{
i
+
1
}
'
)
x
,
hw_shape
=
patch_embed
(
x
)
for
block
in
blocks
:
x
=
block
(
x
)
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
x
=
norm
(
x
)
x
=
x
.
reshape
(
-
1
,
*
hw_shape
,
block
.
out_channels
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
if
i
in
self
.
out_indices
:
outs
.
append
(
x
)
return
tuple
(
outs
)
mmpretrain/models/backbones/vgg.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
import
torch.nn
as
nn
from
mmcv.cnn
import
ConvModule
from
mmengine.utils.dl_utils.parrots_wrapper
import
_BatchNorm
from
mmpretrain.registry
import
MODELS
from
.base_backbone
import
BaseBackbone
def
make_vgg_layer
(
in_channels
,
out_channels
,
num_blocks
,
conv_cfg
=
None
,
norm_cfg
=
None
,
act_cfg
=
dict
(
type
=
'ReLU'
),
dilation
=
1
,
with_norm
=
False
,
ceil_mode
=
False
):
layers
=
[]
for
_
in
range
(
num_blocks
):
layer
=
ConvModule
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
3
,
dilation
=
dilation
,
padding
=
dilation
,
bias
=
True
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
)
layers
.
append
(
layer
)
in_channels
=
out_channels
layers
.
append
(
nn
.
MaxPool2d
(
kernel_size
=
2
,
stride
=
2
,
ceil_mode
=
ceil_mode
))
return
layers
@
MODELS
.
register_module
()
class
VGG
(
BaseBackbone
):
"""VGG backbone.
Args:
depth (int): Depth of vgg, from {11, 13, 16, 19}.
with_norm (bool): Use BatchNorm or not.
num_classes (int): number of classes for classification.
num_stages (int): VGG stages, normally 5.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int], optional): Output from which stages.
When it is None, the default behavior depends on whether
num_classes is specified. If num_classes <= 0, the default value is
(4, ), output the last feature map before classifier. If
num_classes > 0, the default value is (5, ), output the
classification score. Default: None.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
ceil_mode (bool): Whether to use ceil_mode of MaxPool. Default: False.
with_last_pool (bool): Whether to keep the last pooling before
classifier. Default: True.
"""
# Parameters to build layers. Each element specifies the number of conv in
# each stage. For example, VGG11 contains 11 layers with learnable
# parameters. 11 is computed as 11 = (1 + 1 + 2 + 2 + 2) + 3,
# where 3 indicates the last three fully-connected layers.
arch_settings
=
{
11
:
(
1
,
1
,
2
,
2
,
2
),
13
:
(
2
,
2
,
2
,
2
,
2
),
16
:
(
2
,
2
,
3
,
3
,
3
),
19
:
(
2
,
2
,
4
,
4
,
4
)
}
def
__init__
(
self
,
depth
,
num_classes
=-
1
,
num_stages
=
5
,
dilations
=
(
1
,
1
,
1
,
1
,
1
),
out_indices
=
None
,
frozen_stages
=-
1
,
conv_cfg
=
None
,
norm_cfg
=
None
,
act_cfg
=
dict
(
type
=
'ReLU'
),
norm_eval
=
False
,
ceil_mode
=
False
,
with_last_pool
=
True
,
init_cfg
=
[
dict
(
type
=
'Kaiming'
,
layer
=
[
'Conv2d'
]),
dict
(
type
=
'Constant'
,
val
=
1.
,
layer
=
[
'_BatchNorm'
]),
dict
(
type
=
'Normal'
,
std
=
0.01
,
layer
=
[
'Linear'
])
]):
super
(
VGG
,
self
).
__init__
(
init_cfg
)
if
depth
not
in
self
.
arch_settings
:
raise
KeyError
(
f
'invalid depth
{
depth
}
for vgg'
)
assert
num_stages
>=
1
and
num_stages
<=
5
stage_blocks
=
self
.
arch_settings
[
depth
]
self
.
stage_blocks
=
stage_blocks
[:
num_stages
]
assert
len
(
dilations
)
==
num_stages
self
.
num_classes
=
num_classes
self
.
frozen_stages
=
frozen_stages
self
.
norm_eval
=
norm_eval
with_norm
=
norm_cfg
is
not
None
if
out_indices
is
None
:
out_indices
=
(
5
,
)
if
num_classes
>
0
else
(
4
,
)
assert
max
(
out_indices
)
<=
num_stages
self
.
out_indices
=
out_indices
self
.
in_channels
=
3
start_idx
=
0
vgg_layers
=
[]
self
.
range_sub_modules
=
[]
for
i
,
num_blocks
in
enumerate
(
self
.
stage_blocks
):
num_modules
=
num_blocks
+
1
end_idx
=
start_idx
+
num_modules
dilation
=
dilations
[
i
]
out_channels
=
64
*
2
**
i
if
i
<
4
else
512
vgg_layer
=
make_vgg_layer
(
self
.
in_channels
,
out_channels
,
num_blocks
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
,
act_cfg
=
act_cfg
,
dilation
=
dilation
,
with_norm
=
with_norm
,
ceil_mode
=
ceil_mode
)
vgg_layers
.
extend
(
vgg_layer
)
self
.
in_channels
=
out_channels
self
.
range_sub_modules
.
append
([
start_idx
,
end_idx
])
start_idx
=
end_idx
if
not
with_last_pool
:
vgg_layers
.
pop
(
-
1
)
self
.
range_sub_modules
[
-
1
][
1
]
-=
1
self
.
module_name
=
'features'
self
.
add_module
(
self
.
module_name
,
nn
.
Sequential
(
*
vgg_layers
))
if
self
.
num_classes
>
0
:
self
.
classifier
=
nn
.
Sequential
(
nn
.
Linear
(
512
*
7
*
7
,
4096
),
nn
.
ReLU
(
True
),
nn
.
Dropout
(),
nn
.
Linear
(
4096
,
4096
),
nn
.
ReLU
(
True
),
nn
.
Dropout
(),
nn
.
Linear
(
4096
,
num_classes
),
)
def
forward
(
self
,
x
):
outs
=
[]
vgg_layers
=
getattr
(
self
,
self
.
module_name
)
for
i
in
range
(
len
(
self
.
stage_blocks
)):
for
j
in
range
(
*
self
.
range_sub_modules
[
i
]):
vgg_layer
=
vgg_layers
[
j
]
x
=
vgg_layer
(
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
x
)
if
self
.
num_classes
>
0
:
x
=
x
.
view
(
x
.
size
(
0
),
-
1
)
x
=
self
.
classifier
(
x
)
outs
.
append
(
x
)
return
tuple
(
outs
)
def
_freeze_stages
(
self
):
vgg_layers
=
getattr
(
self
,
self
.
module_name
)
for
i
in
range
(
self
.
frozen_stages
):
for
j
in
range
(
*
self
.
range_sub_modules
[
i
]):
m
=
vgg_layers
[
j
]
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
def
train
(
self
,
mode
=
True
):
super
(
VGG
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
if
mode
and
self
.
norm_eval
:
for
m
in
self
.
modules
():
# trick: eval have effect on BatchNorm only
if
isinstance
(
m
,
_BatchNorm
):
m
.
eval
()
mmpretrain/models/backbones/vig.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
# modified from
# https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch
from
typing
import
Sequence
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
build_activation_layer
from
mmcv.cnn.bricks
import
DropPath
from
mmengine.model
import
ModuleList
,
Sequential
from
torch.nn.modules.batchnorm
import
_BatchNorm
from
mmpretrain.models.backbones.base_backbone
import
BaseBackbone
from
mmpretrain.registry
import
MODELS
from
..utils
import
build_norm_layer
def
get_2d_relative_pos_embed
(
embed_dim
,
grid_size
):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, grid_size*grid_size]
"""
pos_embed
=
get_2d_sincos_pos_embed
(
embed_dim
,
grid_size
)
relative_pos
=
2
*
np
.
matmul
(
pos_embed
,
pos_embed
.
transpose
())
/
pos_embed
.
shape
[
1
]
return
relative_pos
def
get_2d_sincos_pos_embed
(
embed_dim
,
grid_size
,
cls_token
=
False
):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h
=
np
.
arange
(
grid_size
,
dtype
=
np
.
float32
)
grid_w
=
np
.
arange
(
grid_size
,
dtype
=
np
.
float32
)
grid
=
np
.
meshgrid
(
grid_w
,
grid_h
)
# here w goes first
grid
=
np
.
stack
(
grid
,
axis
=
0
)
grid
=
grid
.
reshape
([
2
,
1
,
grid_size
,
grid_size
])
pos_embed
=
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
)
if
cls_token
:
pos_embed
=
np
.
concatenate
([
np
.
zeros
([
1
,
embed_dim
]),
pos_embed
],
axis
=
0
)
return
pos_embed
def
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
):
assert
embed_dim
%
2
==
0
# use half of dimensions to encode grid_h
emb_h
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
0
])
# (H*W, D/2)
emb_w
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
1
])
# (H*W, D/2)
emb
=
np
.
concatenate
([
emb_h
,
emb_w
],
axis
=
1
)
# (H*W, D)
return
emb
def
get_1d_sincos_pos_embed_from_grid
(
embed_dim
,
pos
):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert
embed_dim
%
2
==
0
omega
=
np
.
arange
(
embed_dim
//
2
,
dtype
=
np
.
float32
)
omega
/=
embed_dim
/
2.
omega
=
1.
/
10000
**
omega
# (D/2,)
pos
=
pos
.
reshape
(
-
1
)
# (M,)
out
=
np
.
einsum
(
'm,d->md'
,
pos
,
omega
)
# (M, D/2), outer product
emb_sin
=
np
.
sin
(
out
)
# (M, D/2)
emb_cos
=
np
.
cos
(
out
)
# (M, D/2)
emb
=
np
.
concatenate
([
emb_sin
,
emb_cos
],
axis
=
1
)
# (M, D)
return
emb
def
xy_pairwise_distance
(
x
,
y
):
"""Compute pairwise distance of a point cloud.
Args:
x: tensor (batch_size, num_points, num_dims)
y: tensor (batch_size, num_points, num_dims)
Returns:
pairwise distance: (batch_size, num_points, num_points)
"""
with
torch
.
no_grad
():
xy_inner
=
-
2
*
torch
.
matmul
(
x
,
y
.
transpose
(
2
,
1
))
x_square
=
torch
.
sum
(
torch
.
mul
(
x
,
x
),
dim
=-
1
,
keepdim
=
True
)
y_square
=
torch
.
sum
(
torch
.
mul
(
y
,
y
),
dim
=-
1
,
keepdim
=
True
)
return
x_square
+
xy_inner
+
y_square
.
transpose
(
2
,
1
)
def
xy_dense_knn_matrix
(
x
,
y
,
k
=
16
,
relative_pos
=
None
):
"""Get KNN based on the pairwise distance.
Args:
x: (batch_size, num_dims, num_points, 1)
y: (batch_size, num_dims, num_points, 1)
k: int
relative_pos:Whether to use relative_pos
Returns:
nearest neighbors:
(batch_size, num_points, k) (batch_size, num_points, k)
"""
with
torch
.
no_grad
():
x
=
x
.
transpose
(
2
,
1
).
squeeze
(
-
1
)
y
=
y
.
transpose
(
2
,
1
).
squeeze
(
-
1
)
batch_size
,
n_points
,
n_dims
=
x
.
shape
dist
=
xy_pairwise_distance
(
x
.
detach
(),
y
.
detach
())
if
relative_pos
is
not
None
:
dist
+=
relative_pos
_
,
nn_idx
=
torch
.
topk
(
-
dist
,
k
=
k
)
center_idx
=
torch
.
arange
(
0
,
n_points
,
device
=
x
.
device
).
repeat
(
batch_size
,
k
,
1
).
transpose
(
2
,
1
)
return
torch
.
stack
((
nn_idx
,
center_idx
),
dim
=
0
)
class
DenseDilated
(
nn
.
Module
):
"""Find dilated neighbor from neighbor list.
edge_index: (2, batch_size, num_points, k)
"""
def
__init__
(
self
,
k
=
9
,
dilation
=
1
,
use_stochastic
=
False
,
epsilon
=
0.0
):
super
(
DenseDilated
,
self
).
__init__
()
self
.
dilation
=
dilation
self
.
use_stochastic
=
use_stochastic
self
.
epsilon
=
epsilon
self
.
k
=
k
def
forward
(
self
,
edge_index
):
if
self
.
use_stochastic
:
if
torch
.
rand
(
1
)
<
self
.
epsilon
and
self
.
training
:
num
=
self
.
k
*
self
.
dilation
randnum
=
torch
.
randperm
(
num
)[:
self
.
k
]
edge_index
=
edge_index
[:,
:,
:,
randnum
]
else
:
edge_index
=
edge_index
[:,
:,
:,
::
self
.
dilation
]
else
:
edge_index
=
edge_index
[:,
:,
:,
::
self
.
dilation
]
return
edge_index
class
DenseDilatedKnnGraph
(
nn
.
Module
):
"""Find the neighbors' indices based on dilated knn."""
def
__init__
(
self
,
k
=
9
,
dilation
=
1
,
use_stochastic
=
False
,
epsilon
=
0.0
):
super
(
DenseDilatedKnnGraph
,
self
).
__init__
()
self
.
dilation
=
dilation
self
.
use_stochastic
=
use_stochastic
self
.
epsilon
=
epsilon
self
.
k
=
k
self
.
_dilated
=
DenseDilated
(
k
,
dilation
,
use_stochastic
,
epsilon
)
def
forward
(
self
,
x
,
y
=
None
,
relative_pos
=
None
):
if
y
is
not
None
:
x
=
F
.
normalize
(
x
,
p
=
2.0
,
dim
=
1
)
y
=
F
.
normalize
(
y
,
p
=
2.0
,
dim
=
1
)
edge_index
=
xy_dense_knn_matrix
(
x
,
y
,
self
.
k
*
self
.
dilation
,
relative_pos
)
else
:
x
=
F
.
normalize
(
x
,
p
=
2.0
,
dim
=
1
)
y
=
x
.
clone
()
edge_index
=
xy_dense_knn_matrix
(
x
,
y
,
self
.
k
*
self
.
dilation
,
relative_pos
)
return
self
.
_dilated
(
edge_index
)
class
BasicConv
(
Sequential
):
def
__init__
(
self
,
channels
,
act_cfg
,
norm_cfg
=
None
,
graph_conv_bias
=
True
,
drop
=
0.
):
m
=
[]
for
i
in
range
(
1
,
len
(
channels
)):
m
.
append
(
nn
.
Conv2d
(
channels
[
i
-
1
],
channels
[
i
],
1
,
bias
=
graph_conv_bias
,
groups
=
4
))
if
norm_cfg
is
not
None
:
m
.
append
(
build_norm_layer
(
norm_cfg
,
channels
[
-
1
]))
if
act_cfg
is
not
None
:
m
.
append
(
build_activation_layer
(
act_cfg
))
if
drop
>
0
:
m
.
append
(
nn
.
Dropout2d
(
drop
))
super
(
BasicConv
,
self
).
__init__
(
*
m
)
def
batched_index_select
(
x
,
idx
):
r
"""fetches neighbors features from a given neighbor idx
Args:
x (Tensor): input feature Tensor
:math:
`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times 1}`.
idx (Tensor): edge_idx
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times l}`.
Returns:
Tensor: output neighbors features
:math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times k}`.
"""
batch_size
,
num_dims
,
num_vertices_reduced
=
x
.
shape
[:
3
]
_
,
num_vertices
,
k
=
idx
.
shape
idx_base
=
torch
.
arange
(
0
,
batch_size
,
device
=
idx
.
device
).
view
(
-
1
,
1
,
1
)
*
num_vertices_reduced
idx
=
idx
+
idx_base
idx
=
idx
.
contiguous
().
view
(
-
1
)
x
=
x
.
transpose
(
2
,
1
)
feature
=
x
.
contiguous
().
view
(
batch_size
*
num_vertices_reduced
,
-
1
)[
idx
,
:]
feature
=
feature
.
view
(
batch_size
,
num_vertices
,
k
,
num_dims
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
return
feature
class
MRConv2d
(
nn
.
Module
):
"""Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751)
for dense data type."""
def
__init__
(
self
,
in_channels
,
out_channels
,
act_cfg
,
norm_cfg
=
None
,
graph_conv_bias
=
True
):
super
(
MRConv2d
,
self
).
__init__
()
self
.
nn
=
BasicConv
([
in_channels
*
2
,
out_channels
],
act_cfg
,
norm_cfg
,
graph_conv_bias
)
def
forward
(
self
,
x
,
edge_index
,
y
=
None
):
x_i
=
batched_index_select
(
x
,
edge_index
[
1
])
if
y
is
not
None
:
x_j
=
batched_index_select
(
y
,
edge_index
[
0
])
else
:
x_j
=
batched_index_select
(
x
,
edge_index
[
0
])
x_j
,
_
=
torch
.
max
(
x_j
-
x_i
,
-
1
,
keepdim
=
True
)
b
,
c
,
n
,
_
=
x
.
shape
x
=
torch
.
cat
([
x
.
unsqueeze
(
2
),
x_j
.
unsqueeze
(
2
)],
dim
=
2
).
reshape
(
b
,
2
*
c
,
n
,
_
)
return
self
.
nn
(
x
)
class
EdgeConv2d
(
nn
.
Module
):
"""Edge convolution layer (with activation, batch normalization) for dense
data type."""
def
__init__
(
self
,
in_channels
,
out_channels
,
act_cfg
,
norm_cfg
=
None
,
graph_conv_bias
=
True
):
super
(
EdgeConv2d
,
self
).
__init__
()
self
.
nn
=
BasicConv
([
in_channels
*
2
,
out_channels
],
act_cfg
,
norm_cfg
,
graph_conv_bias
)
def
forward
(
self
,
x
,
edge_index
,
y
=
None
):
x_i
=
batched_index_select
(
x
,
edge_index
[
1
])
if
y
is
not
None
:
x_j
=
batched_index_select
(
y
,
edge_index
[
0
])
else
:
x_j
=
batched_index_select
(
x
,
edge_index
[
0
])
max_value
,
_
=
torch
.
max
(
self
.
nn
(
torch
.
cat
([
x_i
,
x_j
-
x_i
],
dim
=
1
)),
-
1
,
keepdim
=
True
)
return
max_value
class
GraphSAGE
(
nn
.
Module
):
"""GraphSAGE Graph Convolution (Paper: https://arxiv.org/abs/1706.02216)
for dense data type."""
def
__init__
(
self
,
in_channels
,
out_channels
,
act_cfg
,
norm_cfg
=
None
,
graph_conv_bias
=
True
):
super
(
GraphSAGE
,
self
).
__init__
()
self
.
nn1
=
BasicConv
([
in_channels
,
in_channels
],
act_cfg
,
norm_cfg
,
graph_conv_bias
)
self
.
nn2
=
BasicConv
([
in_channels
*
2
,
out_channels
],
act_cfg
,
norm_cfg
,
graph_conv_bias
)
def
forward
(
self
,
x
,
edge_index
,
y
=
None
):
if
y
is
not
None
:
x_j
=
batched_index_select
(
y
,
edge_index
[
0
])
else
:
x_j
=
batched_index_select
(
x
,
edge_index
[
0
])
x_j
,
_
=
torch
.
max
(
self
.
nn1
(
x_j
),
-
1
,
keepdim
=
True
)
return
self
.
nn2
(
torch
.
cat
([
x
,
x_j
],
dim
=
1
))
class
GINConv2d
(
nn
.
Module
):
"""GIN Graph Convolution (Paper: https://arxiv.org/abs/1810.00826) for
dense data type."""
def
__init__
(
self
,
in_channels
,
out_channels
,
act_cfg
,
norm_cfg
=
None
,
graph_conv_bias
=
True
):
super
(
GINConv2d
,
self
).
__init__
()
self
.
nn
=
BasicConv
([
in_channels
,
out_channels
],
act_cfg
,
norm_cfg
,
graph_conv_bias
)
eps_init
=
0.0
self
.
eps
=
nn
.
Parameter
(
torch
.
Tensor
([
eps_init
]))
def
forward
(
self
,
x
,
edge_index
,
y
=
None
):
if
y
is
not
None
:
x_j
=
batched_index_select
(
y
,
edge_index
[
0
])
else
:
x_j
=
batched_index_select
(
x
,
edge_index
[
0
])
x_j
=
torch
.
sum
(
x_j
,
-
1
,
keepdim
=
True
)
return
self
.
nn
((
1
+
self
.
eps
)
*
x
+
x_j
)
class
GraphConv2d
(
nn
.
Module
):
"""Static graph convolution layer."""
def
__init__
(
self
,
in_channels
,
out_channels
,
graph_conv_type
,
act_cfg
,
norm_cfg
=
None
,
graph_conv_bias
=
True
):
super
(
GraphConv2d
,
self
).
__init__
()
if
graph_conv_type
==
'edge'
:
self
.
gconv
=
EdgeConv2d
(
in_channels
,
out_channels
,
act_cfg
,
norm_cfg
,
graph_conv_bias
)
elif
graph_conv_type
==
'mr'
:
self
.
gconv
=
MRConv2d
(
in_channels
,
out_channels
,
act_cfg
,
norm_cfg
,
graph_conv_bias
)
elif
graph_conv_type
==
'sage'
:
self
.
gconv
=
GraphSAGE
(
in_channels
,
out_channels
,
act_cfg
,
norm_cfg
,
graph_conv_bias
)
elif
graph_conv_type
==
'gin'
:
self
.
gconv
=
GINConv2d
(
in_channels
,
out_channels
,
act_cfg
,
norm_cfg
,
graph_conv_bias
)
else
:
raise
NotImplementedError
(
'graph_conv_type:{} is not supported'
.
format
(
graph_conv_type
))
def
forward
(
self
,
x
,
edge_index
,
y
=
None
):
return
self
.
gconv
(
x
,
edge_index
,
y
)
class
DyGraphConv2d
(
GraphConv2d
):
"""Dynamic graph convolution layer."""
def
__init__
(
self
,
in_channels
,
out_channels
,
k
=
9
,
dilation
=
1
,
graph_conv_type
=
'mr'
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
None
,
graph_conv_bias
=
True
,
use_stochastic
=
False
,
epsilon
=
0.2
,
r
=
1
):
super
(
DyGraphConv2d
,
self
).
__init__
(
in_channels
,
out_channels
,
graph_conv_type
,
act_cfg
,
norm_cfg
,
graph_conv_bias
)
self
.
k
=
k
self
.
d
=
dilation
self
.
r
=
r
self
.
dilated_knn_graph
=
DenseDilatedKnnGraph
(
k
,
dilation
,
use_stochastic
,
epsilon
)
def
forward
(
self
,
x
,
relative_pos
=
None
):
B
,
C
,
H
,
W
=
x
.
shape
y
=
None
if
self
.
r
>
1
:
y
=
F
.
avg_pool2d
(
x
,
self
.
r
,
self
.
r
)
y
=
y
.
reshape
(
B
,
C
,
-
1
,
1
).
contiguous
()
x
=
x
.
reshape
(
B
,
C
,
-
1
,
1
).
contiguous
()
edge_index
=
self
.
dilated_knn_graph
(
x
,
y
,
relative_pos
)
x
=
super
(
DyGraphConv2d
,
self
).
forward
(
x
,
edge_index
,
y
)
return
x
.
reshape
(
B
,
-
1
,
H
,
W
).
contiguous
()
class
Grapher
(
nn
.
Module
):
"""Grapher module with graph convolution and fc layers."""
def
__init__
(
self
,
in_channels
,
k
=
9
,
dilation
=
1
,
graph_conv_type
=
'mr'
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
None
,
graph_conv_bias
=
True
,
use_stochastic
=
False
,
epsilon
=
0.2
,
r
=
1
,
n
=
196
,
drop_path
=
0.0
,
relative_pos
=
False
):
super
(
Grapher
,
self
).
__init__
()
self
.
channels
=
in_channels
self
.
n
=
n
self
.
r
=
r
self
.
fc1
=
Sequential
(
nn
.
Conv2d
(
in_channels
,
in_channels
,
1
,
stride
=
1
,
padding
=
0
),
build_norm_layer
(
dict
(
type
=
'BN'
),
in_channels
),
)
self
.
graph_conv
=
DyGraphConv2d
(
in_channels
,
in_channels
*
2
,
k
,
dilation
,
graph_conv_type
,
act_cfg
,
norm_cfg
,
graph_conv_bias
,
use_stochastic
,
epsilon
,
r
)
self
.
fc2
=
Sequential
(
nn
.
Conv2d
(
in_channels
*
2
,
in_channels
,
1
,
stride
=
1
,
padding
=
0
),
build_norm_layer
(
dict
(
type
=
'BN'
),
in_channels
),
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
self
.
relative_pos
=
None
if
relative_pos
:
relative_pos_tensor
=
torch
.
from_numpy
(
np
.
float32
(
get_2d_relative_pos_embed
(
in_channels
,
int
(
n
**
0.5
)))).
unsqueeze
(
0
).
unsqueeze
(
1
)
relative_pos_tensor
=
F
.
interpolate
(
relative_pos_tensor
,
size
=
(
n
,
n
//
(
r
*
r
)),
mode
=
'bicubic'
,
align_corners
=
False
)
self
.
relative_pos
=
nn
.
Parameter
(
-
relative_pos_tensor
.
squeeze
(
1
),
requires_grad
=
False
)
def
_get_relative_pos
(
self
,
relative_pos
,
H
,
W
):
if
relative_pos
is
None
or
H
*
W
==
self
.
n
:
return
relative_pos
else
:
N
=
H
*
W
N_reduced
=
N
//
(
self
.
r
*
self
.
r
)
return
F
.
interpolate
(
relative_pos
.
unsqueeze
(
0
),
size
=
(
N
,
N_reduced
),
mode
=
'bicubic'
).
squeeze
(
0
)
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
relative_pos
=
self
.
_get_relative_pos
(
self
.
relative_pos
,
H
,
W
)
shortcut
=
x
x
=
self
.
fc1
(
x
)
x
=
self
.
graph_conv
(
x
,
relative_pos
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop_path
(
x
)
+
shortcut
return
x
class
FFN
(
nn
.
Module
):
""""out_features = out_features or in_features
\n
hidden_features = hidden_features or in_features"""
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_cfg
=
dict
(
type
=
'GELU'
),
drop_path
=
0.0
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
Sequential
(
nn
.
Conv2d
(
in_features
,
hidden_features
,
1
,
stride
=
1
,
padding
=
0
),
build_norm_layer
(
dict
(
type
=
'BN'
),
hidden_features
),
)
self
.
act
=
build_activation_layer
(
act_cfg
)
self
.
fc2
=
Sequential
(
nn
.
Conv2d
(
hidden_features
,
out_features
,
1
,
stride
=
1
,
padding
=
0
),
build_norm_layer
(
dict
(
type
=
'BN'
),
out_features
),
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
def
forward
(
self
,
x
):
shortcut
=
x
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop_path
(
x
)
+
shortcut
return
x
@
MODELS
.
register_module
()
class
Vig
(
BaseBackbone
):
"""Vision GNN backbone.
A PyTorch implementation of `Vision GNN: An Image is Worth Graph of Nodes
<https://arxiv.org/abs/2206.00272>`_.
Modified from the official implementation
https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch
Args:
arch(str): Vision GNN architecture,
choose from 'tiny', 'small' and 'base'.
in_channels (int): The number of channels of input images.
Defaults to 3.
k (int): The number of KNN's k. Defaults to 9.
out_indices (Sequence | int): Output from which blocks.
Defaults to -1, means the last block.
act_cfg (dict): The config of activative functions.
Defaults to ``dict(type='GELU'))``.
norm_cfg (dict): The config of normalization layers.
Defaults to ``dict(type='BN', eps=1e-6)``.
graph_conv_bias (bool): Whether to use bias in the convolution
layers in Grapher. Defaults to True.
graph_conv_type (str): The type of graph convolution,choose
from 'edge', 'mr', 'sage' and 'gin'. Defaults to 'mr'.
epsilon (float): Probability of random arrangement in KNN. It only
works when ``use_dilation=True`` and ``use_stochastic=True``.
Defaults to 0.2.
use_dilation(bool): Whether to use dilation in KNN. Defaults to True.
use_stochastic(bool): Whether to use stochastic in KNN.
Defaults to False.
drop_path (float): stochastic depth rate. Default 0.0
relative_pos(bool): Whether to use relative position embedding.
Defaults to False.
norm_eval (bool): Whether to set the normalization layer to eval mode.
Defaults to False.
frozen_stages (int): Blocks to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters.
init_cfg (dict, optional): The initialization configs.
Defaults to None.
"""
# noqa: E501
arch_settings
=
{
'tiny'
:
dict
(
num_blocks
=
12
,
channels
=
192
),
'small'
:
dict
(
num_blocks
=
16
,
channels
=
320
),
'base'
:
dict
(
num_blocks
=
16
,
channels
=
640
),
}
def
__init__
(
self
,
arch
,
in_channels
=
3
,
k
=
9
,
out_indices
=-
1
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
dict
(
type
=
'BN'
),
graph_conv_bias
=
True
,
graph_conv_type
=
'mr'
,
epsilon
=
0.2
,
use_dilation
=
True
,
use_stochastic
=
False
,
drop_path
=
0.
,
relative_pos
=
False
,
norm_eval
=
False
,
frozen_stages
=
0
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
=
init_cfg
)
arch
=
self
.
arch_settings
[
arch
]
self
.
num_blocks
=
arch
[
'num_blocks'
]
channels
=
arch
[
'channels'
]
if
isinstance
(
out_indices
,
int
):
out_indices
=
[
out_indices
]
elif
isinstance
(
out_indices
,
tuple
):
out_indices
=
list
(
out_indices
)
elif
not
isinstance
(
out_indices
,
list
):
raise
TypeError
(
'"out_indices" must by a tuple, list or int, '
f
'get
{
type
(
out_indices
)
}
instead.'
)
for
i
,
index
in
enumerate
(
out_indices
):
if
index
<
0
:
out_indices
[
i
]
=
self
.
num_blocks
+
index
assert
0
<=
out_indices
[
i
]
<=
self
.
num_blocks
,
\
f
'Invalid out_indices
{
index
}
'
self
.
out_indices
=
out_indices
self
.
stem
=
Sequential
(
nn
.
Conv2d
(
in_channels
,
channels
//
8
,
3
,
stride
=
2
,
padding
=
1
),
build_norm_layer
(
norm_cfg
,
channels
//
8
),
build_activation_layer
(
act_cfg
),
nn
.
Conv2d
(
channels
//
8
,
channels
//
4
,
3
,
stride
=
2
,
padding
=
1
),
build_norm_layer
(
norm_cfg
,
channels
//
4
),
build_activation_layer
(
act_cfg
),
nn
.
Conv2d
(
channels
//
4
,
channels
//
2
,
3
,
stride
=
2
,
padding
=
1
),
build_norm_layer
(
norm_cfg
,
channels
//
2
),
build_activation_layer
(
act_cfg
),
nn
.
Conv2d
(
channels
//
2
,
channels
,
3
,
stride
=
2
,
padding
=
1
),
build_norm_layer
(
norm_cfg
,
channels
),
build_activation_layer
(
act_cfg
),
nn
.
Conv2d
(
channels
,
channels
,
3
,
stride
=
1
,
padding
=
1
),
build_norm_layer
(
norm_cfg
,
channels
),
)
# stochastic depth decay rule
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path
,
self
.
num_blocks
)]
# number of knn's k
num_knn
=
[
int
(
x
.
item
())
for
x
in
torch
.
linspace
(
k
,
2
*
k
,
self
.
num_blocks
)
]
max_dilation
=
196
//
max
(
num_knn
)
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
channels
,
14
,
14
))
self
.
blocks
=
ModuleList
([
Sequential
(
Grapher
(
in_channels
=
channels
,
k
=
num_knn
[
i
],
dilation
=
min
(
i
//
4
+
1
,
max_dilation
)
if
use_dilation
else
1
,
graph_conv_type
=
graph_conv_type
,
act_cfg
=
act_cfg
,
norm_cfg
=
norm_cfg
,
graph_conv_bias
=
graph_conv_bias
,
use_stochastic
=
use_stochastic
,
epsilon
=
epsilon
,
drop_path
=
dpr
[
i
],
relative_pos
=
relative_pos
),
FFN
(
in_features
=
channels
,
hidden_features
=
channels
*
4
,
act_cfg
=
act_cfg
,
drop_path
=
dpr
[
i
]))
for
i
in
range
(
self
.
num_blocks
)
])
self
.
norm_eval
=
norm_eval
self
.
frozen_stages
=
frozen_stages
def
forward
(
self
,
inputs
):
outs
=
[]
x
=
self
.
stem
(
inputs
)
+
self
.
pos_embed
for
i
,
block
in
enumerate
(
self
.
blocks
):
x
=
block
(
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
x
)
return
tuple
(
outs
)
def
_freeze_stages
(
self
):
self
.
stem
.
eval
()
for
i
in
range
(
self
.
frozen_stages
):
m
=
self
.
blocks
[
i
]
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
def
train
(
self
,
mode
=
True
):
super
(
Vig
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
if
mode
and
self
.
norm_eval
:
for
m
in
self
.
modules
():
# trick: eval have effect on BatchNorm only
if
isinstance
(
m
,
_BatchNorm
):
m
.
eval
()
@
MODELS
.
register_module
()
class
PyramidVig
(
BaseBackbone
):
"""Pyramid Vision GNN backbone.
A PyTorch implementation of `Vision GNN: An Image is Worth Graph of Nodes
<https://arxiv.org/abs/2206.00272>`_.
Modified from the official implementation
https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch
Args:
arch (str): Vision GNN architecture, choose from 'tiny',
'small' and 'base'.
in_channels (int): The number of channels of input images.
Defaults to 3.
k (int): The number of KNN's k. Defaults to 9.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
act_cfg (dict): The config of activative functions.
Defaults to ``dict(type='GELU'))``.
norm_cfg (dict): The config of normalization layers.
Defaults to ``dict(type='BN')``.
graph_conv_bias (bool): Whether to use bias in the convolution
layers in Grapher. Defaults to True.
graph_conv_type (str): The type of graph convolution,choose
from 'edge', 'mr', 'sage' and 'gin'. Defaults to 'mr'.
epsilon (float): Probability of random arrangement in KNN. It only
works when ``use_stochastic=True``. Defaults to 0.2.
use_stochastic (bool): Whether to use stochastic in KNN.
Defaults to False.
drop_path (float): stochastic depth rate. Default 0.0
norm_eval (bool): Whether to set the normalization layer to eval mode.
Defaults to False.
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters.
init_cfg (dict, optional): The initialization configs.
Defaults to None.
"""
# noqa: E501
arch_settings
=
{
'tiny'
:
dict
(
blocks
=
[
2
,
2
,
6
,
2
],
channels
=
[
48
,
96
,
240
,
384
]),
'small'
:
dict
(
blocks
=
[
2
,
2
,
6
,
2
],
channels
=
[
80
,
160
,
400
,
640
]),
'medium'
:
dict
(
blocks
=
[
2
,
2
,
16
,
2
],
channels
=
[
96
,
192
,
384
,
768
]),
'base'
:
dict
(
blocks
=
[
2
,
2
,
18
,
2
],
channels
=
[
128
,
256
,
512
,
1024
]),
}
def
__init__
(
self
,
arch
,
in_channels
=
3
,
k
=
9
,
out_indices
=-
1
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
dict
(
type
=
'BN'
),
graph_conv_bias
=
True
,
graph_conv_type
=
'mr'
,
epsilon
=
0.2
,
use_stochastic
=
False
,
drop_path
=
0.
,
norm_eval
=
False
,
frozen_stages
=
0
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
=
init_cfg
)
arch
=
self
.
arch_settings
[
arch
]
self
.
blocks
=
arch
[
'blocks'
]
self
.
num_blocks
=
sum
(
self
.
blocks
)
self
.
num_stages
=
len
(
self
.
blocks
)
channels
=
arch
[
'channels'
]
self
.
channels
=
channels
if
isinstance
(
out_indices
,
int
):
out_indices
=
[
out_indices
]
assert
isinstance
(
out_indices
,
Sequence
),
\
f
'"out_indices" must by a sequence or int, '
\
f
'get
{
type
(
out_indices
)
}
instead.'
for
i
,
index
in
enumerate
(
out_indices
):
if
index
<
0
:
out_indices
[
i
]
=
self
.
num_stages
+
index
assert
0
<=
out_indices
[
i
]
<=
self
.
num_stages
,
\
f
'Invalid out_indices
{
index
}
'
self
.
out_indices
=
out_indices
self
.
stem
=
Sequential
(
nn
.
Conv2d
(
in_channels
,
channels
[
0
]
//
2
,
3
,
stride
=
2
,
padding
=
1
),
build_norm_layer
(
norm_cfg
,
channels
[
0
]
//
2
),
build_activation_layer
(
act_cfg
),
nn
.
Conv2d
(
channels
[
0
]
//
2
,
channels
[
0
],
3
,
stride
=
2
,
padding
=
1
),
build_norm_layer
(
norm_cfg
,
channels
[
0
]),
build_activation_layer
(
act_cfg
),
nn
.
Conv2d
(
channels
[
0
],
channels
[
0
],
3
,
stride
=
1
,
padding
=
1
),
build_norm_layer
(
norm_cfg
,
channels
[
0
]),
)
# stochastic depth decay rule
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path
,
self
.
num_blocks
)]
# number of knn's k
num_knn
=
[
int
(
x
.
item
())
for
x
in
torch
.
linspace
(
k
,
k
,
self
.
num_blocks
)
]
max_dilation
=
49
//
max
(
num_knn
)
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
channels
[
0
],
224
//
4
,
224
//
4
))
HW
=
224
//
4
*
224
//
4
reduce_ratios
=
[
4
,
2
,
1
,
1
]
self
.
stages
=
ModuleList
()
block_idx
=
0
for
stage_idx
,
num_blocks
in
enumerate
(
self
.
blocks
):
mid_channels
=
channels
[
stage_idx
]
reduce_ratio
=
reduce_ratios
[
stage_idx
]
blocks
=
[]
if
stage_idx
>
0
:
blocks
.
append
(
Sequential
(
nn
.
Conv2d
(
self
.
channels
[
stage_idx
-
1
],
mid_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
),
build_norm_layer
(
norm_cfg
,
mid_channels
),
))
HW
=
HW
//
4
for
_
in
range
(
num_blocks
):
blocks
.
append
(
Sequential
(
Grapher
(
in_channels
=
mid_channels
,
k
=
num_knn
[
block_idx
],
dilation
=
min
(
block_idx
//
4
+
1
,
max_dilation
),
graph_conv_type
=
graph_conv_type
,
act_cfg
=
act_cfg
,
norm_cfg
=
norm_cfg
,
graph_conv_bias
=
graph_conv_bias
,
use_stochastic
=
use_stochastic
,
epsilon
=
epsilon
,
r
=
reduce_ratio
,
n
=
HW
,
drop_path
=
dpr
[
block_idx
],
relative_pos
=
True
),
FFN
(
in_features
=
mid_channels
,
hidden_features
=
mid_channels
*
4
,
act_cfg
=
act_cfg
,
drop_path
=
dpr
[
block_idx
])))
block_idx
+=
1
self
.
stages
.
append
(
Sequential
(
*
blocks
))
self
.
norm_eval
=
norm_eval
self
.
frozen_stages
=
frozen_stages
def
forward
(
self
,
inputs
):
outs
=
[]
x
=
self
.
stem
(
inputs
)
+
self
.
pos_embed
for
i
,
blocks
in
enumerate
(
self
.
stages
):
x
=
blocks
(
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
x
)
return
tuple
(
outs
)
def
_freeze_stages
(
self
):
self
.
stem
.
eval
()
for
i
in
range
(
self
.
frozen_stages
):
m
=
self
.
stages
[
i
]
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
def
train
(
self
,
mode
=
True
):
super
(
PyramidVig
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
if
mode
and
self
.
norm_eval
:
for
m
in
self
.
modules
():
# trick: eval have effect on BatchNorm only
if
isinstance
(
m
,
_BatchNorm
):
m
.
eval
()
mmpretrain/models/backbones/vision_transformer.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Sequence
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.cnn.bricks.transformer
import
FFN
,
PatchEmbed
from
mmengine.model
import
BaseModule
,
ModuleList
from
mmengine.model.weight_init
import
trunc_normal_
from
mmpretrain.registry
import
MODELS
from
..utils
import
(
MultiheadAttention
,
SwiGLUFFNFused
,
build_norm_layer
,
resize_pos_embed
,
to_2tuple
)
from
.base_backbone
import
BaseBackbone
class
TransformerEncoderLayer
(
BaseModule
):
"""Implements one encoder layer in Vision Transformer.
Args:
embed_dims (int): The feature dimension
num_heads (int): Parallel attention heads
feedforward_channels (int): The hidden dimension for FFNs
layer_scale_init_value (float or torch.Tensor): Init value of layer
scale. Defaults to 0.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop_rate (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
ffn_type (str): Select the type of ffn layers. Defaults to 'origin'.
act_cfg (dict): The activation config for FFNs.
Defaults to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
feedforward_channels
,
layer_scale_init_value
=
0.
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
num_fcs
=
2
,
qkv_bias
=
True
,
ffn_type
=
'origin'
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
dict
(
type
=
'LN'
),
init_cfg
=
None
):
super
(
TransformerEncoderLayer
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
embed_dims
=
embed_dims
self
.
ln1
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
)
self
.
attn
=
MultiheadAttention
(
embed_dims
=
embed_dims
,
num_heads
=
num_heads
,
attn_drop
=
attn_drop_rate
,
proj_drop
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
qkv_bias
=
qkv_bias
,
layer_scale_init_value
=
layer_scale_init_value
)
self
.
ln2
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
)
if
ffn_type
==
'origin'
:
self
.
ffn
=
FFN
(
embed_dims
=
embed_dims
,
feedforward_channels
=
feedforward_channels
,
num_fcs
=
num_fcs
,
ffn_drop
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
act_cfg
=
act_cfg
,
layer_scale_init_value
=
layer_scale_init_value
)
elif
ffn_type
==
'swiglu_fused'
:
self
.
ffn
=
SwiGLUFFNFused
(
embed_dims
=
embed_dims
,
feedforward_channels
=
feedforward_channels
,
layer_scale_init_value
=
layer_scale_init_value
)
else
:
raise
NotImplementedError
@
property
def
norm1
(
self
):
return
self
.
ln1
@
property
def
norm2
(
self
):
return
self
.
ln2
def
init_weights
(
self
):
super
(
TransformerEncoderLayer
,
self
).
init_weights
()
for
m
in
self
.
ffn
.
modules
():
if
isinstance
(
m
,
nn
.
Linear
):
nn
.
init
.
xavier_uniform_
(
m
.
weight
)
nn
.
init
.
normal_
(
m
.
bias
,
std
=
1e-6
)
def
forward
(
self
,
x
):
x
=
x
+
self
.
attn
(
self
.
ln1
(
x
))
x
=
self
.
ffn
(
self
.
ln2
(
x
),
identity
=
x
)
return
x
@
MODELS
.
register_module
()
class
VisionTransformer
(
BaseBackbone
):
"""Vision Transformer.
A PyTorch implement of : `An Image is Worth 16x16 Words: Transformers
for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_
Args:
arch (str | dict): Vision Transformer architecture. If use string,
choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small'
and 'deit-base'. If use dict, it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
- **feedforward_channels** (int): The hidden dimensions in
feedforward modules.
Defaults to 'base'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
Defaults to ``"cls_token"``.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
layer_scale_init_value (float or torch.Tensor): Init value of layer
scale. Defaults to 0.
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo
=
{
**
dict
.
fromkeys
(
[
's'
,
'small'
],
{
'embed_dims'
:
768
,
'num_layers'
:
8
,
'num_heads'
:
8
,
'feedforward_channels'
:
768
*
3
,
}),
**
dict
.
fromkeys
(
[
'b'
,
'base'
],
{
'embed_dims'
:
768
,
'num_layers'
:
12
,
'num_heads'
:
12
,
'feedforward_channels'
:
3072
}),
**
dict
.
fromkeys
(
[
'l'
,
'large'
],
{
'embed_dims'
:
1024
,
'num_layers'
:
24
,
'num_heads'
:
16
,
'feedforward_channels'
:
4096
}),
**
dict
.
fromkeys
(
[
'h'
,
'huge'
],
{
# The same as the implementation in MAE
# <https://arxiv.org/abs/2111.06377>
'embed_dims'
:
1280
,
'num_layers'
:
32
,
'num_heads'
:
16
,
'feedforward_channels'
:
5120
}),
**
dict
.
fromkeys
(
[
'eva-g'
,
'eva-giant'
],
{
# The implementation in EVA
# <https://arxiv.org/abs/2211.07636>
'embed_dims'
:
1408
,
'num_layers'
:
40
,
'num_heads'
:
16
,
'feedforward_channels'
:
6144
}),
**
dict
.
fromkeys
(
[
'deit-t'
,
'deit-tiny'
],
{
'embed_dims'
:
192
,
'num_layers'
:
12
,
'num_heads'
:
3
,
'feedforward_channels'
:
192
*
4
}),
**
dict
.
fromkeys
(
[
'deit-s'
,
'deit-small'
,
'dinov2-s'
,
'dinov2-small'
],
{
'embed_dims'
:
384
,
'num_layers'
:
12
,
'num_heads'
:
6
,
'feedforward_channels'
:
384
*
4
}),
**
dict
.
fromkeys
(
[
'deit-b'
,
'deit-base'
],
{
'embed_dims'
:
768
,
'num_layers'
:
12
,
'num_heads'
:
12
,
'feedforward_channels'
:
768
*
4
}),
**
dict
.
fromkeys
(
[
'dinov2-g'
,
'dinov2-giant'
],
{
'embed_dims'
:
1536
,
'num_layers'
:
40
,
'num_heads'
:
24
,
'feedforward_channels'
:
6144
}),
}
num_extra_tokens
=
1
# class token
OUT_TYPES
=
{
'raw'
,
'cls_token'
,
'featmap'
,
'avg_featmap'
}
def
__init__
(
self
,
arch
=
'base'
,
img_size
=
224
,
patch_size
=
16
,
in_channels
=
3
,
out_indices
=-
1
,
drop_rate
=
0.
,
drop_path_rate
=
0.
,
qkv_bias
=
True
,
norm_cfg
=
dict
(
type
=
'LN'
,
eps
=
1e-6
),
final_norm
=
True
,
out_type
=
'cls_token'
,
with_cls_token
=
True
,
frozen_stages
=-
1
,
interpolate_mode
=
'bicubic'
,
layer_scale_init_value
=
0.
,
patch_cfg
=
dict
(),
layer_cfgs
=
dict
(),
pre_norm
=
False
,
init_cfg
=
None
):
super
(
VisionTransformer
,
self
).
__init__
(
init_cfg
)
if
isinstance
(
arch
,
str
):
arch
=
arch
.
lower
()
assert
arch
in
set
(
self
.
arch_zoo
),
\
f
'Arch
{
arch
}
is not in default archs
{
set
(
self
.
arch_zoo
)
}
'
self
.
arch_settings
=
self
.
arch_zoo
[
arch
]
else
:
essential_keys
=
{
'embed_dims'
,
'num_layers'
,
'num_heads'
,
'feedforward_channels'
}
assert
isinstance
(
arch
,
dict
)
and
essential_keys
<=
set
(
arch
),
\
f
'Custom arch needs a dict with keys
{
essential_keys
}
'
self
.
arch_settings
=
arch
self
.
embed_dims
=
self
.
arch_settings
[
'embed_dims'
]
self
.
num_layers
=
self
.
arch_settings
[
'num_layers'
]
self
.
img_size
=
to_2tuple
(
img_size
)
# Set patch embedding
_patch_cfg
=
dict
(
in_channels
=
in_channels
,
input_size
=
img_size
,
embed_dims
=
self
.
embed_dims
,
conv_type
=
'Conv2d'
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
bias
=
not
pre_norm
,
# disable bias if pre_norm is used(e.g., CLIP)
)
_patch_cfg
.
update
(
patch_cfg
)
self
.
patch_embed
=
PatchEmbed
(
**
_patch_cfg
)
self
.
patch_resolution
=
self
.
patch_embed
.
init_out_size
num_patches
=
self
.
patch_resolution
[
0
]
*
self
.
patch_resolution
[
1
]
# Set out type
if
out_type
not
in
self
.
OUT_TYPES
:
raise
ValueError
(
f
'Unsupported `out_type`
{
out_type
}
, please '
f
'choose from
{
self
.
OUT_TYPES
}
'
)
self
.
out_type
=
out_type
# Set cls token
self
.
with_cls_token
=
with_cls_token
if
with_cls_token
:
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
self
.
embed_dims
))
elif
out_type
!=
'cls_token'
:
self
.
cls_token
=
None
self
.
num_extra_tokens
=
0
else
:
raise
ValueError
(
'with_cls_token must be True when `out_type="cls_token"`.'
)
# Set position embedding
self
.
interpolate_mode
=
interpolate_mode
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
+
self
.
num_extra_tokens
,
self
.
embed_dims
))
self
.
_register_load_state_dict_pre_hook
(
self
.
_prepare_pos_embed
)
self
.
drop_after_pos
=
nn
.
Dropout
(
p
=
drop_rate
)
if
isinstance
(
out_indices
,
int
):
out_indices
=
[
out_indices
]
assert
isinstance
(
out_indices
,
Sequence
),
\
f
'"out_indices" must by a sequence or int, '
\
f
'get
{
type
(
out_indices
)
}
instead.'
for
i
,
index
in
enumerate
(
out_indices
):
if
index
<
0
:
out_indices
[
i
]
=
self
.
num_layers
+
index
assert
0
<=
out_indices
[
i
]
<=
self
.
num_layers
,
\
f
'Invalid out_indices
{
index
}
'
self
.
out_indices
=
out_indices
# stochastic depth decay rule
dpr
=
np
.
linspace
(
0
,
drop_path_rate
,
self
.
num_layers
)
self
.
layers
=
ModuleList
()
if
isinstance
(
layer_cfgs
,
dict
):
layer_cfgs
=
[
layer_cfgs
]
*
self
.
num_layers
for
i
in
range
(
self
.
num_layers
):
_layer_cfg
=
dict
(
embed_dims
=
self
.
embed_dims
,
num_heads
=
self
.
arch_settings
[
'num_heads'
],
feedforward_channels
=
self
.
arch_settings
[
'feedforward_channels'
],
layer_scale_init_value
=
layer_scale_init_value
,
drop_rate
=
drop_rate
,
drop_path_rate
=
dpr
[
i
],
qkv_bias
=
qkv_bias
,
norm_cfg
=
norm_cfg
)
_layer_cfg
.
update
(
layer_cfgs
[
i
])
self
.
layers
.
append
(
TransformerEncoderLayer
(
**
_layer_cfg
))
self
.
frozen_stages
=
frozen_stages
if
pre_norm
:
self
.
pre_norm
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
)
else
:
self
.
pre_norm
=
nn
.
Identity
()
self
.
final_norm
=
final_norm
if
final_norm
:
self
.
ln1
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
)
if
self
.
out_type
==
'avg_featmap'
:
self
.
ln2
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
)
# freeze stages only when self.frozen_stages > 0
if
self
.
frozen_stages
>
0
:
self
.
_freeze_stages
()
@
property
def
norm1
(
self
):
return
self
.
ln1
@
property
def
norm2
(
self
):
return
self
.
ln2
def
init_weights
(
self
):
super
(
VisionTransformer
,
self
).
init_weights
()
if
not
(
isinstance
(
self
.
init_cfg
,
dict
)
and
self
.
init_cfg
[
'type'
]
==
'Pretrained'
):
if
self
.
pos_embed
is
not
None
:
trunc_normal_
(
self
.
pos_embed
,
std
=
0.02
)
def
_prepare_pos_embed
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
name
=
prefix
+
'pos_embed'
if
name
not
in
state_dict
.
keys
():
return
ckpt_pos_embed_shape
=
state_dict
[
name
].
shape
if
(
not
self
.
with_cls_token
and
ckpt_pos_embed_shape
[
1
]
==
self
.
pos_embed
.
shape
[
1
]
+
1
):
# Remove cls token from state dict if it's not used.
state_dict
[
name
]
=
state_dict
[
name
][:,
1
:]
ckpt_pos_embed_shape
=
state_dict
[
name
].
shape
if
self
.
pos_embed
.
shape
!=
ckpt_pos_embed_shape
:
from
mmengine.logging
import
MMLogger
logger
=
MMLogger
.
get_current_instance
()
logger
.
info
(
f
'Resize the pos_embed shape from
{
ckpt_pos_embed_shape
}
'
f
'to
{
self
.
pos_embed
.
shape
}
.'
)
ckpt_pos_embed_shape
=
to_2tuple
(
int
(
np
.
sqrt
(
ckpt_pos_embed_shape
[
1
]
-
self
.
num_extra_tokens
)))
pos_embed_shape
=
self
.
patch_embed
.
init_out_size
state_dict
[
name
]
=
resize_pos_embed
(
state_dict
[
name
],
ckpt_pos_embed_shape
,
pos_embed_shape
,
self
.
interpolate_mode
,
self
.
num_extra_tokens
)
@
staticmethod
def
resize_pos_embed
(
*
args
,
**
kwargs
):
"""Interface for backward-compatibility."""
return
resize_pos_embed
(
*
args
,
**
kwargs
)
def
_freeze_stages
(
self
):
# freeze position embedding
if
self
.
pos_embed
is
not
None
:
self
.
pos_embed
.
requires_grad
=
False
# set dropout to eval model
self
.
drop_after_pos
.
eval
()
# freeze patch embedding
self
.
patch_embed
.
eval
()
for
param
in
self
.
patch_embed
.
parameters
():
param
.
requires_grad
=
False
# freeze pre-norm
for
param
in
self
.
pre_norm
.
parameters
():
param
.
requires_grad
=
False
# freeze cls_token
if
self
.
cls_token
is
not
None
:
self
.
cls_token
.
requires_grad
=
False
# freeze layers
for
i
in
range
(
1
,
self
.
frozen_stages
+
1
):
m
=
self
.
layers
[
i
-
1
]
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
# freeze the last layer norm
if
self
.
frozen_stages
==
len
(
self
.
layers
):
if
self
.
final_norm
:
self
.
ln1
.
eval
()
for
param
in
self
.
ln1
.
parameters
():
param
.
requires_grad
=
False
if
self
.
out_type
==
'avg_featmap'
:
self
.
ln2
.
eval
()
for
param
in
self
.
ln2
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
x
):
B
=
x
.
shape
[
0
]
x
,
patch_resolution
=
self
.
patch_embed
(
x
)
if
self
.
cls_token
is
not
None
:
# stole cls_tokens impl from Phil Wang, thanks
cls_token
=
self
.
cls_token
.
expand
(
B
,
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_token
,
x
),
dim
=
1
)
x
=
x
+
resize_pos_embed
(
self
.
pos_embed
,
self
.
patch_resolution
,
patch_resolution
,
mode
=
self
.
interpolate_mode
,
num_extra_tokens
=
self
.
num_extra_tokens
)
x
=
self
.
drop_after_pos
(
x
)
x
=
self
.
pre_norm
(
x
)
outs
=
[]
for
i
,
layer
in
enumerate
(
self
.
layers
):
x
=
layer
(
x
)
if
i
==
len
(
self
.
layers
)
-
1
and
self
.
final_norm
:
x
=
self
.
ln1
(
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
self
.
_format_output
(
x
,
patch_resolution
))
return
tuple
(
outs
)
def
_format_output
(
self
,
x
,
hw
):
if
self
.
out_type
==
'raw'
:
return
x
if
self
.
out_type
==
'cls_token'
:
return
x
[:,
0
]
patch_token
=
x
[:,
self
.
num_extra_tokens
:]
if
self
.
out_type
==
'featmap'
:
B
=
x
.
size
(
0
)
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
return
patch_token
.
reshape
(
B
,
*
hw
,
-
1
).
permute
(
0
,
3
,
1
,
2
)
if
self
.
out_type
==
'avg_featmap'
:
return
self
.
ln2
(
patch_token
.
mean
(
dim
=
1
))
def
get_layer_depth
(
self
,
param_name
:
str
,
prefix
:
str
=
''
):
"""Get the layer-wise depth of a parameter.
Args:
param_name (str): The name of the parameter.
prefix (str): The prefix for the parameter.
Defaults to an empty string.
Returns:
Tuple[int, int]: The layer-wise depth and the num of layers.
Note:
The first depth is the stem module (``layer_depth=0``), and the
last depth is the subsequent module (``layer_depth=num_layers-1``)
"""
num_layers
=
self
.
num_layers
+
2
if
not
param_name
.
startswith
(
prefix
):
# For subsequent module like head
return
num_layers
-
1
,
num_layers
param_name
=
param_name
[
len
(
prefix
):]
if
param_name
in
(
'cls_token'
,
'pos_embed'
):
layer_depth
=
0
elif
param_name
.
startswith
(
'patch_embed'
):
layer_depth
=
0
elif
param_name
.
startswith
(
'layers'
):
layer_id
=
int
(
param_name
.
split
(
'.'
)[
1
])
layer_depth
=
layer_id
+
1
else
:
layer_depth
=
num_layers
-
1
return
layer_depth
,
num_layers
mmpretrain/models/backbones/vit_eva02.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.cnn.bricks.drop
import
build_dropout
from
mmengine.model
import
BaseModule
,
ModuleList
from
mmpretrain.registry
import
MODELS
from
..utils
import
(
RotaryEmbeddingFast
,
SwiGLUFFN
,
build_norm_layer
,
resize_pos_embed
)
from
.vision_transformer
import
VisionTransformer
class
AttentionWithRoPE
(
BaseModule
):
"""Multi-head Attention Module with 2D sincos position embedding (RoPE).
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
attn_drop (float): Dropout rate of the dropout layer after the
attention calculation of query and key. Defaults to 0.
proj_drop (float): Dropout rate of the dropout layer after the
output projection. Defaults to 0.
qkv_bias (bool): If True, add a learnable bias to q and v. Note
that we follows the official implementation where ``k_bias``
is 0. Defaults to True.
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
proj_bias (bool) If True, add a learnable bias to output projection.
Defaults to True.
rope (:obj:`torch.nn.Module`, optional): If it is an object of the
``RotaryEmbedding``, the rotation of the token position will be
performed before the softmax. Defaults to None.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
attn_drop
=
0.
,
proj_drop
=
0.
,
qkv_bias
=
True
,
qk_scale
=
None
,
proj_bias
=
True
,
rope
=
None
,
with_cls_token
=
True
,
init_cfg
=
None
):
super
(
AttentionWithRoPE
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
embed_dims
=
embed_dims
self
.
num_heads
=
num_heads
self
.
head_dims
=
embed_dims
//
num_heads
self
.
scale
=
qk_scale
or
self
.
head_dims
**-
0.5
self
.
qkv
=
nn
.
Linear
(
embed_dims
,
embed_dims
*
3
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
,
bias
=
proj_bias
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
self
.
with_cls_token
=
with_cls_token
self
.
rope
=
rope
def
forward
(
self
,
x
,
patch_resolution
):
B
,
N
,
_
=
x
.
shape
qkv
=
self
.
qkv
(
x
)
qkv
=
qkv
.
reshape
(
B
,
N
,
3
,
self
.
num_heads
,
-
1
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
.
unbind
(
dim
=
0
)
if
self
.
rope
:
if
self
.
with_cls_token
:
q_t
=
q
[:,
:,
1
:,
:]
ro_q_t
=
self
.
rope
(
q_t
,
patch_resolution
)
q
=
torch
.
cat
((
q
[:,
:,
:
1
,
:],
ro_q_t
),
-
2
).
type_as
(
v
)
k_t
=
k
[:,
:,
1
:,
:]
if
self
.
with_cls_token
else
k
ro_k_t
=
self
.
rope
(
k_t
,
patch_resolution
)
k
=
torch
.
cat
((
k
[:,
:,
:
1
,
:],
ro_k_t
),
-
2
).
type_as
(
v
)
else
:
q
=
self
.
rope
(
q
,
patch_resolution
)
k
=
self
.
rope
(
k
,
patch_resolution
)
q
=
q
*
self
.
scale
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
attn
=
attn
.
softmax
(
dim
=-
1
).
type_as
(
x
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
-
1
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
EVA02EndcoderLayer
(
BaseModule
):
"""Implements one encoder EVA02EndcoderLayer in EVA02.
Args:
embed_dims (int): The feature dimension
num_heads (int): Parallel attention heads
feedforward_channels (int): The hidden dimension of FFNs.
sub_ln (bool): Whether to add the sub layer normalization
in the attention module. Defaults to False.
attn_drop (float): Dropout rate of the dropout layer after the
attention calculation of query and key. Defaults to 0.
proj_drop (float): Dropout rate of the dropout layer after the
output projection. Defaults to 0.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
proj_bias (bool): enable bias for projection in the attention module
if True. Defaults to True.
rope (:obj:`torch.nn.Module`, optional): RotaryEmbedding object
in the attention module. Defaults to None.
drop_rate (float): Dropout rate in the mlp module. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
feedforward_channels
,
sub_ln
=
False
,
attn_drop
=
0.
,
proj_drop
=
0.
,
qkv_bias
=
False
,
qk_scale
=
None
,
proj_bias
=
True
,
rope
=
None
,
with_cls_token
=
True
,
drop_rate
=
0.
,
drop_path_rate
=
0.
,
norm_cfg
=
dict
(
type
=
'LN'
),
init_cfg
=
None
):
super
(
EVA02EndcoderLayer
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
norm1
=
build_norm_layer
(
norm_cfg
,
embed_dims
)
self
.
attn
=
AttentionWithRoPE
(
embed_dims
=
embed_dims
,
num_heads
=
num_heads
,
attn_drop
=
attn_drop
,
proj_drop
=
proj_drop
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
proj_bias
=
proj_bias
,
rope
=
rope
,
with_cls_token
=
with_cls_token
)
self
.
drop_path
=
build_dropout
(
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
))
self
.
norm2
=
build_norm_layer
(
norm_cfg
,
embed_dims
)
if
drop_rate
>
0
:
dropout_layer
=
dict
(
type
=
'Dropout'
,
drop_prob
=
drop_rate
)
else
:
dropout_layer
=
None
if
sub_ln
:
ffn_norm
=
norm_cfg
else
:
ffn_norm
=
None
self
.
mlp
=
SwiGLUFFN
(
embed_dims
=
embed_dims
,
feedforward_channels
=
feedforward_channels
,
dropout_layer
=
dropout_layer
,
norm_cfg
=
ffn_norm
,
add_identity
=
False
,
)
def
forward
(
self
,
x
,
patch_resolution
):
inputs
=
x
x
=
self
.
norm1
(
x
)
x
=
self
.
attn
(
x
,
patch_resolution
)
x
=
self
.
drop_path
(
x
)
x
=
inputs
+
x
inputs
=
x
x
=
self
.
norm2
(
x
)
x
=
self
.
mlp
(
x
)
x
=
self
.
drop_path
(
x
)
x
=
inputs
+
x
return
x
@
MODELS
.
register_module
()
class
ViTEVA02
(
VisionTransformer
):
"""EVA02 Vision Transformer.
A PyTorch implement of : `EVA-02: A Visual Representation for Neon Genesis
<https://arxiv.org/abs/2303.11331>`_
Args:
arch (str | dict): Vision Transformer architecture. If use string,
choose from 'tiny', 'small', 'base', 'large'. If use dict,
it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
- **mlp_ratio** (float): The ratio of the mlp module.
Defaults to 'tiny'.
sub_ln (bool): Whether to add the sub layer normalization in swiglu.
Defaults to False.
drop_rate (float): Probability of an element to be zeroed in the
mlp module. Defaults to 0.
attn_drop_rate (float): Probability of an element to be zeroed after
the softmax in the attention. Defaults to 0.
proj_drop_rate (float): Probability of an element to be zeroed after
projection in the attention. Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
**kwargs(dict, optional): Other args for Vision Transformer.
"""
arch_zoo
=
{
**
dict
.
fromkeys
(
[
't'
,
'ti'
,
'tiny'
],
{
'embed_dims'
:
192
,
'num_layers'
:
12
,
'num_heads'
:
3
,
'feedforward_channels'
:
int
(
192
*
4
*
2
/
3
)
}),
**
dict
.
fromkeys
(
[
's'
,
'small'
],
{
'embed_dims'
:
384
,
'num_layers'
:
12
,
'num_heads'
:
6
,
'feedforward_channels'
:
int
(
384
*
4
*
2
/
3
)
}),
**
dict
.
fromkeys
(
[
'b'
,
'base'
],
{
'embed_dims'
:
768
,
'num_layers'
:
12
,
'num_heads'
:
12
,
'feedforward_channels'
:
int
(
768
*
4
*
2
/
3
)
}),
**
dict
.
fromkeys
(
[
'l'
,
'large'
],
{
'embed_dims'
:
1024
,
'num_layers'
:
24
,
'num_heads'
:
16
,
'feedforward_channels'
:
int
(
1024
*
4
*
2
/
3
)
})
}
num_extra_tokens
=
1
# class token
OUT_TYPES
=
{
'raw'
,
'cls_token'
,
'featmap'
,
'avg_featmap'
}
def
__init__
(
self
,
arch
=
'tiny'
,
sub_ln
=
False
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
proj_drop_rate
=
0.
,
drop_path_rate
=
0.
,
qkv_bias
=
True
,
norm_cfg
=
dict
(
type
=
'LN'
),
with_cls_token
=
True
,
layer_cfgs
=
dict
(),
**
kwargs
):
# set essential args for Vision Transformer
kwargs
.
update
(
arch
=
arch
,
drop_rate
=
drop_rate
,
drop_path_rate
=
drop_path_rate
,
norm_cfg
=
norm_cfg
,
with_cls_token
=
with_cls_token
)
super
(
ViTEVA02
,
self
).
__init__
(
**
kwargs
)
self
.
num_heads
=
self
.
arch_settings
[
'num_heads'
]
# Set RoPE
head_dim
=
self
.
embed_dims
//
self
.
num_heads
self
.
rope
=
RotaryEmbeddingFast
(
embed_dims
=
head_dim
,
patch_resolution
=
self
.
patch_resolution
)
# stochastic depth decay rule
dpr
=
np
.
linspace
(
0
,
drop_path_rate
,
self
.
num_layers
)
self
.
layers
=
ModuleList
()
if
isinstance
(
layer_cfgs
,
dict
):
layer_cfgs
=
[
layer_cfgs
]
*
self
.
num_layers
for
i
in
range
(
self
.
num_layers
):
_layer_cfg
=
dict
(
embed_dims
=
self
.
embed_dims
,
num_heads
=
self
.
num_heads
,
feedforward_channels
=
self
.
arch_settings
[
'feedforward_channels'
],
sub_ln
=
sub_ln
,
norm_cfg
=
norm_cfg
,
proj_drop
=
proj_drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_rate
=
drop_rate
,
qkv_bias
=
qkv_bias
,
rope
=
self
.
rope
,
with_cls_token
=
with_cls_token
,
drop_path_rate
=
dpr
[
i
])
_layer_cfg
.
update
(
layer_cfgs
[
i
])
self
.
layers
.
append
(
EVA02EndcoderLayer
(
**
_layer_cfg
))
def
forward
(
self
,
x
):
B
=
x
.
shape
[
0
]
x
,
patch_resolution
=
self
.
patch_embed
(
x
)
if
self
.
cls_token
is
not
None
:
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens
=
self
.
cls_token
.
expand
(
B
,
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_tokens
,
x
),
dim
=
1
)
x
=
x
+
resize_pos_embed
(
self
.
pos_embed
,
self
.
patch_resolution
,
patch_resolution
,
mode
=
self
.
interpolate_mode
,
num_extra_tokens
=
self
.
num_extra_tokens
)
x
=
self
.
drop_after_pos
(
x
)
x
=
self
.
pre_norm
(
x
)
outs
=
[]
for
i
,
layer
in
enumerate
(
self
.
layers
):
x
=
layer
(
x
,
patch_resolution
)
if
i
==
len
(
self
.
layers
)
-
1
and
self
.
final_norm
:
x
=
self
.
ln1
(
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
self
.
_format_output
(
x
,
patch_resolution
))
return
tuple
(
outs
)
mmpretrain/models/backbones/vit_sam.py
0 → 100644
View file @
cbc25585
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Optional
,
Sequence
,
Tuple
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn.bricks.transformer
import
FFN
,
PatchEmbed
from
mmengine.model
import
BaseModule
,
ModuleList
from
mmengine.model.weight_init
import
trunc_normal_
from
mmpretrain.registry
import
MODELS
from
..utils
import
LayerNorm2d
,
build_norm_layer
,
resize_pos_embed
,
to_2tuple
from
.base_backbone
import
BaseBackbone
def
window_partition
(
x
:
torch
.
Tensor
,
window_size
:
int
)
->
Tuple
[
torch
.
Tensor
,
Tuple
[
int
,
int
]]:
"""Partition into non-overlapping windows with padding if needed.
Borrowed from https://github.com/facebookresearch/segment-anything/
Args:
x (torch.Tensor): Input tokens with [B, H, W, C].
window_size (int): Window size.
Returns:
Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
- ``windows``: Windows after partition with
[B * num_windows, window_size, window_size, C].
- ``(Hp, Wp)``: Padded height and width before partition
"""
B
,
H
,
W
,
C
=
x
.
shape
pad_h
=
(
window_size
-
H
%
window_size
)
%
window_size
pad_w
=
(
window_size
-
W
%
window_size
)
%
window_size
if
pad_h
>
0
or
pad_w
>
0
:
x
=
F
.
pad
(
x
,
(
0
,
0
,
0
,
pad_w
,
0
,
pad_h
))
Hp
,
Wp
=
H
+
pad_h
,
W
+
pad_w
x
=
x
.
view
(
B
,
Hp
//
window_size
,
window_size
,
Wp
//
window_size
,
window_size
,
C
)
windows
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
-
1
,
window_size
,
window_size
,
C
)
return
windows
,
(
Hp
,
Wp
)
def
window_unpartition
(
windows
:
torch
.
Tensor
,
window_size
:
int
,
pad_hw
:
Tuple
[
int
,
int
],
hw
:
Tuple
[
int
,
int
])
->
torch
.
Tensor
:
"""Window unpartition into original sequences and removing padding.
Borrowed from https://github.com/facebookresearch/segment-anything/
Args:
x (torch.Tensor): Input tokens with
[B * num_windows, window_size, window_size, C].
window_size (int): Window size.
pad_hw (tuple): Padded height and width (Hp, Wp).
hw (tuple): Original height and width (H, W) before padding.
Returns:
torch.Tensor: Unpartitioned sequences with [B, H, W, C].
"""
Hp
,
Wp
=
pad_hw
H
,
W
=
hw
B
=
windows
.
shape
[
0
]
//
(
Hp
*
Wp
//
window_size
//
window_size
)
x
=
windows
.
view
(
B
,
Hp
//
window_size
,
Wp
//
window_size
,
window_size
,
window_size
,
-
1
)
x
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
B
,
Hp
,
Wp
,
-
1
)
if
Hp
>
H
or
Wp
>
W
:
x
=
x
[:,
:
H
,
:
W
,
:].
contiguous
()
return
x
def
get_rel_pos
(
q_size
:
int
,
k_size
:
int
,
rel_pos
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Get relative positional embeddings according to the relative positions
of query and key sizes.
Borrowed from https://github.com/facebookresearch/segment-anything/
Args:
q_size (int): Size of query q.
k_size (int): Size of key k.
rel_pos (torch.Tensor): Relative position embeddings (L, C).
Returns:
torch.Tensor: Extracted positional embeddings according to relative
positions.
"""
max_rel_dist
=
int
(
2
*
max
(
q_size
,
k_size
)
-
1
)
# Interpolate rel pos if needed.
if
rel_pos
.
shape
[
0
]
!=
max_rel_dist
:
# Interpolate rel pos.
rel_pos_resized
=
F
.
interpolate
(
rel_pos
.
reshape
(
1
,
rel_pos
.
shape
[
0
],
-
1
).
permute
(
0
,
2
,
1
),
size
=
max_rel_dist
,
mode
=
'linear'
,
)
rel_pos_resized
=
rel_pos_resized
.
reshape
(
-
1
,
max_rel_dist
).
permute
(
1
,
0
)
else
:
rel_pos_resized
=
rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords
=
torch
.
arange
(
q_size
)[:,
None
]
*
max
(
k_size
/
q_size
,
1.0
)
k_coords
=
torch
.
arange
(
k_size
)[
None
,
:]
*
max
(
q_size
/
k_size
,
1.0
)
relative_coords
=
(
q_coords
-
k_coords
)
+
(
k_size
-
1
)
*
max
(
q_size
/
k_size
,
1.0
)
return
rel_pos_resized
[
relative_coords
.
long
()]
def
add_decomposed_rel_pos
(
attn
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
rel_pos_h
:
torch
.
Tensor
,
rel_pos_w
:
torch
.
Tensor
,
q_size
:
Tuple
[
int
,
int
],
k_size
:
Tuple
[
int
,
int
],
)
->
torch
.
Tensor
:
"""Borrowed from https://github.com/facebookresearch/segment-anything/
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
Args:
attn (torch.Tensor): Attention map.
q (torch.Tensor): Query q in the attention layer with shape
(B, q_h * q_w, C).
rel_pos_h (torch.Tensor): Relative position embeddings (Lh, C) for
height axis.
rel_pos_w (torch.Tensor): Relative position embeddings (Lw, C) for
width axis.
q_size (tuple): Spatial sequence size of query q with (q_h, q_w).
k_size (tuple): Spatial sequence size of key k with (k_h, k_w).
Returns:
torch.Tensor: Attention map with added relative positional embeddings.
"""
q_h
,
q_w
=
q_size
k_h
,
k_w
=
k_size
Rh
=
get_rel_pos
(
q_h
,
k_h
,
rel_pos_h
)
Rw
=
get_rel_pos
(
q_w
,
k_w
,
rel_pos_w
)
B
,
_
,
dim
=
q
.
shape
r_q
=
q
.
reshape
(
B
,
q_h
,
q_w
,
dim
)
rel_h
=
torch
.
einsum
(
'bhwc,hkc->bhwk'
,
r_q
,
Rh
)
rel_w
=
torch
.
einsum
(
'bhwc,wkc->bhwk'
,
r_q
,
Rw
)
attn
=
(
attn
.
view
(
B
,
q_h
,
q_w
,
k_h
,
k_w
)
+
rel_h
[:,
:,
:,
:,
None
]
+
rel_w
[:,
:,
:,
None
,
:]).
view
(
B
,
q_h
*
q_w
,
k_h
*
k_w
)
return
attn
class
Attention
(
nn
.
Module
):
"""Multi-head Attention block with relative position embeddings.
Borrowed from https://github.com/facebookresearch/segment-anything/
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
qkv_bias (bool): If True, add a learnable bias to q, k, v.
Defaults to True.
use_rel_pos (bool):Whether to use relative position embedding.
Defaults to False.
input_size (int, optional): Input resolution for calculating the
relative positional parameter size. Defaults to None.
"""
def
__init__
(
self
,
embed_dims
:
int
,
num_heads
:
int
=
8
,
qkv_bias
:
bool
=
True
,
use_rel_pos
:
bool
=
False
,
input_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
num_heads
=
num_heads
head_embed_dims
=
embed_dims
//
num_heads
self
.
scale
=
head_embed_dims
**-
0.5
self
.
qkv
=
nn
.
Linear
(
embed_dims
,
embed_dims
*
3
,
bias
=
qkv_bias
)
self
.
proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
use_rel_pos
=
use_rel_pos
if
self
.
use_rel_pos
:
assert
(
input_size
is
not
None
),
\
'Input size must be provided if using relative position embed.'
# initialize relative positional embeddings
self
.
rel_pos_h
=
nn
.
Parameter
(
torch
.
zeros
(
2
*
input_size
[
0
]
-
1
,
head_embed_dims
))
self
.
rel_pos_w
=
nn
.
Parameter
(
torch
.
zeros
(
2
*
input_size
[
1
]
-
1
,
head_embed_dims
))
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
B
,
H
,
W
,
_
=
x
.
shape
# qkv with shape (3, B, nHead, H * W, C)
qkv
=
self
.
qkv
(
x
).
reshape
(
B
,
H
*
W
,
3
,
self
.
num_heads
,
-
1
).
permute
(
2
,
0
,
3
,
1
,
4
)
# q, k, v with shape (B * nHead, H * W, C)
q
,
k
,
v
=
qkv
.
reshape
(
3
,
B
*
self
.
num_heads
,
H
*
W
,
-
1
).
unbind
(
0
)
attn
=
(
q
*
self
.
scale
)
@
k
.
transpose
(
-
2
,
-
1
)
if
self
.
use_rel_pos
:
attn
=
add_decomposed_rel_pos
(
attn
,
q
,
self
.
rel_pos_h
,
self
.
rel_pos_w
,
(
H
,
W
),
(
H
,
W
))
attn
=
attn
.
softmax
(
dim
=-
1
)
x
=
(
attn
@
v
).
view
(
B
,
self
.
num_heads
,
H
,
W
,
-
1
).
permute
(
0
,
2
,
3
,
1
,
4
).
reshape
(
B
,
H
,
W
,
-
1
)
x
=
self
.
proj
(
x
)
return
x
class
TransformerEncoderLayer
(
BaseModule
):
"""Encoder layer with window attention in Vision Transformer.
Args:
embed_dims (int): The feature dimension
num_heads (int): Parallel attention heads
feedforward_channels (int): The hidden dimension for FFNs
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
act_cfg (dict): The activation config for FFNs.
Defaults to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
use_rel_pos (bool):Whether to use relative position embedding.
Defaults to False.
window_size (int): Window size for window attention. Defaults to 0.
input_size (int, optional): Input resolution for calculating the
relative positional parameter size. Defaults to None.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
:
int
,
num_heads
:
int
,
feedforward_channels
:
int
,
drop_rate
:
float
=
0.
,
drop_path_rate
:
float
=
0.
,
num_fcs
:
int
=
2
,
qkv_bias
:
bool
=
True
,
act_cfg
:
dict
=
dict
(
type
=
'GELU'
),
norm_cfg
:
dict
=
dict
(
type
=
'LN'
),
use_rel_pos
:
bool
=
False
,
window_size
:
int
=
0
,
input_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
=
init_cfg
)
self
.
embed_dims
=
embed_dims
self
.
window_size
=
window_size
self
.
ln1
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
)
self
.
attn
=
Attention
(
embed_dims
=
embed_dims
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
use_rel_pos
=
use_rel_pos
,
input_size
=
input_size
if
window_size
==
0
else
(
window_size
,
window_size
),
)
self
.
ln2
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
)
self
.
ffn
=
FFN
(
embed_dims
=
embed_dims
,
feedforward_channels
=
feedforward_channels
,
num_fcs
=
num_fcs
,
ffn_drop
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
act_cfg
=
act_cfg
)
@
property
def
norm1
(
self
):
return
self
.
ln1
@
property
def
norm2
(
self
):
return
self
.
ln2
def
forward
(
self
,
x
):
shortcut
=
x
x
=
self
.
ln1
(
x
)
# Window partition
if
self
.
window_size
>
0
:
H
,
W
=
x
.
shape
[
1
],
x
.
shape
[
2
]
x
,
pad_hw
=
window_partition
(
x
,
self
.
window_size
)
x
=
self
.
attn
(
x
)
# Reverse window partition
if
self
.
window_size
>
0
:
x
=
window_unpartition
(
x
,
self
.
window_size
,
pad_hw
,
(
H
,
W
))
x
=
shortcut
+
x
x
=
self
.
ffn
(
self
.
ln2
(
x
),
identity
=
x
)
return
x
@
MODELS
.
register_module
()
class
ViTSAM
(
BaseBackbone
):
"""Vision Transformer as image encoder used in SAM.
A PyTorch implement of backbone: `Segment Anything
<https://arxiv.org/abs/2304.02643>`_
Args:
arch (str | dict): Vision Transformer architecture. If use string,
choose from 'base', 'large', 'huge'. If use dict, it should have
below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
- **feedforward_channels** (int): The hidden dimensions in
feedforward modules.
- **global_attn_indexes** (int): The index of layers with global
attention.
Defaults to 'base'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
out_channels (int): The num of output channels, if equal to 0, the
channel reduction layer is disabled. Defaults to 256.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
out_type (str): The type of output features. Please choose from
- ``"raw"`` or ``"featmap"``: The feature map tensor from the
patch tokens with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
Defaults to ``"raw"``.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
use_abs_pos (bool): Whether to use absolute position embedding.
Defaults to True.
use_rel_pos (bool):Whether to use relative position embedding.
Defaults to True.
window_size (int): Window size for window attention. Defaults to 14.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo
=
{
**
dict
.
fromkeys
(
[
'b'
,
'base'
],
{
'embed_dims'
:
768
,
'num_layers'
:
12
,
'num_heads'
:
12
,
'feedforward_channels'
:
3072
,
'global_attn_indexes'
:
[
2
,
5
,
8
,
11
]
}),
**
dict
.
fromkeys
(
[
'l'
,
'large'
],
{
'embed_dims'
:
1024
,
'num_layers'
:
24
,
'num_heads'
:
16
,
'feedforward_channels'
:
4096
,
'global_attn_indexes'
:
[
5
,
11
,
17
,
23
]
}),
**
dict
.
fromkeys
(
[
'h'
,
'huge'
],
{
'embed_dims'
:
1280
,
'num_layers'
:
32
,
'num_heads'
:
16
,
'feedforward_channels'
:
5120
,
'global_attn_indexes'
:
[
7
,
15
,
23
,
31
]
}),
}
OUT_TYPES
=
{
'raw'
,
'featmap'
,
'avg_featmap'
}
def
__init__
(
self
,
arch
:
str
=
'base'
,
img_size
:
int
=
224
,
patch_size
:
int
=
16
,
in_channels
:
int
=
3
,
out_channels
:
int
=
256
,
out_indices
:
int
=
-
1
,
out_type
:
str
=
'raw'
,
drop_rate
:
float
=
0.
,
drop_path_rate
:
float
=
0.
,
qkv_bias
:
bool
=
True
,
use_abs_pos
:
bool
=
True
,
use_rel_pos
:
bool
=
True
,
window_size
:
int
=
14
,
norm_cfg
:
dict
=
dict
(
type
=
'LN'
,
eps
=
1e-6
),
frozen_stages
:
int
=
-
1
,
interpolate_mode
:
str
=
'bicubic'
,
patch_cfg
:
dict
=
dict
(),
layer_cfgs
:
dict
=
dict
(),
init_cfg
:
Optional
[
dict
]
=
None
):
super
().
__init__
(
init_cfg
)
if
isinstance
(
arch
,
str
):
arch
=
arch
.
lower
()
assert
arch
in
set
(
self
.
arch_zoo
),
\
f
'Arch
{
arch
}
is not in default archs
{
set
(
self
.
arch_zoo
)
}
'
self
.
arch_settings
=
self
.
arch_zoo
[
arch
]
else
:
essential_keys
=
{
'embed_dims'
,
'num_layers'
,
'num_heads'
,
'feedforward_channels'
}
assert
isinstance
(
arch
,
dict
)
and
essential_keys
<=
set
(
arch
),
\
f
'Custom arch needs a dict with keys
{
essential_keys
}
'
self
.
arch_settings
=
arch
self
.
embed_dims
=
self
.
arch_settings
[
'embed_dims'
]
self
.
num_layers
=
self
.
arch_settings
[
'num_layers'
]
self
.
global_attn_indexes
=
self
.
arch_settings
[
'global_attn_indexes'
]
self
.
img_size
=
to_2tuple
(
img_size
)
# Set patch embedding
_patch_cfg
=
dict
(
in_channels
=
in_channels
,
input_size
=
img_size
,
embed_dims
=
self
.
embed_dims
,
conv_type
=
'Conv2d'
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
)
_patch_cfg
.
update
(
patch_cfg
)
self
.
patch_embed
=
PatchEmbed
(
**
_patch_cfg
)
self
.
patch_resolution
=
self
.
patch_embed
.
init_out_size
# Set out type
if
out_type
not
in
self
.
OUT_TYPES
:
raise
ValueError
(
f
'Unsupported `out_type`
{
out_type
}
, please '
f
'choose from
{
self
.
OUT_TYPES
}
'
)
self
.
out_type
=
out_type
self
.
use_abs_pos
=
use_abs_pos
self
.
interpolate_mode
=
interpolate_mode
if
use_abs_pos
:
# Set position embedding
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
*
self
.
patch_resolution
,
self
.
embed_dims
))
self
.
drop_after_pos
=
nn
.
Dropout
(
p
=
drop_rate
)
self
.
_register_load_state_dict_pre_hook
(
self
.
_prepare_pos_embed
)
if
use_rel_pos
:
self
.
_register_load_state_dict_pre_hook
(
self
.
_prepare_relative_position
)
if
isinstance
(
out_indices
,
int
):
out_indices
=
[
out_indices
]
assert
isinstance
(
out_indices
,
Sequence
),
\
f
'"out_indices" must by a sequence or int, '
\
f
'get
{
type
(
out_indices
)
}
instead.'
for
i
,
index
in
enumerate
(
out_indices
):
if
index
<
0
:
out_indices
[
i
]
=
self
.
num_layers
+
index
assert
0
<=
out_indices
[
i
]
<=
self
.
num_layers
,
\
f
'Invalid out_indices
{
index
}
'
self
.
out_indices
=
out_indices
# stochastic depth decay rule
dpr
=
np
.
linspace
(
0
,
drop_path_rate
,
self
.
num_layers
)
self
.
layers
=
ModuleList
()
if
isinstance
(
layer_cfgs
,
dict
):
layer_cfgs
=
[
layer_cfgs
]
*
self
.
num_layers
for
i
in
range
(
self
.
num_layers
):
_layer_cfg
=
dict
(
embed_dims
=
self
.
embed_dims
,
num_heads
=
self
.
arch_settings
[
'num_heads'
],
feedforward_channels
=
self
.
arch_settings
[
'feedforward_channels'
],
drop_rate
=
drop_rate
,
drop_path_rate
=
dpr
[
i
],
qkv_bias
=
qkv_bias
,
window_size
=
window_size
if
i
not
in
self
.
global_attn_indexes
else
0
,
input_size
=
self
.
patch_resolution
,
use_rel_pos
=
use_rel_pos
,
norm_cfg
=
norm_cfg
)
_layer_cfg
.
update
(
layer_cfgs
[
i
])
self
.
layers
.
append
(
TransformerEncoderLayer
(
**
_layer_cfg
))
self
.
out_channels
=
out_channels
if
self
.
out_channels
>
0
:
self
.
channel_reduction
=
nn
.
Sequential
(
nn
.
Conv2d
(
self
.
embed_dims
,
out_channels
,
kernel_size
=
1
,
bias
=
False
,
),
LayerNorm2d
(
out_channels
,
eps
=
1e-6
),
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
padding
=
1
,
bias
=
False
,
),
LayerNorm2d
(
out_channels
,
eps
=
1e-6
),
)
# freeze stages only when self.frozen_stages > 0
self
.
frozen_stages
=
frozen_stages
if
self
.
frozen_stages
>
0
:
self
.
_freeze_stages
()
def
init_weights
(
self
):
super
().
init_weights
()
if
not
(
isinstance
(
self
.
init_cfg
,
dict
)
and
self
.
init_cfg
[
'type'
]
==
'Pretrained'
):
if
self
.
pos_embed
is
not
None
:
trunc_normal_
(
self
.
pos_embed
,
std
=
0.02
)
def
_freeze_stages
(
self
):
# freeze position embedding
if
self
.
pos_embed
is
not
None
:
self
.
pos_embed
.
requires_grad
=
False
# set dropout to eval model
self
.
drop_after_pos
.
eval
()
# freeze patch embedding
self
.
patch_embed
.
eval
()
for
param
in
self
.
patch_embed
.
parameters
():
param
.
requires_grad
=
False
# freeze layers
for
i
in
range
(
1
,
self
.
frozen_stages
+
1
):
m
=
self
.
layers
[
i
-
1
]
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
# freeze channel_reduction module
if
self
.
frozen_stages
==
self
.
num_layers
and
self
.
out_channels
>
0
:
m
=
self
.
channel_reduction
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
]:
B
=
x
.
shape
[
0
]
x
,
patch_resolution
=
self
.
patch_embed
(
x
)
x
=
x
.
view
(
B
,
patch_resolution
[
0
],
patch_resolution
[
1
],
self
.
embed_dims
)
if
self
.
use_abs_pos
:
# 'resize_pos_embed' only supports 'pos_embed' with ndim==3, but
# in ViTSAM, the 'pos_embed' has 4 dimensions (1, H, W, C), so it
# is flattened. Besides, ViTSAM doesn't have any extra token.
resized_pos_embed
=
resize_pos_embed
(
self
.
pos_embed
.
flatten
(
1
,
2
),
self
.
patch_resolution
,
patch_resolution
,
mode
=
self
.
interpolate_mode
,
num_extra_tokens
=
0
)
x
=
x
+
resized_pos_embed
.
view
(
1
,
*
patch_resolution
,
self
.
embed_dims
)
x
=
self
.
drop_after_pos
(
x
)
outs
=
[]
for
i
,
layer
in
enumerate
(
self
.
layers
):
x
=
layer
(
x
)
if
i
in
self
.
out_indices
:
# (B, H, W, C) -> (B, C, H, W)
x_reshape
=
x
.
permute
(
0
,
3
,
1
,
2
)
if
self
.
out_channels
>
0
:
x_reshape
=
self
.
channel_reduction
(
x_reshape
)
outs
.
append
(
self
.
_format_output
(
x_reshape
))
return
tuple
(
outs
)
def
_format_output
(
self
,
x
)
->
torch
.
Tensor
:
if
self
.
out_type
==
'raw'
or
self
.
out_type
==
'featmap'
:
return
x
elif
self
.
out_type
==
'avg_featmap'
:
# (B, C, H, W) -> (B, C, N) -> (B, N, C)
x
=
x
.
flatten
(
2
).
permute
(
0
,
2
,
1
)
return
x
.
mean
(
dim
=
1
)
def
_prepare_pos_embed
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
name
=
prefix
+
'pos_embed'
if
name
not
in
state_dict
.
keys
():
return
ckpt_pos_embed_shape
=
state_dict
[
name
].
shape
if
self
.
pos_embed
.
shape
!=
ckpt_pos_embed_shape
:
from
mmengine.logging
import
MMLogger
logger
=
MMLogger
.
get_current_instance
()
logger
.
info
(
f
'Resize the pos_embed shape from
{
ckpt_pos_embed_shape
}
'
f
'to
{
self
.
pos_embed
.
shape
}
.'
)
ckpt_pos_embed_shape
=
ckpt_pos_embed_shape
[
1
:
3
]
pos_embed_shape
=
self
.
patch_embed
.
init_out_size
flattened_pos_embed
=
state_dict
[
name
].
flatten
(
1
,
2
)
resized_pos_embed
=
resize_pos_embed
(
flattened_pos_embed
,
ckpt_pos_embed_shape
,
pos_embed_shape
,
self
.
interpolate_mode
,
0
)
state_dict
[
name
]
=
resized_pos_embed
.
view
(
1
,
*
pos_embed_shape
,
self
.
embed_dims
)
def
_prepare_relative_position
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
state_dict_model
=
self
.
state_dict
()
all_keys
=
list
(
state_dict_model
.
keys
())
for
key
in
all_keys
:
if
'rel_pos_'
in
key
:
ckpt_key
=
prefix
+
key
if
ckpt_key
not
in
state_dict
:
continue
relative_position_pretrained
=
state_dict
[
ckpt_key
]
relative_position_current
=
state_dict_model
[
key
]
L1
,
_
=
relative_position_pretrained
.
size
()
L2
,
_
=
relative_position_current
.
size
()
if
L1
!=
L2
:
new_rel_pos
=
F
.
interpolate
(
relative_position_pretrained
.
reshape
(
1
,
L1
,
-
1
).
permute
(
0
,
2
,
1
),
size
=
L2
,
mode
=
'linear'
,
)
new_rel_pos
=
new_rel_pos
.
reshape
(
-
1
,
L2
).
permute
(
1
,
0
)
from
mmengine.logging
import
MMLogger
logger
=
MMLogger
.
get_current_instance
()
logger
.
info
(
f
'Resize the
{
ckpt_key
}
from '
f
'
{
state_dict
[
ckpt_key
].
shape
}
to '
f
'
{
new_rel_pos
.
shape
}
'
)
state_dict
[
ckpt_key
]
=
new_rel_pos
def
get_layer_depth
(
self
,
param_name
:
str
,
prefix
:
str
=
''
):
"""Get the layer-wise depth of a parameter.
Args:
param_name (str): The name of the parameter.
prefix (str): The prefix for the parameter.
Defaults to an empty string.
Returns:
Tuple[int, int]: The layer-wise depth and the num of layers.
Note:
The first depth is the stem module (``layer_depth=0``), and the
last depth is the subsequent module (``layer_depth=num_layers-1``)
"""
num_layers
=
self
.
num_layers
+
2
if
not
param_name
.
startswith
(
prefix
):
# For subsequent module like head
return
num_layers
-
1
,
num_layers
param_name
=
param_name
[
len
(
prefix
):]
if
param_name
in
(
'cls_token'
,
'pos_embed'
):
layer_depth
=
0
elif
param_name
.
startswith
(
'patch_embed'
):
layer_depth
=
0
elif
param_name
.
startswith
(
'layers'
):
layer_id
=
int
(
param_name
.
split
(
'.'
)[
1
])
layer_depth
=
layer_id
+
1
else
:
layer_depth
=
num_layers
-
1
return
layer_depth
,
num_layers
Prev
1
…
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment