Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
dcuai
dlexamples
Commits
0fd8347d
Commit
0fd8347d
authored
Jan 08, 2023
by
unknown
Browse files
添加mmclassification-0.24.1代码,删除mmclassification-speed-benchmark
parent
cc567e9e
Changes
839
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3961 additions
and
59 deletions
+3961
-59
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/resnest.py
...mmclassification-0.24.1/mmcls/models/backbones/resnest.py
+2
-1
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/resnet.py
.../mmclassification-0.24.1/mmcls/models/backbones/resnet.py
+61
-23
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/resnet_cifar.py
...ssification-0.24.1/mmcls/models/backbones/resnet_cifar.py
+2
-4
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/resnext.py
...mmclassification-0.24.1/mmcls/models/backbones/resnext.py
+2
-1
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/seresnet.py
...mclassification-0.24.1/mmcls/models/backbones/seresnet.py
+2
-1
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/seresnext.py
...classification-0.24.1/mmcls/models/backbones/seresnext.py
+2
-1
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/shufflenet_v1.py
...sification-0.24.1/mmcls/models/backbones/shufflenet_v1.py
+14
-8
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/shufflenet_v2.py
...sification-0.24.1/mmcls/models/backbones/shufflenet_v2.py
+22
-9
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/swin_transformer.py
...ication-0.24.1/mmcls/models/backbones/swin_transformer.py
+548
-0
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/swin_transformer_v2.py
...tion-0.24.1/mmcls/models/backbones/swin_transformer_v2.py
+560
-0
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/t2t_vit.py
...mmclassification-0.24.1/mmcls/models/backbones/t2t_vit.py
+440
-0
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/timm_backbone.py
...sification-0.24.1/mmcls/models/backbones/timm_backbone.py
+112
-0
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/tnt.py
...est/mmclassification-0.24.1/mmcls/models/backbones/tnt.py
+368
-0
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/twins.py
...t/mmclassification-0.24.1/mmcls/models/backbones/twins.py
+723
-0
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/van.py
...est/mmclassification-0.24.1/mmcls/models/backbones/van.py
+445
-0
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/vgg.py
...est/mmclassification-0.24.1/mmcls/models/backbones/vgg.py
+8
-11
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/vision_transformer.py
...ation-0.24.1/mmcls/models/backbones/vision_transformer.py
+383
-0
openmmlab_test/mmclassification-0.24.1/mmcls/models/builder.py
...mlab_test/mmclassification-0.24.1/mmcls/models/builder.py
+38
-0
openmmlab_test/mmclassification-0.24.1/mmcls/models/classifiers/__init__.py
...lassification-0.24.1/mmcls/models/classifiers/__init__.py
+5
-0
openmmlab_test/mmclassification-0.24.1/mmcls/models/classifiers/base.py
.../mmclassification-0.24.1/mmcls/models/classifiers/base.py
+224
-0
No files found.
Too many changes to show.
To preserve performance only
839 of 839+
files are displayed.
Plain diff
Email patch
openmmlab_test/mmclassification-
speed-benchmark
/mmcls/models/backbones/resnest.py
→
openmmlab_test/mmclassification-
0.24.1
/mmcls/models/backbones/resnest.py
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
...
...
@@ -260,7 +261,7 @@ class Bottleneck(_Bottleneck):
class
ResNeSt
(
ResNetV1d
):
"""ResNeSt backbone.
Please refer to the `paper <https://arxiv.org/pdf/2004.08955.pdf>`_ for
Please refer to the `paper <https://arxiv.org/pdf/2004.08955.pdf>`_
_
for
details.
Args:
...
...
openmmlab_test/mmclassification-
speed-benchmark
/mmcls/models/backbones/resnet.py
→
openmmlab_test/mmclassification-
0.24.1
/mmcls/models/backbones/resnet.py
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
cp
from
mmcv.cnn
import
(
ConvModule
,
build_conv_layer
,
build_norm_layer
,
constant_init
)
from
mmcv.cnn
import
(
ConvModule
,
build_activation_layer
,
build_conv_layer
,
build_norm_layer
,
constant_init
)
from
mmcv.cnn.bricks
import
DropPath
from
mmcv.runner
import
BaseModule
from
mmcv.utils.parrots_wrapper
import
_BatchNorm
from
..builder
import
BACKBONES
from
.base_backbone
import
BaseBackbone
eps
=
1.0e-5
class
BasicBlock
(
nn
.
Module
):
class
BasicBlock
(
Base
Module
):
"""BasicBlock for ResNet.
Args:
...
...
@@ -41,8 +47,11 @@ class BasicBlock(nn.Module):
style
=
'pytorch'
,
with_cp
=
False
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
)):
super
(
BasicBlock
,
self
).
__init__
()
norm_cfg
=
dict
(
type
=
'BN'
),
drop_path_rate
=
0.0
,
act_cfg
=
dict
(
type
=
'ReLU'
,
inplace
=
True
),
init_cfg
=
None
):
super
(
BasicBlock
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
expansion
=
expansion
...
...
@@ -80,8 +89,10 @@ class BasicBlock(nn.Module):
bias
=
False
)
self
.
add_module
(
self
.
norm2_name
,
norm2
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
relu
=
build_activation_layer
(
act_cfg
)
self
.
downsample
=
downsample
self
.
drop_path
=
DropPath
(
drop_prob
=
drop_path_rate
)
if
drop_path_rate
>
eps
else
nn
.
Identity
()
@
property
def
norm1
(
self
):
...
...
@@ -106,6 +117,8 @@ class BasicBlock(nn.Module):
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
=
self
.
drop_path
(
out
)
out
+=
identity
return
out
...
...
@@ -120,7 +133,7 @@ class BasicBlock(nn.Module):
return
out
class
Bottleneck
(
nn
.
Module
):
class
Bottleneck
(
Base
Module
):
"""Bottleneck block for ResNet.
Args:
...
...
@@ -153,8 +166,11 @@ class Bottleneck(nn.Module):
style
=
'pytorch'
,
with_cp
=
False
,
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
)):
super
(
Bottleneck
,
self
).
__init__
()
norm_cfg
=
dict
(
type
=
'BN'
),
act_cfg
=
dict
(
type
=
'ReLU'
,
inplace
=
True
),
drop_path_rate
=
0.0
,
init_cfg
=
None
):
super
(
Bottleneck
,
self
).
__init__
(
init_cfg
=
init_cfg
)
assert
style
in
[
'pytorch'
,
'caffe'
]
self
.
in_channels
=
in_channels
...
...
@@ -210,8 +226,10 @@ class Bottleneck(nn.Module):
bias
=
False
)
self
.
add_module
(
self
.
norm3_name
,
norm3
)
self
.
relu
=
nn
.
ReLU
(
inplace
=
True
)
self
.
relu
=
build_activation_layer
(
act_cfg
)
self
.
downsample
=
downsample
self
.
drop_path
=
DropPath
(
drop_prob
=
drop_path_rate
)
if
drop_path_rate
>
eps
else
nn
.
Identity
()
@
property
def
norm1
(
self
):
...
...
@@ -244,6 +262,8 @@ class Bottleneck(nn.Module):
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
=
self
.
drop_path
(
out
)
out
+=
identity
return
out
...
...
@@ -382,7 +402,7 @@ class ResLayer(nn.Sequential):
class
ResNet
(
BaseBackbone
):
"""ResNet backbone.
Please refer to the `paper <https://arxiv.org/abs/1512.03385>`_ for
Please refer to the `paper <https://arxiv.org/abs/1512.03385>`_
_
for
details.
Args:
...
...
@@ -395,10 +415,8 @@ class ResNet(BaseBackbone):
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. Default: ``(3, )``.
out_indices (Sequence[int]): Output from which stages.
Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
...
...
@@ -466,7 +484,8 @@ class ResNet(BaseBackbone):
type
=
'Constant'
,
val
=
1
,
layer
=
[
'_BatchNorm'
,
'GroupNorm'
])
]):
],
drop_path_rate
=
0.0
):
super
(
ResNet
,
self
).
__init__
(
init_cfg
)
if
depth
not
in
self
.
arch_settings
:
raise
KeyError
(
f
'invalid depth
{
depth
}
for resnet'
)
...
...
@@ -513,7 +532,8 @@ class ResNet(BaseBackbone):
avg_down
=
self
.
avg_down
,
with_cp
=
with_cp
,
conv_cfg
=
conv_cfg
,
norm_cfg
=
norm_cfg
)
norm_cfg
=
norm_cfg
,
drop_path_rate
=
drop_path_rate
)
_in_channels
=
_out_channels
_out_channels
*=
2
layer_name
=
f
'layer
{
i
+
1
}
'
...
...
@@ -594,10 +614,14 @@ class ResNet(BaseBackbone):
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
# def init_weights(self, pretrained=None):
def
init_weights
(
self
):
super
(
ResNet
,
self
).
init_weights
()
if
(
isinstance
(
self
.
init_cfg
,
dict
)
and
self
.
init_cfg
[
'type'
]
==
'Pretrained'
):
# Suppress zero_init_residual if use pretrained model.
return
if
self
.
zero_init_residual
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
Bottleneck
):
...
...
@@ -619,9 +643,6 @@ class ResNet(BaseBackbone):
x
=
res_layer
(
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
x
)
if
len
(
outs
)
==
1
:
return
outs
[
0
]
else
:
return
tuple
(
outs
)
def
train
(
self
,
mode
=
True
):
...
...
@@ -634,10 +655,27 @@ class ResNet(BaseBackbone):
m
.
eval
()
@
BACKBONES
.
register_module
()
class
ResNetV1c
(
ResNet
):
"""ResNetV1c backbone.
This variant is described in `Bag of Tricks.
<https://arxiv.org/pdf/1812.01187.pdf>`_.
Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv
in the input stem with three 3x3 convs.
"""
def
__init__
(
self
,
**
kwargs
):
super
(
ResNetV1c
,
self
).
__init__
(
deep_stem
=
True
,
avg_down
=
False
,
**
kwargs
)
@
BACKBONES
.
register_module
()
class
ResNetV1d
(
ResNet
):
"""ResNetV1d
variant described in `Bag of Tricks
.
"""ResNetV1d
backbone
.
This variant is described in `Bag of Tricks.
<https://arxiv.org/pdf/1812.01187.pdf>`_.
Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
...
...
openmmlab_test/mmclassification-
speed-benchmark
/mmcls/models/backbones/resnet_cifar.py
→
openmmlab_test/mmclassification-
0.24.1
/mmcls/models/backbones/resnet_cifar.py
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
torch.nn
as
nn
from
mmcv.cnn
import
build_conv_layer
,
build_norm_layer
...
...
@@ -77,7 +78,4 @@ class ResNet_CIFAR(ResNet):
x
=
res_layer
(
x
)
if
i
in
self
.
out_indices
:
outs
.
append
(
x
)
if
len
(
outs
)
==
1
:
return
outs
[
0
]
else
:
return
tuple
(
outs
)
openmmlab_test/mmclassification-
speed-benchmark
/mmcls/models/backbones/resnext.py
→
openmmlab_test/mmclassification-
0.24.1
/mmcls/models/backbones/resnext.py
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
mmcv.cnn
import
build_conv_layer
,
build_norm_layer
from
..builder
import
BACKBONES
...
...
@@ -89,7 +90,7 @@ class Bottleneck(_Bottleneck):
class
ResNeXt
(
ResNet
):
"""ResNeXt backbone.
Please refer to the `paper <https://arxiv.org/abs/1611.05431>`_ for
Please refer to the `paper <https://arxiv.org/abs/1611.05431>`_
_
for
details.
Args:
...
...
openmmlab_test/mmclassification-
speed-benchmark
/mmcls/models/backbones/seresnet.py
→
openmmlab_test/mmclassification-
0.24.1
/mmcls/models/backbones/seresnet.py
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
torch.utils.checkpoint
as
cp
from
..builder
import
BACKBONES
...
...
@@ -57,7 +58,7 @@ class SEBottleneck(Bottleneck):
class
SEResNet
(
ResNet
):
"""SEResNet backbone.
Please refer to the `paper <https://arxiv.org/abs/1709.01507>`_ for
Please refer to the `paper <https://arxiv.org/abs/1709.01507>`_
_
for
details.
Args:
...
...
openmmlab_test/mmclassification-
speed-benchmark
/mmcls/models/backbones/seresnext.py
→
openmmlab_test/mmclassification-
0.24.1
/mmcls/models/backbones/seresnext.py
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
mmcv.cnn
import
build_conv_layer
,
build_norm_layer
from
..builder
import
BACKBONES
...
...
@@ -95,7 +96,7 @@ class SEBottleneck(_SEBottleneck):
class
SEResNeXt
(
SEResNet
):
"""SEResNeXt backbone.
Please refer to the `paper <https://arxiv.org/abs/1709.01507>`_ for
Please refer to the `paper <https://arxiv.org/abs/1709.01507>`_
_
for
details.
Args:
...
...
openmmlab_test/mmclassification-
speed-benchmark
/mmcls/models/backbones/shufflenet_v1.py
→
openmmlab_test/mmclassification-
0.24.1
/mmcls/models/backbones/shufflenet_v1.py
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
cp
from
mmcv.cnn
import
(
ConvModule
,
build_activation_layer
,
constant_init
,
normal_init
)
from
mmcv.runner
import
BaseModule
from
torch.nn.modules.batchnorm
import
_BatchNorm
from
mmcls.models.utils
import
channel_shuffle
,
make_divisible
...
...
@@ -10,7 +12,7 @@ from ..builder import BACKBONES
from
.base_backbone
import
BaseBackbone
class
ShuffleUnit
(
nn
.
Module
):
class
ShuffleUnit
(
Base
Module
):
"""ShuffleUnit block.
ShuffleNet unit with pointwise group convolution (GConv) and channel
...
...
@@ -22,7 +24,7 @@ class ShuffleUnit(nn.Module):
groups (int): The number of groups to be used in grouped 1x1
convolutions in each ShuffleUnit. Default: 3
first_block (bool): Whether it is the first ShuffleUnit of a
sequential ShuffleUnits. Default:
Fals
e, which means not using the
sequential ShuffleUnits. Default:
Tru
e, which means not using the
grouped 1x1 convolution.
combine (str): The ways to combine the input and output
branches. Default: 'add'.
...
...
@@ -184,6 +186,7 @@ class ShuffleNetV1(BaseBackbone):
with_cp
=
False
,
init_cfg
=
None
):
super
(
ShuffleNetV1
,
self
).
__init__
(
init_cfg
)
self
.
init_cfg
=
init_cfg
self
.
stage_blocks
=
[
4
,
8
,
4
]
self
.
groups
=
groups
...
...
@@ -250,6 +253,12 @@ class ShuffleNetV1(BaseBackbone):
def
init_weights
(
self
):
super
(
ShuffleNetV1
,
self
).
init_weights
()
if
(
isinstance
(
self
.
init_cfg
,
dict
)
and
self
.
init_cfg
[
'type'
]
==
'Pretrained'
):
# Suppress default init if use pretrained model.
return
for
name
,
m
in
self
.
named_modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
if
'conv1'
in
name
:
...
...
@@ -257,7 +266,7 @@ class ShuffleNetV1(BaseBackbone):
else
:
normal_init
(
m
,
mean
=
0
,
std
=
1.0
/
m
.
weight
.
shape
[
1
])
elif
isinstance
(
m
,
(
_BatchNorm
,
nn
.
GroupNorm
)):
constant_init
(
m
.
weight
,
val
=
1
,
bias
=
0.0001
)
constant_init
(
m
,
val
=
1
,
bias
=
0.0001
)
if
isinstance
(
m
,
_BatchNorm
):
if
m
.
running_mean
is
not
None
:
nn
.
init
.
constant_
(
m
.
running_mean
,
0
)
...
...
@@ -269,7 +278,7 @@ class ShuffleNetV1(BaseBackbone):
out_channels (int): out_channels of the block.
num_blocks (int): Number of blocks.
first_block (bool): Whether is the first ShuffleUnit of a
sequential ShuffleUnits. Default: False, which means
not
using
sequential ShuffleUnits. Default: False, which means using
the grouped 1x1 convolution.
"""
layers
=
[]
...
...
@@ -301,9 +310,6 @@ class ShuffleNetV1(BaseBackbone):
if
i
in
self
.
out_indices
:
outs
.
append
(
x
)
if
len
(
outs
)
==
1
:
return
outs
[
0
]
else
:
return
tuple
(
outs
)
def
train
(
self
,
mode
=
True
):
...
...
openmmlab_test/mmclassification-
speed-benchmark
/mmcls/models/backbones/shufflenet_v2.py
→
openmmlab_test/mmclassification-
0.24.1
/mmcls/models/backbones/shufflenet_v2.py
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
cp
from
mmcv.cnn
import
ConvModule
,
constant_init
,
normal_init
from
mmcv.runner
import
BaseModule
from
torch.nn.modules.batchnorm
import
_BatchNorm
from
mmcls.models.utils
import
channel_shuffle
...
...
@@ -9,7 +11,7 @@ from ..builder import BACKBONES
from
.base_backbone
import
BaseBackbone
class
InvertedResidual
(
nn
.
Module
):
class
InvertedResidual
(
Base
Module
):
"""InvertedResidual block for ShuffleNetV2 backbone.
Args:
...
...
@@ -36,8 +38,9 @@ class InvertedResidual(nn.Module):
conv_cfg
=
None
,
norm_cfg
=
dict
(
type
=
'BN'
),
act_cfg
=
dict
(
type
=
'ReLU'
),
with_cp
=
False
):
super
(
InvertedResidual
,
self
).
__init__
()
with_cp
=
False
,
init_cfg
=
None
):
super
(
InvertedResidual
,
self
).
__init__
(
init_cfg
)
self
.
stride
=
stride
self
.
with_cp
=
with_cp
...
...
@@ -112,7 +115,14 @@ class InvertedResidual(nn.Module):
if
self
.
stride
>
1
:
out
=
torch
.
cat
((
self
.
branch1
(
x
),
self
.
branch2
(
x
)),
dim
=
1
)
else
:
x1
,
x2
=
x
.
chunk
(
2
,
dim
=
1
)
# Channel Split operation. using these lines of code to replace
# ``chunk(x, 2, dim=1)`` can make it easier to deploy a
# shufflenetv2 model by using mmdeploy.
channels
=
x
.
shape
[
1
]
c
=
channels
//
2
+
channels
%
2
x1
=
x
[:,
:
c
,
:,
:]
x2
=
x
[:,
c
:,
:,
:]
out
=
torch
.
cat
((
x1
,
self
.
branch2
(
x2
)),
dim
=
1
)
out
=
channel_shuffle
(
out
,
2
)
...
...
@@ -253,8 +263,14 @@ class ShuffleNetV2(BaseBackbone):
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
def
init_weighs
(
self
):
def
init_weigh
t
s
(
self
):
super
(
ShuffleNetV2
,
self
).
init_weights
()
if
(
isinstance
(
self
.
init_cfg
,
dict
)
and
self
.
init_cfg
[
'type'
]
==
'Pretrained'
):
# Suppress default init if use pretrained model.
return
for
name
,
m
in
self
.
named_modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
if
'conv1'
in
name
:
...
...
@@ -277,9 +293,6 @@ class ShuffleNetV2(BaseBackbone):
if
i
in
self
.
out_indices
:
outs
.
append
(
x
)
if
len
(
outs
)
==
1
:
return
outs
[
0
]
else
:
return
tuple
(
outs
)
def
train
(
self
,
mode
=
True
):
...
...
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/swin_transformer.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
from
typing
import
Sequence
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
cp
from
mmcv.cnn
import
build_norm_layer
from
mmcv.cnn.bricks.transformer
import
FFN
,
PatchEmbed
,
PatchMerging
from
mmcv.cnn.utils.weight_init
import
trunc_normal_
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
from
mmcv.utils.parrots_wrapper
import
_BatchNorm
from
..builder
import
BACKBONES
from
..utils
import
(
ShiftWindowMSA
,
resize_pos_embed
,
resize_relative_position_bias_table
,
to_2tuple
)
from
.base_backbone
import
BaseBackbone
class
SwinBlock
(
BaseModule
):
"""Swin Transformer block.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
shift (bool): Shift the attention window or not. Defaults to False.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
drop_path (float): The drop path rate after attention and ffn.
Defaults to 0.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
attn_cfgs (dict): The extra config of Shift Window-MSA.
Defaults to empty dict.
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
norm_cfg (dict): The config of norm layers.
Defaults to ``dict(type='LN')``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
window_size
=
7
,
shift
=
False
,
ffn_ratio
=
4.
,
drop_path
=
0.
,
pad_small_map
=
False
,
attn_cfgs
=
dict
(),
ffn_cfgs
=
dict
(),
norm_cfg
=
dict
(
type
=
'LN'
),
with_cp
=
False
,
init_cfg
=
None
):
super
(
SwinBlock
,
self
).
__init__
(
init_cfg
)
self
.
with_cp
=
with_cp
_attn_cfgs
=
{
'embed_dims'
:
embed_dims
,
'num_heads'
:
num_heads
,
'shift_size'
:
window_size
//
2
if
shift
else
0
,
'window_size'
:
window_size
,
'dropout_layer'
:
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path
),
'pad_small_map'
:
pad_small_map
,
**
attn_cfgs
}
self
.
norm1
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
self
.
attn
=
ShiftWindowMSA
(
**
_attn_cfgs
)
_ffn_cfgs
=
{
'embed_dims'
:
embed_dims
,
'feedforward_channels'
:
int
(
embed_dims
*
ffn_ratio
),
'num_fcs'
:
2
,
'ffn_drop'
:
0
,
'dropout_layer'
:
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path
),
'act_cfg'
:
dict
(
type
=
'GELU'
),
**
ffn_cfgs
}
self
.
norm2
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
self
.
ffn
=
FFN
(
**
_ffn_cfgs
)
def
forward
(
self
,
x
,
hw_shape
):
def
_inner_forward
(
x
):
identity
=
x
x
=
self
.
norm1
(
x
)
x
=
self
.
attn
(
x
,
hw_shape
)
x
=
x
+
identity
identity
=
x
x
=
self
.
norm2
(
x
)
x
=
self
.
ffn
(
x
,
identity
=
identity
)
return
x
if
self
.
with_cp
and
x
.
requires_grad
:
x
=
cp
.
checkpoint
(
_inner_forward
,
x
)
else
:
x
=
_inner_forward
(
x
)
return
x
class
SwinBlockSequence
(
BaseModule
):
"""Module with successive Swin Transformer blocks and downsample layer.
Args:
embed_dims (int): Number of input channels.
depth (int): Number of successive swin transformer blocks.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
downsample (bool): Downsample the output of blocks by patch merging.
Defaults to False.
downsample_cfg (dict): The extra config of the patch merging layer.
Defaults to empty dict.
drop_paths (Sequence[float] | float): The drop path rate in each block.
Defaults to 0.
block_cfgs (Sequence[dict] | dict): The extra config of each block.
Defaults to empty dicts.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
depth
,
num_heads
,
window_size
=
7
,
downsample
=
False
,
downsample_cfg
=
dict
(),
drop_paths
=
0.
,
block_cfgs
=
dict
(),
with_cp
=
False
,
pad_small_map
=
False
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
if
not
isinstance
(
drop_paths
,
Sequence
):
drop_paths
=
[
drop_paths
]
*
depth
if
not
isinstance
(
block_cfgs
,
Sequence
):
block_cfgs
=
[
deepcopy
(
block_cfgs
)
for
_
in
range
(
depth
)]
self
.
embed_dims
=
embed_dims
self
.
blocks
=
ModuleList
()
for
i
in
range
(
depth
):
_block_cfg
=
{
'embed_dims'
:
embed_dims
,
'num_heads'
:
num_heads
,
'window_size'
:
window_size
,
'shift'
:
False
if
i
%
2
==
0
else
True
,
'drop_path'
:
drop_paths
[
i
],
'with_cp'
:
with_cp
,
'pad_small_map'
:
pad_small_map
,
**
block_cfgs
[
i
]
}
block
=
SwinBlock
(
**
_block_cfg
)
self
.
blocks
.
append
(
block
)
if
downsample
:
_downsample_cfg
=
{
'in_channels'
:
embed_dims
,
'out_channels'
:
2
*
embed_dims
,
'norm_cfg'
:
dict
(
type
=
'LN'
),
**
downsample_cfg
}
self
.
downsample
=
PatchMerging
(
**
_downsample_cfg
)
else
:
self
.
downsample
=
None
def
forward
(
self
,
x
,
in_shape
,
do_downsample
=
True
):
for
block
in
self
.
blocks
:
x
=
block
(
x
,
in_shape
)
if
self
.
downsample
is
not
None
and
do_downsample
:
x
,
out_shape
=
self
.
downsample
(
x
,
in_shape
)
else
:
out_shape
=
in_shape
return
x
,
out_shape
@
property
def
out_channels
(
self
):
if
self
.
downsample
:
return
self
.
downsample
.
out_channels
else
:
return
self
.
embed_dims
@
BACKBONES
.
register_module
()
class
SwinTransformer
(
BaseBackbone
):
"""Swin Transformer.
A PyTorch implement of : `Swin Transformer:
Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>`_
Inspiration from
https://github.com/microsoft/Swin-Transformer
Args:
arch (str | dict): Swin Transformer architecture. If use string, choose
from 'tiny', 'small', 'base' and 'large'. If use dict, it should
have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **depths** (List[int]): The number of blocks in each stage.
- **num_heads** (List[int]): The number of heads in attention
modules of each stage.
Defaults to 'tiny'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 4.
in_channels (int): The num of input channels. Defaults to 3.
window_size (int): The height and width of the window. Defaults to 7.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
out_after_downsample (bool): Whether to output the feature map of a
stage after the following downsample layer. Defaults to False.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults to False.
interpolate_mode (str): Select the interpolate mode for absolute
position embeding vector resize. Defaults to "bicubic".
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
norm_cfg (dict): Config dict for normalization layer for all output
features. Defaults to ``dict(type='LN')``
stage_cfgs (Sequence[dict] | dict): Extra config dict for each
stage. Defaults to an empty dict.
patch_cfg (dict): Extra config dict for patch embedding.
Defaults to an empty dict.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmcls.models import SwinTransformer
>>> import torch
>>> extra_config = dict(
>>> arch='tiny',
>>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3,
>>> 'expansion_ratio': 3}))
>>> self = SwinTransformer(**extra_config)
>>> inputs = torch.rand(1, 3, 224, 224)
>>> output = self.forward(inputs)
>>> print(output.shape)
(1, 2592, 4)
"""
arch_zoo
=
{
**
dict
.
fromkeys
([
't'
,
'tiny'
],
{
'embed_dims'
:
96
,
'depths'
:
[
2
,
2
,
6
,
2
],
'num_heads'
:
[
3
,
6
,
12
,
24
]}),
**
dict
.
fromkeys
([
's'
,
'small'
],
{
'embed_dims'
:
96
,
'depths'
:
[
2
,
2
,
18
,
2
],
'num_heads'
:
[
3
,
6
,
12
,
24
]}),
**
dict
.
fromkeys
([
'b'
,
'base'
],
{
'embed_dims'
:
128
,
'depths'
:
[
2
,
2
,
18
,
2
],
'num_heads'
:
[
4
,
8
,
16
,
32
]}),
**
dict
.
fromkeys
([
'l'
,
'large'
],
{
'embed_dims'
:
192
,
'depths'
:
[
2
,
2
,
18
,
2
],
'num_heads'
:
[
6
,
12
,
24
,
48
]}),
}
# yapf: disable
_version
=
3
num_extra_tokens
=
0
def
__init__
(
self
,
arch
=
'tiny'
,
img_size
=
224
,
patch_size
=
4
,
in_channels
=
3
,
window_size
=
7
,
drop_rate
=
0.
,
drop_path_rate
=
0.1
,
out_indices
=
(
3
,
),
out_after_downsample
=
False
,
use_abs_pos_embed
=
False
,
interpolate_mode
=
'bicubic'
,
with_cp
=
False
,
frozen_stages
=-
1
,
norm_eval
=
False
,
pad_small_map
=
False
,
norm_cfg
=
dict
(
type
=
'LN'
),
stage_cfgs
=
dict
(),
patch_cfg
=
dict
(),
init_cfg
=
None
):
super
(
SwinTransformer
,
self
).
__init__
(
init_cfg
=
init_cfg
)
if
isinstance
(
arch
,
str
):
arch
=
arch
.
lower
()
assert
arch
in
set
(
self
.
arch_zoo
),
\
f
'Arch
{
arch
}
is not in default archs
{
set
(
self
.
arch_zoo
)
}
'
self
.
arch_settings
=
self
.
arch_zoo
[
arch
]
else
:
essential_keys
=
{
'embed_dims'
,
'depths'
,
'num_heads'
}
assert
isinstance
(
arch
,
dict
)
and
set
(
arch
)
==
essential_keys
,
\
f
'Custom arch needs a dict with keys
{
essential_keys
}
'
self
.
arch_settings
=
arch
self
.
embed_dims
=
self
.
arch_settings
[
'embed_dims'
]
self
.
depths
=
self
.
arch_settings
[
'depths'
]
self
.
num_heads
=
self
.
arch_settings
[
'num_heads'
]
self
.
num_layers
=
len
(
self
.
depths
)
self
.
out_indices
=
out_indices
self
.
out_after_downsample
=
out_after_downsample
self
.
use_abs_pos_embed
=
use_abs_pos_embed
self
.
interpolate_mode
=
interpolate_mode
self
.
frozen_stages
=
frozen_stages
_patch_cfg
=
dict
(
in_channels
=
in_channels
,
input_size
=
img_size
,
embed_dims
=
self
.
embed_dims
,
conv_type
=
'Conv2d'
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
norm_cfg
=
dict
(
type
=
'LN'
),
)
_patch_cfg
.
update
(
patch_cfg
)
self
.
patch_embed
=
PatchEmbed
(
**
_patch_cfg
)
self
.
patch_resolution
=
self
.
patch_embed
.
init_out_size
if
self
.
use_abs_pos_embed
:
num_patches
=
self
.
patch_resolution
[
0
]
*
self
.
patch_resolution
[
1
]
self
.
absolute_pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
,
self
.
embed_dims
))
self
.
_register_load_state_dict_pre_hook
(
self
.
_prepare_abs_pos_embed
)
self
.
_register_load_state_dict_pre_hook
(
self
.
_prepare_relative_position_bias_table
)
self
.
drop_after_pos
=
nn
.
Dropout
(
p
=
drop_rate
)
self
.
norm_eval
=
norm_eval
# stochastic depth
total_depth
=
sum
(
self
.
depths
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
total_depth
)
]
# stochastic depth decay rule
self
.
stages
=
ModuleList
()
embed_dims
=
[
self
.
embed_dims
]
for
i
,
(
depth
,
num_heads
)
in
enumerate
(
zip
(
self
.
depths
,
self
.
num_heads
)):
if
isinstance
(
stage_cfgs
,
Sequence
):
stage_cfg
=
stage_cfgs
[
i
]
else
:
stage_cfg
=
deepcopy
(
stage_cfgs
)
downsample
=
True
if
i
<
self
.
num_layers
-
1
else
False
_stage_cfg
=
{
'embed_dims'
:
embed_dims
[
-
1
],
'depth'
:
depth
,
'num_heads'
:
num_heads
,
'window_size'
:
window_size
,
'downsample'
:
downsample
,
'drop_paths'
:
dpr
[:
depth
],
'with_cp'
:
with_cp
,
'pad_small_map'
:
pad_small_map
,
**
stage_cfg
}
stage
=
SwinBlockSequence
(
**
_stage_cfg
)
self
.
stages
.
append
(
stage
)
dpr
=
dpr
[
depth
:]
embed_dims
.
append
(
stage
.
out_channels
)
if
self
.
out_after_downsample
:
self
.
num_features
=
embed_dims
[
1
:]
else
:
self
.
num_features
=
embed_dims
[:
-
1
]
for
i
in
out_indices
:
if
norm_cfg
is
not
None
:
norm_layer
=
build_norm_layer
(
norm_cfg
,
self
.
num_features
[
i
])[
1
]
else
:
norm_layer
=
nn
.
Identity
()
self
.
add_module
(
f
'norm
{
i
}
'
,
norm_layer
)
def
init_weights
(
self
):
super
(
SwinTransformer
,
self
).
init_weights
()
if
(
isinstance
(
self
.
init_cfg
,
dict
)
and
self
.
init_cfg
[
'type'
]
==
'Pretrained'
):
# Suppress default init if use pretrained model.
return
if
self
.
use_abs_pos_embed
:
trunc_normal_
(
self
.
absolute_pos_embed
,
std
=
0.02
)
def
forward
(
self
,
x
):
x
,
hw_shape
=
self
.
patch_embed
(
x
)
if
self
.
use_abs_pos_embed
:
x
=
x
+
resize_pos_embed
(
self
.
absolute_pos_embed
,
self
.
patch_resolution
,
hw_shape
,
self
.
interpolate_mode
,
self
.
num_extra_tokens
)
x
=
self
.
drop_after_pos
(
x
)
outs
=
[]
for
i
,
stage
in
enumerate
(
self
.
stages
):
x
,
hw_shape
=
stage
(
x
,
hw_shape
,
do_downsample
=
self
.
out_after_downsample
)
if
i
in
self
.
out_indices
:
norm_layer
=
getattr
(
self
,
f
'norm
{
i
}
'
)
out
=
norm_layer
(
x
)
out
=
out
.
view
(
-
1
,
*
hw_shape
,
self
.
num_features
[
i
]).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
out
)
if
stage
.
downsample
is
not
None
and
not
self
.
out_after_downsample
:
x
,
hw_shape
=
stage
.
downsample
(
x
,
hw_shape
)
return
tuple
(
outs
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
local_metadata
,
*
args
,
**
kwargs
):
"""load checkpoints."""
# Names of some parameters in has been changed.
version
=
local_metadata
.
get
(
'version'
,
None
)
if
(
version
is
None
or
version
<
2
)
and
self
.
__class__
is
SwinTransformer
:
final_stage_num
=
len
(
self
.
stages
)
-
1
state_dict_keys
=
list
(
state_dict
.
keys
())
for
k
in
state_dict_keys
:
if
k
.
startswith
(
'norm.'
)
or
k
.
startswith
(
'backbone.norm.'
):
convert_key
=
k
.
replace
(
'norm.'
,
f
'norm
{
final_stage_num
}
.'
)
state_dict
[
convert_key
]
=
state_dict
[
k
]
del
state_dict
[
k
]
if
(
version
is
None
or
version
<
3
)
and
self
.
__class__
is
SwinTransformer
:
state_dict_keys
=
list
(
state_dict
.
keys
())
for
k
in
state_dict_keys
:
if
'attn_mask'
in
k
:
del
state_dict
[
k
]
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
local_metadata
,
*
args
,
**
kwargs
)
def
_freeze_stages
(
self
):
if
self
.
frozen_stages
>=
0
:
self
.
patch_embed
.
eval
()
for
param
in
self
.
patch_embed
.
parameters
():
param
.
requires_grad
=
False
for
i
in
range
(
0
,
self
.
frozen_stages
+
1
):
m
=
self
.
stages
[
i
]
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
for
i
in
self
.
out_indices
:
if
i
<=
self
.
frozen_stages
:
for
param
in
getattr
(
self
,
f
'norm
{
i
}
'
).
parameters
():
param
.
requires_grad
=
False
def
train
(
self
,
mode
=
True
):
super
(
SwinTransformer
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
if
mode
and
self
.
norm_eval
:
for
m
in
self
.
modules
():
# trick: eval have effect on BatchNorm only
if
isinstance
(
m
,
_BatchNorm
):
m
.
eval
()
def
_prepare_abs_pos_embed
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
name
=
prefix
+
'absolute_pos_embed'
if
name
not
in
state_dict
.
keys
():
return
ckpt_pos_embed_shape
=
state_dict
[
name
].
shape
if
self
.
absolute_pos_embed
.
shape
!=
ckpt_pos_embed_shape
:
from
mmcls.utils
import
get_root_logger
logger
=
get_root_logger
()
logger
.
info
(
'Resize the absolute_pos_embed shape from '
f
'
{
ckpt_pos_embed_shape
}
to
{
self
.
absolute_pos_embed
.
shape
}
.'
)
ckpt_pos_embed_shape
=
to_2tuple
(
int
(
np
.
sqrt
(
ckpt_pos_embed_shape
[
1
]
-
self
.
num_extra_tokens
)))
pos_embed_shape
=
self
.
patch_embed
.
init_out_size
state_dict
[
name
]
=
resize_pos_embed
(
state_dict
[
name
],
ckpt_pos_embed_shape
,
pos_embed_shape
,
self
.
interpolate_mode
,
self
.
num_extra_tokens
)
def
_prepare_relative_position_bias_table
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
state_dict_model
=
self
.
state_dict
()
all_keys
=
list
(
state_dict_model
.
keys
())
for
key
in
all_keys
:
if
'relative_position_bias_table'
in
key
:
ckpt_key
=
prefix
+
key
if
ckpt_key
not
in
state_dict
:
continue
relative_position_bias_table_pretrained
=
state_dict
[
ckpt_key
]
relative_position_bias_table_current
=
state_dict_model
[
key
]
L1
,
nH1
=
relative_position_bias_table_pretrained
.
size
()
L2
,
nH2
=
relative_position_bias_table_current
.
size
()
if
L1
!=
L2
:
src_size
=
int
(
L1
**
0.5
)
dst_size
=
int
(
L2
**
0.5
)
new_rel_pos_bias
=
resize_relative_position_bias_table
(
src_size
,
dst_size
,
relative_position_bias_table_pretrained
,
nH1
)
from
mmcls.utils
import
get_root_logger
logger
=
get_root_logger
()
logger
.
info
(
'Resize the relative_position_bias_table from '
f
'
{
state_dict
[
ckpt_key
].
shape
}
to '
f
'
{
new_rel_pos_bias
.
shape
}
'
)
state_dict
[
ckpt_key
]
=
new_rel_pos_bias
# The index buffer need to be re-generated.
index_buffer
=
ckpt_key
.
replace
(
'bias_table'
,
'index'
)
del
state_dict
[
index_buffer
]
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/swin_transformer_v2.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
from
typing
import
Sequence
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
cp
from
mmcv.cnn
import
build_norm_layer
from
mmcv.cnn.bricks.transformer
import
FFN
,
PatchEmbed
from
mmcv.cnn.utils.weight_init
import
trunc_normal_
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
from
mmcv.utils.parrots_wrapper
import
_BatchNorm
from
..builder
import
BACKBONES
from
..utils
import
(
PatchMerging
,
ShiftWindowMSA
,
WindowMSAV2
,
resize_pos_embed
,
to_2tuple
)
from
.base_backbone
import
BaseBackbone
class
SwinBlockV2
(
BaseModule
):
"""Swin Transformer V2 block. Use post normalization.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
shift (bool): Shift the attention window or not. Defaults to False.
extra_norm (bool): Whether add extra norm at the end of main branch.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
drop_path (float): The drop path rate after attention and ffn.
Defaults to 0.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
attn_cfgs (dict): The extra config of Shift Window-MSA.
Defaults to empty dict.
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
norm_cfg (dict): The config of norm layers.
Defaults to ``dict(type='LN')``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
pretrained_window_size (int): Window size in pretrained.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
window_size
=
8
,
shift
=
False
,
extra_norm
=
False
,
ffn_ratio
=
4.
,
drop_path
=
0.
,
pad_small_map
=
False
,
attn_cfgs
=
dict
(),
ffn_cfgs
=
dict
(),
norm_cfg
=
dict
(
type
=
'LN'
),
with_cp
=
False
,
pretrained_window_size
=
0
,
init_cfg
=
None
):
super
(
SwinBlockV2
,
self
).
__init__
(
init_cfg
)
self
.
with_cp
=
with_cp
self
.
extra_norm
=
extra_norm
_attn_cfgs
=
{
'embed_dims'
:
embed_dims
,
'num_heads'
:
num_heads
,
'shift_size'
:
window_size
//
2
if
shift
else
0
,
'window_size'
:
window_size
,
'dropout_layer'
:
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path
),
'pad_small_map'
:
pad_small_map
,
**
attn_cfgs
}
# use V2 attention implementation
_attn_cfgs
.
update
(
window_msa
=
WindowMSAV2
,
msa_cfg
=
dict
(
pretrained_window_size
=
to_2tuple
(
pretrained_window_size
)))
self
.
attn
=
ShiftWindowMSA
(
**
_attn_cfgs
)
self
.
norm1
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
_ffn_cfgs
=
{
'embed_dims'
:
embed_dims
,
'feedforward_channels'
:
int
(
embed_dims
*
ffn_ratio
),
'num_fcs'
:
2
,
'ffn_drop'
:
0
,
'dropout_layer'
:
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path
),
'act_cfg'
:
dict
(
type
=
'GELU'
),
'add_identity'
:
False
,
**
ffn_cfgs
}
self
.
ffn
=
FFN
(
**
_ffn_cfgs
)
self
.
norm2
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
# add extra norm for every n blocks in huge and giant model
if
self
.
extra_norm
:
self
.
norm3
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
def
forward
(
self
,
x
,
hw_shape
):
def
_inner_forward
(
x
):
# Use post normalization
identity
=
x
x
=
self
.
attn
(
x
,
hw_shape
)
x
=
self
.
norm1
(
x
)
x
=
x
+
identity
identity
=
x
x
=
self
.
ffn
(
x
)
x
=
self
.
norm2
(
x
)
x
=
x
+
identity
if
self
.
extra_norm
:
x
=
self
.
norm3
(
x
)
return
x
if
self
.
with_cp
and
x
.
requires_grad
:
x
=
cp
.
checkpoint
(
_inner_forward
,
x
)
else
:
x
=
_inner_forward
(
x
)
return
x
class
SwinBlockV2Sequence
(
BaseModule
):
"""Module with successive Swin Transformer blocks and downsample layer.
Args:
embed_dims (int): Number of input channels.
depth (int): Number of successive swin transformer blocks.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
downsample (bool): Downsample the output of blocks by patch merging.
Defaults to False.
downsample_cfg (dict): The extra config of the patch merging layer.
Defaults to empty dict.
drop_paths (Sequence[float] | float): The drop path rate in each block.
Defaults to 0.
block_cfgs (Sequence[dict] | dict): The extra config of each block.
Defaults to empty dicts.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
extra_norm_every_n_blocks (int): Add extra norm at the end of main
branch every n blocks. Defaults to 0, which means no needs for
extra norm layer.
pretrained_window_size (int): Window size in pretrained.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
depth
,
num_heads
,
window_size
=
8
,
downsample
=
False
,
downsample_cfg
=
dict
(),
drop_paths
=
0.
,
block_cfgs
=
dict
(),
with_cp
=
False
,
pad_small_map
=
False
,
extra_norm_every_n_blocks
=
0
,
pretrained_window_size
=
0
,
init_cfg
=
None
):
super
().
__init__
(
init_cfg
)
if
not
isinstance
(
drop_paths
,
Sequence
):
drop_paths
=
[
drop_paths
]
*
depth
if
not
isinstance
(
block_cfgs
,
Sequence
):
block_cfgs
=
[
deepcopy
(
block_cfgs
)
for
_
in
range
(
depth
)]
if
downsample
:
self
.
out_channels
=
2
*
embed_dims
_downsample_cfg
=
{
'in_channels'
:
embed_dims
,
'out_channels'
:
self
.
out_channels
,
'norm_cfg'
:
dict
(
type
=
'LN'
),
**
downsample_cfg
}
self
.
downsample
=
PatchMerging
(
**
_downsample_cfg
)
else
:
self
.
out_channels
=
embed_dims
self
.
downsample
=
None
self
.
blocks
=
ModuleList
()
for
i
in
range
(
depth
):
extra_norm
=
True
if
extra_norm_every_n_blocks
and
\
(
i
+
1
)
%
extra_norm_every_n_blocks
==
0
else
False
_block_cfg
=
{
'embed_dims'
:
self
.
out_channels
,
'num_heads'
:
num_heads
,
'window_size'
:
window_size
,
'shift'
:
False
if
i
%
2
==
0
else
True
,
'extra_norm'
:
extra_norm
,
'drop_path'
:
drop_paths
[
i
],
'with_cp'
:
with_cp
,
'pad_small_map'
:
pad_small_map
,
'pretrained_window_size'
:
pretrained_window_size
,
**
block_cfgs
[
i
]
}
block
=
SwinBlockV2
(
**
_block_cfg
)
self
.
blocks
.
append
(
block
)
def
forward
(
self
,
x
,
in_shape
):
if
self
.
downsample
:
x
,
out_shape
=
self
.
downsample
(
x
,
in_shape
)
else
:
out_shape
=
in_shape
for
block
in
self
.
blocks
:
x
=
block
(
x
,
out_shape
)
return
x
,
out_shape
@
BACKBONES
.
register_module
()
class
SwinTransformerV2
(
BaseBackbone
):
"""Swin Transformer V2.
A PyTorch implement of : `Swin Transformer V2:
Scaling Up Capacity and Resolution
<https://arxiv.org/abs/2111.09883>`_
Inspiration from
https://github.com/microsoft/Swin-Transformer
Args:
arch (str | dict): Swin Transformer architecture. If use string, choose
from 'tiny', 'small', 'base' and 'large'. If use dict, it should
have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **depths** (List[int]): The number of blocks in each stage.
- **num_heads** (List[int]): The number of heads in attention
modules of each stage.
- **extra_norm_every_n_blocks** (int): Add extra norm at the end
of main branch every n blocks.
Defaults to 'tiny'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 4.
in_channels (int): The num of input channels. Defaults to 3.
window_size (int | Sequence): The height and width of the window.
Defaults to 7.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults to False.
interpolate_mode (str): Select the interpolate mode for absolute
position embeding vector resize. Defaults to "bicubic".
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
norm_cfg (dict): Config dict for normalization layer for all output
features. Defaults to ``dict(type='LN')``
stage_cfgs (Sequence[dict] | dict): Extra config dict for each
stage. Defaults to an empty dict.
patch_cfg (dict): Extra config dict for patch embedding.
Defaults to an empty dict.
pretrained_window_sizes (tuple(int)): Pretrained window sizes of
each layer.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmcls.models import SwinTransformerV2
>>> import torch
>>> extra_config = dict(
>>> arch='tiny',
>>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3,
>>> 'padding': 'same'}))
>>> self = SwinTransformerV2(**extra_config)
>>> inputs = torch.rand(1, 3, 224, 224)
>>> output = self.forward(inputs)
>>> print(output.shape)
(1, 2592, 4)
"""
arch_zoo
=
{
**
dict
.
fromkeys
([
't'
,
'tiny'
],
{
'embed_dims'
:
96
,
'depths'
:
[
2
,
2
,
6
,
2
],
'num_heads'
:
[
3
,
6
,
12
,
24
],
'extra_norm_every_n_blocks'
:
0
}),
**
dict
.
fromkeys
([
's'
,
'small'
],
{
'embed_dims'
:
96
,
'depths'
:
[
2
,
2
,
18
,
2
],
'num_heads'
:
[
3
,
6
,
12
,
24
],
'extra_norm_every_n_blocks'
:
0
}),
**
dict
.
fromkeys
([
'b'
,
'base'
],
{
'embed_dims'
:
128
,
'depths'
:
[
2
,
2
,
18
,
2
],
'num_heads'
:
[
4
,
8
,
16
,
32
],
'extra_norm_every_n_blocks'
:
0
}),
**
dict
.
fromkeys
([
'l'
,
'large'
],
{
'embed_dims'
:
192
,
'depths'
:
[
2
,
2
,
18
,
2
],
'num_heads'
:
[
6
,
12
,
24
,
48
],
'extra_norm_every_n_blocks'
:
0
}),
# head count not certain for huge, and is employed for another
# parallel study about self-supervised learning.
**
dict
.
fromkeys
([
'h'
,
'huge'
],
{
'embed_dims'
:
352
,
'depths'
:
[
2
,
2
,
18
,
2
],
'num_heads'
:
[
8
,
16
,
32
,
64
],
'extra_norm_every_n_blocks'
:
6
}),
**
dict
.
fromkeys
([
'g'
,
'giant'
],
{
'embed_dims'
:
512
,
'depths'
:
[
2
,
2
,
42
,
4
],
'num_heads'
:
[
16
,
32
,
64
,
128
],
'extra_norm_every_n_blocks'
:
6
}),
}
# yapf: disable
_version
=
1
num_extra_tokens
=
0
def
__init__
(
self
,
arch
=
'tiny'
,
img_size
=
256
,
patch_size
=
4
,
in_channels
=
3
,
window_size
=
8
,
drop_rate
=
0.
,
drop_path_rate
=
0.1
,
out_indices
=
(
3
,
),
use_abs_pos_embed
=
False
,
interpolate_mode
=
'bicubic'
,
with_cp
=
False
,
frozen_stages
=-
1
,
norm_eval
=
False
,
pad_small_map
=
False
,
norm_cfg
=
dict
(
type
=
'LN'
),
stage_cfgs
=
dict
(
downsample_cfg
=
dict
(
is_post_norm
=
True
)),
patch_cfg
=
dict
(),
pretrained_window_sizes
=
[
0
,
0
,
0
,
0
],
init_cfg
=
None
):
super
(
SwinTransformerV2
,
self
).
__init__
(
init_cfg
=
init_cfg
)
if
isinstance
(
arch
,
str
):
arch
=
arch
.
lower
()
assert
arch
in
set
(
self
.
arch_zoo
),
\
f
'Arch
{
arch
}
is not in default archs
{
set
(
self
.
arch_zoo
)
}
'
self
.
arch_settings
=
self
.
arch_zoo
[
arch
]
else
:
essential_keys
=
{
'embed_dims'
,
'depths'
,
'num_heads'
,
'extra_norm_every_n_blocks'
}
assert
isinstance
(
arch
,
dict
)
and
set
(
arch
)
==
essential_keys
,
\
f
'Custom arch needs a dict with keys
{
essential_keys
}
'
self
.
arch_settings
=
arch
self
.
embed_dims
=
self
.
arch_settings
[
'embed_dims'
]
self
.
depths
=
self
.
arch_settings
[
'depths'
]
self
.
num_heads
=
self
.
arch_settings
[
'num_heads'
]
self
.
extra_norm_every_n_blocks
=
self
.
arch_settings
[
'extra_norm_every_n_blocks'
]
self
.
num_layers
=
len
(
self
.
depths
)
self
.
out_indices
=
out_indices
self
.
use_abs_pos_embed
=
use_abs_pos_embed
self
.
interpolate_mode
=
interpolate_mode
self
.
frozen_stages
=
frozen_stages
if
isinstance
(
window_size
,
int
):
self
.
window_sizes
=
[
window_size
for
_
in
range
(
self
.
num_layers
)]
elif
isinstance
(
window_size
,
Sequence
):
assert
len
(
window_size
)
==
self
.
num_layers
,
\
f
'Length of window_sizes
{
len
(
window_size
)
}
is not equal to '
\
f
'length of stages
{
self
.
num_layers
}
.'
self
.
window_sizes
=
window_size
else
:
raise
TypeError
(
'window_size should be a Sequence or int.'
)
_patch_cfg
=
dict
(
in_channels
=
in_channels
,
input_size
=
img_size
,
embed_dims
=
self
.
embed_dims
,
conv_type
=
'Conv2d'
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
norm_cfg
=
dict
(
type
=
'LN'
),
)
_patch_cfg
.
update
(
patch_cfg
)
self
.
patch_embed
=
PatchEmbed
(
**
_patch_cfg
)
self
.
patch_resolution
=
self
.
patch_embed
.
init_out_size
if
self
.
use_abs_pos_embed
:
num_patches
=
self
.
patch_resolution
[
0
]
*
self
.
patch_resolution
[
1
]
self
.
absolute_pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
,
self
.
embed_dims
))
self
.
_register_load_state_dict_pre_hook
(
self
.
_prepare_abs_pos_embed
)
self
.
_register_load_state_dict_pre_hook
(
self
.
_delete_reinit_params
)
self
.
drop_after_pos
=
nn
.
Dropout
(
p
=
drop_rate
)
self
.
norm_eval
=
norm_eval
# stochastic depth
total_depth
=
sum
(
self
.
depths
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
total_depth
)
]
# stochastic depth decay rule
self
.
stages
=
ModuleList
()
embed_dims
=
[
self
.
embed_dims
]
for
i
,
(
depth
,
num_heads
)
in
enumerate
(
zip
(
self
.
depths
,
self
.
num_heads
)):
if
isinstance
(
stage_cfgs
,
Sequence
):
stage_cfg
=
stage_cfgs
[
i
]
else
:
stage_cfg
=
deepcopy
(
stage_cfgs
)
downsample
=
True
if
i
>
0
else
False
_stage_cfg
=
{
'embed_dims'
:
embed_dims
[
-
1
],
'depth'
:
depth
,
'num_heads'
:
num_heads
,
'window_size'
:
self
.
window_sizes
[
i
],
'downsample'
:
downsample
,
'drop_paths'
:
dpr
[:
depth
],
'with_cp'
:
with_cp
,
'pad_small_map'
:
pad_small_map
,
'extra_norm_every_n_blocks'
:
self
.
extra_norm_every_n_blocks
,
'pretrained_window_size'
:
pretrained_window_sizes
[
i
],
**
stage_cfg
}
stage
=
SwinBlockV2Sequence
(
**
_stage_cfg
)
self
.
stages
.
append
(
stage
)
dpr
=
dpr
[
depth
:]
embed_dims
.
append
(
stage
.
out_channels
)
for
i
in
out_indices
:
if
norm_cfg
is
not
None
:
norm_layer
=
build_norm_layer
(
norm_cfg
,
embed_dims
[
i
+
1
])[
1
]
else
:
norm_layer
=
nn
.
Identity
()
self
.
add_module
(
f
'norm
{
i
}
'
,
norm_layer
)
def
init_weights
(
self
):
super
(
SwinTransformerV2
,
self
).
init_weights
()
if
(
isinstance
(
self
.
init_cfg
,
dict
)
and
self
.
init_cfg
[
'type'
]
==
'Pretrained'
):
# Suppress default init if use pretrained model.
return
if
self
.
use_abs_pos_embed
:
trunc_normal_
(
self
.
absolute_pos_embed
,
std
=
0.02
)
def
forward
(
self
,
x
):
x
,
hw_shape
=
self
.
patch_embed
(
x
)
if
self
.
use_abs_pos_embed
:
x
=
x
+
resize_pos_embed
(
self
.
absolute_pos_embed
,
self
.
patch_resolution
,
hw_shape
,
self
.
interpolate_mode
,
self
.
num_extra_tokens
)
x
=
self
.
drop_after_pos
(
x
)
outs
=
[]
for
i
,
stage
in
enumerate
(
self
.
stages
):
x
,
hw_shape
=
stage
(
x
,
hw_shape
)
if
i
in
self
.
out_indices
:
norm_layer
=
getattr
(
self
,
f
'norm
{
i
}
'
)
out
=
norm_layer
(
x
)
out
=
out
.
view
(
-
1
,
*
hw_shape
,
stage
.
out_channels
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
out
)
return
tuple
(
outs
)
def
_freeze_stages
(
self
):
if
self
.
frozen_stages
>=
0
:
self
.
patch_embed
.
eval
()
for
param
in
self
.
patch_embed
.
parameters
():
param
.
requires_grad
=
False
for
i
in
range
(
0
,
self
.
frozen_stages
+
1
):
m
=
self
.
stages
[
i
]
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
for
i
in
self
.
out_indices
:
if
i
<=
self
.
frozen_stages
:
for
param
in
getattr
(
self
,
f
'norm
{
i
}
'
).
parameters
():
param
.
requires_grad
=
False
def
train
(
self
,
mode
=
True
):
super
(
SwinTransformerV2
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
if
mode
and
self
.
norm_eval
:
for
m
in
self
.
modules
():
# trick: eval have effect on BatchNorm only
if
isinstance
(
m
,
_BatchNorm
):
m
.
eval
()
def
_prepare_abs_pos_embed
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
name
=
prefix
+
'absolute_pos_embed'
if
name
not
in
state_dict
.
keys
():
return
ckpt_pos_embed_shape
=
state_dict
[
name
].
shape
if
self
.
absolute_pos_embed
.
shape
!=
ckpt_pos_embed_shape
:
from
mmcls.utils
import
get_root_logger
logger
=
get_root_logger
()
logger
.
info
(
'Resize the absolute_pos_embed shape from '
f
'
{
ckpt_pos_embed_shape
}
to
{
self
.
absolute_pos_embed
.
shape
}
.'
)
ckpt_pos_embed_shape
=
to_2tuple
(
int
(
np
.
sqrt
(
ckpt_pos_embed_shape
[
1
]
-
self
.
num_extra_tokens
)))
pos_embed_shape
=
self
.
patch_embed
.
init_out_size
state_dict
[
name
]
=
resize_pos_embed
(
state_dict
[
name
],
ckpt_pos_embed_shape
,
pos_embed_shape
,
self
.
interpolate_mode
,
self
.
num_extra_tokens
)
def
_delete_reinit_params
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
# delete relative_position_index since we always re-init it
relative_position_index_keys
=
[
k
for
k
in
state_dict
.
keys
()
if
'relative_position_index'
in
k
]
for
k
in
relative_position_index_keys
:
del
state_dict
[
k
]
# delete relative_coords_table since we always re-init it
relative_position_index_keys
=
[
k
for
k
in
state_dict
.
keys
()
if
'relative_coords_table'
in
k
]
for
k
in
relative_position_index_keys
:
del
state_dict
[
k
]
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/t2t_vit.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
copy
import
deepcopy
from
typing
import
Sequence
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
build_norm_layer
from
mmcv.cnn.bricks.transformer
import
FFN
from
mmcv.cnn.utils.weight_init
import
trunc_normal_
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
from
..builder
import
BACKBONES
from
..utils
import
MultiheadAttention
,
resize_pos_embed
,
to_2tuple
from
.base_backbone
import
BaseBackbone
class
T2TTransformerLayer
(
BaseModule
):
"""Transformer Layer for T2T_ViT.
Comparing with :obj:`TransformerEncoderLayer` in ViT, it supports
different ``input_dims`` and ``embed_dims``.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs
input_dims (int, optional): The input token dimension.
Defaults to None.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop_rate (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
qk_scale (float, optional): Override default qk scale of
``(input_dims // num_heads) ** -0.5`` if set. Defaults to None.
act_cfg (dict): The activation config for FFNs.
Defaluts to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
Notes:
In general, ``qk_scale`` should be ``head_dims ** -0.5``, i.e.
``(embed_dims // num_heads) ** -0.5``. However, in the official
code, it uses ``(input_dims // num_heads) ** -0.5``, so here we
keep the same with the official implementation.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
feedforward_channels
,
input_dims
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
num_fcs
=
2
,
qkv_bias
=
False
,
qk_scale
=
None
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
dict
(
type
=
'LN'
),
init_cfg
=
None
):
super
(
T2TTransformerLayer
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
v_shortcut
=
True
if
input_dims
is
not
None
else
False
input_dims
=
input_dims
or
embed_dims
self
.
norm1_name
,
norm1
=
build_norm_layer
(
norm_cfg
,
input_dims
,
postfix
=
1
)
self
.
add_module
(
self
.
norm1_name
,
norm1
)
self
.
attn
=
MultiheadAttention
(
input_dims
=
input_dims
,
embed_dims
=
embed_dims
,
num_heads
=
num_heads
,
attn_drop
=
attn_drop_rate
,
proj_drop
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
or
(
input_dims
//
num_heads
)
**-
0.5
,
v_shortcut
=
self
.
v_shortcut
)
self
.
norm2_name
,
norm2
=
build_norm_layer
(
norm_cfg
,
embed_dims
,
postfix
=
2
)
self
.
add_module
(
self
.
norm2_name
,
norm2
)
self
.
ffn
=
FFN
(
embed_dims
=
embed_dims
,
feedforward_channels
=
feedforward_channels
,
num_fcs
=
num_fcs
,
ffn_drop
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
act_cfg
=
act_cfg
)
@
property
def
norm1
(
self
):
return
getattr
(
self
,
self
.
norm1_name
)
@
property
def
norm2
(
self
):
return
getattr
(
self
,
self
.
norm2_name
)
def
forward
(
self
,
x
):
if
self
.
v_shortcut
:
x
=
self
.
attn
(
self
.
norm1
(
x
))
else
:
x
=
x
+
self
.
attn
(
self
.
norm1
(
x
))
x
=
self
.
ffn
(
self
.
norm2
(
x
),
identity
=
x
)
return
x
class
T2TModule
(
BaseModule
):
"""Tokens-to-Token module.
"Tokens-to-Token module" (T2T Module) can model the local structure
information of images and reduce the length of tokens progressively.
Args:
img_size (int): Input image size
in_channels (int): Number of input channels
embed_dims (int): Embedding dimension
token_dims (int): Tokens dimension in T2TModuleAttention.
use_performer (bool): If True, use Performer version self-attention to
adopt regular self-attention. Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Default: None.
Notes:
Usually, ``token_dim`` is set as a small value (32 or 64) to reduce
MACs
"""
def
__init__
(
self
,
img_size
=
224
,
in_channels
=
3
,
embed_dims
=
384
,
token_dims
=
64
,
use_performer
=
False
,
init_cfg
=
None
,
):
super
(
T2TModule
,
self
).
__init__
(
init_cfg
)
self
.
embed_dims
=
embed_dims
self
.
soft_split0
=
nn
.
Unfold
(
kernel_size
=
(
7
,
7
),
stride
=
(
4
,
4
),
padding
=
(
2
,
2
))
self
.
soft_split1
=
nn
.
Unfold
(
kernel_size
=
(
3
,
3
),
stride
=
(
2
,
2
),
padding
=
(
1
,
1
))
self
.
soft_split2
=
nn
.
Unfold
(
kernel_size
=
(
3
,
3
),
stride
=
(
2
,
2
),
padding
=
(
1
,
1
))
if
not
use_performer
:
self
.
attention1
=
T2TTransformerLayer
(
input_dims
=
in_channels
*
7
*
7
,
embed_dims
=
token_dims
,
num_heads
=
1
,
feedforward_channels
=
token_dims
)
self
.
attention2
=
T2TTransformerLayer
(
input_dims
=
token_dims
*
3
*
3
,
embed_dims
=
token_dims
,
num_heads
=
1
,
feedforward_channels
=
token_dims
)
self
.
project
=
nn
.
Linear
(
token_dims
*
3
*
3
,
embed_dims
)
else
:
raise
NotImplementedError
(
"Performer hasn't been implemented."
)
# there are 3 soft split, stride are 4,2,2 separately
out_side
=
img_size
//
(
4
*
2
*
2
)
self
.
init_out_size
=
[
out_side
,
out_side
]
self
.
num_patches
=
out_side
**
2
@
staticmethod
def
_get_unfold_size
(
unfold
:
nn
.
Unfold
,
input_size
):
h
,
w
=
input_size
kernel_size
=
to_2tuple
(
unfold
.
kernel_size
)
stride
=
to_2tuple
(
unfold
.
stride
)
padding
=
to_2tuple
(
unfold
.
padding
)
dilation
=
to_2tuple
(
unfold
.
dilation
)
h_out
=
(
h
+
2
*
padding
[
0
]
-
dilation
[
0
]
*
(
kernel_size
[
0
]
-
1
)
-
1
)
//
stride
[
0
]
+
1
w_out
=
(
w
+
2
*
padding
[
1
]
-
dilation
[
1
]
*
(
kernel_size
[
1
]
-
1
)
-
1
)
//
stride
[
1
]
+
1
return
(
h_out
,
w_out
)
def
forward
(
self
,
x
):
# step0: soft split
hw_shape
=
self
.
_get_unfold_size
(
self
.
soft_split0
,
x
.
shape
[
2
:])
x
=
self
.
soft_split0
(
x
).
transpose
(
1
,
2
)
for
step
in
[
1
,
2
]:
# re-structurization/reconstruction
attn
=
getattr
(
self
,
f
'attention
{
step
}
'
)
x
=
attn
(
x
).
transpose
(
1
,
2
)
B
,
C
,
_
=
x
.
shape
x
=
x
.
reshape
(
B
,
C
,
hw_shape
[
0
],
hw_shape
[
1
])
# soft split
soft_split
=
getattr
(
self
,
f
'soft_split
{
step
}
'
)
hw_shape
=
self
.
_get_unfold_size
(
soft_split
,
hw_shape
)
x
=
soft_split
(
x
).
transpose
(
1
,
2
)
# final tokens
x
=
self
.
project
(
x
)
return
x
,
hw_shape
def
get_sinusoid_encoding
(
n_position
,
embed_dims
):
"""Generate sinusoid encoding table.
Sinusoid encoding is a kind of relative position encoding method came from
`Attention Is All You Need<https://arxiv.org/abs/1706.03762>`_.
Args:
n_position (int): The length of the input token.
embed_dims (int): The position embedding dimension.
Returns:
:obj:`torch.FloatTensor`: The sinusoid encoding table.
"""
vec
=
torch
.
arange
(
embed_dims
,
dtype
=
torch
.
float64
)
vec
=
(
vec
-
vec
%
2
)
/
embed_dims
vec
=
torch
.
pow
(
10000
,
-
vec
).
view
(
1
,
-
1
)
sinusoid_table
=
torch
.
arange
(
n_position
).
view
(
-
1
,
1
)
*
vec
sinusoid_table
[:,
0
::
2
].
sin_
()
# dim 2i
sinusoid_table
[:,
1
::
2
].
cos_
()
# dim 2i+1
sinusoid_table
=
sinusoid_table
.
to
(
torch
.
float32
)
return
sinusoid_table
.
unsqueeze
(
0
)
@
BACKBONES
.
register_module
()
class
T2T_ViT
(
BaseBackbone
):
"""Tokens-to-Token Vision Transformer (T2T-ViT)
A PyTorch implementation of `Tokens-to-Token ViT: Training Vision
Transformers from Scratch on ImageNet <https://arxiv.org/abs/2101.11986>`_
Args:
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
in_channels (int): Number of input channels.
embed_dims (int): Embedding dimension.
num_layers (int): Num of transformer layers in encoder.
Defaults to 14.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Dropout rate after position embedding.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
norm_cfg (dict): Config dict for normalization layer. Defaults to
``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
output_cls_token (bool): Whether output the cls_token. If set True,
``with_cls_token`` must be True. Defaults to True.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
t2t_cfg (dict): Extra config of Tokens-to-Token module.
Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
num_extra_tokens
=
1
# cls_token
def
__init__
(
self
,
img_size
=
224
,
in_channels
=
3
,
embed_dims
=
384
,
num_layers
=
14
,
out_indices
=-
1
,
drop_rate
=
0.
,
drop_path_rate
=
0.
,
norm_cfg
=
dict
(
type
=
'LN'
),
final_norm
=
True
,
with_cls_token
=
True
,
output_cls_token
=
True
,
interpolate_mode
=
'bicubic'
,
t2t_cfg
=
dict
(),
layer_cfgs
=
dict
(),
init_cfg
=
None
):
super
(
T2T_ViT
,
self
).
__init__
(
init_cfg
)
# Token-to-Token Module
self
.
tokens_to_token
=
T2TModule
(
img_size
=
img_size
,
in_channels
=
in_channels
,
embed_dims
=
embed_dims
,
**
t2t_cfg
)
self
.
patch_resolution
=
self
.
tokens_to_token
.
init_out_size
num_patches
=
self
.
patch_resolution
[
0
]
*
self
.
patch_resolution
[
1
]
# Set cls token
if
output_cls_token
:
assert
with_cls_token
is
True
,
f
'with_cls_token must be True if'
\
f
'set output_cls_token to True, but got
{
with_cls_token
}
'
self
.
with_cls_token
=
with_cls_token
self
.
output_cls_token
=
output_cls_token
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
embed_dims
))
# Set position embedding
self
.
interpolate_mode
=
interpolate_mode
sinusoid_table
=
get_sinusoid_encoding
(
num_patches
+
self
.
num_extra_tokens
,
embed_dims
)
self
.
register_buffer
(
'pos_embed'
,
sinusoid_table
)
self
.
_register_load_state_dict_pre_hook
(
self
.
_prepare_pos_embed
)
self
.
drop_after_pos
=
nn
.
Dropout
(
p
=
drop_rate
)
if
isinstance
(
out_indices
,
int
):
out_indices
=
[
out_indices
]
assert
isinstance
(
out_indices
,
Sequence
),
\
f
'"out_indices" must be a sequence or int, '
\
f
'get
{
type
(
out_indices
)
}
instead.'
for
i
,
index
in
enumerate
(
out_indices
):
if
index
<
0
:
out_indices
[
i
]
=
num_layers
+
index
assert
0
<=
out_indices
[
i
]
<=
num_layers
,
\
f
'Invalid out_indices
{
index
}
'
self
.
out_indices
=
out_indices
# stochastic depth decay rule
dpr
=
[
x
for
x
in
np
.
linspace
(
0
,
drop_path_rate
,
num_layers
)]
self
.
encoder
=
ModuleList
()
for
i
in
range
(
num_layers
):
if
isinstance
(
layer_cfgs
,
Sequence
):
layer_cfg
=
layer_cfgs
[
i
]
else
:
layer_cfg
=
deepcopy
(
layer_cfgs
)
layer_cfg
=
{
'embed_dims'
:
embed_dims
,
'num_heads'
:
6
,
'feedforward_channels'
:
3
*
embed_dims
,
'drop_path_rate'
:
dpr
[
i
],
'qkv_bias'
:
False
,
'norm_cfg'
:
norm_cfg
,
**
layer_cfg
}
layer
=
T2TTransformerLayer
(
**
layer_cfg
)
self
.
encoder
.
append
(
layer
)
self
.
final_norm
=
final_norm
if
final_norm
:
self
.
norm
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
else
:
self
.
norm
=
nn
.
Identity
()
def
init_weights
(
self
):
super
().
init_weights
()
if
(
isinstance
(
self
.
init_cfg
,
dict
)
and
self
.
init_cfg
[
'type'
]
==
'Pretrained'
):
# Suppress custom init if use pretrained model.
return
trunc_normal_
(
self
.
cls_token
,
std
=
.
02
)
def
_prepare_pos_embed
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
name
=
prefix
+
'pos_embed'
if
name
not
in
state_dict
.
keys
():
return
ckpt_pos_embed_shape
=
state_dict
[
name
].
shape
if
self
.
pos_embed
.
shape
!=
ckpt_pos_embed_shape
:
from
mmcls.utils
import
get_root_logger
logger
=
get_root_logger
()
logger
.
info
(
f
'Resize the pos_embed shape from
{
ckpt_pos_embed_shape
}
'
f
'to
{
self
.
pos_embed
.
shape
}
.'
)
ckpt_pos_embed_shape
=
to_2tuple
(
int
(
np
.
sqrt
(
ckpt_pos_embed_shape
[
1
]
-
self
.
num_extra_tokens
)))
pos_embed_shape
=
self
.
tokens_to_token
.
init_out_size
state_dict
[
name
]
=
resize_pos_embed
(
state_dict
[
name
],
ckpt_pos_embed_shape
,
pos_embed_shape
,
self
.
interpolate_mode
,
self
.
num_extra_tokens
)
def
forward
(
self
,
x
):
B
=
x
.
shape
[
0
]
x
,
patch_resolution
=
self
.
tokens_to_token
(
x
)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens
=
self
.
cls_token
.
expand
(
B
,
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_tokens
,
x
),
dim
=
1
)
x
=
x
+
resize_pos_embed
(
self
.
pos_embed
,
self
.
patch_resolution
,
patch_resolution
,
mode
=
self
.
interpolate_mode
,
num_extra_tokens
=
self
.
num_extra_tokens
)
x
=
self
.
drop_after_pos
(
x
)
if
not
self
.
with_cls_token
:
# Remove class token for transformer encoder input
x
=
x
[:,
1
:]
outs
=
[]
for
i
,
layer
in
enumerate
(
self
.
encoder
):
x
=
layer
(
x
)
if
i
==
len
(
self
.
encoder
)
-
1
and
self
.
final_norm
:
x
=
self
.
norm
(
x
)
if
i
in
self
.
out_indices
:
B
,
_
,
C
=
x
.
shape
if
self
.
with_cls_token
:
patch_token
=
x
[:,
1
:].
reshape
(
B
,
*
patch_resolution
,
C
)
patch_token
=
patch_token
.
permute
(
0
,
3
,
1
,
2
)
cls_token
=
x
[:,
0
]
else
:
patch_token
=
x
.
reshape
(
B
,
*
patch_resolution
,
C
)
patch_token
=
patch_token
.
permute
(
0
,
3
,
1
,
2
)
cls_token
=
None
if
self
.
output_cls_token
:
out
=
[
patch_token
,
cls_token
]
else
:
out
=
patch_token
outs
.
append
(
out
)
return
tuple
(
outs
)
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/timm_backbone.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
try
:
import
timm
except
ImportError
:
timm
=
None
import
warnings
from
mmcv.cnn.bricks.registry
import
NORM_LAYERS
from
...utils
import
get_root_logger
from
..builder
import
BACKBONES
from
.base_backbone
import
BaseBackbone
def
print_timm_feature_info
(
feature_info
):
"""Print feature_info of timm backbone to help development and debug.
Args:
feature_info (list[dict] | timm.models.features.FeatureInfo | None):
feature_info of timm backbone.
"""
logger
=
get_root_logger
()
if
feature_info
is
None
:
logger
.
warning
(
'This backbone does not have feature_info'
)
elif
isinstance
(
feature_info
,
list
):
for
feat_idx
,
each_info
in
enumerate
(
feature_info
):
logger
.
info
(
f
'backbone feature_info[
{
feat_idx
}
]:
{
each_info
}
'
)
else
:
try
:
logger
.
info
(
f
'backbone out_indices:
{
feature_info
.
out_indices
}
'
)
logger
.
info
(
f
'backbone out_channels:
{
feature_info
.
channels
()
}
'
)
logger
.
info
(
f
'backbone out_strides:
{
feature_info
.
reduction
()
}
'
)
except
AttributeError
:
logger
.
warning
(
'Unexpected format of backbone feature_info'
)
@
BACKBONES
.
register_module
()
class
TIMMBackbone
(
BaseBackbone
):
"""Wrapper to use backbones from timm library.
More details can be found in
`timm <https://github.com/rwightman/pytorch-image-models>`_.
See especially the document for `feature extraction
<https://rwightman.github.io/pytorch-image-models/feature_extraction/>`_.
Args:
model_name (str): Name of timm model to instantiate.
features_only (bool): Whether to extract feature pyramid (multi-scale
feature maps from the deepest layer at each stride). For Vision
Transformer models that do not support this argument,
set this False. Defaults to False.
pretrained (bool): Whether to load pretrained weights.
Defaults to False.
checkpoint_path (str): Path of checkpoint to load at the last of
``timm.create_model``. Defaults to empty string, which means
not loading.
in_channels (int): Number of input image channels. Defaults to 3.
init_cfg (dict or list[dict], optional): Initialization config dict of
OpenMMLab projects. Defaults to None.
**kwargs: Other timm & model specific arguments.
"""
def
__init__
(
self
,
model_name
,
features_only
=
False
,
pretrained
=
False
,
checkpoint_path
=
''
,
in_channels
=
3
,
init_cfg
=
None
,
**
kwargs
):
if
timm
is
None
:
raise
RuntimeError
(
'Failed to import timm. Please run "pip install timm". '
'"pip install dataclasses" may also be needed for Python 3.6.'
)
if
not
isinstance
(
pretrained
,
bool
):
raise
TypeError
(
'pretrained must be bool, not str for model path'
)
if
features_only
and
checkpoint_path
:
warnings
.
warn
(
'Using both features_only and checkpoint_path will cause error'
' in timm. See '
'https://github.com/rwightman/pytorch-image-models/issues/488'
)
super
(
TIMMBackbone
,
self
).
__init__
(
init_cfg
)
if
'norm_layer'
in
kwargs
:
kwargs
[
'norm_layer'
]
=
NORM_LAYERS
.
get
(
kwargs
[
'norm_layer'
])
self
.
timm_model
=
timm
.
create_model
(
model_name
=
model_name
,
features_only
=
features_only
,
pretrained
=
pretrained
,
in_chans
=
in_channels
,
checkpoint_path
=
checkpoint_path
,
**
kwargs
)
# reset classifier
if
hasattr
(
self
.
timm_model
,
'reset_classifier'
):
self
.
timm_model
.
reset_classifier
(
0
,
''
)
# Hack to use pretrained weights from timm
if
pretrained
or
checkpoint_path
:
self
.
_is_init
=
True
feature_info
=
getattr
(
self
.
timm_model
,
'feature_info'
,
None
)
print_timm_feature_info
(
feature_info
)
def
forward
(
self
,
x
):
features
=
self
.
timm_model
(
x
)
if
isinstance
(
features
,
(
list
,
tuple
)):
features
=
tuple
(
features
)
else
:
features
=
(
features
,
)
return
features
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/tnt.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
math
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
build_norm_layer
from
mmcv.cnn.bricks.transformer
import
FFN
,
MultiheadAttention
from
mmcv.cnn.utils.weight_init
import
trunc_normal_
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
from
..builder
import
BACKBONES
from
..utils
import
to_2tuple
from
.base_backbone
import
BaseBackbone
class
TransformerBlock
(
BaseModule
):
"""Implement a transformer block in TnTLayer.
Args:
embed_dims (int): The feature dimension
num_heads (int): Parallel attention heads
ffn_ratio (int): A ratio to calculate the hidden_dims in ffn layer.
Default: 4
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Default 0.
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.
drop_path_rate (float): stochastic depth rate. Default 0.
num_fcs (int): The number of fully-connected layers for FFNs. Default 2
qkv_bias (bool): Enable bias for qkv if True. Default False
act_cfg (dict): The activation config for FFNs. Defaults to GELU.
norm_cfg (dict): Config dict for normalization layer. Default
layer normalization
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim) or (n, batch, embed_dim).
(batch, n, embed_dim) is common case in CV. Default to False
init_cfg (dict, optional): Initialization config dict. Default to None
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
ffn_ratio
=
4
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
num_fcs
=
2
,
qkv_bias
=
False
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
dict
(
type
=
'LN'
),
batch_first
=
True
,
init_cfg
=
None
):
super
(
TransformerBlock
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
norm_attn
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
self
.
attn
=
MultiheadAttention
(
embed_dims
=
embed_dims
,
num_heads
=
num_heads
,
attn_drop
=
attn_drop_rate
,
proj_drop
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
batch_first
=
batch_first
)
self
.
norm_ffn
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
self
.
ffn
=
FFN
(
embed_dims
=
embed_dims
,
feedforward_channels
=
embed_dims
*
ffn_ratio
,
num_fcs
=
num_fcs
,
ffn_drop
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
act_cfg
=
act_cfg
)
if
not
qkv_bias
:
self
.
attn
.
attn
.
in_proj_bias
=
None
def
forward
(
self
,
x
):
x
=
self
.
attn
(
self
.
norm_attn
(
x
),
identity
=
x
)
x
=
self
.
ffn
(
self
.
norm_ffn
(
x
),
identity
=
x
)
return
x
class
TnTLayer
(
BaseModule
):
"""Implement one encoder layer in Transformer in Transformer.
Args:
num_pixel (int): The pixel number in target patch transformed with
a linear projection in inner transformer
embed_dims_inner (int): Feature dimension in inner transformer block
embed_dims_outer (int): Feature dimension in outer transformer block
num_heads_inner (int): Parallel attention heads in inner transformer.
num_heads_outer (int): Parallel attention heads in outer transformer.
inner_block_cfg (dict): Extra config of inner transformer block.
Defaults to empty dict.
outer_block_cfg (dict): Extra config of outer transformer block.
Defaults to empty dict.
norm_cfg (dict): Config dict for normalization layer. Default
layer normalization
init_cfg (dict, optional): Initialization config dict. Default to None
"""
def
__init__
(
self
,
num_pixel
,
embed_dims_inner
,
embed_dims_outer
,
num_heads_inner
,
num_heads_outer
,
inner_block_cfg
=
dict
(),
outer_block_cfg
=
dict
(),
norm_cfg
=
dict
(
type
=
'LN'
),
init_cfg
=
None
):
super
(
TnTLayer
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
inner_block
=
TransformerBlock
(
embed_dims
=
embed_dims_inner
,
num_heads
=
num_heads_inner
,
**
inner_block_cfg
)
self
.
norm_proj
=
build_norm_layer
(
norm_cfg
,
embed_dims_inner
)[
1
]
self
.
projection
=
nn
.
Linear
(
embed_dims_inner
*
num_pixel
,
embed_dims_outer
,
bias
=
True
)
self
.
outer_block
=
TransformerBlock
(
embed_dims
=
embed_dims_outer
,
num_heads
=
num_heads_outer
,
**
outer_block_cfg
)
def
forward
(
self
,
pixel_embed
,
patch_embed
):
pixel_embed
=
self
.
inner_block
(
pixel_embed
)
B
,
N
,
C
=
patch_embed
.
size
()
patch_embed
[:,
1
:]
=
patch_embed
[:,
1
:]
+
self
.
projection
(
self
.
norm_proj
(
pixel_embed
).
reshape
(
B
,
N
-
1
,
-
1
))
patch_embed
=
self
.
outer_block
(
patch_embed
)
return
pixel_embed
,
patch_embed
class
PixelEmbed
(
BaseModule
):
"""Image to Pixel Embedding.
Args:
img_size (int | tuple): The size of input image
patch_size (int): The size of one patch
in_channels (int): The num of input channels
embed_dims_inner (int): The num of channels of the target patch
transformed with a linear projection in inner transformer
stride (int): The stride of the conv2d layer. We use a conv2d layer
and a unfold layer to implement image to pixel embedding.
init_cfg (dict, optional): Initialization config dict
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_channels
=
3
,
embed_dims_inner
=
48
,
stride
=
4
,
init_cfg
=
None
):
super
(
PixelEmbed
,
self
).
__init__
(
init_cfg
=
init_cfg
)
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
# patches_resolution property necessary for resizing
# positional embedding
patches_resolution
=
[
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
]
]
num_patches
=
patches_resolution
[
0
]
*
patches_resolution
[
1
]
self
.
img_size
=
img_size
self
.
num_patches
=
num_patches
self
.
embed_dims_inner
=
embed_dims_inner
new_patch_size
=
[
math
.
ceil
(
ps
/
stride
)
for
ps
in
patch_size
]
self
.
new_patch_size
=
new_patch_size
self
.
proj
=
nn
.
Conv2d
(
in_channels
,
self
.
embed_dims_inner
,
kernel_size
=
7
,
padding
=
3
,
stride
=
stride
)
self
.
unfold
=
nn
.
Unfold
(
kernel_size
=
new_patch_size
,
stride
=
new_patch_size
)
def
forward
(
self
,
x
,
pixel_pos
):
B
,
C
,
H
,
W
=
x
.
shape
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
\
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model "
\
f
'(
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
).'
x
=
self
.
proj
(
x
)
x
=
self
.
unfold
(
x
)
x
=
x
.
transpose
(
1
,
2
).
reshape
(
B
*
self
.
num_patches
,
self
.
embed_dims_inner
,
self
.
new_patch_size
[
0
],
self
.
new_patch_size
[
1
])
x
=
x
+
pixel_pos
x
=
x
.
reshape
(
B
*
self
.
num_patches
,
self
.
embed_dims_inner
,
-
1
).
transpose
(
1
,
2
)
return
x
@
BACKBONES
.
register_module
()
class
TNT
(
BaseBackbone
):
"""Transformer in Transformer.
A PyTorch implement of: `Transformer in Transformer
<https://arxiv.org/abs/2103.00112>`_
Inspiration from
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/tnt.py
Args:
arch (str | dict): Vision Transformer architecture
Default: 'b'
img_size (int | tuple): Input image size. Default to 224
patch_size (int | tuple): The patch size. Deault to 16
in_channels (int): Number of input channels. Default to 3
ffn_ratio (int): A ratio to calculate the hidden_dims in ffn layer.
Default: 4
qkv_bias (bool): Enable bias for qkv if True. Default False
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Default 0.
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.
drop_path_rate (float): stochastic depth rate. Default 0.
act_cfg (dict): The activation config for FFNs. Defaults to GELU.
norm_cfg (dict): Config dict for normalization layer. Default
layer normalization
first_stride (int): The stride of the conv2d layer. We use a conv2d
layer and a unfold layer to implement image to pixel embedding.
num_fcs (int): The number of fully-connected layers for FFNs. Default 2
init_cfg (dict, optional): Initialization config dict
"""
arch_zoo
=
{
**
dict
.
fromkeys
(
[
's'
,
'small'
],
{
'embed_dims_outer'
:
384
,
'embed_dims_inner'
:
24
,
'num_layers'
:
12
,
'num_heads_outer'
:
6
,
'num_heads_inner'
:
4
}),
**
dict
.
fromkeys
(
[
'b'
,
'base'
],
{
'embed_dims_outer'
:
640
,
'embed_dims_inner'
:
40
,
'num_layers'
:
12
,
'num_heads_outer'
:
10
,
'num_heads_inner'
:
4
})
}
def
__init__
(
self
,
arch
=
'b'
,
img_size
=
224
,
patch_size
=
16
,
in_channels
=
3
,
ffn_ratio
=
4
,
qkv_bias
=
False
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
dict
(
type
=
'LN'
),
first_stride
=
4
,
num_fcs
=
2
,
init_cfg
=
[
dict
(
type
=
'TruncNormal'
,
layer
=
'Linear'
,
std
=
.
02
),
dict
(
type
=
'Constant'
,
layer
=
'LayerNorm'
,
val
=
1.
,
bias
=
0.
)
]):
super
(
TNT
,
self
).
__init__
(
init_cfg
=
init_cfg
)
if
isinstance
(
arch
,
str
):
arch
=
arch
.
lower
()
assert
arch
in
set
(
self
.
arch_zoo
),
\
f
'Arch
{
arch
}
is not in default archs
{
set
(
self
.
arch_zoo
)
}
'
self
.
arch_settings
=
self
.
arch_zoo
[
arch
]
else
:
essential_keys
=
{
'embed_dims_outer'
,
'embed_dims_inner'
,
'num_layers'
,
'num_heads_inner'
,
'num_heads_outer'
}
assert
isinstance
(
arch
,
dict
)
and
set
(
arch
)
==
essential_keys
,
\
f
'Custom arch needs a dict with keys
{
essential_keys
}
'
self
.
arch_settings
=
arch
self
.
embed_dims_inner
=
self
.
arch_settings
[
'embed_dims_inner'
]
self
.
embed_dims_outer
=
self
.
arch_settings
[
'embed_dims_outer'
]
# embed_dims for consistency with other models
self
.
embed_dims
=
self
.
embed_dims_outer
self
.
num_layers
=
self
.
arch_settings
[
'num_layers'
]
self
.
num_heads_inner
=
self
.
arch_settings
[
'num_heads_inner'
]
self
.
num_heads_outer
=
self
.
arch_settings
[
'num_heads_outer'
]
self
.
pixel_embed
=
PixelEmbed
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_channels
=
in_channels
,
embed_dims_inner
=
self
.
embed_dims_inner
,
stride
=
first_stride
)
num_patches
=
self
.
pixel_embed
.
num_patches
self
.
num_patches
=
num_patches
new_patch_size
=
self
.
pixel_embed
.
new_patch_size
num_pixel
=
new_patch_size
[
0
]
*
new_patch_size
[
1
]
self
.
norm1_proj
=
build_norm_layer
(
norm_cfg
,
num_pixel
*
self
.
embed_dims_inner
)[
1
]
self
.
projection
=
nn
.
Linear
(
num_pixel
*
self
.
embed_dims_inner
,
self
.
embed_dims_outer
)
self
.
norm2_proj
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims_outer
)[
1
]
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
self
.
embed_dims_outer
))
self
.
patch_pos
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
+
1
,
self
.
embed_dims_outer
))
self
.
pixel_pos
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
self
.
embed_dims_inner
,
new_patch_size
[
0
],
new_patch_size
[
1
]))
self
.
drop_after_pos
=
nn
.
Dropout
(
p
=
drop_rate
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
self
.
num_layers
)
]
# stochastic depth decay rule
self
.
layers
=
ModuleList
()
for
i
in
range
(
self
.
num_layers
):
block_cfg
=
dict
(
ffn_ratio
=
ffn_ratio
,
drop_rate
=
drop_rate
,
attn_drop_rate
=
attn_drop_rate
,
drop_path_rate
=
dpr
[
i
],
num_fcs
=
num_fcs
,
qkv_bias
=
qkv_bias
,
norm_cfg
=
norm_cfg
,
batch_first
=
True
)
self
.
layers
.
append
(
TnTLayer
(
num_pixel
=
num_pixel
,
embed_dims_inner
=
self
.
embed_dims_inner
,
embed_dims_outer
=
self
.
embed_dims_outer
,
num_heads_inner
=
self
.
num_heads_inner
,
num_heads_outer
=
self
.
num_heads_outer
,
inner_block_cfg
=
block_cfg
,
outer_block_cfg
=
block_cfg
,
norm_cfg
=
norm_cfg
))
self
.
norm
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims_outer
)[
1
]
trunc_normal_
(
self
.
cls_token
,
std
=
.
02
)
trunc_normal_
(
self
.
patch_pos
,
std
=
.
02
)
trunc_normal_
(
self
.
pixel_pos
,
std
=
.
02
)
def
forward
(
self
,
x
):
B
=
x
.
shape
[
0
]
pixel_embed
=
self
.
pixel_embed
(
x
,
self
.
pixel_pos
)
patch_embed
=
self
.
norm2_proj
(
self
.
projection
(
self
.
norm1_proj
(
pixel_embed
.
reshape
(
B
,
self
.
num_patches
,
-
1
))))
patch_embed
=
torch
.
cat
(
(
self
.
cls_token
.
expand
(
B
,
-
1
,
-
1
),
patch_embed
),
dim
=
1
)
patch_embed
=
patch_embed
+
self
.
patch_pos
patch_embed
=
self
.
drop_after_pos
(
patch_embed
)
for
layer
in
self
.
layers
:
pixel_embed
,
patch_embed
=
layer
(
pixel_embed
,
patch_embed
)
patch_embed
=
self
.
norm
(
patch_embed
)
return
(
patch_embed
[:,
0
],
)
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/twins.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mmcv.cnn
import
Conv2d
,
build_norm_layer
from
mmcv.cnn.bricks.drop
import
build_dropout
from
mmcv.cnn.bricks.transformer
import
FFN
,
PatchEmbed
from
mmcv.cnn.utils.weight_init
import
(
constant_init
,
normal_init
,
trunc_normal_init
)
from
mmcv.runner
import
BaseModule
,
ModuleList
from
torch.nn.modules.batchnorm
import
_BatchNorm
from
mmcls.models.builder
import
BACKBONES
from
mmcls.models.utils.attention
import
MultiheadAttention
from
mmcls.models.utils.position_encoding
import
ConditionalPositionEncoding
class
GlobalSubsampledAttention
(
MultiheadAttention
):
"""Global Sub-sampled Attention (GSA) module.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
input_dims (int, optional): The input dimension, and if None,
use ``embed_dims``. Defaults to None.
attn_drop (float): Dropout rate of the dropout layer after the
attention calculation of query and key. Defaults to 0.
proj_drop (float): Dropout rate of the dropout layer after the
output projection. Defaults to 0.
dropout_layer (dict): The dropout config before adding the shortcut.
Defaults to ``dict(type='Dropout', drop_prob=0.)``.
qkv_bias (bool): If True, add a learnable bias to q, k, v.
Defaults to True.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
proj_bias (bool) If True, add a learnable bias to output projection.
Defaults to True.
v_shortcut (bool): Add a shortcut from value to output. It's usually
used if ``input_dims`` is different from ``embed_dims``.
Defaults to False.
sr_ratio (float): The ratio of spatial reduction in attention modules.
Defaults to 1.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
norm_cfg
=
dict
(
type
=
'LN'
),
qkv_bias
=
True
,
sr_ratio
=
1
,
**
kwargs
):
super
(
GlobalSubsampledAttention
,
self
).
__init__
(
embed_dims
,
num_heads
,
**
kwargs
)
self
.
qkv_bias
=
qkv_bias
self
.
q
=
nn
.
Linear
(
self
.
input_dims
,
embed_dims
,
bias
=
qkv_bias
)
self
.
kv
=
nn
.
Linear
(
self
.
input_dims
,
embed_dims
*
2
,
bias
=
qkv_bias
)
# remove self.qkv, here split into self.q, self.kv
delattr
(
self
,
'qkv'
)
self
.
sr_ratio
=
sr_ratio
if
sr_ratio
>
1
:
# use a conv as the spatial-reduction operation, the kernel_size
# and stride in conv are equal to the sr_ratio.
self
.
sr
=
Conv2d
(
in_channels
=
embed_dims
,
out_channels
=
embed_dims
,
kernel_size
=
sr_ratio
,
stride
=
sr_ratio
)
# The ret[0] of build_norm_layer is norm name.
self
.
norm
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
def
forward
(
self
,
x
,
hw_shape
):
B
,
N
,
C
=
x
.
shape
H
,
W
=
hw_shape
assert
H
*
W
==
N
,
'The product of h and w of hw_shape must be N, '
\
'which is the 2nd dim number of the input Tensor x.'
q
=
self
.
q
(
x
).
reshape
(
B
,
N
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
0
,
2
,
1
,
3
)
if
self
.
sr_ratio
>
1
:
x
=
x
.
permute
(
0
,
2
,
1
).
reshape
(
B
,
C
,
*
hw_shape
)
# BNC_2_BCHW
x
=
self
.
sr
(
x
)
x
=
x
.
reshape
(
B
,
C
,
-
1
).
permute
(
0
,
2
,
1
)
# BCHW_2_BNC
x
=
self
.
norm
(
x
)
kv
=
self
.
kv
(
x
).
reshape
(
B
,
-
1
,
2
,
self
.
num_heads
,
self
.
head_dims
).
permute
(
2
,
0
,
3
,
1
,
4
)
k
,
v
=
kv
[
0
],
kv
[
1
]
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
*
self
.
scale
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
out_drop
(
self
.
proj_drop
(
x
))
if
self
.
v_shortcut
:
x
=
v
.
squeeze
(
1
)
+
x
return
x
class
GSAEncoderLayer
(
BaseModule
):
"""Implements one encoder layer with GlobalSubsampledAttention(GSA).
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Default: 0.0.
attn_drop_rate (float): The drop out rate for attention layer.
Default: 0.0.
drop_path_rate (float): Stochastic depth rate. Default 0.0.
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
qkv_bias (bool): Enable bias for qkv if True. Default: True
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
sr_ratio (float): The ratio of spatial reduction in attention modules.
Defaults to 1.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
feedforward_channels
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
num_fcs
=
2
,
qkv_bias
=
True
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
dict
(
type
=
'LN'
),
sr_ratio
=
1.
,
init_cfg
=
None
):
super
(
GSAEncoderLayer
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
norm1
=
build_norm_layer
(
norm_cfg
,
embed_dims
,
postfix
=
1
)[
1
]
self
.
attn
=
GlobalSubsampledAttention
(
embed_dims
=
embed_dims
,
num_heads
=
num_heads
,
attn_drop
=
attn_drop_rate
,
proj_drop
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
qkv_bias
=
qkv_bias
,
norm_cfg
=
norm_cfg
,
sr_ratio
=
sr_ratio
)
self
.
norm2
=
build_norm_layer
(
norm_cfg
,
embed_dims
,
postfix
=
2
)[
1
]
self
.
ffn
=
FFN
(
embed_dims
=
embed_dims
,
feedforward_channels
=
feedforward_channels
,
num_fcs
=
num_fcs
,
ffn_drop
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
act_cfg
=
act_cfg
,
add_identity
=
False
)
self
.
drop_path
=
build_dropout
(
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
)
)
if
drop_path_rate
>
0.
else
nn
.
Identity
()
def
forward
(
self
,
x
,
hw_shape
):
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
),
hw_shape
))
x
=
x
+
self
.
drop_path
(
self
.
ffn
(
self
.
norm2
(
x
)))
return
x
class
LocallyGroupedSelfAttention
(
BaseModule
):
"""Locally-grouped Self Attention (LSA) module.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads. Default: 8
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: False.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Default: 0.0
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
window_size(int): Window size of LSA. Default: 1.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
=
8
,
qkv_bias
=
False
,
qk_scale
=
None
,
attn_drop_rate
=
0.
,
proj_drop_rate
=
0.
,
window_size
=
1
,
init_cfg
=
None
):
super
(
LocallyGroupedSelfAttention
,
self
).
__init__
(
init_cfg
=
init_cfg
)
assert
embed_dims
%
num_heads
==
0
,
\
f
'dim
{
embed_dims
}
should be divided by num_heads
{
num_heads
}
'
self
.
embed_dims
=
embed_dims
self
.
num_heads
=
num_heads
head_dim
=
embed_dims
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**-
0.5
self
.
qkv
=
nn
.
Linear
(
embed_dims
,
embed_dims
*
3
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop_rate
)
self
.
proj
=
nn
.
Linear
(
embed_dims
,
embed_dims
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop_rate
)
self
.
window_size
=
window_size
def
forward
(
self
,
x
,
hw_shape
):
B
,
N
,
C
=
x
.
shape
H
,
W
=
hw_shape
x
=
x
.
view
(
B
,
H
,
W
,
C
)
# pad feature maps to multiples of Local-groups
pad_l
=
pad_t
=
0
pad_r
=
(
self
.
window_size
-
W
%
self
.
window_size
)
%
self
.
window_size
pad_b
=
(
self
.
window_size
-
H
%
self
.
window_size
)
%
self
.
window_size
x
=
F
.
pad
(
x
,
(
0
,
0
,
pad_l
,
pad_r
,
pad_t
,
pad_b
))
# calculate attention mask for LSA
Hp
,
Wp
=
x
.
shape
[
1
:
-
1
]
_h
,
_w
=
Hp
//
self
.
window_size
,
Wp
//
self
.
window_size
mask
=
torch
.
zeros
((
1
,
Hp
,
Wp
),
device
=
x
.
device
)
mask
[:,
-
pad_b
:,
:].
fill_
(
1
)
mask
[:,
:,
-
pad_r
:].
fill_
(
1
)
# [B, _h, _w, window_size, window_size, C]
x
=
x
.
reshape
(
B
,
_h
,
self
.
window_size
,
_w
,
self
.
window_size
,
C
).
transpose
(
2
,
3
)
mask
=
mask
.
reshape
(
1
,
_h
,
self
.
window_size
,
_w
,
self
.
window_size
).
transpose
(
2
,
3
).
reshape
(
1
,
_h
*
_w
,
self
.
window_size
*
self
.
window_size
)
# [1, _h*_w, window_size*window_size, window_size*window_size]
attn_mask
=
mask
.
unsqueeze
(
2
)
-
mask
.
unsqueeze
(
3
)
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
!=
0
,
float
(
-
1000.0
)).
masked_fill
(
attn_mask
==
0
,
float
(
0.0
))
# [3, B, _w*_h, nhead, window_size*window_size, dim]
qkv
=
self
.
qkv
(
x
).
reshape
(
B
,
_h
*
_w
,
self
.
window_size
*
self
.
window_size
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
3
,
0
,
1
,
4
,
2
,
5
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
# [B, _h*_w, n_head, window_size*window_size, window_size*window_size]
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
*
self
.
scale
attn
=
attn
+
attn_mask
.
unsqueeze
(
2
)
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
attn
=
(
attn
@
v
).
transpose
(
2
,
3
).
reshape
(
B
,
_h
,
_w
,
self
.
window_size
,
self
.
window_size
,
C
)
x
=
attn
.
transpose
(
2
,
3
).
reshape
(
B
,
_h
*
self
.
window_size
,
_w
*
self
.
window_size
,
C
)
if
pad_r
>
0
or
pad_b
>
0
:
x
=
x
[:,
:
H
,
:
W
,
:].
contiguous
()
x
=
x
.
reshape
(
B
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
LSAEncoderLayer
(
BaseModule
):
"""Implements one encoder layer with LocallyGroupedSelfAttention(LSA).
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Default: 0.0.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Default: 0.0
drop_path_rate (float): Stochastic depth rate. Default 0.0.
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
qkv_bias (bool): Enable bias for qkv if True. Default: True
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
window_size (int): Window size of LSA. Default: 1.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
feedforward_channels
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
num_fcs
=
2
,
qkv_bias
=
True
,
qk_scale
=
None
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
dict
(
type
=
'LN'
),
window_size
=
1
,
init_cfg
=
None
):
super
(
LSAEncoderLayer
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
norm1
=
build_norm_layer
(
norm_cfg
,
embed_dims
,
postfix
=
1
)[
1
]
self
.
attn
=
LocallyGroupedSelfAttention
(
embed_dims
,
num_heads
,
qkv_bias
,
qk_scale
,
attn_drop_rate
,
drop_rate
,
window_size
)
self
.
norm2
=
build_norm_layer
(
norm_cfg
,
embed_dims
,
postfix
=
2
)[
1
]
self
.
ffn
=
FFN
(
embed_dims
=
embed_dims
,
feedforward_channels
=
feedforward_channels
,
num_fcs
=
num_fcs
,
ffn_drop
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
act_cfg
=
act_cfg
,
add_identity
=
False
)
self
.
drop_path
=
build_dropout
(
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
)
)
if
drop_path_rate
>
0.
else
nn
.
Identity
()
def
forward
(
self
,
x
,
hw_shape
):
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
),
hw_shape
))
x
=
x
+
self
.
drop_path
(
self
.
ffn
(
self
.
norm2
(
x
)))
return
x
@
BACKBONES
.
register_module
()
class
PCPVT
(
BaseModule
):
"""The backbone of Twins-PCPVT.
This backbone is the implementation of `Twins: Revisiting the Design
of Spatial Attention in Vision Transformers
<https://arxiv.org/abs/1512.03385>`_.
Args:
arch (dict, str): PCPVT architecture, a str value in arch zoo or a
detailed configuration dict with 7 keys, and the length of all the
values in dict should be the same:
- depths (List[int]): The number of encoder layers in each stage.
- embed_dims (List[int]): Embedding dimension in each stage.
- patch_sizes (List[int]): The patch sizes in each stage.
- num_heads (List[int]): Numbers of attention head in each stage.
- strides (List[int]): The strides in each stage.
- mlp_ratios (List[int]): The ratios of mlp in each stage.
- sr_ratios (List[int]): The ratios of GSA-encoder layers in each
stage.
in_channels (int): Number of input channels. Default: 3.
out_indices (tuple[int]): Output from which stages.
Default: (3, ).
qkv_bias (bool): Enable bias for qkv if True. Default: False.
drop_rate (float): Probability of an element to be zeroed.
Default 0.
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0
drop_path_rate (float): Stochastic depth rate. Default 0.0
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
norm_after_stage(bool, List[bool]): Add extra norm after each stage.
Default False.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmcls.models import PCPVT
>>> import torch
>>> pcpvt_cfg = {'arch': "small",
>>> 'norm_after_stage': [False, False, False, True]}
>>> model = PCPVT(**pcpvt_cfg)
>>> x = torch.rand(1, 3, 224, 224)
>>> outputs = model(x)
>>> print(outputs[-1].shape)
torch.Size([1, 512, 7, 7])
>>> pcpvt_cfg['norm_after_stage'] = [True, True, True, True]
>>> pcpvt_cfg['out_indices'] = (0, 1, 2, 3)
>>> model = PCPVT(**pcpvt_cfg)
>>> outputs = model(x)
>>> for feat in outputs:
>>> print(feat.shape)
torch.Size([1, 64, 56, 56])
torch.Size([1, 128, 28, 28])
torch.Size([1, 320, 14, 14])
torch.Size([1, 512, 7, 7])
"""
arch_zoo
=
{
**
dict
.
fromkeys
([
's'
,
'small'
],
{
'embed_dims'
:
[
64
,
128
,
320
,
512
],
'depths'
:
[
3
,
4
,
6
,
3
],
'num_heads'
:
[
1
,
2
,
5
,
8
],
'patch_sizes'
:
[
4
,
2
,
2
,
2
],
'strides'
:
[
4
,
2
,
2
,
2
],
'mlp_ratios'
:
[
8
,
8
,
4
,
4
],
'sr_ratios'
:
[
8
,
4
,
2
,
1
]}),
**
dict
.
fromkeys
([
'b'
,
'base'
],
{
'embed_dims'
:
[
64
,
128
,
320
,
512
],
'depths'
:
[
3
,
4
,
18
,
3
],
'num_heads'
:
[
1
,
2
,
5
,
8
],
'patch_sizes'
:
[
4
,
2
,
2
,
2
],
'strides'
:
[
4
,
2
,
2
,
2
],
'mlp_ratios'
:
[
8
,
8
,
4
,
4
],
'sr_ratios'
:
[
8
,
4
,
2
,
1
]}),
**
dict
.
fromkeys
([
'l'
,
'large'
],
{
'embed_dims'
:
[
64
,
128
,
320
,
512
],
'depths'
:
[
3
,
8
,
27
,
3
],
'num_heads'
:
[
1
,
2
,
5
,
8
],
'patch_sizes'
:
[
4
,
2
,
2
,
2
],
'strides'
:
[
4
,
2
,
2
,
2
],
'mlp_ratios'
:
[
8
,
8
,
4
,
4
],
'sr_ratios'
:
[
8
,
4
,
2
,
1
]}),
}
# yapf: disable
essential_keys
=
{
'embed_dims'
,
'depths'
,
'num_heads'
,
'patch_sizes'
,
'strides'
,
'mlp_ratios'
,
'sr_ratios'
}
def
__init__
(
self
,
arch
,
in_channels
=
3
,
out_indices
=
(
3
,
),
qkv_bias
=
False
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
norm_cfg
=
dict
(
type
=
'LN'
),
norm_after_stage
=
False
,
init_cfg
=
None
):
super
(
PCPVT
,
self
).
__init__
(
init_cfg
=
init_cfg
)
if
isinstance
(
arch
,
str
):
arch
=
arch
.
lower
()
assert
arch
in
set
(
self
.
arch_zoo
),
\
f
'Arch
{
arch
}
is not in default archs
{
set
(
self
.
arch_zoo
)
}
'
self
.
arch_settings
=
self
.
arch_zoo
[
arch
]
else
:
assert
isinstance
(
arch
,
dict
)
and
(
set
(
arch
)
==
self
.
essential_keys
),
f
'Custom arch needs a dict with keys
{
self
.
essential_keys
}
.'
self
.
arch_settings
=
arch
self
.
depths
=
self
.
arch_settings
[
'depths'
]
self
.
embed_dims
=
self
.
arch_settings
[
'embed_dims'
]
self
.
patch_sizes
=
self
.
arch_settings
[
'patch_sizes'
]
self
.
strides
=
self
.
arch_settings
[
'strides'
]
self
.
mlp_ratios
=
self
.
arch_settings
[
'mlp_ratios'
]
self
.
num_heads
=
self
.
arch_settings
[
'num_heads'
]
self
.
sr_ratios
=
self
.
arch_settings
[
'sr_ratios'
]
self
.
num_extra_tokens
=
0
# there is no cls-token in Twins
self
.
num_stage
=
len
(
self
.
depths
)
for
key
,
value
in
self
.
arch_settings
.
items
():
assert
isinstance
(
value
,
list
)
and
len
(
value
)
==
self
.
num_stage
,
(
'Length of setting item in arch dict must be type of list and'
' have the same length.'
)
# patch_embeds
self
.
patch_embeds
=
ModuleList
()
self
.
position_encoding_drops
=
ModuleList
()
self
.
stages
=
ModuleList
()
for
i
in
range
(
self
.
num_stage
):
# use in_channels of the model in the first stage
if
i
==
0
:
stage_in_channels
=
in_channels
else
:
stage_in_channels
=
self
.
embed_dims
[
i
-
1
]
self
.
patch_embeds
.
append
(
PatchEmbed
(
in_channels
=
stage_in_channels
,
embed_dims
=
self
.
embed_dims
[
i
],
conv_type
=
'Conv2d'
,
kernel_size
=
self
.
patch_sizes
[
i
],
stride
=
self
.
strides
[
i
],
padding
=
'corner'
,
norm_cfg
=
dict
(
type
=
'LN'
)))
self
.
position_encoding_drops
.
append
(
nn
.
Dropout
(
p
=
drop_rate
))
# PEGs
self
.
position_encodings
=
ModuleList
([
ConditionalPositionEncoding
(
embed_dim
,
embed_dim
)
for
embed_dim
in
self
.
embed_dims
])
# stochastic depth
total_depth
=
sum
(
self
.
depths
)
self
.
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
total_depth
)
]
# stochastic depth decay rule
cur
=
0
for
k
in
range
(
len
(
self
.
depths
)):
_block
=
ModuleList
([
GSAEncoderLayer
(
embed_dims
=
self
.
embed_dims
[
k
],
num_heads
=
self
.
num_heads
[
k
],
feedforward_channels
=
self
.
mlp_ratios
[
k
]
*
self
.
embed_dims
[
k
],
attn_drop_rate
=
attn_drop_rate
,
drop_rate
=
drop_rate
,
drop_path_rate
=
self
.
dpr
[
cur
+
i
],
num_fcs
=
2
,
qkv_bias
=
qkv_bias
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
norm_cfg
,
sr_ratio
=
self
.
sr_ratios
[
k
])
for
i
in
range
(
self
.
depths
[
k
])
])
self
.
stages
.
append
(
_block
)
cur
+=
self
.
depths
[
k
]
self
.
out_indices
=
out_indices
assert
isinstance
(
norm_after_stage
,
(
bool
,
list
))
if
isinstance
(
norm_after_stage
,
bool
):
self
.
norm_after_stage
=
[
norm_after_stage
]
*
self
.
num_stage
else
:
self
.
norm_after_stage
=
norm_after_stage
assert
len
(
self
.
norm_after_stage
)
==
self
.
num_stage
,
\
(
f
'Number of norm_after_stage(
{
len
(
self
.
norm_after_stage
)
}
) should'
f
' be equal to the number of stages(
{
self
.
num_stage
}
).'
)
for
i
,
has_norm
in
enumerate
(
self
.
norm_after_stage
):
assert
isinstance
(
has_norm
,
bool
),
'norm_after_stage should be '
\
'bool or List[bool].'
if
has_norm
and
norm_cfg
is
not
None
:
norm_layer
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
[
i
])[
1
]
else
:
norm_layer
=
nn
.
Identity
()
self
.
add_module
(
f
'norm_after_stage
{
i
}
'
,
norm_layer
)
def
init_weights
(
self
):
if
self
.
init_cfg
is
not
None
:
super
(
PCPVT
,
self
).
init_weights
()
else
:
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_init
(
m
,
std
=
.
02
,
bias
=
0.
)
elif
isinstance
(
m
,
(
_BatchNorm
,
nn
.
GroupNorm
,
nn
.
LayerNorm
)):
constant_init
(
m
,
val
=
1.0
,
bias
=
0.
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
normal_init
(
m
,
mean
=
0
,
std
=
math
.
sqrt
(
2.0
/
fan_out
),
bias
=
0
)
def
forward
(
self
,
x
):
outputs
=
list
()
b
=
x
.
shape
[
0
]
for
i
in
range
(
self
.
num_stage
):
x
,
hw_shape
=
self
.
patch_embeds
[
i
](
x
)
h
,
w
=
hw_shape
x
=
self
.
position_encoding_drops
[
i
](
x
)
for
j
,
blk
in
enumerate
(
self
.
stages
[
i
]):
x
=
blk
(
x
,
hw_shape
)
if
j
==
0
:
x
=
self
.
position_encodings
[
i
](
x
,
hw_shape
)
norm_layer
=
getattr
(
self
,
f
'norm_after_stage
{
i
}
'
)
x
=
norm_layer
(
x
)
x
=
x
.
reshape
(
b
,
h
,
w
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
if
i
in
self
.
out_indices
:
outputs
.
append
(
x
)
return
tuple
(
outputs
)
@
BACKBONES
.
register_module
()
class
SVT
(
PCPVT
):
"""The backbone of Twins-SVT.
This backbone is the implementation of `Twins: Revisiting the Design
of Spatial Attention in Vision Transformers
<https://arxiv.org/abs/1512.03385>`_.
Args:
arch (dict, str): SVT architecture, a str value in arch zoo or a
detailed configuration dict with 8 keys, and the length of all the
values in dict should be the same:
- depths (List[int]): The number of encoder layers in each stage.
- embed_dims (List[int]): Embedding dimension in each stage.
- patch_sizes (List[int]): The patch sizes in each stage.
- num_heads (List[int]): Numbers of attention head in each stage.
- strides (List[int]): The strides in each stage.
- mlp_ratios (List[int]): The ratios of mlp in each stage.
- sr_ratios (List[int]): The ratios of GSA-encoder layers in each
stage.
- windiow_sizes (List[int]): The window sizes in LSA-encoder layers
in each stage.
in_channels (int): Number of input channels. Default: 3.
out_indices (tuple[int]): Output from which stages.
Default: (3, ).
qkv_bias (bool): Enable bias for qkv if True. Default: False.
drop_rate (float): Dropout rate. Default 0.
attn_drop_rate (float): Dropout ratio of attention weight.
Default 0.0
drop_path_rate (float): Stochastic depth rate. Default 0.2.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
norm_after_stage(bool, List[bool]): Add extra norm after each stage.
Default False.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmcls.models import SVT
>>> import torch
>>> svt_cfg = {'arch': "small",
>>> 'norm_after_stage': [False, False, False, True]}
>>> model = SVT(**svt_cfg)
>>> x = torch.rand(1, 3, 224, 224)
>>> outputs = model(x)
>>> print(outputs[-1].shape)
torch.Size([1, 512, 7, 7])
>>> svt_cfg["out_indices"] = (0, 1, 2, 3)
>>> svt_cfg["norm_after_stage"] = [True, True, True, True]
>>> model = SVT(**svt_cfg)
>>> output = model(x)
>>> for feat in output:
>>> print(feat.shape)
torch.Size([1, 64, 56, 56])
torch.Size([1, 128, 28, 28])
torch.Size([1, 320, 14, 14])
torch.Size([1, 512, 7, 7])
"""
arch_zoo
=
{
**
dict
.
fromkeys
([
's'
,
'small'
],
{
'embed_dims'
:
[
64
,
128
,
256
,
512
],
'depths'
:
[
2
,
2
,
10
,
4
],
'num_heads'
:
[
2
,
4
,
8
,
16
],
'patch_sizes'
:
[
4
,
2
,
2
,
2
],
'strides'
:
[
4
,
2
,
2
,
2
],
'mlp_ratios'
:
[
4
,
4
,
4
,
4
],
'sr_ratios'
:
[
8
,
4
,
2
,
1
],
'window_sizes'
:
[
7
,
7
,
7
,
7
]}),
**
dict
.
fromkeys
([
'b'
,
'base'
],
{
'embed_dims'
:
[
96
,
192
,
384
,
768
],
'depths'
:
[
2
,
2
,
18
,
2
],
'num_heads'
:
[
3
,
6
,
12
,
24
],
'patch_sizes'
:
[
4
,
2
,
2
,
2
],
'strides'
:
[
4
,
2
,
2
,
2
],
'mlp_ratios'
:
[
4
,
4
,
4
,
4
],
'sr_ratios'
:
[
8
,
4
,
2
,
1
],
'window_sizes'
:
[
7
,
7
,
7
,
7
]}),
**
dict
.
fromkeys
([
'l'
,
'large'
],
{
'embed_dims'
:
[
128
,
256
,
512
,
1024
],
'depths'
:
[
2
,
2
,
18
,
2
],
'num_heads'
:
[
4
,
8
,
16
,
32
],
'patch_sizes'
:
[
4
,
2
,
2
,
2
],
'strides'
:
[
4
,
2
,
2
,
2
],
'mlp_ratios'
:
[
4
,
4
,
4
,
4
],
'sr_ratios'
:
[
8
,
4
,
2
,
1
],
'window_sizes'
:
[
7
,
7
,
7
,
7
]}),
}
# yapf: disable
essential_keys
=
{
'embed_dims'
,
'depths'
,
'num_heads'
,
'patch_sizes'
,
'strides'
,
'mlp_ratios'
,
'sr_ratios'
,
'window_sizes'
}
def
__init__
(
self
,
arch
,
in_channels
=
3
,
out_indices
=
(
3
,
),
qkv_bias
=
False
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.0
,
norm_cfg
=
dict
(
type
=
'LN'
),
norm_after_stage
=
False
,
init_cfg
=
None
):
super
(
SVT
,
self
).
__init__
(
arch
,
in_channels
,
out_indices
,
qkv_bias
,
drop_rate
,
attn_drop_rate
,
drop_path_rate
,
norm_cfg
,
norm_after_stage
,
init_cfg
)
self
.
window_sizes
=
self
.
arch_settings
[
'window_sizes'
]
for
k
in
range
(
self
.
num_stage
):
for
i
in
range
(
self
.
depths
[
k
]):
# in even-numbered layers of each stage, replace GSA with LSA
if
i
%
2
==
0
:
ffn_channels
=
self
.
mlp_ratios
[
k
]
*
self
.
embed_dims
[
k
]
self
.
stages
[
k
][
i
]
=
\
LSAEncoderLayer
(
embed_dims
=
self
.
embed_dims
[
k
],
num_heads
=
self
.
num_heads
[
k
],
feedforward_channels
=
ffn_channels
,
drop_rate
=
drop_rate
,
norm_cfg
=
norm_cfg
,
attn_drop_rate
=
attn_drop_rate
,
drop_path_rate
=
self
.
dpr
[
sum
(
self
.
depths
[:
k
])
+
i
],
qkv_bias
=
qkv_bias
,
window_size
=
self
.
window_sizes
[
k
])
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/van.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
Conv2d
,
build_activation_layer
,
build_norm_layer
from
mmcv.cnn.bricks
import
DropPath
from
mmcv.cnn.bricks.transformer
import
PatchEmbed
from
mmcv.runner
import
BaseModule
,
ModuleList
from
mmcv.utils.parrots_wrapper
import
_BatchNorm
from
..builder
import
BACKBONES
from
.base_backbone
import
BaseBackbone
class
MixFFN
(
BaseModule
):
"""An implementation of MixFFN of VAN. Refer to
mmdetection/mmdet/models/backbones/pvt.py.
The differences between MixFFN & FFN:
1. Use 1X1 Conv to replace Linear layer.
2. Introduce 3X3 Depth-wise Conv to encode positional information.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`.
feedforward_channels (int): The hidden dimension of FFNs.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='GELU').
ffn_drop (float, optional): Probability of an element to be
zeroed in FFN. Default 0.0.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def
__init__
(
self
,
embed_dims
,
feedforward_channels
,
act_cfg
=
dict
(
type
=
'GELU'
),
ffn_drop
=
0.
,
init_cfg
=
None
):
super
(
MixFFN
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
embed_dims
=
embed_dims
self
.
feedforward_channels
=
feedforward_channels
self
.
act_cfg
=
act_cfg
self
.
fc1
=
Conv2d
(
in_channels
=
embed_dims
,
out_channels
=
feedforward_channels
,
kernel_size
=
1
)
self
.
dwconv
=
Conv2d
(
in_channels
=
feedforward_channels
,
out_channels
=
feedforward_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
True
,
groups
=
feedforward_channels
)
self
.
act
=
build_activation_layer
(
act_cfg
)
self
.
fc2
=
Conv2d
(
in_channels
=
feedforward_channels
,
out_channels
=
embed_dims
,
kernel_size
=
1
)
self
.
drop
=
nn
.
Dropout
(
ffn_drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
dwconv
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
class
LKA
(
BaseModule
):
"""Large Kernel Attention(LKA) of VAN.
.. code:: text
DW_conv (depth-wise convolution)
|
|
DW_D_conv (depth-wise dilation convolution)
|
|
Transition Convolution (1×1 convolution)
Args:
embed_dims (int): Number of input channels.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def
__init__
(
self
,
embed_dims
,
init_cfg
=
None
):
super
(
LKA
,
self
).
__init__
(
init_cfg
=
init_cfg
)
# a spatial local convolution (depth-wise convolution)
self
.
DW_conv
=
Conv2d
(
in_channels
=
embed_dims
,
out_channels
=
embed_dims
,
kernel_size
=
5
,
padding
=
2
,
groups
=
embed_dims
)
# a spatial long-range convolution (depth-wise dilation convolution)
self
.
DW_D_conv
=
Conv2d
(
in_channels
=
embed_dims
,
out_channels
=
embed_dims
,
kernel_size
=
7
,
stride
=
1
,
padding
=
9
,
groups
=
embed_dims
,
dilation
=
3
)
self
.
conv1
=
Conv2d
(
in_channels
=
embed_dims
,
out_channels
=
embed_dims
,
kernel_size
=
1
)
def
forward
(
self
,
x
):
u
=
x
.
clone
()
attn
=
self
.
DW_conv
(
x
)
attn
=
self
.
DW_D_conv
(
attn
)
attn
=
self
.
conv1
(
attn
)
return
u
*
attn
class
SpatialAttention
(
BaseModule
):
"""Basic attention module in VANBloack.
Args:
embed_dims (int): Number of input channels.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='GELU').
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def
__init__
(
self
,
embed_dims
,
act_cfg
=
dict
(
type
=
'GELU'
),
init_cfg
=
None
):
super
(
SpatialAttention
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
proj_1
=
Conv2d
(
in_channels
=
embed_dims
,
out_channels
=
embed_dims
,
kernel_size
=
1
)
self
.
activation
=
build_activation_layer
(
act_cfg
)
self
.
spatial_gating_unit
=
LKA
(
embed_dims
)
self
.
proj_2
=
Conv2d
(
in_channels
=
embed_dims
,
out_channels
=
embed_dims
,
kernel_size
=
1
)
def
forward
(
self
,
x
):
shorcut
=
x
.
clone
()
x
=
self
.
proj_1
(
x
)
x
=
self
.
activation
(
x
)
x
=
self
.
spatial_gating_unit
(
x
)
x
=
self
.
proj_2
(
x
)
x
=
x
+
shorcut
return
x
class
VANBlock
(
BaseModule
):
"""A block of VAN.
Args:
embed_dims (int): Number of input channels.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='GELU').
layer_scale_init_value (float): Init value for Layer Scale.
Defaults to 1e-2.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def
__init__
(
self
,
embed_dims
,
ffn_ratio
=
4.
,
drop_rate
=
0.
,
drop_path_rate
=
0.
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
dict
(
type
=
'BN'
,
eps
=
1e-5
),
layer_scale_init_value
=
1e-2
,
init_cfg
=
None
):
super
(
VANBlock
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
out_channels
=
embed_dims
self
.
norm1
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
self
.
attn
=
SpatialAttention
(
embed_dims
,
act_cfg
=
act_cfg
)
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.
else
nn
.
Identity
()
self
.
norm2
=
build_norm_layer
(
norm_cfg
,
embed_dims
)[
1
]
mlp_hidden_dim
=
int
(
embed_dims
*
ffn_ratio
)
self
.
mlp
=
MixFFN
(
embed_dims
=
embed_dims
,
feedforward_channels
=
mlp_hidden_dim
,
act_cfg
=
act_cfg
,
ffn_drop
=
drop_rate
)
self
.
layer_scale_1
=
nn
.
Parameter
(
layer_scale_init_value
*
torch
.
ones
((
embed_dims
)),
requires_grad
=
True
)
if
layer_scale_init_value
>
0
else
None
self
.
layer_scale_2
=
nn
.
Parameter
(
layer_scale_init_value
*
torch
.
ones
((
embed_dims
)),
requires_grad
=
True
)
if
layer_scale_init_value
>
0
else
None
def
forward
(
self
,
x
):
identity
=
x
x
=
self
.
norm1
(
x
)
x
=
self
.
attn
(
x
)
if
self
.
layer_scale_1
is
not
None
:
x
=
self
.
layer_scale_1
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
*
x
x
=
identity
+
self
.
drop_path
(
x
)
identity
=
x
x
=
self
.
norm2
(
x
)
x
=
self
.
mlp
(
x
)
if
self
.
layer_scale_2
is
not
None
:
x
=
self
.
layer_scale_2
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
*
x
x
=
identity
+
self
.
drop_path
(
x
)
return
x
class
VANPatchEmbed
(
PatchEmbed
):
"""Image to Patch Embedding of VAN.
The differences between VANPatchEmbed & PatchEmbed:
1. Use BN.
2. Do not use 'flatten' and 'transpose'.
"""
def
__init__
(
self
,
*
args
,
norm_cfg
=
dict
(
type
=
'BN'
),
**
kwargs
):
super
(
VANPatchEmbed
,
self
).
__init__
(
*
args
,
norm_cfg
=
norm_cfg
,
**
kwargs
)
def
forward
(
self
,
x
):
"""
Args:
x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, out_h * out_w, embed_dims)
- out_size (tuple[int]): Spatial shape of x, arrange as
(out_h, out_w).
"""
if
self
.
adaptive_padding
:
x
=
self
.
adaptive_padding
(
x
)
x
=
self
.
projection
(
x
)
out_size
=
(
x
.
shape
[
2
],
x
.
shape
[
3
])
if
self
.
norm
is
not
None
:
x
=
self
.
norm
(
x
)
return
x
,
out_size
@
BACKBONES
.
register_module
()
class
VAN
(
BaseBackbone
):
"""Visual Attention Network.
A PyTorch implement of : `Visual Attention Network
<https://arxiv.org/pdf/2202.09741v2.pdf>`_
Inspiration from
https://github.com/Visual-Attention-Network/VAN-Classification
Args:
arch (str | dict): Visual Attention Network architecture.
If use string, choose from 'b0', 'b1', b2', b3' and etc.,
if use dict, it should have below keys:
- **embed_dims** (List[int]): The dimensions of embedding.
- **depths** (List[int]): The number of blocks in each stage.
- **ffn_ratios** (List[int]): The number of expansion ratio of
feedforward network hidden layer channels.
Defaults to 'tiny'.
patch_sizes (List[int | tuple]): The patch size in patch embeddings.
Defaults to [7, 3, 3, 3].
in_channels (int): The num of input channels. Defaults to 3.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
out_indices (Sequence[int]): Output from which stages.
Default: ``(3, )``.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
norm_cfg (dict): Config dict for normalization layer for all output
features. Defaults to ``dict(type='LN')``
block_cfgs (Sequence[dict] | dict): The extra config of each block.
Defaults to empty dicts.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmcls.models import VAN
>>> import torch
>>> model = VAN(arch='b0')
>>> inputs = torch.rand(1, 3, 224, 224)
>>> outputs = model(inputs)
>>> for out in outputs:
>>> print(out.size())
(1, 256, 7, 7)
"""
arch_zoo
=
{
**
dict
.
fromkeys
([
'b0'
,
't'
,
'tiny'
],
{
'embed_dims'
:
[
32
,
64
,
160
,
256
],
'depths'
:
[
3
,
3
,
5
,
2
],
'ffn_ratios'
:
[
8
,
8
,
4
,
4
]}),
**
dict
.
fromkeys
([
'b1'
,
's'
,
'small'
],
{
'embed_dims'
:
[
64
,
128
,
320
,
512
],
'depths'
:
[
2
,
2
,
4
,
2
],
'ffn_ratios'
:
[
8
,
8
,
4
,
4
]}),
**
dict
.
fromkeys
([
'b2'
,
'b'
,
'base'
],
{
'embed_dims'
:
[
64
,
128
,
320
,
512
],
'depths'
:
[
3
,
3
,
12
,
3
],
'ffn_ratios'
:
[
8
,
8
,
4
,
4
]}),
**
dict
.
fromkeys
([
'b3'
,
'l'
,
'large'
],
{
'embed_dims'
:
[
64
,
128
,
320
,
512
],
'depths'
:
[
3
,
5
,
27
,
3
],
'ffn_ratios'
:
[
8
,
8
,
4
,
4
]}),
**
dict
.
fromkeys
([
'b4'
],
{
'embed_dims'
:
[
64
,
128
,
320
,
512
],
'depths'
:
[
3
,
6
,
40
,
3
],
'ffn_ratios'
:
[
8
,
8
,
4
,
4
]}),
**
dict
.
fromkeys
([
'b5'
],
{
'embed_dims'
:
[
96
,
192
,
480
,
768
],
'depths'
:
[
3
,
3
,
24
,
3
],
'ffn_ratios'
:
[
8
,
8
,
4
,
4
]}),
**
dict
.
fromkeys
([
'b6'
],
{
'embed_dims'
:
[
96
,
192
,
384
,
768
],
'depths'
:
[
6
,
6
,
90
,
6
],
'ffn_ratios'
:
[
8
,
8
,
4
,
4
]}),
}
# yapf: disable
def
__init__
(
self
,
arch
=
'tiny'
,
patch_sizes
=
[
7
,
3
,
3
,
3
],
in_channels
=
3
,
drop_rate
=
0.
,
drop_path_rate
=
0.
,
out_indices
=
(
3
,
),
frozen_stages
=-
1
,
norm_eval
=
False
,
norm_cfg
=
dict
(
type
=
'LN'
),
block_cfgs
=
dict
(),
init_cfg
=
None
):
super
(
VAN
,
self
).
__init__
(
init_cfg
=
init_cfg
)
if
isinstance
(
arch
,
str
):
arch
=
arch
.
lower
()
assert
arch
in
set
(
self
.
arch_zoo
),
\
f
'Arch
{
arch
}
is not in default archs
{
set
(
self
.
arch_zoo
)
}
'
self
.
arch_settings
=
self
.
arch_zoo
[
arch
]
else
:
essential_keys
=
{
'embed_dims'
,
'depths'
,
'ffn_ratios'
}
assert
isinstance
(
arch
,
dict
)
and
set
(
arch
)
==
essential_keys
,
\
f
'Custom arch needs a dict with keys
{
essential_keys
}
'
self
.
arch_settings
=
arch
self
.
embed_dims
=
self
.
arch_settings
[
'embed_dims'
]
self
.
depths
=
self
.
arch_settings
[
'depths'
]
self
.
ffn_ratios
=
self
.
arch_settings
[
'ffn_ratios'
]
self
.
num_stages
=
len
(
self
.
depths
)
self
.
out_indices
=
out_indices
self
.
frozen_stages
=
frozen_stages
self
.
norm_eval
=
norm_eval
total_depth
=
sum
(
self
.
depths
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
total_depth
)
]
# stochastic depth decay rule
cur_block_idx
=
0
for
i
,
depth
in
enumerate
(
self
.
depths
):
patch_embed
=
VANPatchEmbed
(
in_channels
=
in_channels
if
i
==
0
else
self
.
embed_dims
[
i
-
1
],
input_size
=
None
,
embed_dims
=
self
.
embed_dims
[
i
],
kernel_size
=
patch_sizes
[
i
],
stride
=
patch_sizes
[
i
]
//
2
+
1
,
padding
=
(
patch_sizes
[
i
]
//
2
,
patch_sizes
[
i
]
//
2
),
norm_cfg
=
dict
(
type
=
'BN'
))
blocks
=
ModuleList
([
VANBlock
(
embed_dims
=
self
.
embed_dims
[
i
],
ffn_ratio
=
self
.
ffn_ratios
[
i
],
drop_rate
=
drop_rate
,
drop_path_rate
=
dpr
[
cur_block_idx
+
j
],
**
block_cfgs
)
for
j
in
range
(
depth
)
])
cur_block_idx
+=
depth
norm
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
[
i
])[
1
]
self
.
add_module
(
f
'patch_embed
{
i
+
1
}
'
,
patch_embed
)
self
.
add_module
(
f
'blocks
{
i
+
1
}
'
,
blocks
)
self
.
add_module
(
f
'norm
{
i
+
1
}
'
,
norm
)
def
train
(
self
,
mode
=
True
):
super
(
VAN
,
self
).
train
(
mode
)
self
.
_freeze_stages
()
if
mode
and
self
.
norm_eval
:
for
m
in
self
.
modules
():
# trick: eval have effect on BatchNorm only
if
isinstance
(
m
,
_BatchNorm
):
m
.
eval
()
def
_freeze_stages
(
self
):
for
i
in
range
(
0
,
self
.
frozen_stages
+
1
):
# freeze patch embed
m
=
getattr
(
self
,
f
'patch_embed
{
i
+
1
}
'
)
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
# freeze blocks
m
=
getattr
(
self
,
f
'blocks
{
i
+
1
}
'
)
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
# freeze norm
m
=
getattr
(
self
,
f
'norm
{
i
+
1
}
'
)
m
.
eval
()
for
param
in
m
.
parameters
():
param
.
requires_grad
=
False
def
forward
(
self
,
x
):
outs
=
[]
for
i
in
range
(
self
.
num_stages
):
patch_embed
=
getattr
(
self
,
f
'patch_embed
{
i
+
1
}
'
)
blocks
=
getattr
(
self
,
f
'blocks
{
i
+
1
}
'
)
norm
=
getattr
(
self
,
f
'norm
{
i
+
1
}
'
)
x
,
hw_shape
=
patch_embed
(
x
)
for
block
in
blocks
:
x
=
block
(
x
)
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
x
=
norm
(
x
)
x
=
x
.
reshape
(
-
1
,
*
hw_shape
,
block
.
out_channels
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
if
i
in
self
.
out_indices
:
outs
.
append
(
x
)
return
tuple
(
outs
)
openmmlab_test/mmclassification-
speed-benchmark
/mmcls/models/backbones/vgg.py
→
openmmlab_test/mmclassification-
0.24.1
/mmcls/models/backbones/vgg.py
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
import
torch.nn
as
nn
from
mmcv.cnn
import
ConvModule
from
mmcv.utils.parrots_wrapper
import
_BatchNorm
...
...
@@ -45,13 +46,11 @@ class VGG(BaseBackbone):
num_stages (int): VGG stages, normally 5.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int], optional): Output from which stages.
If only one stage is specified, a single tensor (feature map) is
returned, otherwise multiple stages are specified, a tuple of
tensors will be returned. When it is None, the default behavior
depends on whether num_classes is specified. If num_classes <= 0,
the default value is (4, ), outputing the last feature map before
classifier. If num_classes > 0, the default value is (5, ),
outputing the classification score. Default: None.
When it is None, the default behavior depends on whether
num_classes is specified. If num_classes <= 0, the default value is
(4, ), output the last feature map before classifier. If
num_classes > 0, the default value is (5, ), output the
classification score. Default: None.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
...
...
@@ -162,9 +161,7 @@ class VGG(BaseBackbone):
x
=
x
.
view
(
x
.
size
(
0
),
-
1
)
x
=
self
.
classifier
(
x
)
outs
.
append
(
x
)
if
len
(
outs
)
==
1
:
return
outs
[
0
]
else
:
return
tuple
(
outs
)
def
_freeze_stages
(
self
):
...
...
openmmlab_test/mmclassification-0.24.1/mmcls/models/backbones/vision_transformer.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
typing
import
Sequence
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
mmcv.cnn
import
build_norm_layer
from
mmcv.cnn.bricks.transformer
import
FFN
,
PatchEmbed
from
mmcv.cnn.utils.weight_init
import
trunc_normal_
from
mmcv.runner.base_module
import
BaseModule
,
ModuleList
from
mmcls.utils
import
get_root_logger
from
..builder
import
BACKBONES
from
..utils
import
MultiheadAttention
,
resize_pos_embed
,
to_2tuple
from
.base_backbone
import
BaseBackbone
class
TransformerEncoderLayer
(
BaseModule
):
"""Implements one encoder layer in Vision Transformer.
Args:
embed_dims (int): The feature dimension
num_heads (int): Parallel attention heads
feedforward_channels (int): The hidden dimension for FFNs
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop_rate (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
act_cfg (dict): The activation config for FFNs.
Defaluts to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def
__init__
(
self
,
embed_dims
,
num_heads
,
feedforward_channels
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
num_fcs
=
2
,
qkv_bias
=
True
,
act_cfg
=
dict
(
type
=
'GELU'
),
norm_cfg
=
dict
(
type
=
'LN'
),
init_cfg
=
None
):
super
(
TransformerEncoderLayer
,
self
).
__init__
(
init_cfg
=
init_cfg
)
self
.
embed_dims
=
embed_dims
self
.
norm1_name
,
norm1
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
,
postfix
=
1
)
self
.
add_module
(
self
.
norm1_name
,
norm1
)
self
.
attn
=
MultiheadAttention
(
embed_dims
=
embed_dims
,
num_heads
=
num_heads
,
attn_drop
=
attn_drop_rate
,
proj_drop
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
qkv_bias
=
qkv_bias
)
self
.
norm2_name
,
norm2
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
,
postfix
=
2
)
self
.
add_module
(
self
.
norm2_name
,
norm2
)
self
.
ffn
=
FFN
(
embed_dims
=
embed_dims
,
feedforward_channels
=
feedforward_channels
,
num_fcs
=
num_fcs
,
ffn_drop
=
drop_rate
,
dropout_layer
=
dict
(
type
=
'DropPath'
,
drop_prob
=
drop_path_rate
),
act_cfg
=
act_cfg
)
@
property
def
norm1
(
self
):
return
getattr
(
self
,
self
.
norm1_name
)
@
property
def
norm2
(
self
):
return
getattr
(
self
,
self
.
norm2_name
)
def
init_weights
(
self
):
super
(
TransformerEncoderLayer
,
self
).
init_weights
()
for
m
in
self
.
ffn
.
modules
():
if
isinstance
(
m
,
nn
.
Linear
):
nn
.
init
.
xavier_uniform_
(
m
.
weight
)
nn
.
init
.
normal_
(
m
.
bias
,
std
=
1e-6
)
def
forward
(
self
,
x
):
x
=
x
+
self
.
attn
(
self
.
norm1
(
x
))
x
=
self
.
ffn
(
self
.
norm2
(
x
),
identity
=
x
)
return
x
@
BACKBONES
.
register_module
()
class
VisionTransformer
(
BaseBackbone
):
"""Vision Transformer.
A PyTorch implement of : `An Image is Worth 16x16 Words: Transformers
for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_
Args:
arch (str | dict): Vision Transformer architecture. If use string,
choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small'
and 'deit-base'. If use dict, it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
- **feedforward_channels** (int): The hidden dimensions in
feedforward modules.
Defaults to 'base'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
output_cls_token (bool): Whether output the cls_token. If set True,
``with_cls_token`` must be True. Defaults to True.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo
=
{
**
dict
.
fromkeys
(
[
's'
,
'small'
],
{
'embed_dims'
:
768
,
'num_layers'
:
8
,
'num_heads'
:
8
,
'feedforward_channels'
:
768
*
3
,
}),
**
dict
.
fromkeys
(
[
'b'
,
'base'
],
{
'embed_dims'
:
768
,
'num_layers'
:
12
,
'num_heads'
:
12
,
'feedforward_channels'
:
3072
}),
**
dict
.
fromkeys
(
[
'l'
,
'large'
],
{
'embed_dims'
:
1024
,
'num_layers'
:
24
,
'num_heads'
:
16
,
'feedforward_channels'
:
4096
}),
**
dict
.
fromkeys
(
[
'deit-t'
,
'deit-tiny'
],
{
'embed_dims'
:
192
,
'num_layers'
:
12
,
'num_heads'
:
3
,
'feedforward_channels'
:
192
*
4
}),
**
dict
.
fromkeys
(
[
'deit-s'
,
'deit-small'
],
{
'embed_dims'
:
384
,
'num_layers'
:
12
,
'num_heads'
:
6
,
'feedforward_channels'
:
384
*
4
}),
**
dict
.
fromkeys
(
[
'deit-b'
,
'deit-base'
],
{
'embed_dims'
:
768
,
'num_layers'
:
12
,
'num_heads'
:
12
,
'feedforward_channels'
:
768
*
4
}),
}
# Some structures have multiple extra tokens, like DeiT.
num_extra_tokens
=
1
# cls_token
def
__init__
(
self
,
arch
=
'base'
,
img_size
=
224
,
patch_size
=
16
,
in_channels
=
3
,
out_indices
=-
1
,
drop_rate
=
0.
,
drop_path_rate
=
0.
,
qkv_bias
=
True
,
norm_cfg
=
dict
(
type
=
'LN'
,
eps
=
1e-6
),
final_norm
=
True
,
with_cls_token
=
True
,
output_cls_token
=
True
,
interpolate_mode
=
'bicubic'
,
patch_cfg
=
dict
(),
layer_cfgs
=
dict
(),
init_cfg
=
None
):
super
(
VisionTransformer
,
self
).
__init__
(
init_cfg
)
if
isinstance
(
arch
,
str
):
arch
=
arch
.
lower
()
assert
arch
in
set
(
self
.
arch_zoo
),
\
f
'Arch
{
arch
}
is not in default archs
{
set
(
self
.
arch_zoo
)
}
'
self
.
arch_settings
=
self
.
arch_zoo
[
arch
]
else
:
essential_keys
=
{
'embed_dims'
,
'num_layers'
,
'num_heads'
,
'feedforward_channels'
}
assert
isinstance
(
arch
,
dict
)
and
essential_keys
<=
set
(
arch
),
\
f
'Custom arch needs a dict with keys
{
essential_keys
}
'
self
.
arch_settings
=
arch
self
.
embed_dims
=
self
.
arch_settings
[
'embed_dims'
]
self
.
num_layers
=
self
.
arch_settings
[
'num_layers'
]
self
.
img_size
=
to_2tuple
(
img_size
)
# Set patch embedding
_patch_cfg
=
dict
(
in_channels
=
in_channels
,
input_size
=
img_size
,
embed_dims
=
self
.
embed_dims
,
conv_type
=
'Conv2d'
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
)
_patch_cfg
.
update
(
patch_cfg
)
self
.
patch_embed
=
PatchEmbed
(
**
_patch_cfg
)
self
.
patch_resolution
=
self
.
patch_embed
.
init_out_size
num_patches
=
self
.
patch_resolution
[
0
]
*
self
.
patch_resolution
[
1
]
# Set cls token
if
output_cls_token
:
assert
with_cls_token
is
True
,
f
'with_cls_token must be True if'
\
f
'set output_cls_token to True, but got
{
with_cls_token
}
'
self
.
with_cls_token
=
with_cls_token
self
.
output_cls_token
=
output_cls_token
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
self
.
embed_dims
))
# Set position embedding
self
.
interpolate_mode
=
interpolate_mode
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
+
self
.
num_extra_tokens
,
self
.
embed_dims
))
self
.
_register_load_state_dict_pre_hook
(
self
.
_prepare_pos_embed
)
self
.
drop_after_pos
=
nn
.
Dropout
(
p
=
drop_rate
)
if
isinstance
(
out_indices
,
int
):
out_indices
=
[
out_indices
]
assert
isinstance
(
out_indices
,
Sequence
),
\
f
'"out_indices" must by a sequence or int, '
\
f
'get
{
type
(
out_indices
)
}
instead.'
for
i
,
index
in
enumerate
(
out_indices
):
if
index
<
0
:
out_indices
[
i
]
=
self
.
num_layers
+
index
assert
0
<=
out_indices
[
i
]
<=
self
.
num_layers
,
\
f
'Invalid out_indices
{
index
}
'
self
.
out_indices
=
out_indices
# stochastic depth decay rule
dpr
=
np
.
linspace
(
0
,
drop_path_rate
,
self
.
num_layers
)
self
.
layers
=
ModuleList
()
if
isinstance
(
layer_cfgs
,
dict
):
layer_cfgs
=
[
layer_cfgs
]
*
self
.
num_layers
for
i
in
range
(
self
.
num_layers
):
_layer_cfg
=
dict
(
embed_dims
=
self
.
embed_dims
,
num_heads
=
self
.
arch_settings
[
'num_heads'
],
feedforward_channels
=
self
.
arch_settings
[
'feedforward_channels'
],
drop_rate
=
drop_rate
,
drop_path_rate
=
dpr
[
i
],
qkv_bias
=
qkv_bias
,
norm_cfg
=
norm_cfg
)
_layer_cfg
.
update
(
layer_cfgs
[
i
])
self
.
layers
.
append
(
TransformerEncoderLayer
(
**
_layer_cfg
))
self
.
final_norm
=
final_norm
if
final_norm
:
self
.
norm1_name
,
norm1
=
build_norm_layer
(
norm_cfg
,
self
.
embed_dims
,
postfix
=
1
)
self
.
add_module
(
self
.
norm1_name
,
norm1
)
@
property
def
norm1
(
self
):
return
getattr
(
self
,
self
.
norm1_name
)
def
init_weights
(
self
):
super
(
VisionTransformer
,
self
).
init_weights
()
if
not
(
isinstance
(
self
.
init_cfg
,
dict
)
and
self
.
init_cfg
[
'type'
]
==
'Pretrained'
):
trunc_normal_
(
self
.
pos_embed
,
std
=
0.02
)
def
_prepare_pos_embed
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
name
=
prefix
+
'pos_embed'
if
name
not
in
state_dict
.
keys
():
return
ckpt_pos_embed_shape
=
state_dict
[
name
].
shape
if
self
.
pos_embed
.
shape
!=
ckpt_pos_embed_shape
:
from
mmcv.utils
import
print_log
logger
=
get_root_logger
()
print_log
(
f
'Resize the pos_embed shape from
{
ckpt_pos_embed_shape
}
'
f
'to
{
self
.
pos_embed
.
shape
}
.'
,
logger
=
logger
)
ckpt_pos_embed_shape
=
to_2tuple
(
int
(
np
.
sqrt
(
ckpt_pos_embed_shape
[
1
]
-
self
.
num_extra_tokens
)))
pos_embed_shape
=
self
.
patch_embed
.
init_out_size
state_dict
[
name
]
=
resize_pos_embed
(
state_dict
[
name
],
ckpt_pos_embed_shape
,
pos_embed_shape
,
self
.
interpolate_mode
,
self
.
num_extra_tokens
)
@
staticmethod
def
resize_pos_embed
(
*
args
,
**
kwargs
):
"""Interface for backward-compatibility."""
return
resize_pos_embed
(
*
args
,
**
kwargs
)
def
forward
(
self
,
x
):
B
=
x
.
shape
[
0
]
x
,
patch_resolution
=
self
.
patch_embed
(
x
)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens
=
self
.
cls_token
.
expand
(
B
,
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_tokens
,
x
),
dim
=
1
)
x
=
x
+
resize_pos_embed
(
self
.
pos_embed
,
self
.
patch_resolution
,
patch_resolution
,
mode
=
self
.
interpolate_mode
,
num_extra_tokens
=
self
.
num_extra_tokens
)
x
=
self
.
drop_after_pos
(
x
)
if
not
self
.
with_cls_token
:
# Remove class token for transformer encoder input
x
=
x
[:,
1
:]
outs
=
[]
for
i
,
layer
in
enumerate
(
self
.
layers
):
x
=
layer
(
x
)
if
i
==
len
(
self
.
layers
)
-
1
and
self
.
final_norm
:
x
=
self
.
norm1
(
x
)
if
i
in
self
.
out_indices
:
B
,
_
,
C
=
x
.
shape
if
self
.
with_cls_token
:
patch_token
=
x
[:,
1
:].
reshape
(
B
,
*
patch_resolution
,
C
)
patch_token
=
patch_token
.
permute
(
0
,
3
,
1
,
2
)
cls_token
=
x
[:,
0
]
else
:
patch_token
=
x
.
reshape
(
B
,
*
patch_resolution
,
C
)
patch_token
=
patch_token
.
permute
(
0
,
3
,
1
,
2
)
cls_token
=
None
if
self
.
output_cls_token
:
out
=
[
patch_token
,
cls_token
]
else
:
out
=
patch_token
outs
.
append
(
out
)
return
tuple
(
outs
)
openmmlab_test/mmclassification-0.24.1/mmcls/models/builder.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
mmcv.cnn
import
MODELS
as
MMCV_MODELS
from
mmcv.cnn.bricks.registry
import
ATTENTION
as
MMCV_ATTENTION
from
mmcv.utils
import
Registry
MODELS
=
Registry
(
'models'
,
parent
=
MMCV_MODELS
)
BACKBONES
=
MODELS
NECKS
=
MODELS
HEADS
=
MODELS
LOSSES
=
MODELS
CLASSIFIERS
=
MODELS
ATTENTION
=
Registry
(
'attention'
,
parent
=
MMCV_ATTENTION
)
def
build_backbone
(
cfg
):
"""Build backbone."""
return
BACKBONES
.
build
(
cfg
)
def
build_neck
(
cfg
):
"""Build neck."""
return
NECKS
.
build
(
cfg
)
def
build_head
(
cfg
):
"""Build head."""
return
HEADS
.
build
(
cfg
)
def
build_loss
(
cfg
):
"""Build loss."""
return
LOSSES
.
build
(
cfg
)
def
build_classifier
(
cfg
):
return
CLASSIFIERS
.
build
(
cfg
)
openmmlab_test/mmclassification-0.24.1/mmcls/models/classifiers/__init__.py
0 → 100644
View file @
0fd8347d
# Copyright (c) OpenMMLab. All rights reserved.
from
.base
import
BaseClassifier
from
.image
import
ImageClassifier
__all__
=
[
'BaseClassifier'
,
'ImageClassifier'
]
openmmlab_test/mmclassification-
speed-benchmark
/mmcls/models/classifiers/base.py
→
openmmlab_test/mmclassification-
0.24.1
/mmcls/models/classifiers/base.py
View file @
0fd8347d
import
warnings
# Copyright (c) OpenMMLab. All rights reserved.
from
abc
import
ABCMeta
,
abstractmethod
from
collections
import
OrderedDict
from
typing
import
Sequence
import
cv2
import
mmcv
import
torch
import
torch.distributed
as
dist
from
mmcv
import
color_val
from
mmcv.runner
import
BaseModule
from
mmcv.runner
import
BaseModule
,
auto_fp16
# TODO import `auto_fp16` from mmcv and delete them from mmcls
try
:
from
mmcv.runner
import
auto_fp16
except
ImportError
:
warnings
.
warn
(
'auto_fp16 from mmcls will be deprecated.'
'Please install mmcv>=1.1.4.'
)
from
mmcls.core
import
auto_fp16
from
mmcls.core.visualization
import
imshow_infos
class
BaseClassifier
(
BaseModule
,
metaclass
=
ABCMeta
):
...
...
@@ -34,13 +27,14 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
return
hasattr
(
self
,
'head'
)
and
self
.
head
is
not
None
@
abstractmethod
def
extract_feat
(
self
,
imgs
):
def
extract_feat
(
self
,
imgs
,
stage
=
None
):
pass
def
extract_feats
(
self
,
imgs
):
assert
isinstance
(
imgs
,
list
)
def
extract_feats
(
self
,
imgs
,
stage
=
None
):
assert
isinstance
(
imgs
,
Sequence
)
kwargs
=
{}
if
stage
is
None
else
{
'stage'
:
stage
}
for
img
in
imgs
:
yield
self
.
extract_feat
(
img
)
yield
self
.
extract_feat
(
img
,
**
kwargs
)
@
abstractmethod
def
forward_train
(
self
,
imgs
,
**
kwargs
):
...
...
@@ -117,7 +111,7 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
return
loss
,
log_vars
def
train_step
(
self
,
data
,
optimizer
):
def
train_step
(
self
,
data
,
optimizer
=
None
,
**
kwargs
):
"""The iteration step during training.
This method defines an iteration step during training, except for the
...
...
@@ -128,20 +122,19 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
Args:
data (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict
): The optimizer of
runner is passed to ``train_step()``. This
argument is unused
and reserved.
optimizer (:obj:`torch.optim.Optimizer` | dict
, optional): The
optimizer of
runner is passed to ``train_step()``. This
argument is unused
and reserved.
Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
``num_samples``.
``loss`` is a tensor for back propagation, which can be a
weighted sum of multiple losses.
``log_vars`` contains all the variables to be sent to the
logger.
``num_samples`` indicates the batch size (when the model is
DDP, it means the batch size on each GPU), which is used for
averaging the logs.
dict: Dict of outputs. The following fields are contained.
- loss (torch.Tensor): A tensor for back propagation, which
\
can be a weighted sum of multiple losses.
- log_vars (dict): Dict contains all the variables to be sent
\
to the logger.
- num_samples (int): Indicates the batch size (when the model
\
is DDP, it means the batch size on each GPU), which is
\
used for averaging the logs.
"""
losses
=
self
(
**
data
)
loss
,
log_vars
=
self
.
_parse_losses
(
losses
)
...
...
@@ -151,12 +144,28 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
return
outputs
def
val_step
(
self
,
data
,
optimizer
):
def
val_step
(
self
,
data
,
optimizer
=
None
,
**
kwargs
):
"""The iteration step during validation.
This method shares the same signature as :func:`train_step`, but used
during val epochs. Note that the evaluation after training epochs is
not implemented with this method, but an evaluation hook.
Args:
data (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict, optional): The
optimizer of runner is passed to ``train_step()``. This
argument is unused and reserved.
Returns:
dict: Dict of outputs. The following fields are contained.
- loss (torch.Tensor): A tensor for back propagation, which
\
can be a weighted sum of multiple losses.
- log_vars (dict): Dict contains all the variables to be sent
\
to the logger.
- num_samples (int): Indicates the batch size (when the model
\
is DDP, it means the batch size on each GPU), which is
\
used for averaging the logs.
"""
losses
=
self
(
**
data
)
loss
,
log_vars
=
self
.
_parse_losses
(
losses
)
...
...
@@ -169,56 +178,47 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
def
show_result
(
self
,
img
,
result
,
text_color
=
'
green
'
,
text_color
=
'
white
'
,
font_scale
=
0.5
,
row_width
=
20
,
show
=
False
,
fig_size
=
(
15
,
10
),
win_name
=
''
,
wait_time
=
0
,
out_file
=
None
):
"""Draw `result` over `img`.
Args:
img (str or
Tensor
): The image to be displayed.
result (
Tensor
): The classification results to draw over `img`.
img (str or
ndarray
): The image to be displayed.
result (
dict
): The classification results to draw over `img`.
text_color (str or tuple or :obj:`Color`): Color of texts.
font_scale (float): Font scales of texts.
row_width (int): width between each row of results on the image.
show (bool): Whether to show the image.
Default: False.
fig_size (tuple): Image show figure size. Defaults to (15, 10).
win_name (str): The window name.
wait_time (int):
Value of waitKey param
.
Default
:
0.
wait_time (int):
How many seconds to display the image
.
Default
s to
0.
out_file (str or None): The filename to write the image.
Default: None.
Returns:
img (
Tensor): Only if not `show` or `out_file`
img (
ndarray): Image with overlaid results.
"""
img
=
mmcv
.
imread
(
img
)
img
=
img
.
copy
()
# write results on left-top of the image
x
,
y
=
0
,
row_width
text_color
=
color_val
(
text_color
)
for
k
,
v
in
result
.
items
():
if
isinstance
(
v
,
float
):
v
=
f
'
{
v
:.
2
f
}
'
label_text
=
f
'
{
k
}
:
{
v
}
'
cv2
.
putText
(
img
,
label_text
,
(
x
,
y
),
cv2
.
FONT_HERSHEY_COMPLEX
,
font_scale
,
text_color
)
y
+=
row_width
# if out_file specified, do not show image in window
if
out_file
is
not
None
:
show
=
False
if
show
:
mmcv
.
imshow
(
img
,
win_name
,
wait_time
)
if
out_file
is
not
None
:
mmcv
.
imwrite
(
img
,
out_file
)
if
not
(
show
or
out_file
):
warnings
.
warn
(
'show==False and out_file is not specified, only '
'result image will be returned'
)
img
=
imshow_infos
(
img
,
result
,
text_color
=
text_color
,
font_size
=
int
(
font_scale
*
50
),
row_width
=
row_width
,
win_name
=
win_name
,
show
=
show
,
fig_size
=
fig_size
,
wait_time
=
wait_time
,
out_file
=
out_file
)
return
img
Prev
1
…
33
34
35
36
37
38
39
40
41
42
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment