Unverified Commit e05fb560 authored by Shilong Zhang's avatar Shilong Zhang Committed by GitHub
Browse files

Refactor the baseclass related to transformer (#978)



* minor changes

* change to modulist

* change to Sequential

* replace dropout with attn_drop and proj_drop in MultiheadAttention

* add operation_name for attn

* add drop path and move all ffn args to ffncfgs

* fix typo

* fix a bug when use default value of ffn_cfgs

* fix ffns

* add deprecate warning

* fix deprecate warning

* change to pop kwargs

* support register FFN of transformer

* support batch first

* fix batch first wapper

* fix forward wapper

* fix typo

* fix lint

* add unitest for transformer

* fix unitest

* fix equal

* use allclose

* fix comments

* fix comments

* change configdict to dict

* move drop to a file

* add comments for drop path

* add noqa 501

* move bnc wapper to MultiheadAttention

* move bnc wapper to MultiheadAttention

* use dep warning

* resolve comments

* add unitest:

* rename residual to identity

* revert runner

* msda residual to identity

* rename inp_identity to identity

* fix name

* fix transformer

* remove key in msda

* remove assert for key
Co-authored-by: default avatarHIT-cwh <2892770585@qq.com>
Co-authored-by: default avatarbkhuang <congee524@gmail.com>
Co-authored-by: default avatarWenwei Zhang <40779233+ZwwWayne@users.noreply.github.com>
parent 11629d52
......@@ -5,6 +5,7 @@ from .conv2d_adaptive_padding import Conv2dAdaptivePadding
from .conv_module import ConvModule
from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d
from .depthwise_separable_conv_module import DepthwiseSeparableConvModule
from .drop import Dropout, DropPath
from .generalized_attention import GeneralizedAttention
from .hsigmoid import HSigmoid
from .hswish import HSwish
......@@ -29,5 +30,5 @@ __all__ = [
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
'ConvTranspose3d', 'MaxPool3d', 'Conv3d'
'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'Dropout', 'DropPath'
]
import torch
import torch.nn as nn
from mmcv import build_from_cfg
from .registry import DROPOUT_LAYERS
def drop_path(x, drop_prob=0., training=False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
residual blocks).
We follow the implementation
https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
# handle tensors with different dimensions, not just 4D tensors.
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(
shape, dtype=x.dtype, device=x.device)
output = x.div(keep_prob) * random_tensor.floor()
return output
@DROPOUT_LAYERS.register_module()
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
residual blocks).
We follow the implementation
https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
Args:
drop_prob (float): Probability of the path to be zeroed. Default: 0.1
"""
def __init__(self, drop_prob=0.1):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
@DROPOUT_LAYERS.register_module()
class Dropout(nn.Dropout):
"""A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of
``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with
``DropPath``
Args:
drop_prob (float): Probability of the elements to be
zeroed. Default: 0.5.
inplace (bool): Do the operation inplace or not. Default: False.
"""
def __init__(self, drop_prob=0.5, inplace=False):
super().__init__(p=drop_prob, inplace=inplace)
def build_dropout(cfg, default_args=None):
"""Builder for drop out layers."""
return build_from_cfg(cfg, DROPOUT_LAYERS, default_args)
......@@ -7,7 +7,9 @@ PADDING_LAYERS = Registry('padding layer')
UPSAMPLE_LAYERS = Registry('upsample layer')
PLUGIN_LAYERS = Registry('plugin layer')
POSITIONAL_ENCODING = Registry('Position encoding')
ATTENTION = Registry('Attention')
TRANSFORMER_LAYER = Registry('TransformerLayer')
TRANSFORMER_LAYER_SEQUENCE = Registry('TransformerLayerSequence')
DROPOUT_LAYERS = Registry('drop out layers')
POSITIONAL_ENCODING = Registry('position encoding')
ATTENTION = Registry('attention')
FEEDFORWARD_NETWORK = Registry('feed-forward Network')
TRANSFORMER_LAYER = Registry('transformerLayer')
TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence')
import copy
import math
import warnings
import torch
import torch.nn as nn
from mmcv import ConfigDict
from mmcv.cnn import (Linear, build_activation_layer, build_norm_layer,
constant_init, xavier_init)
from mmcv.ops.multi_scale_deform_attn import (
MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch)
from mmcv.runner.base_module import BaseModule
from mmcv import ConfigDict, deprecated_api_warning
from mmcv.cnn import Linear, build_activation_layer, build_norm_layer
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.utils import build_from_cfg
from .registry import (ATTENTION, POSITIONAL_ENCODING, TRANSFORMER_LAYER,
TRANSFORMER_LAYER_SEQUENCE)
from .drop import build_dropout
from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
def build_positional_encoding(cfg, default_args=None):
......@@ -26,6 +23,11 @@ def build_attention(cfg, default_args=None):
return build_from_cfg(cfg, ATTENTION, default_args)
def build_feedforward_network(cfg, default_args=None):
"""Builder for feed-forward network (FFN)."""
return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args)
def build_transformer_layer(cfg, default_args=None):
"""Builder for transformer layer."""
return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args)
......@@ -38,39 +40,82 @@ def build_transformer_layer_sequence(cfg, default_args=None):
@ATTENTION.register_module()
class MultiheadAttention(BaseModule):
"""A warpper for torch.nn.MultiheadAttention.
"""A wrapper for ``torch.nn.MultiheadAttention``.
This module implements MultiheadAttention with residual connection,
and positional encoding used in DETR is also passed as input.
This module implements MultiheadAttention with identity connection,
and positional encoding is also passed as input.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads. Same as
`nn.MultiheadAttention`.
dropout (float):w A Dropout layer on attn_output_weights. Default: 0..
num_heads (int): Parallel attention heads.
attn_drop (float): A Dropout layer on attn_output_weights.
Default: 0.0.
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
Default: 0.0.
dropout_layer (obj:`ConfigDict`): The dropout_layer used
when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
"""
def __init__(self,
embed_dims,
num_heads,
dropout=0.,
attn_drop=0.,
proj_drop=0.,
dropout_layer=dict(type='Dropout', drop_prob=0.),
init_cfg=None,
batch_first=False,
**kwargs):
super(MultiheadAttention, self).__init__(init_cfg)
if 'dropout' in kwargs:
warnings.warn('The arguments `dropout` in MultiheadAttention '
'has been deprecated, now you can separately '
'set `attn_drop`(float), proj_drop(float), '
'and `dropout_layer`(dict) ')
attn_drop = kwargs['dropout']
dropout_layer['drop_prob'] = kwargs.pop('dropout')
self.embed_dims = embed_dims
self.num_heads = num_heads
self.dropout = dropout
self.attn = nn.MultiheadAttention(embed_dims, num_heads, dropout,
self.batch_first = batch_first
self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
**kwargs)
self.dropout = nn.Dropout(dropout)
if self.batch_first:
def _bnc_to_nbc(forward):
"""This function can adjust the shape of dataflow('key',
'query', 'value') from batch_first (batch, num_query,
embed_dims) to num_query_first (num_query ,batch,
embed_dims)."""
def forward_wrapper(**kwargs):
convert_keys = ('key', 'query', 'value')
for key in kwargs.keys():
if key in convert_keys:
kwargs[key] = kwargs[key].transpose(0, 1)
attn_output, attn_output_weights = forward(**kwargs)
return attn_output.transpose(0, 1), attn_output_weights
return forward_wrapper
self.attn.forward = _bnc_to_nbc(self.attn.forward)
self.proj_drop = nn.Dropout(proj_drop)
self.dropout_layer = build_dropout(
dropout_layer) if dropout_layer else nn.Identity()
@deprecated_api_warning({'residual': 'identity'},
cls_name='MultiheadAttention')
def forward(self,
query,
key=None,
value=None,
residual=None,
identity=None,
query_pos=None,
key_pos=None,
attn_mask=None,
......@@ -83,15 +128,17 @@ class MultiheadAttention(BaseModule):
Args:
query (Tensor): The input query with shape [num_queries, bs,
embed_dims]. Same in `nn.MultiheadAttention.forward`.
embed_dims] if self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims]. Same in `nn.MultiheadAttention.forward`.
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
If None, the ``query`` will be used. Defaults to None.
value (Tensor): The value tensor with same shape as `key`.
Same in `nn.MultiheadAttention.forward`. Defaults to None.
If None, the `key` will be used.
residual (Tensor): This tensor, with the same shape as x,
will be used for the residual link.
identity (Tensor): This tensor, with the same shape as x,
will be used for the identity link.
If None, `x` will be used. Defaults to None.
query_pos (Tensor): The positional encoding for query, with
the same shape as `x`. If not None, it will
......@@ -105,18 +152,21 @@ class MultiheadAttention(BaseModule):
num_keys]. Same in `nn.MultiheadAttention.forward`.
Defaults to None.
key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
Same in `nn.MultiheadAttention.forward`. Defaults to None.
Defaults to None.
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
Tensor: forwarded results with shape
[num_queries, bs, embed_dims]
if self.batch_first is False, else
[bs, num_queries embed_dims].
"""
if key is None:
key = query
if value is None:
value = key
if residual is None:
residual = query
if identity is None:
identity = query
if key_pos is None:
if query_pos is not None:
# use query_pos if key_pos is not available
......@@ -129,238 +179,56 @@ class MultiheadAttention(BaseModule):
query = query + query_pos
if key_pos is not None:
key = key + key_pos
out = self.attn(
query,
key,
query=query,
key=key,
value=value,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask)[0]
return residual + self.dropout(out)
@ATTENTION.register_module()
class MultiScaleDeformableAttention(BaseModule):
"""An attention module used in Deformable-Detr. `Deformable DETR:
Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_residual`.
Default: 0..
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=4,
im2col_step=64,
dropout=0.1,
norm_cfg=None,
init_cfg=None):
super().__init__(init_cfg)
if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, '
f'but got {embed_dims} and {num_heads}')
dim_per_head = embed_dims // num_heads
self.norm_cfg = norm_cfg
self.init_cfg = init_cfg
self.dropout = nn.Dropout(dropout)
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError(
'invalid input for _is_power_of_2: {} (type: {})'.format(
n, type(n)))
return (n & (n - 1) == 0) and n != 0
if not _is_power_of_2(dim_per_head):
warnings.warn(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.')
self.im2col_step = im2col_step
self.embed_dims = embed_dims
self.num_levels = num_levels
self.num_heads = num_heads
self.num_points = num_points
self.sampling_offsets = nn.Linear(
embed_dims, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dims,
num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dims, embed_dims)
self.output_proj = nn.Linear(embed_dims, embed_dims)
self.init_weight()
def init_weight(self):
"""Default initialization for Parameters of Module."""
constant_init(self.sampling_offsets, 0.)
thetas = torch.arange(
self.num_heads,
dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init /
grid_init.abs().max(-1, keepdim=True)[0]).view(
self.num_heads, 1, 1,
2).repeat(1, self.num_levels, self.num_points, 1)
for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1
self.sampling_offsets.bias.data = grid_init.view(-1)
constant_init(self.attention_weights, val=0., bias=0.)
xavier_init(self.value_proj, distribution='uniform', bias=0.)
xavier_init(self.output_proj, distribution='uniform', bias=0.)
def forward(self,
query,
key,
value,
residual=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
residual (Tensor): The tensor used for addition, with the
same shape as `x`. Default None. If None, `x` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different level. With shape (num_levels, 2),
last dimension represent (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape (num_levels) and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if key is None:
key = query
if value is None:
value = key
if residual is None:
inp_residual = query
if query_pos is not None:
query = query + query_pos
# change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2)
bs, num_query, _ = query.shape
bs, num_key, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_key
value = self.value_proj(value)
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], 0.0)
value = value.view(bs, num_key, self.num_heads, -1)
sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points)
attention_weights = attention_weights.softmax(-1)
attention_weights = attention_weights.view(bs, num_query,
self.num_heads,
self.num_levels,
self.num_points)
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets \
/ offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.num_points \
* reference_points[:, :, None, :, None, 2:] \
* 0.5
else:
raise ValueError(
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available():
output = MultiScaleDeformableAttnFunction.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
output = self.output_proj(output).permute(1, 0, 2)
# (num_query, bs ,embed_dims)
return self.dropout(output) + inp_residual
return identity + self.dropout_layer(self.proj_drop(out))
@FEEDFORWARD_NETWORK.register_module()
class FFN(BaseModule):
"""Implements feed-forward networks (FFNs) with residual connection.
"""Implements feed-forward networks (FFNs) with identity connection.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`.
`MultiheadAttention`. Defaults: 256.
feedforward_channels (int): The hidden dimension of FFNs.
Defaults: 1024.
num_fcs (int, optional): The number of fully-connected layers in
FFNs. Default: 2.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='ReLU')
dropout (float, optional): Probability of an element to be
zeroed. Default 0..
add_residual (bool, optional): Whether to add the
residual connection. Default: `True`.
ffn_drop (float, optional): Probability of an element to be
zeroed in FFN. Default 0.0.
add_identity (bool, optional): Whether to add the
identity connection. Default: `True`.
dropout_layer (obj:`ConfigDict`): The dropout_layer used
when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
@deprecated_api_warning(
{
'dropout': 'ffn_drop',
'add_residual': 'add_identity'
},
cls_name='FFN')
def __init__(self,
embed_dims,
feedforward_channels,
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
act_cfg=dict(type='ReLU', inplace=True),
dropout=0.,
add_residual=True,
init_cfg=None):
ffn_drop=0.,
dropout_layer=None,
add_identity=True,
init_cfg=None,
**kwargs):
super(FFN, self).__init__(init_cfg)
assert num_fcs >= 2, 'num_fcs should be no less ' \
f'than 2. got {num_fcs}.'
......@@ -368,33 +236,35 @@ class FFN(BaseModule):
self.feedforward_channels = feedforward_channels
self.num_fcs = num_fcs
self.act_cfg = act_cfg
self.dropout = dropout
self.activate = build_activation_layer(act_cfg)
layers = []
in_channels = embed_dims
for _ in range(num_fcs - 1):
layers.append(
nn.Sequential(
Sequential(
Linear(in_channels, feedforward_channels), self.activate,
nn.Dropout(dropout)))
nn.Dropout(ffn_drop)))
in_channels = feedforward_channels
layers.append(Linear(feedforward_channels, embed_dims))
self.layers = nn.Sequential(*layers)
self.dropout = nn.Dropout(dropout)
self.add_residual = add_residual
def forward(self, x, residual=None):
layers.append(nn.Dropout(ffn_drop))
self.layers = Sequential(*layers)
self.dropout_layer = build_dropout(
dropout_layer) if dropout_layer else torch.nn.Identity()
self.add_identity = add_identity
@deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
def forward(self, x, identity=None):
"""Forward function for `FFN`.
The function would add x to the output tensor if residue is None.
"""
out = self.layers(x)
if not self.add_residual:
return self.dropout(out)
if residual is None:
residual = x
return residual + self.dropout(out)
if not self.add_identity:
return self.dropout_layer(out)
if identity is None:
identity = x
return identity + self.dropout_layer(out)
@TRANSFORMER_LAYER.register_module()
......@@ -416,85 +286,121 @@ class BaseTransformerLayer(BaseModule):
corresponding attentions in operation_order.
If it is a dict, all of the attention modules in operation_order
will be built with this config. Default: None.
feedforward_channels (int): The hidden dimension for FFNs.
Default: None.
ffn_dropout (float): Probability of an element to be zeroed
in ffn. Default 0..
ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
Configs for FFN, The order of the configs in the list should be
consistent with corresponding ffn in operation_order.
If it is a dict, all of the attention modules in operation_order
will be built with this config.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
Support `prenorm` when you specifying first element as `norm`.
Default:None.
act_cfg (dict): The activation config for FFNs.
Default: dict(type='ReLU')
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
ffn_num_fcs (int): The number of fully-connected layers in FFNs.
Default:2.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
batch_first (bool): Key, Query and Value are shape
of (batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
"""
def __init__(self,
attn_cfgs=None,
feedforward_channels=None,
ffn_dropout=0.,
ffn_cfgs=dict(
type='FFN',
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
ffn_drop=0.,
act_cfg=dict(type='ReLU', inplace=True),
),
operation_order=None,
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN'),
ffn_num_fcs=2,
init_cfg=None):
init_cfg=None,
batch_first=False,
**kwargs):
deprecated_args = dict(
feedforward_channels='feedforward_channels',
ffn_dropout='ffn_drop',
ffn_num_fcs='num_fcs')
for ori_name, new_name in deprecated_args.items():
if ori_name in kwargs:
warnings.warn(
f'The arguments `{ori_name}` in BaseTransformerLayer '
f'has been deprecated, now you should set `{new_name}` '
f'and other FFN related arguments '
f'to a dict named `ffn_cfgs`. ')
ffn_cfgs[new_name] = kwargs[ori_name]
super(BaseTransformerLayer, self).__init__(init_cfg)
self.batch_first = batch_first
assert set(operation_order) & set(
['self_attn', 'norm', 'ffn', 'cross_attn']) == \
set(operation_order), f'The operation_order of' \
f' {self.__class__.__name__} should ' \
f'contains all four operation type ' \
f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
num_attn = operation_order.count('self_attn') + operation_order.count(
'cross_attn')
if isinstance(attn_cfgs, ConfigDict):
if isinstance(attn_cfgs, dict):
attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
else:
assert num_attn == len(attn_cfgs), f'The length ' \
f'of attn_cfg {num_attn} is ' \
f'not consistent with the number of attention' \
f'in operation_order {operation_order}.'
self.init_cfg = init_cfg
self.num_attn = num_attn
self.feedforward_channels = feedforward_channels
self.ffn_dropout = ffn_dropout
self.operation_order = operation_order
self.act_cfg = act_cfg
self.norm_cfg = norm_cfg
self.ffn_num_fcs = ffn_num_fcs
self.pre_norm = operation_order[0] == 'norm'
self.attentions = nn.ModuleList()
self.attentions = ModuleList()
index = 0
for operation in operation_order:
if operation in ['self_attn', 'cross_attn']:
for operation_name in operation_order:
if operation_name in ['self_attn', 'cross_attn']:
if 'batch_first' in attn_cfgs[index]:
assert self.batch_first == attn_cfgs[index]['batch_first']
else:
attn_cfgs[index]['batch_first'] = self.batch_first
attention = build_attention(attn_cfgs[index])
# Some custom attentions used as `self_attn`
# or `cross_attn` can have different behavior.
attention.operation_name = operation_name
self.attentions.append(attention)
index += 1
self.embed_dims = self.attentions[0].embed_dims
self.ffns = nn.ModuleList()
self.ffns = ModuleList()
num_ffns = operation_order.count('ffn')
for _ in range(num_ffns):
if isinstance(ffn_cfgs, dict):
ffn_cfgs = ConfigDict(ffn_cfgs)
if isinstance(ffn_cfgs, dict):
ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
assert len(ffn_cfgs) == num_ffns
for ffn_index in range(num_ffns):
if 'embed_dims' not in ffn_cfgs[ffn_index]:
ffn_cfgs['embed_dims'] = self.embed_dims
else:
assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
self.ffns.append(
FFN(self.embed_dims, feedforward_channels, ffn_num_fcs,
act_cfg, ffn_dropout))
build_feedforward_network(ffn_cfgs[ffn_index],
dict(type='FFN')))
self.norms = nn.ModuleList()
self.norms = ModuleList()
num_norms = operation_order.count('norm')
for _ in range(num_norms):
self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
def forward(self,
query,
key,
value,
key=None,
value=None,
query_pos=None,
key_pos=None,
attn_masks=None,
......@@ -506,12 +412,14 @@ class BaseTransformerLayer(BaseModule):
**kwargs contains some specific arguments of attentions.
Args:
query (Tensor): Input query with the shape
`(num_queries, bs, embed_dims)`.
key (Tensor): The key tensor with shape
`(num_keys, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_keys, bs, embed_dims)`.
query (Tensor): The input query with shape
[num_queries, bs, embed_dims] if
self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
value (Tensor): The value tensor with same shape as `key`.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`.
......@@ -533,7 +441,7 @@ class BaseTransformerLayer(BaseModule):
norm_index = 0
attn_index = 0
ffn_index = 0
inp_residual = query
identity = query
if attn_masks is None:
attn_masks = [None for _ in range(self.num_attn)]
elif isinstance(attn_masks, torch.Tensor):
......@@ -555,14 +463,14 @@ class BaseTransformerLayer(BaseModule):
query,
temp_key,
temp_value,
inp_residual if self.pre_norm else None,
identity if self.pre_norm else None,
query_pos=query_pos,
key_pos=query_pos,
attn_mask=attn_masks[attn_index],
key_padding_mask=query_key_padding_mask,
**kwargs)
attn_index += 1
inp_residual = query
identity = query
elif layer == 'norm':
query = self.norms[norm_index](query)
......@@ -573,18 +481,18 @@ class BaseTransformerLayer(BaseModule):
query,
key,
value,
inp_residual if self.pre_norm else None,
identity if self.pre_norm else None,
query_pos=query_pos,
key_pos=key_pos,
attn_mask=attn_masks[attn_index],
key_padding_mask=key_padding_mask,
**kwargs)
attn_index += 1
inp_residual = query
identity = query
elif layer == 'ffn':
query = self.ffns[ffn_index](
query, inp_residual if self.pre_norm else None)
query, identity if self.pre_norm else None)
ffn_index += 1
return query
......@@ -612,7 +520,7 @@ class TransformerLayerSequence(BaseModule):
def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
super(TransformerLayerSequence, self).__init__(init_cfg)
if isinstance(transformerlayers, ConfigDict):
if isinstance(transformerlayers, dict):
transformerlayers = [
copy.deepcopy(transformerlayers) for _ in range(num_layers)
]
......@@ -620,13 +528,11 @@ class TransformerLayerSequence(BaseModule):
assert isinstance(transformerlayers, list) and \
len(transformerlayers) == num_layers
self.num_layers = num_layers
operation_order = transformerlayers[0]['operation_order']
self.pre_norm = operation_order[0] == 'norm'
self.layers = nn.ModuleList()
self.layers = ModuleList()
for i in range(num_layers):
self.layers.append(build_transformer_layer(transformerlayers[i]))
self.embed_dims = self.layers[0].embed_dims
self.pre_norm = self.layers[0].operation_order[0] == 'norm'
self.pre_norm = self.layers[0].pre_norm
def forward(self,
query,
......@@ -661,7 +567,7 @@ class TransformerLayerSequence(BaseModule):
shape [bs, num_keys]. Default: None.
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
Tensor: results with shape [num_queries, bs, embed_dims].
"""
for layer in self.layers:
query = layer(
......
......@@ -21,6 +21,7 @@ from .masked_conv import MaskedConv2d, masked_conv2d
from .modulated_deform_conv import (ModulatedDeformConv2d,
ModulatedDeformConv2dPack,
modulated_deform_conv2d)
from .multi_scale_deform_attn import MultiScaleDeformableAttention
from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms
from .pixel_group import pixel_group
from .point_sample import (SimpleRoIAlign, point_sample,
......@@ -50,5 +51,5 @@ __all__ = [
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated',
'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand',
'BorderAlign', 'border_align'
'MultiScaleDeformableAttention', 'BorderAlign', 'border_align'
]
import math
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd.function import Function, once_differentiable
from mmcv import deprecated_api_warning
from mmcv.cnn import constant_init, xavier_init
from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.runner import BaseModule
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
......@@ -140,3 +148,211 @@ def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes,
attention_weights).sum(-1).view(bs, num_heads * embed_dims,
num_queries)
return output.transpose(1, 2).contiguous()
@ATTENTION.register_module()
class MultiScaleDeformableAttention(BaseModule):
"""An attention module used in Deformable-Detr. `Deformable DETR:
Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=4,
im2col_step=64,
dropout=0.1,
batch_first=False,
norm_cfg=None,
init_cfg=None):
super().__init__(init_cfg)
if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, '
f'but got {embed_dims} and {num_heads}')
dim_per_head = embed_dims // num_heads
self.norm_cfg = norm_cfg
self.dropout = nn.Dropout(dropout)
self.batch_first = batch_first
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError(
'invalid input for _is_power_of_2: {} (type: {})'.format(
n, type(n)))
return (n & (n - 1) == 0) and n != 0
if not _is_power_of_2(dim_per_head):
warnings.warn(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.')
self.im2col_step = im2col_step
self.embed_dims = embed_dims
self.num_levels = num_levels
self.num_heads = num_heads
self.num_points = num_points
self.sampling_offsets = nn.Linear(
embed_dims, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dims,
num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dims, embed_dims)
self.output_proj = nn.Linear(embed_dims, embed_dims)
self.init_weights()
def init_weights(self):
"""Default initialization for Parameters of Module."""
constant_init(self.sampling_offsets, 0.)
thetas = torch.arange(
self.num_heads,
dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init /
grid_init.abs().max(-1, keepdim=True)[0]).view(
self.num_heads, 1, 1,
2).repeat(1, self.num_levels, self.num_points, 1)
for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1
self.sampling_offsets.bias.data = grid_init.view(-1)
constant_init(self.attention_weights, val=0., bias=0.)
xavier_init(self.value_proj, distribution='uniform', bias=0.)
xavier_init(self.output_proj, distribution='uniform', bias=0.)
self._is_init = True
@deprecated_api_warning({'residual': 'identity'},
cls_name='MultiScaleDeformableAttention')
def forward(self,
query,
key=None,
value=None,
identity=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if value is None:
value = query
if identity is None:
identity = query
if query_pos is not None:
query = query + query_pos
if not self.batch_first:
# change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2)
bs, num_query, _ = query.shape
bs, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
value = self.value_proj(value)
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], 0.0)
value = value.view(bs, num_value, self.num_heads, -1)
sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points)
attention_weights = attention_weights.softmax(-1)
attention_weights = attention_weights.view(bs, num_query,
self.num_heads,
self.num_levels,
self.num_points)
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets \
/ offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.num_points \
* reference_points[:, :, None, :, None, 2:] \
* 0.5
else:
raise ValueError(
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available():
output = MultiScaleDeformableAttnFunction.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
output = self.output_proj(output)
if not self.batch_first:
# (num_query, bs ,embed_dims)
output = output.permute(1, 0, 2)
return self.dropout(output) + identity
import pytest
import torch
from mmcv.cnn.bricks.drop import DropPath
from mmcv.cnn.bricks.transformer import (FFN, BaseTransformerLayer,
MultiheadAttention,
TransformerLayerSequence)
def test_multiheadattention():
MultiheadAttention(
embed_dims=5,
num_heads=5,
attn_drop=0,
proj_drop=0,
dropout_layer=dict(type='Dropout', drop_prob=0.),
batch_first=True)
batch_dim = 2
embed_dim = 5
num_query = 100
attn_batch_first = MultiheadAttention(
embed_dims=5,
num_heads=5,
attn_drop=0,
proj_drop=0,
dropout_layer=dict(type='DropPath', drop_prob=0.),
batch_first=True)
attn_query_first = MultiheadAttention(
embed_dims=5,
num_heads=5,
attn_drop=0,
proj_drop=0,
dropout_layer=dict(type='DropPath', drop_prob=0.),
batch_first=False)
param_dict = dict(attn_query_first.named_parameters())
for n, v in attn_batch_first.named_parameters():
param_dict[n].data = v.data
input_batch_first = torch.rand(batch_dim, num_query, embed_dim)
input_query_first = input_batch_first.transpose(0, 1)
assert torch.allclose(
attn_query_first(input_query_first).sum(),
attn_batch_first(input_batch_first).sum())
key_batch_first = torch.rand(batch_dim, num_query, embed_dim)
key_query_first = key_batch_first.transpose(0, 1)
assert torch.allclose(
attn_query_first(input_query_first, key_query_first).sum(),
attn_batch_first(input_batch_first, key_batch_first).sum())
identity = torch.ones_like(input_query_first)
# check deprecated arguments can be used normally
assert torch.allclose(
attn_query_first(
input_query_first, key_query_first, residual=identity).sum(),
attn_batch_first(input_batch_first, key_batch_first).sum() +
identity.sum() - input_batch_first.sum())
assert torch.allclose(
attn_query_first(
input_query_first, key_query_first, identity=identity).sum(),
attn_batch_first(input_batch_first, key_batch_first).sum() +
identity.sum() - input_batch_first.sum())
attn_query_first(
input_query_first, key_query_first, identity=identity).sum(),
def test_ffn():
with pytest.raises(AssertionError):
# num_fcs should be no less than 2
FFN(num_fcs=1)
FFN(dropout=0, add_residual=True)
ffn = FFN(dropout=0, add_identity=True)
input_tensor = torch.rand(2, 20, 256)
input_tensor_nbc = input_tensor.transpose(0, 1)
assert torch.allclose(ffn(input_tensor).sum(), ffn(input_tensor_nbc).sum())
residual = torch.rand_like(input_tensor)
torch.allclose(
ffn(input_tensor, residual=residual).sum(),
ffn(input_tensor).sum() + residual.sum() - input_tensor.sum())
torch.allclose(
ffn(input_tensor, identity=residual).sum(),
ffn(input_tensor).sum() + residual.sum() - input_tensor.sum())
def test_basetransformerlayer():
attn_cfgs = dict(type='MultiheadAttention', embed_dims=256, num_heads=8),
feedforward_channels = 2048
ffn_dropout = 0.1
operation_order = ('self_attn', 'norm', 'ffn', 'norm')
# test deprecated_args
baselayer = BaseTransformerLayer(
attn_cfgs=attn_cfgs,
feedforward_channels=feedforward_channels,
ffn_dropout=ffn_dropout,
operation_order=operation_order)
assert baselayer.batch_first is False
assert baselayer.ffns[0].feedforward_channels == feedforward_channels
attn_cfgs = dict(type='MultiheadAttention', num_heads=8, embed_dims=256),
feedforward_channels = 2048
ffn_dropout = 0.1
operation_order = ('self_attn', 'norm', 'ffn', 'norm')
baselayer = BaseTransformerLayer(
attn_cfgs=attn_cfgs,
feedforward_channels=feedforward_channels,
ffn_dropout=ffn_dropout,
operation_order=operation_order,
batch_first=True)
assert baselayer.attentions[0].batch_first
in_tensor = torch.rand(2, 10, 256)
baselayer(in_tensor)
def test_transformerlayersequence():
squeue = TransformerLayerSequence(
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
dropout=0.1),
dict(type='MultiheadAttention', embed_dims=256, num_heads=4)
],
feedforward_channels=1024,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn',
'norm')))
assert len(squeue.layers) == 6
assert squeue.pre_norm is False
with pytest.raises(AssertionError):
# if transformerlayers is a list, len(transformerlayers)
# should be equal to num_layers
TransformerLayerSequence(
num_layers=6,
transformerlayers=[
dict(
type='BaseTransformerLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
dropout=0.1),
dict(type='MultiheadAttention', embed_dims=256)
],
feedforward_channels=1024,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm'))
])
def test_drop_path():
drop_path = DropPath(drop_prob=0)
test_in = torch.rand(2, 3, 4, 5)
assert test_in is drop_path(test_in)
drop_path = DropPath(drop_prob=0.1)
drop_path.training = False
test_in = torch.rand(2, 3, 4, 5)
assert test_in is drop_path(test_in)
drop_path.training = True
assert test_in is not drop_path(test_in)
......@@ -2,7 +2,8 @@ import pytest
import torch
from mmcv.ops.multi_scale_deform_attn import (
MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch)
MultiScaleDeformableAttention, MultiScaleDeformableAttnFunction,
multi_scale_deformable_attn_pytorch)
_USING_PARROTS = True
try:
......@@ -98,7 +99,14 @@ def test_forward_equal_with_pytorch_float():
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
@pytest.mark.parametrize('channels', [4, 30, 32, 64, 71, 1025, 2048, 3096])
@pytest.mark.parametrize('channels', [
4,
30,
32,
64,
71,
1025,
])
def test_gradient_numerical(channels,
grad_value=True,
grad_sampling_loc=True,
......@@ -134,3 +142,20 @@ def test_gradient_numerical(channels,
assert gradcheck(func, (value.double(), shapes, level_start_index,
sampling_locations.double(),
attention_weights.double(), im2col_step))
def test_multiscale_deformable_attention():
with pytest.raises(ValueError):
# embed_dims must be divisible by num_heads,
MultiScaleDeformableAttention(
embed_dims=256,
num_heads=7,
)
with pytest.raises(ValueError):
# embed_dims must be divisible by num_heads,
MultiScaleDeformableAttention(
embed_dims=256,
num_heads=7,
)
MultiScaleDeformableAttention(embed_dims=256, num_heads=8)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment