Unverified Commit a9803da1 authored by ZhangShilong's avatar ZhangShilong Committed by GitHub
Browse files

[Feature]: Add transformer base class (#892)

* add transformer base class

* fix according to commments

* resiger module

* change to use basemodule

* add _ between build function

* split the name
parent 5f5e8e83
...@@ -6,3 +6,8 @@ ACTIVATION_LAYERS = Registry('activation layer') ...@@ -6,3 +6,8 @@ ACTIVATION_LAYERS = Registry('activation layer')
PADDING_LAYERS = Registry('padding layer') PADDING_LAYERS = Registry('padding layer')
UPSAMPLE_LAYERS = Registry('upsample layer') UPSAMPLE_LAYERS = Registry('upsample layer')
PLUGIN_LAYERS = Registry('plugin layer') PLUGIN_LAYERS = Registry('plugin layer')
POSITIONAL_ENCODING = Registry('Position encoding')
ATTENTION = Registry('Attention')
TRANSFORMER_LAYER = Registry('TransformerLayer')
TRANSFORMER_LAYER_SEQUENCE = Registry('TransformerLayerSequence')
import copy
import warnings
import torch.nn as nn
from mmcv import ConfigDict
from mmcv.cnn import Linear, build_activation_layer, build_norm_layer
from mmcv.runner.base_module import BaseModule
from mmcv.utils import build_from_cfg
from .registry import (ATTENTION, POSITIONAL_ENCODING, TRANSFORMER_LAYER,
TRANSFORMER_LAYER_SEQUENCE)
def build_positional_encoding(cfg, default_args=None):
"""Builder for Position Encoding."""
return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args)
def build_attention(cfg, default_args=None):
"""Builder for attention."""
return build_from_cfg(cfg, ATTENTION, default_args)
def build_transformer_layer(cfg, default_args=None):
"""Builder for transformer layer."""
return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args)
def build_transformer_layer_sequence(cfg, default_args=None):
"""Builder for transformer encoder and transformer decoder."""
return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args)
@ATTENTION.register_module()
class MultiheadAttention(BaseModule):
"""A warpper for torch.nn.MultiheadAttention.
This module implements MultiheadAttention with residual connection,
and positional encoding used in DETR is also passed as input.
Args:
embed_dims (int): The embedding dimension.
num_heads (int): Parallel attention heads. Same as
`nn.MultiheadAttention`.
dropout (float):w A Dropout layer on attn_output_weights. Default: 0..
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
num_heads,
dropout=0.,
init_cfg=None,
**kwargs):
super(MultiheadAttention, self).__init__()
self.embed_dims = embed_dims
self.num_heads = num_heads
self.dropout = dropout
self.attn = nn.MultiheadAttention(embed_dims, num_heads, dropout,
**kwargs)
self.dropout = nn.Dropout(dropout)
self.init_cfg = init_cfg
def forward(self,
query,
key=None,
value=None,
residual=None,
query_pos=None,
key_pos=None,
attn_mask=None,
key_padding_mask=None,
**kwargs):
"""Forward function for `MultiheadAttention`.
**kwargs allow passing a more general data flow when combining
with other operations in `transformerlayer`.
Args:
query (Tensor): The input query with shape [num_queries, bs,
embed_dims]. Same in `nn.MultiheadAttention.forward`.
key (Tensor): The key tensor with shape [num_keys, bs,
embed_dims]. Same in `nn.MultiheadAttention.forward`.
If None, the ``query`` will be used. Defaults to None.
value (Tensor): The value tensor with same shape as `key`.
Same in `nn.MultiheadAttention.forward`. Defaults to None.
If None, the `key` will be used.
residual (Tensor): This tensor, with the same shape as x,
will be used for the residual link.
If None, `x` will be used. Defaults to None.
query_pos (Tensor): The positional encoding for query, with
the same shape as `x`. If not None, it will
be added to `x` before forward function. Defaults to None.
key_pos (Tensor): The positional encoding for `key`, with the
same shape as `key`. Defaults to None. If not None, it will
be added to `key` before forward function. If None, and
`query_pos` has the same shape as `key`, then `query_pos`
will be used for `key_pos`. Defaults to None.
attn_mask (Tensor): ByteTensor mask with shape [num_queries,
num_keys]. Same in `nn.MultiheadAttention.forward`.
Defaults to None.
key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
Same in `nn.MultiheadAttention.forward`. Defaults to None.
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
"""
if key is None:
key = query
if value is None:
value = key
if residual is None:
residual = query
if key_pos is None:
if query_pos is not None:
# use query_pos if key_pos is not available
if query_pos.shape == key.shape:
key_pos = query_pos
else:
warnings.warn(f'position encoding of key is'
f'missing in {self.__class__.__name__}.')
if query_pos is not None:
query = query + query_pos
if key_pos is not None:
key = key + key_pos
out = self.attn(
query,
key,
value=value,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask)[0]
return residual + self.dropout(out)
class FFN(BaseModule):
"""Implements feed-forward networks (FFNs) with residual connection.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`.
feedforward_channels (int): The hidden dimension of FFNs.
num_fcs (int, optional): The number of fully-connected layers in
FFNs. Default: 2.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='ReLU')
dropout (float, optional): Probability of an element to be
zeroed. Default 0..
add_residual (bool, optional): Whether to add the
residual connection. Default: `True`.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
embed_dims,
feedforward_channels,
num_fcs=2,
act_cfg=dict(type='ReLU', inplace=True),
dropout=0.,
add_residual=True,
init_cfg=None):
super(FFN, self).__init__()
assert num_fcs >= 2, 'num_fcs should be no less ' \
f'than 2. got {num_fcs}.'
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.num_fcs = num_fcs
self.act_cfg = act_cfg
self.dropout = dropout
self.init_cfg = init_cfg
self.activate = build_activation_layer(act_cfg)
layers = []
in_channels = embed_dims
for _ in range(num_fcs - 1):
layers.append(
nn.Sequential(
Linear(in_channels, feedforward_channels), self.activate,
nn.Dropout(dropout)))
in_channels = feedforward_channels
layers.append(Linear(feedforward_channels, embed_dims))
self.layers = nn.Sequential(*layers)
self.dropout = nn.Dropout(dropout)
self.add_residual = add_residual
def forward(self, x, residual=None):
"""Forward function for `FFN`.
The function would add x to the output tensor if residue is None.
"""
out = self.layers(x)
if not self.add_residual:
return out
if residual is None:
residual = x
return residual + self.dropout(out)
@TRANSFORMER_LAYER.register_module()
class BaseTransformerLayer(BaseModule):
"""Base `TransformerLayer` for vision transformer.
It can be built from `mmcv.ConfigDict` and support more flexible
customization, for example, using any number of `FFN or LN ` and
use different kinds of `attention` by specifying a list of `ConfigDict`
named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
when you specifying `norm` as the first element of `operation_order`.
More details about the `prenorm`: `On Layer Normalization in the
Transformer Architecture <https://arxiv.org/abs/2007.08103>`_ .
Args:
attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
Configs for `self_attention` or `cross_attention` modules,
The order of the configs in the list should be consistent with
corresponding attentions in operation_order.
If it is a dict, all of the attention modules in operation_order
will be built with this config. Default: None.
feedforward_channels (int): The hidden dimension for FFNs.
Default: None.
ffn_dropout (float): Probability of an element to be zeroed
in ffn. Default 0..
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
Support `prenorm` when you specifying first element as `norm`.
Default:None.
act_cfg (dict): The activation config for FFNs.
Default: dict(type='ReLU')
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN').
ffn_num_fcs (int): The number of fully-connected layers in FFNs.
Default:2.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self,
attn_cfgs=None,
feedforward_channels=None,
ffn_dropout=0.,
operation_order=None,
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN'),
ffn_num_fcs=2,
init_cfg=None):
super(BaseTransformerLayer, self).__init__()
assert set(operation_order) & set(
['self_attn', 'norm', 'ffn', 'cross_attn']) == \
set(operation_order), f'The operation_order of' \
f' {self.__class__.__name__} should ' \
f'contains all four operation type ' \
f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
num_attn = operation_order.count('self_attn') + operation_order.count(
'cross_attn')
if isinstance(attn_cfgs, ConfigDict):
attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
else:
assert num_attn == len(attn_cfgs), f'The length ' \
f'of attn_cfg {num_attn} is ' \
f'not consistent with the number of attention' \
f'in operation_order {operation_order}.'
self.init_cfg = init_cfg
self.num_attn = num_attn
self.feedforward_channels = feedforward_channels
self.ffn_dropout = ffn_dropout
self.operation_order = operation_order
self.act_cfg = act_cfg
self.norm_cfg = norm_cfg
self.ffn_num_fcs = ffn_num_fcs
self.pre_norm = operation_order[0] == 'norm'
self.attentions = nn.ModuleList()
index = 0
for operation in operation_order:
if operation in ['self_attn', 'cross_attn']:
attention = build_attention(attn_cfgs[index])
self.attentions.append(attention)
index += 1
self.embed_dims = self.attentions[0].embed_dims
self.ffns = nn.ModuleList()
num_ffns = operation_order.count('ffn')
for _ in range(num_ffns):
self.ffns.append(
FFN(self.embed_dims, feedforward_channels, ffn_num_fcs,
act_cfg, ffn_dropout))
self.norms = nn.ModuleList()
num_norms = operation_order.count('norm')
for _ in range(num_norms):
self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
def forward(self,
query,
key,
value,
query_pos=None,
key_pos=None,
attn_masks=None,
query_key_padding_mask=None,
key_padding_mask=None,
**kwargs):
"""Forward function for `TransformerDecoderLayer`.
**kwargs contains some specific arguments of attentions.
Args:
query (Tensor): Input query with the shape
`(num_queries, bs, embed_dims)`.
key (Tensor): The key tensor with shape
`(num_keys, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_keys, bs, embed_dims)`.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`.
Default: None.
attn_masks (List[Tensor] | None): 2D Tensor used in
calculation of corresponding attention. The length of
it should equal to the number of `attention` in
`operation_order`. Default: None.
query_key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_queries]. Only used in `self_attn` layer.
Defaults to None.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_keys]. Default: None.
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
"""
norm_index = 0
attn_index = 0
ffn_index = 0
inp_residual = query
if attn_masks is None:
attn_masks = [None for _ in range(self.num_attn)]
else:
assert len(attn_masks) == self.num_attn, f'The length of ' \
f'attn_masks {len(attn_masks)} must be equal ' \
f'to the number of attention in ' \
f'operation_order {self.num_attn}'
for layer in self.operation_order:
if layer == 'self_attn':
temp_key = temp_value = query
query = self.attentions[attn_index](
query,
temp_key,
temp_value,
inp_residual if self.pre_norm else None,
query_pos=query_pos,
key_pos=query_pos,
attn_mask=attn_masks[attn_index],
key_padding_mask=query_key_padding_mask,
**kwargs)
attn_index += 1
inp_residual = query
elif layer == 'norm':
query = self.norms[norm_index](query)
norm_index += 1
elif layer == 'cross_attn':
query = self.attentions[attn_index](
query,
key,
value,
inp_residual if self.pre_norm else None,
query_pos=query_pos,
key_pos=key_pos,
attn_mask=attn_masks[attn_index],
key_padding_mask=key_padding_mask,
**kwargs)
attn_index += 1
inp_residual = query
elif layer == 'ffn':
query = self.ffns[ffn_index](
query, inp_residual if self.pre_norm else None)
ffn_index += 1
return query
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class TransformerLayerSequence(BaseModule):
"""Base class for TransformerEncoder and TransformerDecoder in vision
transformer.
As base-class of Encoder and Decoder in vision transformer.
Support customization such as specifying different kind
of `transformer_layer` in `transformer_coder`.
Args:
transformerlayer (list[obj:`mmcv.ConfigDict`] |
obj:`mmcv.ConfigDict`): Config of transformerlayer
in TransformerCoder. If it is obj:`mmcv.ConfigDict`,
it would be repeated `num_layer` times to a
list[`mmcv.ConfigDict`]. Default: None.
num_layers (int): The number of `TransformerLayer`. Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
super(TransformerLayerSequence, self).__init__()
if isinstance(transformerlayers, ConfigDict):
transformerlayers = [
copy.deepcopy(transformerlayers) for _ in range(num_layers)
]
else:
assert isinstance(transformerlayers, list) and \
len(transformerlayers) == num_layers
self.init_cfg = init_cfg
self.num_layers = num_layers
operation_order = transformerlayers[0]['operation_order']
self.pre_norm = operation_order[0] == 'norm'
self.layers = nn.ModuleList()
for i in range(num_layers):
self.layers.append(build_transformer_layer(transformerlayers[i]))
self.embed_dims = self.layers[0].embed_dims
self.pre_norm = self.layers[0].operation_order[0] == 'norm'
def forward(self,
query,
key,
value,
query_pos=None,
key_pos=None,
attn_masks=None,
query_key_padding_mask=None,
key_padding_mask=None,
**kwargs):
"""Forward function for `TransformerCoder`.
Args:
query (Tensor): Input query with shape
`(num_queries, bs, embed_dims)`.
key (Tensor): The key tensor with shape
`(num_keys, bs, embed_dims)`.
value (Tensor): The value tensor with shape
`(num_keys, bs, embed_dims)`.
query_pos (Tensor): The positional encoding for `query`.
Default: None.
key_pos (Tensor): The positional encoding for `key`.
Default: None.
attn_masks (List[Tensor], optional): Each element is 2D Tensor
which is used in calculation of corresponding attention in
operation_order. Default: None.
query_key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_queries]. Only used in self-attention
Default: None.
key_padding_mask (Tensor): ByteTensor for `query`, with
shape [bs, num_keys]. Default: None.
Returns:
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
"""
for layer in self.layers:
query = layer(
query,
key,
value,
query_pos=query_pos,
key_pos=key_pos,
attn_masks=attn_masks,
query_key_padding_mask=query_key_padding_mask,
key_padding_mask=key_padding_mask,
**kwargs)
return query
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment