Commit cbc25585 authored by limm's avatar limm
Browse files

add mmpretrain/ part

parent 1baf0566
Pipeline #2801 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import trunc_normal_
from mmpretrain.registry import MODELS
from ..utils import to_2tuple
from .base_backbone import BaseBackbone
class TransformerBlock(BaseModule):
"""Implement a transformer block in TnTLayer.
Args:
embed_dims (int): The feature dimension
num_heads (int): Parallel attention heads
ffn_ratio (int): A ratio to calculate the hidden_dims in ffn layer.
Default: 4
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Default 0.
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.
drop_path_rate (float): stochastic depth rate. Default 0.
num_fcs (int): The number of fully-connected layers for FFNs. Default 2
qkv_bias (bool): Enable bias for qkv if True. Default False
act_cfg (dict): The activation config for FFNs. Defaults to GELU.
norm_cfg (dict): Config dict for normalization layer. Default
layer normalization
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim) or (n, batch, embed_dim).
(batch, n, embed_dim) is common case in CV. Defaults to False
init_cfg (dict, optional): Initialization config dict. Defaults to None
"""
def __init__(self,
embed_dims,
num_heads,
ffn_ratio=4,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
qkv_bias=False,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
batch_first=True,
init_cfg=None):
super(TransformerBlock, self).__init__(init_cfg=init_cfg)
self.norm_attn = build_norm_layer(norm_cfg, embed_dims)[1]
self.attn = MultiheadAttention(
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
batch_first=batch_first)
self.norm_ffn = build_norm_layer(norm_cfg, embed_dims)[1]
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=embed_dims * ffn_ratio,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg)
if not qkv_bias:
self.attn.attn.in_proj_bias = None
def forward(self, x):
x = self.attn(self.norm_attn(x), identity=x)
x = self.ffn(self.norm_ffn(x), identity=x)
return x
class TnTLayer(BaseModule):
"""Implement one encoder layer in Transformer in Transformer.
Args:
num_pixel (int): The pixel number in target patch transformed with
a linear projection in inner transformer
embed_dims_inner (int): Feature dimension in inner transformer block
embed_dims_outer (int): Feature dimension in outer transformer block
num_heads_inner (int): Parallel attention heads in inner transformer.
num_heads_outer (int): Parallel attention heads in outer transformer.
inner_block_cfg (dict): Extra config of inner transformer block.
Defaults to empty dict.
outer_block_cfg (dict): Extra config of outer transformer block.
Defaults to empty dict.
norm_cfg (dict): Config dict for normalization layer. Default
layer normalization
init_cfg (dict, optional): Initialization config dict. Defaults to None
"""
def __init__(self,
num_pixel,
embed_dims_inner,
embed_dims_outer,
num_heads_inner,
num_heads_outer,
inner_block_cfg=dict(),
outer_block_cfg=dict(),
norm_cfg=dict(type='LN'),
init_cfg=None):
super(TnTLayer, self).__init__(init_cfg=init_cfg)
self.inner_block = TransformerBlock(
embed_dims=embed_dims_inner,
num_heads=num_heads_inner,
**inner_block_cfg)
self.norm_proj = build_norm_layer(norm_cfg, embed_dims_inner)[1]
self.projection = nn.Linear(
embed_dims_inner * num_pixel, embed_dims_outer, bias=True)
self.outer_block = TransformerBlock(
embed_dims=embed_dims_outer,
num_heads=num_heads_outer,
**outer_block_cfg)
def forward(self, pixel_embed, patch_embed):
pixel_embed = self.inner_block(pixel_embed)
B, N, C = patch_embed.size()
patch_embed[:, 1:] = patch_embed[:, 1:] + self.projection(
self.norm_proj(pixel_embed).reshape(B, N - 1, -1))
patch_embed = self.outer_block(patch_embed)
return pixel_embed, patch_embed
class PixelEmbed(BaseModule):
"""Image to Pixel Embedding.
Args:
img_size (int | tuple): The size of input image
patch_size (int): The size of one patch
in_channels (int): The num of input channels
embed_dims_inner (int): The num of channels of the target patch
transformed with a linear projection in inner transformer
stride (int): The stride of the conv2d layer. We use a conv2d layer
and a unfold layer to implement image to pixel embedding.
init_cfg (dict, optional): Initialization config dict
"""
def __init__(self,
img_size=224,
patch_size=16,
in_channels=3,
embed_dims_inner=48,
stride=4,
init_cfg=None):
super(PixelEmbed, self).__init__(init_cfg=init_cfg)
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
# patches_resolution property necessary for resizing
# positional embedding
patches_resolution = [
img_size[0] // patch_size[0], img_size[1] // patch_size[1]
]
num_patches = patches_resolution[0] * patches_resolution[1]
self.img_size = img_size
self.num_patches = num_patches
self.embed_dims_inner = embed_dims_inner
new_patch_size = [math.ceil(ps / stride) for ps in patch_size]
self.new_patch_size = new_patch_size
self.proj = nn.Conv2d(
in_channels,
self.embed_dims_inner,
kernel_size=7,
padding=3,
stride=stride)
self.unfold = nn.Unfold(
kernel_size=new_patch_size, stride=new_patch_size)
def forward(self, x, pixel_pos):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model " \
f'({self.img_size[0]}*{self.img_size[1]}).'
x = self.proj(x)
x = self.unfold(x)
x = x.transpose(1,
2).reshape(B * self.num_patches, self.embed_dims_inner,
self.new_patch_size[0],
self.new_patch_size[1])
x = x + pixel_pos
x = x.reshape(B * self.num_patches, self.embed_dims_inner,
-1).transpose(1, 2)
return x
@MODELS.register_module()
class TNT(BaseBackbone):
"""Transformer in Transformer.
A PyTorch implement of: `Transformer in Transformer
<https://arxiv.org/abs/2103.00112>`_
Inspiration from
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/tnt.py
Args:
arch (str | dict): Vision Transformer architecture
Default: 'b'
img_size (int | tuple): Input image size. Defaults to 224
patch_size (int | tuple): The patch size. Deault to 16
in_channels (int): Number of input channels. Defaults to 3
ffn_ratio (int): A ratio to calculate the hidden_dims in ffn layer.
Default: 4
qkv_bias (bool): Enable bias for qkv if True. Default False
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Default 0.
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.
drop_path_rate (float): stochastic depth rate. Default 0.
act_cfg (dict): The activation config for FFNs. Defaults to GELU.
norm_cfg (dict): Config dict for normalization layer. Default
layer normalization
first_stride (int): The stride of the conv2d layer. We use a conv2d
layer and a unfold layer to implement image to pixel embedding.
num_fcs (int): The number of fully-connected layers for FFNs. Default 2
init_cfg (dict, optional): Initialization config dict
"""
arch_zoo = {
**dict.fromkeys(
['s', 'small'], {
'embed_dims_outer': 384,
'embed_dims_inner': 24,
'num_layers': 12,
'num_heads_outer': 6,
'num_heads_inner': 4
}),
**dict.fromkeys(
['b', 'base'], {
'embed_dims_outer': 640,
'embed_dims_inner': 40,
'num_layers': 12,
'num_heads_outer': 10,
'num_heads_inner': 4
})
}
def __init__(self,
arch='b',
img_size=224,
patch_size=16,
in_channels=3,
ffn_ratio=4,
qkv_bias=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
first_stride=4,
num_fcs=2,
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=.02),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)
]):
super(TNT, self).__init__(init_cfg=init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {
'embed_dims_outer', 'embed_dims_inner', 'num_layers',
'num_heads_inner', 'num_heads_outer'
}
assert isinstance(arch, dict) and set(arch) == essential_keys, \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.embed_dims_inner = self.arch_settings['embed_dims_inner']
self.embed_dims_outer = self.arch_settings['embed_dims_outer']
# embed_dims for consistency with other models
self.embed_dims = self.embed_dims_outer
self.num_layers = self.arch_settings['num_layers']
self.num_heads_inner = self.arch_settings['num_heads_inner']
self.num_heads_outer = self.arch_settings['num_heads_outer']
self.pixel_embed = PixelEmbed(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dims_inner=self.embed_dims_inner,
stride=first_stride)
num_patches = self.pixel_embed.num_patches
self.num_patches = num_patches
new_patch_size = self.pixel_embed.new_patch_size
num_pixel = new_patch_size[0] * new_patch_size[1]
self.norm1_proj = build_norm_layer(norm_cfg, num_pixel *
self.embed_dims_inner)[1]
self.projection = nn.Linear(num_pixel * self.embed_dims_inner,
self.embed_dims_outer)
self.norm2_proj = build_norm_layer(norm_cfg, self.embed_dims_outer)[1]
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims_outer))
self.patch_pos = nn.Parameter(
torch.zeros(1, num_patches + 1, self.embed_dims_outer))
self.pixel_pos = nn.Parameter(
torch.zeros(1, self.embed_dims_inner, new_patch_size[0],
new_patch_size[1]))
self.drop_after_pos = nn.Dropout(p=drop_rate)
dpr = [
x.item()
for x in torch.linspace(0, drop_path_rate, self.num_layers)
] # stochastic depth decay rule
self.layers = ModuleList()
for i in range(self.num_layers):
block_cfg = dict(
ffn_ratio=ffn_ratio,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=dpr[i],
num_fcs=num_fcs,
qkv_bias=qkv_bias,
norm_cfg=norm_cfg,
batch_first=True)
self.layers.append(
TnTLayer(
num_pixel=num_pixel,
embed_dims_inner=self.embed_dims_inner,
embed_dims_outer=self.embed_dims_outer,
num_heads_inner=self.num_heads_inner,
num_heads_outer=self.num_heads_outer,
inner_block_cfg=block_cfg,
outer_block_cfg=block_cfg,
norm_cfg=norm_cfg))
self.norm = build_norm_layer(norm_cfg, self.embed_dims_outer)[1]
trunc_normal_(self.cls_token, std=.02)
trunc_normal_(self.patch_pos, std=.02)
trunc_normal_(self.pixel_pos, std=.02)
def forward(self, x):
B = x.shape[0]
pixel_embed = self.pixel_embed(x, self.pixel_pos)
patch_embed = self.norm2_proj(
self.projection(
self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1))))
patch_embed = torch.cat(
(self.cls_token.expand(B, -1, -1), patch_embed), dim=1)
patch_embed = patch_embed + self.patch_pos
patch_embed = self.drop_after_pos(patch_embed)
for layer in self.layers:
pixel_embed, patch_embed = layer(pixel_embed, patch_embed)
patch_embed = self.norm(patch_embed)
return (patch_embed[:, 0], )
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Conv2d, build_norm_layer
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import (constant_init, normal_init,
trunc_normal_init)
from torch.nn.modules.batchnorm import _BatchNorm
from mmpretrain.registry import MODELS
from ..utils import ConditionalPositionEncoding, MultiheadAttention
class GlobalSubsampledAttention(MultiheadAttention):
"""Global Sub-sampled Attention (GSA) module.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
input_dims (int, optional): The input dimension, and if None,
use ``embed_dims``. Defaults to None.
attn_drop (float): Dropout rate of the dropout layer after the
attention calculation of query and key. Defaults to 0.
proj_drop (float): Dropout rate of the dropout layer after the
output projection. Defaults to 0.
dropout_layer (dict): The dropout config before adding the shortcut.
Defaults to ``dict(type='Dropout', drop_prob=0.)``.
qkv_bias (bool): If True, add a learnable bias to q, k, v.
Defaults to True.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
proj_bias (bool) If True, add a learnable bias to output projection.
Defaults to True.
v_shortcut (bool): Add a shortcut from value to output. It's usually
used if ``input_dims`` is different from ``embed_dims``.
Defaults to False.
sr_ratio (float): The ratio of spatial reduction in attention modules.
Defaults to 1.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
norm_cfg=dict(type='LN'),
qkv_bias=True,
sr_ratio=1,
**kwargs):
super(GlobalSubsampledAttention,
self).__init__(embed_dims, num_heads, **kwargs)
self.qkv_bias = qkv_bias
self.q = nn.Linear(self.input_dims, embed_dims, bias=qkv_bias)
self.kv = nn.Linear(self.input_dims, embed_dims * 2, bias=qkv_bias)
# remove self.qkv, here split into self.q, self.kv
delattr(self, 'qkv')
self.sr_ratio = sr_ratio
if sr_ratio > 1:
# use a conv as the spatial-reduction operation, the kernel_size
# and stride in conv are equal to the sr_ratio.
self.sr = Conv2d(
in_channels=embed_dims,
out_channels=embed_dims,
kernel_size=sr_ratio,
stride=sr_ratio)
# The ret[0] of build_norm_layer is norm name.
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
def forward(self, x, hw_shape):
B, N, C = x.shape
H, W = hw_shape
assert H * W == N, 'The product of h and w of hw_shape must be N, ' \
'which is the 2nd dim number of the input Tensor x.'
q = self.q(x).reshape(B, N, self.num_heads,
C // self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x = x.permute(0, 2, 1).reshape(B, C, *hw_shape) # BNC_2_BCHW
x = self.sr(x)
x = x.reshape(B, C, -1).permute(0, 2, 1) # BCHW_2_BNC
x = self.norm(x)
kv = self.kv(x).reshape(B, -1, 2, self.num_heads,
self.head_dims).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn_drop = self.attn_drop if self.training else 0.
x = self.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
x = x.transpose(1, 2).reshape(B, N, self.embed_dims)
x = self.proj(x)
x = self.out_drop(self.proj_drop(x))
if self.v_shortcut:
x = v.squeeze(1) + x
return x
class GSAEncoderLayer(BaseModule):
"""Implements one encoder layer with GlobalSubsampledAttention(GSA).
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Default: 0.0.
attn_drop_rate (float): The drop out rate for attention layer.
Default: 0.0.
drop_path_rate (float): Stochastic depth rate. Default 0.0.
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
qkv_bias (bool): Enable bias for qkv if True. Default: True
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
sr_ratio (float): The ratio of spatial reduction in attention modules.
Defaults to 1.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
qkv_bias=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
sr_ratio=1.,
init_cfg=None):
super(GSAEncoderLayer, self).__init__(init_cfg=init_cfg)
self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1]
self.attn = GlobalSubsampledAttention(
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
qkv_bias=qkv_bias,
norm_cfg=norm_cfg,
sr_ratio=sr_ratio)
self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1]
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg,
add_identity=False)
self.drop_path = build_dropout(
dict(type='DropPath', drop_prob=drop_path_rate)
) if drop_path_rate > 0. else nn.Identity()
def forward(self, x, hw_shape):
x = x + self.drop_path(self.attn(self.norm1(x), hw_shape))
x = x + self.drop_path(self.ffn(self.norm2(x)))
return x
class LocallyGroupedSelfAttention(BaseModule):
"""Locally-grouped Self Attention (LSA) module.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads. Default: 8
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Default: False.
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Default: 0.0
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0.
window_size(int): Window size of LSA. Default: 1.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop_rate=0.,
proj_drop_rate=0.,
window_size=1,
init_cfg=None):
super(LocallyGroupedSelfAttention, self).__init__(init_cfg=init_cfg)
assert embed_dims % num_heads == 0, \
f'dim {embed_dims} should be divided by num_heads {num_heads}'
self.embed_dims = embed_dims
self.num_heads = num_heads
head_dim = embed_dims // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop_rate)
self.proj = nn.Linear(embed_dims, embed_dims)
self.proj_drop = nn.Dropout(proj_drop_rate)
self.window_size = window_size
def forward(self, x, hw_shape):
B, N, C = x.shape
H, W = hw_shape
x = x.view(B, H, W, C)
# pad feature maps to multiples of Local-groups
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
# calculate attention mask for LSA
Hp, Wp = x.shape[1:-1]
_h, _w = Hp // self.window_size, Wp // self.window_size
mask = torch.zeros((1, Hp, Wp), device=x.device)
mask[:, -pad_b:, :].fill_(1)
mask[:, :, -pad_r:].fill_(1)
# [B, _h, _w, window_size, window_size, C]
x = x.reshape(B, _h, self.window_size, _w, self.window_size,
C).transpose(2, 3)
mask = mask.reshape(1, _h, self.window_size, _w,
self.window_size).transpose(2, 3).reshape(
1, _h * _w,
self.window_size * self.window_size)
# [1, _h*_w, window_size*window_size, window_size*window_size]
attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3)
attn_mask = attn_mask.masked_fill(attn_mask != 0,
float(-1000.0)).masked_fill(
attn_mask == 0, float(0.0))
# [3, B, _w*_h, nhead, window_size*window_size, dim]
qkv = self.qkv(x).reshape(B, _h * _w,
self.window_size * self.window_size, 3,
self.num_heads, C // self.num_heads).permute(
3, 0, 1, 4, 2, 5)
q, k, v = qkv[0], qkv[1], qkv[2]
# [B, _h*_w, n_head, window_size*window_size, window_size*window_size]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn + attn_mask.unsqueeze(2)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.window_size,
self.window_size, C)
x = attn.transpose(2, 3).reshape(B, _h * self.window_size,
_w * self.window_size, C)
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class LSAEncoderLayer(BaseModule):
"""Implements one encoder layer with LocallyGroupedSelfAttention(LSA).
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Default: 0.0.
attn_drop_rate (float, optional): Dropout ratio of attention weight.
Default: 0.0
drop_path_rate (float): Stochastic depth rate. Default 0.0.
num_fcs (int): The number of fully-connected layers for FFNs.
Default: 2.
qkv_bias (bool): Enable bias for qkv if True. Default: True
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Default: None.
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
window_size (int): Window size of LSA. Default: 1.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
qkv_bias=True,
qk_scale=None,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
window_size=1,
init_cfg=None):
super(LSAEncoderLayer, self).__init__(init_cfg=init_cfg)
self.norm1 = build_norm_layer(norm_cfg, embed_dims, postfix=1)[1]
self.attn = LocallyGroupedSelfAttention(embed_dims, num_heads,
qkv_bias, qk_scale,
attn_drop_rate, drop_rate,
window_size)
self.norm2 = build_norm_layer(norm_cfg, embed_dims, postfix=2)[1]
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg,
add_identity=False)
self.drop_path = build_dropout(
dict(type='DropPath', drop_prob=drop_path_rate)
) if drop_path_rate > 0. else nn.Identity()
def forward(self, x, hw_shape):
x = x + self.drop_path(self.attn(self.norm1(x), hw_shape))
x = x + self.drop_path(self.ffn(self.norm2(x)))
return x
@MODELS.register_module()
class PCPVT(BaseModule):
"""The backbone of Twins-PCPVT.
This backbone is the implementation of `Twins: Revisiting the Design
of Spatial Attention in Vision Transformers
<https://arxiv.org/abs/1512.03385>`_.
Args:
arch (dict, str): PCPVT architecture, a str value in arch zoo or a
detailed configuration dict with 7 keys, and the length of all the
values in dict should be the same:
- depths (List[int]): The number of encoder layers in each stage.
- embed_dims (List[int]): Embedding dimension in each stage.
- patch_sizes (List[int]): The patch sizes in each stage.
- num_heads (List[int]): Numbers of attention head in each stage.
- strides (List[int]): The strides in each stage.
- mlp_ratios (List[int]): The ratios of mlp in each stage.
- sr_ratios (List[int]): The ratios of GSA-encoder layers in each
stage.
in_channels (int): Number of input channels. Defaults to 3.
out_indices (tuple[int]): Output from which stages.
Defaults to ``(3, )``.
qkv_bias (bool): Enable bias for qkv if True. Defaults to False.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
attn_drop_rate (float): The drop out rate for attention layer.
Defaults to 0.0
drop_path_rate (float): Stochastic depth rate. Defaults to 0.0.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
norm_after_stage(bool, List[bool]): Add extra norm after each stage.
Defaults to False.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmpretrain.models import PCPVT
>>> import torch
>>> pcpvt_cfg = {'arch': "small",
>>> 'norm_after_stage': [False, False, False, True]}
>>> model = PCPVT(**pcpvt_cfg)
>>> x = torch.rand(1, 3, 224, 224)
>>> outputs = model(x)
>>> print(outputs[-1].shape)
torch.Size([1, 512, 7, 7])
>>> pcpvt_cfg['norm_after_stage'] = [True, True, True, True]
>>> pcpvt_cfg['out_indices'] = (0, 1, 2, 3)
>>> model = PCPVT(**pcpvt_cfg)
>>> outputs = model(x)
>>> for feat in outputs:
>>> print(feat.shape)
torch.Size([1, 64, 56, 56])
torch.Size([1, 128, 28, 28])
torch.Size([1, 320, 14, 14])
torch.Size([1, 512, 7, 7])
"""
arch_zoo = {
**dict.fromkeys(['s', 'small'],
{'embed_dims': [64, 128, 320, 512],
'depths': [3, 4, 6, 3],
'num_heads': [1, 2, 5, 8],
'patch_sizes': [4, 2, 2, 2],
'strides': [4, 2, 2, 2],
'mlp_ratios': [8, 8, 4, 4],
'sr_ratios': [8, 4, 2, 1]}),
**dict.fromkeys(['b', 'base'],
{'embed_dims': [64, 128, 320, 512],
'depths': [3, 4, 18, 3],
'num_heads': [1, 2, 5, 8],
'patch_sizes': [4, 2, 2, 2],
'strides': [4, 2, 2, 2],
'mlp_ratios': [8, 8, 4, 4],
'sr_ratios': [8, 4, 2, 1]}),
**dict.fromkeys(['l', 'large'],
{'embed_dims': [64, 128, 320, 512],
'depths': [3, 8, 27, 3],
'num_heads': [1, 2, 5, 8],
'patch_sizes': [4, 2, 2, 2],
'strides': [4, 2, 2, 2],
'mlp_ratios': [8, 8, 4, 4],
'sr_ratios': [8, 4, 2, 1]}),
} # yapf: disable
essential_keys = {
'embed_dims', 'depths', 'num_heads', 'patch_sizes', 'strides',
'mlp_ratios', 'sr_ratios'
}
def __init__(self,
arch,
in_channels=3,
out_indices=(3, ),
qkv_bias=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_cfg=dict(type='LN'),
norm_after_stage=False,
init_cfg=None):
super(PCPVT, self).__init__(init_cfg=init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
assert isinstance(arch, dict) and (
set(arch) == self.essential_keys
), f'Custom arch needs a dict with keys {self.essential_keys}.'
self.arch_settings = arch
self.depths = self.arch_settings['depths']
self.embed_dims = self.arch_settings['embed_dims']
self.patch_sizes = self.arch_settings['patch_sizes']
self.strides = self.arch_settings['strides']
self.mlp_ratios = self.arch_settings['mlp_ratios']
self.num_heads = self.arch_settings['num_heads']
self.sr_ratios = self.arch_settings['sr_ratios']
self.num_extra_tokens = 0 # there is no cls-token in Twins
self.num_stage = len(self.depths)
for key, value in self.arch_settings.items():
assert isinstance(value, list) and len(value) == self.num_stage, (
'Length of setting item in arch dict must be type of list and'
' have the same length.')
# patch_embeds
self.patch_embeds = ModuleList()
self.position_encoding_drops = ModuleList()
self.stages = ModuleList()
for i in range(self.num_stage):
# use in_channels of the model in the first stage
if i == 0:
stage_in_channels = in_channels
else:
stage_in_channels = self.embed_dims[i - 1]
self.patch_embeds.append(
PatchEmbed(
in_channels=stage_in_channels,
embed_dims=self.embed_dims[i],
conv_type='Conv2d',
kernel_size=self.patch_sizes[i],
stride=self.strides[i],
padding='corner',
norm_cfg=dict(type='LN')))
self.position_encoding_drops.append(nn.Dropout(p=drop_rate))
# PEGs
self.position_encodings = ModuleList([
ConditionalPositionEncoding(embed_dim, embed_dim)
for embed_dim in self.embed_dims
])
# stochastic depth
total_depth = sum(self.depths)
self.dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
] # stochastic depth decay rule
cur = 0
for k in range(len(self.depths)):
_block = ModuleList([
GSAEncoderLayer(
embed_dims=self.embed_dims[k],
num_heads=self.num_heads[k],
feedforward_channels=self.mlp_ratios[k] *
self.embed_dims[k],
attn_drop_rate=attn_drop_rate,
drop_rate=drop_rate,
drop_path_rate=self.dpr[cur + i],
num_fcs=2,
qkv_bias=qkv_bias,
act_cfg=dict(type='GELU'),
norm_cfg=norm_cfg,
sr_ratio=self.sr_ratios[k]) for i in range(self.depths[k])
])
self.stages.append(_block)
cur += self.depths[k]
self.out_indices = out_indices
assert isinstance(norm_after_stage, (bool, list))
if isinstance(norm_after_stage, bool):
self.norm_after_stage = [norm_after_stage] * self.num_stage
else:
self.norm_after_stage = norm_after_stage
assert len(self.norm_after_stage) == self.num_stage, \
(f'Number of norm_after_stage({len(self.norm_after_stage)}) should'
f' be equal to the number of stages({self.num_stage}).')
for i, has_norm in enumerate(self.norm_after_stage):
assert isinstance(has_norm, bool), 'norm_after_stage should be ' \
'bool or List[bool].'
if has_norm and norm_cfg is not None:
norm_layer = build_norm_layer(norm_cfg, self.embed_dims[i])[1]
else:
norm_layer = nn.Identity()
self.add_module(f'norm_after_stage{i}', norm_layer)
def init_weights(self):
if self.init_cfg is not None:
super(PCPVT, self).init_weights()
else:
for m in self.modules():
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
constant_init(m, val=1.0, bias=0.)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[
1] * m.out_channels
fan_out //= m.groups
normal_init(
m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
def forward(self, x):
outputs = list()
b = x.shape[0]
for i in range(self.num_stage):
x, hw_shape = self.patch_embeds[i](x)
h, w = hw_shape
x = self.position_encoding_drops[i](x)
for j, blk in enumerate(self.stages[i]):
x = blk(x, hw_shape)
if j == 0:
x = self.position_encodings[i](x, hw_shape)
norm_layer = getattr(self, f'norm_after_stage{i}')
x = norm_layer(x)
x = x.reshape(b, h, w, -1).permute(0, 3, 1, 2).contiguous()
if i in self.out_indices:
outputs.append(x)
return tuple(outputs)
@MODELS.register_module()
class SVT(PCPVT):
"""The backbone of Twins-SVT.
This backbone is the implementation of `Twins: Revisiting the Design
of Spatial Attention in Vision Transformers
<https://arxiv.org/abs/1512.03385>`_.
Args:
arch (dict, str): SVT architecture, a str value in arch zoo or a
detailed configuration dict with 8 keys, and the length of all the
values in dict should be the same:
- depths (List[int]): The number of encoder layers in each stage.
- embed_dims (List[int]): Embedding dimension in each stage.
- patch_sizes (List[int]): The patch sizes in each stage.
- num_heads (List[int]): Numbers of attention head in each stage.
- strides (List[int]): The strides in each stage.
- mlp_ratios (List[int]): The ratios of mlp in each stage.
- sr_ratios (List[int]): The ratios of GSA-encoder layers in each
stage.
- windiow_sizes (List[int]): The window sizes in LSA-encoder layers
in each stage.
in_channels (int): Number of input channels. Defaults to 3.
out_indices (tuple[int]): Output from which stages.
Defaults to (3, ).
qkv_bias (bool): Enable bias for qkv if True. Defaults to False.
drop_rate (float): Dropout rate. Defaults to 0.
attn_drop_rate (float): Dropout ratio of attention weight.
Defaults to 0.0
drop_path_rate (float): Stochastic depth rate. Defaults to 0.2.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
norm_after_stage(bool, List[bool]): Add extra norm after each stage.
Defaults to False.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmpretrain.models import SVT
>>> import torch
>>> svt_cfg = {'arch': "small",
>>> 'norm_after_stage': [False, False, False, True]}
>>> model = SVT(**svt_cfg)
>>> x = torch.rand(1, 3, 224, 224)
>>> outputs = model(x)
>>> print(outputs[-1].shape)
torch.Size([1, 512, 7, 7])
>>> svt_cfg["out_indices"] = (0, 1, 2, 3)
>>> svt_cfg["norm_after_stage"] = [True, True, True, True]
>>> model = SVT(**svt_cfg)
>>> output = model(x)
>>> for feat in output:
>>> print(feat.shape)
torch.Size([1, 64, 56, 56])
torch.Size([1, 128, 28, 28])
torch.Size([1, 320, 14, 14])
torch.Size([1, 512, 7, 7])
"""
arch_zoo = {
**dict.fromkeys(['s', 'small'],
{'embed_dims': [64, 128, 256, 512],
'depths': [2, 2, 10, 4],
'num_heads': [2, 4, 8, 16],
'patch_sizes': [4, 2, 2, 2],
'strides': [4, 2, 2, 2],
'mlp_ratios': [4, 4, 4, 4],
'sr_ratios': [8, 4, 2, 1],
'window_sizes': [7, 7, 7, 7]}),
**dict.fromkeys(['b', 'base'],
{'embed_dims': [96, 192, 384, 768],
'depths': [2, 2, 18, 2],
'num_heads': [3, 6, 12, 24],
'patch_sizes': [4, 2, 2, 2],
'strides': [4, 2, 2, 2],
'mlp_ratios': [4, 4, 4, 4],
'sr_ratios': [8, 4, 2, 1],
'window_sizes': [7, 7, 7, 7]}),
**dict.fromkeys(['l', 'large'],
{'embed_dims': [128, 256, 512, 1024],
'depths': [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32],
'patch_sizes': [4, 2, 2, 2],
'strides': [4, 2, 2, 2],
'mlp_ratios': [4, 4, 4, 4],
'sr_ratios': [8, 4, 2, 1],
'window_sizes': [7, 7, 7, 7]}),
} # yapf: disable
essential_keys = {
'embed_dims', 'depths', 'num_heads', 'patch_sizes', 'strides',
'mlp_ratios', 'sr_ratios', 'window_sizes'
}
def __init__(self,
arch,
in_channels=3,
out_indices=(3, ),
qkv_bias=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.0,
norm_cfg=dict(type='LN'),
norm_after_stage=False,
init_cfg=None):
super(SVT, self).__init__(arch, in_channels, out_indices, qkv_bias,
drop_rate, attn_drop_rate, drop_path_rate,
norm_cfg, norm_after_stage, init_cfg)
self.window_sizes = self.arch_settings['window_sizes']
for k in range(self.num_stage):
for i in range(self.depths[k]):
# in even-numbered layers of each stage, replace GSA with LSA
if i % 2 == 0:
ffn_channels = self.mlp_ratios[k] * self.embed_dims[k]
self.stages[k][i] = \
LSAEncoderLayer(
embed_dims=self.embed_dims[k],
num_heads=self.num_heads[k],
feedforward_channels=ffn_channels,
drop_rate=drop_rate,
norm_cfg=norm_cfg,
attn_drop_rate=attn_drop_rate,
drop_path_rate=self.dpr[sum(self.depths[:k])+i],
qkv_bias=qkv_bias,
window_size=self.window_sizes[k])
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import Conv2d, build_activation_layer, build_norm_layer
from mmcv.cnn.bricks import DropPath
from mmcv.cnn.bricks.transformer import PatchEmbed
from mmengine.model import BaseModule, ModuleList
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone
class MixFFN(BaseModule):
"""An implementation of MixFFN of VAN. Refer to
mmdetection/mmdet/models/backbones/pvt.py.
The differences between MixFFN & FFN:
1. Use 1X1 Conv to replace Linear layer.
2. Introduce 3X3 Depth-wise Conv to encode positional information.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`.
feedforward_channels (int): The hidden dimension of FFNs.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='GELU').
ffn_drop (float, optional): Probability of an element to be
zeroed in FFN. Default 0.0.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
feedforward_channels,
act_cfg=dict(type='GELU'),
ffn_drop=0.,
init_cfg=None):
super(MixFFN, self).__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.act_cfg = act_cfg
self.fc1 = Conv2d(
in_channels=embed_dims,
out_channels=feedforward_channels,
kernel_size=1)
self.dwconv = Conv2d(
in_channels=feedforward_channels,
out_channels=feedforward_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True,
groups=feedforward_channels)
self.act = build_activation_layer(act_cfg)
self.fc2 = Conv2d(
in_channels=feedforward_channels,
out_channels=embed_dims,
kernel_size=1)
self.drop = nn.Dropout(ffn_drop)
def forward(self, x):
x = self.fc1(x)
x = self.dwconv(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class LKA(BaseModule):
"""Large Kernel Attention(LKA) of VAN.
.. code:: text
DW_conv (depth-wise convolution)
|
|
DW_D_conv (depth-wise dilation convolution)
|
|
Transition Convolution (1×1 convolution)
Args:
embed_dims (int): Number of input channels.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self, embed_dims, init_cfg=None):
super(LKA, self).__init__(init_cfg=init_cfg)
# a spatial local convolution (depth-wise convolution)
self.DW_conv = Conv2d(
in_channels=embed_dims,
out_channels=embed_dims,
kernel_size=5,
padding=2,
groups=embed_dims)
# a spatial long-range convolution (depth-wise dilation convolution)
self.DW_D_conv = Conv2d(
in_channels=embed_dims,
out_channels=embed_dims,
kernel_size=7,
stride=1,
padding=9,
groups=embed_dims,
dilation=3)
self.conv1 = Conv2d(
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
def forward(self, x):
u = x.clone()
attn = self.DW_conv(x)
attn = self.DW_D_conv(attn)
attn = self.conv1(attn)
return u * attn
class SpatialAttention(BaseModule):
"""Basic attention module in VANBloack.
Args:
embed_dims (int): Number of input channels.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='GELU').
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self, embed_dims, act_cfg=dict(type='GELU'), init_cfg=None):
super(SpatialAttention, self).__init__(init_cfg=init_cfg)
self.proj_1 = Conv2d(
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
self.activation = build_activation_layer(act_cfg)
self.spatial_gating_unit = LKA(embed_dims)
self.proj_2 = Conv2d(
in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
def forward(self, x):
shorcut = x.clone()
x = self.proj_1(x)
x = self.activation(x)
x = self.spatial_gating_unit(x)
x = self.proj_2(x)
x = x + shorcut
return x
class VANBlock(BaseModule):
"""A block of VAN.
Args:
embed_dims (int): Number of input channels.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='GELU').
layer_scale_init_value (float): Init value for Layer Scale.
Defaults to 1e-2.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
ffn_ratio=4.,
drop_rate=0.,
drop_path_rate=0.,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='BN', eps=1e-5),
layer_scale_init_value=1e-2,
init_cfg=None):
super(VANBlock, self).__init__(init_cfg=init_cfg)
self.out_channels = embed_dims
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
self.attn = SpatialAttention(embed_dims, act_cfg=act_cfg)
self.drop_path = DropPath(
drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
mlp_hidden_dim = int(embed_dims * ffn_ratio)
self.mlp = MixFFN(
embed_dims=embed_dims,
feedforward_channels=mlp_hidden_dim,
act_cfg=act_cfg,
ffn_drop=drop_rate)
self.layer_scale_1 = nn.Parameter(
layer_scale_init_value * torch.ones((embed_dims)),
requires_grad=True) if layer_scale_init_value > 0 else None
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((embed_dims)),
requires_grad=True) if layer_scale_init_value > 0 else None
def forward(self, x):
identity = x
x = self.norm1(x)
x = self.attn(x)
if self.layer_scale_1 is not None:
x = self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * x
x = identity + self.drop_path(x)
identity = x
x = self.norm2(x)
x = self.mlp(x)
if self.layer_scale_2 is not None:
x = self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * x
x = identity + self.drop_path(x)
return x
class VANPatchEmbed(PatchEmbed):
"""Image to Patch Embedding of VAN.
The differences between VANPatchEmbed & PatchEmbed:
1. Use BN.
2. Do not use 'flatten' and 'transpose'.
"""
def __init__(self, *args, norm_cfg=dict(type='BN'), **kwargs):
super(VANPatchEmbed, self).__init__(*args, norm_cfg=norm_cfg, **kwargs)
def forward(self, x):
"""
Args:
x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, out_h * out_w, embed_dims)
- out_size (tuple[int]): Spatial shape of x, arrange as
(out_h, out_w).
"""
if self.adaptive_padding:
x = self.adaptive_padding(x)
x = self.projection(x)
out_size = (x.shape[2], x.shape[3])
if self.norm is not None:
x = self.norm(x)
return x, out_size
@MODELS.register_module()
class VAN(BaseBackbone):
"""Visual Attention Network.
A PyTorch implement of : `Visual Attention Network
<https://arxiv.org/pdf/2202.09741v2.pdf>`_
Inspiration from
https://github.com/Visual-Attention-Network/VAN-Classification
Args:
arch (str | dict): Visual Attention Network architecture.
If use string, choose from 'tiny', 'small', 'base' and 'large'.
If use dict, it should have below keys:
- **embed_dims** (List[int]): The dimensions of embedding.
- **depths** (List[int]): The number of blocks in each stage.
- **ffn_ratios** (List[int]): The number of expansion ratio of
feedforward network hidden layer channels.
Defaults to 'tiny'.
patch_sizes (List[int | tuple]): The patch size in patch embeddings.
Defaults to [7, 3, 3, 3].
in_channels (int): The num of input channels. Defaults to 3.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
out_indices (Sequence[int]): Output from which stages.
Default: ``(3, )``.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
norm_cfg (dict): Config dict for normalization layer for all output
features. Defaults to ``dict(type='LN')``
block_cfgs (Sequence[dict] | dict): The extra config of each block.
Defaults to empty dicts.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmpretrain.models import VAN
>>> import torch
>>> cfg = dict(arch='tiny')
>>> model = VAN(**cfg)
>>> inputs = torch.rand(1, 3, 224, 224)
>>> outputs = model(inputs)
>>> for out in outputs:
>>> print(out.size())
(1, 256, 7, 7)
"""
arch_zoo = {
**dict.fromkeys(['t', 'tiny'],
{'embed_dims': [32, 64, 160, 256],
'depths': [3, 3, 5, 2],
'ffn_ratios': [8, 8, 4, 4]}),
**dict.fromkeys(['s', 'small'],
{'embed_dims': [64, 128, 320, 512],
'depths': [2, 2, 4, 2],
'ffn_ratios': [8, 8, 4, 4]}),
**dict.fromkeys(['b', 'base'],
{'embed_dims': [64, 128, 320, 512],
'depths': [3, 3, 12, 3],
'ffn_ratios': [8, 8, 4, 4]}),
**dict.fromkeys(['l', 'large'],
{'embed_dims': [64, 128, 320, 512],
'depths': [3, 5, 27, 3],
'ffn_ratios': [8, 8, 4, 4]}),
} # yapf: disable
def __init__(self,
arch='tiny',
patch_sizes=[7, 3, 3, 3],
in_channels=3,
drop_rate=0.,
drop_path_rate=0.,
out_indices=(3, ),
frozen_stages=-1,
norm_eval=False,
norm_cfg=dict(type='LN'),
block_cfgs=dict(),
init_cfg=None):
super(VAN, self).__init__(init_cfg=init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {'embed_dims', 'depths', 'ffn_ratios'}
assert isinstance(arch, dict) and set(arch) == essential_keys, \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.embed_dims = self.arch_settings['embed_dims']
self.depths = self.arch_settings['depths']
self.ffn_ratios = self.arch_settings['ffn_ratios']
self.num_stages = len(self.depths)
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.norm_eval = norm_eval
total_depth = sum(self.depths)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
] # stochastic depth decay rule
cur_block_idx = 0
for i, depth in enumerate(self.depths):
patch_embed = VANPatchEmbed(
in_channels=in_channels if i == 0 else self.embed_dims[i - 1],
input_size=None,
embed_dims=self.embed_dims[i],
kernel_size=patch_sizes[i],
stride=patch_sizes[i] // 2 + 1,
padding=(patch_sizes[i] // 2, patch_sizes[i] // 2),
norm_cfg=dict(type='BN'))
blocks = ModuleList([
VANBlock(
embed_dims=self.embed_dims[i],
ffn_ratio=self.ffn_ratios[i],
drop_rate=drop_rate,
drop_path_rate=dpr[cur_block_idx + j],
**block_cfgs) for j in range(depth)
])
cur_block_idx += depth
norm = build_norm_layer(norm_cfg, self.embed_dims[i])[1]
self.add_module(f'patch_embed{i + 1}', patch_embed)
self.add_module(f'blocks{i + 1}', blocks)
self.add_module(f'norm{i + 1}', norm)
def train(self, mode=True):
super(VAN, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
def _freeze_stages(self):
for i in range(0, self.frozen_stages + 1):
# freeze patch embed
m = getattr(self, f'patch_embed{i + 1}')
m.eval()
for param in m.parameters():
param.requires_grad = False
# freeze blocks
m = getattr(self, f'blocks{i + 1}')
m.eval()
for param in m.parameters():
param.requires_grad = False
# freeze norm
m = getattr(self, f'norm{i + 1}')
m.eval()
for param in m.parameters():
param.requires_grad = False
def forward(self, x):
outs = []
for i in range(self.num_stages):
patch_embed = getattr(self, f'patch_embed{i + 1}')
blocks = getattr(self, f'blocks{i + 1}')
norm = getattr(self, f'norm{i + 1}')
x, hw_shape = patch_embed(x)
for block in blocks:
x = block(x)
x = x.flatten(2).transpose(1, 2)
x = norm(x)
x = x.reshape(-1, *hw_shape,
block.out_channels).permute(0, 3, 1, 2).contiguous()
if i in self.out_indices:
outs.append(x)
return tuple(outs)
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone
def make_vgg_layer(in_channels,
out_channels,
num_blocks,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
dilation=1,
with_norm=False,
ceil_mode=False):
layers = []
for _ in range(num_blocks):
layer = ConvModule(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
dilation=dilation,
padding=dilation,
bias=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
layers.append(layer)
in_channels = out_channels
layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode))
return layers
@MODELS.register_module()
class VGG(BaseBackbone):
"""VGG backbone.
Args:
depth (int): Depth of vgg, from {11, 13, 16, 19}.
with_norm (bool): Use BatchNorm or not.
num_classes (int): number of classes for classification.
num_stages (int): VGG stages, normally 5.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int], optional): Output from which stages.
When it is None, the default behavior depends on whether
num_classes is specified. If num_classes <= 0, the default value is
(4, ), output the last feature map before classifier. If
num_classes > 0, the default value is (5, ), output the
classification score. Default: None.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
ceil_mode (bool): Whether to use ceil_mode of MaxPool. Default: False.
with_last_pool (bool): Whether to keep the last pooling before
classifier. Default: True.
"""
# Parameters to build layers. Each element specifies the number of conv in
# each stage. For example, VGG11 contains 11 layers with learnable
# parameters. 11 is computed as 11 = (1 + 1 + 2 + 2 + 2) + 3,
# where 3 indicates the last three fully-connected layers.
arch_settings = {
11: (1, 1, 2, 2, 2),
13: (2, 2, 2, 2, 2),
16: (2, 2, 3, 3, 3),
19: (2, 2, 4, 4, 4)
}
def __init__(self,
depth,
num_classes=-1,
num_stages=5,
dilations=(1, 1, 1, 1, 1),
out_indices=None,
frozen_stages=-1,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
norm_eval=False,
ceil_mode=False,
with_last_pool=True,
init_cfg=[
dict(type='Kaiming', layer=['Conv2d']),
dict(type='Constant', val=1., layer=['_BatchNorm']),
dict(type='Normal', std=0.01, layer=['Linear'])
]):
super(VGG, self).__init__(init_cfg)
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for vgg')
assert num_stages >= 1 and num_stages <= 5
stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages]
assert len(dilations) == num_stages
self.num_classes = num_classes
self.frozen_stages = frozen_stages
self.norm_eval = norm_eval
with_norm = norm_cfg is not None
if out_indices is None:
out_indices = (5, ) if num_classes > 0 else (4, )
assert max(out_indices) <= num_stages
self.out_indices = out_indices
self.in_channels = 3
start_idx = 0
vgg_layers = []
self.range_sub_modules = []
for i, num_blocks in enumerate(self.stage_blocks):
num_modules = num_blocks + 1
end_idx = start_idx + num_modules
dilation = dilations[i]
out_channels = 64 * 2**i if i < 4 else 512
vgg_layer = make_vgg_layer(
self.in_channels,
out_channels,
num_blocks,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
dilation=dilation,
with_norm=with_norm,
ceil_mode=ceil_mode)
vgg_layers.extend(vgg_layer)
self.in_channels = out_channels
self.range_sub_modules.append([start_idx, end_idx])
start_idx = end_idx
if not with_last_pool:
vgg_layers.pop(-1)
self.range_sub_modules[-1][1] -= 1
self.module_name = 'features'
self.add_module(self.module_name, nn.Sequential(*vgg_layers))
if self.num_classes > 0:
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
def forward(self, x):
outs = []
vgg_layers = getattr(self, self.module_name)
for i in range(len(self.stage_blocks)):
for j in range(*self.range_sub_modules[i]):
vgg_layer = vgg_layers[j]
x = vgg_layer(x)
if i in self.out_indices:
outs.append(x)
if self.num_classes > 0:
x = x.view(x.size(0), -1)
x = self.classifier(x)
outs.append(x)
return tuple(outs)
def _freeze_stages(self):
vgg_layers = getattr(self, self.module_name)
for i in range(self.frozen_stages):
for j in range(*self.range_sub_modules[i]):
m = vgg_layers[j]
m.eval()
for param in m.parameters():
param.requires_grad = False
def train(self, mode=True):
super(VGG, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
# Copyright (c) OpenMMLab. All rights reserved.
# modified from
# https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch
from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_activation_layer
from mmcv.cnn.bricks import DropPath
from mmengine.model import ModuleList, Sequential
from torch.nn.modules.batchnorm import _BatchNorm
from mmpretrain.models.backbones.base_backbone import BaseBackbone
from mmpretrain.registry import MODELS
from ..utils import build_norm_layer
def get_2d_relative_pos_embed(embed_dim, grid_size):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, grid_size*grid_size]
"""
pos_embed = get_2d_sincos_pos_embed(embed_dim, grid_size)
relative_pos = 2 * np.matmul(pos_embed,
pos_embed.transpose()) / pos_embed.shape[1]
return relative_pos
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed],
axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2,
grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2,
grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def xy_pairwise_distance(x, y):
"""Compute pairwise distance of a point cloud.
Args:
x: tensor (batch_size, num_points, num_dims)
y: tensor (batch_size, num_points, num_dims)
Returns:
pairwise distance: (batch_size, num_points, num_points)
"""
with torch.no_grad():
xy_inner = -2 * torch.matmul(x, y.transpose(2, 1))
x_square = torch.sum(torch.mul(x, x), dim=-1, keepdim=True)
y_square = torch.sum(torch.mul(y, y), dim=-1, keepdim=True)
return x_square + xy_inner + y_square.transpose(2, 1)
def xy_dense_knn_matrix(x, y, k=16, relative_pos=None):
"""Get KNN based on the pairwise distance.
Args:
x: (batch_size, num_dims, num_points, 1)
y: (batch_size, num_dims, num_points, 1)
k: int
relative_pos:Whether to use relative_pos
Returns:
nearest neighbors:
(batch_size, num_points, k) (batch_size, num_points, k)
"""
with torch.no_grad():
x = x.transpose(2, 1).squeeze(-1)
y = y.transpose(2, 1).squeeze(-1)
batch_size, n_points, n_dims = x.shape
dist = xy_pairwise_distance(x.detach(), y.detach())
if relative_pos is not None:
dist += relative_pos
_, nn_idx = torch.topk(-dist, k=k)
center_idx = torch.arange(
0, n_points, device=x.device).repeat(batch_size, k,
1).transpose(2, 1)
return torch.stack((nn_idx, center_idx), dim=0)
class DenseDilated(nn.Module):
"""Find dilated neighbor from neighbor list.
edge_index: (2, batch_size, num_points, k)
"""
def __init__(self, k=9, dilation=1, use_stochastic=False, epsilon=0.0):
super(DenseDilated, self).__init__()
self.dilation = dilation
self.use_stochastic = use_stochastic
self.epsilon = epsilon
self.k = k
def forward(self, edge_index):
if self.use_stochastic:
if torch.rand(1) < self.epsilon and self.training:
num = self.k * self.dilation
randnum = torch.randperm(num)[:self.k]
edge_index = edge_index[:, :, :, randnum]
else:
edge_index = edge_index[:, :, :, ::self.dilation]
else:
edge_index = edge_index[:, :, :, ::self.dilation]
return edge_index
class DenseDilatedKnnGraph(nn.Module):
"""Find the neighbors' indices based on dilated knn."""
def __init__(self, k=9, dilation=1, use_stochastic=False, epsilon=0.0):
super(DenseDilatedKnnGraph, self).__init__()
self.dilation = dilation
self.use_stochastic = use_stochastic
self.epsilon = epsilon
self.k = k
self._dilated = DenseDilated(k, dilation, use_stochastic, epsilon)
def forward(self, x, y=None, relative_pos=None):
if y is not None:
x = F.normalize(x, p=2.0, dim=1)
y = F.normalize(y, p=2.0, dim=1)
edge_index = xy_dense_knn_matrix(x, y, self.k * self.dilation,
relative_pos)
else:
x = F.normalize(x, p=2.0, dim=1)
y = x.clone()
edge_index = xy_dense_knn_matrix(x, y, self.k * self.dilation,
relative_pos)
return self._dilated(edge_index)
class BasicConv(Sequential):
def __init__(self,
channels,
act_cfg,
norm_cfg=None,
graph_conv_bias=True,
drop=0.):
m = []
for i in range(1, len(channels)):
m.append(
nn.Conv2d(
channels[i - 1],
channels[i],
1,
bias=graph_conv_bias,
groups=4))
if norm_cfg is not None:
m.append(build_norm_layer(norm_cfg, channels[-1]))
if act_cfg is not None:
m.append(build_activation_layer(act_cfg))
if drop > 0:
m.append(nn.Dropout2d(drop))
super(BasicConv, self).__init__(*m)
def batched_index_select(x, idx):
r"""fetches neighbors features from a given neighbor idx
Args:
x (Tensor): input feature Tensor
:math:
`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times 1}`.
idx (Tensor): edge_idx
:math:`\mathbf{X} \in \mathbb{R}^{B \times N \times l}`.
Returns:
Tensor: output neighbors features
:math:`\mathbf{X} \in \mathbb{R}^{B \times C \times N \times k}`.
"""
batch_size, num_dims, num_vertices_reduced = x.shape[:3]
_, num_vertices, k = idx.shape
idx_base = torch.arange(
0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices_reduced
idx = idx + idx_base
idx = idx.contiguous().view(-1)
x = x.transpose(2, 1)
feature = x.contiguous().view(batch_size * num_vertices_reduced,
-1)[idx, :]
feature = feature.view(batch_size, num_vertices, k,
num_dims).permute(0, 3, 1, 2).contiguous()
return feature
class MRConv2d(nn.Module):
"""Max-Relative Graph Convolution (Paper: https://arxiv.org/abs/1904.03751)
for dense data type."""
def __init__(self,
in_channels,
out_channels,
act_cfg,
norm_cfg=None,
graph_conv_bias=True):
super(MRConv2d, self).__init__()
self.nn = BasicConv([in_channels * 2, out_channels], act_cfg, norm_cfg,
graph_conv_bias)
def forward(self, x, edge_index, y=None):
x_i = batched_index_select(x, edge_index[1])
if y is not None:
x_j = batched_index_select(y, edge_index[0])
else:
x_j = batched_index_select(x, edge_index[0])
x_j, _ = torch.max(x_j - x_i, -1, keepdim=True)
b, c, n, _ = x.shape
x = torch.cat([x.unsqueeze(2), x_j.unsqueeze(2)],
dim=2).reshape(b, 2 * c, n, _)
return self.nn(x)
class EdgeConv2d(nn.Module):
"""Edge convolution layer (with activation, batch normalization) for dense
data type."""
def __init__(self,
in_channels,
out_channels,
act_cfg,
norm_cfg=None,
graph_conv_bias=True):
super(EdgeConv2d, self).__init__()
self.nn = BasicConv([in_channels * 2, out_channels], act_cfg, norm_cfg,
graph_conv_bias)
def forward(self, x, edge_index, y=None):
x_i = batched_index_select(x, edge_index[1])
if y is not None:
x_j = batched_index_select(y, edge_index[0])
else:
x_j = batched_index_select(x, edge_index[0])
max_value, _ = torch.max(
self.nn(torch.cat([x_i, x_j - x_i], dim=1)), -1, keepdim=True)
return max_value
class GraphSAGE(nn.Module):
"""GraphSAGE Graph Convolution (Paper: https://arxiv.org/abs/1706.02216)
for dense data type."""
def __init__(self,
in_channels,
out_channels,
act_cfg,
norm_cfg=None,
graph_conv_bias=True):
super(GraphSAGE, self).__init__()
self.nn1 = BasicConv([in_channels, in_channels], act_cfg, norm_cfg,
graph_conv_bias)
self.nn2 = BasicConv([in_channels * 2, out_channels], act_cfg,
norm_cfg, graph_conv_bias)
def forward(self, x, edge_index, y=None):
if y is not None:
x_j = batched_index_select(y, edge_index[0])
else:
x_j = batched_index_select(x, edge_index[0])
x_j, _ = torch.max(self.nn1(x_j), -1, keepdim=True)
return self.nn2(torch.cat([x, x_j], dim=1))
class GINConv2d(nn.Module):
"""GIN Graph Convolution (Paper: https://arxiv.org/abs/1810.00826) for
dense data type."""
def __init__(self,
in_channels,
out_channels,
act_cfg,
norm_cfg=None,
graph_conv_bias=True):
super(GINConv2d, self).__init__()
self.nn = BasicConv([in_channels, out_channels], act_cfg, norm_cfg,
graph_conv_bias)
eps_init = 0.0
self.eps = nn.Parameter(torch.Tensor([eps_init]))
def forward(self, x, edge_index, y=None):
if y is not None:
x_j = batched_index_select(y, edge_index[0])
else:
x_j = batched_index_select(x, edge_index[0])
x_j = torch.sum(x_j, -1, keepdim=True)
return self.nn((1 + self.eps) * x + x_j)
class GraphConv2d(nn.Module):
"""Static graph convolution layer."""
def __init__(self,
in_channels,
out_channels,
graph_conv_type,
act_cfg,
norm_cfg=None,
graph_conv_bias=True):
super(GraphConv2d, self).__init__()
if graph_conv_type == 'edge':
self.gconv = EdgeConv2d(in_channels, out_channels, act_cfg,
norm_cfg, graph_conv_bias)
elif graph_conv_type == 'mr':
self.gconv = MRConv2d(in_channels, out_channels, act_cfg, norm_cfg,
graph_conv_bias)
elif graph_conv_type == 'sage':
self.gconv = GraphSAGE(in_channels, out_channels, act_cfg,
norm_cfg, graph_conv_bias)
elif graph_conv_type == 'gin':
self.gconv = GINConv2d(in_channels, out_channels, act_cfg,
norm_cfg, graph_conv_bias)
else:
raise NotImplementedError(
'graph_conv_type:{} is not supported'.format(graph_conv_type))
def forward(self, x, edge_index, y=None):
return self.gconv(x, edge_index, y)
class DyGraphConv2d(GraphConv2d):
"""Dynamic graph convolution layer."""
def __init__(self,
in_channels,
out_channels,
k=9,
dilation=1,
graph_conv_type='mr',
act_cfg=dict(type='GELU'),
norm_cfg=None,
graph_conv_bias=True,
use_stochastic=False,
epsilon=0.2,
r=1):
super(DyGraphConv2d,
self).__init__(in_channels, out_channels, graph_conv_type,
act_cfg, norm_cfg, graph_conv_bias)
self.k = k
self.d = dilation
self.r = r
self.dilated_knn_graph = DenseDilatedKnnGraph(k, dilation,
use_stochastic, epsilon)
def forward(self, x, relative_pos=None):
B, C, H, W = x.shape
y = None
if self.r > 1:
y = F.avg_pool2d(x, self.r, self.r)
y = y.reshape(B, C, -1, 1).contiguous()
x = x.reshape(B, C, -1, 1).contiguous()
edge_index = self.dilated_knn_graph(x, y, relative_pos)
x = super(DyGraphConv2d, self).forward(x, edge_index, y)
return x.reshape(B, -1, H, W).contiguous()
class Grapher(nn.Module):
"""Grapher module with graph convolution and fc layers."""
def __init__(self,
in_channels,
k=9,
dilation=1,
graph_conv_type='mr',
act_cfg=dict(type='GELU'),
norm_cfg=None,
graph_conv_bias=True,
use_stochastic=False,
epsilon=0.2,
r=1,
n=196,
drop_path=0.0,
relative_pos=False):
super(Grapher, self).__init__()
self.channels = in_channels
self.n = n
self.r = r
self.fc1 = Sequential(
nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0),
build_norm_layer(dict(type='BN'), in_channels),
)
self.graph_conv = DyGraphConv2d(in_channels, in_channels * 2, k,
dilation, graph_conv_type, act_cfg,
norm_cfg, graph_conv_bias,
use_stochastic, epsilon, r)
self.fc2 = Sequential(
nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0),
build_norm_layer(dict(type='BN'), in_channels),
)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.relative_pos = None
if relative_pos:
relative_pos_tensor = torch.from_numpy(
np.float32(
get_2d_relative_pos_embed(in_channels, int(
n**0.5)))).unsqueeze(0).unsqueeze(1)
relative_pos_tensor = F.interpolate(
relative_pos_tensor,
size=(n, n // (r * r)),
mode='bicubic',
align_corners=False)
self.relative_pos = nn.Parameter(
-relative_pos_tensor.squeeze(1), requires_grad=False)
def _get_relative_pos(self, relative_pos, H, W):
if relative_pos is None or H * W == self.n:
return relative_pos
else:
N = H * W
N_reduced = N // (self.r * self.r)
return F.interpolate(
relative_pos.unsqueeze(0), size=(N, N_reduced),
mode='bicubic').squeeze(0)
def forward(self, x):
B, C, H, W = x.shape
relative_pos = self._get_relative_pos(self.relative_pos, H, W)
shortcut = x
x = self.fc1(x)
x = self.graph_conv(x, relative_pos)
x = self.fc2(x)
x = self.drop_path(x) + shortcut
return x
class FFN(nn.Module):
""""out_features = out_features or in_features\n
hidden_features = hidden_features or in_features"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_cfg=dict(type='GELU'),
drop_path=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = Sequential(
nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0),
build_norm_layer(dict(type='BN'), hidden_features),
)
self.act = build_activation_layer(act_cfg)
self.fc2 = Sequential(
nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0),
build_norm_layer(dict(type='BN'), out_features),
)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
shortcut = x
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
x = self.drop_path(x) + shortcut
return x
@MODELS.register_module()
class Vig(BaseBackbone):
"""Vision GNN backbone.
A PyTorch implementation of `Vision GNN: An Image is Worth Graph of Nodes
<https://arxiv.org/abs/2206.00272>`_.
Modified from the official implementation
https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch
Args:
arch(str): Vision GNN architecture,
choose from 'tiny', 'small' and 'base'.
in_channels (int): The number of channels of input images.
Defaults to 3.
k (int): The number of KNN's k. Defaults to 9.
out_indices (Sequence | int): Output from which blocks.
Defaults to -1, means the last block.
act_cfg (dict): The config of activative functions.
Defaults to ``dict(type='GELU'))``.
norm_cfg (dict): The config of normalization layers.
Defaults to ``dict(type='BN', eps=1e-6)``.
graph_conv_bias (bool): Whether to use bias in the convolution
layers in Grapher. Defaults to True.
graph_conv_type (str): The type of graph convolution,choose
from 'edge', 'mr', 'sage' and 'gin'. Defaults to 'mr'.
epsilon (float): Probability of random arrangement in KNN. It only
works when ``use_dilation=True`` and ``use_stochastic=True``.
Defaults to 0.2.
use_dilation(bool): Whether to use dilation in KNN. Defaults to True.
use_stochastic(bool): Whether to use stochastic in KNN.
Defaults to False.
drop_path (float): stochastic depth rate. Default 0.0
relative_pos(bool): Whether to use relative position embedding.
Defaults to False.
norm_eval (bool): Whether to set the normalization layer to eval mode.
Defaults to False.
frozen_stages (int): Blocks to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters.
init_cfg (dict, optional): The initialization configs.
Defaults to None.
""" # noqa: E501
arch_settings = {
'tiny': dict(num_blocks=12, channels=192),
'small': dict(num_blocks=16, channels=320),
'base': dict(num_blocks=16, channels=640),
}
def __init__(self,
arch,
in_channels=3,
k=9,
out_indices=-1,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='BN'),
graph_conv_bias=True,
graph_conv_type='mr',
epsilon=0.2,
use_dilation=True,
use_stochastic=False,
drop_path=0.,
relative_pos=False,
norm_eval=False,
frozen_stages=0,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
arch = self.arch_settings[arch]
self.num_blocks = arch['num_blocks']
channels = arch['channels']
if isinstance(out_indices, int):
out_indices = [out_indices]
elif isinstance(out_indices, tuple):
out_indices = list(out_indices)
elif not isinstance(out_indices, list):
raise TypeError('"out_indices" must by a tuple, list or int, '
f'get {type(out_indices)} instead.')
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.num_blocks + index
assert 0 <= out_indices[i] <= self.num_blocks, \
f'Invalid out_indices {index}'
self.out_indices = out_indices
self.stem = Sequential(
nn.Conv2d(in_channels, channels // 8, 3, stride=2, padding=1),
build_norm_layer(norm_cfg, channels // 8),
build_activation_layer(act_cfg),
nn.Conv2d(channels // 8, channels // 4, 3, stride=2, padding=1),
build_norm_layer(norm_cfg, channels // 4),
build_activation_layer(act_cfg),
nn.Conv2d(channels // 4, channels // 2, 3, stride=2, padding=1),
build_norm_layer(norm_cfg, channels // 2),
build_activation_layer(act_cfg),
nn.Conv2d(channels // 2, channels, 3, stride=2, padding=1),
build_norm_layer(norm_cfg, channels),
build_activation_layer(act_cfg),
nn.Conv2d(channels, channels, 3, stride=1, padding=1),
build_norm_layer(norm_cfg, channels),
)
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path, self.num_blocks)]
# number of knn's k
num_knn = [
int(x.item()) for x in torch.linspace(k, 2 * k, self.num_blocks)
]
max_dilation = 196 // max(num_knn)
self.pos_embed = nn.Parameter(torch.zeros(1, channels, 14, 14))
self.blocks = ModuleList([
Sequential(
Grapher(
in_channels=channels,
k=num_knn[i],
dilation=min(i // 4 +
1, max_dilation) if use_dilation else 1,
graph_conv_type=graph_conv_type,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
graph_conv_bias=graph_conv_bias,
use_stochastic=use_stochastic,
epsilon=epsilon,
drop_path=dpr[i],
relative_pos=relative_pos),
FFN(in_features=channels,
hidden_features=channels * 4,
act_cfg=act_cfg,
drop_path=dpr[i])) for i in range(self.num_blocks)
])
self.norm_eval = norm_eval
self.frozen_stages = frozen_stages
def forward(self, inputs):
outs = []
x = self.stem(inputs) + self.pos_embed
for i, block in enumerate(self.blocks):
x = block(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
def _freeze_stages(self):
self.stem.eval()
for i in range(self.frozen_stages):
m = self.blocks[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
def train(self, mode=True):
super(Vig, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
@MODELS.register_module()
class PyramidVig(BaseBackbone):
"""Pyramid Vision GNN backbone.
A PyTorch implementation of `Vision GNN: An Image is Worth Graph of Nodes
<https://arxiv.org/abs/2206.00272>`_.
Modified from the official implementation
https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch
Args:
arch (str): Vision GNN architecture, choose from 'tiny',
'small' and 'base'.
in_channels (int): The number of channels of input images.
Defaults to 3.
k (int): The number of KNN's k. Defaults to 9.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
act_cfg (dict): The config of activative functions.
Defaults to ``dict(type='GELU'))``.
norm_cfg (dict): The config of normalization layers.
Defaults to ``dict(type='BN')``.
graph_conv_bias (bool): Whether to use bias in the convolution
layers in Grapher. Defaults to True.
graph_conv_type (str): The type of graph convolution,choose
from 'edge', 'mr', 'sage' and 'gin'. Defaults to 'mr'.
epsilon (float): Probability of random arrangement in KNN. It only
works when ``use_stochastic=True``. Defaults to 0.2.
use_stochastic (bool): Whether to use stochastic in KNN.
Defaults to False.
drop_path (float): stochastic depth rate. Default 0.0
norm_eval (bool): Whether to set the normalization layer to eval mode.
Defaults to False.
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters.
init_cfg (dict, optional): The initialization configs.
Defaults to None.
""" # noqa: E501
arch_settings = {
'tiny': dict(blocks=[2, 2, 6, 2], channels=[48, 96, 240, 384]),
'small': dict(blocks=[2, 2, 6, 2], channels=[80, 160, 400, 640]),
'medium': dict(blocks=[2, 2, 16, 2], channels=[96, 192, 384, 768]),
'base': dict(blocks=[2, 2, 18, 2], channels=[128, 256, 512, 1024]),
}
def __init__(self,
arch,
in_channels=3,
k=9,
out_indices=-1,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='BN'),
graph_conv_bias=True,
graph_conv_type='mr',
epsilon=0.2,
use_stochastic=False,
drop_path=0.,
norm_eval=False,
frozen_stages=0,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
arch = self.arch_settings[arch]
self.blocks = arch['blocks']
self.num_blocks = sum(self.blocks)
self.num_stages = len(self.blocks)
channels = arch['channels']
self.channels = channels
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.num_stages + index
assert 0 <= out_indices[i] <= self.num_stages, \
f'Invalid out_indices {index}'
self.out_indices = out_indices
self.stem = Sequential(
nn.Conv2d(in_channels, channels[0] // 2, 3, stride=2, padding=1),
build_norm_layer(norm_cfg, channels[0] // 2),
build_activation_layer(act_cfg),
nn.Conv2d(channels[0] // 2, channels[0], 3, stride=2, padding=1),
build_norm_layer(norm_cfg, channels[0]),
build_activation_layer(act_cfg),
nn.Conv2d(channels[0], channels[0], 3, stride=1, padding=1),
build_norm_layer(norm_cfg, channels[0]),
)
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path, self.num_blocks)]
# number of knn's k
num_knn = [
int(x.item()) for x in torch.linspace(k, k, self.num_blocks)
]
max_dilation = 49 // max(num_knn)
self.pos_embed = nn.Parameter(
torch.zeros(1, channels[0], 224 // 4, 224 // 4))
HW = 224 // 4 * 224 // 4
reduce_ratios = [4, 2, 1, 1]
self.stages = ModuleList()
block_idx = 0
for stage_idx, num_blocks in enumerate(self.blocks):
mid_channels = channels[stage_idx]
reduce_ratio = reduce_ratios[stage_idx]
blocks = []
if stage_idx > 0:
blocks.append(
Sequential(
nn.Conv2d(
self.channels[stage_idx - 1],
mid_channels,
kernel_size=3,
stride=2,
padding=1),
build_norm_layer(norm_cfg, mid_channels),
))
HW = HW // 4
for _ in range(num_blocks):
blocks.append(
Sequential(
Grapher(
in_channels=mid_channels,
k=num_knn[block_idx],
dilation=min(block_idx // 4 + 1, max_dilation),
graph_conv_type=graph_conv_type,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
graph_conv_bias=graph_conv_bias,
use_stochastic=use_stochastic,
epsilon=epsilon,
r=reduce_ratio,
n=HW,
drop_path=dpr[block_idx],
relative_pos=True),
FFN(in_features=mid_channels,
hidden_features=mid_channels * 4,
act_cfg=act_cfg,
drop_path=dpr[block_idx])))
block_idx += 1
self.stages.append(Sequential(*blocks))
self.norm_eval = norm_eval
self.frozen_stages = frozen_stages
def forward(self, inputs):
outs = []
x = self.stem(inputs) + self.pos_embed
for i, blocks in enumerate(self.stages):
x = blocks(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
def _freeze_stages(self):
self.stem.eval()
for i in range(self.frozen_stages):
m = self.stages[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
def train(self, mode=True):
super(PyramidVig, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import trunc_normal_
from mmpretrain.registry import MODELS
from ..utils import (MultiheadAttention, SwiGLUFFNFused, build_norm_layer,
resize_pos_embed, to_2tuple)
from .base_backbone import BaseBackbone
class TransformerEncoderLayer(BaseModule):
"""Implements one encoder layer in Vision Transformer.
Args:
embed_dims (int): The feature dimension
num_heads (int): Parallel attention heads
feedforward_channels (int): The hidden dimension for FFNs
layer_scale_init_value (float or torch.Tensor): Init value of layer
scale. Defaults to 0.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop_rate (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
ffn_type (str): Select the type of ffn layers. Defaults to 'origin'.
act_cfg (dict): The activation config for FFNs.
Defaults to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
layer_scale_init_value=0.,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
qkv_bias=True,
ffn_type='origin',
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_cfg=None):
super(TransformerEncoderLayer, self).__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
self.attn = MultiheadAttention(
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
qkv_bias=qkv_bias,
layer_scale_init_value=layer_scale_init_value)
self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
if ffn_type == 'origin':
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg,
layer_scale_init_value=layer_scale_init_value)
elif ffn_type == 'swiglu_fused':
self.ffn = SwiGLUFFNFused(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
layer_scale_init_value=layer_scale_init_value)
else:
raise NotImplementedError
@property
def norm1(self):
return self.ln1
@property
def norm2(self):
return self.ln2
def init_weights(self):
super(TransformerEncoderLayer, self).init_weights()
for m in self.ffn.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.normal_(m.bias, std=1e-6)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = self.ffn(self.ln2(x), identity=x)
return x
@MODELS.register_module()
class VisionTransformer(BaseBackbone):
"""Vision Transformer.
A PyTorch implement of : `An Image is Worth 16x16 Words: Transformers
for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_
Args:
arch (str | dict): Vision Transformer architecture. If use string,
choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small'
and 'deit-base'. If use dict, it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
- **feedforward_channels** (int): The hidden dimensions in
feedforward modules.
Defaults to 'base'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
Defaults to ``"cls_token"``.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
layer_scale_init_value (float or torch.Tensor): Init value of layer
scale. Defaults to 0.
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo = {
**dict.fromkeys(
['s', 'small'], {
'embed_dims': 768,
'num_layers': 8,
'num_heads': 8,
'feedforward_channels': 768 * 3,
}),
**dict.fromkeys(
['b', 'base'], {
'embed_dims': 768,
'num_layers': 12,
'num_heads': 12,
'feedforward_channels': 3072
}),
**dict.fromkeys(
['l', 'large'], {
'embed_dims': 1024,
'num_layers': 24,
'num_heads': 16,
'feedforward_channels': 4096
}),
**dict.fromkeys(
['h', 'huge'],
{
# The same as the implementation in MAE
# <https://arxiv.org/abs/2111.06377>
'embed_dims': 1280,
'num_layers': 32,
'num_heads': 16,
'feedforward_channels': 5120
}),
**dict.fromkeys(
['eva-g', 'eva-giant'],
{
# The implementation in EVA
# <https://arxiv.org/abs/2211.07636>
'embed_dims': 1408,
'num_layers': 40,
'num_heads': 16,
'feedforward_channels': 6144
}),
**dict.fromkeys(
['deit-t', 'deit-tiny'], {
'embed_dims': 192,
'num_layers': 12,
'num_heads': 3,
'feedforward_channels': 192 * 4
}),
**dict.fromkeys(
['deit-s', 'deit-small', 'dinov2-s', 'dinov2-small'], {
'embed_dims': 384,
'num_layers': 12,
'num_heads': 6,
'feedforward_channels': 384 * 4
}),
**dict.fromkeys(
['deit-b', 'deit-base'], {
'embed_dims': 768,
'num_layers': 12,
'num_heads': 12,
'feedforward_channels': 768 * 4
}),
**dict.fromkeys(
['dinov2-g', 'dinov2-giant'], {
'embed_dims': 1536,
'num_layers': 40,
'num_heads': 24,
'feedforward_channels': 6144
}),
}
num_extra_tokens = 1 # class token
OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'}
def __init__(self,
arch='base',
img_size=224,
patch_size=16,
in_channels=3,
out_indices=-1,
drop_rate=0.,
drop_path_rate=0.,
qkv_bias=True,
norm_cfg=dict(type='LN', eps=1e-6),
final_norm=True,
out_type='cls_token',
with_cls_token=True,
frozen_stages=-1,
interpolate_mode='bicubic',
layer_scale_init_value=0.,
patch_cfg=dict(),
layer_cfgs=dict(),
pre_norm=False,
init_cfg=None):
super(VisionTransformer, self).__init__(init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {
'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
}
assert isinstance(arch, dict) and essential_keys <= set(arch), \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.embed_dims = self.arch_settings['embed_dims']
self.num_layers = self.arch_settings['num_layers']
self.img_size = to_2tuple(img_size)
# Set patch embedding
_patch_cfg = dict(
in_channels=in_channels,
input_size=img_size,
embed_dims=self.embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
bias=not pre_norm, # disable bias if pre_norm is used(e.g., CLIP)
)
_patch_cfg.update(patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
self.patch_resolution = self.patch_embed.init_out_size
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
# Set out type
if out_type not in self.OUT_TYPES:
raise ValueError(f'Unsupported `out_type` {out_type}, please '
f'choose from {self.OUT_TYPES}')
self.out_type = out_type
# Set cls token
self.with_cls_token = with_cls_token
if with_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
elif out_type != 'cls_token':
self.cls_token = None
self.num_extra_tokens = 0
else:
raise ValueError(
'with_cls_token must be True when `out_type="cls_token"`.')
# Set position embedding
self.interpolate_mode = interpolate_mode
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + self.num_extra_tokens,
self.embed_dims))
self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
self.drop_after_pos = nn.Dropout(p=drop_rate)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.num_layers + index
assert 0 <= out_indices[i] <= self.num_layers, \
f'Invalid out_indices {index}'
self.out_indices = out_indices
# stochastic depth decay rule
dpr = np.linspace(0, drop_path_rate, self.num_layers)
self.layers = ModuleList()
if isinstance(layer_cfgs, dict):
layer_cfgs = [layer_cfgs] * self.num_layers
for i in range(self.num_layers):
_layer_cfg = dict(
embed_dims=self.embed_dims,
num_heads=self.arch_settings['num_heads'],
feedforward_channels=self.
arch_settings['feedforward_channels'],
layer_scale_init_value=layer_scale_init_value,
drop_rate=drop_rate,
drop_path_rate=dpr[i],
qkv_bias=qkv_bias,
norm_cfg=norm_cfg)
_layer_cfg.update(layer_cfgs[i])
self.layers.append(TransformerEncoderLayer(**_layer_cfg))
self.frozen_stages = frozen_stages
if pre_norm:
self.pre_norm = build_norm_layer(norm_cfg, self.embed_dims)
else:
self.pre_norm = nn.Identity()
self.final_norm = final_norm
if final_norm:
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
if self.out_type == 'avg_featmap':
self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
# freeze stages only when self.frozen_stages > 0
if self.frozen_stages > 0:
self._freeze_stages()
@property
def norm1(self):
return self.ln1
@property
def norm2(self):
return self.ln2
def init_weights(self):
super(VisionTransformer, self).init_weights()
if not (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=0.02)
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'pos_embed'
if name not in state_dict.keys():
return
ckpt_pos_embed_shape = state_dict[name].shape
if (not self.with_cls_token
and ckpt_pos_embed_shape[1] == self.pos_embed.shape[1] + 1):
# Remove cls token from state dict if it's not used.
state_dict[name] = state_dict[name][:, 1:]
ckpt_pos_embed_shape = state_dict[name].shape
if self.pos_embed.shape != ckpt_pos_embed_shape:
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info(
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
f'to {self.pos_embed.shape}.')
ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
pos_embed_shape = self.patch_embed.init_out_size
state_dict[name] = resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode,
self.num_extra_tokens)
@staticmethod
def resize_pos_embed(*args, **kwargs):
"""Interface for backward-compatibility."""
return resize_pos_embed(*args, **kwargs)
def _freeze_stages(self):
# freeze position embedding
if self.pos_embed is not None:
self.pos_embed.requires_grad = False
# set dropout to eval model
self.drop_after_pos.eval()
# freeze patch embedding
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
# freeze pre-norm
for param in self.pre_norm.parameters():
param.requires_grad = False
# freeze cls_token
if self.cls_token is not None:
self.cls_token.requires_grad = False
# freeze layers
for i in range(1, self.frozen_stages + 1):
m = self.layers[i - 1]
m.eval()
for param in m.parameters():
param.requires_grad = False
# freeze the last layer norm
if self.frozen_stages == len(self.layers):
if self.final_norm:
self.ln1.eval()
for param in self.ln1.parameters():
param.requires_grad = False
if self.out_type == 'avg_featmap':
self.ln2.eval()
for param in self.ln2.parameters():
param.requires_grad = False
def forward(self, x):
B = x.shape[0]
x, patch_resolution = self.patch_embed(x)
if self.cls_token is not None:
# stole cls_tokens impl from Phil Wang, thanks
cls_token = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = x + resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=self.num_extra_tokens)
x = self.drop_after_pos(x)
x = self.pre_norm(x)
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1 and self.final_norm:
x = self.ln1(x)
if i in self.out_indices:
outs.append(self._format_output(x, patch_resolution))
return tuple(outs)
def _format_output(self, x, hw):
if self.out_type == 'raw':
return x
if self.out_type == 'cls_token':
return x[:, 0]
patch_token = x[:, self.num_extra_tokens:]
if self.out_type == 'featmap':
B = x.size(0)
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
if self.out_type == 'avg_featmap':
return self.ln2(patch_token.mean(dim=1))
def get_layer_depth(self, param_name: str, prefix: str = ''):
"""Get the layer-wise depth of a parameter.
Args:
param_name (str): The name of the parameter.
prefix (str): The prefix for the parameter.
Defaults to an empty string.
Returns:
Tuple[int, int]: The layer-wise depth and the num of layers.
Note:
The first depth is the stem module (``layer_depth=0``), and the
last depth is the subsequent module (``layer_depth=num_layers-1``)
"""
num_layers = self.num_layers + 2
if not param_name.startswith(prefix):
# For subsequent module like head
return num_layers - 1, num_layers
param_name = param_name[len(prefix):]
if param_name in ('cls_token', 'pos_embed'):
layer_depth = 0
elif param_name.startswith('patch_embed'):
layer_depth = 0
elif param_name.startswith('layers'):
layer_id = int(param_name.split('.')[1])
layer_depth = layer_id + 1
else:
layer_depth = num_layers - 1
return layer_depth, num_layers
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn.bricks.drop import build_dropout
from mmengine.model import BaseModule, ModuleList
from mmpretrain.registry import MODELS
from ..utils import (RotaryEmbeddingFast, SwiGLUFFN, build_norm_layer,
resize_pos_embed)
from .vision_transformer import VisionTransformer
class AttentionWithRoPE(BaseModule):
"""Multi-head Attention Module with 2D sincos position embedding (RoPE).
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
attn_drop (float): Dropout rate of the dropout layer after the
attention calculation of query and key. Defaults to 0.
proj_drop (float): Dropout rate of the dropout layer after the
output projection. Defaults to 0.
qkv_bias (bool): If True, add a learnable bias to q and v. Note
that we follows the official implementation where ``k_bias``
is 0. Defaults to True.
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
proj_bias (bool) If True, add a learnable bias to output projection.
Defaults to True.
rope (:obj:`torch.nn.Module`, optional): If it is an object of the
``RotaryEmbedding``, the rotation of the token position will be
performed before the softmax. Defaults to None.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
attn_drop=0.,
proj_drop=0.,
qkv_bias=True,
qk_scale=None,
proj_bias=True,
rope=None,
with_cls_token=True,
init_cfg=None):
super(AttentionWithRoPE, self).__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.num_heads = num_heads
self.head_dims = embed_dims // num_heads
self.scale = qk_scale or self.head_dims**-0.5
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
self.with_cls_token = with_cls_token
self.rope = rope
def forward(self, x, patch_resolution):
B, N, _ = x.shape
qkv = self.qkv(x)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(dim=0)
if self.rope:
if self.with_cls_token:
q_t = q[:, :, 1:, :]
ro_q_t = self.rope(q_t, patch_resolution)
q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
k_t = k[:, :, 1:, :] if self.with_cls_token else k
ro_k_t = self.rope(k_t, patch_resolution)
k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
else:
q = self.rope(q, patch_resolution)
k = self.rope(k, patch_resolution)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1).type_as(x)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class EVA02EndcoderLayer(BaseModule):
"""Implements one encoder EVA02EndcoderLayer in EVA02.
Args:
embed_dims (int): The feature dimension
num_heads (int): Parallel attention heads
feedforward_channels (int): The hidden dimension of FFNs.
sub_ln (bool): Whether to add the sub layer normalization
in the attention module. Defaults to False.
attn_drop (float): Dropout rate of the dropout layer after the
attention calculation of query and key. Defaults to 0.
proj_drop (float): Dropout rate of the dropout layer after the
output projection. Defaults to 0.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
proj_bias (bool): enable bias for projection in the attention module
if True. Defaults to True.
rope (:obj:`torch.nn.Module`, optional): RotaryEmbedding object
in the attention module. Defaults to None.
drop_rate (float): Dropout rate in the mlp module. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
sub_ln=False,
attn_drop=0.,
proj_drop=0.,
qkv_bias=False,
qk_scale=None,
proj_bias=True,
rope=None,
with_cls_token=True,
drop_rate=0.,
drop_path_rate=0.,
norm_cfg=dict(type='LN'),
init_cfg=None):
super(EVA02EndcoderLayer, self).__init__(init_cfg=init_cfg)
self.norm1 = build_norm_layer(norm_cfg, embed_dims)
self.attn = AttentionWithRoPE(
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop,
proj_drop=proj_drop,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
proj_bias=proj_bias,
rope=rope,
with_cls_token=with_cls_token)
self.drop_path = build_dropout(
dict(type='DropPath', drop_prob=drop_path_rate))
self.norm2 = build_norm_layer(norm_cfg, embed_dims)
if drop_rate > 0:
dropout_layer = dict(type='Dropout', drop_prob=drop_rate)
else:
dropout_layer = None
if sub_ln:
ffn_norm = norm_cfg
else:
ffn_norm = None
self.mlp = SwiGLUFFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
dropout_layer=dropout_layer,
norm_cfg=ffn_norm,
add_identity=False,
)
def forward(self, x, patch_resolution):
inputs = x
x = self.norm1(x)
x = self.attn(x, patch_resolution)
x = self.drop_path(x)
x = inputs + x
inputs = x
x = self.norm2(x)
x = self.mlp(x)
x = self.drop_path(x)
x = inputs + x
return x
@MODELS.register_module()
class ViTEVA02(VisionTransformer):
"""EVA02 Vision Transformer.
A PyTorch implement of : `EVA-02: A Visual Representation for Neon Genesis
<https://arxiv.org/abs/2303.11331>`_
Args:
arch (str | dict): Vision Transformer architecture. If use string,
choose from 'tiny', 'small', 'base', 'large'. If use dict,
it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
- **mlp_ratio** (float): The ratio of the mlp module.
Defaults to 'tiny'.
sub_ln (bool): Whether to add the sub layer normalization in swiglu.
Defaults to False.
drop_rate (float): Probability of an element to be zeroed in the
mlp module. Defaults to 0.
attn_drop_rate (float): Probability of an element to be zeroed after
the softmax in the attention. Defaults to 0.
proj_drop_rate (float): Probability of an element to be zeroed after
projection in the attention. Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
**kwargs(dict, optional): Other args for Vision Transformer.
"""
arch_zoo = {
**dict.fromkeys(
['t', 'ti', 'tiny'], {
'embed_dims': 192,
'num_layers': 12,
'num_heads': 3,
'feedforward_channels': int(192 * 4 * 2 / 3)
}),
**dict.fromkeys(
['s', 'small'], {
'embed_dims': 384,
'num_layers': 12,
'num_heads': 6,
'feedforward_channels': int(384 * 4 * 2 / 3)
}),
**dict.fromkeys(
['b', 'base'], {
'embed_dims': 768,
'num_layers': 12,
'num_heads': 12,
'feedforward_channels': int(768 * 4 * 2 / 3)
}),
**dict.fromkeys(
['l', 'large'], {
'embed_dims': 1024,
'num_layers': 24,
'num_heads': 16,
'feedforward_channels': int(1024 * 4 * 2 / 3)
})
}
num_extra_tokens = 1 # class token
OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'}
def __init__(self,
arch='tiny',
sub_ln=False,
drop_rate=0.,
attn_drop_rate=0.,
proj_drop_rate=0.,
drop_path_rate=0.,
qkv_bias=True,
norm_cfg=dict(type='LN'),
with_cls_token=True,
layer_cfgs=dict(),
**kwargs):
# set essential args for Vision Transformer
kwargs.update(
arch=arch,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
norm_cfg=norm_cfg,
with_cls_token=with_cls_token)
super(ViTEVA02, self).__init__(**kwargs)
self.num_heads = self.arch_settings['num_heads']
# Set RoPE
head_dim = self.embed_dims // self.num_heads
self.rope = RotaryEmbeddingFast(
embed_dims=head_dim, patch_resolution=self.patch_resolution)
# stochastic depth decay rule
dpr = np.linspace(0, drop_path_rate, self.num_layers)
self.layers = ModuleList()
if isinstance(layer_cfgs, dict):
layer_cfgs = [layer_cfgs] * self.num_layers
for i in range(self.num_layers):
_layer_cfg = dict(
embed_dims=self.embed_dims,
num_heads=self.num_heads,
feedforward_channels=self.
arch_settings['feedforward_channels'],
sub_ln=sub_ln,
norm_cfg=norm_cfg,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_rate=drop_rate,
qkv_bias=qkv_bias,
rope=self.rope,
with_cls_token=with_cls_token,
drop_path_rate=dpr[i])
_layer_cfg.update(layer_cfgs[i])
self.layers.append(EVA02EndcoderLayer(**_layer_cfg))
def forward(self, x):
B = x.shape[0]
x, patch_resolution = self.patch_embed(x)
if self.cls_token is not None:
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=self.num_extra_tokens)
x = self.drop_after_pos(x)
x = self.pre_norm(x)
outs = []
for i, layer in enumerate(self.layers):
x = layer(x, patch_resolution)
if i == len(self.layers) - 1 and self.final_norm:
x = self.ln1(x)
if i in self.out_indices:
outs.append(self._format_output(x, patch_resolution))
return tuple(outs)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import trunc_normal_
from mmpretrain.registry import MODELS
from ..utils import LayerNorm2d, build_norm_layer, resize_pos_embed, to_2tuple
from .base_backbone import BaseBackbone
def window_partition(x: torch.Tensor,
window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""Partition into non-overlapping windows with padding if needed.
Borrowed from https://github.com/facebookresearch/segment-anything/
Args:
x (torch.Tensor): Input tokens with [B, H, W, C].
window_size (int): Window size.
Returns:
Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
- ``windows``: Windows after partition with
[B * num_windows, window_size, window_size, C].
- ``(Hp, Wp)``: Padded height and width before partition
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size,
window_size, C)
windows = x.permute(0, 1, 3, 2, 4,
5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(windows: torch.Tensor, window_size: int,
pad_hw: Tuple[int, int],
hw: Tuple[int, int]) -> torch.Tensor:
"""Window unpartition into original sequences and removing padding.
Borrowed from https://github.com/facebookresearch/segment-anything/
Args:
x (torch.Tensor): Input tokens with
[B * num_windows, window_size, window_size, C].
window_size (int): Window size.
pad_hw (tuple): Padded height and width (Hp, Wp).
hw (tuple): Original height and width (H, W) before padding.
Returns:
torch.Tensor: Unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size,
window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size: int, k_size: int,
rel_pos: torch.Tensor) -> torch.Tensor:
"""Get relative positional embeddings according to the relative positions
of query and key sizes.
Borrowed from https://github.com/facebookresearch/segment-anything/
Args:
q_size (int): Size of query q.
k_size (int): Size of key k.
rel_pos (torch.Tensor): Relative position embeddings (L, C).
Returns:
torch.Tensor: Extracted positional embeddings according to relative
positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode='linear',
)
rel_pos_resized = rel_pos_resized.reshape(-1,
max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords -
k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int],
) -> torch.Tensor:
"""Borrowed from https://github.com/facebookresearch/segment-anything/
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
Args:
attn (torch.Tensor): Attention map.
q (torch.Tensor): Query q in the attention layer with shape
(B, q_h * q_w, C).
rel_pos_h (torch.Tensor): Relative position embeddings (Lh, C) for
height axis.
rel_pos_w (torch.Tensor): Relative position embeddings (Lw, C) for
width axis.
q_size (tuple): Spatial sequence size of query q with (q_h, q_w).
k_size (tuple): Spatial sequence size of key k with (k_h, k_w).
Returns:
torch.Tensor: Attention map with added relative positional embeddings.
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum('bhwc,hkc->bhwk', r_q, Rh)
rel_w = torch.einsum('bhwc,wkc->bhwk', r_q, Rw)
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] +
rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w)
return attn
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings.
Borrowed from https://github.com/facebookresearch/segment-anything/
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
qkv_bias (bool): If True, add a learnable bias to q, k, v.
Defaults to True.
use_rel_pos (bool):Whether to use relative position embedding.
Defaults to False.
input_size (int, optional): Input resolution for calculating the
relative positional parameter size. Defaults to None.
"""
def __init__(
self,
embed_dims: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
input_size: Optional[Tuple[int, int]] = None,
) -> None:
super().__init__()
self.num_heads = num_heads
head_embed_dims = embed_dims // num_heads
self.scale = head_embed_dims**-0.5
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
self.proj = nn.Linear(embed_dims, embed_dims)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (input_size is not None), \
'Input size must be provided if using relative position embed.'
# initialize relative positional embeddings
self.rel_pos_h = nn.Parameter(
torch.zeros(2 * input_size[0] - 1, head_embed_dims))
self.rel_pos_w = nn.Parameter(
torch.zeros(2 * input_size[1] - 1, head_embed_dims))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads,
-1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h,
self.rel_pos_w, (H, W), (H, W))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W,
-1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
class TransformerEncoderLayer(BaseModule):
"""Encoder layer with window attention in Vision Transformer.
Args:
embed_dims (int): The feature dimension
num_heads (int): Parallel attention heads
feedforward_channels (int): The hidden dimension for FFNs
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
act_cfg (dict): The activation config for FFNs.
Defaults to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
use_rel_pos (bool):Whether to use relative position embedding.
Defaults to False.
window_size (int): Window size for window attention. Defaults to 0.
input_size (int, optional): Input resolution for calculating the
relative positional parameter size. Defaults to None.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
embed_dims: int,
num_heads: int,
feedforward_channels: int,
drop_rate: float = 0.,
drop_path_rate: float = 0.,
num_fcs: int = 2,
qkv_bias: bool = True,
act_cfg: dict = dict(type='GELU'),
norm_cfg: dict = dict(type='LN'),
use_rel_pos: bool = False,
window_size: int = 0,
input_size: Optional[Tuple[int, int]] = None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
self.window_size = window_size
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
self.attn = Attention(
embed_dims=embed_dims,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
input_size=input_size if window_size == 0 else
(window_size, window_size),
)
self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg)
@property
def norm1(self):
return self.ln1
@property
def norm2(self):
return self.ln2
def forward(self, x):
shortcut = x
x = self.ln1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
x = self.ffn(self.ln2(x), identity=x)
return x
@MODELS.register_module()
class ViTSAM(BaseBackbone):
"""Vision Transformer as image encoder used in SAM.
A PyTorch implement of backbone: `Segment Anything
<https://arxiv.org/abs/2304.02643>`_
Args:
arch (str | dict): Vision Transformer architecture. If use string,
choose from 'base', 'large', 'huge'. If use dict, it should have
below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
- **feedforward_channels** (int): The hidden dimensions in
feedforward modules.
- **global_attn_indexes** (int): The index of layers with global
attention.
Defaults to 'base'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
out_channels (int): The num of output channels, if equal to 0, the
channel reduction layer is disabled. Defaults to 256.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
out_type (str): The type of output features. Please choose from
- ``"raw"`` or ``"featmap"``: The feature map tensor from the
patch tokens with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
Defaults to ``"raw"``.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
use_abs_pos (bool): Whether to use absolute position embedding.
Defaults to True.
use_rel_pos (bool):Whether to use relative position embedding.
Defaults to True.
window_size (int): Window size for window attention. Defaults to 14.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo = {
**dict.fromkeys(
['b', 'base'], {
'embed_dims': 768,
'num_layers': 12,
'num_heads': 12,
'feedforward_channels': 3072,
'global_attn_indexes': [2, 5, 8, 11]
}),
**dict.fromkeys(
['l', 'large'], {
'embed_dims': 1024,
'num_layers': 24,
'num_heads': 16,
'feedforward_channels': 4096,
'global_attn_indexes': [5, 11, 17, 23]
}),
**dict.fromkeys(
['h', 'huge'], {
'embed_dims': 1280,
'num_layers': 32,
'num_heads': 16,
'feedforward_channels': 5120,
'global_attn_indexes': [7, 15, 23, 31]
}),
}
OUT_TYPES = {'raw', 'featmap', 'avg_featmap'}
def __init__(self,
arch: str = 'base',
img_size: int = 224,
patch_size: int = 16,
in_channels: int = 3,
out_channels: int = 256,
out_indices: int = -1,
out_type: str = 'raw',
drop_rate: float = 0.,
drop_path_rate: float = 0.,
qkv_bias: bool = True,
use_abs_pos: bool = True,
use_rel_pos: bool = True,
window_size: int = 14,
norm_cfg: dict = dict(type='LN', eps=1e-6),
frozen_stages: int = -1,
interpolate_mode: str = 'bicubic',
patch_cfg: dict = dict(),
layer_cfgs: dict = dict(),
init_cfg: Optional[dict] = None):
super().__init__(init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {
'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
}
assert isinstance(arch, dict) and essential_keys <= set(arch), \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.embed_dims = self.arch_settings['embed_dims']
self.num_layers = self.arch_settings['num_layers']
self.global_attn_indexes = self.arch_settings['global_attn_indexes']
self.img_size = to_2tuple(img_size)
# Set patch embedding
_patch_cfg = dict(
in_channels=in_channels,
input_size=img_size,
embed_dims=self.embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
)
_patch_cfg.update(patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
self.patch_resolution = self.patch_embed.init_out_size
# Set out type
if out_type not in self.OUT_TYPES:
raise ValueError(f'Unsupported `out_type` {out_type}, please '
f'choose from {self.OUT_TYPES}')
self.out_type = out_type
self.use_abs_pos = use_abs_pos
self.interpolate_mode = interpolate_mode
if use_abs_pos:
# Set position embedding
self.pos_embed = nn.Parameter(
torch.zeros(1, *self.patch_resolution, self.embed_dims))
self.drop_after_pos = nn.Dropout(p=drop_rate)
self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
if use_rel_pos:
self._register_load_state_dict_pre_hook(
self._prepare_relative_position)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = self.num_layers + index
assert 0 <= out_indices[i] <= self.num_layers, \
f'Invalid out_indices {index}'
self.out_indices = out_indices
# stochastic depth decay rule
dpr = np.linspace(0, drop_path_rate, self.num_layers)
self.layers = ModuleList()
if isinstance(layer_cfgs, dict):
layer_cfgs = [layer_cfgs] * self.num_layers
for i in range(self.num_layers):
_layer_cfg = dict(
embed_dims=self.embed_dims,
num_heads=self.arch_settings['num_heads'],
feedforward_channels=self.
arch_settings['feedforward_channels'],
drop_rate=drop_rate,
drop_path_rate=dpr[i],
qkv_bias=qkv_bias,
window_size=window_size
if i not in self.global_attn_indexes else 0,
input_size=self.patch_resolution,
use_rel_pos=use_rel_pos,
norm_cfg=norm_cfg)
_layer_cfg.update(layer_cfgs[i])
self.layers.append(TransformerEncoderLayer(**_layer_cfg))
self.out_channels = out_channels
if self.out_channels > 0:
self.channel_reduction = nn.Sequential(
nn.Conv2d(
self.embed_dims,
out_channels,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_channels, eps=1e-6),
nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_channels, eps=1e-6),
)
# freeze stages only when self.frozen_stages > 0
self.frozen_stages = frozen_stages
if self.frozen_stages > 0:
self._freeze_stages()
def init_weights(self):
super().init_weights()
if not (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=0.02)
def _freeze_stages(self):
# freeze position embedding
if self.pos_embed is not None:
self.pos_embed.requires_grad = False
# set dropout to eval model
self.drop_after_pos.eval()
# freeze patch embedding
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
# freeze layers
for i in range(1, self.frozen_stages + 1):
m = self.layers[i - 1]
m.eval()
for param in m.parameters():
param.requires_grad = False
# freeze channel_reduction module
if self.frozen_stages == self.num_layers and self.out_channels > 0:
m = self.channel_reduction
m.eval()
for param in m.parameters():
param.requires_grad = False
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
B = x.shape[0]
x, patch_resolution = self.patch_embed(x)
x = x.view(B, patch_resolution[0], patch_resolution[1],
self.embed_dims)
if self.use_abs_pos:
# 'resize_pos_embed' only supports 'pos_embed' with ndim==3, but
# in ViTSAM, the 'pos_embed' has 4 dimensions (1, H, W, C), so it
# is flattened. Besides, ViTSAM doesn't have any extra token.
resized_pos_embed = resize_pos_embed(
self.pos_embed.flatten(1, 2),
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=0)
x = x + resized_pos_embed.view(1, *patch_resolution,
self.embed_dims)
x = self.drop_after_pos(x)
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i in self.out_indices:
# (B, H, W, C) -> (B, C, H, W)
x_reshape = x.permute(0, 3, 1, 2)
if self.out_channels > 0:
x_reshape = self.channel_reduction(x_reshape)
outs.append(self._format_output(x_reshape))
return tuple(outs)
def _format_output(self, x) -> torch.Tensor:
if self.out_type == 'raw' or self.out_type == 'featmap':
return x
elif self.out_type == 'avg_featmap':
# (B, C, H, W) -> (B, C, N) -> (B, N, C)
x = x.flatten(2).permute(0, 2, 1)
return x.mean(dim=1)
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'pos_embed'
if name not in state_dict.keys():
return
ckpt_pos_embed_shape = state_dict[name].shape
if self.pos_embed.shape != ckpt_pos_embed_shape:
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info(
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
f'to {self.pos_embed.shape}.')
ckpt_pos_embed_shape = ckpt_pos_embed_shape[1:3]
pos_embed_shape = self.patch_embed.init_out_size
flattened_pos_embed = state_dict[name].flatten(1, 2)
resized_pos_embed = resize_pos_embed(flattened_pos_embed,
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode, 0)
state_dict[name] = resized_pos_embed.view(1, *pos_embed_shape,
self.embed_dims)
def _prepare_relative_position(self, state_dict, prefix, *args, **kwargs):
state_dict_model = self.state_dict()
all_keys = list(state_dict_model.keys())
for key in all_keys:
if 'rel_pos_' in key:
ckpt_key = prefix + key
if ckpt_key not in state_dict:
continue
relative_position_pretrained = state_dict[ckpt_key]
relative_position_current = state_dict_model[key]
L1, _ = relative_position_pretrained.size()
L2, _ = relative_position_current.size()
if L1 != L2:
new_rel_pos = F.interpolate(
relative_position_pretrained.reshape(1, L1,
-1).permute(
0, 2, 1),
size=L2,
mode='linear',
)
new_rel_pos = new_rel_pos.reshape(-1, L2).permute(1, 0)
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info(f'Resize the {ckpt_key} from '
f'{state_dict[ckpt_key].shape} to '
f'{new_rel_pos.shape}')
state_dict[ckpt_key] = new_rel_pos
def get_layer_depth(self, param_name: str, prefix: str = ''):
"""Get the layer-wise depth of a parameter.
Args:
param_name (str): The name of the parameter.
prefix (str): The prefix for the parameter.
Defaults to an empty string.
Returns:
Tuple[int, int]: The layer-wise depth and the num of layers.
Note:
The first depth is the stem module (``layer_depth=0``), and the
last depth is the subsequent module (``layer_depth=num_layers-1``)
"""
num_layers = self.num_layers + 2
if not param_name.startswith(prefix):
# For subsequent module like head
return num_layers - 1, num_layers
param_name = param_name[len(prefix):]
if param_name in ('cls_token', 'pos_embed'):
layer_depth = 0
elif param_name.startswith('patch_embed'):
layer_depth = 0
elif param_name.startswith('layers'):
layer_id = int(param_name.split('.')[1])
layer_depth = layer_id + 1
else:
layer_depth = num_layers - 1
return layer_depth, num_layers
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment