Unverified Commit e05fb560 authored by Shilong Zhang's avatar Shilong Zhang Committed by GitHub
Browse files

Refactor the baseclass related to transformer (#978)



* minor changes

* change to modulist

* change to Sequential

* replace dropout with attn_drop and proj_drop in MultiheadAttention

* add operation_name for attn

* add drop path and move all ffn args to ffncfgs

* fix typo

* fix a bug when use default value of ffn_cfgs

* fix ffns

* add deprecate warning

* fix deprecate warning

* change to pop kwargs

* support register FFN of transformer

* support batch first

* fix batch first wapper

* fix forward wapper

* fix typo

* fix lint

* add unitest for transformer

* fix unitest

* fix equal

* use allclose

* fix comments

* fix comments

* change configdict to dict

* move drop to a file

* add comments for drop path

* add noqa 501

* move bnc wapper to MultiheadAttention

* move bnc wapper to MultiheadAttention

* use dep warning

* resolve comments

* add unitest:

* rename residual to identity

* revert runner

* msda residual to identity

* rename inp_identity to identity

* fix name

* fix transformer

* remove key in msda

* remove assert for key
Co-authored-by: default avatarHIT-cwh <2892770585@qq.com>
Co-authored-by: default avatarbkhuang <congee524@gmail.com>
Co-authored-by: default avatarWenwei Zhang <40779233+ZwwWayne@users.noreply.github.com>
parent 11629d52
...@@ -5,6 +5,7 @@ from .conv2d_adaptive_padding import Conv2dAdaptivePadding ...@@ -5,6 +5,7 @@ from .conv2d_adaptive_padding import Conv2dAdaptivePadding
from .conv_module import ConvModule from .conv_module import ConvModule
from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d
from .depthwise_separable_conv_module import DepthwiseSeparableConvModule from .depthwise_separable_conv_module import DepthwiseSeparableConvModule
from .drop import Dropout, DropPath
from .generalized_attention import GeneralizedAttention from .generalized_attention import GeneralizedAttention
from .hsigmoid import HSigmoid from .hsigmoid import HSigmoid
from .hswish import HSwish from .hswish import HSwish
...@@ -29,5 +30,5 @@ __all__ = [ ...@@ -29,5 +30,5 @@ __all__ = [
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear', 'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
'ConvTranspose3d', 'MaxPool3d', 'Conv3d' 'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'Dropout', 'DropPath'
] ]
import torch
import torch.nn as nn
from mmcv import build_from_cfg
from .registry import DROPOUT_LAYERS
def drop_path(x, drop_prob=0., training=False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
residual blocks).
We follow the implementation
https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
# handle tensors with different dimensions, not just 4D tensors.
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(
shape, dtype=x.dtype, device=x.device)
output = x.div(keep_prob) * random_tensor.floor()
return output
@DROPOUT_LAYERS.register_module()
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
residual blocks).
We follow the implementation
https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
Args:
drop_prob (float): Probability of the path to be zeroed. Default: 0.1
"""
def __init__(self, drop_prob=0.1):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
@DROPOUT_LAYERS.register_module()
class Dropout(nn.Dropout):
"""A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of
``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with
``DropPath``
Args:
drop_prob (float): Probability of the elements to be
zeroed. Default: 0.5.
inplace (bool): Do the operation inplace or not. Default: False.
"""
def __init__(self, drop_prob=0.5, inplace=False):
super().__init__(p=drop_prob, inplace=inplace)
def build_dropout(cfg, default_args=None):
"""Builder for drop out layers."""
return build_from_cfg(cfg, DROPOUT_LAYERS, default_args)
...@@ -7,7 +7,9 @@ PADDING_LAYERS = Registry('padding layer') ...@@ -7,7 +7,9 @@ PADDING_LAYERS = Registry('padding layer')
UPSAMPLE_LAYERS = Registry('upsample layer') UPSAMPLE_LAYERS = Registry('upsample layer')
PLUGIN_LAYERS = Registry('plugin layer') PLUGIN_LAYERS = Registry('plugin layer')
POSITIONAL_ENCODING = Registry('Position encoding') DROPOUT_LAYERS = Registry('drop out layers')
ATTENTION = Registry('Attention') POSITIONAL_ENCODING = Registry('position encoding')
TRANSFORMER_LAYER = Registry('TransformerLayer') ATTENTION = Registry('attention')
TRANSFORMER_LAYER_SEQUENCE = Registry('TransformerLayerSequence') FEEDFORWARD_NETWORK = Registry('feed-forward Network')
TRANSFORMER_LAYER = Registry('transformerLayer')
TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence')
import copy import copy
import math
import warnings import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv import ConfigDict from mmcv import ConfigDict, deprecated_api_warning
from mmcv.cnn import (Linear, build_activation_layer, build_norm_layer, from mmcv.cnn import Linear, build_activation_layer, build_norm_layer
constant_init, xavier_init) from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.ops.multi_scale_deform_attn import (
MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch)
from mmcv.runner.base_module import BaseModule
from mmcv.utils import build_from_cfg from mmcv.utils import build_from_cfg
from .registry import (ATTENTION, POSITIONAL_ENCODING, TRANSFORMER_LAYER, from .drop import build_dropout
TRANSFORMER_LAYER_SEQUENCE) from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
def build_positional_encoding(cfg, default_args=None): def build_positional_encoding(cfg, default_args=None):
...@@ -26,6 +23,11 @@ def build_attention(cfg, default_args=None): ...@@ -26,6 +23,11 @@ def build_attention(cfg, default_args=None):
return build_from_cfg(cfg, ATTENTION, default_args) return build_from_cfg(cfg, ATTENTION, default_args)
def build_feedforward_network(cfg, default_args=None):
"""Builder for feed-forward network (FFN)."""
return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args)
def build_transformer_layer(cfg, default_args=None): def build_transformer_layer(cfg, default_args=None):
"""Builder for transformer layer.""" """Builder for transformer layer."""
return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args) return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args)
...@@ -38,39 +40,82 @@ def build_transformer_layer_sequence(cfg, default_args=None): ...@@ -38,39 +40,82 @@ def build_transformer_layer_sequence(cfg, default_args=None):
@ATTENTION.register_module() @ATTENTION.register_module()
class MultiheadAttention(BaseModule): class MultiheadAttention(BaseModule):
"""A warpper for torch.nn.MultiheadAttention. """A wrapper for ``torch.nn.MultiheadAttention``.
This module implements MultiheadAttention with residual connection, This module implements MultiheadAttention with identity connection,
and positional encoding used in DETR is also passed as input. and positional encoding is also passed as input.
Args: Args:
embed_dims (int): The embedding dimension. embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads. Same as num_heads (int): Parallel attention heads.
`nn.MultiheadAttention`. attn_drop (float): A Dropout layer on attn_output_weights.
dropout (float):w A Dropout layer on attn_output_weights. Default: 0.. Default: 0.0.
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
Default: 0.0.
dropout_layer (obj:`ConfigDict`): The dropout_layer used
when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None. Default: None.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
""" """
def __init__(self, def __init__(self,
embed_dims, embed_dims,
num_heads, num_heads,
dropout=0., attn_drop=0.,
proj_drop=0.,
dropout_layer=dict(type='Dropout', drop_prob=0.),
init_cfg=None, init_cfg=None,
batch_first=False,
**kwargs): **kwargs):
super(MultiheadAttention, self).__init__(init_cfg) super(MultiheadAttention, self).__init__(init_cfg)
if 'dropout' in kwargs:
warnings.warn('The arguments `dropout` in MultiheadAttention '
'has been deprecated, now you can separately '
'set `attn_drop`(float), proj_drop(float), '
'and `dropout_layer`(dict) ')
attn_drop = kwargs['dropout']
dropout_layer['drop_prob'] = kwargs.pop('dropout')
self.embed_dims = embed_dims self.embed_dims = embed_dims
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.batch_first = batch_first
self.attn = nn.MultiheadAttention(embed_dims, num_heads, dropout,
self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
**kwargs) **kwargs)
self.dropout = nn.Dropout(dropout) if self.batch_first:
def _bnc_to_nbc(forward):
"""This function can adjust the shape of dataflow('key',
'query', 'value') from batch_first (batch, num_query,
embed_dims) to num_query_first (num_query ,batch,
embed_dims)."""
def forward_wrapper(**kwargs):
convert_keys = ('key', 'query', 'value')
for key in kwargs.keys():
if key in convert_keys:
kwargs[key] = kwargs[key].transpose(0, 1)
attn_output, attn_output_weights = forward(**kwargs)
return attn_output.transpose(0, 1), attn_output_weights
return forward_wrapper
self.attn.forward = _bnc_to_nbc(self.attn.forward)
self.proj_drop = nn.Dropout(proj_drop)
self.dropout_layer = build_dropout(
dropout_layer) if dropout_layer else nn.Identity()
@deprecated_api_warning({'residual': 'identity'},
cls_name='MultiheadAttention')
def forward(self, def forward(self,
query, query,
key=None, key=None,
value=None, value=None,
residual=None, identity=None,
query_pos=None, query_pos=None,
key_pos=None, key_pos=None,
attn_mask=None, attn_mask=None,
...@@ -83,15 +128,17 @@ class MultiheadAttention(BaseModule): ...@@ -83,15 +128,17 @@ class MultiheadAttention(BaseModule):
Args: Args:
query (Tensor): The input query with shape [num_queries, bs, query (Tensor): The input query with shape [num_queries, bs,
embed_dims]. Same in `nn.MultiheadAttention.forward`. embed_dims] if self.batch_first is False, else
[bs, num_queries embed_dims].
key (Tensor): The key tensor with shape [num_keys, bs, key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims]. Same in `nn.MultiheadAttention.forward`. embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
If None, the ``query`` will be used. Defaults to None. If None, the ``query`` will be used. Defaults to None.
value (Tensor): The value tensor with same shape as `key`. value (Tensor): The value tensor with same shape as `key`.
Same in `nn.MultiheadAttention.forward`. Defaults to None. Same in `nn.MultiheadAttention.forward`. Defaults to None.
If None, the `key` will be used. If None, the `key` will be used.
residual (Tensor): This tensor, with the same shape as x, identity (Tensor): This tensor, with the same shape as x,
will be used for the residual link. will be used for the identity link.
If None, `x` will be used. Defaults to None. If None, `x` will be used. Defaults to None.
query_pos (Tensor): The positional encoding for query, with query_pos (Tensor): The positional encoding for query, with
the same shape as `x`. If not None, it will the same shape as `x`. If not None, it will
...@@ -105,18 +152,21 @@ class MultiheadAttention(BaseModule): ...@@ -105,18 +152,21 @@ class MultiheadAttention(BaseModule):
num_keys]. Same in `nn.MultiheadAttention.forward`. num_keys]. Same in `nn.MultiheadAttention.forward`.
Defaults to None. Defaults to None.
key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
Same in `nn.MultiheadAttention.forward`. Defaults to None. Defaults to None.
Returns: Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims]. Tensor: forwarded results with shape
[num_queries, bs, embed_dims]
if self.batch_first is False, else
[bs, num_queries embed_dims].
""" """
if key is None: if key is None:
key = query key = query
if value is None: if value is None:
value = key value = key
if residual is None: if identity is None:
residual = query identity = query
if key_pos is None: if key_pos is None:
if query_pos is not None: if query_pos is not None:
# use query_pos if key_pos is not available # use query_pos if key_pos is not available
...@@ -129,238 +179,56 @@ class MultiheadAttention(BaseModule): ...@@ -129,238 +179,56 @@ class MultiheadAttention(BaseModule):
query = query + query_pos query = query + query_pos
if key_pos is not None: if key_pos is not None:
key = key + key_pos key = key + key_pos
out = self.attn( out = self.attn(
query, query=query,
key, key=key,
value=value, value=value,
attn_mask=attn_mask, attn_mask=attn_mask,
key_padding_mask=key_padding_mask)[0] key_padding_mask=key_padding_mask)[0]
return residual + self.dropout(out) return identity + self.dropout_layer(self.proj_drop(out))
@ATTENTION.register_module()
class MultiScaleDeformableAttention(BaseModule):
"""An attention module used in Deformable-Detr. `Deformable DETR:
Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_residual`.
Default: 0..
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=4,
im2col_step=64,
dropout=0.1,
norm_cfg=None,
init_cfg=None):
super().__init__(init_cfg)
if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, '
f'but got {embed_dims} and {num_heads}')
dim_per_head = embed_dims // num_heads
self.norm_cfg = norm_cfg
self.init_cfg = init_cfg
self.dropout = nn.Dropout(dropout)
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError(
'invalid input for _is_power_of_2: {} (type: {})'.format(
n, type(n)))
return (n & (n - 1) == 0) and n != 0
if not _is_power_of_2(dim_per_head):
warnings.warn(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.')
self.im2col_step = im2col_step
self.embed_dims = embed_dims
self.num_levels = num_levels
self.num_heads = num_heads
self.num_points = num_points
self.sampling_offsets = nn.Linear(
embed_dims, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dims,
num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dims, embed_dims)
self.output_proj = nn.Linear(embed_dims, embed_dims)
self.init_weight()
def init_weight(self):
"""Default initialization for Parameters of Module."""
constant_init(self.sampling_offsets, 0.)
thetas = torch.arange(
self.num_heads,
dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init /
grid_init.abs().max(-1, keepdim=True)[0]).view(
self.num_heads, 1, 1,
2).repeat(1, self.num_levels, self.num_points, 1)
for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1
self.sampling_offsets.bias.data = grid_init.view(-1)
constant_init(self.attention_weights, val=0., bias=0.)
xavier_init(self.value_proj, distribution='uniform', bias=0.)
xavier_init(self.output_proj, distribution='uniform', bias=0.)
def forward(self,
query,
key,
value,
residual=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
residual (Tensor): The tensor used for addition, with the
same shape as `x`. Default None. If None, `x` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different level. With shape (num_levels, 2),
last dimension represent (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape (num_levels) and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if key is None:
key = query
if value is None:
value = key
if residual is None:
inp_residual = query
if query_pos is not None:
query = query + query_pos
# change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2)
bs, num_query, _ = query.shape
bs, num_key, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_key
value = self.value_proj(value)
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], 0.0)
value = value.view(bs, num_key, self.num_heads, -1)
sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points)
attention_weights = attention_weights.softmax(-1)
attention_weights = attention_weights.view(bs, num_query,
self.num_heads,
self.num_levels,
self.num_points)
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets \
/ offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.num_points \
* reference_points[:, :, None, :, None, 2:] \
* 0.5
else:
raise ValueError(
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available():
output = MultiScaleDeformableAttnFunction.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
output = self.output_proj(output).permute(1, 0, 2)
# (num_query, bs ,embed_dims)
return self.dropout(output) + inp_residual
@FEEDFORWARD_NETWORK.register_module()
class FFN(BaseModule): class FFN(BaseModule):
"""Implements feed-forward networks (FFNs) with residual connection. """Implements feed-forward networks (FFNs) with identity connection.
Args: Args:
embed_dims (int): The feature dimension. Same as embed_dims (int): The feature dimension. Same as
`MultiheadAttention`. `MultiheadAttention`. Defaults: 256.
feedforward_channels (int): The hidden dimension of FFNs. feedforward_channels (int): The hidden dimension of FFNs.
Defaults: 1024.
num_fcs (int, optional): The number of fully-connected layers in num_fcs (int, optional): The number of fully-connected layers in
FFNs. Default: 2. FFNs. Default: 2.
act_cfg (dict, optional): The activation config for FFNs. act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='ReLU') Default: dict(type='ReLU')
dropout (float, optional): Probability of an element to be ffn_drop (float, optional): Probability of an element to be
zeroed. Default 0.. zeroed in FFN. Default 0.0.
add_residual (bool, optional): Whether to add the add_identity (bool, optional): Whether to add the
residual connection. Default: `True`. identity connection. Default: `True`.
dropout_layer (obj:`ConfigDict`): The dropout_layer used
when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None. Default: None.
""" """
@deprecated_api_warning(
{
'dropout': 'ffn_drop',
'add_residual': 'add_identity'
},
cls_name='FFN')
def __init__(self, def __init__(self,
embed_dims, embed_dims=256,
feedforward_channels, feedforward_channels=1024,
num_fcs=2, num_fcs=2,
act_cfg=dict(type='ReLU', inplace=True), act_cfg=dict(type='ReLU', inplace=True),
dropout=0., ffn_drop=0.,
add_residual=True, dropout_layer=None,
init_cfg=None): add_identity=True,
init_cfg=None,
**kwargs):
super(FFN, self).__init__(init_cfg) super(FFN, self).__init__(init_cfg)
assert num_fcs >= 2, 'num_fcs should be no less ' \ assert num_fcs >= 2, 'num_fcs should be no less ' \
f'than 2. got {num_fcs}.' f'than 2. got {num_fcs}.'
...@@ -368,33 +236,35 @@ class FFN(BaseModule): ...@@ -368,33 +236,35 @@ class FFN(BaseModule):
self.feedforward_channels = feedforward_channels self.feedforward_channels = feedforward_channels
self.num_fcs = num_fcs self.num_fcs = num_fcs
self.act_cfg = act_cfg self.act_cfg = act_cfg
self.dropout = dropout
self.activate = build_activation_layer(act_cfg) self.activate = build_activation_layer(act_cfg)
layers = [] layers = []
in_channels = embed_dims in_channels = embed_dims
for _ in range(num_fcs - 1): for _ in range(num_fcs - 1):
layers.append( layers.append(
nn.Sequential( Sequential(
Linear(in_channels, feedforward_channels), self.activate, Linear(in_channels, feedforward_channels), self.activate,
nn.Dropout(dropout))) nn.Dropout(ffn_drop)))
in_channels = feedforward_channels in_channels = feedforward_channels
layers.append(Linear(feedforward_channels, embed_dims)) layers.append(Linear(feedforward_channels, embed_dims))
self.layers = nn.Sequential(*layers) layers.append(nn.Dropout(ffn_drop))
self.dropout = nn.Dropout(dropout) self.layers = Sequential(*layers)
self.add_residual = add_residual self.dropout_layer = build_dropout(
dropout_layer) if dropout_layer else torch.nn.Identity()
def forward(self, x, residual=None): self.add_identity = add_identity
@deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
def forward(self, x, identity=None):
"""Forward function for `FFN`. """Forward function for `FFN`.
The function would add x to the output tensor if residue is None. The function would add x to the output tensor if residue is None.
""" """
out = self.layers(x) out = self.layers(x)
if not self.add_residual: if not self.add_identity:
return self.dropout(out) return self.dropout_layer(out)
if residual is None: if identity is None:
residual = x identity = x
return residual + self.dropout(out) return identity + self.dropout_layer(out)
@TRANSFORMER_LAYER.register_module() @TRANSFORMER_LAYER.register_module()
...@@ -416,85 +286,121 @@ class BaseTransformerLayer(BaseModule): ...@@ -416,85 +286,121 @@ class BaseTransformerLayer(BaseModule):
corresponding attentions in operation_order. corresponding attentions in operation_order.
If it is a dict, all of the attention modules in operation_order If it is a dict, all of the attention modules in operation_order
will be built with this config. Default: None. will be built with this config. Default: None.
feedforward_channels (int): The hidden dimension for FFNs. ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
Default: None. Configs for FFN, The order of the configs in the list should be
ffn_dropout (float): Probability of an element to be zeroed consistent with corresponding ffn in operation_order.
in ffn. Default 0.. If it is a dict, all of the attention modules in operation_order
will be built with this config.
operation_order (tuple[str]): The execution order of operation operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
Support `prenorm` when you specifying first element as `norm`. Support `prenorm` when you specifying first element as `norm`.
Default:None. Default:None.
act_cfg (dict): The activation config for FFNs.
Default: dict(type='ReLU')
norm_cfg (dict): Config dict for normalization layer. norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN'). Default: dict(type='LN').
ffn_num_fcs (int): The number of fully-connected layers in FFNs.
Default:2.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None. Default: None.
batch_first (bool): Key, Query and Value are shape
of (batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
""" """
def __init__(self, def __init__(self,
attn_cfgs=None, attn_cfgs=None,
feedforward_channels=None, ffn_cfgs=dict(
ffn_dropout=0., type='FFN',
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
ffn_drop=0.,
act_cfg=dict(type='ReLU', inplace=True),
),
operation_order=None, operation_order=None,
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN'), norm_cfg=dict(type='LN'),
ffn_num_fcs=2, init_cfg=None,
init_cfg=None): batch_first=False,
**kwargs):
deprecated_args = dict(
feedforward_channels='feedforward_channels',
ffn_dropout='ffn_drop',
ffn_num_fcs='num_fcs')
for ori_name, new_name in deprecated_args.items():
if ori_name in kwargs:
warnings.warn(
f'The arguments `{ori_name}` in BaseTransformerLayer '
f'has been deprecated, now you should set `{new_name}` '
f'and other FFN related arguments '
f'to a dict named `ffn_cfgs`. ')
ffn_cfgs[new_name] = kwargs[ori_name]
super(BaseTransformerLayer, self).__init__(init_cfg) super(BaseTransformerLayer, self).__init__(init_cfg)
self.batch_first = batch_first
assert set(operation_order) & set( assert set(operation_order) & set(
['self_attn', 'norm', 'ffn', 'cross_attn']) == \ ['self_attn', 'norm', 'ffn', 'cross_attn']) == \
set(operation_order), f'The operation_order of' \ set(operation_order), f'The operation_order of' \
f' {self.__class__.__name__} should ' \ f' {self.__class__.__name__} should ' \
f'contains all four operation type ' \ f'contains all four operation type ' \
f"{['self_attn', 'norm', 'ffn', 'cross_attn']}" f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
num_attn = operation_order.count('self_attn') + operation_order.count( num_attn = operation_order.count('self_attn') + operation_order.count(
'cross_attn') 'cross_attn')
if isinstance(attn_cfgs, ConfigDict): if isinstance(attn_cfgs, dict):
attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)] attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
else: else:
assert num_attn == len(attn_cfgs), f'The length ' \ assert num_attn == len(attn_cfgs), f'The length ' \
f'of attn_cfg {num_attn} is ' \ f'of attn_cfg {num_attn} is ' \
f'not consistent with the number of attention' \ f'not consistent with the number of attention' \
f'in operation_order {operation_order}.' f'in operation_order {operation_order}.'
self.init_cfg = init_cfg
self.num_attn = num_attn self.num_attn = num_attn
self.feedforward_channels = feedforward_channels
self.ffn_dropout = ffn_dropout
self.operation_order = operation_order self.operation_order = operation_order
self.act_cfg = act_cfg
self.norm_cfg = norm_cfg self.norm_cfg = norm_cfg
self.ffn_num_fcs = ffn_num_fcs
self.pre_norm = operation_order[0] == 'norm' self.pre_norm = operation_order[0] == 'norm'
self.attentions = nn.ModuleList() self.attentions = ModuleList()
index = 0 index = 0
for operation in operation_order: for operation_name in operation_order:
if operation in ['self_attn', 'cross_attn']: if operation_name in ['self_attn', 'cross_attn']:
if 'batch_first' in attn_cfgs[index]:
assert self.batch_first == attn_cfgs[index]['batch_first']
else:
attn_cfgs[index]['batch_first'] = self.batch_first
attention = build_attention(attn_cfgs[index]) attention = build_attention(attn_cfgs[index])
# Some custom attentions used as `self_attn`
# or `cross_attn` can have different behavior.
attention.operation_name = operation_name
self.attentions.append(attention) self.attentions.append(attention)
index += 1 index += 1
self.embed_dims = self.attentions[0].embed_dims self.embed_dims = self.attentions[0].embed_dims
self.ffns = nn.ModuleList()
self.ffns = ModuleList()
num_ffns = operation_order.count('ffn') num_ffns = operation_order.count('ffn')
for _ in range(num_ffns): if isinstance(ffn_cfgs, dict):
ffn_cfgs = ConfigDict(ffn_cfgs)
if isinstance(ffn_cfgs, dict):
ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
assert len(ffn_cfgs) == num_ffns
for ffn_index in range(num_ffns):
if 'embed_dims' not in ffn_cfgs[ffn_index]:
ffn_cfgs['embed_dims'] = self.embed_dims
else:
assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
self.ffns.append( self.ffns.append(
FFN(self.embed_dims, feedforward_channels, ffn_num_fcs, build_feedforward_network(ffn_cfgs[ffn_index],
act_cfg, ffn_dropout)) dict(type='FFN')))
self.norms = nn.ModuleList() self.norms = ModuleList()
num_norms = operation_order.count('norm') num_norms = operation_order.count('norm')
for _ in range(num_norms): for _ in range(num_norms):
self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1]) self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
def forward(self, def forward(self,
query, query,
key, key=None,
value, value=None,
query_pos=None, query_pos=None,
key_pos=None, key_pos=None,
attn_masks=None, attn_masks=None,
...@@ -506,12 +412,14 @@ class BaseTransformerLayer(BaseModule): ...@@ -506,12 +412,14 @@ class BaseTransformerLayer(BaseModule):
**kwargs contains some specific arguments of attentions. **kwargs contains some specific arguments of attentions.
Args: Args:
query (Tensor): Input query with the shape query (Tensor): The input query with shape
`(num_queries, bs, embed_dims)`. [num_queries, bs, embed_dims] if
key (Tensor): The key tensor with shape self.batch_first is False, else
`(num_keys, bs, embed_dims)`. [bs, num_queries embed_dims].
value (Tensor): The value tensor with shape key (Tensor): The key tensor with shape [num_keys, bs,
`(num_keys, bs, embed_dims)`. embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims] .
value (Tensor): The value tensor with same shape as `key`.
query_pos (Tensor): The positional encoding for `query`. query_pos (Tensor): The positional encoding for `query`.
Default: None. Default: None.
key_pos (Tensor): The positional encoding for `key`. key_pos (Tensor): The positional encoding for `key`.
...@@ -533,7 +441,7 @@ class BaseTransformerLayer(BaseModule): ...@@ -533,7 +441,7 @@ class BaseTransformerLayer(BaseModule):
norm_index = 0 norm_index = 0
attn_index = 0 attn_index = 0
ffn_index = 0 ffn_index = 0
inp_residual = query identity = query
if attn_masks is None: if attn_masks is None:
attn_masks = [None for _ in range(self.num_attn)] attn_masks = [None for _ in range(self.num_attn)]
elif isinstance(attn_masks, torch.Tensor): elif isinstance(attn_masks, torch.Tensor):
...@@ -555,14 +463,14 @@ class BaseTransformerLayer(BaseModule): ...@@ -555,14 +463,14 @@ class BaseTransformerLayer(BaseModule):
query, query,
temp_key, temp_key,
temp_value, temp_value,
inp_residual if self.pre_norm else None, identity if self.pre_norm else None,
query_pos=query_pos, query_pos=query_pos,
key_pos=query_pos, key_pos=query_pos,
attn_mask=attn_masks[attn_index], attn_mask=attn_masks[attn_index],
key_padding_mask=query_key_padding_mask, key_padding_mask=query_key_padding_mask,
**kwargs) **kwargs)
attn_index += 1 attn_index += 1
inp_residual = query identity = query
elif layer == 'norm': elif layer == 'norm':
query = self.norms[norm_index](query) query = self.norms[norm_index](query)
...@@ -573,18 +481,18 @@ class BaseTransformerLayer(BaseModule): ...@@ -573,18 +481,18 @@ class BaseTransformerLayer(BaseModule):
query, query,
key, key,
value, value,
inp_residual if self.pre_norm else None, identity if self.pre_norm else None,
query_pos=query_pos, query_pos=query_pos,
key_pos=key_pos, key_pos=key_pos,
attn_mask=attn_masks[attn_index], attn_mask=attn_masks[attn_index],
key_padding_mask=key_padding_mask, key_padding_mask=key_padding_mask,
**kwargs) **kwargs)
attn_index += 1 attn_index += 1
inp_residual = query identity = query
elif layer == 'ffn': elif layer == 'ffn':
query = self.ffns[ffn_index]( query = self.ffns[ffn_index](
query, inp_residual if self.pre_norm else None) query, identity if self.pre_norm else None)
ffn_index += 1 ffn_index += 1
return query return query
...@@ -612,7 +520,7 @@ class TransformerLayerSequence(BaseModule): ...@@ -612,7 +520,7 @@ class TransformerLayerSequence(BaseModule):
def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None): def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
super(TransformerLayerSequence, self).__init__(init_cfg) super(TransformerLayerSequence, self).__init__(init_cfg)
if isinstance(transformerlayers, ConfigDict): if isinstance(transformerlayers, dict):
transformerlayers = [ transformerlayers = [
copy.deepcopy(transformerlayers) for _ in range(num_layers) copy.deepcopy(transformerlayers) for _ in range(num_layers)
] ]
...@@ -620,13 +528,11 @@ class TransformerLayerSequence(BaseModule): ...@@ -620,13 +528,11 @@ class TransformerLayerSequence(BaseModule):
assert isinstance(transformerlayers, list) and \ assert isinstance(transformerlayers, list) and \
len(transformerlayers) == num_layers len(transformerlayers) == num_layers
self.num_layers = num_layers self.num_layers = num_layers
operation_order = transformerlayers[0]['operation_order'] self.layers = ModuleList()
self.pre_norm = operation_order[0] == 'norm'
self.layers = nn.ModuleList()
for i in range(num_layers): for i in range(num_layers):
self.layers.append(build_transformer_layer(transformerlayers[i])) self.layers.append(build_transformer_layer(transformerlayers[i]))
self.embed_dims = self.layers[0].embed_dims self.embed_dims = self.layers[0].embed_dims
self.pre_norm = self.layers[0].operation_order[0] == 'norm' self.pre_norm = self.layers[0].pre_norm
def forward(self, def forward(self,
query, query,
...@@ -661,7 +567,7 @@ class TransformerLayerSequence(BaseModule): ...@@ -661,7 +567,7 @@ class TransformerLayerSequence(BaseModule):
shape [bs, num_keys]. Default: None. shape [bs, num_keys]. Default: None.
Returns: Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims]. Tensor: results with shape [num_queries, bs, embed_dims].
""" """
for layer in self.layers: for layer in self.layers:
query = layer( query = layer(
......
...@@ -21,6 +21,7 @@ from .masked_conv import MaskedConv2d, masked_conv2d ...@@ -21,6 +21,7 @@ from .masked_conv import MaskedConv2d, masked_conv2d
from .modulated_deform_conv import (ModulatedDeformConv2d, from .modulated_deform_conv import (ModulatedDeformConv2d,
ModulatedDeformConv2dPack, ModulatedDeformConv2dPack,
modulated_deform_conv2d) modulated_deform_conv2d)
from .multi_scale_deform_attn import MultiScaleDeformableAttention
from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms
from .pixel_group import pixel_group from .pixel_group import pixel_group
from .point_sample import (SimpleRoIAlign, point_sample, from .point_sample import (SimpleRoIAlign, point_sample,
...@@ -50,5 +51,5 @@ __all__ = [ ...@@ -50,5 +51,5 @@ __all__ = [
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated', 'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated',
'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu', 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand', 'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand',
'BorderAlign', 'border_align' 'MultiScaleDeformableAttention', 'BorderAlign', 'border_align'
] ]
import math
import warnings
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd.function import Function, once_differentiable from torch.autograd.function import Function, once_differentiable
from mmcv import deprecated_api_warning
from mmcv.cnn import constant_init, xavier_init
from mmcv.cnn.bricks.registry import ATTENTION
from mmcv.runner import BaseModule
from ..utils import ext_loader from ..utils import ext_loader
ext_module = ext_loader.load_ext( ext_module = ext_loader.load_ext(
...@@ -140,3 +148,211 @@ def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes, ...@@ -140,3 +148,211 @@ def multi_scale_deformable_attn_pytorch(value, value_spatial_shapes,
attention_weights).sum(-1).view(bs, num_heads * embed_dims, attention_weights).sum(-1).view(bs, num_heads * embed_dims,
num_queries) num_queries)
return output.transpose(1, 2).contiguous() return output.transpose(1, 2).contiguous()
@ATTENTION.register_module()
class MultiScaleDeformableAttention(BaseModule):
"""An attention module used in Deformable-Detr. `Deformable DETR:
Deformable Transformers for End-to-End Object Detection.
<https://arxiv.org/pdf/2010.04159.pdf>`_.
Args:
embed_dims (int): The embedding dimension of Attention.
Default: 256.
num_heads (int): Parallel attention heads. Default: 64.
num_levels (int): The number of feature map used in
Attention. Default: 4.
num_points (int): The number of sampling points for
each query in each head. Default: 4.
im2col_step (int): The step used in image_to_column.
Default: 64.
dropout (float): A Dropout layer on `inp_identity`.
Default: 0.1.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
norm_cfg (dict): Config dict for normalization layer.
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=4,
im2col_step=64,
dropout=0.1,
batch_first=False,
norm_cfg=None,
init_cfg=None):
super().__init__(init_cfg)
if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, '
f'but got {embed_dims} and {num_heads}')
dim_per_head = embed_dims // num_heads
self.norm_cfg = norm_cfg
self.dropout = nn.Dropout(dropout)
self.batch_first = batch_first
# you'd better set dim_per_head to a power of 2
# which is more efficient in the CUDA implementation
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError(
'invalid input for _is_power_of_2: {} (type: {})'.format(
n, type(n)))
return (n & (n - 1) == 0) and n != 0
if not _is_power_of_2(dim_per_head):
warnings.warn(
"You'd better set embed_dims in "
'MultiScaleDeformAttention to make '
'the dimension of each attention head a power of 2 '
'which is more efficient in our CUDA implementation.')
self.im2col_step = im2col_step
self.embed_dims = embed_dims
self.num_levels = num_levels
self.num_heads = num_heads
self.num_points = num_points
self.sampling_offsets = nn.Linear(
embed_dims, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dims,
num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dims, embed_dims)
self.output_proj = nn.Linear(embed_dims, embed_dims)
self.init_weights()
def init_weights(self):
"""Default initialization for Parameters of Module."""
constant_init(self.sampling_offsets, 0.)
thetas = torch.arange(
self.num_heads,
dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init /
grid_init.abs().max(-1, keepdim=True)[0]).view(
self.num_heads, 1, 1,
2).repeat(1, self.num_levels, self.num_points, 1)
for i in range(self.num_points):
grid_init[:, :, i, :] *= i + 1
self.sampling_offsets.bias.data = grid_init.view(-1)
constant_init(self.attention_weights, val=0., bias=0.)
xavier_init(self.value_proj, distribution='uniform', bias=0.)
xavier_init(self.output_proj, distribution='uniform', bias=0.)
self._is_init = True
@deprecated_api_warning({'residual': 'identity'},
cls_name='MultiScaleDeformableAttention')
def forward(self,
query,
key=None,
value=None,
identity=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
"""Forward Function of MultiScaleDeformAttention.
Args:
query (Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
identity (Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`. Default
None.
reference_points (Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_key].
spatial_shapes (Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if value is None:
value = query
if identity is None:
identity = query
if query_pos is not None:
query = query + query_pos
if not self.batch_first:
# change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2)
bs, num_query, _ = query.shape
bs, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
value = self.value_proj(value)
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], 0.0)
value = value.view(bs, num_value, self.num_heads, -1)
sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points)
attention_weights = attention_weights.softmax(-1)
attention_weights = attention_weights.view(bs, num_query,
self.num_heads,
self.num_levels,
self.num_points)
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets \
/ offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.num_points \
* reference_points[:, :, None, :, None, 2:] \
* 0.5
else:
raise ValueError(
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available():
output = MultiScaleDeformableAttnFunction.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
output = self.output_proj(output)
if not self.batch_first:
# (num_query, bs ,embed_dims)
output = output.permute(1, 0, 2)
return self.dropout(output) + identity
import pytest
import torch
from mmcv.cnn.bricks.drop import DropPath
from mmcv.cnn.bricks.transformer import (FFN, BaseTransformerLayer,
MultiheadAttention,
TransformerLayerSequence)
def test_multiheadattention():
MultiheadAttention(
embed_dims=5,
num_heads=5,
attn_drop=0,
proj_drop=0,
dropout_layer=dict(type='Dropout', drop_prob=0.),
batch_first=True)
batch_dim = 2
embed_dim = 5
num_query = 100
attn_batch_first = MultiheadAttention(
embed_dims=5,
num_heads=5,
attn_drop=0,
proj_drop=0,
dropout_layer=dict(type='DropPath', drop_prob=0.),
batch_first=True)
attn_query_first = MultiheadAttention(
embed_dims=5,
num_heads=5,
attn_drop=0,
proj_drop=0,
dropout_layer=dict(type='DropPath', drop_prob=0.),
batch_first=False)
param_dict = dict(attn_query_first.named_parameters())
for n, v in attn_batch_first.named_parameters():
param_dict[n].data = v.data
input_batch_first = torch.rand(batch_dim, num_query, embed_dim)
input_query_first = input_batch_first.transpose(0, 1)
assert torch.allclose(
attn_query_first(input_query_first).sum(),
attn_batch_first(input_batch_first).sum())
key_batch_first = torch.rand(batch_dim, num_query, embed_dim)
key_query_first = key_batch_first.transpose(0, 1)
assert torch.allclose(
attn_query_first(input_query_first, key_query_first).sum(),
attn_batch_first(input_batch_first, key_batch_first).sum())
identity = torch.ones_like(input_query_first)
# check deprecated arguments can be used normally
assert torch.allclose(
attn_query_first(
input_query_first, key_query_first, residual=identity).sum(),
attn_batch_first(input_batch_first, key_batch_first).sum() +
identity.sum() - input_batch_first.sum())
assert torch.allclose(
attn_query_first(
input_query_first, key_query_first, identity=identity).sum(),
attn_batch_first(input_batch_first, key_batch_first).sum() +
identity.sum() - input_batch_first.sum())
attn_query_first(
input_query_first, key_query_first, identity=identity).sum(),
def test_ffn():
with pytest.raises(AssertionError):
# num_fcs should be no less than 2
FFN(num_fcs=1)
FFN(dropout=0, add_residual=True)
ffn = FFN(dropout=0, add_identity=True)
input_tensor = torch.rand(2, 20, 256)
input_tensor_nbc = input_tensor.transpose(0, 1)
assert torch.allclose(ffn(input_tensor).sum(), ffn(input_tensor_nbc).sum())
residual = torch.rand_like(input_tensor)
torch.allclose(
ffn(input_tensor, residual=residual).sum(),
ffn(input_tensor).sum() + residual.sum() - input_tensor.sum())
torch.allclose(
ffn(input_tensor, identity=residual).sum(),
ffn(input_tensor).sum() + residual.sum() - input_tensor.sum())
def test_basetransformerlayer():
attn_cfgs = dict(type='MultiheadAttention', embed_dims=256, num_heads=8),
feedforward_channels = 2048
ffn_dropout = 0.1
operation_order = ('self_attn', 'norm', 'ffn', 'norm')
# test deprecated_args
baselayer = BaseTransformerLayer(
attn_cfgs=attn_cfgs,
feedforward_channels=feedforward_channels,
ffn_dropout=ffn_dropout,
operation_order=operation_order)
assert baselayer.batch_first is False
assert baselayer.ffns[0].feedforward_channels == feedforward_channels
attn_cfgs = dict(type='MultiheadAttention', num_heads=8, embed_dims=256),
feedforward_channels = 2048
ffn_dropout = 0.1
operation_order = ('self_attn', 'norm', 'ffn', 'norm')
baselayer = BaseTransformerLayer(
attn_cfgs=attn_cfgs,
feedforward_channels=feedforward_channels,
ffn_dropout=ffn_dropout,
operation_order=operation_order,
batch_first=True)
assert baselayer.attentions[0].batch_first
in_tensor = torch.rand(2, 10, 256)
baselayer(in_tensor)
def test_transformerlayersequence():
squeue = TransformerLayerSequence(
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
dropout=0.1),
dict(type='MultiheadAttention', embed_dims=256, num_heads=4)
],
feedforward_channels=1024,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn',
'norm')))
assert len(squeue.layers) == 6
assert squeue.pre_norm is False
with pytest.raises(AssertionError):
# if transformerlayers is a list, len(transformerlayers)
# should be equal to num_layers
TransformerLayerSequence(
num_layers=6,
transformerlayers=[
dict(
type='BaseTransformerLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
dropout=0.1),
dict(type='MultiheadAttention', embed_dims=256)
],
feedforward_channels=1024,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm'))
])
def test_drop_path():
drop_path = DropPath(drop_prob=0)
test_in = torch.rand(2, 3, 4, 5)
assert test_in is drop_path(test_in)
drop_path = DropPath(drop_prob=0.1)
drop_path.training = False
test_in = torch.rand(2, 3, 4, 5)
assert test_in is drop_path(test_in)
drop_path.training = True
assert test_in is not drop_path(test_in)
...@@ -2,7 +2,8 @@ import pytest ...@@ -2,7 +2,8 @@ import pytest
import torch import torch
from mmcv.ops.multi_scale_deform_attn import ( from mmcv.ops.multi_scale_deform_attn import (
MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch) MultiScaleDeformableAttention, MultiScaleDeformableAttnFunction,
multi_scale_deformable_attn_pytorch)
_USING_PARROTS = True _USING_PARROTS = True
try: try:
...@@ -98,7 +99,14 @@ def test_forward_equal_with_pytorch_float(): ...@@ -98,7 +99,14 @@ def test_forward_equal_with_pytorch_float():
@pytest.mark.skipif( @pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support') not torch.cuda.is_available(), reason='requires CUDA support')
@pytest.mark.parametrize('channels', [4, 30, 32, 64, 71, 1025, 2048, 3096]) @pytest.mark.parametrize('channels', [
4,
30,
32,
64,
71,
1025,
])
def test_gradient_numerical(channels, def test_gradient_numerical(channels,
grad_value=True, grad_value=True,
grad_sampling_loc=True, grad_sampling_loc=True,
...@@ -134,3 +142,20 @@ def test_gradient_numerical(channels, ...@@ -134,3 +142,20 @@ def test_gradient_numerical(channels,
assert gradcheck(func, (value.double(), shapes, level_start_index, assert gradcheck(func, (value.double(), shapes, level_start_index,
sampling_locations.double(), sampling_locations.double(),
attention_weights.double(), im2col_step)) attention_weights.double(), im2col_step))
def test_multiscale_deformable_attention():
with pytest.raises(ValueError):
# embed_dims must be divisible by num_heads,
MultiScaleDeformableAttention(
embed_dims=256,
num_heads=7,
)
with pytest.raises(ValueError):
# embed_dims must be divisible by num_heads,
MultiScaleDeformableAttention(
embed_dims=256,
num_heads=7,
)
MultiScaleDeformableAttention(embed_dims=256, num_heads=8)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment