# Copyright (c) OpenMMLab. All rights reserved. import torch from mmcv.runner import BaseModule from mmdet.core import bbox_xyxy_to_cxcywh from mmdet.models.utils.transformer import inverse_sigmoid class DnQueryGenerator(BaseModule): def __init__(self, num_queries, hidden_dim, num_classes, noise_scale=dict(label=0.5, box=0.4), group_cfg=dict( dynamic=True, num_groups=None, num_dn_queries=None)): super(DnQueryGenerator, self).__init__() self.num_queries = num_queries self.hidden_dim = hidden_dim self.num_classes = num_classes self.label_noise_scale = noise_scale['label'] self.box_noise_scale = noise_scale['box'] self.dynamic_dn_groups = group_cfg.get('dynamic', False) if self.dynamic_dn_groups: assert 'num_dn_queries' in group_cfg, \ 'num_dn_queries should be set when using ' \ 'dynamic dn groups' self.num_dn = group_cfg['num_dn_queries'] else: assert 'num_groups' in group_cfg, \ 'num_groups should be set when using ' \ 'static dn groups' self.num_dn = group_cfg['num_groups'] assert isinstance(self.num_dn, int) and self.num_dn >= 1, \ f'Expected the num in group_cfg to have type int. ' \ f'Found {type(self.num_dn)} ' def get_num_groups(self, group_queries=None): """ Args: group_queries (int): Number of dn queries in one group. """ if self.dynamic_dn_groups: assert group_queries is not None, \ 'group_queries should be provided when using ' \ 'dynamic dn groups' if group_queries == 0: num_groups = 1 else: num_groups = self.num_dn // group_queries else: num_groups = self.num_dn if num_groups < 1: # avoid num_groups < 1 in query generator num_groups = 1 return int(num_groups) def forward(self, gt_bboxes, gt_labels=None, label_enc=None, img_metas=None): """ Args: gt_bboxes (List[Tensor]): List of ground truth bboxes of the image, shape of each (num_gts, 4). gt_labels (List[Tensor]): List of ground truth labels of the image, shape of each (num_gts,), if None, TODO:noisy_label would be None. Returns: TODO """ # TODO: temp only support for CDN # TODO: temp assert gt_labels is not None and label_enc is not None if self.training: if gt_labels is not None: assert len(gt_bboxes) == len(gt_labels), \ f'the length of provided gt_labels ' \ f'{len(gt_labels)} should be equal to' \ f' that of gt_bboxes {len(gt_bboxes)}' assert gt_labels is not None \ and label_enc is not None \ and img_metas is not None # TODO: adjust args batch_size = len(gt_bboxes) # convert bbox gt_bboxes_list = [] for img_meta, bboxes in zip(img_metas, gt_bboxes): img_h, img_w, _ = img_meta['img_shape'] factor = bboxes.new_tensor([img_w, img_h, img_w, img_h]).unsqueeze(0) bboxes_normalized = bbox_xyxy_to_cxcywh(bboxes) / factor gt_bboxes_list.append(bboxes_normalized) gt_bboxes = gt_bboxes_list known = [torch.ones_like(labels) for labels in gt_labels] known_num = [sum(k) for k in known] num_groups = self.get_num_groups(int(max(known_num))) unmask_bbox = unmask_label = torch.cat(known) labels = torch.cat(gt_labels) boxes = torch.cat(gt_bboxes) batch_idx = torch.cat([ torch.full_like(t.long(), i) for i, t in enumerate(gt_labels) ]) known_indice = torch.nonzero(unmask_label + unmask_bbox) known_indice = known_indice.view(-1) known_indice = known_indice.repeat(2 * num_groups, 1).view(-1) known_labels = labels.repeat(2 * num_groups, 1).view(-1) known_bid = batch_idx.repeat(2 * num_groups, 1).view(-1) known_bboxs = boxes.repeat(2 * num_groups, 1) known_labels_expand = known_labels.clone() known_bbox_expand = known_bboxs.clone() if self.label_noise_scale > 0: p = torch.rand_like(known_labels_expand.float()) chosen_indice = torch.nonzero( p < (self.label_noise_scale * 0.5)).view(-1) new_label = torch.randint_like(chosen_indice, 0, self.num_classes) known_labels_expand.scatter_(0, chosen_indice, new_label) single_pad = int(max(known_num)) # TODO pad_size = int(single_pad * 2 * num_groups) positive_idx = torch.tensor(range( len(boxes))).long().cuda().unsqueeze(0).repeat(num_groups, 1) positive_idx += (torch.tensor(range(num_groups)) * len(boxes) * 2).long().cuda().unsqueeze(1) positive_idx = positive_idx.flatten() negative_idx = positive_idx + len(boxes) if self.box_noise_scale > 0: known_bbox_ = torch.zeros_like(known_bboxs) known_bbox_[:, : 2] = \ known_bboxs[:, : 2] - known_bboxs[:, 2:] / 2 known_bbox_[:, 2:] = \ known_bboxs[:, :2] + known_bboxs[:, 2:] / 2 diff = torch.zeros_like(known_bboxs) diff[:, :2] = known_bboxs[:, 2:] / 2 diff[:, 2:] = known_bboxs[:, 2:] / 2 rand_sign = torch.randint_like( known_bboxs, low=0, high=2, dtype=torch.float32) rand_sign = rand_sign * 2.0 - 1.0 rand_part = torch.rand_like(known_bboxs) rand_part[negative_idx] += 1.0 rand_part *= rand_sign known_bbox_ += \ torch.mul(rand_part, diff).cuda() * self.box_noise_scale known_bbox_ = known_bbox_.clamp(min=0.0, max=1.0) known_bbox_expand[:, :2] = \ (known_bbox_[:, :2] + known_bbox_[:, 2:]) / 2 known_bbox_expand[:, 2:] = \ known_bbox_[:, 2:] - known_bbox_[:, :2] m = known_labels_expand.long().to('cuda') input_label_embed = label_enc(m) input_bbox_embed = inverse_sigmoid(known_bbox_expand, eps=1e-3) padding_label = torch.zeros(pad_size, self.hidden_dim).cuda() padding_bbox = torch.zeros(pad_size, 4).cuda() input_query_label = padding_label.repeat(batch_size, 1, 1) input_query_bbox = padding_bbox.repeat(batch_size, 1, 1) map_known_indice = torch.tensor([]).to('cuda') if len(known_num): map_known_indice = torch.cat( [torch.tensor(range(num)) for num in known_num]) map_known_indice = torch.cat([ map_known_indice + single_pad * i for i in range(2 * num_groups) ]).long() if len(known_bid): input_query_label[(known_bid.long(), map_known_indice)] = input_label_embed input_query_bbox[(known_bid.long(), map_known_indice)] = input_bbox_embed tgt_size = pad_size + self.num_queries attn_mask = torch.ones(tgt_size, tgt_size).to('cuda') < 0 # match query cannot see the reconstruct attn_mask[pad_size:, :pad_size] = True # reconstruct cannot see each other for i in range(num_groups): if i == 0: attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True if i == num_groups - 1: attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * i * 2] = True else: attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), single_pad * 2 * (i + 1):pad_size] = True attn_mask[single_pad * 2 * i:single_pad * 2 * (i + 1), :single_pad * 2 * i] = True dn_meta = { 'pad_size': pad_size, 'num_dn_group': num_groups, } else: input_query_label = None input_query_bbox = None attn_mask = None dn_meta = None return input_query_label, input_query_bbox, attn_mask, dn_meta class CdnQueryGenerator(DnQueryGenerator): def __init__(self, *args, **kwargs): super(CdnQueryGenerator, self).__init__(*args, **kwargs) def build_dn_generator(dn_args): """ Args: dn_args (dict): Returns: """ if dn_args is None: return None type = dn_args.pop('type') if type == 'DnQueryGenerator': return DnQueryGenerator(**dn_args) elif type == 'CdnQueryGenerator': return CdnQueryGenerator(**dn_args) else: raise NotImplementedError(f'{type} is not supported yet')