Commit 0fd8347d authored by unknown's avatar unknown
Browse files

添加mmclassification-0.24.1代码,删除mmclassification-speed-benchmark

parent cc567e9e
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -260,7 +261,7 @@ class Bottleneck(_Bottleneck):
class ResNeSt(ResNetV1d):
"""ResNeSt backbone.
Please refer to the `paper <https://arxiv.org/pdf/2004.08955.pdf>`_ for
Please refer to the `paper <https://arxiv.org/pdf/2004.08955.pdf>`__ for
details.
Args:
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
constant_init)
from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
build_norm_layer, constant_init)
from mmcv.cnn.bricks import DropPath
from mmcv.runner import BaseModule
from mmcv.utils.parrots_wrapper import _BatchNorm
from ..builder import BACKBONES
from .base_backbone import BaseBackbone
eps = 1.0e-5
class BasicBlock(nn.Module):
class BasicBlock(BaseModule):
"""BasicBlock for ResNet.
Args:
......@@ -41,8 +47,11 @@ class BasicBlock(nn.Module):
style='pytorch',
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN')):
super(BasicBlock, self).__init__()
norm_cfg=dict(type='BN'),
drop_path_rate=0.0,
act_cfg=dict(type='ReLU', inplace=True),
init_cfg=None):
super(BasicBlock, self).__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.expansion = expansion
......@@ -80,8 +89,10 @@ class BasicBlock(nn.Module):
bias=False)
self.add_module(self.norm2_name, norm2)
self.relu = nn.ReLU(inplace=True)
self.relu = build_activation_layer(act_cfg)
self.downsample = downsample
self.drop_path = DropPath(drop_prob=drop_path_rate
) if drop_path_rate > eps else nn.Identity()
@property
def norm1(self):
......@@ -106,6 +117,8 @@ class BasicBlock(nn.Module):
if self.downsample is not None:
identity = self.downsample(x)
out = self.drop_path(out)
out += identity
return out
......@@ -120,7 +133,7 @@ class BasicBlock(nn.Module):
return out
class Bottleneck(nn.Module):
class Bottleneck(BaseModule):
"""Bottleneck block for ResNet.
Args:
......@@ -153,8 +166,11 @@ class Bottleneck(nn.Module):
style='pytorch',
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN')):
super(Bottleneck, self).__init__()
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU', inplace=True),
drop_path_rate=0.0,
init_cfg=None):
super(Bottleneck, self).__init__(init_cfg=init_cfg)
assert style in ['pytorch', 'caffe']
self.in_channels = in_channels
......@@ -210,8 +226,10 @@ class Bottleneck(nn.Module):
bias=False)
self.add_module(self.norm3_name, norm3)
self.relu = nn.ReLU(inplace=True)
self.relu = build_activation_layer(act_cfg)
self.downsample = downsample
self.drop_path = DropPath(drop_prob=drop_path_rate
) if drop_path_rate > eps else nn.Identity()
@property
def norm1(self):
......@@ -244,6 +262,8 @@ class Bottleneck(nn.Module):
if self.downsample is not None:
identity = self.downsample(x)
out = self.drop_path(out)
out += identity
return out
......@@ -382,7 +402,7 @@ class ResLayer(nn.Sequential):
class ResNet(BaseBackbone):
"""ResNet backbone.
Please refer to the `paper <https://arxiv.org/abs/1512.03385>`_ for
Please refer to the `paper <https://arxiv.org/abs/1512.03385>`__ for
details.
Args:
......@@ -395,10 +415,8 @@ class ResNet(BaseBackbone):
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. Default: ``(3, )``.
out_indices (Sequence[int]): Output from which stages.
Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
......@@ -466,7 +484,8 @@ class ResNet(BaseBackbone):
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]):
],
drop_path_rate=0.0):
super(ResNet, self).__init__(init_cfg)
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')
......@@ -513,7 +532,8 @@ class ResNet(BaseBackbone):
avg_down=self.avg_down,
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg)
norm_cfg=norm_cfg,
drop_path_rate=drop_path_rate)
_in_channels = _out_channels
_out_channels *= 2
layer_name = f'layer{i + 1}'
......@@ -594,10 +614,14 @@ class ResNet(BaseBackbone):
for param in m.parameters():
param.requires_grad = False
# def init_weights(self, pretrained=None):
def init_weights(self):
super(ResNet, self).init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress zero_init_residual if use pretrained model.
return
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
......@@ -619,9 +643,6 @@ class ResNet(BaseBackbone):
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
def train(self, mode=True):
......@@ -634,10 +655,27 @@ class ResNet(BaseBackbone):
m.eval()
@BACKBONES.register_module()
class ResNetV1c(ResNet):
"""ResNetV1c backbone.
This variant is described in `Bag of Tricks.
<https://arxiv.org/pdf/1812.01187.pdf>`_.
Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv
in the input stem with three 3x3 convs.
"""
def __init__(self, **kwargs):
super(ResNetV1c, self).__init__(
deep_stem=True, avg_down=False, **kwargs)
@BACKBONES.register_module()
class ResNetV1d(ResNet):
"""ResNetV1d variant described in `Bag of Tricks.
"""ResNetV1d backbone.
This variant is described in `Bag of Tricks.
<https://arxiv.org/pdf/1812.01187.pdf>`_.
Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import build_conv_layer, build_norm_layer
......@@ -77,7 +78,4 @@ class ResNet_CIFAR(ResNet):
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES
......@@ -89,7 +90,7 @@ class Bottleneck(_Bottleneck):
class ResNeXt(ResNet):
"""ResNeXt backbone.
Please refer to the `paper <https://arxiv.org/abs/1611.05431>`_ for
Please refer to the `paper <https://arxiv.org/abs/1611.05431>`__ for
details.
Args:
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch.utils.checkpoint as cp
from ..builder import BACKBONES
......@@ -57,7 +58,7 @@ class SEBottleneck(Bottleneck):
class SEResNet(ResNet):
"""SEResNet backbone.
Please refer to the `paper <https://arxiv.org/abs/1709.01507>`_ for
Please refer to the `paper <https://arxiv.org/abs/1709.01507>`__ for
details.
Args:
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import build_conv_layer, build_norm_layer
from ..builder import BACKBONES
......@@ -95,7 +96,7 @@ class SEBottleneck(_SEBottleneck):
class SEResNeXt(SEResNet):
"""SEResNeXt backbone.
Please refer to the `paper <https://arxiv.org/abs/1709.01507>`_ for
Please refer to the `paper <https://arxiv.org/abs/1709.01507>`__ for
details.
Args:
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, build_activation_layer, constant_init,
normal_init)
from mmcv.runner import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.utils import channel_shuffle, make_divisible
......@@ -10,7 +12,7 @@ from ..builder import BACKBONES
from .base_backbone import BaseBackbone
class ShuffleUnit(nn.Module):
class ShuffleUnit(BaseModule):
"""ShuffleUnit block.
ShuffleNet unit with pointwise group convolution (GConv) and channel
......@@ -22,7 +24,7 @@ class ShuffleUnit(nn.Module):
groups (int): The number of groups to be used in grouped 1x1
convolutions in each ShuffleUnit. Default: 3
first_block (bool): Whether it is the first ShuffleUnit of a
sequential ShuffleUnits. Default: False, which means not using the
sequential ShuffleUnits. Default: True, which means not using the
grouped 1x1 convolution.
combine (str): The ways to combine the input and output
branches. Default: 'add'.
......@@ -184,6 +186,7 @@ class ShuffleNetV1(BaseBackbone):
with_cp=False,
init_cfg=None):
super(ShuffleNetV1, self).__init__(init_cfg)
self.init_cfg = init_cfg
self.stage_blocks = [4, 8, 4]
self.groups = groups
......@@ -250,6 +253,12 @@ class ShuffleNetV1(BaseBackbone):
def init_weights(self):
super(ShuffleNetV1, self).init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if 'conv1' in name:
......@@ -257,7 +266,7 @@ class ShuffleNetV1(BaseBackbone):
else:
normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m.weight, val=1, bias=0.0001)
constant_init(m, val=1, bias=0.0001)
if isinstance(m, _BatchNorm):
if m.running_mean is not None:
nn.init.constant_(m.running_mean, 0)
......@@ -269,7 +278,7 @@ class ShuffleNetV1(BaseBackbone):
out_channels (int): out_channels of the block.
num_blocks (int): Number of blocks.
first_block (bool): Whether is the first ShuffleUnit of a
sequential ShuffleUnits. Default: False, which means not using
sequential ShuffleUnits. Default: False, which means using
the grouped 1x1 convolution.
"""
layers = []
......@@ -301,9 +310,6 @@ class ShuffleNetV1(BaseBackbone):
if i in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
def train(self, mode=True):
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule, constant_init, normal_init
from mmcv.runner import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.utils import channel_shuffle
......@@ -9,7 +11,7 @@ from ..builder import BACKBONES
from .base_backbone import BaseBackbone
class InvertedResidual(nn.Module):
class InvertedResidual(BaseModule):
"""InvertedResidual block for ShuffleNetV2 backbone.
Args:
......@@ -36,8 +38,9 @@ class InvertedResidual(nn.Module):
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
with_cp=False):
super(InvertedResidual, self).__init__()
with_cp=False,
init_cfg=None):
super(InvertedResidual, self).__init__(init_cfg)
self.stride = stride
self.with_cp = with_cp
......@@ -112,7 +115,14 @@ class InvertedResidual(nn.Module):
if self.stride > 1:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
else:
x1, x2 = x.chunk(2, dim=1)
# Channel Split operation. using these lines of code to replace
# ``chunk(x, 2, dim=1)`` can make it easier to deploy a
# shufflenetv2 model by using mmdeploy.
channels = x.shape[1]
c = channels // 2 + channels % 2
x1 = x[:, :c, :, :]
x2 = x[:, c:, :, :]
out = torch.cat((x1, self.branch2(x2)), dim=1)
out = channel_shuffle(out, 2)
......@@ -253,8 +263,14 @@ class ShuffleNetV2(BaseBackbone):
for param in m.parameters():
param.requires_grad = False
def init_weighs(self):
def init_weights(self):
super(ShuffleNetV2, self).init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if 'conv1' in name:
......@@ -277,9 +293,6 @@ class ShuffleNetV2(BaseBackbone):
if i in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
def train(self, mode=True):
......
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed, PatchMerging
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule, ModuleList
from mmcv.utils.parrots_wrapper import _BatchNorm
from ..builder import BACKBONES
from ..utils import (ShiftWindowMSA, resize_pos_embed,
resize_relative_position_bias_table, to_2tuple)
from .base_backbone import BaseBackbone
class SwinBlock(BaseModule):
"""Swin Transformer block.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
shift (bool): Shift the attention window or not. Defaults to False.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
drop_path (float): The drop path rate after attention and ffn.
Defaults to 0.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
attn_cfgs (dict): The extra config of Shift Window-MSA.
Defaults to empty dict.
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
norm_cfg (dict): The config of norm layers.
Defaults to ``dict(type='LN')``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size=7,
shift=False,
ffn_ratio=4.,
drop_path=0.,
pad_small_map=False,
attn_cfgs=dict(),
ffn_cfgs=dict(),
norm_cfg=dict(type='LN'),
with_cp=False,
init_cfg=None):
super(SwinBlock, self).__init__(init_cfg)
self.with_cp = with_cp
_attn_cfgs = {
'embed_dims': embed_dims,
'num_heads': num_heads,
'shift_size': window_size // 2 if shift else 0,
'window_size': window_size,
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
'pad_small_map': pad_small_map,
**attn_cfgs
}
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
self.attn = ShiftWindowMSA(**_attn_cfgs)
_ffn_cfgs = {
'embed_dims': embed_dims,
'feedforward_channels': int(embed_dims * ffn_ratio),
'num_fcs': 2,
'ffn_drop': 0,
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
'act_cfg': dict(type='GELU'),
**ffn_cfgs
}
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
self.ffn = FFN(**_ffn_cfgs)
def forward(self, x, hw_shape):
def _inner_forward(x):
identity = x
x = self.norm1(x)
x = self.attn(x, hw_shape)
x = x + identity
identity = x
x = self.norm2(x)
x = self.ffn(x, identity=identity)
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
class SwinBlockSequence(BaseModule):
"""Module with successive Swin Transformer blocks and downsample layer.
Args:
embed_dims (int): Number of input channels.
depth (int): Number of successive swin transformer blocks.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
downsample (bool): Downsample the output of blocks by patch merging.
Defaults to False.
downsample_cfg (dict): The extra config of the patch merging layer.
Defaults to empty dict.
drop_paths (Sequence[float] | float): The drop path rate in each block.
Defaults to 0.
block_cfgs (Sequence[dict] | dict): The extra config of each block.
Defaults to empty dicts.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
depth,
num_heads,
window_size=7,
downsample=False,
downsample_cfg=dict(),
drop_paths=0.,
block_cfgs=dict(),
with_cp=False,
pad_small_map=False,
init_cfg=None):
super().__init__(init_cfg)
if not isinstance(drop_paths, Sequence):
drop_paths = [drop_paths] * depth
if not isinstance(block_cfgs, Sequence):
block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)]
self.embed_dims = embed_dims
self.blocks = ModuleList()
for i in range(depth):
_block_cfg = {
'embed_dims': embed_dims,
'num_heads': num_heads,
'window_size': window_size,
'shift': False if i % 2 == 0 else True,
'drop_path': drop_paths[i],
'with_cp': with_cp,
'pad_small_map': pad_small_map,
**block_cfgs[i]
}
block = SwinBlock(**_block_cfg)
self.blocks.append(block)
if downsample:
_downsample_cfg = {
'in_channels': embed_dims,
'out_channels': 2 * embed_dims,
'norm_cfg': dict(type='LN'),
**downsample_cfg
}
self.downsample = PatchMerging(**_downsample_cfg)
else:
self.downsample = None
def forward(self, x, in_shape, do_downsample=True):
for block in self.blocks:
x = block(x, in_shape)
if self.downsample is not None and do_downsample:
x, out_shape = self.downsample(x, in_shape)
else:
out_shape = in_shape
return x, out_shape
@property
def out_channels(self):
if self.downsample:
return self.downsample.out_channels
else:
return self.embed_dims
@BACKBONES.register_module()
class SwinTransformer(BaseBackbone):
"""Swin Transformer.
A PyTorch implement of : `Swin Transformer:
Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>`_
Inspiration from
https://github.com/microsoft/Swin-Transformer
Args:
arch (str | dict): Swin Transformer architecture. If use string, choose
from 'tiny', 'small', 'base' and 'large'. If use dict, it should
have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **depths** (List[int]): The number of blocks in each stage.
- **num_heads** (List[int]): The number of heads in attention
modules of each stage.
Defaults to 'tiny'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 4.
in_channels (int): The num of input channels. Defaults to 3.
window_size (int): The height and width of the window. Defaults to 7.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
out_after_downsample (bool): Whether to output the feature map of a
stage after the following downsample layer. Defaults to False.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults to False.
interpolate_mode (str): Select the interpolate mode for absolute
position embeding vector resize. Defaults to "bicubic".
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
norm_cfg (dict): Config dict for normalization layer for all output
features. Defaults to ``dict(type='LN')``
stage_cfgs (Sequence[dict] | dict): Extra config dict for each
stage. Defaults to an empty dict.
patch_cfg (dict): Extra config dict for patch embedding.
Defaults to an empty dict.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmcls.models import SwinTransformer
>>> import torch
>>> extra_config = dict(
>>> arch='tiny',
>>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3,
>>> 'expansion_ratio': 3}))
>>> self = SwinTransformer(**extra_config)
>>> inputs = torch.rand(1, 3, 224, 224)
>>> output = self.forward(inputs)
>>> print(output.shape)
(1, 2592, 4)
"""
arch_zoo = {
**dict.fromkeys(['t', 'tiny'],
{'embed_dims': 96,
'depths': [2, 2, 6, 2],
'num_heads': [3, 6, 12, 24]}),
**dict.fromkeys(['s', 'small'],
{'embed_dims': 96,
'depths': [2, 2, 18, 2],
'num_heads': [3, 6, 12, 24]}),
**dict.fromkeys(['b', 'base'],
{'embed_dims': 128,
'depths': [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32]}),
**dict.fromkeys(['l', 'large'],
{'embed_dims': 192,
'depths': [2, 2, 18, 2],
'num_heads': [6, 12, 24, 48]}),
} # yapf: disable
_version = 3
num_extra_tokens = 0
def __init__(self,
arch='tiny',
img_size=224,
patch_size=4,
in_channels=3,
window_size=7,
drop_rate=0.,
drop_path_rate=0.1,
out_indices=(3, ),
out_after_downsample=False,
use_abs_pos_embed=False,
interpolate_mode='bicubic',
with_cp=False,
frozen_stages=-1,
norm_eval=False,
pad_small_map=False,
norm_cfg=dict(type='LN'),
stage_cfgs=dict(),
patch_cfg=dict(),
init_cfg=None):
super(SwinTransformer, self).__init__(init_cfg=init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {'embed_dims', 'depths', 'num_heads'}
assert isinstance(arch, dict) and set(arch) == essential_keys, \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.embed_dims = self.arch_settings['embed_dims']
self.depths = self.arch_settings['depths']
self.num_heads = self.arch_settings['num_heads']
self.num_layers = len(self.depths)
self.out_indices = out_indices
self.out_after_downsample = out_after_downsample
self.use_abs_pos_embed = use_abs_pos_embed
self.interpolate_mode = interpolate_mode
self.frozen_stages = frozen_stages
_patch_cfg = dict(
in_channels=in_channels,
input_size=img_size,
embed_dims=self.embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
norm_cfg=dict(type='LN'),
)
_patch_cfg.update(patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
self.patch_resolution = self.patch_embed.init_out_size
if self.use_abs_pos_embed:
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
self.absolute_pos_embed = nn.Parameter(
torch.zeros(1, num_patches, self.embed_dims))
self._register_load_state_dict_pre_hook(
self._prepare_abs_pos_embed)
self._register_load_state_dict_pre_hook(
self._prepare_relative_position_bias_table)
self.drop_after_pos = nn.Dropout(p=drop_rate)
self.norm_eval = norm_eval
# stochastic depth
total_depth = sum(self.depths)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
] # stochastic depth decay rule
self.stages = ModuleList()
embed_dims = [self.embed_dims]
for i, (depth,
num_heads) in enumerate(zip(self.depths, self.num_heads)):
if isinstance(stage_cfgs, Sequence):
stage_cfg = stage_cfgs[i]
else:
stage_cfg = deepcopy(stage_cfgs)
downsample = True if i < self.num_layers - 1 else False
_stage_cfg = {
'embed_dims': embed_dims[-1],
'depth': depth,
'num_heads': num_heads,
'window_size': window_size,
'downsample': downsample,
'drop_paths': dpr[:depth],
'with_cp': with_cp,
'pad_small_map': pad_small_map,
**stage_cfg
}
stage = SwinBlockSequence(**_stage_cfg)
self.stages.append(stage)
dpr = dpr[depth:]
embed_dims.append(stage.out_channels)
if self.out_after_downsample:
self.num_features = embed_dims[1:]
else:
self.num_features = embed_dims[:-1]
for i in out_indices:
if norm_cfg is not None:
norm_layer = build_norm_layer(norm_cfg,
self.num_features[i])[1]
else:
norm_layer = nn.Identity()
self.add_module(f'norm{i}', norm_layer)
def init_weights(self):
super(SwinTransformer, self).init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return
if self.use_abs_pos_embed:
trunc_normal_(self.absolute_pos_embed, std=0.02)
def forward(self, x):
x, hw_shape = self.patch_embed(x)
if self.use_abs_pos_embed:
x = x + resize_pos_embed(
self.absolute_pos_embed, self.patch_resolution, hw_shape,
self.interpolate_mode, self.num_extra_tokens)
x = self.drop_after_pos(x)
outs = []
for i, stage in enumerate(self.stages):
x, hw_shape = stage(
x, hw_shape, do_downsample=self.out_after_downsample)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
out = norm_layer(x)
out = out.view(-1, *hw_shape,
self.num_features[i]).permute(0, 3, 1,
2).contiguous()
outs.append(out)
if stage.downsample is not None and not self.out_after_downsample:
x, hw_shape = stage.downsample(x, hw_shape)
return tuple(outs)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, *args,
**kwargs):
"""load checkpoints."""
# Names of some parameters in has been changed.
version = local_metadata.get('version', None)
if (version is None
or version < 2) and self.__class__ is SwinTransformer:
final_stage_num = len(self.stages) - 1
state_dict_keys = list(state_dict.keys())
for k in state_dict_keys:
if k.startswith('norm.') or k.startswith('backbone.norm.'):
convert_key = k.replace('norm.', f'norm{final_stage_num}.')
state_dict[convert_key] = state_dict[k]
del state_dict[k]
if (version is None
or version < 3) and self.__class__ is SwinTransformer:
state_dict_keys = list(state_dict.keys())
for k in state_dict_keys:
if 'attn_mask' in k:
del state_dict[k]
super()._load_from_state_dict(state_dict, prefix, local_metadata,
*args, **kwargs)
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
for i in range(0, self.frozen_stages + 1):
m = self.stages[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
for i in self.out_indices:
if i <= self.frozen_stages:
for param in getattr(self, f'norm{i}').parameters():
param.requires_grad = False
def train(self, mode=True):
super(SwinTransformer, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'absolute_pos_embed'
if name not in state_dict.keys():
return
ckpt_pos_embed_shape = state_dict[name].shape
if self.absolute_pos_embed.shape != ckpt_pos_embed_shape:
from mmcls.utils import get_root_logger
logger = get_root_logger()
logger.info(
'Resize the absolute_pos_embed shape from '
f'{ckpt_pos_embed_shape} to {self.absolute_pos_embed.shape}.')
ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
pos_embed_shape = self.patch_embed.init_out_size
state_dict[name] = resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode,
self.num_extra_tokens)
def _prepare_relative_position_bias_table(self, state_dict, prefix, *args,
**kwargs):
state_dict_model = self.state_dict()
all_keys = list(state_dict_model.keys())
for key in all_keys:
if 'relative_position_bias_table' in key:
ckpt_key = prefix + key
if ckpt_key not in state_dict:
continue
relative_position_bias_table_pretrained = state_dict[ckpt_key]
relative_position_bias_table_current = state_dict_model[key]
L1, nH1 = relative_position_bias_table_pretrained.size()
L2, nH2 = relative_position_bias_table_current.size()
if L1 != L2:
src_size = int(L1**0.5)
dst_size = int(L2**0.5)
new_rel_pos_bias = resize_relative_position_bias_table(
src_size, dst_size,
relative_position_bias_table_pretrained, nH1)
from mmcls.utils import get_root_logger
logger = get_root_logger()
logger.info('Resize the relative_position_bias_table from '
f'{state_dict[ckpt_key].shape} to '
f'{new_rel_pos_bias.shape}')
state_dict[ckpt_key] = new_rel_pos_bias
# The index buffer need to be re-generated.
index_buffer = ckpt_key.replace('bias_table', 'index')
del state_dict[index_buffer]
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule, ModuleList
from mmcv.utils.parrots_wrapper import _BatchNorm
from ..builder import BACKBONES
from ..utils import (PatchMerging, ShiftWindowMSA, WindowMSAV2,
resize_pos_embed, to_2tuple)
from .base_backbone import BaseBackbone
class SwinBlockV2(BaseModule):
"""Swin Transformer V2 block. Use post normalization.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
shift (bool): Shift the attention window or not. Defaults to False.
extra_norm (bool): Whether add extra norm at the end of main branch.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
drop_path (float): The drop path rate after attention and ffn.
Defaults to 0.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
attn_cfgs (dict): The extra config of Shift Window-MSA.
Defaults to empty dict.
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
norm_cfg (dict): The config of norm layers.
Defaults to ``dict(type='LN')``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
pretrained_window_size (int): Window size in pretrained.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size=8,
shift=False,
extra_norm=False,
ffn_ratio=4.,
drop_path=0.,
pad_small_map=False,
attn_cfgs=dict(),
ffn_cfgs=dict(),
norm_cfg=dict(type='LN'),
with_cp=False,
pretrained_window_size=0,
init_cfg=None):
super(SwinBlockV2, self).__init__(init_cfg)
self.with_cp = with_cp
self.extra_norm = extra_norm
_attn_cfgs = {
'embed_dims': embed_dims,
'num_heads': num_heads,
'shift_size': window_size // 2 if shift else 0,
'window_size': window_size,
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
'pad_small_map': pad_small_map,
**attn_cfgs
}
# use V2 attention implementation
_attn_cfgs.update(
window_msa=WindowMSAV2,
msa_cfg=dict(
pretrained_window_size=to_2tuple(pretrained_window_size)))
self.attn = ShiftWindowMSA(**_attn_cfgs)
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
_ffn_cfgs = {
'embed_dims': embed_dims,
'feedforward_channels': int(embed_dims * ffn_ratio),
'num_fcs': 2,
'ffn_drop': 0,
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
'act_cfg': dict(type='GELU'),
'add_identity': False,
**ffn_cfgs
}
self.ffn = FFN(**_ffn_cfgs)
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
# add extra norm for every n blocks in huge and giant model
if self.extra_norm:
self.norm3 = build_norm_layer(norm_cfg, embed_dims)[1]
def forward(self, x, hw_shape):
def _inner_forward(x):
# Use post normalization
identity = x
x = self.attn(x, hw_shape)
x = self.norm1(x)
x = x + identity
identity = x
x = self.ffn(x)
x = self.norm2(x)
x = x + identity
if self.extra_norm:
x = self.norm3(x)
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
class SwinBlockV2Sequence(BaseModule):
"""Module with successive Swin Transformer blocks and downsample layer.
Args:
embed_dims (int): Number of input channels.
depth (int): Number of successive swin transformer blocks.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
downsample (bool): Downsample the output of blocks by patch merging.
Defaults to False.
downsample_cfg (dict): The extra config of the patch merging layer.
Defaults to empty dict.
drop_paths (Sequence[float] | float): The drop path rate in each block.
Defaults to 0.
block_cfgs (Sequence[dict] | dict): The extra config of each block.
Defaults to empty dicts.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
extra_norm_every_n_blocks (int): Add extra norm at the end of main
branch every n blocks. Defaults to 0, which means no needs for
extra norm layer.
pretrained_window_size (int): Window size in pretrained.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
depth,
num_heads,
window_size=8,
downsample=False,
downsample_cfg=dict(),
drop_paths=0.,
block_cfgs=dict(),
with_cp=False,
pad_small_map=False,
extra_norm_every_n_blocks=0,
pretrained_window_size=0,
init_cfg=None):
super().__init__(init_cfg)
if not isinstance(drop_paths, Sequence):
drop_paths = [drop_paths] * depth
if not isinstance(block_cfgs, Sequence):
block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)]
if downsample:
self.out_channels = 2 * embed_dims
_downsample_cfg = {
'in_channels': embed_dims,
'out_channels': self.out_channels,
'norm_cfg': dict(type='LN'),
**downsample_cfg
}
self.downsample = PatchMerging(**_downsample_cfg)
else:
self.out_channels = embed_dims
self.downsample = None
self.blocks = ModuleList()
for i in range(depth):
extra_norm = True if extra_norm_every_n_blocks and \
(i + 1) % extra_norm_every_n_blocks == 0 else False
_block_cfg = {
'embed_dims': self.out_channels,
'num_heads': num_heads,
'window_size': window_size,
'shift': False if i % 2 == 0 else True,
'extra_norm': extra_norm,
'drop_path': drop_paths[i],
'with_cp': with_cp,
'pad_small_map': pad_small_map,
'pretrained_window_size': pretrained_window_size,
**block_cfgs[i]
}
block = SwinBlockV2(**_block_cfg)
self.blocks.append(block)
def forward(self, x, in_shape):
if self.downsample:
x, out_shape = self.downsample(x, in_shape)
else:
out_shape = in_shape
for block in self.blocks:
x = block(x, out_shape)
return x, out_shape
@BACKBONES.register_module()
class SwinTransformerV2(BaseBackbone):
"""Swin Transformer V2.
A PyTorch implement of : `Swin Transformer V2:
Scaling Up Capacity and Resolution
<https://arxiv.org/abs/2111.09883>`_
Inspiration from
https://github.com/microsoft/Swin-Transformer
Args:
arch (str | dict): Swin Transformer architecture. If use string, choose
from 'tiny', 'small', 'base' and 'large'. If use dict, it should
have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **depths** (List[int]): The number of blocks in each stage.
- **num_heads** (List[int]): The number of heads in attention
modules of each stage.
- **extra_norm_every_n_blocks** (int): Add extra norm at the end
of main branch every n blocks.
Defaults to 'tiny'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 4.
in_channels (int): The num of input channels. Defaults to 3.
window_size (int | Sequence): The height and width of the window.
Defaults to 7.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults to False.
interpolate_mode (str): Select the interpolate mode for absolute
position embeding vector resize. Defaults to "bicubic".
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
norm_cfg (dict): Config dict for normalization layer for all output
features. Defaults to ``dict(type='LN')``
stage_cfgs (Sequence[dict] | dict): Extra config dict for each
stage. Defaults to an empty dict.
patch_cfg (dict): Extra config dict for patch embedding.
Defaults to an empty dict.
pretrained_window_sizes (tuple(int)): Pretrained window sizes of
each layer.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmcls.models import SwinTransformerV2
>>> import torch
>>> extra_config = dict(
>>> arch='tiny',
>>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3,
>>> 'padding': 'same'}))
>>> self = SwinTransformerV2(**extra_config)
>>> inputs = torch.rand(1, 3, 224, 224)
>>> output = self.forward(inputs)
>>> print(output.shape)
(1, 2592, 4)
"""
arch_zoo = {
**dict.fromkeys(['t', 'tiny'],
{'embed_dims': 96,
'depths': [2, 2, 6, 2],
'num_heads': [3, 6, 12, 24],
'extra_norm_every_n_blocks': 0}),
**dict.fromkeys(['s', 'small'],
{'embed_dims': 96,
'depths': [2, 2, 18, 2],
'num_heads': [3, 6, 12, 24],
'extra_norm_every_n_blocks': 0}),
**dict.fromkeys(['b', 'base'],
{'embed_dims': 128,
'depths': [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32],
'extra_norm_every_n_blocks': 0}),
**dict.fromkeys(['l', 'large'],
{'embed_dims': 192,
'depths': [2, 2, 18, 2],
'num_heads': [6, 12, 24, 48],
'extra_norm_every_n_blocks': 0}),
# head count not certain for huge, and is employed for another
# parallel study about self-supervised learning.
**dict.fromkeys(['h', 'huge'],
{'embed_dims': 352,
'depths': [2, 2, 18, 2],
'num_heads': [8, 16, 32, 64],
'extra_norm_every_n_blocks': 6}),
**dict.fromkeys(['g', 'giant'],
{'embed_dims': 512,
'depths': [2, 2, 42, 4],
'num_heads': [16, 32, 64, 128],
'extra_norm_every_n_blocks': 6}),
} # yapf: disable
_version = 1
num_extra_tokens = 0
def __init__(self,
arch='tiny',
img_size=256,
patch_size=4,
in_channels=3,
window_size=8,
drop_rate=0.,
drop_path_rate=0.1,
out_indices=(3, ),
use_abs_pos_embed=False,
interpolate_mode='bicubic',
with_cp=False,
frozen_stages=-1,
norm_eval=False,
pad_small_map=False,
norm_cfg=dict(type='LN'),
stage_cfgs=dict(downsample_cfg=dict(is_post_norm=True)),
patch_cfg=dict(),
pretrained_window_sizes=[0, 0, 0, 0],
init_cfg=None):
super(SwinTransformerV2, self).__init__(init_cfg=init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {
'embed_dims', 'depths', 'num_heads',
'extra_norm_every_n_blocks'
}
assert isinstance(arch, dict) and set(arch) == essential_keys, \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.embed_dims = self.arch_settings['embed_dims']
self.depths = self.arch_settings['depths']
self.num_heads = self.arch_settings['num_heads']
self.extra_norm_every_n_blocks = self.arch_settings[
'extra_norm_every_n_blocks']
self.num_layers = len(self.depths)
self.out_indices = out_indices
self.use_abs_pos_embed = use_abs_pos_embed
self.interpolate_mode = interpolate_mode
self.frozen_stages = frozen_stages
if isinstance(window_size, int):
self.window_sizes = [window_size for _ in range(self.num_layers)]
elif isinstance(window_size, Sequence):
assert len(window_size) == self.num_layers, \
f'Length of window_sizes {len(window_size)} is not equal to '\
f'length of stages {self.num_layers}.'
self.window_sizes = window_size
else:
raise TypeError('window_size should be a Sequence or int.')
_patch_cfg = dict(
in_channels=in_channels,
input_size=img_size,
embed_dims=self.embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
norm_cfg=dict(type='LN'),
)
_patch_cfg.update(patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
self.patch_resolution = self.patch_embed.init_out_size
if self.use_abs_pos_embed:
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
self.absolute_pos_embed = nn.Parameter(
torch.zeros(1, num_patches, self.embed_dims))
self._register_load_state_dict_pre_hook(
self._prepare_abs_pos_embed)
self._register_load_state_dict_pre_hook(self._delete_reinit_params)
self.drop_after_pos = nn.Dropout(p=drop_rate)
self.norm_eval = norm_eval
# stochastic depth
total_depth = sum(self.depths)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
] # stochastic depth decay rule
self.stages = ModuleList()
embed_dims = [self.embed_dims]
for i, (depth,
num_heads) in enumerate(zip(self.depths, self.num_heads)):
if isinstance(stage_cfgs, Sequence):
stage_cfg = stage_cfgs[i]
else:
stage_cfg = deepcopy(stage_cfgs)
downsample = True if i > 0 else False
_stage_cfg = {
'embed_dims': embed_dims[-1],
'depth': depth,
'num_heads': num_heads,
'window_size': self.window_sizes[i],
'downsample': downsample,
'drop_paths': dpr[:depth],
'with_cp': with_cp,
'pad_small_map': pad_small_map,
'extra_norm_every_n_blocks': self.extra_norm_every_n_blocks,
'pretrained_window_size': pretrained_window_sizes[i],
**stage_cfg
}
stage = SwinBlockV2Sequence(**_stage_cfg)
self.stages.append(stage)
dpr = dpr[depth:]
embed_dims.append(stage.out_channels)
for i in out_indices:
if norm_cfg is not None:
norm_layer = build_norm_layer(norm_cfg, embed_dims[i + 1])[1]
else:
norm_layer = nn.Identity()
self.add_module(f'norm{i}', norm_layer)
def init_weights(self):
super(SwinTransformerV2, self).init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return
if self.use_abs_pos_embed:
trunc_normal_(self.absolute_pos_embed, std=0.02)
def forward(self, x):
x, hw_shape = self.patch_embed(x)
if self.use_abs_pos_embed:
x = x + resize_pos_embed(
self.absolute_pos_embed, self.patch_resolution, hw_shape,
self.interpolate_mode, self.num_extra_tokens)
x = self.drop_after_pos(x)
outs = []
for i, stage in enumerate(self.stages):
x, hw_shape = stage(x, hw_shape)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
out = norm_layer(x)
out = out.view(-1, *hw_shape,
stage.out_channels).permute(0, 3, 1,
2).contiguous()
outs.append(out)
return tuple(outs)
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
for i in range(0, self.frozen_stages + 1):
m = self.stages[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
for i in self.out_indices:
if i <= self.frozen_stages:
for param in getattr(self, f'norm{i}').parameters():
param.requires_grad = False
def train(self, mode=True):
super(SwinTransformerV2, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'absolute_pos_embed'
if name not in state_dict.keys():
return
ckpt_pos_embed_shape = state_dict[name].shape
if self.absolute_pos_embed.shape != ckpt_pos_embed_shape:
from mmcls.utils import get_root_logger
logger = get_root_logger()
logger.info(
'Resize the absolute_pos_embed shape from '
f'{ckpt_pos_embed_shape} to {self.absolute_pos_embed.shape}.')
ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
pos_embed_shape = self.patch_embed.init_out_size
state_dict[name] = resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode,
self.num_extra_tokens)
def _delete_reinit_params(self, state_dict, prefix, *args, **kwargs):
# delete relative_position_index since we always re-init it
relative_position_index_keys = [
k for k in state_dict.keys() if 'relative_position_index' in k
]
for k in relative_position_index_keys:
del state_dict[k]
# delete relative_coords_table since we always re-init it
relative_position_index_keys = [
k for k in state_dict.keys() if 'relative_coords_table' in k
]
for k in relative_position_index_keys:
del state_dict[k]
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule, ModuleList
from ..builder import BACKBONES
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple
from .base_backbone import BaseBackbone
class T2TTransformerLayer(BaseModule):
"""Transformer Layer for T2T_ViT.
Comparing with :obj:`TransformerEncoderLayer` in ViT, it supports
different ``input_dims`` and ``embed_dims``.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs
input_dims (int, optional): The input token dimension.
Defaults to None.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop_rate (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
qk_scale (float, optional): Override default qk scale of
``(input_dims // num_heads) ** -0.5`` if set. Defaults to None.
act_cfg (dict): The activation config for FFNs.
Defaluts to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
Notes:
In general, ``qk_scale`` should be ``head_dims ** -0.5``, i.e.
``(embed_dims // num_heads) ** -0.5``. However, in the official
code, it uses ``(input_dims // num_heads) ** -0.5``, so here we
keep the same with the official implementation.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
input_dims=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
qkv_bias=False,
qk_scale=None,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_cfg=None):
super(T2TTransformerLayer, self).__init__(init_cfg=init_cfg)
self.v_shortcut = True if input_dims is not None else False
input_dims = input_dims or embed_dims
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, input_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self.attn = MultiheadAttention(
input_dims=input_dims,
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
qkv_bias=qkv_bias,
qk_scale=qk_scale or (input_dims // num_heads)**-0.5,
v_shortcut=self.v_shortcut)
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, embed_dims, postfix=2)
self.add_module(self.norm2_name, norm2)
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg)
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
def forward(self, x):
if self.v_shortcut:
x = self.attn(self.norm1(x))
else:
x = x + self.attn(self.norm1(x))
x = self.ffn(self.norm2(x), identity=x)
return x
class T2TModule(BaseModule):
"""Tokens-to-Token module.
"Tokens-to-Token module" (T2T Module) can model the local structure
information of images and reduce the length of tokens progressively.
Args:
img_size (int): Input image size
in_channels (int): Number of input channels
embed_dims (int): Embedding dimension
token_dims (int): Tokens dimension in T2TModuleAttention.
use_performer (bool): If True, use Performer version self-attention to
adopt regular self-attention. Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Default: None.
Notes:
Usually, ``token_dim`` is set as a small value (32 or 64) to reduce
MACs
"""
def __init__(
self,
img_size=224,
in_channels=3,
embed_dims=384,
token_dims=64,
use_performer=False,
init_cfg=None,
):
super(T2TModule, self).__init__(init_cfg)
self.embed_dims = embed_dims
self.soft_split0 = nn.Unfold(
kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
self.soft_split1 = nn.Unfold(
kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.soft_split2 = nn.Unfold(
kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
if not use_performer:
self.attention1 = T2TTransformerLayer(
input_dims=in_channels * 7 * 7,
embed_dims=token_dims,
num_heads=1,
feedforward_channels=token_dims)
self.attention2 = T2TTransformerLayer(
input_dims=token_dims * 3 * 3,
embed_dims=token_dims,
num_heads=1,
feedforward_channels=token_dims)
self.project = nn.Linear(token_dims * 3 * 3, embed_dims)
else:
raise NotImplementedError("Performer hasn't been implemented.")
# there are 3 soft split, stride are 4,2,2 separately
out_side = img_size // (4 * 2 * 2)
self.init_out_size = [out_side, out_side]
self.num_patches = out_side**2
@staticmethod
def _get_unfold_size(unfold: nn.Unfold, input_size):
h, w = input_size
kernel_size = to_2tuple(unfold.kernel_size)
stride = to_2tuple(unfold.stride)
padding = to_2tuple(unfold.padding)
dilation = to_2tuple(unfold.dilation)
h_out = (h + 2 * padding[0] - dilation[0] *
(kernel_size[0] - 1) - 1) // stride[0] + 1
w_out = (w + 2 * padding[1] - dilation[1] *
(kernel_size[1] - 1) - 1) // stride[1] + 1
return (h_out, w_out)
def forward(self, x):
# step0: soft split
hw_shape = self._get_unfold_size(self.soft_split0, x.shape[2:])
x = self.soft_split0(x).transpose(1, 2)
for step in [1, 2]:
# re-structurization/reconstruction
attn = getattr(self, f'attention{step}')
x = attn(x).transpose(1, 2)
B, C, _ = x.shape
x = x.reshape(B, C, hw_shape[0], hw_shape[1])
# soft split
soft_split = getattr(self, f'soft_split{step}')
hw_shape = self._get_unfold_size(soft_split, hw_shape)
x = soft_split(x).transpose(1, 2)
# final tokens
x = self.project(x)
return x, hw_shape
def get_sinusoid_encoding(n_position, embed_dims):
"""Generate sinusoid encoding table.
Sinusoid encoding is a kind of relative position encoding method came from
`Attention Is All You Need<https://arxiv.org/abs/1706.03762>`_.
Args:
n_position (int): The length of the input token.
embed_dims (int): The position embedding dimension.
Returns:
:obj:`torch.FloatTensor`: The sinusoid encoding table.
"""
vec = torch.arange(embed_dims, dtype=torch.float64)
vec = (vec - vec % 2) / embed_dims
vec = torch.pow(10000, -vec).view(1, -1)
sinusoid_table = torch.arange(n_position).view(-1, 1) * vec
sinusoid_table[:, 0::2].sin_() # dim 2i
sinusoid_table[:, 1::2].cos_() # dim 2i+1
sinusoid_table = sinusoid_table.to(torch.float32)
return sinusoid_table.unsqueeze(0)
@BACKBONES.register_module()
class T2T_ViT(BaseBackbone):
"""Tokens-to-Token Vision Transformer (T2T-ViT)
A PyTorch implementation of `Tokens-to-Token ViT: Training Vision
Transformers from Scratch on ImageNet <https://arxiv.org/abs/2101.11986>`_
Args:
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
in_channels (int): Number of input channels.
embed_dims (int): Embedding dimension.
num_layers (int): Num of transformer layers in encoder.
Defaults to 14.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Dropout rate after position embedding.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
norm_cfg (dict): Config dict for normalization layer. Defaults to
``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
output_cls_token (bool): Whether output the cls_token. If set True,
``with_cls_token`` must be True. Defaults to True.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
t2t_cfg (dict): Extra config of Tokens-to-Token module.
Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
num_extra_tokens = 1 # cls_token
def __init__(self,
img_size=224,
in_channels=3,
embed_dims=384,
num_layers=14,
out_indices=-1,
drop_rate=0.,
drop_path_rate=0.,
norm_cfg=dict(type='LN'),
final_norm=True,
with_cls_token=True,
output_cls_token=True,
interpolate_mode='bicubic',
t2t_cfg=dict(),
layer_cfgs=dict(),
init_cfg=None):
super(T2T_ViT, self).__init__(init_cfg)
# Token-to-Token Module
self.tokens_to_token = T2TModule(
img_size=img_size,
in_channels=in_channels,
embed_dims=embed_dims,
**t2t_cfg)
self.patch_resolution = self.tokens_to_token.init_out_size
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
# Set cls token
if output_cls_token:
assert with_cls_token is True, f'with_cls_token must be True if' \
f'set output_cls_token to True, but got {with_cls_token}'
self.with_cls_token = with_cls_token
self.output_cls_token = output_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
# Set position embedding
self.interpolate_mode = interpolate_mode
sinusoid_table = get_sinusoid_encoding(
num_patches + self.num_extra_tokens, embed_dims)
self.register_buffer('pos_embed', sinusoid_table)
self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
self.drop_after_pos = nn.Dropout(p=drop_rate)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must be a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = num_layers + index
assert 0 <= out_indices[i] <= num_layers, \
f'Invalid out_indices {index}'
self.out_indices = out_indices
# stochastic depth decay rule
dpr = [x for x in np.linspace(0, drop_path_rate, num_layers)]
self.encoder = ModuleList()
for i in range(num_layers):
if isinstance(layer_cfgs, Sequence):
layer_cfg = layer_cfgs[i]
else:
layer_cfg = deepcopy(layer_cfgs)
layer_cfg = {
'embed_dims': embed_dims,
'num_heads': 6,
'feedforward_channels': 3 * embed_dims,
'drop_path_rate': dpr[i],
'qkv_bias': False,
'norm_cfg': norm_cfg,
**layer_cfg
}
layer = T2TTransformerLayer(**layer_cfg)
self.encoder.append(layer)
self.final_norm = final_norm
if final_norm:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = nn.Identity()
def init_weights(self):
super().init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress custom init if use pretrained model.
return
trunc_normal_(self.cls_token, std=.02)
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'pos_embed'
if name not in state_dict.keys():
return
ckpt_pos_embed_shape = state_dict[name].shape
if self.pos_embed.shape != ckpt_pos_embed_shape:
from mmcls.utils import get_root_logger
logger = get_root_logger()
logger.info(
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
f'to {self.pos_embed.shape}.')
ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
pos_embed_shape = self.tokens_to_token.init_out_size
state_dict[name] = resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode,
self.num_extra_tokens)
def forward(self, x):
B = x.shape[0]
x, patch_resolution = self.tokens_to_token(x)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=self.num_extra_tokens)
x = self.drop_after_pos(x)
if not self.with_cls_token:
# Remove class token for transformer encoder input
x = x[:, 1:]
outs = []
for i, layer in enumerate(self.encoder):
x = layer(x)
if i == len(self.encoder) - 1 and self.final_norm:
x = self.norm(x)
if i in self.out_indices:
B, _, C = x.shape
if self.with_cls_token:
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = x[:, 0]
else:
patch_token = x.reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = None
if self.output_cls_token:
out = [patch_token, cls_token]
else:
out = patch_token
outs.append(out)
return tuple(outs)
# Copyright (c) OpenMMLab. All rights reserved.
try:
import timm
except ImportError:
timm = None
import warnings
from mmcv.cnn.bricks.registry import NORM_LAYERS
from ...utils import get_root_logger
from ..builder import BACKBONES
from .base_backbone import BaseBackbone
def print_timm_feature_info(feature_info):
"""Print feature_info of timm backbone to help development and debug.
Args:
feature_info (list[dict] | timm.models.features.FeatureInfo | None):
feature_info of timm backbone.
"""
logger = get_root_logger()
if feature_info is None:
logger.warning('This backbone does not have feature_info')
elif isinstance(feature_info, list):
for feat_idx, each_info in enumerate(feature_info):
logger.info(f'backbone feature_info[{feat_idx}]: {each_info}')
else:
try:
logger.info(f'backbone out_indices: {feature_info.out_indices}')
logger.info(f'backbone out_channels: {feature_info.channels()}')
logger.info(f'backbone out_strides: {feature_info.reduction()}')
except AttributeError:
logger.warning('Unexpected format of backbone feature_info')
@BACKBONES.register_module()
class TIMMBackbone(BaseBackbone):
"""Wrapper to use backbones from timm library.
More details can be found in
`timm <https://github.com/rwightman/pytorch-image-models>`_.
See especially the document for `feature extraction
<https://rwightman.github.io/pytorch-image-models/feature_extraction/>`_.
Args:
model_name (str): Name of timm model to instantiate.
features_only (bool): Whether to extract feature pyramid (multi-scale
feature maps from the deepest layer at each stride). For Vision
Transformer models that do not support this argument,
set this False. Defaults to False.
pretrained (bool): Whether to load pretrained weights.
Defaults to False.
checkpoint_path (str): Path of checkpoint to load at the last of
``timm.create_model``. Defaults to empty string, which means
not loading.
in_channels (int): Number of input image channels. Defaults to 3.
init_cfg (dict or list[dict], optional): Initialization config dict of
OpenMMLab projects. Defaults to None.
**kwargs: Other timm & model specific arguments.
"""
def __init__(self,
model_name,
features_only=False,
pretrained=False,
checkpoint_path='',
in_channels=3,
init_cfg=None,
**kwargs):
if timm is None:
raise RuntimeError(
'Failed to import timm. Please run "pip install timm". '
'"pip install dataclasses" may also be needed for Python 3.6.')
if not isinstance(pretrained, bool):
raise TypeError('pretrained must be bool, not str for model path')
if features_only and checkpoint_path:
warnings.warn(
'Using both features_only and checkpoint_path will cause error'
' in timm. See '
'https://github.com/rwightman/pytorch-image-models/issues/488')
super(TIMMBackbone, self).__init__(init_cfg)
if 'norm_layer' in kwargs:
kwargs['norm_layer'] = NORM_LAYERS.get(kwargs['norm_layer'])
self.timm_model = timm.create_model(
model_name=model_name,
features_only=features_only,
pretrained=pretrained,
in_chans=in_channels,
checkpoint_path=checkpoint_path,
**kwargs)
# reset classifier
if hasattr(self.timm_model, 'reset_classifier'):
self.timm_model.reset_classifier(0, '')
# Hack to use pretrained weights from timm
if pretrained or checkpoint_path:
self._is_init = True
feature_info = getattr(self.timm_model, 'feature_info', None)
print_timm_feature_info(feature_info)
def forward(self, x):
features = self.timm_model(x)
if isinstance(features, (list, tuple)):
features = tuple(features)
else:
features = (features, )
return features
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule, ModuleList
from ..builder import BACKBONES
from ..utils import to_2tuple
from .base_backbone import BaseBackbone
class TransformerBlock(BaseModule):
"""Implement a transformer block in TnTLayer.
Args:
embed_dims (int): The feature dimension
num_heads (int): Parallel attention heads
ffn_ratio (int): A ratio to calculate the hidden_dims in ffn layer.
Default: 4
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Default 0.
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.
drop_path_rate (float): stochastic depth rate. Default 0.
num_fcs (int): The number of fully-connected layers for FFNs. Default 2
qkv_bias (bool): Enable bias for qkv if True. Default False
act_cfg (dict): The activation config for FFNs. Defaults to GELU.
norm_cfg (dict): Config dict for normalization layer. Default
layer normalization
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim) or (n, batch, embed_dim).
(batch, n, embed_dim) is common case in CV. Default to False
init_cfg (dict, optional): Initialization config dict. Default to None
"""
def __init__(self,
embed_dims,
num_heads,
ffn_ratio=4,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
qkv_bias=False,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
batch_first=True,
init_cfg=None):
super(TransformerBlock, self).__init__(init_cfg=init_cfg)
self.norm_attn = build_norm_layer(norm_cfg, embed_dims)[1]
self.attn = MultiheadAttention(
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
batch_first=batch_first)
self.norm_ffn = build_norm_layer(norm_cfg, embed_dims)[1]
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=embed_dims * ffn_ratio,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg)
if not qkv_bias:
self.attn.attn.in_proj_bias = None
def forward(self, x):
x = self.attn(self.norm_attn(x), identity=x)
x = self.ffn(self.norm_ffn(x), identity=x)
return x
class TnTLayer(BaseModule):
"""Implement one encoder layer in Transformer in Transformer.
Args:
num_pixel (int): The pixel number in target patch transformed with
a linear projection in inner transformer
embed_dims_inner (int): Feature dimension in inner transformer block
embed_dims_outer (int): Feature dimension in outer transformer block
num_heads_inner (int): Parallel attention heads in inner transformer.
num_heads_outer (int): Parallel attention heads in outer transformer.
inner_block_cfg (dict): Extra config of inner transformer block.
Defaults to empty dict.
outer_block_cfg (dict): Extra config of outer transformer block.
Defaults to empty dict.
norm_cfg (dict): Config dict for normalization layer. Default
layer normalization
init_cfg (dict, optional): Initialization config dict. Default to None
"""
def __init__(self,
num_pixel,
embed_dims_inner,
embed_dims_outer,
num_heads_inner,
num_heads_outer,
inner_block_cfg=dict(),
outer_block_cfg=dict(),
norm_cfg=dict(type='LN'),
init_cfg=None):
super(TnTLayer, self).__init__(init_cfg=init_cfg)
self.inner_block = TransformerBlock(
embed_dims=embed_dims_inner,
num_heads=num_heads_inner,
**inner_block_cfg)
self.norm_proj = build_norm_layer(norm_cfg, embed_dims_inner)[1]
self.projection = nn.Linear(
embed_dims_inner * num_pixel, embed_dims_outer, bias=True)
self.outer_block = TransformerBlock(
embed_dims=embed_dims_outer,
num_heads=num_heads_outer,
**outer_block_cfg)
def forward(self, pixel_embed, patch_embed):
pixel_embed = self.inner_block(pixel_embed)
B, N, C = patch_embed.size()
patch_embed[:, 1:] = patch_embed[:, 1:] + self.projection(
self.norm_proj(pixel_embed).reshape(B, N - 1, -1))
patch_embed = self.outer_block(patch_embed)
return pixel_embed, patch_embed
class PixelEmbed(BaseModule):
"""Image to Pixel Embedding.
Args:
img_size (int | tuple): The size of input image
patch_size (int): The size of one patch
in_channels (int): The num of input channels
embed_dims_inner (int): The num of channels of the target patch
transformed with a linear projection in inner transformer
stride (int): The stride of the conv2d layer. We use a conv2d layer
and a unfold layer to implement image to pixel embedding.
init_cfg (dict, optional): Initialization config dict
"""
def __init__(self,
img_size=224,
patch_size=16,
in_channels=3,
embed_dims_inner=48,
stride=4,
init_cfg=None):
super(PixelEmbed, self).__init__(init_cfg=init_cfg)
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
# patches_resolution property necessary for resizing
# positional embedding
patches_resolution = [
img_size[0] // patch_size[0], img_size[1] // patch_size[1]
]
num_patches = patches_resolution[0] * patches_resolution[1]
self.img_size = img_size
self.num_patches = num_patches
self.embed_dims_inner = embed_dims_inner
new_patch_size = [math.ceil(ps / stride) for ps in patch_size]
self.new_patch_size = new_patch_size
self.proj = nn.Conv2d(
in_channels,
self.embed_dims_inner,
kernel_size=7,
padding=3,
stride=stride)
self.unfold = nn.Unfold(
kernel_size=new_patch_size, stride=new_patch_size)
def forward(self, x, pixel_pos):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model " \
f'({self.img_size[0]}*{self.img_size[1]}).'
x = self.proj(x)
x = self.unfold(x)
x = x.transpose(1,
2).reshape(B * self.num_patches, self.embed_dims_inner,
self.new_patch_size[0],
self.new_patch_size[1])
x = x + pixel_pos
x = x.reshape(B * self.num_patches, self.embed_dims_inner,
-1).transpose(1, 2)
return x
@BACKBONES.register_module()
class TNT(BaseBackbone):
"""Transformer in Transformer.
A PyTorch implement of: `Transformer in Transformer
<https://arxiv.org/abs/2103.00112>`_
Inspiration from
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/tnt.py
Args:
arch (str | dict): Vision Transformer architecture
Default: 'b'
img_size (int | tuple): Input image size. Default to 224
patch_size (int | tuple): The patch size. Deault to 16
in_channels (int): Number of input channels. Default to 3
ffn_ratio (int): A ratio to calculate the hidden_dims in ffn layer.
Default: 4
qkv_bias (bool): Enable bias for qkv if True. Default False
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Default 0.
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.
drop_path_rate (float): stochastic depth rate. Default 0.
act_cfg (dict): The activation config for FFNs. Defaults to GELU.
norm_cfg (dict): Config dict for normalization layer. Default
layer normalization
first_stride (int): The stride of the conv2d layer. We use a conv2d
layer and a unfold layer to implement image to pixel embedding.
num_fcs (int): The number of fully-connected layers for FFNs. Default 2
init_cfg (dict, optional): Initialization config dict
"""
arch_zoo = {
**dict.fromkeys(
['s', 'small'], {
'embed_dims_outer': 384,
'embed_dims_inner': 24,
'num_layers': 12,
'num_heads_outer': 6,
'num_heads_inner': 4
}),
**dict.fromkeys(
['b', 'base'], {
'embed_dims_outer': 640,
'embed_dims_inner': 40,
'num_layers': 12,
'num_heads_outer': 10,
'num_heads_inner': 4
})
}
def __init__(self,
arch='b',
img_size=224,
patch_size=16,
in_channels=3,
ffn_ratio=4,
qkv_bias=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
first_stride=4,
num_fcs=2,
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=.02),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
]):
super(TNT, self).__init__(init_cfg=init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {
'embed_dims_outer', 'embed_dims_inner', 'num_layers',
'num_heads_inner', 'num_heads_outer'
}
assert isinstance(arch, dict) and set(arch) == essential_keys, \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.embed_dims_inner = self.arch_settings['embed_dims_inner']
self.embed_dims_outer = self.arch_settings['embed_dims_outer']
# embed_dims for consistency with other models
self.embed_dims = self.embed_dims_outer
self.num_layers = self.arch_settings['num_layers']
self.num_heads_inner = self.arch_settings['num_heads_inner']
self.num_heads_outer = self.arch_settings['num_heads_outer']
self.pixel_embed = PixelEmbed(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dims_inner=self.embed_dims_inner,
stride=first_stride)
num_patches = self.pixel_embed.num_patches
self.num_patches = num_patches
new_patch_size = self.pixel_embed.new_patch_size
num_pixel = new_patch_size[0] * new_patch_size[1]
self.norm1_proj = build_norm_layer(norm_cfg, num_pixel *
self.embed_dims_inner)[1]
self.projection = nn.Linear(num_pixel * self.embed_dims_inner,
self.embed_dims_outer)
self.norm2_proj = build_norm_layer(norm_cfg, self.embed_dims_outer)[1]
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims_outer))
self.patch_pos = nn.Parameter(
torch.zeros(1, num_patches + 1, self.embed_dims_outer))
self.pixel_pos = nn.Parameter(
torch.zeros(1, self.embed_dims_inner, new_patch_size[0],
new_patch_size[1]))
self.drop_after_pos = nn.Dropout(p=drop_rate)
dpr = [
x.item()
for x in torch.linspace(0, drop_path_rate, self.num_layers)
] # stochastic depth decay rule
self.layers = ModuleList()
for i in range(self.num_layers):
block_cfg = dict(
ffn_ratio=ffn_ratio,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dpr[i],
num_fcs=num_fcs,
qkv_bias=qkv_bias,
norm_cfg=norm_cfg,
batch_first=True)
self.layers.append(
TnTLayer(
num_pixel=num_pixel,
embed_dims_inner=self.embed_dims_inner,
embed_dims_outer=self.embed_dims_outer,
num_heads_inner=self.num_heads_inner,
num_heads_outer=self.num_heads_outer,
inner_block_cfg=block_cfg,
outer_block_cfg=block_cfg,
norm_cfg=norm_cfg))
self.norm = build_norm_layer(norm_cfg, self.embed_dims_outer)[1]
trunc_normal_(self.cls_token, std=.02)
trunc_normal_(self.patch_pos, std=.02)
trunc_normal_(self.pixel_pos, std=.02)
def forward(self, x):
B = x.shape[0]
pixel_embed = self.pixel_embed(x, self.pixel_pos)
patch_embed = self.norm2_proj(
self.projection(
self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1))))
patch_embed = torch.cat(
(self.cls_token.expand(B, -1, -1), patch_embed), dim=1)
patch_embed = patch_embed + self.patch_pos
patch_embed = self.drop_after_pos(patch_embed)
for layer in self.layers:
pixel_embed, patch_embed = layer(pixel_embed, patch_embed)
patch_embed = self.norm(patch_embed)
return (patch_embed[:, 0], )
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Conv2d, build_norm_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmcv.cnn.utils.weight_init import (constant_init, normal_init,
trunc_normal_init)
from mmcv.runner import BaseModule, ModuleList
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.builder import BACKBONES
from mmcls.models.utils.attention import MultiheadAttention
from mmcls.models.utils.position_encoding import ConditionalPositionEncoding
class GlobalSubsampledAttention(MultiheadAttention):
"""Global Sub-sampled Attention (GSA) module.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
input_dims (int, optional): The input dimension, and if None,
use ``embed_dims``. Defaults to None.
attn_drop (float): Dropout rate of the dropout layer after the
attention calculation of query and key. Defaults to 0.
proj_drop (float): Dropout rate of the dropout layer after the
output projection. Defaults to 0.
dropout_layer (dict): The dropout config before adding the shortcut.
Defaults to ``dict(type='Dropout', drop_prob=0.)``.
qkv_bias (bool): If True, add a learnable bias to q, k, v.
Defaults to True.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
proj_bias (bool) If True, add a learnable bias to output projection.
Defaults to True.
v_shortcut (bool): Add a shortcut from value to output. It's usually
used if ``input_dims`` is different from ``embed_dims``.
Defaults to False.
sr_ratio (float): The ratio of spatial reduction in attention modules.
Defaults to 1.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
norm_cfg=dict(type='LN'),
qkv_bias=True,
sr_ratio=1,
**kwargs):
super(GlobalSubsampledAttention,
self).__init__(embed_dims, num_heads, **kwargs)
self.qkv_bias = qkv_bias
self.q = nn.Linear(self.input_dims, embed_dims, bias=qkv_bias)
self.kv = nn.Linear(self.input_dims, embed_dims * 2, bias=qkv_bias)
# remove self.qkv, here split into self.q, self.kv
delattr(self, 'qkv')
self.sr_ratio = sr_ratio
if sr_ratio > 1:
# use a conv as the spatial-reduction operation, the kernel_size
# and stride in conv are equal to the sr_ratio.
self.sr = Conv2d(
in_channels=embed_dims,
out_channels=embed_dims,
kernel_size=sr_ratio,
stride=sr_ratio)
# The ret[0] of build_norm_layer is norm name.
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
def forward(self, x, hw_shape):
B, N, C = x.shape
H, W = hw_shape
assert H * W == N, 'The product of h and w of hw_shape must be N, ' \
'which is the 2nd dim number of the input Tensor x.'
q = self.q(x).reshape(B, N, self.num_heads,
C // self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x = x.permute(0, 2, 1).reshape(B, C, *hw_shape) # BNC_2_BCHW
x = self.sr(x)
x = x.reshape(B, C, -1).permute(0, 2, 1) # BCHW_2_BNC
x = self.norm(x)
kv = self.kv(x).reshape(B, -1, 2, self.num_heads,
self.head_dims).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.out_drop(self.proj_drop(x))
if self.v_shortcut:
x = v.squeeze(1) + x
return x
class GSAEncoderLayer(BaseModule):
"""Implements one encoder layer with GlobalSubsampledAttention(GSA).
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Default: 0.0.
attn_drop_rate (float): The drop out rate for attention layer.
Default: 0.0.
drop_path_rate (float): Stochastic depth rate. Default 0.0.
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
qkv_bias (bool): Enable bias for qkv if True. Default: True
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
sr_ratio (float): The ratio of spatial reduction in attention modules.
Defaults to 1.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
qkv_bias=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
sr_ratio=1.,
init_cfg=None):
super(GSAEncoderLayer, self).__init__(init_cfg=init_cfg)
self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1]
self.attn = GlobalSubsampledAttention(
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
qkv_bias=qkv_bias,
norm_cfg=norm_cfg,
sr_ratio=sr_ratio)
self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1]
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg,
add_identity=False)
self.drop_path = build_dropout(
dict(type='DropPath', drop_prob=drop_path_rate)
) if drop_path_rate > 0. else nn.Identity()
def forward(self, x, hw_shape):
x = x + self.drop_path(self.attn(self.norm1(x), hw_shape))
x = x + self.drop_path(self.ffn(self.norm2(x)))
return x
class LocallyGroupedSelfAttention(BaseModule):
"""Locally-grouped Self Attention (LSA) module.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads. Default: 8
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: False.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Default: 0.0
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
window_size(int): Window size of LSA. Default: 1.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop_rate=0.,
proj_drop_rate=0.,
window_size=1,
init_cfg=None):
super(LocallyGroupedSelfAttention, self).__init__(init_cfg=init_cfg)
assert embed_dims % num_heads == 0, \
f'dim {embed_dims} should be divided by num_heads {num_heads}'
self.embed_dims = embed_dims
self.num_heads = num_heads
head_dim = embed_dims // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_rate)
self.proj = nn.Linear(embed_dims, embed_dims)
self.proj_drop = nn.Dropout(proj_drop_rate)
self.window_size = window_size
def forward(self, x, hw_shape):
B, N, C = x.shape
H, W = hw_shape
x = x.view(B, H, W, C)
# pad feature maps to multiples of Local-groups
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
# calculate attention mask for LSA
Hp, Wp = x.shape[1:-1]
_h, _w = Hp // self.window_size, Wp // self.window_size
mask = torch.zeros((1, Hp, Wp), device=x.device)
mask[:, -pad_b:, :].fill_(1)
mask[:, :, -pad_r:].fill_(1)
# [B, _h, _w, window_size, window_size, C]
x = x.reshape(B, _h, self.window_size, _w, self.window_size,
C).transpose(2, 3)
mask = mask.reshape(1, _h, self.window_size, _w,
self.window_size).transpose(2, 3).reshape(
1, _h * _w,
self.window_size * self.window_size)
# [1, _h*_w, window_size*window_size, window_size*window_size]
attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3)
attn_mask = attn_mask.masked_fill(attn_mask != 0,
float(-1000.0)).masked_fill(
attn_mask == 0, float(0.0))
# [3, B, _w*_h, nhead, window_size*window_size, dim]
qkv = self.qkv(x).reshape(B, _h * _w,
self.window_size * self.window_size, 3,
self.num_heads, C // self.num_heads).permute(
3, 0, 1, 4, 2, 5)
q, k, v = qkv[0], qkv[1], qkv[2]
# [B, _h*_w, n_head, window_size*window_size, window_size*window_size]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn + attn_mask.unsqueeze(2)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.window_size,
self.window_size, C)
x = attn.transpose(2, 3).reshape(B, _h * self.window_size,
_w * self.window_size, C)
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class LSAEncoderLayer(BaseModule):
"""Implements one encoder layer with LocallyGroupedSelfAttention(LSA).
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Default: 0.0.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Default: 0.0
drop_path_rate (float): Stochastic depth rate. Default 0.0.
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
qkv_bias (bool): Enable bias for qkv if True. Default: True
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
window_size (int): Window size of LSA. Default: 1.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
qkv_bias=True,
qk_scale=None,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
window_size=1,
init_cfg=None):
super(LSAEncoderLayer, self).__init__(init_cfg=init_cfg)
self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1]
self.attn = LocallyGroupedSelfAttention(embed_dims, num_heads,
qkv_bias, qk_scale,
attn_drop_rate, drop_rate,
window_size)
self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1]
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg,
add_identity=False)
self.drop_path = build_dropout(
dict(type='DropPath', drop_prob=drop_path_rate)
) if drop_path_rate > 0. else nn.Identity()
def forward(self, x, hw_shape):
x = x + self.drop_path(self.attn(self.norm1(x), hw_shape))
x = x + self.drop_path(self.ffn(self.norm2(x)))
return x
@BACKBONES.register_module()
class PCPVT(BaseModule):
"""The backbone of Twins-PCPVT.
This backbone is the implementation of `Twins: Revisiting the Design
of Spatial Attention in Vision Transformers
<https://arxiv.org/abs/1512.03385>`_.
Args:
arch (dict, str): PCPVT architecture, a str value in arch zoo or a
detailed configuration dict with 7 keys, and the length of all the
values in dict should be the same:
- depths (List[int]): The number of encoder layers in each stage.
- embed_dims (List[int]): Embedding dimension in each stage.
- patch_sizes (List[int]): The patch sizes in each stage.
- num_heads (List[int]): Numbers of attention head in each stage.
- strides (List[int]): The strides in each stage.
- mlp_ratios (List[int]): The ratios of mlp in each stage.
- sr_ratios (List[int]): The ratios of GSA-encoder layers in each
stage.
in_channels (int): Number of input channels. Default: 3.
out_indices (tuple[int]): Output from which stages.
Default: (3, ).
qkv_bias (bool): Enable bias for qkv if True. Default: False.
drop_rate (float): Probability of an element to be zeroed.
Default 0.
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0
drop_path_rate (float): Stochastic depth rate. Default 0.0
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
norm_after_stage(bool, List[bool]): Add extra norm after each stage.
Default False.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmcls.models import PCPVT
>>> import torch
>>> pcpvt_cfg = {'arch': "small",
>>> 'norm_after_stage': [False, False, False, True]}
>>> model = PCPVT(**pcpvt_cfg)
>>> x = torch.rand(1, 3, 224, 224)
>>> outputs = model(x)
>>> print(outputs[-1].shape)
torch.Size([1, 512, 7, 7])
>>> pcpvt_cfg['norm_after_stage'] = [True, True, True, True]
>>> pcpvt_cfg['out_indices'] = (0, 1, 2, 3)
>>> model = PCPVT(**pcpvt_cfg)
>>> outputs = model(x)
>>> for feat in outputs:
>>> print(feat.shape)
torch.Size([1, 64, 56, 56])
torch.Size([1, 128, 28, 28])
torch.Size([1, 320, 14, 14])
torch.Size([1, 512, 7, 7])
"""
arch_zoo = {
**dict.fromkeys(['s', 'small'],
{'embed_dims': [64, 128, 320, 512],
'depths': [3, 4, 6, 3],
'num_heads': [1, 2, 5, 8],
'patch_sizes': [4, 2, 2, 2],
'strides': [4, 2, 2, 2],
'mlp_ratios': [8, 8, 4, 4],
'sr_ratios': [8, 4, 2, 1]}),
**dict.fromkeys(['b', 'base'],
{'embed_dims': [64, 128, 320, 512],
'depths': [3, 4, 18, 3],
'num_heads': [1, 2, 5, 8],
'patch_sizes': [4, 2, 2, 2],
'strides': [4, 2, 2, 2],
'mlp_ratios': [8, 8, 4, 4],
'sr_ratios': [8, 4, 2, 1]}),
**dict.fromkeys(['l', 'large'],
{'embed_dims': [64, 128, 320, 512],
'depths': [3, 8, 27, 3],
'num_heads': [1, 2, 5, 8],
'patch_sizes': [4, 2, 2, 2],
'strides': [4, 2, 2, 2],
'mlp_ratios': [8, 8, 4, 4],
'sr_ratios': [8, 4, 2, 1]}),
} # yapf: disable
essential_keys = {
'embed_dims', 'depths', 'num_heads', 'patch_sizes', 'strides',
'mlp_ratios', 'sr_ratios'
}
def __init__(self,
arch,
in_channels=3,
out_indices=(3, ),
qkv_bias=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_cfg=dict(type='LN'),
norm_after_stage=False,
init_cfg=None):
super(PCPVT, self).__init__(init_cfg=init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
assert isinstance(arch, dict) and (
set(arch) == self.essential_keys
), f'Custom arch needs a dict with keys {self.essential_keys}.'
self.arch_settings = arch
self.depths = self.arch_settings['depths']
self.embed_dims = self.arch_settings['embed_dims']
self.patch_sizes = self.arch_settings['patch_sizes']
self.strides = self.arch_settings['strides']
self.mlp_ratios = self.arch_settings['mlp_ratios']
self.num_heads = self.arch_settings['num_heads']
self.sr_ratios = self.arch_settings['sr_ratios']
self.num_extra_tokens = 0 # there is no cls-token in Twins
self.num_stage = len(self.depths)
for key, value in self.arch_settings.items():
assert isinstance(value, list) and len(value) == self.num_stage, (
'Length of setting item in arch dict must be type of list and'
' have the same length.')
# patch_embeds
self.patch_embeds = ModuleList()
self.position_encoding_drops = ModuleList()
self.stages = ModuleList()
for i in range(self.num_stage):
# use in_channels of the model in the first stage
if i == 0:
stage_in_channels = in_channels
else:
stage_in_channels = self.embed_dims[i - 1]
self.patch_embeds.append(
PatchEmbed(
in_channels=stage_in_channels,
embed_dims=self.embed_dims[i],
conv_type='Conv2d',
kernel_size=self.patch_sizes[i],
stride=self.strides[i],
padding='corner',
norm_cfg=dict(type='LN')))
self.position_encoding_drops.append(nn.Dropout(p=drop_rate))
# PEGs
self.position_encodings = ModuleList([
ConditionalPositionEncoding(embed_dim, embed_dim)
for embed_dim in self.embed_dims
])
# stochastic depth
total_depth = sum(self.depths)
self.dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
] # stochastic depth decay rule
cur = 0
for k in range(len(self.depths)):
_block = ModuleList([
GSAEncoderLayer(
embed_dims=self.embed_dims[k],
num_heads=self.num_heads[k],
feedforward_channels=self.mlp_ratios[k] *
self.embed_dims[k],
attn_drop_rate=attn_drop_rate,
drop_rate=drop_rate,
drop_path_rate=self.dpr[cur + i],
num_fcs=2,
qkv_bias=qkv_bias,
act_cfg=dict(type='GELU'),
norm_cfg=norm_cfg,
sr_ratio=self.sr_ratios[k]) for i in range(self.depths[k])
])
self.stages.append(_block)
cur += self.depths[k]
self.out_indices = out_indices
assert isinstance(norm_after_stage, (bool, list))
if isinstance(norm_after_stage, bool):
self.norm_after_stage = [norm_after_stage] * self.num_stage
else:
self.norm_after_stage = norm_after_stage
assert len(self.norm_after_stage) == self.num_stage, \
(f'Number of norm_after_stage({len(self.norm_after_stage)}) should'
f' be equal to the number of stages({self.num_stage}).')
for i, has_norm in enumerate(self.norm_after_stage):
assert isinstance(has_norm, bool), 'norm_after_stage should be ' \
'bool or List[bool].'
if has_norm and norm_cfg is not None:
norm_layer = build_norm_layer(norm_cfg, self.embed_dims[i])[1]
else:
norm_layer = nn.Identity()
self.add_module(f'norm_after_stage{i}', norm_layer)
def init_weights(self):
if self.init_cfg is not None:
super(PCPVT, self).init_weights()
else:
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
constant_init(m, val=1.0, bias=0.)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[
1] * m.out_channels
fan_out //= m.groups
normal_init(
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
def forward(self, x):
outputs = list()
b = x.shape[0]
for i in range(self.num_stage):
x, hw_shape = self.patch_embeds[i](x)
h, w = hw_shape
x = self.position_encoding_drops[i](x)
for j, blk in enumerate(self.stages[i]):
x = blk(x, hw_shape)
if j == 0:
x = self.position_encodings[i](x, hw_shape)
norm_layer = getattr(self, f'norm_after_stage{i}')
x = norm_layer(x)
x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
if i in self.out_indices:
outputs.append(x)
return tuple(outputs)
@BACKBONES.register_module()
class SVT(PCPVT):
"""The backbone of Twins-SVT.
This backbone is the implementation of `Twins: Revisiting the Design
of Spatial Attention in Vision Transformers
<https://arxiv.org/abs/1512.03385>`_.
Args:
arch (dict, str): SVT architecture, a str value in arch zoo or a
detailed configuration dict with 8 keys, and the length of all the
values in dict should be the same:
- depths (List[int]): The number of encoder layers in each stage.
- embed_dims (List[int]): Embedding dimension in each stage.
- patch_sizes (List[int]): The patch sizes in each stage.
- num_heads (List[int]): Numbers of attention head in each stage.
- strides (List[int]): The strides in each stage.
- mlp_ratios (List[int]): The ratios of mlp in each stage.
- sr_ratios (List[int]): The ratios of GSA-encoder layers in each
stage.
- windiow_sizes (List[int]): The window sizes in LSA-encoder layers
in each stage.
in_channels (int): Number of input channels. Default: 3.
out_indices (tuple[int]): Output from which stages.
Default: (3, ).
qkv_bias (bool): Enable bias for qkv if True. Default: False.
drop_rate (float): Dropout rate. Default 0.
attn_drop_rate (float): Dropout ratio of attention weight.
Default 0.0
drop_path_rate (float): Stochastic depth rate. Default 0.2.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
norm_after_stage(bool, List[bool]): Add extra norm after each stage.
Default False.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmcls.models import SVT
>>> import torch
>>> svt_cfg = {'arch': "small",
>>> 'norm_after_stage': [False, False, False, True]}
>>> model = SVT(**svt_cfg)
>>> x = torch.rand(1, 3, 224, 224)
>>> outputs = model(x)
>>> print(outputs[-1].shape)
torch.Size([1, 512, 7, 7])
>>> svt_cfg["out_indices"] = (0, 1, 2, 3)
>>> svt_cfg["norm_after_stage"] = [True, True, True, True]
>>> model = SVT(**svt_cfg)
>>> output = model(x)
>>> for feat in output:
>>> print(feat.shape)
torch.Size([1, 64, 56, 56])
torch.Size([1, 128, 28, 28])
torch.Size([1, 320, 14, 14])
torch.Size([1, 512, 7, 7])
"""
arch_zoo = {
**dict.fromkeys(['s', 'small'],
{'embed_dims': [64, 128, 256, 512],
'depths': [2, 2, 10, 4],
'num_heads': [2, 4, 8, 16],
'patch_sizes': [4, 2, 2, 2],
'strides': [4, 2, 2, 2],
'mlp_ratios': [4, 4, 4, 4],
'sr_ratios': [8, 4, 2, 1],
'window_sizes': [7, 7, 7, 7]}),
**dict.fromkeys(['b', 'base'],
{'embed_dims': [96, 192, 384, 768],
'depths': [2, 2, 18, 2],
'num_heads': [3, 6, 12, 24],
'patch_sizes': [4, 2, 2, 2],
'strides': [4, 2, 2, 2],
'mlp_ratios': [4, 4, 4, 4],
'sr_ratios': [8, 4, 2, 1],
'window_sizes': [7, 7, 7, 7]}),
**dict.fromkeys(['l', 'large'],
{'embed_dims': [128, 256, 512, 1024],
'depths': [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32],
'patch_sizes': [4, 2, 2, 2],
'strides': [4, 2, 2, 2],
'mlp_ratios': [4, 4, 4, 4],
'sr_ratios': [8, 4, 2, 1],
'window_sizes': [7, 7, 7, 7]}),
} # yapf: disable
essential_keys = {
'embed_dims', 'depths', 'num_heads', 'patch_sizes', 'strides',
'mlp_ratios', 'sr_ratios', 'window_sizes'
}
def __init__(self,
arch,
in_channels=3,
out_indices=(3, ),
qkv_bias=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.0,
norm_cfg=dict(type='LN'),
norm_after_stage=False,
init_cfg=None):
super(SVT, self).__init__(arch, in_channels, out_indices, qkv_bias,
drop_rate, attn_drop_rate, drop_path_rate,
norm_cfg, norm_after_stage, init_cfg)
self.window_sizes = self.arch_settings['window_sizes']
for k in range(self.num_stage):
for i in range(self.depths[k]):
# in even-numbered layers of each stage, replace GSA with LSA
if i % 2 == 0:
ffn_channels = self.mlp_ratios[k] * self.embed_dims[k]
self.stages[k][i] = \
LSAEncoderLayer(
embed_dims=self.embed_dims[k],
num_heads=self.num_heads[k],
feedforward_channels=ffn_channels,
drop_rate=drop_rate,
norm_cfg=norm_cfg,
attn_drop_rate=attn_drop_rate,
drop_path_rate=self.dpr[sum(self.depths[:k])+i],
qkv_bias=qkv_bias,
window_size=self.window_sizes[k])
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
from mmcv.cnn.bricks import DropPath
from mmcv.cnn.bricks.transformer import PatchEmbed
from mmcv.runner import BaseModule, ModuleList
from mmcv.utils.parrots_wrapper import _BatchNorm
from ..builder import BACKBONES
from .base_backbone import BaseBackbone
class MixFFN(BaseModule):
"""An implementation of MixFFN of VAN. Refer to
mmdetection/mmdet/models/backbones/pvt.py.
The differences between MixFFN & FFN:
1. Use 1X1 Conv to replace Linear layer.
2. Introduce 3X3 Depth-wise Conv to encode positional information.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`.
feedforward_channels (int): The hidden dimension of FFNs.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='GELU').
ffn_drop (float, optional): Probability of an element to be
zeroed in FFN. Default 0.0.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
feedforward_channels,
act_cfg=dict(type='GELU'),
ffn_drop=0.,
init_cfg=None):
super(MixFFN, self).__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.act_cfg = act_cfg
self.fc1 = Conv2d(
in_channels=embed_dims,
out_channels=feedforward_channels,
kernel_size=1)
self.dwconv = Conv2d(
in_channels=feedforward_channels,
out_channels=feedforward_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True,
groups=feedforward_channels)
self.act = build_activation_layer(act_cfg)
self.fc2 = Conv2d(
in_channels=feedforward_channels,
out_channels=embed_dims,
kernel_size=1)
self.drop = nn.Dropout(ffn_drop)
def forward(self, x):
x = self.fc1(x)
x = self.dwconv(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class LKA(BaseModule):
"""Large Kernel Attention(LKA) of VAN.
.. code:: text
DW_conv (depth-wise convolution)
|
|
DW_D_conv (depth-wise dilation convolution)
|
|
Transition Convolution (1×1 convolution)
Args:
embed_dims (int): Number of input channels.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self, embed_dims, init_cfg=None):
super(LKA, self).__init__(init_cfg=init_cfg)
# a spatial local convolution (depth-wise convolution)
self.DW_conv = Conv2d(
in_channels=embed_dims,
out_channels=embed_dims,
kernel_size=5,
padding=2,
groups=embed_dims)
# a spatial long-range convolution (depth-wise dilation convolution)
self.DW_D_conv = Conv2d(
in_channels=embed_dims,
out_channels=embed_dims,
kernel_size=7,
stride=1,
padding=9,
groups=embed_dims,
dilation=3)
self.conv1 = Conv2d(
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
def forward(self, x):
u = x.clone()
attn = self.DW_conv(x)
attn = self.DW_D_conv(attn)
attn = self.conv1(attn)
return u * attn
class SpatialAttention(BaseModule):
"""Basic attention module in VANBloack.
Args:
embed_dims (int): Number of input channels.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='GELU').
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self, embed_dims, act_cfg=dict(type='GELU'), init_cfg=None):
super(SpatialAttention, self).__init__(init_cfg=init_cfg)
self.proj_1 = Conv2d(
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
self.activation = build_activation_layer(act_cfg)
self.spatial_gating_unit = LKA(embed_dims)
self.proj_2 = Conv2d(
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
def forward(self, x):
shorcut = x.clone()
x = self.proj_1(x)
x = self.activation(x)
x = self.spatial_gating_unit(x)
x = self.proj_2(x)
x = x + shorcut
return x
class VANBlock(BaseModule):
"""A block of VAN.
Args:
embed_dims (int): Number of input channels.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='GELU').
layer_scale_init_value (float): Init value for Layer Scale.
Defaults to 1e-2.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
ffn_ratio=4.,
drop_rate=0.,
drop_path_rate=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='BN', eps=1e-5),
layer_scale_init_value=1e-2,
init_cfg=None):
super(VANBlock, self).__init__(init_cfg=init_cfg)
self.out_channels = embed_dims
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
self.attn = SpatialAttention(embed_dims, act_cfg=act_cfg)
self.drop_path = DropPath(
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
mlp_hidden_dim = int(embed_dims * ffn_ratio)
self.mlp = MixFFN(
embed_dims=embed_dims,
feedforward_channels=mlp_hidden_dim,
act_cfg=act_cfg,
ffn_drop=drop_rate)
self.layer_scale_1 = nn.Parameter(
layer_scale_init_value * torch.ones((embed_dims)),
requires_grad=True) if layer_scale_init_value > 0 else None
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((embed_dims)),
requires_grad=True) if layer_scale_init_value > 0 else None
def forward(self, x):
identity = x
x = self.norm1(x)
x = self.attn(x)
if self.layer_scale_1 is not None:
x = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * x
x = identity + self.drop_path(x)
identity = x
x = self.norm2(x)
x = self.mlp(x)
if self.layer_scale_2 is not None:
x = self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * x
x = identity + self.drop_path(x)
return x
class VANPatchEmbed(PatchEmbed):
"""Image to Patch Embedding of VAN.
The differences between VANPatchEmbed & PatchEmbed:
1. Use BN.
2. Do not use 'flatten' and 'transpose'.
"""
def __init__(self, *args, norm_cfg=dict(type='BN'), **kwargs):
super(VANPatchEmbed, self).__init__(*args, norm_cfg=norm_cfg, **kwargs)
def forward(self, x):
"""
Args:
x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, out_h * out_w, embed_dims)
- out_size (tuple[int]): Spatial shape of x, arrange as
(out_h, out_w).
"""
if self.adaptive_padding:
x = self.adaptive_padding(x)
x = self.projection(x)
out_size = (x.shape[2], x.shape[3])
if self.norm is not None:
x = self.norm(x)
return x, out_size
@BACKBONES.register_module()
class VAN(BaseBackbone):
"""Visual Attention Network.
A PyTorch implement of : `Visual Attention Network
<https://arxiv.org/pdf/2202.09741v2.pdf>`_
Inspiration from
https://github.com/Visual-Attention-Network/VAN-Classification
Args:
arch (str | dict): Visual Attention Network architecture.
If use string, choose from 'b0', 'b1', b2', b3' and etc.,
if use dict, it should have below keys:
- **embed_dims** (List[int]): The dimensions of embedding.
- **depths** (List[int]): The number of blocks in each stage.
- **ffn_ratios** (List[int]): The number of expansion ratio of
feedforward network hidden layer channels.
Defaults to 'tiny'.
patch_sizes (List[int | tuple]): The patch size in patch embeddings.
Defaults to [7, 3, 3, 3].
in_channels (int): The num of input channels. Defaults to 3.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
out_indices (Sequence[int]): Output from which stages.
Default: ``(3, )``.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
norm_cfg (dict): Config dict for normalization layer for all output
features. Defaults to ``dict(type='LN')``
block_cfgs (Sequence[dict] | dict): The extra config of each block.
Defaults to empty dicts.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmcls.models import VAN
>>> import torch
>>> model = VAN(arch='b0')
>>> inputs = torch.rand(1, 3, 224, 224)
>>> outputs = model(inputs)
>>> for out in outputs:
>>> print(out.size())
(1, 256, 7, 7)
"""
arch_zoo = {
**dict.fromkeys(['b0', 't', 'tiny'],
{'embed_dims': [32, 64, 160, 256],
'depths': [3, 3, 5, 2],
'ffn_ratios': [8, 8, 4, 4]}),
**dict.fromkeys(['b1', 's', 'small'],
{'embed_dims': [64, 128, 320, 512],
'depths': [2, 2, 4, 2],
'ffn_ratios': [8, 8, 4, 4]}),
**dict.fromkeys(['b2', 'b', 'base'],
{'embed_dims': [64, 128, 320, 512],
'depths': [3, 3, 12, 3],
'ffn_ratios': [8, 8, 4, 4]}),
**dict.fromkeys(['b3', 'l', 'large'],
{'embed_dims': [64, 128, 320, 512],
'depths': [3, 5, 27, 3],
'ffn_ratios': [8, 8, 4, 4]}),
**dict.fromkeys(['b4'],
{'embed_dims': [64, 128, 320, 512],
'depths': [3, 6, 40, 3],
'ffn_ratios': [8, 8, 4, 4]}),
**dict.fromkeys(['b5'],
{'embed_dims': [96, 192, 480, 768],
'depths': [3, 3, 24, 3],
'ffn_ratios': [8, 8, 4, 4]}),
**dict.fromkeys(['b6'],
{'embed_dims': [96, 192, 384, 768],
'depths': [6, 6, 90, 6],
'ffn_ratios': [8, 8, 4, 4]}),
} # yapf: disable
def __init__(self,
arch='tiny',
patch_sizes=[7, 3, 3, 3],
in_channels=3,
drop_rate=0.,
drop_path_rate=0.,
out_indices=(3, ),
frozen_stages=-1,
norm_eval=False,
norm_cfg=dict(type='LN'),
block_cfgs=dict(),
init_cfg=None):
super(VAN, self).__init__(init_cfg=init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {'embed_dims', 'depths', 'ffn_ratios'}
assert isinstance(arch, dict) and set(arch) == essential_keys, \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.embed_dims = self.arch_settings['embed_dims']
self.depths = self.arch_settings['depths']
self.ffn_ratios = self.arch_settings['ffn_ratios']
self.num_stages = len(self.depths)
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.norm_eval = norm_eval
total_depth = sum(self.depths)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
] # stochastic depth decay rule
cur_block_idx = 0
for i, depth in enumerate(self.depths):
patch_embed = VANPatchEmbed(
in_channels=in_channels if i == 0 else self.embed_dims[i - 1],
input_size=None,
embed_dims=self.embed_dims[i],
kernel_size=patch_sizes[i],
stride=patch_sizes[i] // 2 + 1,
padding=(patch_sizes[i] // 2, patch_sizes[i] // 2),
norm_cfg=dict(type='BN'))
blocks = ModuleList([
VANBlock(
embed_dims=self.embed_dims[i],
ffn_ratio=self.ffn_ratios[i],
drop_rate=drop_rate,
drop_path_rate=dpr[cur_block_idx + j],
**block_cfgs) for j in range(depth)
])
cur_block_idx += depth
norm = build_norm_layer(norm_cfg, self.embed_dims[i])[1]
self.add_module(f'patch_embed{i + 1}', patch_embed)
self.add_module(f'blocks{i + 1}', blocks)
self.add_module(f'norm{i + 1}', norm)
def train(self, mode=True):
super(VAN, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
def _freeze_stages(self):
for i in range(0, self.frozen_stages + 1):
# freeze patch embed
m = getattr(self, f'patch_embed{i + 1}')
m.eval()
for param in m.parameters():
param.requires_grad = False
# freeze blocks
m = getattr(self, f'blocks{i + 1}')
m.eval()
for param in m.parameters():
param.requires_grad = False
# freeze norm
m = getattr(self, f'norm{i + 1}')
m.eval()
for param in m.parameters():
param.requires_grad = False
def forward(self, x):
outs = []
for i in range(self.num_stages):
patch_embed = getattr(self, f'patch_embed{i + 1}')
blocks = getattr(self, f'blocks{i + 1}')
norm = getattr(self, f'norm{i + 1}')
x, hw_shape = patch_embed(x)
for block in blocks:
x = block(x)
x = x.flatten(2).transpose(1, 2)
x = norm(x)
x = x.reshape(-1, *hw_shape,
block.out_channels).permute(0, 3, 1, 2).contiguous()
if i in self.out_indices:
outs.append(x)
return tuple(outs)
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.utils.parrots_wrapper import _BatchNorm
......@@ -45,13 +46,11 @@ class VGG(BaseBackbone):
num_stages (int): VGG stages, normally 5.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int], optional): Output from which stages.
If only one stage is specified, a single tensor (feature map) is
returned, otherwise multiple stages are specified, a tuple of
tensors will be returned. When it is None, the default behavior
depends on whether num_classes is specified. If num_classes <= 0,
the default value is (4, ), outputing the last feature map before
classifier. If num_classes > 0, the default value is (5, ),
outputing the classification score. Default: None.
When it is None, the default behavior depends on whether
num_classes is specified. If num_classes <= 0, the default value is
(4, ), output the last feature map before classifier. If
num_classes > 0, the default value is (5, ), output the
classification score. Default: None.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
......@@ -162,9 +161,7 @@ class VGG(BaseBackbone):
x = x.view(x.size(0), -1)
x = self.classifier(x)
outs.append(x)
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
def _freeze_stages(self):
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule, ModuleList
from mmcls.utils import get_root_logger
from ..builder import BACKBONES
from ..utils import MultiheadAttention, resize_pos_embed, to_2tuple
from .base_backbone import BaseBackbone
class TransformerEncoderLayer(BaseModule):
"""Implements one encoder layer in Vision Transformer.
Args:
embed_dims (int): The feature dimension
num_heads (int): Parallel attention heads
feedforward_channels (int): The hidden dimension for FFNs
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop_rate (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
act_cfg (dict): The activation config for FFNs.
Defaluts to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
qkv_bias=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_cfg=None):
super(TransformerEncoderLayer, self).__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, self.embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self.attn = MultiheadAttention(
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
qkv_bias=qkv_bias)
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, self.embed_dims, postfix=2)
self.add_module(self.norm2_name, norm2)
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg)
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
def init_weights(self):
super(TransformerEncoderLayer, self).init_weights()
for m in self.ffn.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.normal_(m.bias, std=1e-6)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = self.ffn(self.norm2(x), identity=x)
return x
@BACKBONES.register_module()
class VisionTransformer(BaseBackbone):
"""Vision Transformer.
A PyTorch implement of : `An Image is Worth 16x16 Words: Transformers
for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_
Args:
arch (str | dict): Vision Transformer architecture. If use string,
choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small'
and 'deit-base'. If use dict, it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
- **feedforward_channels** (int): The hidden dimensions in
feedforward modules.
Defaults to 'base'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
output_cls_token (bool): Whether output the cls_token. If set True,
``with_cls_token`` must be True. Defaults to True.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo = {
**dict.fromkeys(
['s', 'small'], {
'embed_dims': 768,
'num_layers': 8,
'num_heads': 8,
'feedforward_channels': 768 * 3,
}),
**dict.fromkeys(
['b', 'base'], {
'embed_dims': 768,
'num_layers': 12,
'num_heads': 12,
'feedforward_channels': 3072
}),
**dict.fromkeys(
['l', 'large'], {
'embed_dims': 1024,
'num_layers': 24,
'num_heads': 16,
'feedforward_channels': 4096
}),
**dict.fromkeys(
['deit-t', 'deit-tiny'], {
'embed_dims': 192,
'num_layers': 12,
'num_heads': 3,
'feedforward_channels': 192 * 4
}),
**dict.fromkeys(
['deit-s', 'deit-small'], {
'embed_dims': 384,
'num_layers': 12,
'num_heads': 6,
'feedforward_channels': 384 * 4
}),
**dict.fromkeys(
['deit-b', 'deit-base'], {
'embed_dims': 768,
'num_layers': 12,
'num_heads': 12,
'feedforward_channels': 768 * 4
}),
}
# Some structures have multiple extra tokens, like DeiT.
num_extra_tokens = 1 # cls_token
def __init__(self,
arch='base',
img_size=224,
patch_size=16,
in_channels=3,
out_indices=-1,
drop_rate=0.,
drop_path_rate=0.,
qkv_bias=True,
norm_cfg=dict(type='LN', eps=1e-6),
final_norm=True,
with_cls_token=True,
output_cls_token=True,
interpolate_mode='bicubic',
patch_cfg=dict(),
layer_cfgs=dict(),
init_cfg=None):
super(VisionTransformer, self).__init__(init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {
'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
}
assert isinstance(arch, dict) and essential_keys <= set(arch), \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.embed_dims = self.arch_settings['embed_dims']
self.num_layers = self.arch_settings['num_layers']
self.img_size = to_2tuple(img_size)
# Set patch embedding
_patch_cfg = dict(
in_channels=in_channels,
input_size=img_size,
embed_dims=self.embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
)
_patch_cfg.update(patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
self.patch_resolution = self.patch_embed.init_out_size
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
# Set cls token
if output_cls_token:
assert with_cls_token is True, f'with_cls_token must be True if' \
f'set output_cls_token to True, but got {with_cls_token}'
self.with_cls_token = with_cls_token
self.output_cls_token = output_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
# Set position embedding
self.interpolate_mode = interpolate_mode
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + self.num_extra_tokens,
self.embed_dims))
self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
self.drop_after_pos = nn.Dropout(p=drop_rate)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.num_layers + index
assert 0 <= out_indices[i] <= self.num_layers, \
f'Invalid out_indices {index}'
self.out_indices = out_indices
# stochastic depth decay rule
dpr = np.linspace(0, drop_path_rate, self.num_layers)
self.layers = ModuleList()
if isinstance(layer_cfgs, dict):
layer_cfgs = [layer_cfgs] * self.num_layers
for i in range(self.num_layers):
_layer_cfg = dict(
embed_dims=self.embed_dims,
num_heads=self.arch_settings['num_heads'],
feedforward_channels=self.
arch_settings['feedforward_channels'],
drop_rate=drop_rate,
drop_path_rate=dpr[i],
qkv_bias=qkv_bias,
norm_cfg=norm_cfg)
_layer_cfg.update(layer_cfgs[i])
self.layers.append(TransformerEncoderLayer(**_layer_cfg))
self.final_norm = final_norm
if final_norm:
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, self.embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
@property
def norm1(self):
return getattr(self, self.norm1_name)
def init_weights(self):
super(VisionTransformer, self).init_weights()
if not (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
trunc_normal_(self.pos_embed, std=0.02)
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'pos_embed'
if name not in state_dict.keys():
return
ckpt_pos_embed_shape = state_dict[name].shape
if self.pos_embed.shape != ckpt_pos_embed_shape:
from mmcv.utils import print_log
logger = get_root_logger()
print_log(
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
f'to {self.pos_embed.shape}.',
logger=logger)
ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
pos_embed_shape = self.patch_embed.init_out_size
state_dict[name] = resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode,
self.num_extra_tokens)
@staticmethod
def resize_pos_embed(*args, **kwargs):
"""Interface for backward-compatibility."""
return resize_pos_embed(*args, **kwargs)
def forward(self, x):
B = x.shape[0]
x, patch_resolution = self.patch_embed(x)
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=self.num_extra_tokens)
x = self.drop_after_pos(x)
if not self.with_cls_token:
# Remove class token for transformer encoder input
x = x[:, 1:]
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1 and self.final_norm:
x = self.norm1(x)
if i in self.out_indices:
B, _, C = x.shape
if self.with_cls_token:
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = x[:, 0]
else:
patch_token = x.reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = None
if self.output_cls_token:
out = [patch_token, cls_token]
else:
out = patch_token
outs.append(out)
return tuple(outs)
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import MODELS as MMCV_MODELS
from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION
from mmcv.utils import Registry
MODELS = Registry('models', parent=MMCV_MODELS)
BACKBONES = MODELS
NECKS = MODELS
HEADS = MODELS
LOSSES = MODELS
CLASSIFIERS = MODELS
ATTENTION = Registry('attention', parent=MMCV_ATTENTION)
def build_backbone(cfg):
"""Build backbone."""
return BACKBONES.build(cfg)
def build_neck(cfg):
"""Build neck."""
return NECKS.build(cfg)
def build_head(cfg):
"""Build head."""
return HEADS.build(cfg)
def build_loss(cfg):
"""Build loss."""
return LOSSES.build(cfg)
def build_classifier(cfg):
return CLASSIFIERS.build(cfg)
# Copyright (c) OpenMMLab. All rights reserved.
from .base import BaseClassifier
from .image import ImageClassifier
__all__ = ['BaseClassifier', 'ImageClassifier']
import warnings
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from typing import Sequence
import cv2
import mmcv
import torch
import torch.distributed as dist
from mmcv import color_val
from mmcv.runner import BaseModule
from mmcv.runner import BaseModule, auto_fp16
# TODO import `auto_fp16` from mmcv and delete them from mmcls
try:
from mmcv.runner import auto_fp16
except ImportError:
warnings.warn('auto_fp16 from mmcls will be deprecated.'
'Please install mmcv>=1.1.4.')
from mmcls.core import auto_fp16
from mmcls.core.visualization import imshow_infos
class BaseClassifier(BaseModule, metaclass=ABCMeta):
......@@ -34,13 +27,14 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
return hasattr(self, 'head') and self.head is not None
@abstractmethod
def extract_feat(self, imgs):
def extract_feat(self, imgs, stage=None):
pass
def extract_feats(self, imgs):
assert isinstance(imgs, list)
def extract_feats(self, imgs, stage=None):
assert isinstance(imgs, Sequence)
kwargs = {} if stage is None else {'stage': stage}
for img in imgs:
yield self.extract_feat(img)
yield self.extract_feat(img, **kwargs)
@abstractmethod
def forward_train(self, imgs, **kwargs):
......@@ -117,7 +111,7 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
return loss, log_vars
def train_step(self, data, optimizer):
def train_step(self, data, optimizer=None, **kwargs):
"""The iteration step during training.
This method defines an iteration step during training, except for the
......@@ -128,20 +122,19 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
Args:
data (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused
and reserved.
optimizer (:obj:`torch.optim.Optimizer` | dict, optional): The
optimizer of runner is passed to ``train_step()``. This
argument is unused and reserved.
Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
``num_samples``.
``loss`` is a tensor for back propagation, which can be a
weighted sum of multiple losses.
``log_vars`` contains all the variables to be sent to the
logger.
``num_samples`` indicates the batch size (when the model is
DDP, it means the batch size on each GPU), which is used for
averaging the logs.
dict: Dict of outputs. The following fields are contained.
- loss (torch.Tensor): A tensor for back propagation, which \
can be a weighted sum of multiple losses.
- log_vars (dict): Dict contains all the variables to be sent \
to the logger.
- num_samples (int): Indicates the batch size (when the model \
is DDP, it means the batch size on each GPU), which is \
used for averaging the logs.
"""
losses = self(**data)
loss, log_vars = self._parse_losses(losses)
......@@ -151,12 +144,28 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
return outputs
def val_step(self, data, optimizer):
def val_step(self, data, optimizer=None, **kwargs):
"""The iteration step during validation.
This method shares the same signature as :func:`train_step`, but used
during val epochs. Note that the evaluation after training epochs is
not implemented with this method, but an evaluation hook.
Args:
data (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict, optional): The
optimizer of runner is passed to ``train_step()``. This
argument is unused and reserved.
Returns:
dict: Dict of outputs. The following fields are contained.
- loss (torch.Tensor): A tensor for back propagation, which \
can be a weighted sum of multiple losses.
- log_vars (dict): Dict contains all the variables to be sent \
to the logger.
- num_samples (int): Indicates the batch size (when the model \
is DDP, it means the batch size on each GPU), which is \
used for averaging the logs.
"""
losses = self(**data)
loss, log_vars = self._parse_losses(losses)
......@@ -169,56 +178,47 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
def show_result(self,
img,
result,
text_color='green',
text_color='white',
font_scale=0.5,
row_width=20,
show=False,
fig_size=(15, 10),
win_name='',
wait_time=0,
out_file=None):
"""Draw `result` over `img`.
Args:
img (str or Tensor): The image to be displayed.
result (Tensor): The classification results to draw over `img`.
img (str or ndarray): The image to be displayed.
result (dict): The classification results to draw over `img`.
text_color (str or tuple or :obj:`Color`): Color of texts.
font_scale (float): Font scales of texts.
row_width (int): width between each row of results on the image.
show (bool): Whether to show the image.
Default: False.
fig_size (tuple): Image show figure size. Defaults to (15, 10).
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
wait_time (int): How many seconds to display the image.
Defaults to 0.
out_file (str or None): The filename to write the image.
Default: None.
Returns:
img (Tensor): Only if not `show` or `out_file`
img (ndarray): Image with overlaid results.
"""
img = mmcv.imread(img)
img = img.copy()
# write results on left-top of the image
x, y = 0, row_width
text_color = color_val(text_color)
for k, v in result.items():
if isinstance(v, float):
v = f'{v:.2f}'
label_text = f'{k}: {v}'
cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
font_scale, text_color)
y += row_width
# if out_file specified, do not show image in window
if out_file is not None:
show = False
if show:
mmcv.imshow(img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)
if not (show or out_file):
warnings.warn('show==False and out_file is not specified, only '
'result image will be returned')
img = imshow_infos(
img,
result,
text_color=text_color,
font_size=int(font_scale * 50),
row_width=row_width,
win_name=win_name,
show=show,
fig_size=fig_size,
wait_time=wait_time,
out_file=out_file)
return img
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment