Commit cbc25585 authored by limm's avatar limm
Browse files

add mmpretrain/ part

parent 1baf0566
Pipeline #2801 canceled with stages
# Copyright (c) OpenMMLab. All rights reserved.
# Adapted from official impl at https://github.com/DingXiaoH/RepMLP.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
build_norm_layer)
from mmcv.cnn.bricks.transformer import PatchEmbed as _PatchEmbed
from mmengine.model import BaseModule, ModuleList, Sequential
from mmpretrain.models.utils import SELayer, to_2tuple
from mmpretrain.registry import MODELS
def fuse_bn(conv_or_fc, bn):
"""fuse conv and bn."""
std = (bn.running_var + bn.eps).sqrt()
tmp_weight = bn.weight / std
tmp_weight = tmp_weight.reshape(-1, 1, 1, 1)
if len(tmp_weight) == conv_or_fc.weight.size(0):
return (conv_or_fc.weight * tmp_weight,
bn.bias - bn.running_mean * bn.weight / std)
else:
# in RepMLPBlock, dim0 of fc3 weights and fc3_bn weights
# are different.
repeat_times = conv_or_fc.weight.size(0) // len(tmp_weight)
repeated = tmp_weight.repeat_interleave(repeat_times, 0)
fused_weight = conv_or_fc.weight * repeated
bias = bn.bias - bn.running_mean * bn.weight / std
fused_bias = (bias).repeat_interleave(repeat_times, 0)
return (fused_weight, fused_bias)
class PatchEmbed(_PatchEmbed):
"""Image to Patch Embedding.
Compared with default Patch Embedding(in ViT), Patch Embedding of RepMLP
have ReLu and do not convert output tensor into shape (N, L, C).
Args:
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
conv_type (str): The type of convolution
to generate patch embedding. Default: "Conv2d".
kernel_size (int): The kernel_size of embedding conv. Default: 16.
stride (int): The slide stride of embedding conv.
Default: 16.
padding (int | tuple | string): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Default: "corner".
dilation (int): The dilation rate of embedding conv. Default: 1.
bias (bool): Bias of embed conv. Default: True.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: None.
input_size (int | tuple | None): The size of input, which will be
used to calculate the out size. Only works when `dynamic_size`
is False. Default: None.
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
Default: None.
"""
def __init__(self, *args, **kwargs):
super(PatchEmbed, self).__init__(*args, **kwargs)
self.relu = nn.ReLU()
def forward(self, x):
"""
Args:
x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): The output tensor.
- out_size (tuple[int]): Spatial shape of x, arrange as
(out_h, out_w).
"""
if self.adaptive_padding:
x = self.adaptive_padding(x)
x = self.projection(x)
if self.norm is not None:
x = self.norm(x)
x = self.relu(x)
out_size = (x.shape[2], x.shape[3])
return x, out_size
class GlobalPerceptron(SELayer):
"""GlobalPerceptron implemented by using ``mmpretrain.modes.SELayer``.
Args:
input_channels (int): The number of input (and output) channels
in the GlobalPerceptron.
ratio (int): Squeeze ratio in GlobalPerceptron, the intermediate
channel will be ``make_divisible(channels // ratio, divisor)``.
"""
def __init__(self, input_channels: int, ratio: int, **kwargs) -> None:
super(GlobalPerceptron, self).__init__(
channels=input_channels,
ratio=ratio,
return_weight=True,
act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')),
**kwargs)
class RepMLPBlock(BaseModule):
"""Basic RepMLPNet, consists of PartitionPerceptron and GlobalPerceptron.
Args:
channels (int): The number of input and the output channels of the
block.
path_h (int): The height of patches.
path_w (int): The weidth of patches.
reparam_conv_kernels (Squeue(int) | None): The conv kernels in the
GlobalPerceptron. Default: None.
globalperceptron_ratio (int): The reducation ratio in the
GlobalPerceptron. Default: 4.
num_sharesets (int): The number of sharesets in the
PartitionPerceptron. Default 1.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True).
deploy (bool): Whether to switch the model structure to
deployment mode. Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
def __init__(self,
channels,
path_h,
path_w,
reparam_conv_kernels=None,
globalperceptron_ratio=4,
num_sharesets=1,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
deploy=False,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.deploy = deploy
self.channels = channels
self.num_sharesets = num_sharesets
self.path_h, self.path_w = path_h, path_w
# the input channel of fc3
self._path_vec_channles = path_h * path_w * num_sharesets
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.gp = GlobalPerceptron(
input_channels=channels, ratio=globalperceptron_ratio)
# using a conv layer to implement a fc layer
self.fc3 = build_conv_layer(
conv_cfg,
in_channels=self._path_vec_channles,
out_channels=self._path_vec_channles,
kernel_size=1,
stride=1,
padding=0,
bias=deploy,
groups=num_sharesets)
if deploy:
self.fc3_bn = nn.Identity()
else:
norm_layer = build_norm_layer(norm_cfg, num_sharesets)[1]
self.add_module('fc3_bn', norm_layer)
self.reparam_conv_kernels = reparam_conv_kernels
if not deploy and reparam_conv_kernels is not None:
for k in reparam_conv_kernels:
conv_branch = ConvModule(
in_channels=num_sharesets,
out_channels=num_sharesets,
kernel_size=k,
stride=1,
padding=k // 2,
norm_cfg=dict(type='BN', requires_grad=True),
groups=num_sharesets,
act_cfg=None)
self.__setattr__('repconv{}'.format(k), conv_branch)
def partition(self, x, h_parts, w_parts):
# convert (N, C, H, W) to (N, h_parts, w_parts, C, path_h, path_w)
x = x.reshape(-1, self.channels, h_parts, self.path_h, w_parts,
self.path_w)
x = x.permute(0, 2, 4, 1, 3, 5)
return x
def partition_affine(self, x, h_parts, w_parts):
"""perform Partition Perceptron."""
fc_inputs = x.reshape(-1, self._path_vec_channles, 1, 1)
out = self.fc3(fc_inputs)
out = out.reshape(-1, self.num_sharesets, self.path_h, self.path_w)
out = self.fc3_bn(out)
out = out.reshape(-1, h_parts, w_parts, self.num_sharesets,
self.path_h, self.path_w)
return out
def forward(self, inputs):
# Global Perceptron
global_vec = self.gp(inputs)
origin_shape = inputs.size()
h_parts = origin_shape[2] // self.path_h
w_parts = origin_shape[3] // self.path_w
partitions = self.partition(inputs, h_parts, w_parts)
# Channel Perceptron
fc3_out = self.partition_affine(partitions, h_parts, w_parts)
# perform Local Perceptron
if self.reparam_conv_kernels is not None and not self.deploy:
conv_inputs = partitions.reshape(-1, self.num_sharesets,
self.path_h, self.path_w)
conv_out = 0
for k in self.reparam_conv_kernels:
conv_branch = self.__getattr__('repconv{}'.format(k))
conv_out += conv_branch(conv_inputs)
conv_out = conv_out.reshape(-1, h_parts, w_parts,
self.num_sharesets, self.path_h,
self.path_w)
fc3_out += conv_out
# N, h_parts, w_parts, num_sharesets, out_h, out_w
fc3_out = fc3_out.permute(0, 3, 1, 4, 2, 5)
out = fc3_out.reshape(*origin_shape)
out = out * global_vec
return out
def get_equivalent_fc3(self):
"""get the equivalent fc3 weight and bias."""
fc_weight, fc_bias = fuse_bn(self.fc3, self.fc3_bn)
if self.reparam_conv_kernels is not None:
largest_k = max(self.reparam_conv_kernels)
largest_branch = self.__getattr__('repconv{}'.format(largest_k))
total_kernel, total_bias = fuse_bn(largest_branch.conv,
largest_branch.bn)
for k in self.reparam_conv_kernels:
if k != largest_k:
k_branch = self.__getattr__('repconv{}'.format(k))
kernel, bias = fuse_bn(k_branch.conv, k_branch.bn)
total_kernel += F.pad(kernel, [(largest_k - k) // 2] * 4)
total_bias += bias
rep_weight, rep_bias = self._convert_conv_to_fc(
total_kernel, total_bias)
final_fc3_weight = rep_weight.reshape_as(fc_weight) + fc_weight
final_fc3_bias = rep_bias + fc_bias
else:
final_fc3_weight = fc_weight
final_fc3_bias = fc_bias
return final_fc3_weight, final_fc3_bias
def local_inject(self):
"""inject the Local Perceptron into Partition Perceptron."""
self.deploy = True
# Locality Injection
fc3_weight, fc3_bias = self.get_equivalent_fc3()
# Remove Local Perceptron
if self.reparam_conv_kernels is not None:
for k in self.reparam_conv_kernels:
self.__delattr__('repconv{}'.format(k))
self.__delattr__('fc3')
self.__delattr__('fc3_bn')
self.fc3 = build_conv_layer(
self.conv_cfg,
self._path_vec_channles,
self._path_vec_channles,
1,
1,
0,
bias=True,
groups=self.num_sharesets)
self.fc3_bn = nn.Identity()
self.fc3.weight.data = fc3_weight
self.fc3.bias.data = fc3_bias
def _convert_conv_to_fc(self, conv_kernel, conv_bias):
"""convert conv_k1 to fc, which is still a conv_k2, and the k2 > k1."""
in_channels = torch.eye(self.path_h * self.path_w).repeat(
1, self.num_sharesets).reshape(self.path_h * self.path_w,
self.num_sharesets, self.path_h,
self.path_w).to(conv_kernel.device)
fc_k = F.conv2d(
in_channels,
conv_kernel,
padding=(conv_kernel.size(2) // 2, conv_kernel.size(3) // 2),
groups=self.num_sharesets)
fc_k = fc_k.reshape(self.path_w * self.path_w, self.num_sharesets *
self.path_h * self.path_w).t()
fc_bias = conv_bias.repeat_interleave(self.path_h * self.path_w)
return fc_k, fc_bias
class RepMLPNetUnit(BaseModule):
"""A basic unit in RepMLPNet : [REPMLPBlock + BN + ConvFFN + BN].
Args:
channels (int): The number of input and the output channels of the
unit.
path_h (int): The height of patches.
path_w (int): The weidth of patches.
reparam_conv_kernels (Squeue(int) | None): The conv kernels in the
GlobalPerceptron. Default: None.
globalperceptron_ratio (int): The reducation ratio in the
GlobalPerceptron. Default: 4.
num_sharesets (int): The number of sharesets in the
PartitionPerceptron. Default 1.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
deploy (bool): Whether to switch the model structure to
deployment mode. Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
def __init__(self,
channels,
path_h,
path_w,
reparam_conv_kernels,
globalperceptron_ratio,
norm_cfg=dict(type='BN', requires_grad=True),
ffn_expand=4,
num_sharesets=1,
deploy=False,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.repmlp_block = RepMLPBlock(
channels=channels,
path_h=path_h,
path_w=path_w,
reparam_conv_kernels=reparam_conv_kernels,
globalperceptron_ratio=globalperceptron_ratio,
num_sharesets=num_sharesets,
deploy=deploy)
self.ffn_block = ConvFFN(channels, channels * ffn_expand)
norm1 = build_norm_layer(norm_cfg, channels)[1]
self.add_module('norm1', norm1)
norm2 = build_norm_layer(norm_cfg, channels)[1]
self.add_module('norm2', norm2)
def forward(self, x):
y = x + self.repmlp_block(self.norm1(x))
out = y + self.ffn_block(self.norm2(y))
return out
class ConvFFN(nn.Module):
"""ConvFFN implemented by using point-wise convs."""
def __init__(self,
in_channels,
hidden_channels=None,
out_channels=None,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='GELU')):
super().__init__()
out_features = out_channels or in_channels
hidden_features = hidden_channels or in_channels
self.ffn_fc1 = ConvModule(
in_channels=in_channels,
out_channels=hidden_features,
kernel_size=1,
stride=1,
padding=0,
norm_cfg=norm_cfg,
act_cfg=None)
self.ffn_fc2 = ConvModule(
in_channels=hidden_features,
out_channels=out_features,
kernel_size=1,
stride=1,
padding=0,
norm_cfg=norm_cfg,
act_cfg=None)
self.act = build_activation_layer(act_cfg)
def forward(self, x):
x = self.ffn_fc1(x)
x = self.act(x)
x = self.ffn_fc2(x)
return x
@MODELS.register_module()
class RepMLPNet(BaseModule):
"""RepMLPNet backbone.
A PyTorch impl of : `RepMLP: Re-parameterizing Convolutions into
Fully-connected Layers for Image Recognition
<https://arxiv.org/abs/2105.01883>`_
Args:
arch (str | dict): RepMLP architecture. If use string, choose
from 'base' and 'b'. If use dict, it should have below keys:
- channels (List[int]): Number of blocks in each stage.
- depths (List[int]): The number of blocks in each branch.
- sharesets_nums (List[int]): RepVGG Block that declares
the need to apply group convolution.
img_size (int | tuple): The size of input image. Defaults: 224.
in_channels (int): Number of input image channels. Default: 3.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 4.
out_indices (Sequence[int]): Output from which stages.
Default: ``(3, )``.
reparam_conv_kernels (Squeue(int) | None): The conv kernels in the
GlobalPerceptron. Default: None.
globalperceptron_ratio (int): The reducation ratio in the
GlobalPerceptron. Default: 4.
num_sharesets (int): The number of sharesets in the
PartitionPerceptron. Default 1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
Default: dict(type='BN', requires_grad=True).
patch_cfg (dict): Extra config dict for patch embedding.
Defaults to an empty dict.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
deploy (bool): Whether to switch the model structure to deployment
mode. Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
arch_zoo = {
**dict.fromkeys(['b', 'base'],
{'channels': [96, 192, 384, 768],
'depths': [2, 2, 12, 2],
'sharesets_nums': [1, 4, 32, 128]}),
} # yapf: disable
num_extra_tokens = 0 # there is no cls-token in RepMLP
def __init__(self,
arch,
img_size=224,
in_channels=3,
patch_size=4,
out_indices=(3, ),
reparam_conv_kernels=(3, ),
globalperceptron_ratio=4,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
patch_cfg=dict(),
final_norm=True,
deploy=False,
init_cfg=None):
super(RepMLPNet, self).__init__(init_cfg=init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {'channels', 'depths', 'sharesets_nums'}
assert isinstance(arch, dict) and set(arch) == essential_keys, \
f'Custom arch needs a dict with keys {essential_keys}.'
self.arch_settings = arch
self.img_size = to_2tuple(img_size)
self.patch_size = to_2tuple(patch_size)
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.num_stage = len(self.arch_settings['channels'])
for value in self.arch_settings.values():
assert isinstance(value, list) and len(value) == self.num_stage, (
'Length of setting item in arch dict must be type of list and'
' have the same length.')
self.channels = self.arch_settings['channels']
self.depths = self.arch_settings['depths']
self.sharesets_nums = self.arch_settings['sharesets_nums']
_patch_cfg = dict(
in_channels=in_channels,
input_size=self.img_size,
embed_dims=self.channels[0],
conv_type='Conv2d',
kernel_size=self.patch_size,
stride=self.patch_size,
norm_cfg=self.norm_cfg,
bias=False)
_patch_cfg.update(patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
self.patch_resolution = self.patch_embed.init_out_size
self.patch_hs = [
self.patch_resolution[0] // 2**i for i in range(self.num_stage)
]
self.patch_ws = [
self.patch_resolution[1] // 2**i for i in range(self.num_stage)
]
self.stages = ModuleList()
self.downsample_layers = ModuleList()
for stage_idx in range(self.num_stage):
# make stage layers
_stage_cfg = dict(
channels=self.channels[stage_idx],
path_h=self.patch_hs[stage_idx],
path_w=self.patch_ws[stage_idx],
reparam_conv_kernels=reparam_conv_kernels,
globalperceptron_ratio=globalperceptron_ratio,
norm_cfg=self.norm_cfg,
ffn_expand=4,
num_sharesets=self.sharesets_nums[stage_idx],
deploy=deploy)
stage_blocks = [
RepMLPNetUnit(**_stage_cfg)
for _ in range(self.depths[stage_idx])
]
self.stages.append(Sequential(*stage_blocks))
# make downsample layers
if stage_idx < self.num_stage - 1:
self.downsample_layers.append(
ConvModule(
in_channels=self.channels[stage_idx],
out_channels=self.channels[stage_idx + 1],
kernel_size=2,
stride=2,
padding=0,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
inplace=True))
self.out_indice = out_indices
if final_norm:
norm_layer = build_norm_layer(norm_cfg, self.channels[-1])[1]
else:
norm_layer = nn.Identity()
self.add_module('final_norm', norm_layer)
def forward(self, x):
assert x.shape[2:] == self.img_size, \
"The Rep-MLP doesn't support dynamic input shape. " \
f'Please input images with shape {self.img_size}'
outs = []
x, _ = self.patch_embed(x)
for i, stage in enumerate(self.stages):
x = stage(x)
# downsample after each stage except last stage
if i < len(self.stages) - 1:
downsample = self.downsample_layers[i]
x = downsample(x)
if i in self.out_indice:
if self.final_norm and i == len(self.stages) - 1:
out = self.final_norm(x)
else:
out = x
outs.append(out)
return tuple(outs)
def switch_to_deploy(self):
for m in self.modules():
if hasattr(m, 'local_inject'):
m.local_inject()
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
build_norm_layer)
from mmengine.model import BaseModule, Sequential
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from torch import nn
from mmpretrain.registry import MODELS
from ..utils.se_layer import SELayer
from .base_backbone import BaseBackbone
class RepVGGBlock(BaseModule):
"""RepVGG block for RepVGG backbone.
Args:
in_channels (int): The input channels of the block.
out_channels (int): The output channels of the block.
stride (int): Stride of the 3x3 and 1x1 convolution layer. Default: 1.
padding (int): Padding of the 3x3 convolution layer.
dilation (int): Dilation of the 3x3 convolution layer.
groups (int): Groups of the 3x3 and 1x1 convolution layer. Default: 1.
padding_mode (str): Padding mode of the 3x3 convolution layer.
Default: 'zeros'.
se_cfg (None or dict): The configuration of the se module.
Default: None.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
deploy (bool): Whether to switch the model structure to
deployment mode. Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
def __init__(self,
in_channels,
out_channels,
stride=1,
padding=1,
dilation=1,
groups=1,
padding_mode='zeros',
se_cfg=None,
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
deploy=False,
init_cfg=None):
super(RepVGGBlock, self).__init__(init_cfg)
assert se_cfg is None or isinstance(se_cfg, dict)
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.se_cfg = se_cfg
self.with_cp = with_cp
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.deploy = deploy
if deploy:
self.branch_reparam = build_conv_layer(
conv_cfg,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=True,
padding_mode=padding_mode)
else:
# judge if input shape and output shape are the same.
# If true, add a normalized identity shortcut.
if out_channels == in_channels and stride == 1 and \
padding == dilation:
self.branch_norm = build_norm_layer(norm_cfg, in_channels)[1]
else:
self.branch_norm = None
self.branch_3x3 = self.create_conv_bn(
kernel_size=3,
dilation=dilation,
padding=padding,
)
self.branch_1x1 = self.create_conv_bn(kernel_size=1)
if se_cfg is not None:
self.se_layer = SELayer(channels=out_channels, **se_cfg)
else:
self.se_layer = None
self.act = build_activation_layer(act_cfg)
def create_conv_bn(self, kernel_size, dilation=1, padding=0):
conv_bn = Sequential()
conv_bn.add_module(
'conv',
build_conv_layer(
self.conv_cfg,
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=kernel_size,
stride=self.stride,
dilation=dilation,
padding=padding,
groups=self.groups,
bias=False))
conv_bn.add_module(
'norm',
build_norm_layer(self.norm_cfg, num_features=self.out_channels)[1])
return conv_bn
def forward(self, x):
def _inner_forward(inputs):
if self.deploy:
return self.branch_reparam(inputs)
if self.branch_norm is None:
branch_norm_out = 0
else:
branch_norm_out = self.branch_norm(inputs)
inner_out = self.branch_3x3(inputs) + self.branch_1x1(
inputs) + branch_norm_out
if self.se_cfg is not None:
inner_out = self.se_layer(inner_out)
return inner_out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.act(out)
return out
def switch_to_deploy(self):
"""Switch the model structure from training mode to deployment mode."""
if self.deploy:
return
assert self.norm_cfg['type'] == 'BN', \
"Switch is not allowed when norm_cfg['type'] != 'BN'."
reparam_weight, reparam_bias = self.reparameterize()
self.branch_reparam = build_conv_layer(
self.conv_cfg,
self.in_channels,
self.out_channels,
kernel_size=3,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
bias=True)
self.branch_reparam.weight.data = reparam_weight
self.branch_reparam.bias.data = reparam_bias
for param in self.parameters():
param.detach_()
delattr(self, 'branch_3x3')
delattr(self, 'branch_1x1')
delattr(self, 'branch_norm')
self.deploy = True
def reparameterize(self):
"""Fuse all the parameters of all branches.
Returns:
tuple[torch.Tensor, torch.Tensor]: Parameters after fusion of all
branches. the first element is the weights and the second is
the bias.
"""
weight_3x3, bias_3x3 = self._fuse_conv_bn(self.branch_3x3)
weight_1x1, bias_1x1 = self._fuse_conv_bn(self.branch_1x1)
# pad a conv1x1 weight to a conv3x3 weight
weight_1x1 = F.pad(weight_1x1, [1, 1, 1, 1], value=0)
weight_norm, bias_norm = 0, 0
if self.branch_norm:
tmp_conv_bn = self._norm_to_conv3x3(self.branch_norm)
weight_norm, bias_norm = self._fuse_conv_bn(tmp_conv_bn)
return (weight_3x3 + weight_1x1 + weight_norm,
bias_3x3 + bias_1x1 + bias_norm)
def _fuse_conv_bn(self, branch):
"""Fuse the parameters in a branch with a conv and bn.
Args:
branch (mmcv.runner.Sequential): A branch with conv and bn.
Returns:
tuple[torch.Tensor, torch.Tensor]: The parameters obtained after
fusing the parameters of conv and bn in one branch.
The first element is the weight and the second is the bias.
"""
if branch is None:
return 0, 0
conv_weight = branch.conv.weight
running_mean = branch.norm.running_mean
running_var = branch.norm.running_var
gamma = branch.norm.weight
beta = branch.norm.bias
eps = branch.norm.eps
std = (running_var + eps).sqrt()
fused_weight = (gamma / std).reshape(-1, 1, 1, 1) * conv_weight
fused_bias = -running_mean * gamma / std + beta
return fused_weight, fused_bias
def _norm_to_conv3x3(self, branch_nrom):
"""Convert a norm layer to a conv3x3-bn sequence.
Args:
branch (nn.BatchNorm2d): A branch only with bn in the block.
Returns:
tmp_conv3x3 (mmcv.runner.Sequential): a sequential with conv3x3 and
bn.
"""
input_dim = self.in_channels // self.groups
conv_weight = torch.zeros((self.in_channels, input_dim, 3, 3),
dtype=branch_nrom.weight.dtype)
for i in range(self.in_channels):
conv_weight[i, i % input_dim, 1, 1] = 1
conv_weight = conv_weight.to(branch_nrom.weight.device)
tmp_conv3x3 = self.create_conv_bn(kernel_size=3)
tmp_conv3x3.conv.weight.data = conv_weight
tmp_conv3x3.norm = branch_nrom
return tmp_conv3x3
class MTSPPF(BaseModule):
"""MTSPPF block for YOLOX-PAI RepVGG backbone.
Args:
in_channels (int): The input channels of the block.
out_channels (int): The output channels of the block.
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
kernel_size (int): Kernel size of pooling. Default: 5.
"""
def __init__(self,
in_channels,
out_channels,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
kernel_size=5):
super().__init__()
hidden_features = in_channels // 2 # hidden channels
self.conv1 = ConvModule(
in_channels,
hidden_features,
1,
stride=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv2 = ConvModule(
hidden_features * 4,
out_channels,
1,
stride=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.maxpool = nn.MaxPool2d(
kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
def forward(self, x):
x = self.conv1(x)
y1 = self.maxpool(x)
y2 = self.maxpool(y1)
return self.conv2(torch.cat([x, y1, y2, self.maxpool(y2)], 1))
@MODELS.register_module()
class RepVGG(BaseBackbone):
"""RepVGG backbone.
A PyTorch impl of : `RepVGG: Making VGG-style ConvNets Great Again
<https://arxiv.org/abs/2101.03697>`_
Args:
arch (str | dict): RepVGG architecture. If use string, choose from
'A0', 'A1`', 'A2', 'B0', 'B1', 'B1g2', 'B1g4', 'B2', 'B2g2',
'B2g4', 'B3', 'B3g2', 'B3g4' or 'D2se'. If use dict, it should
have below keys:
- **num_blocks** (Sequence[int]): Number of blocks in each stage.
- **width_factor** (Sequence[float]): Width deflator in each stage.
- **group_layer_map** (dict | None): RepVGG Block that declares
the need to apply group convolution.
- **se_cfg** (dict | None): SE Layer config.
- **stem_channels** (int, optional): The stem channels, the final
stem channels will be
``min(stem_channels, base_channels*width_factor[0])``.
If not set here, 64 is used by default in the code.
in_channels (int): Number of input image channels. Defaults to 3.
base_channels (int): Base channels of RepVGG backbone, work with
width_factor together. Defaults to 64.
out_indices (Sequence[int]): Output from which stages.
Defaults to ``(3, )``.
strides (Sequence[int]): Strides of the first block of each stage.
Defaults to ``(2, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Defaults to ``(1, 1, 1, 1)``.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters. Defaults to -1.
conv_cfg (dict | None): The config dict for conv layers.
Defaults to None.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='BN')``.
act_cfg (dict): Config dict for activation layer.
Defaults to ``dict(type='ReLU')``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
deploy (bool): Whether to switch the model structure to deployment
mode. Defaults to False.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
add_ppf (bool): Whether to use the MTSPPF block. Defaults to False.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
g2_layer_map = {layer: 2 for layer in groupwise_layers}
g4_layer_map = {layer: 4 for layer in groupwise_layers}
arch_settings = {
'A0':
dict(
num_blocks=[2, 4, 14, 1],
width_factor=[0.75, 0.75, 0.75, 2.5],
group_layer_map=None,
se_cfg=None),
'A1':
dict(
num_blocks=[2, 4, 14, 1],
width_factor=[1, 1, 1, 2.5],
group_layer_map=None,
se_cfg=None),
'A2':
dict(
num_blocks=[2, 4, 14, 1],
width_factor=[1.5, 1.5, 1.5, 2.75],
group_layer_map=None,
se_cfg=None),
'B0':
dict(
num_blocks=[4, 6, 16, 1],
width_factor=[1, 1, 1, 2.5],
group_layer_map=None,
se_cfg=None,
stem_channels=64),
'B1':
dict(
num_blocks=[4, 6, 16, 1],
width_factor=[2, 2, 2, 4],
group_layer_map=None,
se_cfg=None),
'B1g2':
dict(
num_blocks=[4, 6, 16, 1],
width_factor=[2, 2, 2, 4],
group_layer_map=g2_layer_map,
se_cfg=None),
'B1g4':
dict(
num_blocks=[4, 6, 16, 1],
width_factor=[2, 2, 2, 4],
group_layer_map=g4_layer_map,
se_cfg=None),
'B2':
dict(
num_blocks=[4, 6, 16, 1],
width_factor=[2.5, 2.5, 2.5, 5],
group_layer_map=None,
se_cfg=None),
'B2g2':
dict(
num_blocks=[4, 6, 16, 1],
width_factor=[2.5, 2.5, 2.5, 5],
group_layer_map=g2_layer_map,
se_cfg=None),
'B2g4':
dict(
num_blocks=[4, 6, 16, 1],
width_factor=[2.5, 2.5, 2.5, 5],
group_layer_map=g4_layer_map,
se_cfg=None),
'B3':
dict(
num_blocks=[4, 6, 16, 1],
width_factor=[3, 3, 3, 5],
group_layer_map=None,
se_cfg=None),
'B3g2':
dict(
num_blocks=[4, 6, 16, 1],
width_factor=[3, 3, 3, 5],
group_layer_map=g2_layer_map,
se_cfg=None),
'B3g4':
dict(
num_blocks=[4, 6, 16, 1],
width_factor=[3, 3, 3, 5],
group_layer_map=g4_layer_map,
se_cfg=None),
'D2se':
dict(
num_blocks=[8, 14, 24, 1],
width_factor=[2.5, 2.5, 2.5, 5],
group_layer_map=None,
se_cfg=dict(ratio=16, divisor=1)),
'yolox-pai-small':
dict(
num_blocks=[3, 5, 7, 3],
width_factor=[1, 1, 1, 1],
group_layer_map=None,
se_cfg=None,
stem_channels=32),
}
def __init__(self,
arch,
in_channels=3,
base_channels=64,
out_indices=(3, ),
strides=(2, 2, 2, 2),
dilations=(1, 1, 1, 1),
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
with_cp=False,
deploy=False,
norm_eval=False,
add_ppf=False,
init_cfg=[
dict(type='Kaiming', layer=['Conv2d']),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]):
super(RepVGG, self).__init__(init_cfg)
if isinstance(arch, str):
assert arch in self.arch_settings, \
f'"arch": "{arch}" is not one of the arch_settings'
arch = self.arch_settings[arch]
elif not isinstance(arch, dict):
raise TypeError('Expect "arch" to be either a string '
f'or a dict, got {type(arch)}')
assert len(arch['num_blocks']) == len(
arch['width_factor']) == len(strides) == len(dilations)
assert max(out_indices) < len(arch['num_blocks'])
if arch['group_layer_map'] is not None:
assert max(arch['group_layer_map'].keys()) <= sum(
arch['num_blocks'])
if arch['se_cfg'] is not None:
assert isinstance(arch['se_cfg'], dict)
self.base_channels = base_channels
self.arch = arch
self.in_channels = in_channels
self.out_indices = out_indices
self.strides = strides
self.dilations = dilations
self.deploy = deploy
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.with_cp = with_cp
self.norm_eval = norm_eval
# defaults to 64 to prevert BC-breaking if stem_channels
# not in arch dict;
# the stem channels should not be larger than that of stage1.
channels = min(
arch.get('stem_channels', 64),
int(self.base_channels * self.arch['width_factor'][0]))
self.stem = RepVGGBlock(
self.in_channels,
channels,
stride=2,
se_cfg=arch['se_cfg'],
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
deploy=deploy)
next_create_block_idx = 1
self.stages = []
for i in range(len(arch['num_blocks'])):
num_blocks = self.arch['num_blocks'][i]
stride = self.strides[i]
dilation = self.dilations[i]
out_channels = int(self.base_channels * 2**i *
self.arch['width_factor'][i])
stage, next_create_block_idx = self._make_stage(
channels, out_channels, num_blocks, stride, dilation,
next_create_block_idx, init_cfg)
stage_name = f'stage_{i + 1}'
self.add_module(stage_name, stage)
self.stages.append(stage_name)
channels = out_channels
if add_ppf:
self.ppf = MTSPPF(
out_channels,
out_channels,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
kernel_size=5)
else:
self.ppf = nn.Identity()
def _make_stage(self, in_channels, out_channels, num_blocks, stride,
dilation, next_create_block_idx, init_cfg):
strides = [stride] + [1] * (num_blocks - 1)
dilations = [dilation] * num_blocks
blocks = []
for i in range(num_blocks):
groups = self.arch['group_layer_map'].get(
next_create_block_idx,
1) if self.arch['group_layer_map'] is not None else 1
blocks.append(
RepVGGBlock(
in_channels,
out_channels,
stride=strides[i],
padding=dilations[i],
dilation=dilations[i],
groups=groups,
se_cfg=self.arch['se_cfg'],
with_cp=self.with_cp,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
deploy=self.deploy,
init_cfg=init_cfg))
in_channels = out_channels
next_create_block_idx += 1
return Sequential(*blocks), next_create_block_idx
def forward(self, x):
x = self.stem(x)
outs = []
for i, stage_name in enumerate(self.stages):
stage = getattr(self, stage_name)
x = stage(x)
if i + 1 == len(self.stages):
x = self.ppf(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.stem.eval()
for param in self.stem.parameters():
param.requires_grad = False
for i in range(self.frozen_stages):
stage = getattr(self, f'stage_{i+1}')
stage.eval()
for param in stage.parameters():
param.requires_grad = False
def train(self, mode=True):
super(RepVGG, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
def switch_to_deploy(self):
for m in self.modules():
if isinstance(m, RepVGGBlock):
m.switch_to_deploy()
self.deploy = True
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmengine.model import ModuleList, Sequential
from mmpretrain.registry import MODELS
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResNet
class Bottle2neck(_Bottleneck):
expansion = 4
def __init__(self,
in_channels,
out_channels,
scales=4,
base_width=26,
base_channels=64,
stage_type='normal',
**kwargs):
"""Bottle2neck block for Res2Net."""
super(Bottle2neck, self).__init__(in_channels, out_channels, **kwargs)
assert scales > 1, 'Res2Net degenerates to ResNet when scales = 1.'
mid_channels = out_channels // self.expansion
width = int(math.floor(mid_channels * (base_width / base_channels)))
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, width * scales, postfix=1)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, self.out_channels, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
self.in_channels,
width * scales,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
if stage_type == 'stage':
self.pool = nn.AvgPool2d(
kernel_size=3, stride=self.conv2_stride, padding=1)
self.convs = ModuleList()
self.bns = ModuleList()
for i in range(scales - 1):
self.convs.append(
build_conv_layer(
self.conv_cfg,
width,
width,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
bias=False))
self.bns.append(
build_norm_layer(self.norm_cfg, width, postfix=i + 1)[1])
self.conv3 = build_conv_layer(
self.conv_cfg,
width * scales,
self.out_channels,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
self.stage_type = stage_type
self.scales = scales
self.width = width
delattr(self, 'conv2')
delattr(self, self.norm2_name)
def forward(self, x):
"""Forward function."""
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
spx = torch.split(out, self.width, 1)
sp = self.convs[0](spx[0].contiguous())
sp = self.relu(self.bns[0](sp))
out = sp
for i in range(1, self.scales - 1):
if self.stage_type == 'stage':
sp = spx[i]
else:
sp = sp + spx[i]
sp = self.convs[i](sp.contiguous())
sp = self.relu(self.bns[i](sp))
out = torch.cat((out, sp), 1)
if self.stage_type == 'normal' and self.scales != 1:
out = torch.cat((out, spx[self.scales - 1]), 1)
elif self.stage_type == 'stage' and self.scales != 1:
out = torch.cat((out, self.pool(spx[self.scales - 1])), 1)
out = self.conv3(out)
out = self.norm3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
class Res2Layer(Sequential):
"""Res2Layer to build Res2Net style backbone.
Args:
block (nn.Module): block used to build ResLayer.
inplanes (int): inplanes of block.
planes (int): planes of block.
num_blocks (int): number of blocks.
stride (int): stride of the first block. Default: 1
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottle2neck. Defaults to True.
conv_cfg (dict): dictionary to construct and config conv layer.
Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
scales (int): Scales used in Res2Net. Default: 4
base_width (int): Basic width of each scale. Default: 26
drop_path_rate (float or np.ndarray): stochastic depth rate.
Default: 0.
"""
def __init__(self,
block,
in_channels,
out_channels,
num_blocks,
stride=1,
avg_down=True,
conv_cfg=None,
norm_cfg=dict(type='BN'),
scales=4,
base_width=26,
drop_path_rate=0.0,
**kwargs):
self.block = block
if isinstance(drop_path_rate, float):
drop_path_rate = [drop_path_rate] * num_blocks
assert len(drop_path_rate
) == num_blocks, 'Please check the length of drop_path_rate'
downsample = None
if stride != 1 or in_channels != out_channels:
if avg_down:
downsample = nn.Sequential(
nn.AvgPool2d(
kernel_size=stride,
stride=stride,
ceil_mode=True,
count_include_pad=False),
build_conv_layer(
conv_cfg,
in_channels,
out_channels,
kernel_size=1,
stride=1,
bias=False),
build_norm_layer(norm_cfg, out_channels)[1],
)
else:
downsample = nn.Sequential(
build_conv_layer(
conv_cfg,
in_channels,
out_channels,
kernel_size=1,
stride=stride,
bias=False),
build_norm_layer(norm_cfg, out_channels)[1],
)
layers = []
layers.append(
block(
in_channels=in_channels,
out_channels=out_channels,
stride=stride,
downsample=downsample,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
scales=scales,
base_width=base_width,
stage_type='stage',
drop_path_rate=drop_path_rate[0],
**kwargs))
in_channels = out_channels
for i in range(1, num_blocks):
layers.append(
block(
in_channels=in_channels,
out_channels=out_channels,
stride=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
scales=scales,
base_width=base_width,
drop_path_rate=drop_path_rate[i],
**kwargs))
super(Res2Layer, self).__init__(*layers)
@MODELS.register_module()
class Res2Net(ResNet):
"""Res2Net backbone.
A PyTorch implement of : `Res2Net: A New Multi-scale Backbone
Architecture <https://arxiv.org/pdf/1904.01169.pdf>`_
Args:
depth (int): Depth of Res2Net, choose from {50, 101, 152}.
scales (int): Scales used in Res2Net. Defaults to 4.
base_width (int): Basic width of each scale. Defaults to 26.
in_channels (int): Number of input image channels. Defaults to 3.
num_stages (int): Number of Res2Net stages. Defaults to 4.
strides (Sequence[int]): Strides of the first block of each stage.
Defaults to ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Defaults to ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages.
Defaults to ``(3, )``.
style (str): "pytorch" or "caffe". If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer. Defaults to "pytorch".
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Defaults to True.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottle2neck. Defaults to True.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_cfg (dict): Dictionary to construct and config norm layer.
Defaults to ``dict(type='BN', requires_grad=True)``.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Defaults to True.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
Example:
>>> from mmpretrain.models import Res2Net
>>> import torch
>>> model = Res2Net(depth=50,
... scales=4,
... base_width=26,
... out_indices=(0, 1, 2, 3))
>>> model.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = model.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 256, 8, 8)
(1, 512, 4, 4)
(1, 1024, 2, 2)
(1, 2048, 1, 1)
"""
arch_settings = {
50: (Bottle2neck, (3, 4, 6, 3)),
101: (Bottle2neck, (3, 4, 23, 3)),
152: (Bottle2neck, (3, 8, 36, 3))
}
def __init__(self,
scales=4,
base_width=26,
style='pytorch',
deep_stem=True,
avg_down=True,
init_cfg=None,
**kwargs):
self.scales = scales
self.base_width = base_width
super(Res2Net, self).__init__(
style=style,
deep_stem=deep_stem,
avg_down=avg_down,
init_cfg=init_cfg,
**kwargs)
def make_res_layer(self, **kwargs):
return Res2Layer(
scales=self.scales,
base_width=self.base_width,
base_channels=self.base_channels,
**kwargs)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmpretrain.registry import MODELS
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResLayer, ResNetV1d
class RSoftmax(nn.Module):
"""Radix Softmax module in ``SplitAttentionConv2d``.
Args:
radix (int): Radix of input.
groups (int): Groups of input.
"""
def __init__(self, radix, groups):
super().__init__()
self.radix = radix
self.groups = groups
def forward(self, x):
batch = x.size(0)
if self.radix > 1:
x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
x = F.softmax(x, dim=1)
x = x.reshape(batch, -1)
else:
x = torch.sigmoid(x)
return x
class SplitAttentionConv2d(nn.Module):
"""Split-Attention Conv2d.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int | tuple[int]): Same as nn.Conv2d.
stride (int | tuple[int]): Same as nn.Conv2d.
padding (int | tuple[int]): Same as nn.Conv2d.
dilation (int | tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
radix (int): Radix of SpltAtConv2d. Default: 2
reduction_factor (int): Reduction factor of SplitAttentionConv2d.
Default: 4.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: None.
"""
def __init__(self,
in_channels,
channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
radix=2,
reduction_factor=4,
conv_cfg=None,
norm_cfg=dict(type='BN')):
super(SplitAttentionConv2d, self).__init__()
inter_channels = max(in_channels * radix // reduction_factor, 32)
self.radix = radix
self.groups = groups
self.channels = channels
self.conv = build_conv_layer(
conv_cfg,
in_channels,
channels * radix,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups * radix,
bias=False)
self.norm0_name, norm0 = build_norm_layer(
norm_cfg, channels * radix, postfix=0)
self.add_module(self.norm0_name, norm0)
self.relu = nn.ReLU(inplace=True)
self.fc1 = build_conv_layer(
None, channels, inter_channels, 1, groups=self.groups)
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, inter_channels, postfix=1)
self.add_module(self.norm1_name, norm1)
self.fc2 = build_conv_layer(
None, inter_channels, channels * radix, 1, groups=self.groups)
self.rsoftmax = RSoftmax(radix, groups)
@property
def norm0(self):
return getattr(self, self.norm0_name)
@property
def norm1(self):
return getattr(self, self.norm1_name)
def forward(self, x):
x = self.conv(x)
x = self.norm0(x)
x = self.relu(x)
batch, rchannel = x.shape[:2]
if self.radix > 1:
splits = x.view(batch, self.radix, -1, *x.shape[2:])
gap = splits.sum(dim=1)
else:
gap = x
gap = F.adaptive_avg_pool2d(gap, 1)
gap = self.fc1(gap)
gap = self.norm1(gap)
gap = self.relu(gap)
atten = self.fc2(gap)
atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
if self.radix > 1:
attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
out = torch.sum(attens * splits, dim=1)
else:
out = atten * x
return out.contiguous()
class Bottleneck(_Bottleneck):
"""Bottleneck block for ResNeSt.
Args:
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
groups (int): Groups of conv2.
width_per_group (int): Width per group of conv2. 64x4d indicates
``groups=64, width_per_group=4`` and 32x8d indicates
``groups=32, width_per_group=8``.
radix (int): Radix of SpltAtConv2d. Default: 2
reduction_factor (int): Reduction factor of SplitAttentionConv2d.
Default: 4.
avg_down_stride (bool): Whether to use average pool for stride in
Bottleneck. Default: True.
stride (int): stride of the block. Default: 1
dilation (int): dilation of convolution. Default: 1
downsample (nn.Module, optional): downsample operation on identity
branch. Default: None
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
conv_cfg (dict, optional): dictionary to construct and config conv
layer. Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
def __init__(self,
in_channels,
out_channels,
groups=1,
width_per_group=4,
base_channels=64,
radix=2,
reduction_factor=4,
avg_down_stride=True,
**kwargs):
super(Bottleneck, self).__init__(in_channels, out_channels, **kwargs)
self.groups = groups
self.width_per_group = width_per_group
# For ResNet bottleneck, middle channels are determined by expansion
# and out_channels, but for ResNeXt bottleneck, it is determined by
# groups and width_per_group and the stage it is located in.
if groups != 1:
assert self.mid_channels % base_channels == 0
self.mid_channels = (
groups * width_per_group * self.mid_channels // base_channels)
self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, self.mid_channels, postfix=1)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, self.out_channels, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
self.in_channels,
self.mid_channels,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = SplitAttentionConv2d(
self.mid_channels,
self.mid_channels,
kernel_size=3,
stride=1 if self.avg_down_stride else self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
radix=radix,
reduction_factor=reduction_factor,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg)
delattr(self, self.norm2_name)
if self.avg_down_stride:
self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
self.conv3 = build_conv_layer(
self.conv_cfg,
self.mid_channels,
self.out_channels,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
def forward(self, x):
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
if self.avg_down_stride:
out = self.avd_layer(out)
out = self.conv3(out)
out = self.norm3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
@MODELS.register_module()
class ResNeSt(ResNetV1d):
"""ResNeSt backbone.
Please refer to the `paper <https://arxiv.org/pdf/2004.08955.pdf>`__ for
details.
Args:
depth (int): Network depth, from {50, 101, 152, 200}.
groups (int): Groups of conv2 in Bottleneck. Default: 32.
width_per_group (int): Width per group of conv2 in Bottleneck.
Default: 4.
radix (int): Radix of SpltAtConv2d. Default: 2
reduction_factor (int): Reduction factor of SplitAttentionConv2d.
Default: 4.
avg_down_stride (bool): Whether to use average pool for stride in
Bottleneck. Default: True.
in_channels (int): Number of input image channels. Default: 3.
stem_channels (int): Output channels of the stem layer. Default: 64.
num_stages (int): Stages of the network. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Default: False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
"""
arch_settings = {
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3)),
200: (Bottleneck, (3, 24, 36, 3)),
269: (Bottleneck, (3, 30, 48, 8))
}
def __init__(self,
depth,
groups=1,
width_per_group=4,
radix=2,
reduction_factor=4,
avg_down_stride=True,
**kwargs):
self.groups = groups
self.width_per_group = width_per_group
self.radix = radix
self.reduction_factor = reduction_factor
self.avg_down_stride = avg_down_stride
super(ResNeSt, self).__init__(depth=depth, **kwargs)
def make_res_layer(self, **kwargs):
return ResLayer(
groups=self.groups,
width_per_group=self.width_per_group,
base_channels=self.base_channels,
radix=self.radix,
reduction_factor=self.reduction_factor,
avg_down_stride=self.avg_down_stride,
**kwargs)
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
build_norm_layer)
from mmcv.cnn.bricks import DropPath
from mmengine.model import BaseModule
from mmengine.model.weight_init import constant_init
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone
eps = 1.0e-5
class BasicBlock(BaseModule):
"""BasicBlock for ResNet.
Args:
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
expansion (int): The ratio of ``out_channels/mid_channels`` where
``mid_channels`` is the output channels of conv1. This is a
reserved argument in BasicBlock and should always be 1. Default: 1.
stride (int): stride of the block. Default: 1
dilation (int): dilation of convolution. Default: 1
downsample (nn.Module, optional): downsample operation on identity
branch. Default: None.
style (str): `pytorch` or `caffe`. It is unused and reserved for
unified API with Bottleneck.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
conv_cfg (dict, optional): dictionary to construct and config conv
layer. Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
"""
def __init__(self,
in_channels,
out_channels,
expansion=1,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
drop_path_rate=0.0,
act_cfg=dict(type='ReLU', inplace=True),
init_cfg=None):
super(BasicBlock, self).__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.expansion = expansion
assert self.expansion == 1
assert out_channels % expansion == 0
self.mid_channels = out_channels // expansion
self.stride = stride
self.dilation = dilation
self.style = style
self.with_cp = with_cp
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, self.mid_channels, postfix=1)
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, out_channels, postfix=2)
self.conv1 = build_conv_layer(
conv_cfg,
in_channels,
self.mid_channels,
3,
stride=stride,
padding=dilation,
dilation=dilation,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
conv_cfg,
self.mid_channels,
out_channels,
3,
padding=1,
bias=False)
self.add_module(self.norm2_name, norm2)
self.relu = build_activation_layer(act_cfg)
self.downsample = downsample
self.drop_path = DropPath(drop_prob=drop_path_rate
) if drop_path_rate > eps else nn.Identity()
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
def forward(self, x):
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.drop_path(out)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
class Bottleneck(BaseModule):
"""Bottleneck block for ResNet.
Args:
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
expansion (int): The ratio of ``out_channels/mid_channels`` where
``mid_channels`` is the input/output channels of conv2. Default: 4.
stride (int): stride of the block. Default: 1
dilation (int): dilation of convolution. Default: 1
downsample (nn.Module, optional): downsample operation on identity
branch. Default: None.
style (str): ``"pytorch"`` or ``"caffe"``. If set to "pytorch", the
stride-two layer is the 3x3 conv layer, otherwise the stride-two
layer is the first 1x1 conv layer. Default: "pytorch".
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
conv_cfg (dict, optional): dictionary to construct and config conv
layer. Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
"""
def __init__(self,
in_channels,
out_channels,
expansion=4,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU', inplace=True),
drop_path_rate=0.0,
init_cfg=None):
super(Bottleneck, self).__init__(init_cfg=init_cfg)
assert style in ['pytorch', 'caffe']
self.in_channels = in_channels
self.out_channels = out_channels
self.expansion = expansion
assert out_channels % expansion == 0
self.mid_channels = out_channels // expansion
self.stride = stride
self.dilation = dilation
self.style = style
self.with_cp = with_cp
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
if self.style == 'pytorch':
self.conv1_stride = 1
self.conv2_stride = stride
else:
self.conv1_stride = stride
self.conv2_stride = 1
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, self.mid_channels, postfix=1)
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, self.mid_channels, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
norm_cfg, out_channels, postfix=3)
self.conv1 = build_conv_layer(
conv_cfg,
in_channels,
self.mid_channels,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
conv_cfg,
self.mid_channels,
self.mid_channels,
kernel_size=3,
stride=self.conv2_stride,
padding=dilation,
dilation=dilation,
bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
conv_cfg,
self.mid_channels,
out_channels,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
self.relu = build_activation_layer(act_cfg)
self.downsample = downsample
self.drop_path = DropPath(drop_prob=drop_path_rate
) if drop_path_rate > eps else nn.Identity()
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
@property
def norm3(self):
return getattr(self, self.norm3_name)
def forward(self, x):
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.norm3(out)
if self.downsample is not None:
identity = self.downsample(x)
out = self.drop_path(out)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
def get_expansion(block, expansion=None):
"""Get the expansion of a residual block.
The block expansion will be obtained by the following order:
1. If ``expansion`` is given, just return it.
2. If ``block`` has the attribute ``expansion``, then return
``block.expansion``.
3. Return the default value according the the block type:
1 for ``BasicBlock`` and 4 for ``Bottleneck``.
Args:
block (class): The block class.
expansion (int | None): The given expansion ratio.
Returns:
int: The expansion of the block.
"""
if isinstance(expansion, int):
assert expansion > 0
elif expansion is None:
if hasattr(block, 'expansion'):
expansion = block.expansion
elif issubclass(block, BasicBlock):
expansion = 1
elif issubclass(block, Bottleneck):
expansion = 4
else:
raise TypeError(f'expansion is not specified for {block.__name__}')
else:
raise TypeError('expansion must be an integer or None')
return expansion
class ResLayer(nn.Sequential):
"""ResLayer to build ResNet style backbone.
Args:
block (nn.Module): Residual block used to build ResLayer.
num_blocks (int): Number of blocks.
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
expansion (int, optional): The expansion for BasicBlock/Bottleneck.
If not specified, it will firstly be obtained via
``block.expansion``. If the block has no attribute "expansion",
the following default values will be used: 1 for BasicBlock and
4 for Bottleneck. Default: None.
stride (int): stride of the first block. Default: 1.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False
conv_cfg (dict, optional): dictionary to construct and config conv
layer. Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
drop_path_rate (float or list): stochastic depth rate.
Default: 0.
"""
def __init__(self,
block,
num_blocks,
in_channels,
out_channels,
expansion=None,
stride=1,
avg_down=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
drop_path_rate=0.0,
**kwargs):
self.block = block
self.expansion = get_expansion(block, expansion)
if isinstance(drop_path_rate, float):
drop_path_rate = [drop_path_rate] * num_blocks
assert len(drop_path_rate
) == num_blocks, 'Please check the length of drop_path_rate'
downsample = None
if stride != 1 or in_channels != out_channels:
downsample = []
conv_stride = stride
if avg_down and stride != 1:
conv_stride = 1
downsample.append(
nn.AvgPool2d(
kernel_size=stride,
stride=stride,
ceil_mode=True,
count_include_pad=False))
downsample.extend([
build_conv_layer(
conv_cfg,
in_channels,
out_channels,
kernel_size=1,
stride=conv_stride,
bias=False),
build_norm_layer(norm_cfg, out_channels)[1]
])
downsample = nn.Sequential(*downsample)
layers = []
layers.append(
block(
in_channels=in_channels,
out_channels=out_channels,
expansion=self.expansion,
stride=stride,
downsample=downsample,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
drop_path_rate=drop_path_rate[0],
**kwargs))
in_channels = out_channels
for i in range(1, num_blocks):
layers.append(
block(
in_channels=in_channels,
out_channels=out_channels,
expansion=self.expansion,
stride=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
drop_path_rate=drop_path_rate[i],
**kwargs))
super(ResLayer, self).__init__(*layers)
@MODELS.register_module()
class ResNet(BaseBackbone):
"""ResNet backbone.
Please refer to the `paper <https://arxiv.org/abs/1512.03385>`__ for
details.
Args:
depth (int): Network depth, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Default: 3.
stem_channels (int): Output channels of the stem layer. Default: 64.
base_channels (int): Middle channels of the first stage. Default: 64.
num_stages (int): Stages of the network. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages.
Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Default: False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
Example:
>>> from mmpretrain.models import ResNet
>>> import torch
>>> self = ResNet(depth=18)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 32, 32)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 64, 8, 8)
(1, 128, 4, 4)
(1, 256, 2, 2)
(1, 512, 1, 1)
"""
arch_settings = {
18: (BasicBlock, (2, 2, 2, 2)),
34: (BasicBlock, (3, 4, 6, 3)),
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self,
depth,
in_channels=3,
stem_channels=64,
base_channels=64,
expansion=None,
num_stages=4,
strides=(1, 2, 2, 2),
dilations=(1, 1, 1, 1),
out_indices=(3, ),
style='pytorch',
deep_stem=False,
avg_down=False,
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=False,
with_cp=False,
zero_init_residual=True,
init_cfg=[
dict(type='Kaiming', layer=['Conv2d']),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
],
drop_path_rate=0.0):
super(ResNet, self).__init__(init_cfg)
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')
self.depth = depth
self.stem_channels = stem_channels
self.base_channels = base_channels
self.num_stages = num_stages
assert num_stages >= 1 and num_stages <= 4
self.strides = strides
self.dilations = dilations
assert len(strides) == len(dilations) == num_stages
self.out_indices = out_indices
assert max(out_indices) < num_stages
self.style = style
self.deep_stem = deep_stem
self.avg_down = avg_down
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.with_cp = with_cp
self.norm_eval = norm_eval
self.zero_init_residual = zero_init_residual
self.block, stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages]
self.expansion = get_expansion(self.block, expansion)
self._make_stem_layer(in_channels, stem_channels)
self.res_layers = []
_in_channels = stem_channels
_out_channels = base_channels * self.expansion
# stochastic depth decay rule
total_depth = sum(stage_blocks)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
]
for i, num_blocks in enumerate(self.stage_blocks):
stride = strides[i]
dilation = dilations[i]
res_layer = self.make_res_layer(
block=self.block,
num_blocks=num_blocks,
in_channels=_in_channels,
out_channels=_out_channels,
expansion=self.expansion,
stride=stride,
dilation=dilation,
style=self.style,
avg_down=self.avg_down,
with_cp=with_cp,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
drop_path_rate=dpr[:num_blocks])
_in_channels = _out_channels
_out_channels *= 2
dpr = dpr[num_blocks:]
layer_name = f'layer{i + 1}'
self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name)
self._freeze_stages()
self.feat_dim = res_layer[-1].out_channels
def make_res_layer(self, **kwargs):
return ResLayer(**kwargs)
@property
def norm1(self):
return getattr(self, self.norm1_name)
def _make_stem_layer(self, in_channels, stem_channels):
if self.deep_stem:
self.stem = nn.Sequential(
ConvModule(
in_channels,
stem_channels // 2,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
inplace=True),
ConvModule(
stem_channels // 2,
stem_channels // 2,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
inplace=True),
ConvModule(
stem_channels // 2,
stem_channels,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
inplace=True))
else:
self.conv1 = build_conv_layer(
self.conv_cfg,
in_channels,
stem_channels,
kernel_size=7,
stride=2,
padding=3,
bias=False)
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, stem_channels, postfix=1)
self.add_module(self.norm1_name, norm1)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
def _freeze_stages(self):
if self.frozen_stages >= 0:
if self.deep_stem:
self.stem.eval()
for param in self.stem.parameters():
param.requires_grad = False
else:
self.norm1.eval()
for m in [self.conv1, self.norm1]:
for param in m.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
m = getattr(self, f'layer{i}')
m.eval()
for param in m.parameters():
param.requires_grad = False
def init_weights(self):
super(ResNet, self).init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress zero_init_residual if use pretrained model.
return
if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
def forward(self, x):
if self.deep_stem:
x = self.stem(x)
else:
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.maxpool(x)
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
def train(self, mode=True):
super(ResNet, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
def get_layer_depth(self, param_name: str, prefix: str = ''):
"""Get the layer id to set the different learning rates for ResNet.
ResNet stages:
50 : [3, 4, 6, 3]
101 : [3, 4, 23, 3]
152 : [3, 8, 36, 3]
200 : [3, 24, 36, 3]
eca269d: [3, 30, 48, 8]
Args:
param_name (str): The name of the parameter.
prefix (str): The prefix for the parameter.
Defaults to an empty string.
Returns:
Tuple[int, int]: The layer-wise depth and the num of layers.
"""
depths = self.stage_blocks
if depths[1] == 4 and depths[2] == 6:
blk2, blk3 = 2, 3
elif depths[1] == 4 and depths[2] == 23:
blk2, blk3 = 2, 3
elif depths[1] == 8 and depths[2] == 36:
blk2, blk3 = 4, 4
elif depths[1] == 24 and depths[2] == 36:
blk2, blk3 = 4, 4
elif depths[1] == 30 and depths[2] == 48:
blk2, blk3 = 5, 6
else:
raise NotImplementedError
N2, N3 = math.ceil(depths[1] / blk2 -
1e-5), math.ceil(depths[2] / blk3 - 1e-5)
N = 2 + N2 + N3 # r50: 2 + 2 + 2 = 6
max_layer_id = N + 1 # r50: 2 + 2 + 2 + 1(like head) = 7
if not param_name.startswith(prefix):
# For subsequent module like head
return max_layer_id, max_layer_id + 1
if param_name.startswith('backbone.layer'):
stage_id = int(param_name.split('.')[1][5:])
block_id = int(param_name.split('.')[2])
if stage_id == 1:
layer_id = 1
elif stage_id == 2:
layer_id = 2 + block_id // blk2 # r50: 2, 3
elif stage_id == 3:
layer_id = 2 + N2 + block_id // blk3 # r50: 4, 5
else: # stage_id == 4
layer_id = N # r50: 6
return layer_id, max_layer_id + 1
else:
return 0, max_layer_id + 1
@MODELS.register_module()
class ResNetV1c(ResNet):
"""ResNetV1c backbone.
This variant is described in `Bag of Tricks.
<https://arxiv.org/pdf/1812.01187.pdf>`_.
Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv
in the input stem with three 3x3 convs.
"""
def __init__(self, **kwargs):
super(ResNetV1c, self).__init__(
deep_stem=True, avg_down=False, **kwargs)
@MODELS.register_module()
class ResNetV1d(ResNet):
"""ResNetV1d backbone.
This variant is described in `Bag of Tricks.
<https://arxiv.org/pdf/1812.01187.pdf>`_.
Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
the input stem with three 3x3 convs. And in the downsampling block, a 2x2
avg_pool with stride 2 is added before conv, whose stride is changed to 1.
"""
def __init__(self, **kwargs):
super(ResNetV1d, self).__init__(
deep_stem=True, avg_down=True, **kwargs)
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmpretrain.registry import MODELS
from .resnet import ResNet
@MODELS.register_module()
class ResNet_CIFAR(ResNet):
"""ResNet backbone for CIFAR.
Compared to standard ResNet, it uses `kernel_size=3` and `stride=1` in
conv1, and does not apply MaxPoolinng after stem. It has been proven to
be more efficient than standard ResNet in other public codebase, e.g.,
`https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py`.
Args:
depth (int): Network depth, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Default: 3.
stem_channels (int): Output channels of the stem layer. Default: 64.
base_channels (int): Middle channels of the first stage. Default: 64.
num_stages (int): Stages of the network. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): This network has specific designed stem, thus it is
asserted to be False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
"""
def __init__(self, depth, deep_stem=False, **kwargs):
super(ResNet_CIFAR, self).__init__(
depth, deep_stem=deep_stem, **kwargs)
assert not self.deep_stem, 'ResNet_CIFAR do not support deep_stem'
def _make_stem_layer(self, in_channels, base_channels):
self.conv1 = build_conv_layer(
self.conv_cfg,
in_channels,
base_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, base_channels, postfix=1)
self.add_module(self.norm1_name, norm1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmpretrain.registry import MODELS
from .resnet import Bottleneck as _Bottleneck
from .resnet import ResLayer, ResNet
class Bottleneck(_Bottleneck):
"""Bottleneck block for ResNeXt.
Args:
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
groups (int): Groups of conv2.
width_per_group (int): Width per group of conv2. 64x4d indicates
``groups=64, width_per_group=4`` and 32x8d indicates
``groups=32, width_per_group=8``.
stride (int): stride of the block. Default: 1
dilation (int): dilation of convolution. Default: 1
downsample (nn.Module, optional): downsample operation on identity
branch. Default: None
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
conv_cfg (dict, optional): dictionary to construct and config conv
layer. Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
def __init__(self,
in_channels,
out_channels,
base_channels=64,
groups=32,
width_per_group=4,
**kwargs):
super(Bottleneck, self).__init__(in_channels, out_channels, **kwargs)
self.groups = groups
self.width_per_group = width_per_group
# For ResNet bottleneck, middle channels are determined by expansion
# and out_channels, but for ResNeXt bottleneck, it is determined by
# groups and width_per_group and the stage it is located in.
if groups != 1:
assert self.mid_channels % base_channels == 0
self.mid_channels = (
groups * width_per_group * self.mid_channels // base_channels)
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, self.mid_channels, postfix=1)
self.norm2_name, norm2 = build_norm_layer(
self.norm_cfg, self.mid_channels, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, self.out_channels, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
self.in_channels,
self.mid_channels,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
self.conv_cfg,
self.mid_channels,
self.mid_channels,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
self.conv_cfg,
self.mid_channels,
self.out_channels,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
@MODELS.register_module()
class ResNeXt(ResNet):
"""ResNeXt backbone.
Please refer to the `paper <https://arxiv.org/abs/1611.05431>`__ for
details.
Args:
depth (int): Network depth, from {50, 101, 152}.
groups (int): Groups of conv2 in Bottleneck. Default: 32.
width_per_group (int): Width per group of conv2 in Bottleneck.
Default: 4.
in_channels (int): Number of input image channels. Default: 3.
stem_channels (int): Output channels of the stem layer. Default: 64.
num_stages (int): Stages of the network. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Default: False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
"""
arch_settings = {
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self, depth, groups=32, width_per_group=4, **kwargs):
self.groups = groups
self.width_per_group = width_per_group
super(ResNeXt, self).__init__(depth, **kwargs)
def make_res_layer(self, **kwargs):
return ResLayer(
groups=self.groups,
width_per_group=self.width_per_group,
base_channels=self.base_channels,
**kwargs)
# Copyright (c) OpenMMLab. All rights reserved.
import sys
import numpy as np
import torch
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import trunc_normal_
from torch import nn
from torch.autograd import Function as Function
from mmpretrain.models.backbones.base_backbone import BaseBackbone
from mmpretrain.registry import MODELS
from ..utils import (MultiheadAttention, build_norm_layer, resize_pos_embed,
to_2tuple)
class RevBackProp(Function):
"""Custom Backpropagation function to allow (A) flushing memory in forward
and (B) activation recomputation reversibly in backward for gradient
calculation.
Inspired by
https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py
"""
@staticmethod
def forward(
ctx,
x,
layers,
buffer_layers, # List of layer ids for int activation to buffer
):
"""Reversible Forward pass.
Any intermediate activations from `buffer_layers` are cached in ctx for
forward pass. This is not necessary for standard usecases. Each
reversible layer implements its own forward pass logic.
"""
buffer_layers.sort()
x1, x2 = torch.chunk(x, 2, dim=-1)
intermediate = []
for layer in layers:
x1, x2 = layer(x1, x2)
if layer.layer_id in buffer_layers:
intermediate.extend([x1.detach(), x2.detach()])
if len(buffer_layers) == 0:
all_tensors = [x1.detach(), x2.detach()]
else:
intermediate = [torch.LongTensor(buffer_layers), *intermediate]
all_tensors = [x1.detach(), x2.detach(), *intermediate]
ctx.save_for_backward(*all_tensors)
ctx.layers = layers
return torch.cat([x1, x2], dim=-1)
@staticmethod
def backward(ctx, dx):
"""Reversible Backward pass.
Any intermediate activations from `buffer_layers` are recovered from
ctx. Each layer implements its own loic for backward pass (both
activation recomputation and grad calculation).
"""
d_x1, d_x2 = torch.chunk(dx, 2, dim=-1)
# retrieve params from ctx for backward
x1, x2, *int_tensors = ctx.saved_tensors
# no buffering
if len(int_tensors) != 0:
buffer_layers = int_tensors[0].tolist()
else:
buffer_layers = []
layers = ctx.layers
for _, layer in enumerate(layers[::-1]):
if layer.layer_id in buffer_layers:
x1, x2, d_x1, d_x2 = layer.backward_pass(
y1=int_tensors[buffer_layers.index(layer.layer_id) * 2 +
1],
y2=int_tensors[buffer_layers.index(layer.layer_id) * 2 +
2],
d_y1=d_x1,
d_y2=d_x2,
)
else:
x1, x2, d_x1, d_x2 = layer.backward_pass(
y1=x1,
y2=x2,
d_y1=d_x1,
d_y2=d_x2,
)
dx = torch.cat([d_x1, d_x2], dim=-1)
del int_tensors
del d_x1, d_x2, x1, x2
return dx, None, None
class RevTransformerEncoderLayer(BaseModule):
"""Reversible Transformer Encoder Layer.
This module is a building block of Reversible Transformer Encoder,
which support backpropagation without storing activations.
The residual connection is not applied to the FFN layer.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs.
drop_rate (float): Probability of an element to be zeroed.
Default: 0.0
attn_drop_rate (float): The drop out rate for attention layer.
Default: 0.0
drop_path_rate (float): stochastic depth rate.
Default 0.0
num_fcs (int): The number of linear in FFN
Default: 2
qkv_bias (bool): enable bias for qkv if True.
Default: True
act_cfg (dict): The activation config for FFNs.
Default: dict(type='GELU')
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
layer_id (int): The layer id of current layer. Used in RevBackProp.
Default: 0
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self,
embed_dims: int,
num_heads: int,
feedforward_channels: int,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
num_fcs: int = 2,
qkv_bias: bool = True,
act_cfg: dict = dict(type='GELU'),
norm_cfg: dict = dict(type='LN'),
layer_id: int = 0,
init_cfg=None):
super(RevTransformerEncoderLayer, self).__init__(init_cfg=init_cfg)
self.drop_path_cfg = dict(type='DropPath', drop_prob=drop_path_rate)
self.embed_dims = embed_dims
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
self.attn = MultiheadAttention(
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
qkv_bias=qkv_bias)
self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
act_cfg=act_cfg,
add_identity=False)
self.layer_id = layer_id
self.seeds = {}
def init_weights(self):
super(RevTransformerEncoderLayer, self).init_weights()
for m in self.ffn.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.normal_(m.bias, std=1e-6)
def seed_cuda(self, key):
"""Fix seeds to allow for stochastic elements such as dropout to be
reproduced exactly in activation recomputation in the backward pass."""
# randomize seeds
# use cuda generator if available
if (hasattr(torch.cuda, 'default_generators')
and len(torch.cuda.default_generators) > 0):
# GPU
device_idx = torch.cuda.current_device()
seed = torch.cuda.default_generators[device_idx].seed()
else:
# CPU
seed = int(torch.seed() % sys.maxsize)
self.seeds[key] = seed
torch.manual_seed(self.seeds[key])
def forward(self, x1, x2):
"""
Implementation of Reversible TransformerEncoderLayer
`
x = x + self.attn(self.ln1(x))
x = self.ffn(self.ln2(x), identity=x)
`
"""
self.seed_cuda('attn')
# attention output
f_x2 = self.attn(self.ln1(x2))
# apply droppath on attention output
self.seed_cuda('droppath')
f_x2_dropped = build_dropout(self.drop_path_cfg)(f_x2)
y1 = x1 + f_x2_dropped
# free memory
if self.training:
del x1
# ffn output
self.seed_cuda('ffn')
g_y1 = self.ffn(self.ln2(y1))
# apply droppath on ffn output
torch.manual_seed(self.seeds['droppath'])
g_y1_dropped = build_dropout(self.drop_path_cfg)(g_y1)
# final output
y2 = x2 + g_y1_dropped
# free memory
if self.training:
del x2
return y1, y2
def backward_pass(self, y1, y2, d_y1, d_y2):
"""Activation re-compute with the following equation.
x2 = y2 - g(y1), g = FFN
x1 = y1 - f(x2), f = MSHA
"""
# temporarily record intermediate activation for G
# and use them for gradient calculation of G
with torch.enable_grad():
y1.requires_grad = True
torch.manual_seed(self.seeds['ffn'])
g_y1 = self.ffn(self.ln2(y1))
torch.manual_seed(self.seeds['droppath'])
g_y1 = build_dropout(self.drop_path_cfg)(g_y1)
g_y1.backward(d_y2, retain_graph=True)
# activate recomputation is by design and not part of
# the computation graph in forward pass
with torch.no_grad():
x2 = y2 - g_y1
del g_y1
d_y1 = d_y1 + y1.grad
y1.grad = None
# record F activation and calculate gradients on F
with torch.enable_grad():
x2.requires_grad = True
torch.manual_seed(self.seeds['attn'])
f_x2 = self.attn(self.ln1(x2))
torch.manual_seed(self.seeds['droppath'])
f_x2 = build_dropout(self.drop_path_cfg)(f_x2)
f_x2.backward(d_y1, retain_graph=True)
# propagate reverse computed activations at the
# start of the previous block
with torch.no_grad():
x1 = y1 - f_x2
del f_x2, y1
d_y2 = d_y2 + x2.grad
x2.grad = None
x2 = x2.detach()
return x1, x2, d_y1, d_y2
class TwoStreamFusion(nn.Module):
"""A general constructor for neural modules fusing two equal sized tensors
in forward.
Args:
mode (str): The mode of fusion. Options are 'add', 'max', 'min',
'avg', 'concat'.
"""
def __init__(self, mode: str):
super().__init__()
self.mode = mode
if mode == 'add':
self.fuse_fn = lambda x: torch.stack(x).sum(dim=0)
elif mode == 'max':
self.fuse_fn = lambda x: torch.stack(x).max(dim=0).values
elif mode == 'min':
self.fuse_fn = lambda x: torch.stack(x).min(dim=0).values
elif mode == 'avg':
self.fuse_fn = lambda x: torch.stack(x).mean(dim=0)
elif mode == 'concat':
self.fuse_fn = lambda x: torch.cat(x, dim=-1)
else:
raise NotImplementedError
def forward(self, x):
# split the tensor into two halves in the channel dimension
x = torch.chunk(x, 2, dim=2)
return self.fuse_fn(x)
@MODELS.register_module()
class RevVisionTransformer(BaseBackbone):
"""Reversible Vision Transformer.
A PyTorch implementation of : `Reversible Vision Transformers
<https://openaccess.thecvf.com/content/CVPR2022/html/Mangalam_Reversible_Vision_Transformers_CVPR_2022_paper.html>`_ # noqa: E501
Args:
arch (str | dict): Vision Transformer architecture. If use string,
choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small'
and 'deit-base'. If use dict, it should have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **num_layers** (int): The number of transformer encoder layers.
- **num_heads** (int): The number of heads in attention modules.
- **feedforward_channels** (int): The hidden dimensions in
feedforward modules.
Defaults to 'base'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 16.
in_channels (int): The num of input channels. Defaults to 3.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
qkv_bias (bool): Whether to add bias for qkv in attention modules.
Defaults to True.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
Defaults to ``"avg_featmap"``.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
fusion_mode (str): The fusion mode of transformer layers.
Defaults to 'concat'.
no_custom_backward (bool): Whether to use custom backward.
Defaults to False.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
arch_zoo = {
**dict.fromkeys(
['s', 'small'], {
'embed_dims': 768,
'num_layers': 8,
'num_heads': 8,
'feedforward_channels': 768 * 3,
}),
**dict.fromkeys(
['b', 'base'], {
'embed_dims': 768,
'num_layers': 12,
'num_heads': 12,
'feedforward_channels': 3072
}),
**dict.fromkeys(
['l', 'large'], {
'embed_dims': 1024,
'num_layers': 24,
'num_heads': 16,
'feedforward_channels': 4096
}),
**dict.fromkeys(
['h', 'huge'],
{
# The same as the implementation in MAE
# <https://arxiv.org/abs/2111.06377>
'embed_dims': 1280,
'num_layers': 32,
'num_heads': 16,
'feedforward_channels': 5120
}),
**dict.fromkeys(
['deit-t', 'deit-tiny'], {
'embed_dims': 192,
'num_layers': 12,
'num_heads': 3,
'feedforward_channels': 192 * 4
}),
**dict.fromkeys(
['deit-s', 'deit-small'], {
'embed_dims': 384,
'num_layers': 12,
'num_heads': 6,
'feedforward_channels': 384 * 4
}),
**dict.fromkeys(
['deit-b', 'deit-base'], {
'embed_dims': 768,
'num_layers': 12,
'num_heads': 12,
'feedforward_channels': 768 * 4
}),
}
num_extra_tokens = 0 # The official RevViT doesn't have class token
OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'}
def __init__(self,
arch='base',
img_size=224,
patch_size=16,
in_channels=3,
drop_rate=0.,
drop_path_rate=0.,
qkv_bias=True,
norm_cfg=dict(type='LN', eps=1e-6),
final_norm=True,
out_type='avg_featmap',
with_cls_token=False,
frozen_stages=-1,
interpolate_mode='bicubic',
patch_cfg=dict(),
layer_cfgs=dict(),
fusion_mode='concat',
no_custom_backward=False,
init_cfg=None):
super(RevVisionTransformer, self).__init__(init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {
'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
}
assert isinstance(arch, dict) and essential_keys <= set(arch), \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.embed_dims = self.arch_settings['embed_dims']
self.num_layers = self.arch_settings['num_layers']
self.img_size = to_2tuple(img_size)
self.no_custom_backward = no_custom_backward
# Set patch embedding
_patch_cfg = dict(
in_channels=in_channels,
input_size=img_size,
embed_dims=self.embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
)
_patch_cfg.update(patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
self.patch_resolution = self.patch_embed.init_out_size
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
# Set out type
if out_type not in self.OUT_TYPES:
raise ValueError(f'Unsupported `out_type` {out_type}, please '
f'choose from {self.OUT_TYPES}')
self.out_type = out_type
# Set cls token
if with_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
self.num_extra_tokens = 1
elif out_type != 'cls_token':
self.cls_token = None
self.num_extra_tokens = 0
else:
raise ValueError(
'with_cls_token must be True when `out_type="cls_token"`.')
# Set position embedding
self.interpolate_mode = interpolate_mode
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + self.num_extra_tokens,
self.embed_dims))
self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
self.drop_after_pos = nn.Dropout(p=drop_rate)
# stochastic depth decay rule
dpr = np.linspace(0, drop_path_rate, self.num_layers)
self.layers = ModuleList()
if isinstance(layer_cfgs, dict):
layer_cfgs = [layer_cfgs] * self.num_layers
for i in range(self.num_layers):
_layer_cfg = dict(
embed_dims=self.embed_dims,
num_heads=self.arch_settings['num_heads'],
feedforward_channels=self.
arch_settings['feedforward_channels'],
drop_rate=drop_rate,
drop_path_rate=dpr[i],
qkv_bias=qkv_bias,
layer_id=i,
norm_cfg=norm_cfg)
_layer_cfg.update(layer_cfgs[i])
self.layers.append(RevTransformerEncoderLayer(**_layer_cfg))
# fusion operation for the final output
self.fusion_layer = TwoStreamFusion(mode=fusion_mode)
self.frozen_stages = frozen_stages
self.final_norm = final_norm
if final_norm:
self.ln1 = build_norm_layer(norm_cfg, self.embed_dims * 2)
# freeze stages only when self.frozen_stages > 0
if self.frozen_stages > 0:
self._freeze_stages()
def init_weights(self):
super(RevVisionTransformer, self).init_weights()
if not (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
trunc_normal_(self.pos_embed, std=0.02)
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'pos_embed'
if name not in state_dict.keys():
return
ckpt_pos_embed_shape = state_dict[name].shape
if self.pos_embed.shape != ckpt_pos_embed_shape:
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info(
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
f'to {self.pos_embed.shape}.')
ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
pos_embed_shape = self.patch_embed.init_out_size
state_dict[name] = resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode,
self.num_extra_tokens)
@staticmethod
def resize_pos_embed(*args, **kwargs):
"""Interface for backward-compatibility."""
return resize_pos_embed(*args, **kwargs)
def _freeze_stages(self):
# freeze position embedding
self.pos_embed.requires_grad = False
# set dropout to eval model
self.drop_after_pos.eval()
# freeze patch embedding
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
# freeze cls_token
if self.cls_token is not None:
self.cls_token.requires_grad = False
# freeze layers
for i in range(1, self.frozen_stages + 1):
m = self.layers[i - 1]
m.eval()
for param in m.parameters():
param.requires_grad = False
# freeze the last layer norm
if self.frozen_stages == len(self.layers) and self.final_norm:
self.ln1.eval()
for param in self.ln1.parameters():
param.requires_grad = False
def forward(self, x):
B = x.shape[0]
x, patch_resolution = self.patch_embed(x)
if self.cls_token is not None:
cls_token = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = x + resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=self.num_extra_tokens)
x = self.drop_after_pos(x)
x = torch.cat([x, x], dim=-1)
# forward with different conditions
if not self.training or self.no_custom_backward:
# in eval/inference model
executing_fn = RevVisionTransformer._forward_vanilla_bp
else:
# use custom backward when self.training=True.
executing_fn = RevBackProp.apply
x = executing_fn(x, self.layers, [])
if self.final_norm:
x = self.ln1(x)
x = self.fusion_layer(x)
return (self._format_output(x, patch_resolution), )
@staticmethod
def _forward_vanilla_bp(hidden_state, layers, buffer=[]):
"""Using reversible layers without reversible backpropagation.
Debugging purpose only. Activated with self.no_custom_backward
"""
# split into ffn state(ffn_out) and attention output(attn_out)
ffn_out, attn_out = torch.chunk(hidden_state, 2, dim=-1)
del hidden_state
for _, layer in enumerate(layers):
attn_out, ffn_out = layer(attn_out, ffn_out)
return torch.cat([attn_out, ffn_out], dim=-1)
def _format_output(self, x, hw):
if self.out_type == 'raw':
return x
if self.out_type == 'cls_token':
return x[:, 0]
patch_token = x[:, self.num_extra_tokens:]
if self.out_type == 'featmap':
B = x.size(0)
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
if self.out_type == 'avg_featmap':
return patch_token.mean(dim=1)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence
import torch
import torch.nn as nn
from mmcv.cnn.bricks import DropPath, build_norm_layer
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone
from .poolformer import Mlp, PatchEmbed
class Affine(nn.Module):
"""Affine Transformation module.
Args:
in_features (int): Input dimension.
"""
def __init__(self, in_features):
super().__init__()
self.affine = nn.Conv2d(
in_features,
in_features,
kernel_size=1,
stride=1,
padding=0,
groups=in_features,
bias=True)
def forward(self, x):
return self.affine(x) - x
class RIFormerBlock(BaseModule):
"""RIFormer Block.
Args:
dim (int): Embedding dim.
mlp_ratio (float): Mlp expansion ratio. Defaults to 4.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='GN', num_groups=1)``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
drop (float): Dropout rate. Defaults to 0.
drop_path (float): Stochastic depth rate. Defaults to 0.
layer_scale_init_value (float): Init value for Layer Scale.
Defaults to 1e-5.
deploy (bool): Whether to switch the model structure to
deployment mode. Default: False.
"""
def __init__(self,
dim,
mlp_ratio=4.,
norm_cfg=dict(type='GN', num_groups=1),
act_cfg=dict(type='GELU'),
drop=0.,
drop_path=0.,
layer_scale_init_value=1e-5,
deploy=False):
super().__init__()
if deploy:
self.norm_reparam = build_norm_layer(norm_cfg, dim)[1]
else:
self.norm1 = build_norm_layer(norm_cfg, dim)[1]
self.token_mixer = Affine(in_features=dim)
self.norm2 = build_norm_layer(norm_cfg, dim)[1]
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_cfg=act_cfg,
drop=drop)
# The following two techniques are useful to train deep RIFormers.
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.layer_scale_1 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
self.norm_cfg = norm_cfg
self.dim = dim
self.deploy = deploy
def forward(self, x):
if hasattr(self, 'norm_reparam'):
x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) *
self.norm_reparam(x))
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) *
self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) *
self.token_mixer(self.norm1(x)))
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) *
self.mlp(self.norm2(x)))
return x
def fuse_affine(self, norm, token_mixer):
gamma_affn = token_mixer.affine.weight.reshape(-1)
gamma_affn = gamma_affn - torch.ones_like(gamma_affn)
beta_affn = token_mixer.affine.bias
gamma_ln = norm.weight
beta_ln = norm.bias
return (gamma_ln * gamma_affn), (beta_ln * gamma_affn + beta_affn)
def get_equivalent_scale_bias(self):
eq_s, eq_b = self.fuse_affine(self.norm1, self.token_mixer)
return eq_s, eq_b
def switch_to_deploy(self):
if self.deploy:
return
eq_s, eq_b = self.get_equivalent_scale_bias()
self.norm_reparam = build_norm_layer(self.norm_cfg, self.dim)[1]
self.norm_reparam.weight.data = eq_s
self.norm_reparam.bias.data = eq_b
self.__delattr__('norm1')
if hasattr(self, 'token_mixer'):
self.__delattr__('token_mixer')
self.deploy = True
def basic_blocks(dim,
index,
layers,
mlp_ratio=4.,
norm_cfg=dict(type='GN', num_groups=1),
act_cfg=dict(type='GELU'),
drop_rate=.0,
drop_path_rate=0.,
layer_scale_init_value=1e-5,
deploy=False):
"""generate RIFormer blocks for a stage."""
blocks = []
for block_idx in range(layers[index]):
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (
sum(layers) - 1)
blocks.append(
RIFormerBlock(
dim,
mlp_ratio=mlp_ratio,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
drop=drop_rate,
drop_path=block_dpr,
layer_scale_init_value=layer_scale_init_value,
deploy=deploy,
))
blocks = nn.Sequential(*blocks)
return blocks
@MODELS.register_module()
class RIFormer(BaseBackbone):
"""RIFormer.
A PyTorch implementation of RIFormer introduced by:
`RIFormer: Keep Your Vision Backbone Effective But Removing Token Mixer <https://arxiv.org/abs/xxxx.xxxxx>`_
Args:
arch (str | dict): The model's architecture. If string, it should be
one of architecture in ``RIFormer.arch_settings``. And if dict, it
should include the following two keys:
- layers (list[int]): Number of blocks at each stage.
- embed_dims (list[int]): The number of channels at each stage.
- mlp_ratios (list[int]): Expansion ratio of MLPs.
- layer_scale_init_value (float): Init value for Layer Scale.
Defaults to 'S12'.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='LN2d', eps=1e-6)``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
in_patch_size (int): The patch size of/? input image patch embedding.
Defaults to 7.
in_stride (int): The stride of input image patch embedding.
Defaults to 4.
in_pad (int): The padding of input image patch embedding.
Defaults to 2.
down_patch_size (int): The patch size of downsampling patch embedding.
Defaults to 3.
down_stride (int): The stride of downsampling patch embedding.
Defaults to 2.
down_pad (int): The padding of downsampling patch embedding.
Defaults to 1.
drop_rate (float): Dropout rate. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
out_indices (Sequence | int): Output from which network position.
Index 0-6 respectively corresponds to
[stage1, downsampling, stage2, downsampling, stage3, downsampling, stage4]
Defaults to -1, means the last stage.
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to -1, which means not freezing any parameters.
deploy (bool): Whether to switch the model structure to
deployment mode. Default: False.
init_cfg (dict, optional): Initialization config dict
""" # noqa: E501
# --layers: [x,x,x,x], numbers of layers for the four stages
# --embed_dims, --mlp_ratios:
# embedding dims and mlp ratios for the four stages
# --downsamples: flags to apply downsampling or not in four blocks
arch_settings = {
's12': {
'layers': [2, 2, 6, 2],
'embed_dims': [64, 128, 320, 512],
'mlp_ratios': [4, 4, 4, 4],
'layer_scale_init_value': 1e-5,
},
's24': {
'layers': [4, 4, 12, 4],
'embed_dims': [64, 128, 320, 512],
'mlp_ratios': [4, 4, 4, 4],
'layer_scale_init_value': 1e-5,
},
's36': {
'layers': [6, 6, 18, 6],
'embed_dims': [64, 128, 320, 512],
'mlp_ratios': [4, 4, 4, 4],
'layer_scale_init_value': 1e-6,
},
'm36': {
'layers': [6, 6, 18, 6],
'embed_dims': [96, 192, 384, 768],
'mlp_ratios': [4, 4, 4, 4],
'layer_scale_init_value': 1e-6,
},
'm48': {
'layers': [8, 8, 24, 8],
'embed_dims': [96, 192, 384, 768],
'mlp_ratios': [4, 4, 4, 4],
'layer_scale_init_value': 1e-6,
},
}
def __init__(self,
arch='s12',
in_channels=3,
norm_cfg=dict(type='GN', num_groups=1),
act_cfg=dict(type='GELU'),
in_patch_size=7,
in_stride=4,
in_pad=2,
down_patch_size=3,
down_stride=2,
down_pad=1,
drop_rate=0.,
drop_path_rate=0.,
out_indices=-1,
frozen_stages=-1,
init_cfg=None,
deploy=False):
super().__init__(init_cfg=init_cfg)
if isinstance(arch, str):
assert arch in self.arch_settings, \
f'Unavailable arch, please choose from ' \
f'({set(self.arch_settings)}) or pass a dict.'
arch = self.arch_settings[arch]
elif isinstance(arch, dict):
assert 'layers' in arch and 'embed_dims' in arch, \
f'The arch dict must have "layers" and "embed_dims", ' \
f'but got {list(arch.keys())}.'
layers = arch['layers']
embed_dims = arch['embed_dims']
mlp_ratios = arch['mlp_ratios'] \
if 'mlp_ratios' in arch else [4, 4, 4, 4]
layer_scale_init_value = arch['layer_scale_init_value'] \
if 'layer_scale_init_value' in arch else 1e-5
self.patch_embed = PatchEmbed(
patch_size=in_patch_size,
stride=in_stride,
padding=in_pad,
in_chans=in_channels,
embed_dim=embed_dims[0])
# set the main block in network
network = []
for i in range(len(layers)):
stage = basic_blocks(
embed_dims[i],
i,
layers,
mlp_ratio=mlp_ratios[i],
norm_cfg=norm_cfg,
act_cfg=act_cfg,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
layer_scale_init_value=layer_scale_init_value,
deploy=deploy)
network.append(stage)
if i >= len(layers) - 1:
break
if embed_dims[i] != embed_dims[i + 1]:
# downsampling between two stages
network.append(
PatchEmbed(
patch_size=down_patch_size,
stride=down_stride,
padding=down_pad,
in_chans=embed_dims[i],
embed_dim=embed_dims[i + 1]))
self.network = nn.ModuleList(network)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = 7 + index
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
self.out_indices = out_indices
if self.out_indices:
for i_layer in self.out_indices:
layer = build_norm_layer(norm_cfg,
embed_dims[(i_layer + 1) // 2])[1]
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
self.frozen_stages = frozen_stages
self._freeze_stages()
self.deploy = deploy
def forward_embeddings(self, x):
x = self.patch_embed(x)
return x
def forward_tokens(self, x):
outs = []
for idx, block in enumerate(self.network):
x = block(x)
if idx in self.out_indices:
norm_layer = getattr(self, f'norm{idx}')
x_out = norm_layer(x)
outs.append(x_out)
return tuple(outs)
def forward(self, x):
# input embedding
x = self.forward_embeddings(x)
# through backbone
x = self.forward_tokens(x)
return x
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
for i in range(0, self.frozen_stages + 1):
# Include both block and downsample layer.
module = self.network[i]
module.eval()
for param in module.parameters():
param.requires_grad = False
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
norm_layer.eval()
for param in norm_layer.parameters():
param.requires_grad = False
def train(self, mode=True):
super(RIFormer, self).train(mode)
self._freeze_stages()
return self
def switch_to_deploy(self):
for m in self.modules():
if isinstance(m, RIFormerBlock):
m.switch_to_deploy()
self.deploy = True
# Copyright (c) OpenMMLab. All rights reserved.
import torch.utils.checkpoint as cp
from mmpretrain.registry import MODELS
from ..utils.se_layer import SELayer
from .resnet import Bottleneck, ResLayer, ResNet
class SEBottleneck(Bottleneck):
"""SEBottleneck block for SEResNet.
Args:
in_channels (int): The input channels of the SEBottleneck block.
out_channels (int): The output channel of the SEBottleneck block.
se_ratio (int): Squeeze ratio in SELayer. Default: 16
"""
def __init__(self, in_channels, out_channels, se_ratio=16, **kwargs):
super(SEBottleneck, self).__init__(in_channels, out_channels, **kwargs)
self.se_layer = SELayer(out_channels, ratio=se_ratio)
def forward(self, x):
def _inner_forward(x):
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.norm2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.norm3(out)
out = self.se_layer(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
out = self.relu(out)
return out
@MODELS.register_module()
class SEResNet(ResNet):
"""SEResNet backbone.
Please refer to the `paper <https://arxiv.org/abs/1709.01507>`__ for
details.
Args:
depth (int): Network depth, from {50, 101, 152}.
se_ratio (int): Squeeze ratio in SELayer. Default: 16.
in_channels (int): Number of input image channels. Default: 3.
stem_channels (int): Output channels of the stem layer. Default: 64.
num_stages (int): Stages of the network. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Default: False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
Example:
>>> from mmpretrain.models import SEResNet
>>> import torch
>>> self = SEResNet(depth=50)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 224, 224)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
(1, 64, 56, 56)
(1, 128, 28, 28)
(1, 256, 14, 14)
(1, 512, 7, 7)
"""
arch_settings = {
50: (SEBottleneck, (3, 4, 6, 3)),
101: (SEBottleneck, (3, 4, 23, 3)),
152: (SEBottleneck, (3, 8, 36, 3))
}
def __init__(self, depth, se_ratio=16, **kwargs):
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for SEResNet')
self.se_ratio = se_ratio
super(SEResNet, self).__init__(depth, **kwargs)
def make_res_layer(self, **kwargs):
return ResLayer(se_ratio=self.se_ratio, **kwargs)
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmpretrain.registry import MODELS
from .resnet import ResLayer
from .seresnet import SEBottleneck as _SEBottleneck
from .seresnet import SEResNet
class SEBottleneck(_SEBottleneck):
"""SEBottleneck block for SEResNeXt.
Args:
in_channels (int): Input channels of this block.
out_channels (int): Output channels of this block.
base_channels (int): Middle channels of the first stage. Default: 64.
groups (int): Groups of conv2.
width_per_group (int): Width per group of conv2. 64x4d indicates
``groups=64, width_per_group=4`` and 32x8d indicates
``groups=32, width_per_group=8``.
stride (int): stride of the block. Default: 1
dilation (int): dilation of convolution. Default: 1
downsample (nn.Module, optional): downsample operation on identity
branch. Default: None
se_ratio (int): Squeeze ratio in SELayer. Default: 16
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
conv_cfg (dict, optional): dictionary to construct and config conv
layer. Default: None
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN')
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
def __init__(self,
in_channels,
out_channels,
base_channels=64,
groups=32,
width_per_group=4,
se_ratio=16,
**kwargs):
super(SEBottleneck, self).__init__(in_channels, out_channels, se_ratio,
**kwargs)
self.groups = groups
self.width_per_group = width_per_group
# We follow the same rational of ResNext to compute mid_channels.
# For SEResNet bottleneck, middle channels are determined by expansion
# and out_channels, but for SEResNeXt bottleneck, it is determined by
# groups and width_per_group and the stage it is located in.
if groups != 1:
assert self.mid_channels % base_channels == 0
self.mid_channels = (
groups * width_per_group * self.mid_channels // base_channels)
self.norm1_name, norm1 = build_norm_layer(
self.norm_cfg, self.mid_channels, postfix=1)
self.norm2_name, norm2 = build_norm_layer(
self.norm_cfg, self.mid_channels, postfix=2)
self.norm3_name, norm3 = build_norm_layer(
self.norm_cfg, self.out_channels, postfix=3)
self.conv1 = build_conv_layer(
self.conv_cfg,
self.in_channels,
self.mid_channels,
kernel_size=1,
stride=self.conv1_stride,
bias=False)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
self.conv_cfg,
self.mid_channels,
self.mid_channels,
kernel_size=3,
stride=self.conv2_stride,
padding=self.dilation,
dilation=self.dilation,
groups=groups,
bias=False)
self.add_module(self.norm2_name, norm2)
self.conv3 = build_conv_layer(
self.conv_cfg,
self.mid_channels,
self.out_channels,
kernel_size=1,
bias=False)
self.add_module(self.norm3_name, norm3)
@MODELS.register_module()
class SEResNeXt(SEResNet):
"""SEResNeXt backbone.
Please refer to the `paper <https://arxiv.org/abs/1709.01507>`__ for
details.
Args:
depth (int): Network depth, from {50, 101, 152}.
groups (int): Groups of conv2 in Bottleneck. Default: 32.
width_per_group (int): Width per group of conv2 in Bottleneck.
Default: 4.
se_ratio (int): Squeeze ratio in SELayer. Default: 16.
in_channels (int): Number of input image channels. Default: 3.
stem_channels (int): Output channels of the stem layer. Default: 64.
num_stages (int): Stages of the network. Default: 4.
strides (Sequence[int]): Strides of the first block of each stage.
Default: ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Default: ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages. If only one
stage is specified, a single tensor (feature map) is returned,
otherwise multiple stages are specified, a tuple of tensors will
be returned. Default: ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Default: False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Default: False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Default: -1.
conv_cfg (dict | None): The config dict for conv layers. Default: None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Default: True.
"""
arch_settings = {
50: (SEBottleneck, (3, 4, 6, 3)),
101: (SEBottleneck, (3, 4, 23, 3)),
152: (SEBottleneck, (3, 8, 36, 3))
}
def __init__(self, depth, groups=32, width_per_group=4, **kwargs):
self.groups = groups
self.width_per_group = width_per_group
super(SEResNeXt, self).__init__(depth, **kwargs)
def make_res_layer(self, **kwargs):
return ResLayer(
groups=self.groups,
width_per_group=self.width_per_group,
base_channels=self.base_channels,
**kwargs)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule, build_activation_layer
from mmengine.model import BaseModule
from mmengine.model.weight_init import constant_init, normal_init
from torch.nn.modules.batchnorm import _BatchNorm
from mmpretrain.models.utils import channel_shuffle, make_divisible
from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone
class ShuffleUnit(BaseModule):
"""ShuffleUnit block.
ShuffleNet unit with pointwise group convolution (GConv) and channel
shuffle.
Args:
in_channels (int): The input channels of the ShuffleUnit.
out_channels (int): The output channels of the ShuffleUnit.
groups (int): The number of groups to be used in grouped 1x1
convolutions in each ShuffleUnit. Default: 3
first_block (bool): Whether it is the first ShuffleUnit of a
sequential ShuffleUnits. Default: True, which means not using the
grouped 1x1 convolution.
combine (str): The ways to combine the input and output
branches. Default: 'add'.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
with_cp (bool): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Default: False.
Returns:
Tensor: The output tensor.
"""
def __init__(self,
in_channels,
out_channels,
groups=3,
first_block=True,
combine='add',
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
with_cp=False):
super(ShuffleUnit, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.first_block = first_block
self.combine = combine
self.groups = groups
self.bottleneck_channels = self.out_channels // 4
self.with_cp = with_cp
if self.combine == 'add':
self.depthwise_stride = 1
self._combine_func = self._add
assert in_channels == out_channels, (
'in_channels must be equal to out_channels when combine '
'is add')
elif self.combine == 'concat':
self.depthwise_stride = 2
self._combine_func = self._concat
self.out_channels -= self.in_channels
self.avgpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
else:
raise ValueError(f'Cannot combine tensors with {self.combine}. '
'Only "add" and "concat" are supported')
self.first_1x1_groups = 1 if first_block else self.groups
self.g_conv_1x1_compress = ConvModule(
in_channels=self.in_channels,
out_channels=self.bottleneck_channels,
kernel_size=1,
groups=self.first_1x1_groups,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.depthwise_conv3x3_bn = ConvModule(
in_channels=self.bottleneck_channels,
out_channels=self.bottleneck_channels,
kernel_size=3,
stride=self.depthwise_stride,
padding=1,
groups=self.bottleneck_channels,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
self.g_conv_1x1_expand = ConvModule(
in_channels=self.bottleneck_channels,
out_channels=self.out_channels,
kernel_size=1,
groups=self.groups,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
self.act = build_activation_layer(act_cfg)
@staticmethod
def _add(x, out):
# residual connection
return x + out
@staticmethod
def _concat(x, out):
# concatenate along channel axis
return torch.cat((x, out), 1)
def forward(self, x):
def _inner_forward(x):
residual = x
out = self.g_conv_1x1_compress(x)
out = self.depthwise_conv3x3_bn(out)
if self.groups > 1:
out = channel_shuffle(out, self.groups)
out = self.g_conv_1x1_expand(out)
if self.combine == 'concat':
residual = self.avgpool(residual)
out = self.act(out)
out = self._combine_func(residual, out)
else:
out = self._combine_func(residual, out)
out = self.act(out)
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
@MODELS.register_module()
class ShuffleNetV1(BaseBackbone):
"""ShuffleNetV1 backbone.
Args:
groups (int): The number of groups to be used in grouped 1x1
convolutions in each ShuffleUnit. Default: 3.
widen_factor (float): Width multiplier - adjusts the number
of channels in each layer by this amount. Default: 1.0.
out_indices (Sequence[int]): Output from which stages.
Default: (2, )
frozen_stages (int): Stages to be frozen (all param fixed).
Default: -1, which means not freezing any parameters.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def __init__(self,
groups=3,
widen_factor=1.0,
out_indices=(2, ),
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
norm_eval=False,
with_cp=False,
init_cfg=None):
super(ShuffleNetV1, self).__init__(init_cfg)
self.init_cfg = init_cfg
self.stage_blocks = [4, 8, 4]
self.groups = groups
for index in out_indices:
if index not in range(0, 3):
raise ValueError('the item in out_indices must in '
f'range(0, 3). But received {index}')
if frozen_stages not in range(-1, 3):
raise ValueError('frozen_stages must be in range(-1, 3). '
f'But received {frozen_stages}')
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
if groups == 1:
channels = (144, 288, 576)
elif groups == 2:
channels = (200, 400, 800)
elif groups == 3:
channels = (240, 480, 960)
elif groups == 4:
channels = (272, 544, 1088)
elif groups == 8:
channels = (384, 768, 1536)
else:
raise ValueError(f'{groups} groups is not supported for 1x1 '
'Grouped Convolutions')
channels = [make_divisible(ch * widen_factor, 8) for ch in channels]
self.in_channels = int(24 * widen_factor)
self.conv1 = ConvModule(
in_channels=3,
out_channels=self.in_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layers = nn.ModuleList()
for i, num_blocks in enumerate(self.stage_blocks):
first_block = True if i == 0 else False
layer = self.make_layer(channels[i], num_blocks, first_block)
self.layers.append(layer)
def _freeze_stages(self):
if self.frozen_stages >= 0:
for param in self.conv1.parameters():
param.requires_grad = False
for i in range(self.frozen_stages):
layer = self.layers[i]
layer.eval()
for param in layer.parameters():
param.requires_grad = False
def init_weights(self):
super(ShuffleNetV1, self).init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if 'conv1' in name:
normal_init(m, mean=0, std=0.01)
else:
normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, val=1, bias=0.0001)
if isinstance(m, _BatchNorm):
if m.running_mean is not None:
nn.init.constant_(m.running_mean, 0)
def make_layer(self, out_channels, num_blocks, first_block=False):
"""Stack ShuffleUnit blocks to make a layer.
Args:
out_channels (int): out_channels of the block.
num_blocks (int): Number of blocks.
first_block (bool): Whether is the first ShuffleUnit of a
sequential ShuffleUnits. Default: False, which means using
the grouped 1x1 convolution.
"""
layers = []
for i in range(num_blocks):
first_block = first_block if i == 0 else False
combine_mode = 'concat' if i == 0 else 'add'
layers.append(
ShuffleUnit(
self.in_channels,
out_channels,
groups=self.groups,
first_block=first_block,
combine=combine_mode,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
with_cp=self.with_cp))
self.in_channels = out_channels
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
def train(self, mode=True):
super(ShuffleNetV1, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from mmengine.model.weight_init import constant_init, normal_init
from torch.nn.modules.batchnorm import _BatchNorm
from mmpretrain.models.utils import channel_shuffle
from mmpretrain.registry import MODELS
from .base_backbone import BaseBackbone
class InvertedResidual(BaseModule):
"""InvertedResidual block for ShuffleNetV2 backbone.
Args:
in_channels (int): The input channels of the block.
out_channels (int): The output channels of the block.
stride (int): Stride of the 3x3 convolution layer. Default: 1
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
Returns:
Tensor: The output tensor.
"""
def __init__(self,
in_channels,
out_channels,
stride=1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
with_cp=False,
init_cfg=None):
super(InvertedResidual, self).__init__(init_cfg)
self.stride = stride
self.with_cp = with_cp
branch_features = out_channels // 2
if self.stride == 1:
assert in_channels == branch_features * 2, (
f'in_channels ({in_channels}) should equal to '
f'branch_features * 2 ({branch_features * 2}) '
'when stride is 1')
if in_channels != branch_features * 2:
assert self.stride != 1, (
f'stride ({self.stride}) should not equal 1 when '
f'in_channels != branch_features * 2')
if self.stride > 1:
self.branch1 = nn.Sequential(
ConvModule(
in_channels,
in_channels,
kernel_size=3,
stride=self.stride,
padding=1,
groups=in_channels,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None),
ConvModule(
in_channels,
branch_features,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
)
self.branch2 = nn.Sequential(
ConvModule(
in_channels if (self.stride > 1) else branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
branch_features,
branch_features,
kernel_size=3,
stride=self.stride,
padding=1,
groups=branch_features,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None),
ConvModule(
branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
def forward(self, x):
def _inner_forward(x):
if self.stride > 1:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
else:
# Channel Split operation. using these lines of code to replace
# ``chunk(x, 2, dim=1)`` can make it easier to deploy a
# shufflenetv2 model by using mmdeploy.
channels = x.shape[1]
c = channels // 2 + channels % 2
x1 = x[:, :c, :, :]
x2 = x[:, c:, :, :]
out = torch.cat((x1, self.branch2(x2)), dim=1)
out = channel_shuffle(out, 2)
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
@MODELS.register_module()
class ShuffleNetV2(BaseBackbone):
"""ShuffleNetV2 backbone.
Args:
widen_factor (float): Width multiplier - adjusts the number of
channels in each layer by this amount. Default: 1.0.
out_indices (Sequence[int]): Output from which stages.
Default: (0, 1, 2, 3).
frozen_stages (int): Stages to be frozen (all param fixed).
Default: -1, which means not freezing any parameters.
conv_cfg (dict, optional): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
"""
def __init__(self,
widen_factor=1.0,
out_indices=(3, ),
frozen_stages=-1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
norm_eval=False,
with_cp=False,
init_cfg=None):
super(ShuffleNetV2, self).__init__(init_cfg)
self.stage_blocks = [4, 8, 4]
for index in out_indices:
if index not in range(0, 4):
raise ValueError('the item in out_indices must in '
f'range(0, 4). But received {index}')
if frozen_stages not in range(-1, 4):
raise ValueError('frozen_stages must be in range(-1, 4). '
f'But received {frozen_stages}')
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
if widen_factor == 0.5:
channels = [48, 96, 192, 1024]
elif widen_factor == 1.0:
channels = [116, 232, 464, 1024]
elif widen_factor == 1.5:
channels = [176, 352, 704, 1024]
elif widen_factor == 2.0:
channels = [244, 488, 976, 2048]
else:
raise ValueError('widen_factor must be in [0.5, 1.0, 1.5, 2.0]. '
f'But received {widen_factor}')
self.in_channels = 24
self.conv1 = ConvModule(
in_channels=3,
out_channels=self.in_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layers = nn.ModuleList()
for i, num_blocks in enumerate(self.stage_blocks):
layer = self._make_layer(channels[i], num_blocks)
self.layers.append(layer)
output_channels = channels[-1]
self.layers.append(
ConvModule(
in_channels=self.in_channels,
out_channels=output_channels,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
def _make_layer(self, out_channels, num_blocks):
"""Stack blocks to make a layer.
Args:
out_channels (int): out_channels of the block.
num_blocks (int): number of blocks.
"""
layers = []
for i in range(num_blocks):
stride = 2 if i == 0 else 1
layers.append(
InvertedResidual(
in_channels=self.in_channels,
out_channels=out_channels,
stride=stride,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
with_cp=self.with_cp))
self.in_channels = out_channels
return nn.Sequential(*layers)
def _freeze_stages(self):
if self.frozen_stages >= 0:
for param in self.conv1.parameters():
param.requires_grad = False
for i in range(self.frozen_stages):
m = self.layers[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
def init_weights(self):
super(ShuffleNetV2, self).init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if 'conv1' in name:
normal_init(m, mean=0, std=0.01)
else:
normal_init(m, mean=0, std=1.0 / m.weight.shape[1])
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m.weight, val=1, bias=0.0001)
if isinstance(m, _BatchNorm):
if m.running_mean is not None:
nn.init.constant_(m.running_mean, 0)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
def train(self, mode=True):
super(ShuffleNetV2, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Sequence, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmengine.model import ModuleList, Sequential
from mmpretrain.registry import MODELS
from ..utils import (SparseAvgPooling, SparseConv2d, SparseHelper,
SparseMaxPooling, build_norm_layer)
from .convnext import ConvNeXt, ConvNeXtBlock
class SparseConvNeXtBlock(ConvNeXtBlock):
"""Sparse ConvNeXt Block.
Note:
There are two equivalent implementations:
1. DwConv -> SparseLayerNorm -> 1x1 Conv -> GELU -> 1x1 Conv;
all outputs are in (N, C, H, W).
2. DwConv -> SparseLayerNorm -> Permute to (N, H, W, C) -> Linear ->
GELU -> Linear; Permute back
As default, we use the second to align with the official repository.
And it may be slightly faster.
"""
def forward(self, x):
def _inner_forward(x):
shortcut = x
x = self.depthwise_conv(x)
if self.linear_pw_conv:
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x, data_format='channel_last')
x = self.pointwise_conv1(x)
x = self.act(x)
if self.grn is not None:
x = self.grn(x, data_format='channel_last')
x = self.pointwise_conv2(x)
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
else:
x = self.norm(x, data_format='channel_first')
x = self.pointwise_conv1(x)
x = self.act(x)
if self.grn is not None:
x = self.grn(x, data_format='channel_first')
x = self.pointwise_conv2(x)
if self.gamma is not None:
x = x.mul(self.gamma.view(1, -1, 1, 1))
x *= SparseHelper._get_active_map_or_index(
H=x.shape[2], returning_active_map=True)
x = shortcut + self.drop_path(x)
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
@MODELS.register_module()
class SparseConvNeXt(ConvNeXt):
"""ConvNeXt with sparse module conversion function.
Modified from
https://github.com/keyu-tian/SparK/blob/main/models/convnext.py
and
https://github.com/keyu-tian/SparK/blob/main/encoder.py
To use ConvNeXt v2, please set ``use_grn=True`` and ``layer_scale_init_value=0.``.
Args:
arch (str | dict): The model's architecture. If string, it should be
one of architecture in ``ConvNeXt.arch_settings``. And if dict, it
should include the following two keys:
- depths (list[int]): Number of blocks at each stage.
- channels (list[int]): The number of channels at each stage.
Defaults to 'tiny'.
in_channels (int): Number of input image channels. Defaults to 3.
stem_patch_size (int): The size of one patch in the stem layer.
Defaults to 4.
norm_cfg (dict): The config dict for norm layers.
Defaults to ``dict(type='SparseLN2d', eps=1e-6)``.
act_cfg (dict): The config dict for activation between pointwise
convolution. Defaults to ``dict(type='GELU')``.
linear_pw_conv (bool): Whether to use linear layer to do pointwise
convolution. Defaults to True.
use_grn (bool): Whether to add Global Response Normalization in the
blocks. Defaults to False.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
layer_scale_init_value (float): Init value for Layer Scale.
Defaults to 1e-6.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters.
gap_before_output (bool): Whether to globally average the feature
map before the final norm layer. In the official repo, it's only
used in classification task. Defaults to True.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict, optional): Initialization config dict.
""" # noqa: E501
def __init__(self,
arch: str = 'small',
in_channels: int = 3,
stem_patch_size: int = 4,
norm_cfg: dict = dict(type='SparseLN2d', eps=1e-6),
act_cfg: dict = dict(type='GELU'),
linear_pw_conv: bool = True,
use_grn: bool = False,
drop_path_rate: float = 0,
layer_scale_init_value: float = 1e-6,
out_indices: int = -1,
frozen_stages: int = 0,
gap_before_output: bool = True,
with_cp: bool = False,
init_cfg: Optional[Union[dict, List[dict]]] = [
dict(
type='TruncNormal',
layer=['Conv2d', 'Linear'],
std=.02,
bias=0.),
dict(
type='Constant', layer=['LayerNorm'], val=1.,
bias=0.),
]):
super(ConvNeXt, self).__init__(init_cfg=init_cfg)
if isinstance(arch, str):
assert arch in self.arch_settings, \
f'Unavailable arch, please choose from ' \
f'({set(self.arch_settings)}) or pass a dict.'
arch = self.arch_settings[arch]
elif isinstance(arch, dict):
assert 'depths' in arch and 'channels' in arch, \
f'The arch dict must have "depths" and "channels", ' \
f'but got {list(arch.keys())}.'
self.depths = arch['depths']
self.channels = arch['channels']
assert (isinstance(self.depths, Sequence)
and isinstance(self.channels, Sequence)
and len(self.depths) == len(self.channels)), \
f'The "depths" ({self.depths}) and "channels" ({self.channels}) ' \
'should be both sequence with the same length.'
self.num_stages = len(self.depths)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = 4 + index
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.gap_before_output = gap_before_output
# 4 downsample layers between stages, including the stem layer.
self.downsample_layers = ModuleList()
stem = nn.Sequential(
nn.Conv2d(
in_channels,
self.channels[0],
kernel_size=stem_patch_size,
stride=stem_patch_size),
build_norm_layer(norm_cfg, self.channels[0]),
)
self.downsample_layers.append(stem)
# stochastic depth decay rule
dpr = [
x.item()
for x in torch.linspace(0, drop_path_rate, sum(self.depths))
]
block_idx = 0
# 4 feature resolution stages, each consisting of multiple residual
# blocks
self.stages = nn.ModuleList()
for i in range(self.num_stages):
depth = self.depths[i]
channels = self.channels[i]
if i >= 1:
downsample_layer = nn.Sequential(
build_norm_layer(norm_cfg, self.channels[i - 1]),
nn.Conv2d(
self.channels[i - 1],
channels,
kernel_size=2,
stride=2),
)
self.downsample_layers.append(downsample_layer)
stage = Sequential(*[
SparseConvNeXtBlock(
in_channels=channels,
drop_path_rate=dpr[block_idx + j],
norm_cfg=norm_cfg,
act_cfg=act_cfg,
linear_pw_conv=linear_pw_conv,
layer_scale_init_value=layer_scale_init_value,
use_grn=use_grn,
with_cp=with_cp) for j in range(depth)
])
block_idx += depth
self.stages.append(stage)
self.dense_model_to_sparse(m=self)
def forward(self, x):
outs = []
for i, stage in enumerate(self.stages):
x = self.downsample_layers[i](x)
x = stage(x)
if i in self.out_indices:
if self.gap_before_output:
gap = x.mean([-2, -1], keepdim=True)
outs.append(gap.flatten(1))
else:
outs.append(x)
return tuple(outs)
def dense_model_to_sparse(self, m: nn.Module) -> nn.Module:
"""Convert regular dense modules to sparse modules."""
output = m
if isinstance(m, nn.Conv2d):
m: nn.Conv2d
bias = m.bias is not None
output = SparseConv2d(
m.in_channels,
m.out_channels,
kernel_size=m.kernel_size,
stride=m.stride,
padding=m.padding,
dilation=m.dilation,
groups=m.groups,
bias=bias,
padding_mode=m.padding_mode,
)
output.weight.data.copy_(m.weight.data)
if bias:
output.bias.data.copy_(m.bias.data)
elif isinstance(m, nn.MaxPool2d):
m: nn.MaxPool2d
output = SparseMaxPooling(
m.kernel_size,
stride=m.stride,
padding=m.padding,
dilation=m.dilation,
return_indices=m.return_indices,
ceil_mode=m.ceil_mode)
elif isinstance(m, nn.AvgPool2d):
m: nn.AvgPool2d
output = SparseAvgPooling(
m.kernel_size,
m.stride,
m.padding,
ceil_mode=m.ceil_mode,
count_include_pad=m.count_include_pad,
divisor_override=m.divisor_override)
# elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
# m: nn.BatchNorm2d
# output = (SparseSyncBatchNorm2d
# if enable_sync_bn else SparseBatchNorm2d)(
# m.weight.shape[0],
# eps=m.eps,
# momentum=m.momentum,
# affine=m.affine,
# track_running_stats=m.track_running_stats)
# output.weight.data.copy_(m.weight.data)
# output.bias.data.copy_(m.bias.data)
# output.running_mean.data.copy_(m.running_mean.data)
# output.running_var.data.copy_(m.running_var.data)
# output.num_batches_tracked.data.copy_(m.num_batches_tracked.data)
for name, child in m.named_children():
output.add_module(name, self.dense_model_to_sparse(child))
del m
return output
# Copyright (c) OpenMMLab. All rights reserved.
import re
from typing import Optional, Tuple
import torch.nn as nn
from mmpretrain.models.utils.sparse_modules import (SparseAvgPooling,
SparseBatchNorm2d,
SparseConv2d,
SparseMaxPooling,
SparseSyncBatchNorm2d)
from mmpretrain.registry import MODELS
from .resnet import ResNet
@MODELS.register_module()
class SparseResNet(ResNet):
"""ResNet with sparse module conversion function.
Modified from https://github.com/keyu-tian/SparK/blob/main/encoder.py
Args:
depth (int): Network depth, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Defaults to 3.
stem_channels (int): Output channels of the stem layer. Defaults to 64.
base_channels (int): Middle channels of the first stage.
Defaults to 64.
num_stages (int): Stages of the network. Defaults to 4.
strides (Sequence[int]): Strides of the first block of each stage.
Defaults to ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Defaults to ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages.
Defaults to ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Defaults to False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
conv_cfg (dict | None): The config dict for conv layers.
Defaults to None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Defaults to True.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
"""
def __init__(self,
depth: int,
in_channels: int = 3,
stem_channels: int = 64,
base_channels: int = 64,
expansion: Optional[int] = None,
num_stages: int = 4,
strides: Tuple[int] = (1, 2, 2, 2),
dilations: Tuple[int] = (1, 1, 1, 1),
out_indices: Tuple[int] = (3, ),
style: str = 'pytorch',
deep_stem: bool = False,
avg_down: bool = False,
frozen_stages: int = -1,
conv_cfg: Optional[dict] = None,
norm_cfg: dict = dict(type='SparseSyncBatchNorm2d'),
norm_eval: bool = False,
with_cp: bool = False,
zero_init_residual: bool = False,
init_cfg: Optional[dict] = [
dict(type='Kaiming', layer=['Conv2d']),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
],
drop_path_rate: float = 0,
**kwargs):
super().__init__(
depth=depth,
in_channels=in_channels,
stem_channels=stem_channels,
base_channels=base_channels,
expansion=expansion,
num_stages=num_stages,
strides=strides,
dilations=dilations,
out_indices=out_indices,
style=style,
deep_stem=deep_stem,
avg_down=avg_down,
frozen_stages=frozen_stages,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
norm_eval=norm_eval,
with_cp=with_cp,
zero_init_residual=zero_init_residual,
init_cfg=init_cfg,
drop_path_rate=drop_path_rate,
**kwargs)
norm_type = norm_cfg['type']
enable_sync_bn = False
if re.search('Sync', norm_type) is not None:
enable_sync_bn = True
self.dense_model_to_sparse(m=self, enable_sync_bn=enable_sync_bn)
def dense_model_to_sparse(self, m: nn.Module,
enable_sync_bn: bool) -> nn.Module:
"""Convert regular dense modules to sparse modules."""
output = m
if isinstance(m, nn.Conv2d):
m: nn.Conv2d
bias = m.bias is not None
output = SparseConv2d(
m.in_channels,
m.out_channels,
kernel_size=m.kernel_size,
stride=m.stride,
padding=m.padding,
dilation=m.dilation,
groups=m.groups,
bias=bias,
padding_mode=m.padding_mode,
)
output.weight.data.copy_(m.weight.data)
if bias:
output.bias.data.copy_(m.bias.data)
elif isinstance(m, nn.MaxPool2d):
m: nn.MaxPool2d
output = SparseMaxPooling(
m.kernel_size,
stride=m.stride,
padding=m.padding,
dilation=m.dilation,
return_indices=m.return_indices,
ceil_mode=m.ceil_mode)
elif isinstance(m, nn.AvgPool2d):
m: nn.AvgPool2d
output = SparseAvgPooling(
m.kernel_size,
m.stride,
m.padding,
ceil_mode=m.ceil_mode,
count_include_pad=m.count_include_pad,
divisor_override=m.divisor_override)
elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
m: nn.BatchNorm2d
output = (SparseSyncBatchNorm2d
if enable_sync_bn else SparseBatchNorm2d)(
m.weight.shape[0],
eps=m.eps,
momentum=m.momentum,
affine=m.affine,
track_running_stats=m.track_running_stats)
output.weight.data.copy_(m.weight.data)
output.bias.data.copy_(m.bias.data)
output.running_mean.data.copy_(m.running_mean.data)
output.running_var.data.copy_(m.running_var.data)
output.num_batches_tracked.data.copy_(m.num_batches_tracked.data)
elif isinstance(m, (nn.Conv1d, )):
raise NotImplementedError
for name, child in m.named_children():
output.add_module(
name,
self.dense_model_to_sparse(
child, enable_sync_bn=enable_sync_bn))
del m
return output
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed, PatchMerging
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import trunc_normal_
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmpretrain.registry import MODELS
from ..utils import (ShiftWindowMSA, resize_pos_embed,
resize_relative_position_bias_table, to_2tuple)
from .base_backbone import BaseBackbone
class SwinBlock(BaseModule):
"""Swin Transformer block.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
shift (bool): Shift the attention window or not. Defaults to False.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
drop_path (float): The drop path rate after attention and ffn.
Defaults to 0.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
attn_cfgs (dict): The extra config of Shift Window-MSA.
Defaults to empty dict.
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
norm_cfg (dict): The config of norm layers.
Defaults to ``dict(type='LN')``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size=7,
shift=False,
ffn_ratio=4.,
drop_path=0.,
pad_small_map=False,
attn_cfgs=dict(),
ffn_cfgs=dict(),
norm_cfg=dict(type='LN'),
with_cp=False,
init_cfg=None):
super(SwinBlock, self).__init__(init_cfg)
self.with_cp = with_cp
_attn_cfgs = {
'embed_dims': embed_dims,
'num_heads': num_heads,
'shift_size': window_size // 2 if shift else 0,
'window_size': window_size,
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
'pad_small_map': pad_small_map,
**attn_cfgs
}
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
self.attn = ShiftWindowMSA(**_attn_cfgs)
_ffn_cfgs = {
'embed_dims': embed_dims,
'feedforward_channels': int(embed_dims * ffn_ratio),
'num_fcs': 2,
'ffn_drop': 0,
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
'act_cfg': dict(type='GELU'),
**ffn_cfgs
}
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
self.ffn = FFN(**_ffn_cfgs)
def forward(self, x, hw_shape):
def _inner_forward(x):
identity = x
x = self.norm1(x)
x = self.attn(x, hw_shape)
x = x + identity
identity = x
x = self.norm2(x)
x = self.ffn(x, identity=identity)
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
class SwinBlockSequence(BaseModule):
"""Module with successive Swin Transformer blocks and downsample layer.
Args:
embed_dims (int): Number of input channels.
depth (int): Number of successive swin transformer blocks.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
downsample (bool): Downsample the output of blocks by patch merging.
Defaults to False.
downsample_cfg (dict): The extra config of the patch merging layer.
Defaults to empty dict.
drop_paths (Sequence[float] | float): The drop path rate in each block.
Defaults to 0.
block_cfgs (Sequence[dict] | dict): The extra config of each block.
Defaults to empty dicts.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
depth,
num_heads,
window_size=7,
downsample=False,
downsample_cfg=dict(),
drop_paths=0.,
block_cfgs=dict(),
with_cp=False,
pad_small_map=False,
init_cfg=None):
super().__init__(init_cfg)
if not isinstance(drop_paths, Sequence):
drop_paths = [drop_paths] * depth
if not isinstance(block_cfgs, Sequence):
block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)]
self.embed_dims = embed_dims
self.blocks = ModuleList()
for i in range(depth):
_block_cfg = {
'embed_dims': embed_dims,
'num_heads': num_heads,
'window_size': window_size,
'shift': False if i % 2 == 0 else True,
'drop_path': drop_paths[i],
'with_cp': with_cp,
'pad_small_map': pad_small_map,
**block_cfgs[i]
}
block = SwinBlock(**_block_cfg)
self.blocks.append(block)
if downsample:
_downsample_cfg = {
'in_channels': embed_dims,
'out_channels': 2 * embed_dims,
'norm_cfg': dict(type='LN'),
**downsample_cfg
}
self.downsample = PatchMerging(**_downsample_cfg)
else:
self.downsample = None
def forward(self, x, in_shape, do_downsample=True):
for block in self.blocks:
x = block(x, in_shape)
if self.downsample is not None and do_downsample:
x, out_shape = self.downsample(x, in_shape)
else:
out_shape = in_shape
return x, out_shape
@property
def out_channels(self):
if self.downsample:
return self.downsample.out_channels
else:
return self.embed_dims
@MODELS.register_module()
class SwinTransformer(BaseBackbone):
"""Swin Transformer.
A PyTorch implement of : `Swin Transformer:
Hierarchical Vision Transformer using Shifted Windows
<https://arxiv.org/abs/2103.14030>`_
Inspiration from
https://github.com/microsoft/Swin-Transformer
Args:
arch (str | dict): Swin Transformer architecture. If use string, choose
from 'tiny', 'small', 'base' and 'large'. If use dict, it should
have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **depths** (List[int]): The number of blocks in each stage.
- **num_heads** (List[int]): The number of heads in attention
modules of each stage.
Defaults to 'tiny'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 4.
in_channels (int): The num of input channels. Defaults to 3.
window_size (int): The height and width of the window. Defaults to 7.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
out_after_downsample (bool): Whether to output the feature map of a
stage after the following downsample layer. Defaults to False.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults to False.
interpolate_mode (str): Select the interpolate mode for absolute
position embeding vector resize. Defaults to "bicubic".
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
norm_cfg (dict): Config dict for normalization layer for all output
features. Defaults to ``dict(type='LN')``
stage_cfgs (Sequence[dict] | dict): Extra config dict for each
stage. Defaults to an empty dict.
patch_cfg (dict): Extra config dict for patch embedding.
Defaults to an empty dict.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmpretrain.models import SwinTransformer
>>> import torch
>>> extra_config = dict(
>>> arch='tiny',
>>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3,
>>> 'expansion_ratio': 3}))
>>> self = SwinTransformer(**extra_config)
>>> inputs = torch.rand(1, 3, 224, 224)
>>> output = self.forward(inputs)
>>> print(output.shape)
(1, 2592, 4)
"""
arch_zoo = {
**dict.fromkeys(['t', 'tiny'],
{'embed_dims': 96,
'depths': [2, 2, 6, 2],
'num_heads': [3, 6, 12, 24]}),
**dict.fromkeys(['s', 'small'],
{'embed_dims': 96,
'depths': [2, 2, 18, 2],
'num_heads': [3, 6, 12, 24]}),
**dict.fromkeys(['b', 'base'],
{'embed_dims': 128,
'depths': [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32]}),
**dict.fromkeys(['l', 'large'],
{'embed_dims': 192,
'depths': [2, 2, 18, 2],
'num_heads': [6, 12, 24, 48]}),
} # yapf: disable
_version = 3
num_extra_tokens = 0
def __init__(self,
arch='tiny',
img_size=224,
patch_size=4,
in_channels=3,
window_size=7,
drop_rate=0.,
drop_path_rate=0.1,
out_indices=(3, ),
out_after_downsample=False,
use_abs_pos_embed=False,
interpolate_mode='bicubic',
with_cp=False,
frozen_stages=-1,
norm_eval=False,
pad_small_map=False,
norm_cfg=dict(type='LN'),
stage_cfgs=dict(),
patch_cfg=dict(),
init_cfg=None):
super(SwinTransformer, self).__init__(init_cfg=init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {'embed_dims', 'depths', 'num_heads'}
assert isinstance(arch, dict) and set(arch) == essential_keys, \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.embed_dims = self.arch_settings['embed_dims']
self.depths = self.arch_settings['depths']
self.num_heads = self.arch_settings['num_heads']
self.num_layers = len(self.depths)
self.out_indices = out_indices
self.out_after_downsample = out_after_downsample
self.use_abs_pos_embed = use_abs_pos_embed
self.interpolate_mode = interpolate_mode
self.frozen_stages = frozen_stages
_patch_cfg = dict(
in_channels=in_channels,
input_size=img_size,
embed_dims=self.embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
norm_cfg=dict(type='LN'),
)
_patch_cfg.update(patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
self.patch_resolution = self.patch_embed.init_out_size
if self.use_abs_pos_embed:
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
self.absolute_pos_embed = nn.Parameter(
torch.zeros(1, num_patches, self.embed_dims))
self._register_load_state_dict_pre_hook(
self._prepare_abs_pos_embed)
self._register_load_state_dict_pre_hook(
self._prepare_relative_position_bias_table)
self.drop_after_pos = nn.Dropout(p=drop_rate)
self.norm_eval = norm_eval
# stochastic depth
total_depth = sum(self.depths)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
] # stochastic depth decay rule
self.stages = ModuleList()
embed_dims = [self.embed_dims]
for i, (depth,
num_heads) in enumerate(zip(self.depths, self.num_heads)):
if isinstance(stage_cfgs, Sequence):
stage_cfg = stage_cfgs[i]
else:
stage_cfg = deepcopy(stage_cfgs)
downsample = True if i < self.num_layers - 1 else False
_stage_cfg = {
'embed_dims': embed_dims[-1],
'depth': depth,
'num_heads': num_heads,
'window_size': window_size,
'downsample': downsample,
'drop_paths': dpr[:depth],
'with_cp': with_cp,
'pad_small_map': pad_small_map,
**stage_cfg
}
stage = SwinBlockSequence(**_stage_cfg)
self.stages.append(stage)
dpr = dpr[depth:]
embed_dims.append(stage.out_channels)
if self.out_after_downsample:
self.num_features = embed_dims[1:]
else:
self.num_features = embed_dims[:-1]
for i in out_indices:
if norm_cfg is not None:
norm_layer = build_norm_layer(norm_cfg,
self.num_features[i])[1]
else:
norm_layer = nn.Identity()
self.add_module(f'norm{i}', norm_layer)
def init_weights(self):
super(SwinTransformer, self).init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return
if self.use_abs_pos_embed:
trunc_normal_(self.absolute_pos_embed, std=0.02)
def forward(self, x):
x, hw_shape = self.patch_embed(x)
if self.use_abs_pos_embed:
x = x + resize_pos_embed(
self.absolute_pos_embed, self.patch_resolution, hw_shape,
self.interpolate_mode, self.num_extra_tokens)
x = self.drop_after_pos(x)
outs = []
for i, stage in enumerate(self.stages):
x, hw_shape = stage(
x, hw_shape, do_downsample=self.out_after_downsample)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
out = norm_layer(x)
out = out.view(-1, *hw_shape,
self.num_features[i]).permute(0, 3, 1,
2).contiguous()
outs.append(out)
if stage.downsample is not None and not self.out_after_downsample:
x, hw_shape = stage.downsample(x, hw_shape)
return tuple(outs)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, *args,
**kwargs):
"""load checkpoints."""
# Names of some parameters in has been changed.
version = local_metadata.get('version', None)
if (version is None
or version < 2) and self.__class__ is SwinTransformer:
final_stage_num = len(self.stages) - 1
state_dict_keys = list(state_dict.keys())
for k in state_dict_keys:
if k.startswith('norm.') or k.startswith('backbone.norm.'):
convert_key = k.replace('norm.', f'norm{final_stage_num}.')
state_dict[convert_key] = state_dict[k]
del state_dict[k]
if (version is None
or version < 3) and self.__class__ is SwinTransformer:
state_dict_keys = list(state_dict.keys())
for k in state_dict_keys:
if 'attn_mask' in k:
del state_dict[k]
super()._load_from_state_dict(state_dict, prefix, local_metadata,
*args, **kwargs)
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
for i in range(0, self.frozen_stages + 1):
m = self.stages[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
for i in self.out_indices:
if i <= self.frozen_stages:
for param in getattr(self, f'norm{i}').parameters():
param.requires_grad = False
def train(self, mode=True):
super(SwinTransformer, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'absolute_pos_embed'
if name not in state_dict.keys():
return
ckpt_pos_embed_shape = state_dict[name].shape
if self.absolute_pos_embed.shape != ckpt_pos_embed_shape:
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info(
'Resize the absolute_pos_embed shape from '
f'{ckpt_pos_embed_shape} to {self.absolute_pos_embed.shape}.')
ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
pos_embed_shape = self.patch_embed.init_out_size
state_dict[name] = resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode,
self.num_extra_tokens)
def _prepare_relative_position_bias_table(self, state_dict, prefix, *args,
**kwargs):
state_dict_model = self.state_dict()
all_keys = list(state_dict_model.keys())
for key in all_keys:
if 'relative_position_bias_table' in key:
ckpt_key = prefix + key
if ckpt_key not in state_dict:
continue
relative_position_bias_table_pretrained = state_dict[ckpt_key]
relative_position_bias_table_current = state_dict_model[key]
L1, nH1 = relative_position_bias_table_pretrained.size()
L2, nH2 = relative_position_bias_table_current.size()
if L1 != L2:
src_size = int(L1**0.5)
dst_size = int(L2**0.5)
new_rel_pos_bias = resize_relative_position_bias_table(
src_size, dst_size,
relative_position_bias_table_pretrained, nH1)
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info('Resize the relative_position_bias_table from '
f'{state_dict[ckpt_key].shape} to '
f'{new_rel_pos_bias.shape}')
state_dict[ckpt_key] = new_rel_pos_bias
# The index buffer need to be re-generated.
index_buffer = ckpt_key.replace('bias_table', 'index')
del state_dict[index_buffer]
def get_layer_depth(self, param_name: str, prefix: str = ''):
"""Get the layer-wise depth of a parameter.
Args:
param_name (str): The name of the parameter.
prefix (str): The prefix for the parameter.
Defaults to an empty string.
Returns:
Tuple[int, int]: The layer-wise depth and the num of layers.
Note:
The first depth is the stem module (``layer_depth=0``), and the
last depth is the subsequent module (``layer_depth=num_layers-1``)
"""
num_layers = sum(self.depths) + 2
if not param_name.startswith(prefix):
# For subsequent module like head
return num_layers - 1, num_layers
param_name = param_name[len(prefix):]
if param_name.startswith('patch_embed'):
layer_depth = 0
elif param_name.startswith('stages'):
stage_id = int(param_name.split('.')[1])
block_id = param_name.split('.')[3]
if block_id in ('reduction', 'norm'):
layer_depth = sum(self.depths[:stage_id + 1])
else:
layer_depth = sum(self.depths[:stage_id]) + int(block_id) + 1
else:
layer_depth = num_layers - 1
return layer_depth, num_layers
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import trunc_normal_
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from ..builder import MODELS
from ..utils import (PatchMerging, ShiftWindowMSA, WindowMSAV2,
resize_pos_embed, to_2tuple)
from .base_backbone import BaseBackbone
class SwinBlockV2(BaseModule):
"""Swin Transformer V2 block. Use post normalization.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
shift (bool): Shift the attention window or not. Defaults to False.
extra_norm (bool): Whether add extra norm at the end of main branch.
ffn_ratio (float): The expansion ratio of feedforward network hidden
layer channels. Defaults to 4.
drop_path (float): The drop path rate after attention and ffn.
Defaults to 0.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
attn_cfgs (dict): The extra config of Shift Window-MSA.
Defaults to empty dict.
ffn_cfgs (dict): The extra config of FFN. Defaults to empty dict.
norm_cfg (dict): The config of norm layers.
Defaults to ``dict(type='LN')``.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
pretrained_window_size (int): Window size in pretrained.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size=8,
shift=False,
extra_norm=False,
ffn_ratio=4.,
drop_path=0.,
pad_small_map=False,
attn_cfgs=dict(),
ffn_cfgs=dict(),
norm_cfg=dict(type='LN'),
with_cp=False,
pretrained_window_size=0,
init_cfg=None):
super(SwinBlockV2, self).__init__(init_cfg)
self.with_cp = with_cp
self.extra_norm = extra_norm
_attn_cfgs = {
'embed_dims': embed_dims,
'num_heads': num_heads,
'shift_size': window_size // 2 if shift else 0,
'window_size': window_size,
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
'pad_small_map': pad_small_map,
**attn_cfgs
}
# use V2 attention implementation
_attn_cfgs.update(
window_msa=WindowMSAV2,
pretrained_window_size=to_2tuple(pretrained_window_size))
self.attn = ShiftWindowMSA(**_attn_cfgs)
self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]
_ffn_cfgs = {
'embed_dims': embed_dims,
'feedforward_channels': int(embed_dims * ffn_ratio),
'num_fcs': 2,
'ffn_drop': 0,
'dropout_layer': dict(type='DropPath', drop_prob=drop_path),
'act_cfg': dict(type='GELU'),
'add_identity': False,
**ffn_cfgs
}
self.ffn = FFN(**_ffn_cfgs)
self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]
# add extra norm for every n blocks in huge and giant model
if self.extra_norm:
self.norm3 = build_norm_layer(norm_cfg, embed_dims)[1]
def forward(self, x, hw_shape):
def _inner_forward(x):
# Use post normalization
identity = x
x = self.attn(x, hw_shape)
x = self.norm1(x)
x = x + identity
identity = x
x = self.ffn(x)
x = self.norm2(x)
x = x + identity
if self.extra_norm:
x = self.norm3(x)
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
class SwinBlockV2Sequence(BaseModule):
"""Module with successive Swin Transformer blocks and downsample layer.
Args:
embed_dims (int): Number of input channels.
depth (int): Number of successive swin transformer blocks.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window. Defaults to 7.
downsample (bool): Downsample the output of blocks by patch merging.
Defaults to False.
downsample_cfg (dict): The extra config of the patch merging layer.
Defaults to empty dict.
drop_paths (Sequence[float] | float): The drop path rate in each block.
Defaults to 0.
block_cfgs (Sequence[dict] | dict): The extra config of each block.
Defaults to empty dicts.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
extra_norm_every_n_blocks (int): Add extra norm at the end of main
branch every n blocks. Defaults to 0, which means no needs for
extra norm layer.
pretrained_window_size (int): Window size in pretrained.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
depth,
num_heads,
window_size=8,
downsample=False,
downsample_cfg=dict(),
drop_paths=0.,
block_cfgs=dict(),
with_cp=False,
pad_small_map=False,
extra_norm_every_n_blocks=0,
pretrained_window_size=0,
init_cfg=None):
super().__init__(init_cfg)
if not isinstance(drop_paths, Sequence):
drop_paths = [drop_paths] * depth
if not isinstance(block_cfgs, Sequence):
block_cfgs = [deepcopy(block_cfgs) for _ in range(depth)]
if downsample:
self.out_channels = 2 * embed_dims
_downsample_cfg = {
'in_channels': embed_dims,
'out_channels': self.out_channels,
'norm_cfg': dict(type='LN'),
**downsample_cfg
}
self.downsample = PatchMerging(**_downsample_cfg)
else:
self.out_channels = embed_dims
self.downsample = None
self.blocks = ModuleList()
for i in range(depth):
extra_norm = True if extra_norm_every_n_blocks and \
(i + 1) % extra_norm_every_n_blocks == 0 else False
_block_cfg = {
'embed_dims': self.out_channels,
'num_heads': num_heads,
'window_size': window_size,
'shift': False if i % 2 == 0 else True,
'extra_norm': extra_norm,
'drop_path': drop_paths[i],
'with_cp': with_cp,
'pad_small_map': pad_small_map,
'pretrained_window_size': pretrained_window_size,
**block_cfgs[i]
}
block = SwinBlockV2(**_block_cfg)
self.blocks.append(block)
def forward(self, x, in_shape):
if self.downsample:
x, out_shape = self.downsample(x, in_shape)
else:
out_shape = in_shape
for block in self.blocks:
x = block(x, out_shape)
return x, out_shape
@MODELS.register_module()
class SwinTransformerV2(BaseBackbone):
"""Swin Transformer V2.
A PyTorch implement of : `Swin Transformer V2:
Scaling Up Capacity and Resolution
<https://arxiv.org/abs/2111.09883>`_
Inspiration from
https://github.com/microsoft/Swin-Transformer
Args:
arch (str | dict): Swin Transformer architecture. If use string, choose
from 'tiny', 'small', 'base' and 'large'. If use dict, it should
have below keys:
- **embed_dims** (int): The dimensions of embedding.
- **depths** (List[int]): The number of blocks in each stage.
- **num_heads** (List[int]): The number of heads in attention
modules of each stage.
- **extra_norm_every_n_blocks** (int): Add extra norm at the end
of main branch every n blocks.
Defaults to 'tiny'.
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
patch_size (int | tuple): The patch size in patch embedding.
Defaults to 4.
in_channels (int): The num of input channels. Defaults to 3.
window_size (int | Sequence): The height and width of the window.
Defaults to 7.
drop_rate (float): Dropout rate after embedding. Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
use_abs_pos_embed (bool): If True, add absolute position embedding to
the patch embedding. Defaults to False.
interpolate_mode (str): Select the interpolate mode for absolute
position embeding vector resize. Defaults to "bicubic".
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
norm_cfg (dict): Config dict for normalization layer for all output
features. Defaults to ``dict(type='LN')``
stage_cfgs (Sequence[dict] | dict): Extra config dict for each
stage. Defaults to an empty dict.
patch_cfg (dict): Extra config dict for patch embedding.
Defaults to an empty dict.
pretrained_window_sizes (tuple(int)): Pretrained window sizes of
each layer.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
Examples:
>>> from mmpretrain.models import SwinTransformerV2
>>> import torch
>>> extra_config = dict(
>>> arch='tiny',
>>> stage_cfgs=dict(downsample_cfg={'kernel_size': 3,
>>> 'padding': 'same'}))
>>> self = SwinTransformerV2(**extra_config)
>>> inputs = torch.rand(1, 3, 224, 224)
>>> output = self.forward(inputs)
>>> print(output.shape)
(1, 2592, 4)
"""
arch_zoo = {
**dict.fromkeys(['t', 'tiny'],
{'embed_dims': 96,
'depths': [2, 2, 6, 2],
'num_heads': [3, 6, 12, 24],
'extra_norm_every_n_blocks': 0}),
**dict.fromkeys(['s', 'small'],
{'embed_dims': 96,
'depths': [2, 2, 18, 2],
'num_heads': [3, 6, 12, 24],
'extra_norm_every_n_blocks': 0}),
**dict.fromkeys(['b', 'base'],
{'embed_dims': 128,
'depths': [2, 2, 18, 2],
'num_heads': [4, 8, 16, 32],
'extra_norm_every_n_blocks': 0}),
**dict.fromkeys(['l', 'large'],
{'embed_dims': 192,
'depths': [2, 2, 18, 2],
'num_heads': [6, 12, 24, 48],
'extra_norm_every_n_blocks': 0}),
# head count not certain for huge, and is employed for another
# parallel study about self-supervised learning.
**dict.fromkeys(['h', 'huge'],
{'embed_dims': 352,
'depths': [2, 2, 18, 2],
'num_heads': [8, 16, 32, 64],
'extra_norm_every_n_blocks': 6}),
**dict.fromkeys(['g', 'giant'],
{'embed_dims': 512,
'depths': [2, 2, 42, 4],
'num_heads': [16, 32, 64, 128],
'extra_norm_every_n_blocks': 6}),
} # yapf: disable
_version = 1
num_extra_tokens = 0
def __init__(self,
arch='tiny',
img_size=256,
patch_size=4,
in_channels=3,
window_size=8,
drop_rate=0.,
drop_path_rate=0.1,
out_indices=(3, ),
use_abs_pos_embed=False,
interpolate_mode='bicubic',
with_cp=False,
frozen_stages=-1,
norm_eval=False,
pad_small_map=False,
norm_cfg=dict(type='LN'),
stage_cfgs=dict(),
patch_cfg=dict(),
pretrained_window_sizes=[0, 0, 0, 0],
init_cfg=None):
super(SwinTransformerV2, self).__init__(init_cfg=init_cfg)
if isinstance(arch, str):
arch = arch.lower()
assert arch in set(self.arch_zoo), \
f'Arch {arch} is not in default archs {set(self.arch_zoo)}'
self.arch_settings = self.arch_zoo[arch]
else:
essential_keys = {
'embed_dims', 'depths', 'num_heads',
'extra_norm_every_n_blocks'
}
assert isinstance(arch, dict) and set(arch) == essential_keys, \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
self.embed_dims = self.arch_settings['embed_dims']
self.depths = self.arch_settings['depths']
self.num_heads = self.arch_settings['num_heads']
self.extra_norm_every_n_blocks = self.arch_settings[
'extra_norm_every_n_blocks']
self.num_layers = len(self.depths)
self.out_indices = out_indices
self.use_abs_pos_embed = use_abs_pos_embed
self.interpolate_mode = interpolate_mode
self.frozen_stages = frozen_stages
if isinstance(window_size, int):
self.window_sizes = [window_size for _ in range(self.num_layers)]
elif isinstance(window_size, Sequence):
assert len(window_size) == self.num_layers, \
f'Length of window_sizes {len(window_size)} is not equal to '\
f'length of stages {self.num_layers}.'
self.window_sizes = window_size
else:
raise TypeError('window_size should be a Sequence or int.')
_patch_cfg = dict(
in_channels=in_channels,
input_size=img_size,
embed_dims=self.embed_dims,
conv_type='Conv2d',
kernel_size=patch_size,
stride=patch_size,
norm_cfg=dict(type='LN'),
)
_patch_cfg.update(patch_cfg)
self.patch_embed = PatchEmbed(**_patch_cfg)
self.patch_resolution = self.patch_embed.init_out_size
if self.use_abs_pos_embed:
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
self.absolute_pos_embed = nn.Parameter(
torch.zeros(1, num_patches, self.embed_dims))
self._register_load_state_dict_pre_hook(
self._prepare_abs_pos_embed)
self._register_load_state_dict_pre_hook(self._delete_reinit_params)
self.drop_after_pos = nn.Dropout(p=drop_rate)
self.norm_eval = norm_eval
# stochastic depth
total_depth = sum(self.depths)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, total_depth)
] # stochastic depth decay rule
self.stages = ModuleList()
embed_dims = [self.embed_dims]
for i, (depth,
num_heads) in enumerate(zip(self.depths, self.num_heads)):
if isinstance(stage_cfgs, Sequence):
stage_cfg = stage_cfgs[i]
else:
stage_cfg = deepcopy(stage_cfgs)
downsample = True if i > 0 else False
_stage_cfg = {
'embed_dims': embed_dims[-1],
'depth': depth,
'num_heads': num_heads,
'window_size': self.window_sizes[i],
'downsample': downsample,
'drop_paths': dpr[:depth],
'with_cp': with_cp,
'pad_small_map': pad_small_map,
'extra_norm_every_n_blocks': self.extra_norm_every_n_blocks,
'pretrained_window_size': pretrained_window_sizes[i],
'downsample_cfg': dict(use_post_norm=True),
**stage_cfg
}
stage = SwinBlockV2Sequence(**_stage_cfg)
self.stages.append(stage)
dpr = dpr[depth:]
embed_dims.append(stage.out_channels)
for i in out_indices:
if norm_cfg is not None:
norm_layer = build_norm_layer(norm_cfg, embed_dims[i + 1])[1]
else:
norm_layer = nn.Identity()
self.add_module(f'norm{i}', norm_layer)
def init_weights(self):
super(SwinTransformerV2, self).init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return
if self.use_abs_pos_embed:
trunc_normal_(self.absolute_pos_embed, std=0.02)
def forward(self, x):
x, hw_shape = self.patch_embed(x)
if self.use_abs_pos_embed:
x = x + resize_pos_embed(
self.absolute_pos_embed, self.patch_resolution, hw_shape,
self.interpolate_mode, self.num_extra_tokens)
x = self.drop_after_pos(x)
outs = []
for i, stage in enumerate(self.stages):
x, hw_shape = stage(x, hw_shape)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
out = norm_layer(x)
out = out.view(-1, *hw_shape,
stage.out_channels).permute(0, 3, 1,
2).contiguous()
outs.append(out)
return tuple(outs)
def _freeze_stages(self):
if self.frozen_stages >= 0:
self.patch_embed.eval()
for param in self.patch_embed.parameters():
param.requires_grad = False
for i in range(0, self.frozen_stages + 1):
m = self.stages[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
for i in self.out_indices:
if i <= self.frozen_stages:
for param in getattr(self, f'norm{i}').parameters():
param.requires_grad = False
def train(self, mode=True):
super(SwinTransformerV2, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
def _prepare_abs_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'absolute_pos_embed'
if name not in state_dict.keys():
return
ckpt_pos_embed_shape = state_dict[name].shape
if self.absolute_pos_embed.shape != ckpt_pos_embed_shape:
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info(
'Resize the absolute_pos_embed shape from '
f'{ckpt_pos_embed_shape} to {self.absolute_pos_embed.shape}.')
ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
pos_embed_shape = self.patch_embed.init_out_size
state_dict[name] = resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode,
self.num_extra_tokens)
def _delete_reinit_params(self, state_dict, prefix, *args, **kwargs):
# delete relative_position_index since we always re-init it
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info(
'Delete `relative_position_index` and `relative_coords_table` '
'since we always re-init these params according to the '
'`window_size`, which might cause unwanted but unworried '
'warnings when loading checkpoint.')
relative_position_index_keys = [
k for k in state_dict.keys() if 'relative_position_index' in k
]
for k in relative_position_index_keys:
del state_dict[k]
# delete relative_coords_table since we always re-init it
relative_position_index_keys = [
k for k in state_dict.keys() if 'relative_coords_table' in k
]
for k in relative_position_index_keys:
del state_dict[k]
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn.bricks.transformer import FFN
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import trunc_normal_
from mmpretrain.registry import MODELS
from ..utils import (MultiheadAttention, build_norm_layer, resize_pos_embed,
to_2tuple)
from .base_backbone import BaseBackbone
class T2TTransformerLayer(BaseModule):
"""Transformer Layer for T2T_ViT.
Comparing with :obj:`TransformerEncoderLayer` in ViT, it supports
different ``input_dims`` and ``embed_dims``.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs
input_dims (int, optional): The input token dimension.
Defaults to None.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop_rate (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
qk_scale (float, optional): Override default qk scale of
``(input_dims // num_heads) ** -0.5`` if set. Defaults to None.
act_cfg (dict): The activation config for FFNs.
Defaults to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
Notes:
In general, ``qk_scale`` should be ``head_dims ** -0.5``, i.e.
``(embed_dims // num_heads) ** -0.5``. However, in the official
code, it uses ``(input_dims // num_heads) ** -0.5``, so here we
keep the same with the official implementation.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
input_dims=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
qkv_bias=False,
qk_scale=None,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_cfg=None):
super(T2TTransformerLayer, self).__init__(init_cfg=init_cfg)
self.v_shortcut = True if input_dims is not None else False
input_dims = input_dims or embed_dims
self.ln1 = build_norm_layer(norm_cfg, input_dims)
self.attn = MultiheadAttention(
input_dims=input_dims,
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
qkv_bias=qkv_bias,
qk_scale=qk_scale or (input_dims // num_heads)**-0.5,
v_shortcut=self.v_shortcut)
self.ln2 = build_norm_layer(norm_cfg, embed_dims)
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg)
def forward(self, x):
if self.v_shortcut:
x = self.attn(self.ln1(x))
else:
x = x + self.attn(self.ln1(x))
x = self.ffn(self.ln2(x), identity=x)
return x
class T2TModule(BaseModule):
"""Tokens-to-Token module.
"Tokens-to-Token module" (T2T Module) can model the local structure
information of images and reduce the length of tokens progressively.
Args:
img_size (int): Input image size
in_channels (int): Number of input channels
embed_dims (int): Embedding dimension
token_dims (int): Tokens dimension in T2TModuleAttention.
use_performer (bool): If True, use Performer version self-attention to
adopt regular self-attention. Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Default: None.
Notes:
Usually, ``token_dim`` is set as a small value (32 or 64) to reduce
MACs
"""
def __init__(
self,
img_size=224,
in_channels=3,
embed_dims=384,
token_dims=64,
use_performer=False,
init_cfg=None,
):
super(T2TModule, self).__init__(init_cfg)
self.embed_dims = embed_dims
self.soft_split0 = nn.Unfold(
kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
self.soft_split1 = nn.Unfold(
kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.soft_split2 = nn.Unfold(
kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
if not use_performer:
self.attention1 = T2TTransformerLayer(
input_dims=in_channels * 7 * 7,
embed_dims=token_dims,
num_heads=1,
feedforward_channels=token_dims)
self.attention2 = T2TTransformerLayer(
input_dims=token_dims * 3 * 3,
embed_dims=token_dims,
num_heads=1,
feedforward_channels=token_dims)
self.project = nn.Linear(token_dims * 3 * 3, embed_dims)
else:
raise NotImplementedError("Performer hasn't been implemented.")
# there are 3 soft split, stride are 4,2,2 separately
out_side = img_size // (4 * 2 * 2)
self.init_out_size = [out_side, out_side]
self.num_patches = out_side**2
@staticmethod
def _get_unfold_size(unfold: nn.Unfold, input_size):
h, w = input_size
kernel_size = to_2tuple(unfold.kernel_size)
stride = to_2tuple(unfold.stride)
padding = to_2tuple(unfold.padding)
dilation = to_2tuple(unfold.dilation)
h_out = (h + 2 * padding[0] - dilation[0] *
(kernel_size[0] - 1) - 1) // stride[0] + 1
w_out = (w + 2 * padding[1] - dilation[1] *
(kernel_size[1] - 1) - 1) // stride[1] + 1
return (h_out, w_out)
def forward(self, x):
# step0: soft split
hw_shape = self._get_unfold_size(self.soft_split0, x.shape[2:])
x = self.soft_split0(x).transpose(1, 2)
for step in [1, 2]:
# re-structurization/reconstruction
attn = getattr(self, f'attention{step}')
x = attn(x).transpose(1, 2)
B, C, _ = x.shape
x = x.reshape(B, C, hw_shape[0], hw_shape[1])
# soft split
soft_split = getattr(self, f'soft_split{step}')
hw_shape = self._get_unfold_size(soft_split, hw_shape)
x = soft_split(x).transpose(1, 2)
# final tokens
x = self.project(x)
return x, hw_shape
def get_sinusoid_encoding(n_position, embed_dims):
"""Generate sinusoid encoding table.
Sinusoid encoding is a kind of relative position encoding method came from
`Attention Is All You Need<https://arxiv.org/abs/1706.03762>`_.
Args:
n_position (int): The length of the input token.
embed_dims (int): The position embedding dimension.
Returns:
:obj:`torch.FloatTensor`: The sinusoid encoding table.
"""
def get_position_angle_vec(position):
return [
position / np.power(10000, 2 * (i // 2) / embed_dims)
for i in range(embed_dims)
]
sinusoid_table = np.array(
[get_position_angle_vec(pos) for pos in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
@MODELS.register_module()
class T2T_ViT(BaseBackbone):
"""Tokens-to-Token Vision Transformer (T2T-ViT)
A PyTorch implementation of `Tokens-to-Token ViT: Training Vision
Transformers from Scratch on ImageNet <https://arxiv.org/abs/2101.11986>`_
Args:
img_size (int | tuple): The expected input image shape. Because we
support dynamic input shape, just set the argument to the most
common input image shape. Defaults to 224.
in_channels (int): Number of input channels.
embed_dims (int): Embedding dimension.
num_layers (int): Num of transformer layers in encoder.
Defaults to 14.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Dropout rate after position embedding.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
norm_cfg (dict): Config dict for normalization layer. Defaults to
``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
out_type (str): The type of output features. Please choose from
- ``"cls_token"``: The class token tensor with shape (B, C).
- ``"featmap"``: The feature map tensor from the patch tokens
with shape (B, C, H, W).
- ``"avg_featmap"``: The global averaged feature map tensor
with shape (B, C).
- ``"raw"``: The raw feature tensor includes patch tokens and
class tokens with shape (B, L, C).
Defaults to ``"cls_token"``.
with_cls_token (bool): Whether concatenating class token into image
tokens as transformer input. Defaults to True.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
t2t_cfg (dict): Extra config of Tokens-to-Token module.
Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'}
def __init__(self,
img_size=224,
in_channels=3,
embed_dims=384,
num_layers=14,
out_indices=-1,
drop_rate=0.,
drop_path_rate=0.,
norm_cfg=dict(type='LN'),
final_norm=True,
out_type='cls_token',
with_cls_token=True,
interpolate_mode='bicubic',
t2t_cfg=dict(),
layer_cfgs=dict(),
init_cfg=None):
super().__init__(init_cfg)
# Token-to-Token Module
self.tokens_to_token = T2TModule(
img_size=img_size,
in_channels=in_channels,
embed_dims=embed_dims,
**t2t_cfg)
self.patch_resolution = self.tokens_to_token.init_out_size
num_patches = self.patch_resolution[0] * self.patch_resolution[1]
# Set out type
if out_type not in self.OUT_TYPES:
raise ValueError(f'Unsupported `out_type` {out_type}, please '
f'choose from {self.OUT_TYPES}')
self.out_type = out_type
# Set cls token
if with_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
self.num_extra_tokens = 1
elif out_type != 'cls_token':
self.cls_token = None
self.num_extra_tokens = 0
else:
raise ValueError(
'with_cls_token must be True when `out_type="cls_token"`.')
# Set position embedding
self.interpolate_mode = interpolate_mode
sinusoid_table = get_sinusoid_encoding(
num_patches + self.num_extra_tokens, embed_dims)
self.register_buffer('pos_embed', sinusoid_table)
self._register_load_state_dict_pre_hook(self._prepare_pos_embed)
self.drop_after_pos = nn.Dropout(p=drop_rate)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must be a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = num_layers + index
assert 0 <= out_indices[i] <= num_layers, \
f'Invalid out_indices {index}'
self.out_indices = out_indices
# stochastic depth decay rule
dpr = [x for x in np.linspace(0, drop_path_rate, num_layers)]
self.encoder = ModuleList()
for i in range(num_layers):
if isinstance(layer_cfgs, Sequence):
layer_cfg = layer_cfgs[i]
else:
layer_cfg = deepcopy(layer_cfgs)
layer_cfg = {
'embed_dims': embed_dims,
'num_heads': 6,
'feedforward_channels': 3 * embed_dims,
'drop_path_rate': dpr[i],
'qkv_bias': False,
'norm_cfg': norm_cfg,
**layer_cfg
}
layer = T2TTransformerLayer(**layer_cfg)
self.encoder.append(layer)
self.final_norm = final_norm
if final_norm:
self.norm = build_norm_layer(norm_cfg, embed_dims)
else:
self.norm = nn.Identity()
def init_weights(self):
super().init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress custom init if use pretrained model.
return
trunc_normal_(self.cls_token, std=.02)
def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'pos_embed'
if name not in state_dict.keys():
return
ckpt_pos_embed_shape = state_dict[name].shape
if self.pos_embed.shape != ckpt_pos_embed_shape:
from mmengine.logging import MMLogger
logger = MMLogger.get_current_instance()
logger.info(
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
f'to {self.pos_embed.shape}.')
ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
pos_embed_shape = self.tokens_to_token.init_out_size
state_dict[name] = resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode,
self.num_extra_tokens)
def forward(self, x):
B = x.shape[0]
x, patch_resolution = self.tokens_to_token(x)
if self.cls_token is not None:
# stole cls_tokens impl from Phil Wang, thanks
cls_token = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = x + resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=self.num_extra_tokens)
x = self.drop_after_pos(x)
outs = []
for i, layer in enumerate(self.encoder):
x = layer(x)
if i == len(self.encoder) - 1 and self.final_norm:
x = self.norm(x)
if i in self.out_indices:
outs.append(self._format_output(x, patch_resolution))
return tuple(outs)
def _format_output(self, x, hw):
if self.out_type == 'raw':
return x
if self.out_type == 'cls_token':
return x[:, 0]
patch_token = x[:, self.num_extra_tokens:]
if self.out_type == 'featmap':
B = x.size(0)
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
if self.out_type == 'avg_featmap':
return patch_token.mean(dim=1)
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from mmengine.logging import MMLogger
from mmpretrain.registry import MODELS
from mmpretrain.utils import require
from .base_backbone import BaseBackbone
def print_timm_feature_info(feature_info):
"""Print feature_info of timm backbone to help development and debug.
Args:
feature_info (list[dict] | timm.models.features.FeatureInfo | None):
feature_info of timm backbone.
"""
logger = MMLogger.get_current_instance()
if feature_info is None:
logger.warning('This backbone does not have feature_info')
elif isinstance(feature_info, list):
for feat_idx, each_info in enumerate(feature_info):
logger.info(f'backbone feature_info[{feat_idx}]: {each_info}')
else:
try:
logger.info(f'backbone out_indices: {feature_info.out_indices}')
logger.info(f'backbone out_channels: {feature_info.channels()}')
logger.info(f'backbone out_strides: {feature_info.reduction()}')
except AttributeError:
logger.warning('Unexpected format of backbone feature_info')
@MODELS.register_module()
class TIMMBackbone(BaseBackbone):
"""Wrapper to use backbones from timm library.
More details can be found in
`timm <https://github.com/rwightman/pytorch-image-models>`_.
See especially the document for `feature extraction
<https://rwightman.github.io/pytorch-image-models/feature_extraction/>`_.
Args:
model_name (str): Name of timm model to instantiate.
features_only (bool): Whether to extract feature pyramid (multi-scale
feature maps from the deepest layer at each stride). For Vision
Transformer models that do not support this argument,
set this False. Defaults to False.
pretrained (bool): Whether to load pretrained weights.
Defaults to False.
checkpoint_path (str): Path of checkpoint to load at the last of
``timm.create_model``. Defaults to empty string, which means
not loading.
in_channels (int): Number of input image channels. Defaults to 3.
init_cfg (dict or list[dict], optional): Initialization config dict of
OpenMMLab projects. Defaults to None.
**kwargs: Other timm & model specific arguments.
"""
@require('timm')
def __init__(self,
model_name,
features_only=False,
pretrained=False,
checkpoint_path='',
in_channels=3,
init_cfg=None,
**kwargs):
import timm
if not isinstance(pretrained, bool):
raise TypeError('pretrained must be bool, not str for model path')
if features_only and checkpoint_path:
warnings.warn(
'Using both features_only and checkpoint_path will cause error'
' in timm. See '
'https://github.com/rwightman/pytorch-image-models/issues/488')
super(TIMMBackbone, self).__init__(init_cfg)
if 'norm_layer' in kwargs:
norm_class = MODELS.get(kwargs['norm_layer'])
def build_norm(*args, **kwargs):
return norm_class(*args, **kwargs)
kwargs['norm_layer'] = build_norm
self.timm_model = timm.create_model(
model_name=model_name,
features_only=features_only,
pretrained=pretrained,
in_chans=in_channels,
checkpoint_path=checkpoint_path,
**kwargs)
# reset classifier
if hasattr(self.timm_model, 'reset_classifier'):
self.timm_model.reset_classifier(0, '')
# Hack to use pretrained weights from timm
if pretrained or checkpoint_path:
self._is_init = True
feature_info = getattr(self.timm_model, 'feature_info', None)
print_timm_feature_info(feature_info)
def forward(self, x):
features = self.timm_model(x)
if isinstance(features, (list, tuple)):
features = tuple(features)
else:
features = (features, )
return features
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Tuple
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer
from mmengine.model import BaseModule, ModuleList, Sequential
from torch.nn import functional as F
from mmpretrain.registry import MODELS
from ..utils import LeAttention
from .base_backbone import BaseBackbone
class ConvBN2d(Sequential):
"""An implementation of Conv2d + BatchNorm2d with support of fusion.
Modified from
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
kernel_size (int): The size of the convolution kernel.
Default: 1.
stride (int): The stride of the convolution.
Default: 1.
padding (int): The padding of the convolution.
Default: 0.
dilation (int): The dilation of the convolution.
Default: 1.
groups (int): The number of groups in the convolution.
Default: 1.
bn_weight_init (float): The initial value of the weight of
the nn.BatchNorm2d layer. Default: 1.0.
init_cfg (dict): The initialization config of the module.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
dilation=1,
groups=1,
bn_weight_init=1.0,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.add_module(
'conv2d',
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=False))
bn2d = nn.BatchNorm2d(num_features=out_channels)
# bn initialization
torch.nn.init.constant_(bn2d.weight, bn_weight_init)
torch.nn.init.constant_(bn2d.bias, 0)
self.add_module('bn2d', bn2d)
@torch.no_grad()
def fuse(self):
conv2d, bn2d = self._modules.values()
w = bn2d.weight / (bn2d.running_var + bn2d.eps)**0.5
w = conv2d.weight * w[:, None, None, None]
b = bn2d.bias - bn2d.running_mean * bn2d.weight / \
(bn2d.running_var + bn2d.eps)**0.5
m = nn.Conv2d(
in_channels=w.size(1) * self.c.groups,
out_channels=w.size(0),
kernel_size=w.shape[2:],
stride=self.conv2d.stride,
padding=self.conv2d.padding,
dilation=self.conv2d.dilation,
groups=self.conv2d.groups)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class PatchEmbed(BaseModule):
"""Patch Embedding for Vision Transformer.
Adapted from
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py
Different from `mmcv.cnn.bricks.transformer.PatchEmbed`, this module use
Conv2d and BatchNorm2d to implement PatchEmbedding, and output shape is
(N, C, H, W).
Args:
in_channels (int): The number of input channels.
embed_dim (int): The embedding dimension.
resolution (Tuple[int, int]): The resolution of the input feature.
act_cfg (dict): The activation config of the module.
Default: dict(type='GELU').
"""
def __init__(self,
in_channels,
embed_dim,
resolution,
act_cfg=dict(type='GELU')):
super().__init__()
img_size: Tuple[int, int] = resolution
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
self.num_patches = self.patches_resolution[0] * \
self.patches_resolution[1]
self.in_channels = in_channels
self.embed_dim = embed_dim
self.seq = nn.Sequential(
ConvBN2d(
in_channels,
embed_dim // 2,
kernel_size=3,
stride=2,
padding=1),
build_activation_layer(act_cfg),
ConvBN2d(
embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1),
)
def forward(self, x):
return self.seq(x)
class PatchMerging(nn.Module):
"""Patch Merging for TinyViT.
Adapted from
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py
Different from `mmpretrain.models.utils.PatchMerging`, this module use
Conv2d and BatchNorm2d to implement PatchMerging.
Args:
in_channels (int): The number of input channels.
resolution (Tuple[int, int]): The resolution of the input feature.
out_channels (int): The number of output channels.
act_cfg (dict): The activation config of the module.
Default: dict(type='GELU').
"""
def __init__(self,
resolution,
in_channels,
out_channels,
act_cfg=dict(type='GELU')):
super().__init__()
self.img_size = resolution
self.act = build_activation_layer(act_cfg)
self.conv1 = ConvBN2d(in_channels, out_channels, kernel_size=1)
self.conv2 = ConvBN2d(
out_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1,
groups=out_channels)
self.conv3 = ConvBN2d(out_channels, out_channels, kernel_size=1)
self.out_resolution = (resolution[0] // 2, resolution[1] // 2)
def forward(self, x):
if len(x.shape) == 3:
H, W = self.img_size
B = x.shape[0]
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
x = self.conv1(x)
x = self.act(x)
x = self.conv2(x)
x = self.act(x)
x = self.conv3(x)
x = x.flatten(2).transpose(1, 2)
return x
class MBConvBlock(nn.Module):
"""Mobile Inverted Residual Bottleneck Block for TinyViT. Adapted from
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
expand_ratio (int): The expand ratio of the hidden channels.
drop_rate (float): The drop rate of the block.
act_cfg (dict): The activation config of the module.
Default: dict(type='GELU').
"""
def __init__(self,
in_channels,
out_channels,
expand_ratio,
drop_path,
act_cfg=dict(type='GELU')):
super().__init__()
self.in_channels = in_channels
hidden_channels = int(in_channels * expand_ratio)
# linear
self.conv1 = ConvBN2d(in_channels, hidden_channels, kernel_size=1)
self.act = build_activation_layer(act_cfg)
# depthwise conv
self.conv2 = ConvBN2d(
in_channels=hidden_channels,
out_channels=hidden_channels,
kernel_size=3,
stride=1,
padding=1,
groups=hidden_channels)
# linear
self.conv3 = ConvBN2d(
hidden_channels, out_channels, kernel_size=1, bn_weight_init=0.0)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
shortcut = x
x = self.conv1(x)
x = self.act(x)
x = self.conv2(x)
x = self.act(x)
x = self.conv3(x)
x = self.drop_path(x)
x += shortcut
x = self.act(x)
return x
class ConvStage(BaseModule):
"""Convolution Stage for TinyViT.
Adapted from
https://github.com/microsoft/Cream/blob/main/TinyViT/models/tiny_vit.py
Args:
in_channels (int): The number of input channels.
resolution (Tuple[int, int]): The resolution of the input feature.
depth (int): The number of blocks in the stage.
act_cfg (dict): The activation config of the module.
drop_path (float): The drop path of the block.
downsample (None | nn.Module): The downsample operation.
Default: None.
use_checkpoint (bool): Whether to use checkpointing to save memory.
out_channels (int): The number of output channels.
conv_expand_ratio (int): The expand ratio of the hidden channels.
Default: 4.
init_cfg (dict | list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels,
resolution,
depth,
act_cfg,
drop_path=0.,
downsample=None,
use_checkpoint=False,
out_channels=None,
conv_expand_ratio=4.,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = ModuleList([
MBConvBlock(
in_channels=in_channels,
out_channels=in_channels,
expand_ratio=conv_expand_ratio,
drop_path=drop_path[i]
if isinstance(drop_path, list) else drop_path)
for i in range(depth)
])
# patch merging layer
if downsample is not None:
self.downsample = downsample(
resolution=resolution,
in_channels=in_channels,
out_channels=out_channels,
act_cfg=act_cfg)
self.resolution = self.downsample.out_resolution
else:
self.downsample = None
self.resolution = resolution
def forward(self, x):
for block in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(block, x)
else:
x = block(x)
if self.downsample is not None:
x = self.downsample(x)
return x
class MLP(BaseModule):
"""MLP module for TinyViT.
Args:
in_channels (int): The number of input channels.
hidden_channels (int, optional): The number of hidden channels.
Default: None.
out_channels (int, optional): The number of output channels.
Default: None.
act_cfg (dict): The activation config of the module.
Default: dict(type='GELU').
drop (float): Probability of an element to be zeroed.
Default: 0.
init_cfg (dict | list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels,
hidden_channels=None,
out_channels=None,
act_cfg=dict(type='GELU'),
drop=0.,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
out_channels = out_channels or in_channels
hidden_channels = hidden_channels or in_channels
self.norm = nn.LayerNorm(in_channels)
self.fc1 = nn.Linear(in_channels, hidden_channels)
self.fc2 = nn.Linear(hidden_channels, out_channels)
self.act = build_activation_layer(act_cfg)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.norm(x)
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class TinyViTBlock(BaseModule):
"""TinViT Block.
Args:
in_channels (int): The number of input channels.
resolution (Tuple[int, int]): The resolution of the input feature.
num_heads (int): The number of heads in the multi-head attention.
window_size (int): The size of the window.
Default: 7.
mlp_ratio (float): The ratio of mlp hidden dim to embedding dim.
Default: 4.
drop (float): Probability of an element to be zeroed.
Default: 0.
drop_path (float): The drop path of the block.
Default: 0.
local_conv_size (int): The size of the local convolution.
Default: 3.
act_cfg (dict): The activation config of the module.
Default: dict(type='GELU').
"""
def __init__(self,
in_channels,
resolution,
num_heads,
window_size=7,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
local_conv_size=3,
act_cfg=dict(type='GELU')):
super().__init__()
self.in_channels = in_channels
self.img_size = resolution
self.num_heads = num_heads
assert window_size > 0, 'window_size must be greater than 0'
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
assert in_channels % num_heads == 0, \
'dim must be divisible by num_heads'
head_dim = in_channels // num_heads
window_resolution = (window_size, window_size)
self.attn = LeAttention(
in_channels,
head_dim,
num_heads,
attn_ratio=1,
resolution=window_resolution)
mlp_hidden_dim = int(in_channels * mlp_ratio)
self.mlp = MLP(
in_channels=in_channels,
hidden_channels=mlp_hidden_dim,
act_cfg=act_cfg,
drop=drop)
self.local_conv = ConvBN2d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=local_conv_size,
stride=1,
padding=local_conv_size // 2,
groups=in_channels)
def forward(self, x):
H, W = self.img_size
B, L, C = x.shape
assert L == H * W, 'input feature has wrong size'
res_x = x
if H == self.window_size and W == self.window_size:
x = self.attn(x)
else:
x = x.view(B, H, W, C)
pad_b = (self.window_size -
H % self.window_size) % self.window_size
pad_r = (self.window_size -
W % self.window_size) % self.window_size
padding = pad_b > 0 or pad_r > 0
if padding:
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
pH, pW = H + pad_b, W + pad_r
nH = pH // self.window_size
nW = pW // self.window_size
# window partition
x = x.view(B, nH, self.window_size, nW, self.window_size,
C).transpose(2, 3).reshape(
B * nH * nW, self.window_size * self.window_size, C)
x = self.attn(x)
# window reverse
x = x.view(B, nH, nW, self.window_size, self.window_size,
C).transpose(2, 3).reshape(B, pH, pW, C)
if padding:
x = x[:, :H, :W].contiguous()
x = x.view(B, L, C)
x = res_x + self.drop_path(x)
x = x.transpose(1, 2).reshape(B, C, H, W)
x = self.local_conv(x)
x = x.view(B, C, L).transpose(1, 2)
x = x + self.drop_path(self.mlp(x))
return x
class BasicStage(BaseModule):
"""Basic Stage for TinyViT.
Args:
in_channels (int): The number of input channels.
resolution (Tuple[int, int]): The resolution of the input feature.
depth (int): The number of blocks in the stage.
num_heads (int): The number of heads in the multi-head attention.
window_size (int): The size of the window.
mlp_ratio (float): The ratio of mlp hidden dim to embedding dim.
Default: 4.
drop (float): Probability of an element to be zeroed.
Default: 0.
drop_path (float): The drop path of the block.
Default: 0.
downsample (None | nn.Module): The downsample operation.
Default: None.
use_checkpoint (bool): Whether to use checkpointing to save memory.
Default: False.
act_cfg (dict): The activation config of the module.
Default: dict(type='GELU').
init_cfg (dict | list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels,
resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.,
drop=0.,
drop_path=0.,
downsample=None,
use_checkpoint=False,
local_conv_size=3,
out_channels=None,
act_cfg=dict(type='GELU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = ModuleList([
TinyViTBlock(
in_channels=in_channels,
resolution=resolution,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
drop=drop,
local_conv_size=local_conv_size,
act_cfg=act_cfg,
drop_path=drop_path[i]
if isinstance(drop_path, list) else drop_path)
for i in range(depth)
])
# build patch merging layer
if downsample is not None:
self.downsample = downsample(
resolution=resolution,
in_channels=in_channels,
out_channels=out_channels,
act_cfg=act_cfg)
self.resolution = self.downsample.out_resolution
else:
self.downsample = None
self.resolution = resolution
def forward(self, x):
for block in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(block, x)
else:
x = block(x)
if self.downsample is not None:
x = self.downsample(x)
return x
@MODELS.register_module()
class TinyViT(BaseBackbone):
"""TinyViT.
A PyTorch implementation of : `TinyViT: Fast Pretraining Distillation
for Small Vision Transformers<https://arxiv.org/abs/2201.03545v1>`_
Inspiration from
https://github.com/microsoft/Cream/blob/main/TinyViT
Args:
arch (str | dict): The architecture of TinyViT.
Default: '5m'.
img_size (tuple | int): The resolution of the input image.
Default: (224, 224)
window_size (list): The size of the window.
Default: [7, 7, 14, 7]
in_channels (int): The number of input channels.
Default: 3.
depths (list[int]): The depth of each stage.
Default: [2, 2, 6, 2].
mlp_ratio (list[int]): The ratio of mlp hidden dim to embedding dim.
Default: 4.
drop_rate (float): Probability of an element to be zeroed.
Default: 0.
drop_path_rate (float): The drop path of the block.
Default: 0.1.
use_checkpoint (bool): Whether to use checkpointing to save memory.
Default: False.
mbconv_expand_ratio (int): The expand ratio of the mbconv.
Default: 4.0
local_conv_size (int): The size of the local conv.
Default: 3.
layer_lr_decay (float): The layer lr decay.
Default: 1.0
out_indices (int | list[int]): Output from which stages.
Default: -1
frozen_stages (int | list[int]): Stages to be frozen (all param fixed).
Default: -0
gap_before_final_nrom (bool): Whether to add a gap before the final
norm. Default: True.
act_cfg (dict): The activation config of the module.
Default: dict(type='GELU').
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
init_cfg (dict | list[dict], optional): Initialization config dict.
Default: None.
"""
arch_settings = {
'5m': {
'channels': [64, 128, 160, 320],
'num_heads': [2, 4, 5, 10],
'depths': [2, 2, 6, 2],
},
'11m': {
'channels': [64, 128, 256, 448],
'num_heads': [2, 4, 8, 14],
'depths': [2, 2, 6, 2],
},
'21m': {
'channels': [96, 192, 384, 576],
'num_heads': [3, 6, 12, 18],
'depths': [2, 2, 6, 2],
},
}
def __init__(self,
arch='5m',
img_size=(224, 224),
window_size=[7, 7, 14, 7],
in_channels=3,
mlp_ratio=4.,
drop_rate=0.,
drop_path_rate=0.1,
use_checkpoint=False,
mbconv_expand_ratio=4.0,
local_conv_size=3,
layer_lr_decay=1.0,
out_indices=-1,
frozen_stages=0,
gap_before_final_norm=True,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
if isinstance(arch, str):
assert arch in self.arch_settings, \
f'Unavaiable arch, please choose from ' \
f'({set(self.arch_settings)} or pass a dict.'
arch = self.arch_settings[arch]
elif isinstance(arch, dict):
assert 'channels' in arch and 'num_heads' in arch and \
'depths' in arch, 'The arch dict must have' \
f'"channels", "num_heads", "window_sizes" ' \
f'keys, but got {arch.keys()}'
self.channels = arch['channels']
self.num_heads = arch['num_heads']
self.widow_sizes = window_size
self.img_size = img_size
self.depths = arch['depths']
self.num_stages = len(self.channels)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = 4 + index
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.gap_before_final_norm = gap_before_final_norm
self.layer_lr_decay = layer_lr_decay
self.patch_embed = PatchEmbed(
in_channels=in_channels,
embed_dim=self.channels[0],
resolution=self.img_size,
act_cfg=dict(type='GELU'))
patches_resolution = self.patch_embed.patches_resolution
# stochastic depth decay rule
dpr = [
x.item()
for x in torch.linspace(0, drop_path_rate, sum(self.depths))
]
# build stages
self.stages = ModuleList()
for i in range(self.num_stages):
depth = self.depths[i]
channel = self.channels[i]
curr_resolution = (patches_resolution[0] // (2**i),
patches_resolution[1] // (2**i))
drop_path = dpr[sum(self.depths[:i]):sum(self.depths[:i + 1])]
downsample = PatchMerging if (i < self.num_stages - 1) else None
out_channels = self.channels[min(i + 1, self.num_stages - 1)]
if i >= 1:
stage = BasicStage(
in_channels=channel,
resolution=curr_resolution,
depth=depth,
num_heads=self.num_heads[i],
window_size=self.widow_sizes[i],
mlp_ratio=mlp_ratio,
drop=drop_rate,
drop_path=drop_path,
downsample=downsample,
use_checkpoint=use_checkpoint,
local_conv_size=local_conv_size,
out_channels=out_channels,
act_cfg=act_cfg)
else:
stage = ConvStage(
in_channels=channel,
resolution=curr_resolution,
depth=depth,
act_cfg=act_cfg,
drop_path=drop_path,
downsample=downsample,
use_checkpoint=use_checkpoint,
out_channels=out_channels,
conv_expand_ratio=mbconv_expand_ratio)
self.stages.append(stage)
# add output norm
if i in self.out_indices:
norm_layer = build_norm_layer(norm_cfg, out_channels)[1]
self.add_module(f'norm{i}', norm_layer)
def set_layer_lr_decay(self, layer_lr_decay):
# TODO: add layer_lr_decay
pass
def forward(self, x):
outs = []
x = self.patch_embed(x)
for i, stage in enumerate(self.stages):
x = stage(x)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
if self.gap_before_final_norm:
gap = x.mean(1)
outs.append(norm_layer(gap))
else:
out = norm_layer(x)
# convert the (B,L,C) format into (B,C,H,W) format
# which would be better for the downstream tasks.
B, L, C = out.shape
out = out.view(B, *stage.resolution, C)
outs.append(out.permute(0, 3, 1, 2))
return tuple(outs)
def _freeze_stages(self):
for i in range(self.frozen_stages):
stage = self.stages[i]
stage.eval()
for param in stage.parameters():
param.requires_grad = False
def train(self, mode=True):
super(TinyViT, self).train(mode)
self._freeze_stages()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment