import copy import math import warnings from typing import Sequence import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer, build_norm_layer) from mmcv.runner.base_module import BaseModule, ModuleList, Sequential from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning, to_2tuple) from mmcv.cnn.bricks.drop import build_dropout from mmcv.cnn.bricks.registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING, TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE) @ATTENTION.register_module() class GroupMultiheadAttention(BaseModule): """A wrapper for ``torch.nn.MultiheadAttention``. This module implements MultiheadAttention with identity connection, and positional encoding is also passed as input. Args: embed_dims (int): The embedding dimension. num_heads (int): Parallel attention heads. attn_drop (float): A Dropout layer on attn_output_weights. Default: 0.0. proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. Default: 0.0. dropout_layer (obj:`ConfigDict`): The dropout_layer used when adding the shortcut. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. batch_first (bool): When it is True, Key, Query and Value are shape of (batch, n, embed_dim), otherwise (n, batch, embed_dim). Default to False. """ def __init__(self, embed_dims, num_heads, attn_drop=0., proj_drop=0., group=1, dropout_layer=dict(type='Dropout', drop_prob=0.), init_cfg=None, batch_first=False, **kwargs): super().__init__(init_cfg) if 'dropout' in kwargs: warnings.warn( 'The arguments `dropout` in MultiheadAttention ' 'has been deprecated, now you can separately ' 'set `attn_drop`(float), proj_drop(float), ' 'and `dropout_layer`(dict) ', DeprecationWarning) attn_drop = kwargs['dropout'] dropout_layer['drop_prob'] = kwargs.pop('dropout') self.embed_dims = embed_dims self.num_heads = num_heads self.group = group self.batch_first = batch_first self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop, **kwargs) self.proj_drop = nn.Dropout(proj_drop) self.dropout_layer = build_dropout(dropout_layer) if dropout_layer else nn.Identity() @deprecated_api_warning({'residual': 'identity'}, cls_name='MultiheadAttention') def forward(self, query, key=None, value=None, identity=None, query_pos=None, key_pos=None, attn_mask=None, key_padding_mask=None, **kwargs): """Forward function for `MultiheadAttention`. **kwargs allow passing a more general data flow when combining with other operations in `transformerlayer`. Args: query (Tensor): The input query with shape [num_queries, bs, embed_dims] if self.batch_first is False, else [bs, num_queries embed_dims]. key (Tensor): The key tensor with shape [num_keys, bs, embed_dims] if self.batch_first is False, else [bs, num_keys, embed_dims] . If None, the ``query`` will be used. Defaults to None. value (Tensor): The value tensor with same shape as `key`. Same in `nn.MultiheadAttention.forward`. Defaults to None. If None, the `key` will be used. identity (Tensor): This tensor, with the same shape as x, will be used for the identity link. If None, `x` will be used. Defaults to None. query_pos (Tensor): The positional encoding for query, with the same shape as `x`. If not None, it will be added to `x` before forward function. Defaults to None. key_pos (Tensor): The positional encoding for `key`, with the same shape as `key`. Defaults to None. If not None, it will be added to `key` before forward function. If None, and `query_pos` has the same shape as `key`, then `query_pos` will be used for `key_pos`. Defaults to None. attn_mask (Tensor): ByteTensor mask with shape [num_queries, num_keys]. Same in `nn.MultiheadAttention.forward`. Defaults to None. key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. Defaults to None. Returns: Tensor: forwarded results with shape [num_queries, bs, embed_dims] if self.batch_first is False, else [bs, num_queries embed_dims]. """ if key is None: key = query if value is None: value = key if identity is None: identity = query if key_pos is None: if query_pos is not None: # use query_pos if key_pos is not available if query_pos.shape == key.shape: key_pos = query_pos else: warnings.warn(f'position encoding of key is' f'missing in {self.__class__.__name__}.') if query_pos is not None: query = query + query_pos if key_pos is not None: key = key + key_pos # Because the dataflow('key', 'query', 'value') of # ``torch.nn.MultiheadAttention`` is (num_query, batch, # embed_dims), We should adjust the shape of dataflow from # batch_first (batch, num_query, embed_dims) to num_query_first # (num_query ,batch, embed_dims), and recover ``attn_output`` # from num_query_first to batch_first. if self.batch_first: query = query.transpose(0, 1) key = key.transpose(0, 1) value = value.transpose(0, 1) num_queries = query.shape[0] bs = query.shape[1] if self.training: query = torch.cat(query.split(num_queries // self.group, dim=0), dim=1) key = torch.cat(key.split(num_queries // self.group, dim=0), dim=1) value = torch.cat(value.split(num_queries // self.group, dim=0), dim=1) out = self.attn(query=query, key=key, value=value, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0] if self.training: out = torch.cat(out.split(bs, dim=1), dim=0) # shape if self.batch_first: out = out.transpose(0, 1) return identity + self.dropout_layer(self.proj_drop(out))