import math import warnings import torch import torch.nn as nn from mmcv.cnn import build_norm_layer from mmcv.cnn.bricks.transformer import (BaseTransformerLayer, TransformerLayerSequence, build_transformer_layer_sequence) from mmcv.ops import MultiScaleDeformableAttention from mmengine.model import BaseModule from mmengine.model.weight_init import xavier_init from torch.nn.init import normal_ from mmdet.models.layers.transformer import inverse_sigmoid from mmdet.registry import MODELS try: from fairscale.nn.checkpoint import checkpoint_wrapper except Exception: checkpoint_wrapper = None # In order to save the cost and effort of reproduction, # I did not refactor it into the style of mmdet 3.x DETR. class Transformer(BaseModule): """Implements the DETR transformer. Following the official DETR implementation, this module copy-paste from torch.nn.Transformer with modifications: * positional encodings are passed in MultiheadAttention * extra LN at the end of encoder is removed * decoder returns a stack of activations from all decoding layers See `paper: End-to-End Object Detection with Transformers `_ for details. Args: encoder (`mmcv.ConfigDict` | Dict): Config of TransformerEncoder. Defaults to None. decoder ((`mmcv.ConfigDict` | Dict)): Config of TransformerDecoder. Defaults to None init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Defaults to None. """ def __init__(self, encoder=None, decoder=None, init_cfg=None): super(Transformer, self).__init__(init_cfg=init_cfg) self.encoder = build_transformer_layer_sequence(encoder) self.decoder = build_transformer_layer_sequence(decoder) self.embed_dims = self.encoder.embed_dims def init_weights(self): # follow the official DETR to init parameters for m in self.modules(): if hasattr(m, 'weight') and m.weight.dim() > 1: xavier_init(m, distribution='uniform') self._is_init = True def forward(self, x, mask, query_embed, pos_embed): """Forward function for `Transformer`. Args: x (Tensor): Input query with shape [bs, c, h, w] where c = embed_dims. mask (Tensor): The key_padding_mask used for encoder and decoder, with shape [bs, h, w]. query_embed (Tensor): The query embedding for decoder, with shape [num_query, c]. pos_embed (Tensor): The positional encoding for encoder and decoder, with the same shape as `x`. Returns: tuple[Tensor]: results of decoder containing the following tensor. - out_dec: Output from decoder. If return_intermediate_dec \ is True output has shape [num_dec_layers, bs, num_query, embed_dims], else has shape [1, bs, \ num_query, embed_dims]. - memory: Output results from encoder, with shape \ [bs, embed_dims, h, w]. """ bs, c, h, w = x.shape # use `view` instead of `flatten` for dynamically exporting to ONNX x = x.view(bs, c, -1).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c] pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1) query_embed = query_embed.unsqueeze(1).repeat( 1, bs, 1) # [num_query, dim] -> [num_query, bs, dim] mask = mask.view(bs, -1) # [bs, h, w] -> [bs, h*w] memory = self.encoder( query=x, key=None, value=None, query_pos=pos_embed, query_key_padding_mask=mask) target = torch.zeros_like(query_embed) # out_dec: [num_layers, num_query, bs, dim] out_dec = self.decoder( query=target, key=memory, value=memory, key_pos=pos_embed, query_pos=query_embed, key_padding_mask=mask) out_dec = out_dec.transpose(1, 2) memory = memory.permute(1, 2, 0).reshape(bs, c, h, w) return out_dec, memory @MODELS.register_module(force=True) class DeformableDetrTransformerDecoder(TransformerLayerSequence): """Implements the decoder in DETR transformer. Args: return_intermediate (bool): Whether to return intermediate outputs. coder_norm_cfg (dict): Config of last normalization layer. Default: `LN`. """ def __init__(self, *args, return_intermediate=False, **kwargs): super(DeformableDetrTransformerDecoder, self).__init__(*args, **kwargs) self.return_intermediate = return_intermediate def forward(self, query, *args, reference_points=None, valid_ratios=None, reg_branches=None, **kwargs): """Forward function for `TransformerDecoder`. Args: query (Tensor): Input query with shape `(num_query, bs, embed_dims)`. reference_points (Tensor): The reference points of offset. has shape (bs, num_query, 4) when as_two_stage, otherwise has shape ((bs, num_query, 2). valid_ratios (Tensor): The radios of valid points on the feature map, has shape (bs, num_levels, 2) reg_branch: (obj:`nn.ModuleList`): Used for refining the regression results. Only would be passed when with_box_refine is True, otherwise would be passed a `None`. Returns: Tensor: Results with shape [1, num_query, bs, embed_dims] when return_intermediate is `False`, otherwise it has shape [num_layers, num_query, bs, embed_dims]. """ output = query intermediate = [] intermediate_reference_points = [] for lid, layer in enumerate(self.layers): if reference_points.shape[-1] == 4: reference_points_input = reference_points[:, :, None] * \ torch.cat([valid_ratios, valid_ratios], -1)[:, None] else: assert reference_points.shape[-1] == 2 reference_points_input = reference_points[:, :, None] * \ valid_ratios[:, None] output = layer( output, *args, reference_points=reference_points_input, **kwargs) output = output.permute(1, 0, 2) if reg_branches is not None: tmp = reg_branches[lid](output) if reference_points.shape[-1] == 4: new_reference_points = tmp + inverse_sigmoid( reference_points) new_reference_points = new_reference_points.sigmoid() else: assert reference_points.shape[-1] == 2 new_reference_points = tmp new_reference_points[..., :2] = tmp[ ..., :2] + inverse_sigmoid(reference_points) new_reference_points = new_reference_points.sigmoid() reference_points = new_reference_points.detach() output = output.permute(1, 0, 2) if self.return_intermediate: intermediate.append(output) intermediate_reference_points.append(reference_points) if self.return_intermediate: return torch.stack(intermediate), torch.stack( intermediate_reference_points) return output, reference_points @MODELS.register_module(force=True) class DeformableDetrTransformer(Transformer): """Implements the DeformableDETR transformer. Args: as_two_stage (bool): Generate query from encoder features. Default: False. num_feature_levels (int): Number of feature maps from FPN: Default: 4. two_stage_num_proposals (int): Number of proposals when set `as_two_stage` as True. Default: 300. """ def __init__(self, as_two_stage=False, num_feature_levels=4, two_stage_num_proposals=300, **kwargs): super(DeformableDetrTransformer, self).__init__(**kwargs) self.as_two_stage = as_two_stage self.num_feature_levels = num_feature_levels self.two_stage_num_proposals = two_stage_num_proposals self.embed_dims = self.encoder.embed_dims self.init_layers() def init_layers(self): """Initialize layers of the DeformableDetrTransformer.""" self.level_embeds = nn.Parameter( torch.Tensor(self.num_feature_levels, self.embed_dims)) if self.as_two_stage: self.enc_output = nn.Linear(self.embed_dims, self.embed_dims) self.enc_output_norm = nn.LayerNorm(self.embed_dims) self.pos_trans = nn.Linear(self.embed_dims * 2, self.embed_dims * 2) self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2) else: self.reference_points = nn.Linear(self.embed_dims, 2) def init_weights(self): """Initialize the transformer weights.""" for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MultiScaleDeformableAttention): m.init_weights() if not self.as_two_stage: xavier_init(self.reference_points, distribution='uniform', bias=0.) normal_(self.level_embeds) def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes): """Generate proposals from encoded memory. Args: memory (Tensor) : The output of encoder, has shape (bs, num_key, embed_dim). num_key is equal the number of points on feature map from all level. memory_padding_mask (Tensor): Padding mask for memory. has shape (bs, num_key). spatial_shapes (Tensor): The shape of all feature maps. has shape (num_level, 2). Returns: tuple: A tuple of feature map and bbox prediction. - output_memory (Tensor): The input of decoder, \ has shape (bs, num_key, embed_dim). num_key is \ equal the number of points on feature map from \ all levels. - output_proposals (Tensor): The normalized proposal \ after a inverse sigmoid, has shape \ (bs, num_keys, 4). """ N, S, C = memory.shape proposals = [] _cur = 0 for lvl, (H, W) in enumerate(spatial_shapes): mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H * W)].view( N, H, W, 1) valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) grid_y, grid_x = torch.meshgrid( torch.linspace( 0, H - 1, H, dtype=torch.float32, device=memory.device), torch.linspace( 0, W - 1, W, dtype=torch.float32, device=memory.device)) grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2) grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scale wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) proposal = torch.cat((grid, wh), -1).view(N, -1, 4) proposals.append(proposal) _cur += (H * W) output_proposals = torch.cat(proposals, 1) output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all( -1, keepdim=True) output_proposals = torch.log(output_proposals / (1 - output_proposals)) output_proposals = output_proposals.masked_fill( memory_padding_mask.unsqueeze(-1), float('inf')) output_proposals = output_proposals.masked_fill( ~output_proposals_valid, float('inf')) output_memory = memory output_memory = output_memory.masked_fill( memory_padding_mask.unsqueeze(-1), float(0)) output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) output_memory = self.enc_output_norm(self.enc_output(output_memory)) return output_memory, output_proposals @staticmethod def get_reference_points(spatial_shapes, valid_ratios, device): """Get the reference points used in decoder. Args: spatial_shapes (Tensor): The shape of all feature maps, has shape (num_level, 2). valid_ratios (Tensor): The radios of valid points on the feature map, has shape (bs, num_levels, 2) device (obj:`device`): The device where reference_points should be. Returns: Tensor: reference points used in decoder, has \ shape (bs, num_keys, num_levels, 2). """ reference_points_list = [] for lvl, (H, W) in enumerate(spatial_shapes): ref_y, ref_x = torch.meshgrid( torch.linspace( 0.5, H - 0.5, H, dtype=torch.float32, device=device), torch.linspace( 0.5, W - 0.5, W, dtype=torch.float32, device=device)) ref_y = ref_y.reshape(-1)[None] / ( valid_ratios[:, None, lvl, 1] * H) ref_x = ref_x.reshape(-1)[None] / ( valid_ratios[:, None, lvl, 0] * W) ref = torch.stack((ref_x, ref_y), -1) reference_points_list.append(ref) reference_points = torch.cat(reference_points_list, 1) reference_points = reference_points[:, :, None] * valid_ratios[:, None] return reference_points def get_valid_ratio(self, mask): """Get the valid radios of feature maps of all level.""" _, H, W = mask.shape valid_H = torch.sum(~mask[:, :, 0], 1) valid_W = torch.sum(~mask[:, 0, :], 1) valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) return valid_ratio def get_proposal_pos_embed(self, proposals, num_pos_feats=128, temperature=10000): """Get the position embedding of proposal.""" scale = 2 * math.pi dim_t = torch.arange( num_pos_feats, dtype=torch.float32, device=proposals.device) dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats) # N, L, 4 proposals = proposals.sigmoid() * scale # N, L, 4, 128 pos = proposals[:, :, :, None] / dim_t # N, L, 4, 64, 2 pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2) return pos def forward(self, mlvl_feats, mlvl_masks, query_embed, mlvl_pos_embeds, reg_branches=None, cls_branches=None, **kwargs): """Forward function for `Transformer`. Args: mlvl_feats (list(Tensor)): Input queries from different level. Each element has shape [bs, embed_dims, h, w]. mlvl_masks (list(Tensor)): The key_padding_mask from different level used for encoder and decoder, each element has shape [bs, h, w]. query_embed (Tensor): The query embedding for decoder, with shape [num_query, c]. mlvl_pos_embeds (list(Tensor)): The positional encoding of feats from different level, has the shape [bs, embed_dims, h, w]. reg_branches (obj:`nn.ModuleList`): Regression heads for feature maps from each decoder layer. Only would be passed when `with_box_refine` is True. Default to None. cls_branches (obj:`nn.ModuleList`): Classification heads for feature maps from each decoder layer. Only would be passed when `as_two_stage` is True. Default to None. Returns: tuple[Tensor]: results of decoder containing the following tensor. - inter_states: Outputs from decoder. If return_intermediate_dec is True output has shape \ (num_dec_layers, bs, num_query, embed_dims), else has \ shape (1, bs, num_query, embed_dims). - init_reference_out: The initial value of reference \ points, has shape (bs, num_queries, 4). - inter_references_out: The internal value of reference \ points in decoder, has shape \ (num_dec_layers, bs,num_query, embed_dims) - enc_outputs_class: The classification score of \ proposals generated from \ encoder's feature maps, has shape \ (batch, h*w, num_classes). \ Only would be returned when `as_two_stage` is True, \ otherwise None. - enc_outputs_coord_unact: The regression results \ generated from encoder's feature maps., has shape \ (batch, h*w, 4). Only would \ be returned when `as_two_stage` is True, \ otherwise None. """ assert self.as_two_stage or query_embed is not None feat_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] for lvl, (feat, mask, pos_embed) in enumerate( zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): bs, c, h, w = feat.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) feat = feat.flatten(2).transpose(1, 2) mask = mask.flatten(1) pos_embed = pos_embed.flatten(2).transpose(1, 2) lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1) lvl_pos_embed_flatten.append(lvl_pos_embed) feat_flatten.append(feat) mask_flatten.append(mask) feat_flatten = torch.cat(feat_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) spatial_shapes = torch.as_tensor( spatial_shapes, dtype=torch.long, device=feat_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros( (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack( [self.get_valid_ratio(m) for m in mlvl_masks], 1) reference_points = \ self.get_reference_points(spatial_shapes, valid_ratios, device=feat.device) feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute( 1, 0, 2) # (H*W, bs, embed_dims) memory = self.encoder( query=feat_flatten, key=None, value=None, query_pos=lvl_pos_embed_flatten, query_key_padding_mask=mask_flatten, spatial_shapes=spatial_shapes, reference_points=reference_points, level_start_index=level_start_index, valid_ratios=valid_ratios, **kwargs) memory = memory.permute(1, 0, 2) bs, _, c = memory.shape if self.as_two_stage: output_memory, output_proposals = \ self.gen_encoder_output_proposals( memory, mask_flatten, spatial_shapes) enc_outputs_class = cls_branches[self.decoder.num_layers]( output_memory) enc_outputs_coord_unact = \ reg_branches[ self.decoder.num_layers](output_memory) + output_proposals topk = self.two_stage_num_proposals # We only use the first channel in enc_outputs_class as foreground, # the other (num_classes - 1) channels are actually not used. # Its targets are set to be 0s, which indicates the first # class (foreground) because we use [0, num_classes - 1] to # indicate class labels, background class is indicated by # num_classes (similar convention in RPN). # See https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/deformable_detr_head.py#L241 # noqa # This follows the official implementation of Deformable DETR. topk_proposals = torch.topk( enc_outputs_class[..., 0], topk, dim=1)[1] topk_coords_unact = torch.gather( enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) topk_coords_unact = topk_coords_unact.detach() reference_points = topk_coords_unact.sigmoid() init_reference_out = reference_points pos_trans_out = self.pos_trans_norm( self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))) query_pos, query = torch.split(pos_trans_out, c, dim=2) else: query_pos, query = torch.split(query_embed, c, dim=1) query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) query = query.unsqueeze(0).expand(bs, -1, -1) reference_points = self.reference_points(query_pos).sigmoid() init_reference_out = reference_points # decoder query = query.permute(1, 0, 2) memory = memory.permute(1, 0, 2) query_pos = query_pos.permute(1, 0, 2) inter_states, inter_references = self.decoder( query=query, key=None, value=memory, query_pos=query_pos, key_padding_mask=mask_flatten, reference_points=reference_points, spatial_shapes=spatial_shapes, level_start_index=level_start_index, valid_ratios=valid_ratios, reg_branches=reg_branches, **kwargs) inter_references_out = inter_references if self.as_two_stage: return inter_states, init_reference_out,\ inter_references_out, enc_outputs_class,\ enc_outputs_coord_unact return inter_states, init_reference_out, \ inter_references_out, None, None @MODELS.register_module() class CoDeformableDetrTransformerDecoder(TransformerLayerSequence): """Implements the decoder in DETR transformer. Args: return_intermediate (bool): Whether to return intermediate outputs. coder_norm_cfg (dict): Config of last normalization layer. Default: `LN`. """ def __init__(self, *args, return_intermediate=False, look_forward_twice=False, **kwargs): super(CoDeformableDetrTransformerDecoder, self).__init__(*args, **kwargs) self.return_intermediate = return_intermediate self.look_forward_twice = look_forward_twice def forward(self, query, *args, reference_points=None, valid_ratios=None, reg_branches=None, **kwargs): """Forward function for `TransformerDecoder`. Args: query (Tensor): Input query with shape `(num_query, bs, embed_dims)`. reference_points (Tensor): The reference points of offset. has shape (bs, num_query, 4) when as_two_stage, otherwise has shape ((bs, num_query, 2). valid_ratios (Tensor): The radios of valid points on the feature map, has shape (bs, num_levels, 2) reg_branch: (obj:`nn.ModuleList`): Used for refining the regression results. Only would be passed when with_box_refine is True, otherwise would be passed a `None`. Returns: Tensor: Results with shape [1, num_query, bs, embed_dims] when return_intermediate is `False`, otherwise it has shape [num_layers, num_query, bs, embed_dims]. """ output = query intermediate = [] intermediate_reference_points = [] for lid, layer in enumerate(self.layers): if reference_points.shape[-1] == 4: reference_points_input = reference_points[:, :, None] * \ torch.cat([valid_ratios, valid_ratios], -1)[:, None] else: assert reference_points.shape[-1] == 2 reference_points_input = reference_points[:, :, None] * \ valid_ratios[:, None] output = layer( output, *args, reference_points=reference_points_input, **kwargs) output = output.permute(1, 0, 2) if reg_branches is not None: tmp = reg_branches[lid](output) if reference_points.shape[-1] == 4: new_reference_points = tmp + inverse_sigmoid( reference_points) new_reference_points = new_reference_points.sigmoid() else: assert reference_points.shape[-1] == 2 new_reference_points = tmp new_reference_points[..., :2] = tmp[ ..., :2] + inverse_sigmoid(reference_points) new_reference_points = new_reference_points.sigmoid() reference_points = new_reference_points.detach() output = output.permute(1, 0, 2) if self.return_intermediate: intermediate.append(output) intermediate_reference_points.append( new_reference_points if self. look_forward_twice else reference_points) if self.return_intermediate: return torch.stack(intermediate), torch.stack( intermediate_reference_points) return output, reference_points @MODELS.register_module() class CoDeformableDetrTransformer(DeformableDetrTransformer): def __init__(self, mixed_selection=True, with_pos_coord=True, with_coord_feat=True, num_co_heads=1, **kwargs): self.mixed_selection = mixed_selection self.with_pos_coord = with_pos_coord self.with_coord_feat = with_coord_feat self.num_co_heads = num_co_heads super(CoDeformableDetrTransformer, self).__init__(**kwargs) self._init_layers() def _init_layers(self): """Initialize layers of the CoDeformableDetrTransformer.""" if self.with_pos_coord: if self.num_co_heads > 0: # bug: this code should be 'self.head_pos_embed = # nn.Embedding(self.num_co_heads, self.embed_dims)', # we keep this bug for reproducing our results with ResNet-50. # You can fix this bug when reproducing results with # swin transformer. self.head_pos_embed = nn.Embedding(self.num_co_heads, 1, 1, self.embed_dims) self.aux_pos_trans = nn.ModuleList() self.aux_pos_trans_norm = nn.ModuleList() self.pos_feats_trans = nn.ModuleList() self.pos_feats_norm = nn.ModuleList() for i in range(self.num_co_heads): self.aux_pos_trans.append( nn.Linear(self.embed_dims * 2, self.embed_dims * 2)) self.aux_pos_trans_norm.append( nn.LayerNorm(self.embed_dims * 2)) if self.with_coord_feat: self.pos_feats_trans.append( nn.Linear(self.embed_dims, self.embed_dims)) self.pos_feats_norm.append( nn.LayerNorm(self.embed_dims)) def get_proposal_pos_embed(self, proposals, num_pos_feats=128, temperature=10000): """Get the position embedding of proposal.""" num_pos_feats = self.embed_dims // 2 scale = 2 * math.pi dim_t = torch.arange( num_pos_feats, dtype=torch.float32, device=proposals.device) dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats) # N, L, 4 proposals = proposals.sigmoid() * scale # N, L, 4, 128 pos = proposals[:, :, :, None] / dim_t # N, L, 4, 64, 2 pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2) return pos def forward(self, mlvl_feats, mlvl_masks, query_embed, mlvl_pos_embeds, reg_branches=None, cls_branches=None, return_encoder_output=False, attn_masks=None, **kwargs): """Forward function for `Transformer`. Args: mlvl_feats (list(Tensor)): Input queries from different level. Each element has shape [bs, embed_dims, h, w]. mlvl_masks (list(Tensor)): The key_padding_mask from different level used for encoder and decoder, each element has shape [bs, h, w]. query_embed (Tensor): The query embedding for decoder, with shape [num_query, c]. mlvl_pos_embeds (list(Tensor)): The positional encoding of feats from different level, has the shape [bs, embed_dims, h, w]. reg_branches (obj:`nn.ModuleList`): Regression heads for feature maps from each decoder layer. Only would be passed when `with_box_refine` is True. Default to None. cls_branches (obj:`nn.ModuleList`): Classification heads for feature maps from each decoder layer. Only would be passed when `as_two_stage` is True. Default to None. Returns: tuple[Tensor]: results of decoder containing the following tensor. - inter_states: Outputs from decoder. If return_intermediate_dec is True output has shape \ (num_dec_layers, bs, num_query, embed_dims), else has \ shape (1, bs, num_query, embed_dims). - init_reference_out: The initial value of reference \ points, has shape (bs, num_queries, 4). - inter_references_out: The internal value of reference \ points in decoder, has shape \ (num_dec_layers, bs,num_query, embed_dims) - enc_outputs_class: The classification score of \ proposals generated from \ encoder's feature maps, has shape \ (batch, h*w, num_classes). \ Only would be returned when `as_two_stage` is True, \ otherwise None. - enc_outputs_coord_unact: The regression results \ generated from encoder's feature maps., has shape \ (batch, h*w, 4). Only would \ be returned when `as_two_stage` is True, \ otherwise None. """ assert self.as_two_stage or query_embed is not None feat_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] for lvl, (feat, mask, pos_embed) in enumerate( zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): bs, c, h, w = feat.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) feat = feat.flatten(2).transpose(1, 2) mask = mask.flatten(1) pos_embed = pos_embed.flatten(2).transpose(1, 2) lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1) lvl_pos_embed_flatten.append(lvl_pos_embed) feat_flatten.append(feat) mask_flatten.append(mask) feat_flatten = torch.cat(feat_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) spatial_shapes = torch.as_tensor( spatial_shapes, dtype=torch.long, device=feat_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros( (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack( [self.get_valid_ratio(m) for m in mlvl_masks], 1) reference_points = \ self.get_reference_points(spatial_shapes, valid_ratios, device=feat.device) feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute( 1, 0, 2) # (H*W, bs, embed_dims) memory = self.encoder( query=feat_flatten, key=None, value=None, query_pos=lvl_pos_embed_flatten, query_key_padding_mask=mask_flatten, spatial_shapes=spatial_shapes, reference_points=reference_points, level_start_index=level_start_index, valid_ratios=valid_ratios, **kwargs) memory = memory.permute(1, 0, 2) bs, _, c = memory.shape if self.as_two_stage: output_memory, output_proposals = \ self.gen_encoder_output_proposals( memory, mask_flatten, spatial_shapes) enc_outputs_class = cls_branches[self.decoder.num_layers]( output_memory) enc_outputs_coord_unact = \ reg_branches[ self.decoder.num_layers](output_memory) + output_proposals topk = self.two_stage_num_proposals topk = query_embed.shape[0] topk_proposals = torch.topk( enc_outputs_class[..., 0], topk, dim=1)[1] topk_coords_unact = torch.gather( enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) topk_coords_unact = topk_coords_unact.detach() reference_points = topk_coords_unact.sigmoid() init_reference_out = reference_points pos_trans_out = self.pos_trans_norm( self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))) if not self.mixed_selection: query_pos, query = torch.split(pos_trans_out, c, dim=2) else: # query_embed here is the content embed for deformable DETR query = query_embed.unsqueeze(0).expand(bs, -1, -1) query_pos, _ = torch.split(pos_trans_out, c, dim=2) else: query_pos, query = torch.split(query_embed, c, dim=1) query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) query = query.unsqueeze(0).expand(bs, -1, -1) reference_points = self.reference_points(query_pos).sigmoid() init_reference_out = reference_points # decoder query = query.permute(1, 0, 2) memory = memory.permute(1, 0, 2) query_pos = query_pos.permute(1, 0, 2) inter_states, inter_references = self.decoder( query=query, key=None, value=memory, query_pos=query_pos, key_padding_mask=mask_flatten, reference_points=reference_points, spatial_shapes=spatial_shapes, level_start_index=level_start_index, valid_ratios=valid_ratios, reg_branches=reg_branches, attn_masks=attn_masks, **kwargs) inter_references_out = inter_references if self.as_two_stage: if return_encoder_output: return inter_states, init_reference_out,\ inter_references_out, enc_outputs_class,\ enc_outputs_coord_unact, memory return inter_states, init_reference_out,\ inter_references_out, enc_outputs_class,\ enc_outputs_coord_unact if return_encoder_output: return inter_states, init_reference_out, \ inter_references_out, None, None, memory return inter_states, init_reference_out, \ inter_references_out, None, None def forward_aux(self, mlvl_feats, mlvl_masks, query_embed, mlvl_pos_embeds, pos_anchors, pos_feats=None, reg_branches=None, cls_branches=None, return_encoder_output=False, attn_masks=None, head_idx=0, **kwargs): feat_flatten = [] mask_flatten = [] spatial_shapes = [] for lvl, (feat, mask, pos_embed) in enumerate( zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): bs, c, h, w = feat.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) feat = feat.flatten(2).transpose(1, 2) mask = mask.flatten(1) feat_flatten.append(feat) mask_flatten.append(mask) feat_flatten = torch.cat(feat_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1) spatial_shapes = torch.as_tensor( spatial_shapes, dtype=torch.long, device=feat_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros( (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack( [self.get_valid_ratio(m) for m in mlvl_masks], 1) feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) memory = feat_flatten memory = memory.permute(1, 0, 2) bs, _, c = memory.shape topk_coords_unact = inverse_sigmoid(pos_anchors) reference_points = pos_anchors init_reference_out = reference_points if self.num_co_heads > 0: pos_trans_out = self.aux_pos_trans_norm[head_idx]( self.aux_pos_trans[head_idx]( self.get_proposal_pos_embed(topk_coords_unact))) query_pos, query = torch.split(pos_trans_out, c, dim=2) if self.with_coord_feat: query = query + self.pos_feats_norm[head_idx]( self.pos_feats_trans[head_idx](pos_feats)) query_pos = query_pos + self.head_pos_embed.weight[head_idx] # decoder query = query.permute(1, 0, 2) memory = memory.permute(1, 0, 2) query_pos = query_pos.permute(1, 0, 2) inter_states, inter_references = self.decoder( query=query, key=None, value=memory, query_pos=query_pos, key_padding_mask=mask_flatten, reference_points=reference_points, spatial_shapes=spatial_shapes, level_start_index=level_start_index, valid_ratios=valid_ratios, reg_branches=reg_branches, attn_masks=attn_masks, **kwargs) inter_references_out = inter_references return inter_states, init_reference_out, \ inter_references_out def build_MLP(input_dim, hidden_dim, output_dim, num_layers): assert num_layers > 1, \ f'num_layers should be greater than 1 but got {num_layers}' h = [hidden_dim] * (num_layers - 1) layers = list() for n, k in zip([input_dim] + h[:-1], h): layers.extend((nn.Linear(n, k), nn.ReLU())) # Note that the relu func of MLP in original DETR repo is set # 'inplace=False', however the ReLU cfg of FFN in mmdet is set # 'inplace=True' by default. layers.append(nn.Linear(hidden_dim, output_dim)) return nn.Sequential(*layers) @MODELS.register_module() class DinoTransformerDecoder(DeformableDetrTransformerDecoder): def __init__(self, *args, **kwargs): super(DinoTransformerDecoder, self).__init__(*args, **kwargs) self._init_layers() def _init_layers(self): self.ref_point_head = build_MLP(self.embed_dims * 2, self.embed_dims, self.embed_dims, 2) self.norm = nn.LayerNorm(self.embed_dims) @staticmethod def gen_sineembed_for_position(pos_tensor, pos_feat): # n_query, bs, _ = pos_tensor.size() # sineembed_tensor = torch.zeros(n_query, bs, 256) scale = 2 * math.pi dim_t = torch.arange( pos_feat, dtype=torch.float32, device=pos_tensor.device) dim_t = 10000**(2 * (dim_t // 2) / pos_feat) x_embed = pos_tensor[:, :, 0] * scale y_embed = pos_tensor[:, :, 1] * scale pos_x = x_embed[:, :, None] / dim_t pos_y = y_embed[:, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) if pos_tensor.size(-1) == 2: pos = torch.cat((pos_y, pos_x), dim=2) elif pos_tensor.size(-1) == 4: w_embed = pos_tensor[:, :, 2] * scale pos_w = w_embed[:, :, None] / dim_t pos_w = torch.stack( (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) h_embed = pos_tensor[:, :, 3] * scale pos_h = h_embed[:, :, None] / dim_t pos_h = torch.stack( (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) else: raise ValueError('Unknown pos_tensor shape(-1):{}'.format( pos_tensor.size(-1))) return pos def forward(self, query, *args, reference_points=None, valid_ratios=None, reg_branches=None, **kwargs): output = query intermediate = [] intermediate_reference_points = [reference_points] for lid, layer in enumerate(self.layers): if reference_points.shape[-1] == 4: reference_points_input = \ reference_points[:, :, None] * torch.cat( [valid_ratios, valid_ratios], -1)[:, None] else: assert reference_points.shape[-1] == 2 reference_points_input = \ reference_points[:, :, None] * valid_ratios[:, None] query_sine_embed = self.gen_sineembed_for_position( reference_points_input[:, :, 0, :], self.embed_dims // 2) query_pos = self.ref_point_head(query_sine_embed) query_pos = query_pos.permute(1, 0, 2) output = layer( output, *args, query_pos=query_pos, reference_points=reference_points_input, **kwargs) output = output.permute(1, 0, 2) if reg_branches is not None: tmp = reg_branches[lid](output) assert reference_points.shape[-1] == 4 new_reference_points = tmp + inverse_sigmoid( reference_points, eps=1e-3) new_reference_points = new_reference_points.sigmoid() reference_points = new_reference_points.detach() output = output.permute(1, 0, 2) if self.return_intermediate: intermediate.append(self.norm(output)) intermediate_reference_points.append(new_reference_points) # NOTE this is for the "Look Forward Twice" module, # in the DeformDETR, reference_points was appended. if self.return_intermediate: return torch.stack(intermediate), torch.stack( intermediate_reference_points) return output, reference_points @MODELS.register_module() class CoDinoTransformer(CoDeformableDetrTransformer): def __init__(self, *args, **kwargs): super(CoDinoTransformer, self).__init__(*args, **kwargs) def init_layers(self): """Initialize layers of the DinoTransformer.""" self.level_embeds = nn.Parameter( torch.Tensor(self.num_feature_levels, self.embed_dims)) self.enc_output = nn.Linear(self.embed_dims, self.embed_dims) self.enc_output_norm = nn.LayerNorm(self.embed_dims) self.query_embed = nn.Embedding(self.two_stage_num_proposals, self.embed_dims) def _init_layers(self): if self.with_pos_coord: if self.num_co_heads > 0: self.aux_pos_trans = nn.ModuleList() self.aux_pos_trans_norm = nn.ModuleList() self.pos_feats_trans = nn.ModuleList() self.pos_feats_norm = nn.ModuleList() for i in range(self.num_co_heads): self.aux_pos_trans.append( nn.Linear(self.embed_dims * 2, self.embed_dims)) self.aux_pos_trans_norm.append( nn.LayerNorm(self.embed_dims)) if self.with_coord_feat: self.pos_feats_trans.append( nn.Linear(self.embed_dims, self.embed_dims)) self.pos_feats_norm.append( nn.LayerNorm(self.embed_dims)) def init_weights(self): super().init_weights() nn.init.normal_(self.query_embed.weight.data) def forward(self, mlvl_feats, mlvl_masks, query_embed, mlvl_pos_embeds, dn_label_query, dn_bbox_query, attn_mask, reg_branches=None, cls_branches=None, **kwargs): assert self.as_two_stage and query_embed is None, \ 'as_two_stage must be True for DINO' feat_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] for lvl, (feat, mask, pos_embed) in enumerate( zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): bs, c, h, w = feat.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) feat = feat.flatten(2).transpose(1, 2) mask = mask.flatten(1) pos_embed = pos_embed.flatten(2).transpose(1, 2) lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1) lvl_pos_embed_flatten.append(lvl_pos_embed) feat_flatten.append(feat) mask_flatten.append(mask) feat_flatten = torch.cat(feat_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) spatial_shapes = torch.as_tensor( spatial_shapes, dtype=torch.long, device=feat_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros( (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack( [self.get_valid_ratio(m) for m in mlvl_masks], 1) reference_points = self.get_reference_points( spatial_shapes, valid_ratios, device=feat.device) feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute( 1, 0, 2) # (H*W, bs, embed_dims) memory = self.encoder( query=feat_flatten, key=None, value=None, query_pos=lvl_pos_embed_flatten, query_key_padding_mask=mask_flatten, spatial_shapes=spatial_shapes, reference_points=reference_points, level_start_index=level_start_index, valid_ratios=valid_ratios, **kwargs) memory = memory.permute(1, 0, 2) bs, _, c = memory.shape output_memory, output_proposals = self.gen_encoder_output_proposals( memory, mask_flatten, spatial_shapes) enc_outputs_class = cls_branches[self.decoder.num_layers]( output_memory) enc_outputs_coord_unact = reg_branches[self.decoder.num_layers]( output_memory) + output_proposals cls_out_features = cls_branches[self.decoder.num_layers].out_features topk = self.two_stage_num_proposals # NOTE In DeformDETR, enc_outputs_class[..., 0] is used for topk topk_indices = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1] topk_score = torch.gather( enc_outputs_class, 1, topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features)) topk_coords_unact = torch.gather( enc_outputs_coord_unact, 1, topk_indices.unsqueeze(-1).repeat(1, 1, 4)) topk_anchor = topk_coords_unact.sigmoid() topk_coords_unact = topk_coords_unact.detach() query = self.query_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # NOTE the query_embed here is not spatial query as in DETR. # It is actually content query, which is named tgt in other # DETR-like models if dn_label_query is not None: query = torch.cat([dn_label_query, query], dim=1) if dn_bbox_query is not None: reference_points = torch.cat([dn_bbox_query, topk_coords_unact], dim=1) else: reference_points = topk_coords_unact reference_points = reference_points.sigmoid() # decoder query = query.permute(1, 0, 2) memory = memory.permute(1, 0, 2) inter_states, inter_references = self.decoder( query=query, key=None, value=memory, attn_masks=attn_mask, key_padding_mask=mask_flatten, reference_points=reference_points, spatial_shapes=spatial_shapes, level_start_index=level_start_index, valid_ratios=valid_ratios, reg_branches=reg_branches, **kwargs) inter_references_out = inter_references return inter_states, inter_references_out, \ topk_score, topk_anchor, memory def forward_aux(self, mlvl_feats, mlvl_masks, query_embed, mlvl_pos_embeds, pos_anchors, pos_feats=None, reg_branches=None, cls_branches=None, return_encoder_output=False, attn_masks=None, head_idx=0, **kwargs): feat_flatten = [] mask_flatten = [] spatial_shapes = [] for lvl, (feat, mask, pos_embed) in enumerate( zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): bs, c, h, w = feat.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) feat = feat.flatten(2).transpose(1, 2) mask = mask.flatten(1) feat_flatten.append(feat) mask_flatten.append(mask) feat_flatten = torch.cat(feat_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1) spatial_shapes = torch.as_tensor( spatial_shapes, dtype=torch.long, device=feat_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros( (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack( [self.get_valid_ratio(m) for m in mlvl_masks], 1) feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) memory = feat_flatten memory = memory.permute(1, 0, 2) bs, _, c = memory.shape topk_coords_unact = inverse_sigmoid(pos_anchors) reference_points = pos_anchors if self.num_co_heads > 0: pos_trans_out = self.aux_pos_trans_norm[head_idx]( self.aux_pos_trans[head_idx]( self.get_proposal_pos_embed(topk_coords_unact))) query = pos_trans_out if self.with_coord_feat: query = query + self.pos_feats_norm[head_idx]( self.pos_feats_trans[head_idx](pos_feats)) # decoder query = query.permute(1, 0, 2) memory = memory.permute(1, 0, 2) inter_states, inter_references = self.decoder( query=query, key=None, value=memory, attn_masks=None, key_padding_mask=mask_flatten, reference_points=reference_points, spatial_shapes=spatial_shapes, level_start_index=level_start_index, valid_ratios=valid_ratios, reg_branches=reg_branches, **kwargs) inter_references_out = inter_references return inter_states, inter_references_out @MODELS.register_module() class DetrTransformerEncoder(TransformerLayerSequence): """TransformerEncoder of DETR. Args: post_norm_cfg (dict): Config of last normalization layer. Default: `LN`. Only used when `self.pre_norm` is `True` """ def __init__(self, *args, post_norm_cfg=dict(type='LN'), with_cp=-1, **kwargs): super(DetrTransformerEncoder, self).__init__(*args, **kwargs) if post_norm_cfg is not None: self.post_norm = build_norm_layer( post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None else: assert not self.pre_norm, f'Use prenorm in ' \ f'{self.__class__.__name__},' \ f'Please specify post_norm_cfg' self.post_norm = None self.with_cp = with_cp if self.with_cp > 0: if checkpoint_wrapper is None: warnings.warn('If you want to reduce GPU memory usage, \ please install fairscale by executing the \ following command: pip install fairscale.') return for i in range(self.with_cp): self.layers[i] = checkpoint_wrapper(self.layers[i]) @MODELS.register_module() class DetrTransformerDecoderLayer(BaseTransformerLayer): """Implements decoder layer in DETR transformer. Args: attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )): Configs for self_attention or cross_attention, the order should be consistent with it in `operation_order`. If it is a dict, it would be expand to the number of attention in `operation_order`. feedforward_channels (int): The hidden dimension for FFNs. ffn_dropout (float): Probability of an element to be zeroed in ffn. Default 0.0. operation_order (tuple[str]): The execution order of operation in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). Default:None act_cfg (dict): The activation config for FFNs. Default: `LN` norm_cfg (dict): Config dict for normalization layer. Default: `LN`. ffn_num_fcs (int): The number of fully-connected layers in FFNs. Default:2. """ def __init__(self, attn_cfgs, feedforward_channels, ffn_dropout=0.0, operation_order=None, act_cfg=dict(type='ReLU', inplace=True), norm_cfg=dict(type='LN'), ffn_num_fcs=2, **kwargs): super(DetrTransformerDecoderLayer, self).__init__( attn_cfgs=attn_cfgs, feedforward_channels=feedforward_channels, ffn_dropout=ffn_dropout, operation_order=operation_order, act_cfg=act_cfg, norm_cfg=norm_cfg, ffn_num_fcs=ffn_num_fcs, **kwargs) assert len(operation_order) == 6 assert set(operation_order) == set( ['self_attn', 'norm', 'cross_attn', 'ffn'])