Commit 0fd8347d authored by unknown's avatar unknown
Browse files

添加mmclassification-0.24.1代码,删除mmclassification-speed-benchmark

parent cc567e9e
# Copyright (c) OpenMMLab. All rights reserved.
import functools
import torch
import torch.nn.functional as F
def reduce_loss(loss, reduction):
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are "none", "mean" and "sum".
Return:
Tensor: Reduced loss tensor.
"""
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()
def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
"""Apply element-wise weight and reduce loss.
Args:
loss (Tensor): Element-wise loss.
weight (Tensor): Element-wise weights.
reduction (str): Same as built-in losses of PyTorch.
avg_factor (float): Average factor when computing the mean of losses.
Returns:
Tensor: Processed loss values.
"""
# if weight is specified, apply element-wise weight
if weight is not None:
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
loss = reduce_loss(loss, reduction)
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
loss = loss.sum() / avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
def weighted_loss(loss_func):
"""Create a weighted version of a given loss function.
To use this decorator, the loss function must have the signature like
``loss_func(pred, target, **kwargs)``. The function only needs to compute
element-wise loss without any reduction. This decorator will add weight
and reduction arguments to the function. The decorated function will have
the signature like ``loss_func(pred, target, weight=None, reduction='mean',
avg_factor=None, **kwargs)``.
:Example:
>>> import torch
>>> @weighted_loss
>>> def l1_loss(pred, target):
>>> return (pred - target).abs()
>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])
>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, avg_factor=2)
tensor(1.5000)
"""
@functools.wraps(loss_func)
def wrapper(pred,
target,
weight=None,
reduction='mean',
avg_factor=None,
**kwargs):
# get element-wise loss
loss = loss_func(pred, target, **kwargs)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
return wrapper
def convert_to_one_hot(targets: torch.Tensor, classes) -> torch.Tensor:
"""This function converts target class indices to one-hot vectors, given
the number of classes.
Args:
targets (Tensor): The ground truth label of the prediction
with shape (N, 1)
classes (int): the number of classes.
Returns:
Tensor: Processed loss values.
"""
assert (torch.max(targets).item() <
classes), 'Class Index must be less than number of classes'
one_hot_targets = F.one_hot(
targets.long().squeeze(-1), num_classes=classes)
return one_hot_targets
# Copyright (c) OpenMMLab. All rights reserved.
from .gap import GlobalAveragePooling
from .gem import GeneralizedMeanPooling
from .hr_fuse import HRFuseScales
__all__ = ['GlobalAveragePooling', 'GeneralizedMeanPooling', 'HRFuseScales']
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from torch.nn.parameter import Parameter
from ..builder import NECKS
def gem(x: Tensor, p: Parameter, eps: float = 1e-6, clamp=True) -> Tensor:
if clamp:
x = x.clamp(min=eps)
return F.avg_pool2d(x.pow(p), (x.size(-2), x.size(-1))).pow(1. / p)
@NECKS.register_module()
class GeneralizedMeanPooling(nn.Module):
"""Generalized Mean Pooling neck.
Note that we use `view` to remove extra channel after pooling. We do not
use `squeeze` as it will also remove the batch dimension when the tensor
has a batch dimension of size 1, which can lead to unexpected errors.
Args:
p (float): Parameter value.
Default: 3.
eps (float): epsilon.
Default: 1e-6
clamp (bool): Use clamp before pooling.
Default: True
"""
def __init__(self, p=3., eps=1e-6, clamp=True):
assert p >= 1, "'p' must be a value greater then 1"
super(GeneralizedMeanPooling, self).__init__()
self.p = Parameter(torch.ones(1) * p)
self.eps = eps
self.clamp = clamp
def forward(self, inputs):
if isinstance(inputs, tuple):
outs = tuple([
gem(x, p=self.p, eps=self.eps, clamp=self.clamp)
for x in inputs
])
outs = tuple(
[out.view(x.size(0), -1) for out, x in zip(outs, inputs)])
elif isinstance(inputs, torch.Tensor):
outs = gem(inputs, p=self.p, eps=self.eps, clamp=self.clamp)
outs = outs.view(inputs.size(0), -1)
else:
raise TypeError('neck inputs should be tuple or torch.tensor')
return outs
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn.bricks import ConvModule
from mmcv.runner import BaseModule
from ..backbones.resnet import Bottleneck, ResLayer
from ..builder import NECKS
@NECKS.register_module()
class HRFuseScales(BaseModule):
"""Fuse feature map of multiple scales in HRNet.
Args:
in_channels (list[int]): The input channels of all scales.
out_channels (int): The channels of fused feature map.
Defaults to 2048.
norm_cfg (dict): dictionary to construct norm layers.
Defaults to ``dict(type='BN', momentum=0.1)``.
init_cfg (dict | list[dict], optional): Initialization config dict.
Defaults to ``dict(type='Normal', layer='Linear', std=0.01))``.
"""
def __init__(self,
in_channels,
out_channels=2048,
norm_cfg=dict(type='BN', momentum=0.1),
init_cfg=dict(type='Normal', layer='Linear', std=0.01)):
super(HRFuseScales, self).__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.norm_cfg = norm_cfg
block_type = Bottleneck
out_channels = [128, 256, 512, 1024]
# Increase the channels on each resolution
# from C, 2C, 4C, 8C to 128, 256, 512, 1024
increase_layers = []
for i in range(len(in_channels)):
increase_layers.append(
ResLayer(
block_type,
in_channels=in_channels[i],
out_channels=out_channels[i],
num_blocks=1,
stride=1,
))
self.increase_layers = nn.ModuleList(increase_layers)
# Downsample feature maps in each scale.
downsample_layers = []
for i in range(len(in_channels) - 1):
downsample_layers.append(
ConvModule(
in_channels=out_channels[i],
out_channels=out_channels[i + 1],
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
bias=False,
))
self.downsample_layers = nn.ModuleList(downsample_layers)
# The final conv block before final classifier linear layer.
self.final_layer = ConvModule(
in_channels=out_channels[3],
out_channels=self.out_channels,
kernel_size=1,
norm_cfg=self.norm_cfg,
bias=False,
)
def forward(self, x):
assert isinstance(x, tuple) and len(x) == len(self.in_channels)
feat = self.increase_layers[0](x[0])
for i in range(len(self.downsample_layers)):
feat = self.downsample_layers[i](feat) + \
self.increase_layers[i + 1](x[i + 1])
return (self.final_layer(feat), )
# Copyright (c) OpenMMLab. All rights reserved.
from .attention import MultiheadAttention, ShiftWindowMSA, WindowMSAV2
from .augment.augments import Augments
from .channel_shuffle import channel_shuffle
from .embed import (HybridEmbed, PatchEmbed, PatchMerging, resize_pos_embed,
resize_relative_position_bias_table)
from .helpers import is_tracing, to_2tuple, to_3tuple, to_4tuple, to_ntuple
from .inverted_residual import InvertedResidual
from .layer_scale import LayerScale
from .make_divisible import make_divisible
from .position_encoding import ConditionalPositionEncoding
from .se_layer import SELayer
__all__ = [
'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer',
'to_ntuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'PatchEmbed',
'PatchMerging', 'HybridEmbed', 'Augments', 'ShiftWindowMSA', 'is_tracing',
'MultiheadAttention', 'ConditionalPositionEncoding', 'resize_pos_embed',
'resize_relative_position_bias_table', 'WindowMSAV2', 'LayerScale'
]
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks.registry import DROPOUT_LAYERS
from mmcv.cnn.bricks.transformer import build_dropout
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule
from ..builder import ATTENTION
from .helpers import to_2tuple
class WindowMSA(BaseModule):
"""Window based multi-head self-attention (W-MSA) module with relative
position bias.
Args:
embed_dims (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Defaults to True.
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
attn_drop (float, optional): Dropout ratio of attention weight.
Defaults to 0.
proj_drop (float, optional): Dropout ratio of output. Defaults to 0.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
init_cfg=None):
super().__init__(init_cfg)
self.embed_dims = embed_dims
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_embed_dims = embed_dims // num_heads
self.scale = qk_scale or head_embed_dims**-0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# About 2x faster than original impl
Wh, Ww = self.window_size
rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww)
rel_position_index = rel_index_coords + rel_index_coords.T
rel_position_index = rel_position_index.flip(1).contiguous()
self.register_buffer('relative_position_index', rel_position_index)
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(embed_dims, embed_dims)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def init_weights(self):
super(WindowMSA, self).init_weights()
trunc_normal_(self.relative_position_bias_table, std=0.02)
def forward(self, x, mask=None):
"""
Args:
x (tensor): input features with shape of (num_windows*B, N, C)
mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww,
Wh*Ww), value should be between (-inf, 0].
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[
2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N,
N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@staticmethod
def double_step_seq(step1, len1, step2, len2):
seq1 = torch.arange(0, step1 * len1, step1)
seq2 = torch.arange(0, step2 * len2, step2)
return (seq1[:, None] + seq2[None, :]).reshape(1, -1)
class WindowMSAV2(BaseModule):
"""Window based multi-head self-attention (W-MSA) module with relative
position bias.
Based on implementation on Swin Transformer V2 original repo. Refers to
https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer_v2.py
for more details.
Args:
embed_dims (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Defaults to True.
attn_drop (float, optional): Dropout ratio of attention weight.
Defaults to 0.
proj_drop (float, optional): Dropout ratio of output. Defaults to 0.
pretrained_window_size (tuple(int)): The height and width of the window
in pre-training.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
window_size,
num_heads,
qkv_bias=True,
attn_drop=0.,
proj_drop=0.,
cpb_mlp_hidden_dims=512,
pretrained_window_size=(0, 0),
init_cfg=None,
**kwargs): # accept extra arguments
super().__init__(init_cfg)
self.embed_dims = embed_dims
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
# Use small network for continuous relative position bias
self.cpb_mlp = nn.Sequential(
nn.Linear(
in_features=2, out_features=cpb_mlp_hidden_dims, bias=True),
nn.ReLU(inplace=True),
nn.Linear(
in_features=cpb_mlp_hidden_dims,
out_features=num_heads,
bias=False))
# Add learnable scalar for cosine attention
self.logit_scale = nn.Parameter(
torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
# get relative_coords_table
relative_coords_h = torch.arange(
-(self.window_size[0] - 1),
self.window_size[0],
dtype=torch.float32)
relative_coords_w = torch.arange(
-(self.window_size[1] - 1),
self.window_size[1],
dtype=torch.float32)
relative_coords_table = torch.stack(
torch.meshgrid([relative_coords_h, relative_coords_w])).permute(
1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
if pretrained_window_size[0] > 0:
relative_coords_table[:, :, :, 0] /= (
pretrained_window_size[0] - 1)
relative_coords_table[:, :, :, 1] /= (
pretrained_window_size[1] - 1)
else:
relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
relative_coords_table *= 8 # normalize to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
torch.abs(relative_coords_table) + 1.0) / np.log2(8)
self.register_buffer('relative_coords_table', relative_coords_table)
# get pair-wise relative position index
# for each token inside the window
indexes_h = torch.arange(self.window_size[0])
indexes_w = torch.arange(self.window_size[1])
coordinates = torch.stack(
torch.meshgrid([indexes_h, indexes_w]), dim=0) # 2, Wh, Ww
coordinates = torch.flatten(coordinates, start_dim=1) # 2, Wh*Ww
# 2, Wh*Ww, Wh*Ww
relative_coordinates = coordinates[:, :, None] - coordinates[:,
None, :]
relative_coordinates = relative_coordinates.permute(
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coordinates[:, :, 0] += self.window_size[
0] - 1 # shift to start from 0
relative_coordinates[:, :, 1] += self.window_size[1] - 1
relative_coordinates[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coordinates.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer('relative_position_index',
relative_position_index)
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(embed_dims))
self.v_bias = nn.Parameter(torch.zeros(embed_dims))
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(embed_dims, embed_dims)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x (tensor): input features with shape of (num_windows*B, N, C)
mask (tensor, Optional): mask with shape of (num_windows, Wh*Ww,
Wh*Ww), value should be between (-inf, 0].
"""
B_, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat(
(self.q_bias,
torch.zeros_like(self.v_bias,
requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B_, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[
2] # make torchscript happy (cannot use tensor as tuple)
# cosine attention
attn = (
F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
logit_scale = torch.clamp(
self.logit_scale, max=np.log(1. / 0.01)).exp()
attn = attn * logit_scale
relative_position_bias_table = self.cpb_mlp(
self.relative_coords_table).view(-1, self.num_heads)
relative_position_bias = relative_position_bias_table[
self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N,
N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@ATTENTION.register_module()
class ShiftWindowMSA(BaseModule):
"""Shift Window Multihead Self-Attention Module.
Args:
embed_dims (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): The height and width of the window.
shift_size (int, optional): The shift step of each window towards
right-bottom. If zero, act as regular window-msa. Defaults to 0.
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v.
Defaults to True
qk_scale (float | None, optional): Override default qk scale of
head_dim ** -0.5 if set. Defaults to None.
attn_drop (float, optional): Dropout ratio of attention weight.
Defaults to 0.0.
proj_drop (float, optional): Dropout ratio of output. Defaults to 0.
dropout_layer (dict, optional): The dropout_layer used before output.
Defaults to dict(type='DropPath', drop_prob=0.).
pad_small_map (bool): If True, pad the small feature map to the window
size, which is common used in detection and segmentation. If False,
avoid shifting window and shrink the window size to the size of
feature map, which is common used in classification.
Defaults to False.
version (str, optional): Version of implementation of Swin
Transformers. Defaults to `v1`.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
window_size,
shift_size=0,
qkv_bias=True,
qk_scale=None,
attn_drop=0,
proj_drop=0,
dropout_layer=dict(type='DropPath', drop_prob=0.),
pad_small_map=False,
input_resolution=None,
auto_pad=None,
window_msa=WindowMSA,
msa_cfg=dict(),
init_cfg=None):
super().__init__(init_cfg)
if input_resolution is not None or auto_pad is not None:
warnings.warn(
'The ShiftWindowMSA in new version has supported auto padding '
'and dynamic input shape in all condition. And the argument '
'`auto_pad` and `input_resolution` have been deprecated.',
DeprecationWarning)
self.shift_size = shift_size
self.window_size = window_size
assert 0 <= self.shift_size < self.window_size
assert issubclass(window_msa, BaseModule), \
'Expect Window based multi-head self-attention Module is type of' \
f'{type(BaseModule)}, but got {type(window_msa)}.'
self.w_msa = window_msa(
embed_dims=embed_dims,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=proj_drop,
**msa_cfg,
)
self.drop = build_dropout(dropout_layer)
self.pad_small_map = pad_small_map
def forward(self, query, hw_shape):
B, L, C = query.shape
H, W = hw_shape
assert L == H * W, f"The query length {L} doesn't match the input "\
f'shape ({H}, {W}).'
query = query.view(B, H, W, C)
window_size = self.window_size
shift_size = self.shift_size
if min(H, W) == window_size:
# If not pad small feature map, avoid shifting when the window size
# is equal to the size of feature map. It's to align with the
# behavior of the original implementation.
shift_size = shift_size if self.pad_small_map else 0
elif min(H, W) < window_size:
# In the original implementation, the window size will be shrunk
# to the size of feature map. The behavior is different with
# swin-transformer for downstream tasks. To support dynamic input
# shape, we don't allow this feature.
assert self.pad_small_map, \
f'The input shape ({H}, {W}) is smaller than the window ' \
f'size ({window_size}). Please set `pad_small_map=True`, or ' \
'decrease the `window_size`.'
pad_r = (window_size - W % window_size) % window_size
pad_b = (window_size - H % window_size) % window_size
query = F.pad(query, (0, 0, 0, pad_r, 0, pad_b))
H_pad, W_pad = query.shape[1], query.shape[2]
# cyclic shift
if shift_size > 0:
query = torch.roll(
query, shifts=(-shift_size, -shift_size), dims=(1, 2))
attn_mask = self.get_attn_mask((H_pad, W_pad),
window_size=window_size,
shift_size=shift_size,
device=query.device)
# nW*B, window_size, window_size, C
query_windows = self.window_partition(query, window_size)
# nW*B, window_size*window_size, C
query_windows = query_windows.view(-1, window_size**2, C)
# W-MSA/SW-MSA (nW*B, window_size*window_size, C)
attn_windows = self.w_msa(query_windows, mask=attn_mask)
# merge windows
attn_windows = attn_windows.view(-1, window_size, window_size, C)
# B H' W' C
shifted_x = self.window_reverse(attn_windows, H_pad, W_pad,
window_size)
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(
shifted_x, shifts=(shift_size, shift_size), dims=(1, 2))
else:
x = shifted_x
if H != H_pad or W != W_pad:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
x = self.drop(x)
return x
@staticmethod
def window_reverse(windows, H, W, window_size):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size,
window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
@staticmethod
def window_partition(x, window_size):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size,
window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
windows = windows.view(-1, window_size, window_size, C)
return windows
@staticmethod
def get_attn_mask(hw_shape, window_size, shift_size, device=None):
if shift_size > 0:
img_mask = torch.zeros(1, *hw_shape, 1, device=device)
h_slices = (slice(0, -window_size), slice(-window_size,
-shift_size),
slice(-shift_size, None))
w_slices = (slice(0, -window_size), slice(-window_size,
-shift_size),
slice(-shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# nW, window_size, window_size, 1
mask_windows = ShiftWindowMSA.window_partition(
img_mask, window_size)
mask_windows = mask_windows.view(-1, window_size * window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0)
attn_mask = attn_mask.masked_fill(attn_mask == 0, 0.0)
else:
attn_mask = None
return attn_mask
class MultiheadAttention(BaseModule):
"""Multi-head Attention Module.
This module implements multi-head attention that supports different input
dims and embed dims. And it also supports a shortcut from ``value``, which
is useful if input dims is not the same with embed dims.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads.
input_dims (int, optional): The input dimension, and if None,
use ``embed_dims``. Defaults to None.
attn_drop (float): Dropout rate of the dropout layer after the
attention calculation of query and key. Defaults to 0.
proj_drop (float): Dropout rate of the dropout layer after the
output projection. Defaults to 0.
dropout_layer (dict): The dropout config before adding the shortcut.
Defaults to ``dict(type='Dropout', drop_prob=0.)``.
qkv_bias (bool): If True, add a learnable bias to q, k, v.
Defaults to True.
qk_scale (float, optional): Override default qk scale of
``head_dim ** -0.5`` if set. Defaults to None.
proj_bias (bool) If True, add a learnable bias to output projection.
Defaults to True.
v_shortcut (bool): Add a shortcut from value to output. It's usually
used if ``input_dims`` is different from ``embed_dims``.
Defaults to False.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
embed_dims,
num_heads,
input_dims=None,
attn_drop=0.,
proj_drop=0.,
dropout_layer=dict(type='Dropout', drop_prob=0.),
qkv_bias=True,
qk_scale=None,
proj_bias=True,
v_shortcut=False,
init_cfg=None):
super(MultiheadAttention, self).__init__(init_cfg=init_cfg)
self.input_dims = input_dims or embed_dims
self.embed_dims = embed_dims
self.num_heads = num_heads
self.v_shortcut = v_shortcut
self.head_dims = embed_dims // num_heads
self.scale = qk_scale or self.head_dims**-0.5
self.qkv = nn.Linear(self.input_dims, embed_dims * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
self.out_drop = DROPOUT_LAYERS.build(dropout_layer)
def forward(self, x):
B, N, _ = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
self.head_dims).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.embed_dims)
x = self.proj(x)
x = self.out_drop(self.proj_drop(x))
if self.v_shortcut:
x = v.squeeze(1) + x
return x
# Copyright (c) OpenMMLab. All rights reserved.
from .augments import Augments
from .cutmix import BatchCutMixLayer
from .identity import Identity
from .mixup import BatchMixupLayer
from .resizemix import BatchResizeMixLayer
__all__ = ('Augments', 'BatchCutMixLayer', 'Identity', 'BatchMixupLayer',
'BatchResizeMixLayer')
# Copyright (c) OpenMMLab. All rights reserved.
import random
import numpy as np
......@@ -9,6 +10,7 @@ class Augments(object):
"""Data augments.
We implement some data augmentation methods, such as mixup, cutmix.
Args:
augments_cfg (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict`):
Config dict of augments
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import Registry, build_from_cfg
AUGMENT = Registry('augment')
def build_augment(cfg, default_args=None):
return build_from_cfg(cfg, AUGMENT, default_args)
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
import numpy as np
import torch
import torch.nn.functional as F
from .builder import AUGMENT
from .utils import one_hot_encoding
class BaseCutMixLayer(object, metaclass=ABCMeta):
......@@ -116,13 +117,48 @@ class BaseCutMixLayer(object, metaclass=ABCMeta):
@AUGMENT.register_module(name='BatchCutMix')
class BatchCutMixLayer(BaseCutMixLayer):
"""CutMix layer for batch CutMix."""
r"""CutMix layer for a batch of data.
CutMix is a method to improve the network's generalization capability. It's
proposed in `CutMix: Regularization Strategy to Train Strong Classifiers
with Localizable Features <https://arxiv.org/abs/1905.04899>`
With this method, patches are cut and pasted among training images where
the ground truth labels are also mixed proportionally to the area of the
patches.
Args:
alpha (float): Parameters for Beta distribution to generate the
mixing ratio. It should be a positive number. More details
can be found in :class:`BatchMixupLayer`.
num_classes (int): The number of classes
prob (float): The probability to execute cutmix. It should be in
range [0, 1]. Defaults to 1.0.
cutmix_minmax (List[float], optional): The min/max area ratio of the
patches. If not None, the bounding-box of patches is uniform
sampled within this ratio range, and the ``alpha`` will be ignored.
Otherwise, the bounding-box is generated according to the
``alpha``. Defaults to None.
correct_lam (bool): Whether to apply lambda correction when cutmix bbox
clipped by image borders. Defaults to True.
Note:
If the ``cutmix_minmax`` is None, how to generate the bounding-box of
patches according to the ``alpha``?
First, generate a :math:`\lambda`, details can be found in
:class:`BatchMixupLayer`. And then, the area ratio of the bounding-box
is calculated by:
.. math::
\text{ratio} = \sqrt{1-\lambda}
"""
def __init__(self, *args, **kwargs):
super(BatchCutMixLayer, self).__init__(*args, **kwargs)
def cutmix(self, img, gt_label):
one_hot_gt_label = F.one_hot(gt_label, num_classes=self.num_classes)
one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
lam = np.random.beta(self.alpha, self.alpha)
batch_size = img.size(0)
index = torch.randperm(batch_size)
......
import torch.nn.functional as F
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import AUGMENT
from .utils import one_hot_encoding
@AUGMENT.register_module(name='Identity')
......@@ -23,7 +23,7 @@ class Identity(object):
self.prob = prob
def one_hot(self, gt_label):
return F.one_hot(gt_label, num_classes=self.num_classes)
return one_hot_encoding(gt_label, self.num_classes)
def __call__(self, img, gt_label):
return img, self.one_hot(gt_label)
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
import numpy as np
import torch
from .builder import AUGMENT
from .utils import one_hot_encoding
class BaseMixupLayer(object, metaclass=ABCMeta):
"""Base class for MixupLayer.
Args:
alpha (float): Parameters for Beta distribution to generate the
mixing ratio. It should be a positive number.
num_classes (int): The number of classes.
prob (float): MixUp probability. It should be in range [0, 1].
Default to 1.0
"""
def __init__(self, alpha, num_classes, prob=1.0):
super(BaseMixupLayer, self).__init__()
assert isinstance(alpha, float) and alpha > 0
assert isinstance(num_classes, int)
assert isinstance(prob, float) and 0.0 <= prob <= 1.0
self.alpha = alpha
self.num_classes = num_classes
self.prob = prob
@abstractmethod
def mixup(self, imgs, gt_label):
pass
@AUGMENT.register_module(name='BatchMixup')
class BatchMixupLayer(BaseMixupLayer):
r"""Mixup layer for a batch of data.
Mixup is a method to reduces the memorization of corrupt labels and
increases the robustness to adversarial examples. It's
proposed in `mixup: Beyond Empirical Risk Minimization
<https://arxiv.org/abs/1710.09412>`
This method simply linearly mix pairs of data and their labels.
Args:
alpha (float): Parameters for Beta distribution to generate the
mixing ratio. It should be a positive number. More details
are in the note.
num_classes (int): The number of classes.
prob (float): The probability to execute mixup. It should be in
range [0, 1]. Default sto 1.0.
Note:
The :math:`\alpha` (``alpha``) determines a random distribution
:math:`Beta(\alpha, \alpha)`. For each batch of data, we sample
a mixing ratio (marked as :math:`\lambda`, ``lam``) from the random
distribution.
"""
def __init__(self, *args, **kwargs):
super(BatchMixupLayer, self).__init__(*args, **kwargs)
def mixup(self, img, gt_label):
one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
lam = np.random.beta(self.alpha, self.alpha)
batch_size = img.size(0)
index = torch.randperm(batch_size)
mixed_img = lam * img + (1 - lam) * img[index, :]
mixed_gt_label = lam * one_hot_gt_label + (
1 - lam) * one_hot_gt_label[index, :]
return mixed_img, mixed_gt_label
def __call__(self, img, gt_label):
return self.mixup(img, gt_label)
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn.functional as F
from mmcls.models.utils.augment.builder import AUGMENT
from .cutmix import BatchCutMixLayer
from .utils import one_hot_encoding
@AUGMENT.register_module(name='BatchResizeMix')
class BatchResizeMixLayer(BatchCutMixLayer):
r"""ResizeMix Random Paste layer for a batch of data.
The ResizeMix will resize an image to a small patch and paste it on another
image. It's proposed in `ResizeMix: Mixing Data with Preserved Object
Information and True Labels <https://arxiv.org/abs/2012.11101>`_
Args:
alpha (float): Parameters for Beta distribution to generate the
mixing ratio. It should be a positive number. More details
can be found in :class:`BatchMixupLayer`.
num_classes (int): The number of classes.
lam_min(float): The minimum value of lam. Defaults to 0.1.
lam_max(float): The maximum value of lam. Defaults to 0.8.
interpolation (str): algorithm used for upsampling:
'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' |
'area'. Default to 'bilinear'.
prob (float): The probability to execute resizemix. It should be in
range [0, 1]. Defaults to 1.0.
cutmix_minmax (List[float], optional): The min/max area ratio of the
patches. If not None, the bounding-box of patches is uniform
sampled within this ratio range, and the ``alpha`` will be ignored.
Otherwise, the bounding-box is generated according to the
``alpha``. Defaults to None.
correct_lam (bool): Whether to apply lambda correction when cutmix bbox
clipped by image borders. Defaults to True
**kwargs: Any other parameters accpeted by :class:`BatchCutMixLayer`.
Note:
The :math:`\lambda` (``lam``) is the mixing ratio. It's a random
variable which follows :math:`Beta(\alpha, \alpha)` and is mapped
to the range [``lam_min``, ``lam_max``].
.. math::
\lambda = \frac{Beta(\alpha, \alpha)}
{\lambda_{max} - \lambda_{min}} + \lambda_{min}
And the resize ratio of source images is calculated by :math:`\lambda`:
.. math::
\text{ratio} = \sqrt{1-\lambda}
"""
def __init__(self,
alpha,
num_classes,
lam_min: float = 0.1,
lam_max: float = 0.8,
interpolation='bilinear',
prob=1.0,
cutmix_minmax=None,
correct_lam=True,
**kwargs):
super(BatchResizeMixLayer, self).__init__(
alpha=alpha,
num_classes=num_classes,
prob=prob,
cutmix_minmax=cutmix_minmax,
correct_lam=correct_lam,
**kwargs)
self.lam_min = lam_min
self.lam_max = lam_max
self.interpolation = interpolation
def cutmix(self, img, gt_label):
one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
lam = np.random.beta(self.alpha, self.alpha)
lam = lam * (self.lam_max - self.lam_min) + self.lam_min
batch_size = img.size(0)
index = torch.randperm(batch_size)
(bby1, bby2, bbx1,
bbx2), lam = self.cutmix_bbox_and_lam(img.shape, lam)
img[:, :, bby1:bby2, bbx1:bbx2] = F.interpolate(
img[index],
size=(bby2 - bby1, bbx2 - bbx1),
mode=self.interpolation)
mixed_gt_label = lam * one_hot_gt_label + (
1 - lam) * one_hot_gt_label[index, :]
return img, mixed_gt_label
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn.functional as F
def one_hot_encoding(gt, num_classes):
"""Change gt_label to one_hot encoding.
If the shape has 2 or more
dimensions, return it without encoding.
Args:
gt (Tensor): The gt label with shape (N,) or shape (N, */).
num_classes (int): The number of classes.
Return:
Tensor: One hot gt label.
"""
if gt.ndim == 1:
# multi-class classification
return F.one_hot(gt, num_classes=num_classes)
else:
# binary classification
# example. [[0], [1], [1]]
# multi-label classification
# example. [[0, 1, 1], [1, 0, 0], [1, 1, 1]]
return gt
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.cnn.bricks.transformer import AdaptivePadding
from mmcv.runner.base_module import BaseModule
from .helpers import to_2tuple
def resize_pos_embed(pos_embed,
src_shape,
dst_shape,
mode='bicubic',
num_extra_tokens=1):
"""Resize pos_embed weights.
Args:
pos_embed (torch.Tensor): Position embedding weights with shape
[1, L, C].
src_shape (tuple): The resolution of downsampled origin training
image, in format (H, W).
dst_shape (tuple): The resolution of downsampled new training
image, in format (H, W).
mode (str): Algorithm used for upsampling. Choose one from 'nearest',
'linear', 'bilinear', 'bicubic' and 'trilinear'.
Defaults to 'bicubic'.
num_extra_tokens (int): The number of extra tokens, such as cls_token.
Defaults to 1.
Returns:
torch.Tensor: The resized pos_embed of shape [1, L_new, C]
"""
if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]:
return pos_embed
assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]'
_, L, C = pos_embed.shape
src_h, src_w = src_shape
assert L == src_h * src_w + num_extra_tokens, \
f"The length of `pos_embed` ({L}) doesn't match the expected " \
f'shape ({src_h}*{src_w}+{num_extra_tokens}). Please check the' \
'`img_size` argument.'
extra_tokens = pos_embed[:, :num_extra_tokens]
src_weight = pos_embed[:, num_extra_tokens:]
src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2)
dst_weight = F.interpolate(
src_weight, size=dst_shape, align_corners=False, mode=mode)
dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2)
return torch.cat((extra_tokens, dst_weight), dim=1)
def resize_relative_position_bias_table(src_shape, dst_shape, table, num_head):
"""Resize relative position bias table.
Args:
src_shape (int): The resolution of downsampled origin training
image, in format (H, W).
dst_shape (int): The resolution of downsampled new training
image, in format (H, W).
table (tensor): The relative position bias of the pretrained model.
num_head (int): Number of attention heads.
Returns:
torch.Tensor: The resized relative position bias table.
"""
from scipy import interpolate
def geometric_progression(a, r, n):
return a * (1.0 - r**n) / (1.0 - r)
left, right = 1.01, 1.5
while right - left > 1e-6:
q = (left + right) / 2.0
gp = geometric_progression(1, q, src_shape // 2)
if gp > dst_shape // 2:
right = q
else:
left = q
dis = []
cur = 1
for i in range(src_shape // 2):
dis.append(cur)
cur += q**(i + 1)
r_ids = [-_ for _ in reversed(dis)]
x = r_ids + [0] + dis
y = r_ids + [0] + dis
t = dst_shape // 2.0
dx = np.arange(-t, t + 0.1, 1.0)
dy = np.arange(-t, t + 0.1, 1.0)
all_rel_pos_bias = []
for i in range(num_head):
z = table[:, i].view(src_shape, src_shape).float().numpy()
f_cubic = interpolate.interp2d(x, y, z, kind='cubic')
all_rel_pos_bias.append(
torch.Tensor(f_cubic(dx,
dy)).contiguous().view(-1,
1).to(table.device))
new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
return new_rel_pos_bias
class PatchEmbed(BaseModule):
"""Image to Patch Embedding.
We use a conv layer to implement PatchEmbed.
Args:
img_size (int | tuple): The size of input image. Default: 224
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
norm_cfg (dict, optional): Config dict for normalization layer.
Default: None
conv_cfg (dict, optional): The config dict for conv layers.
Default: None
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
Default: None
"""
def __init__(self,
img_size=224,
in_channels=3,
embed_dims=768,
norm_cfg=None,
conv_cfg=None,
init_cfg=None):
super(PatchEmbed, self).__init__(init_cfg)
warnings.warn('The `PatchEmbed` in mmcls will be deprecated. '
'Please use `mmcv.cnn.bricks.transformer.PatchEmbed`. '
"It's more general and supports dynamic input shape")
if isinstance(img_size, int):
img_size = to_2tuple(img_size)
elif isinstance(img_size, tuple):
if len(img_size) == 1:
img_size = to_2tuple(img_size[0])
assert len(img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(img_size)}'
self.img_size = img_size
self.embed_dims = embed_dims
# Use conv layer to embed
conv_cfg = conv_cfg or dict()
_conv_cfg = dict(
type='Conv2d', kernel_size=16, stride=16, padding=0, dilation=1)
_conv_cfg.update(conv_cfg)
self.projection = build_conv_layer(_conv_cfg, in_channels, embed_dims)
# Calculate how many patches a input image is splited to.
h_out, w_out = [(self.img_size[i] + 2 * self.projection.padding[i] -
self.projection.dilation[i] *
(self.projection.kernel_size[i] - 1) - 1) //
self.projection.stride[i] + 1 for i in range(2)]
self.patches_resolution = (h_out, w_out)
self.num_patches = h_out * w_out
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't " \
f'match model ({self.img_size[0]}*{self.img_size[1]}).'
# The output size is (B, N, D), where N=H*W/P/P, D is embid_dim
x = self.projection(x).flatten(2).transpose(1, 2)
if self.norm is not None:
x = self.norm(x)
return x
# Modified from pytorch-image-models
class HybridEmbed(BaseModule):
"""CNN Feature Map Embedding.
Extract feature map from CNN, flatten,
project to embedding dim.
Args:
backbone (nn.Module): CNN backbone
img_size (int | tuple): The size of input image. Default: 224
feature_size (int | tuple, optional): Size of feature map extracted by
CNN backbone. Default: None
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
conv_cfg (dict, optional): The config dict for conv layers.
Default: None.
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
backbone,
img_size=224,
feature_size=None,
in_channels=3,
embed_dims=768,
conv_cfg=None,
init_cfg=None):
super(HybridEmbed, self).__init__(init_cfg)
assert isinstance(backbone, nn.Module)
if isinstance(img_size, int):
img_size = to_2tuple(img_size)
elif isinstance(img_size, tuple):
if len(img_size) == 1:
img_size = to_2tuple(img_size[0])
assert len(img_size) == 2, \
f'The size of image should have length 1 or 2, ' \
f'but got {len(img_size)}'
self.img_size = img_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
# FIXME this is hacky, but most reliable way of
# determining the exact dim of the output feature
# map for all networks, the feature metadata has
# reliable channel and stride info, but using
# stride to calc feature dim requires info about padding of
# each stage that isn't captured.
training = backbone.training
if training:
backbone.eval()
o = self.backbone(
torch.zeros(1, in_channels, img_size[0], img_size[1]))
if isinstance(o, (list, tuple)):
# last feature if backbone outputs list/tuple of features
o = o[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
if hasattr(self.backbone, 'feature_info'):
feature_dim = self.backbone.feature_info.channels()[-1]
else:
feature_dim = self.backbone.num_features
self.num_patches = feature_size[0] * feature_size[1]
# Use conv layer to embed
conv_cfg = conv_cfg or dict()
_conv_cfg = dict(
type='Conv2d', kernel_size=1, stride=1, padding=0, dilation=1)
_conv_cfg.update(conv_cfg)
self.projection = build_conv_layer(_conv_cfg, feature_dim, embed_dims)
def forward(self, x):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
# last feature if backbone outputs list/tuple of features
x = x[-1]
x = self.projection(x).flatten(2).transpose(1, 2)
return x
class PatchMerging(BaseModule):
"""Merge patch feature map. Modified from mmcv, which uses pre-norm layer
whereas Swin V2 uses post-norm here. Therefore, add extra parameter to
decide whether use post-norm or not.
This layer groups feature map by kernel_size, and applies norm and linear
layers to the grouped feature map ((used in Swin Transformer)).
Our implementation uses `nn.Unfold` to
merge patches, which is about 25% faster than the original
implementation. However, we need to modify pretrained
models for compatibility.
Args:
in_channels (int): The num of input channels.
to gets fully covered by filter and stride you specified.
out_channels (int): The num of output channels.
kernel_size (int | tuple, optional): the kernel size in the unfold
layer. Defaults to 2.
stride (int | tuple, optional): the stride of the sliding blocks in the
unfold layer. Defaults to None. (Would be set as `kernel_size`)
padding (int | tuple | string ): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Defaults to "corner".
dilation (int | tuple, optional): dilation parameter in the unfold
layer. Default: 1.
bias (bool, optional): Whether to add bias in linear layer or not.
Defaults to False.
norm_cfg (dict, optional): Config dict for normalization layer.
Defaults to dict(type='LN').
is_post_norm (bool): Whether to use post normalization here.
Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Defaults to None.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=2,
stride=None,
padding='corner',
dilation=1,
bias=False,
norm_cfg=dict(type='LN'),
is_post_norm=False,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.is_post_norm = is_post_norm
if stride:
stride = stride
else:
stride = kernel_size
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
if isinstance(padding, str):
self.adaptive_padding = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
# disable the padding of unfold
padding = 0
else:
self.adaptive_padding = None
padding = to_2tuple(padding)
self.sampler = nn.Unfold(
kernel_size=kernel_size,
dilation=dilation,
padding=padding,
stride=stride)
sample_dim = kernel_size[0] * kernel_size[1] * in_channels
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
if norm_cfg is not None:
# build pre or post norm layer based on different channels
if self.is_post_norm:
self.norm = build_norm_layer(norm_cfg, out_channels)[1]
else:
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
else:
self.norm = None
def forward(self, x, input_size):
"""
Args:
x (Tensor): Has shape (B, H*W, C_in).
input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
Default: None.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
- out_size (tuple[int]): Spatial shape of x, arrange as
(Merged_H, Merged_W).
"""
B, L, C = x.shape
assert isinstance(input_size, Sequence), f'Expect ' \
f'input_size is ' \
f'`Sequence` ' \
f'but get {input_size}'
H, W = input_size
assert L == H * W, 'input feature has wrong size'
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
if self.adaptive_padding:
x = self.adaptive_padding(x)
H, W = x.shape[-2:]
# Use nn.Unfold to merge patch. About 25% faster than original method,
# but need to modify pretrained model for compatibility
# if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
x = self.sampler(x)
out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
(self.sampler.kernel_size[0] - 1) -
1) // self.sampler.stride[0] + 1
out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
(self.sampler.kernel_size[1] - 1) -
1) // self.sampler.stride[1] + 1
output_size = (out_h, out_w)
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
if self.is_post_norm:
# use post-norm here
x = self.reduction(x)
x = self.norm(x) if self.norm else x
else:
x = self.norm(x) if self.norm else x
x = self.reduction(x)
return x, output_size
# Copyright (c) OpenMMLab. All rights reserved.
import collections.abc
import warnings
from itertools import repeat
import torch
from mmcv.utils import digit_version
def is_tracing() -> bool:
"""Determine whether the model is called during the tracing of code with
``torch.jit.trace``."""
if digit_version(torch.__version__) >= digit_version('1.6.0'):
on_trace = torch.jit.is_tracing()
# In PyTorch 1.6, torch.jit.is_tracing has a bug.
# Refers to https://github.com/pytorch/pytorch/issues/42448
if isinstance(on_trace, bool):
return on_trace
else:
return torch._C._is_tracing()
else:
warnings.warn(
'torch.jit.is_tracing is only supported after v1.6.0. '
'Therefore is_tracing returns False automatically. Please '
'set on_trace manually if you are using trace.', UserWarning)
return False
# From PyTorch internals
def _ntuple(n):
"""A `to_tuple` function generator.
It returns a function, this function will repeat the input to a tuple of
length ``n`` if the input is not an Iterable object, otherwise, return the
input directly.
Args:
n (int): The number of the target length.
"""
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import DropPath
from mmcv.runner import BaseModule
from .se_layer import SELayer
# class InvertedResidual(nn.Module):
class InvertedResidual(BaseModule):
"""Inverted Residual Block.
Args:
in_channels (int): The input channels of this Module.
out_channels (int): The output channels of this Module.
in_channels (int): The input channels of this module.
out_channels (int): The output channels of this module.
mid_channels (int): The input channels of the depthwise convolution.
kernel_size (int): The kernal size of the depthwise convolution.
Default: 3.
stride (int): The stride of the depthwise convolution. Default: 1.
se_cfg (dict): Config dict for se layer. Defaul: None, which means no
se layer.
with_expand_conv (bool): Use expand conv or not. If set False,
mid_channels must be the same with in_channels.
Default: True.
conv_cfg (dict): Config dict for convolution layer. Default: None,
kernel_size (int): The kernel size of the depthwise convolution.
Defaults to 3.
stride (int): The stride of the depthwise convolution. Defaults to 1.
se_cfg (dict, optional): Config dict for se layer. Defaults to None,
which means no se layer.
conv_cfg (dict): Config dict for convolution layer. Defaults to None,
which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
Defaults to ``dict(type='BN')``.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
Defaults to ``dict(type='ReLU')``.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
Returns:
Tensor: The output tensor.
memory while slowing down the training speed. Defaults to False.
init_cfg (dict | list[dict], optional): Initialization config dict.
"""
def __init__(self,
......@@ -41,23 +39,23 @@ class InvertedResidual(BaseModule):
kernel_size=3,
stride=1,
se_cfg=None,
with_expand_conv=True,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
drop_path_rate=0.,
with_cp=False,
init_cfg=None):
super(InvertedResidual, self).__init__(init_cfg)
self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
assert stride in [1, 2]
self.with_cp = with_cp
self.drop_path = DropPath(
drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.with_se = se_cfg is not None
self.with_expand_conv = with_expand_conv
self.with_expand_conv = (mid_channels != in_channels)
if self.with_se:
assert isinstance(se_cfg, dict)
if not self.with_expand_conv:
assert mid_channels == in_channels
if self.with_expand_conv:
self.expand_conv = ConvModule(
......@@ -89,9 +87,17 @@ class InvertedResidual(BaseModule):
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
act_cfg=None)
def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor.
"""
def _inner_forward(x):
out = x
......@@ -107,7 +113,7 @@ class InvertedResidual(BaseModule):
out = self.linear_conv(out)
if self.with_res_shortcut:
return x + out
return x + self.drop_path(out)
else:
return out
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import nn
class LayerScale(nn.Module):
"""LayerScale layer.
Args:
dim (int): Dimension of input features.
inplace (bool): inplace: can optionally do the
operation in-place. Default: ``False``
data_format (str): The input data format, can be 'channels_last'
and 'channels_first', representing (B, C, H, W) and
(B, N, C) format data respectively.
"""
def __init__(self,
dim: int,
inplace: bool = False,
data_format: str = 'channels_last'):
super().__init__()
assert data_format in ('channels_last', 'channels_first'), \
"'data_format' could only be channels_last or channels_first."
self.inplace = inplace
self.data_format = data_format
self.weight = nn.Parameter(torch.ones(dim) * 1e-5)
def forward(self, x):
if self.data_format == 'channels_first':
if self.inplace:
return x.mul_(self.weight.view(-1, 1, 1))
else:
return x * self.weight.view(-1, 1, 1)
return x.mul_(self.weight) if self.inplace else x * self.weight
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment