# ------------------------------------------------------------------------ # Copyright (c) 2022 megvii-model. All Rights Reserved. # ------------------------------------------------------------------------ # Modified from DETR3D (https://github.com/WangYueFt/detr3d) # Copyright (c) 2021 Wang, Yue # ------------------------------------------------------------------------ # Modified from mmdetection3d (https://github.com/open-mmlab/mmdetection3d) # Copyright (c) OpenMMLab. All rights reserved. # ------------------------------------------------------------------------ import warnings import torch import torch.nn as nn import torch.utils.checkpoint as cp from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.transformer import (BaseTransformerLayer, TransformerLayerSequence) from mmengine.model import BaseModule from mmengine.model.weight_init import xavier_init # from mmcv.utils import deprecated_api_warning from mmdet3d.registry import MODELS, TASK_UTILS @MODELS.register_module() class PETRTransformer(BaseModule): """Implements the DETR transformer. Following the official DETR implementation, this module copy-paste from torch.nn.Transformer with modifications: * positional encodings are passed in MultiheadAttention * extra LN at the end of encoder is removed * decoder returns a stack of activations from all decoding layers See `paper: End-to-End Object Detection with Transformers `_ for details. Args: encoder (`mmcv.ConfigDict` | Dict): Config of TransformerEncoder. Defaults to None. decoder ((`mmcv.ConfigDict` | Dict)): Config of TransformerDecoder. Defaults to None init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Defaults to None. """ def __init__(self, encoder=None, decoder=None, init_cfg=None, cross=False): super(PETRTransformer, self).__init__(init_cfg=init_cfg) if encoder is not None: self.encoder = MODELS.build(encoder) else: self.encoder = None self.decoder = MODELS.build(decoder) self.embed_dims = self.decoder.embed_dims self.cross = cross def init_weights(self): # follow the official DETR to init parameters for m in self.modules(): if hasattr(m, 'weight') and m.weight.dim() > 1: xavier_init(m, distribution='uniform') self._is_init = True def forward(self, x, mask, query_embed, pos_embed, reg_branch=None): """Forward function for `Transformer`. Args: x (Tensor): Input query with shape [bs, c, h, w] where c = embed_dims. mask (Tensor): The key_padding_mask used for encoder and decoder, with shape [bs, h, w]. query_embed (Tensor): The query embedding for decoder, with shape [num_query, c]. pos_embed (Tensor): The positional encoding for encoder and decoder, with the same shape as `x`. Returns: tuple[Tensor]: results of decoder containing the following tensor. - out_dec: Output from decoder. If return_intermediate_dec \ is True output has shape [num_dec_layers, bs, num_query, embed_dims], else has shape [1, bs, \ num_query, embed_dims]. - memory: Output results from encoder, with shape \ [bs, embed_dims, h, w]. """ bs, n, c, h, w = x.shape memory = x.permute(1, 3, 4, 0, 2).reshape(-1, bs, c) # [bs, n, c, h, w] -> [n*h*w, bs, c] pos_embed = pos_embed.permute(1, 3, 4, 0, 2).reshape( -1, bs, c) # [bs, n, c, h, w] -> [n*h*w, bs, c] query_embed = query_embed.unsqueeze(1).repeat( 1, bs, 1) # [num_query, dim] -> [num_query, bs, dim] mask = mask.view(bs, -1) # [bs, n, h, w] -> [bs, n*h*w] target = torch.zeros_like(query_embed) # out_dec: [num_layers, num_query, bs, dim] out_dec = self.decoder( query=target, key=memory, value=memory, key_pos=pos_embed, query_pos=query_embed, key_padding_mask=mask, reg_branch=reg_branch, ) out_dec = out_dec.transpose(1, 2) memory = memory.reshape(n, h, w, bs, c).permute(3, 0, 4, 1, 2) return out_dec, memory @MODELS.register_module() class PETRDNTransformer(BaseModule): """Implements the DETR transformer. Following the official DETR implementation, this module copy-paste from torch.nn.Transformer with modifications: * positional encodings are passed in MultiheadAttention * extra LN at the end of encoder is removed * decoder returns a stack of activations from all decoding layers See `paper: End-to-End Object Detection with Transformers `_ for details. Args: encoder (`mmcv.ConfigDict` | Dict): Config of TransformerEncoder. Defaults to None. decoder ((`mmcv.ConfigDict` | Dict)): Config of TransformerDecoder. Defaults to None init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Defaults to None. """ def __init__(self, encoder=None, decoder=None, init_cfg=None, cross=False): super(PETRDNTransformer, self).__init__(init_cfg=init_cfg) if encoder is not None: self.encoder = MODELS.build(encoder) else: self.encoder = None self.decoder = MODELS.build(decoder) self.embed_dims = self.decoder.embed_dims self.cross = cross def init_weights(self): # follow the official DETR to init parameters for m in self.modules(): if hasattr(m, 'weight') and m.weight.dim() > 1: xavier_init(m, distribution='uniform') self._is_init = True def forward(self, x, mask, query_embed, pos_embed, attn_masks=None, reg_branch=None): """Forward function for `Transformer`. Args: x (Tensor): Input query with shape [bs, c, h, w] where c = embed_dims. mask (Tensor): The key_padding_mask used for encoder and decoder, with shape [bs, h, w]. query_embed (Tensor): The query embedding for decoder, with shape [num_query, c]. pos_embed (Tensor): The positional encoding for encoder and decoder, with the same shape as `x`. Returns: tuple[Tensor]: results of decoder containing the following tensor. - out_dec: Output from decoder. If return_intermediate_dec \ is True output has shape [num_dec_layers, bs, num_query, embed_dims], else has shape [1, bs, \ num_query, embed_dims]. - memory: Output results from encoder, with shape \ [bs, embed_dims, h, w]. """ bs, n, c, h, w = x.shape memory = x.permute(1, 3, 4, 0, 2).reshape(-1, bs, c) # [bs, n, c, h, w] -> [n*h*w, bs, c] pos_embed = pos_embed.permute(1, 3, 4, 0, 2).reshape( -1, bs, c) # [bs, n, c, h, w] -> [n*h*w, bs, c] query_embed = query_embed.transpose( 0, 1) # [num_query, dim] -> [num_query, bs, dim] mask = mask.view(bs, -1) # [bs, n, h, w] -> [bs, n*h*w] target = torch.zeros_like(query_embed) # out_dec: [num_layers, num_query, bs, dim] out_dec = self.decoder( query=target, key=memory, value=memory, key_pos=pos_embed, query_pos=query_embed, key_padding_mask=mask, attn_masks=[attn_masks, None], reg_branch=reg_branch, ) out_dec = out_dec.transpose(1, 2) memory = memory.reshape(n, h, w, bs, c).permute(3, 0, 4, 1, 2) return out_dec, memory @MODELS.register_module() class PETRTransformerDecoderLayer(BaseTransformerLayer): """Implements decoder layer in DETR transformer. Args: attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )): Configs for self_attention or cross_attention, the order should be consistent with it in `operation_order`. If it is a dict, it would be expand to the number of attention in `operation_order`. feedforward_channels (int): The hidden dimension for FFNs. ffn_dropout (float): Probability of an element to be zeroed in ffn. Default 0.0. operation_order (tuple[str]): The execution order of operation in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). Default:None act_cfg (dict): The activation config for FFNs. Default: `LN` norm_cfg (dict): Config dict for normalization layer. Default: `LN`. ffn_num_fcs (int): The number of fully-connected layers in FFNs. Default:2. """ def __init__(self, attn_cfgs, feedforward_channels, ffn_dropout=0.0, operation_order=None, act_cfg=dict(type='ReLU', inplace=True), norm_cfg=dict(type='LN'), ffn_num_fcs=2, with_cp=True, **kwargs): super(PETRTransformerDecoderLayer, self).__init__( attn_cfgs=attn_cfgs, feedforward_channels=feedforward_channels, ffn_dropout=ffn_dropout, operation_order=operation_order, act_cfg=act_cfg, norm_cfg=norm_cfg, ffn_num_fcs=ffn_num_fcs, **kwargs) assert len(operation_order) == 6 assert set(operation_order) == set( ['self_attn', 'norm', 'cross_attn', 'ffn']) self.use_checkpoint = with_cp def _forward( self, query, key=None, value=None, query_pos=None, key_pos=None, attn_masks=None, query_key_padding_mask=None, key_padding_mask=None, ): """Forward function for `TransformerCoder`. Returns: Tensor: forwarded results with shape [num_query, bs, embed_dims]. """ x = super(PETRTransformerDecoderLayer, self).forward( query, key=key, value=value, query_pos=query_pos, key_pos=key_pos, attn_masks=attn_masks, query_key_padding_mask=query_key_padding_mask, key_padding_mask=key_padding_mask, ) return x def forward(self, query, key=None, value=None, query_pos=None, key_pos=None, attn_masks=None, query_key_padding_mask=None, key_padding_mask=None, **kwargs): """Forward function for `TransformerCoder`. Returns: Tensor: forwarded results with shape [num_query, bs, embed_dims]. """ if self.use_checkpoint and self.training: x = cp.checkpoint( self._forward, query, key, value, query_pos, key_pos, attn_masks, query_key_padding_mask, key_padding_mask, ) else: x = self._forward( query, key=key, value=value, query_pos=query_pos, key_pos=key_pos, attn_masks=attn_masks, query_key_padding_mask=query_key_padding_mask, key_padding_mask=key_padding_mask) return x @MODELS.register_module() class PETRMultiheadAttention(BaseModule): """A wrapper for ``torch.nn.MultiheadAttention``. This module implements MultiheadAttention with identity connection, and positional encoding is also passed as input. Args: embed_dims (int): The embedding dimension. num_heads (int): Parallel attention heads. attn_drop (float): A Dropout layer on attn_output_weights. Default: 0.0. proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. Default: 0.0. dropout_layer (obj:`ConfigDict`): The dropout_layer used when adding the shortcut. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. batch_first (bool): When it is True, Key, Query and Value are shape of (batch, n, embed_dim), otherwise (n, batch, embed_dim). Default to False. """ def __init__(self, embed_dims, num_heads, attn_drop=0., proj_drop=0., dropout_layer=dict(type='Dropout', drop_prob=0.), init_cfg=None, batch_first=False, **kwargs): super(PETRMultiheadAttention, self).__init__(init_cfg) if 'dropout' in kwargs: warnings.warn( 'The arguments `dropout` in MultiheadAttention ' 'has been deprecated, now you can separately ' 'set `attn_drop`(float), proj_drop(float), ' 'and `dropout_layer`(dict) ', DeprecationWarning) attn_drop = kwargs['dropout'] dropout_layer['drop_prob'] = kwargs.pop('dropout') self.embed_dims = embed_dims self.num_heads = num_heads self.batch_first = batch_first self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop, **kwargs) self.proj_drop = nn.Dropout(proj_drop) self.dropout_layer = MODELS.build( dropout_layer) if dropout_layer else nn.Identity() # @deprecated_api_warning({'residual': 'identity'}, # cls_name='MultiheadAttention') def forward(self, query, key=None, value=None, identity=None, query_pos=None, key_pos=None, attn_mask=None, key_padding_mask=None, **kwargs): """Forward function for `MultiheadAttention`. **kwargs allow passing a more general data flow when combining with other operations in `transformerlayer`. Args: query (Tensor): The input query with shape [num_queries, bs, embed_dims] if self.batch_first is False, else [bs, num_queries embed_dims]. key (Tensor): The key tensor with shape [num_keys, bs, embed_dims] if self.batch_first is False, else [bs, num_keys, embed_dims] . If None, the ``query`` will be used. Defaults to None. value (Tensor): The value tensor with same shape as `key`. Same in `nn.MultiheadAttention.forward`. Defaults to None. If None, the `key` will be used. identity (Tensor): This tensor, with the same shape as x, will be used for the identity link. If None, `x` will be used. Defaults to None. query_pos (Tensor): The positional encoding for query, with the same shape as `x`. If not None, it will be added to `x` before forward function. Defaults to None. key_pos (Tensor): The positional encoding for `key`, with the same shape as `key`. Defaults to None. If not None, it will be added to `key` before forward function. If None, and `query_pos` has the same shape as `key`, then `query_pos` will be used for `key_pos`. Defaults to None. attn_mask (Tensor): ByteTensor mask with shape [num_queries, num_keys]. Same in `nn.MultiheadAttention.forward`. Defaults to None. key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. Defaults to None. Returns: Tensor: forwarded results with shape [num_queries, bs, embed_dims] if self.batch_first is False, else [bs, num_queries embed_dims]. """ if key is None: key = query if value is None: value = key if identity is None: identity = query if key_pos is None: if query_pos is not None: # use query_pos if key_pos is not available if query_pos.shape == key.shape: key_pos = query_pos else: warnings.warn(f'position encoding of key is' f'missing in {self.__class__.__name__}.') if query_pos is not None: query = query + query_pos if key_pos is not None: key = key + key_pos # Because the dataflow('key', 'query', 'value') of # ``torch.nn.MultiheadAttention`` is (num_query, batch, # embed_dims), We should adjust the shape of dataflow from # batch_first (batch, num_query, embed_dims) to num_query_first # (num_query ,batch, embed_dims), and recover ``attn_output`` # from num_query_first to batch_first. if self.batch_first: query = query.transpose(0, 1) key = key.transpose(0, 1) value = value.transpose(0, 1) out = self.attn( query=query, key=key, value=value, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0] if self.batch_first: out = out.transpose(0, 1) return identity + self.dropout_layer(self.proj_drop(out)) @MODELS.register_module() class PETRTransformerEncoder(TransformerLayerSequence): """TransformerEncoder of DETR. Args: post_norm_cfg (dict): Config of last normalization layer. Default: `LN`. Only used when `self.pre_norm` is `True` """ def __init__(self, *args, post_norm_cfg=dict(type='LN'), **kwargs): super(PETRTransformerEncoder, self).__init__(*args, **kwargs) if post_norm_cfg is not None: self.post_norm = TASK_UTILS.build( post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None else: assert not self.pre_norm, f'Use prenorm in ' \ f'{self.__class__.__name__},' \ f'Please specify post_norm_cfg' self.post_norm = None def forward(self, *args, **kwargs): """Forward function for `TransformerCoder`. Returns: Tensor: forwarded results with shape [num_query, bs, embed_dims]. """ x = super(PETRTransformerEncoder, self).forward(*args, **kwargs) if self.post_norm is not None: x = self.post_norm(x) return x @MODELS.register_module() class PETRTransformerDecoder(TransformerLayerSequence): """Implements the decoder in DETR transformer. Args: return_intermediate (bool): Whether to return intermediate outputs. post_norm_cfg (dict): Config of last normalization layer. Default: `LN`. """ def __init__(self, *args, post_norm_cfg=dict(type='LN'), return_intermediate=False, **kwargs): super(PETRTransformerDecoder, self).__init__(*args, **kwargs) self.return_intermediate = return_intermediate if post_norm_cfg is not None: self.post_norm = build_norm_layer(post_norm_cfg, self.embed_dims)[1] else: self.post_norm = None def forward(self, query, *args, **kwargs): """Forward function for `TransformerDecoder`. Args: query (Tensor): Input query with shape `(num_query, bs, embed_dims)`. Returns: Tensor: Results with shape [1, num_query, bs, embed_dims] when return_intermediate is `False`, otherwise it has shape [num_layers, num_query, bs, embed_dims]. """ if not self.return_intermediate: x = super().forward(query, *args, **kwargs) if self.post_norm: x = self.post_norm(x)[None] return x intermediate = [] for layer in self.layers: query = layer(query, *args, **kwargs) if self.return_intermediate: if self.post_norm is not None: intermediate.append(self.post_norm(query)) else: intermediate.append(query) return torch.stack(intermediate)