# Copyright (c) OpenMMLab. All rights reserved. import numpy as np import torch from mmdet3d.registry import TASK_UTILS from .partial_bin_based_bbox_coder import PartialBinBasedBBoxCoder @TASK_UTILS.register_module() class GroupFree3DBBoxCoder(PartialBinBasedBBoxCoder): """Modified partial bin based bbox coder for GroupFree3D. Args: num_dir_bins (int): Number of bins to encode direction angle. num_sizes (int): Number of size clusters. mean_sizes (list[list[int]]): Mean size of bboxes in each class. with_rot (bool, optional): Whether the bbox is with rotation. Defaults to True. size_cls_agnostic (bool, optional): Whether the predicted size is class-agnostic. Defaults to True. """ def __init__(self, num_dir_bins, num_sizes, mean_sizes, with_rot=True, size_cls_agnostic=True): super(GroupFree3DBBoxCoder, self).__init__( num_dir_bins=num_dir_bins, num_sizes=num_sizes, mean_sizes=mean_sizes, with_rot=with_rot) self.size_cls_agnostic = size_cls_agnostic def encode(self, gt_bboxes_3d, gt_labels_3d): """Encode ground truth to prediction targets. Args: gt_bboxes_3d (BaseInstance3DBoxes): Ground truth bboxes with shape (n, 7). gt_labels_3d (torch.Tensor): Ground truth classes. Returns: tuple: Targets of center, size and direction. """ # generate center target center_target = gt_bboxes_3d.gravity_center # generate bbox size target size_target = gt_bboxes_3d.dims size_class_target = gt_labels_3d size_res_target = gt_bboxes_3d.dims - gt_bboxes_3d.tensor.new_tensor( self.mean_sizes)[size_class_target] # generate dir target box_num = gt_labels_3d.shape[0] if self.with_rot: (dir_class_target, dir_res_target) = self.angle2class(gt_bboxes_3d.yaw) else: dir_class_target = gt_labels_3d.new_zeros(box_num) dir_res_target = gt_bboxes_3d.tensor.new_zeros(box_num) return (center_target, size_target, size_class_target, size_res_target, dir_class_target, dir_res_target) def decode(self, bbox_out, prefix=''): """Decode predicted parts to bbox3d. Args: bbox_out (dict): Predictions from model, should contain keys below. - center: predicted bottom center of bboxes. - dir_class: predicted bbox direction class. - dir_res: predicted bbox direction residual. - size_class: predicted bbox size class. - size_res: predicted bbox size residual. - size: predicted class-agnostic bbox size prefix (str, optional): Decode predictions with specific prefix. Defaults to ''. Returns: torch.Tensor: Decoded bbox3d with shape (batch, n, 7). """ center = bbox_out[f'{prefix}center'] batch_size, num_proposal = center.shape[:2] # decode heading angle if self.with_rot: dir_class = torch.argmax(bbox_out[f'{prefix}dir_class'], -1) dir_res = torch.gather(bbox_out[f'{prefix}dir_res'], 2, dir_class.unsqueeze(-1)) dir_res.squeeze_(2) dir_angle = self.class2angle(dir_class, dir_res).reshape( batch_size, num_proposal, 1) else: dir_angle = center.new_zeros(batch_size, num_proposal, 1) # decode bbox size if self.size_cls_agnostic: bbox_size = bbox_out[f'{prefix}size'].reshape( batch_size, num_proposal, 3) else: size_class = torch.argmax( bbox_out[f'{prefix}size_class'], -1, keepdim=True) size_res = torch.gather( bbox_out[f'{prefix}size_res'], 2, size_class.unsqueeze(-1).repeat(1, 1, 1, 3)) mean_sizes = center.new_tensor(self.mean_sizes) size_base = torch.index_select(mean_sizes, 0, size_class.reshape(-1)) bbox_size = size_base.reshape(batch_size, num_proposal, -1) + size_res.squeeze(2) bbox3d = torch.cat([center, bbox_size, dir_angle], dim=-1) return bbox3d def split_pred(self, cls_preds, reg_preds, base_xyz, prefix=''): """Split predicted features to specific parts. Args: cls_preds (torch.Tensor): Class predicted features to split. reg_preds (torch.Tensor): Regression predicted features to split. base_xyz (torch.Tensor): Coordinates of points. prefix (str, optional): Decode predictions with specific prefix. Defaults to ''. Returns: dict[str, torch.Tensor]: Split results. """ results = {} start, end = 0, 0 cls_preds_trans = cls_preds.transpose(2, 1) reg_preds_trans = reg_preds.transpose(2, 1) # decode center end += 3 # (batch_size, num_proposal, 3) results[f'{prefix}center_residual'] = \ reg_preds_trans[..., start:end].contiguous() results[f'{prefix}center'] = base_xyz + \ reg_preds_trans[..., start:end].contiguous() start = end # decode direction end += self.num_dir_bins results[f'{prefix}dir_class'] = \ reg_preds_trans[..., start:end].contiguous() start = end end += self.num_dir_bins dir_res_norm = reg_preds_trans[..., start:end].contiguous() start = end results[f'{prefix}dir_res_norm'] = dir_res_norm results[f'{prefix}dir_res'] = dir_res_norm * ( np.pi / self.num_dir_bins) # decode size if self.size_cls_agnostic: end += 3 results[f'{prefix}size'] = \ reg_preds_trans[..., start:end].contiguous() else: end += self.num_sizes results[f'{prefix}size_class'] = reg_preds_trans[ ..., start:end].contiguous() start = end end += self.num_sizes * 3 size_res_norm = reg_preds_trans[..., start:end] batch_size, num_proposal = reg_preds_trans.shape[:2] size_res_norm = size_res_norm.view( [batch_size, num_proposal, self.num_sizes, 3]) start = end results[f'{prefix}size_res_norm'] = size_res_norm.contiguous() mean_sizes = reg_preds.new_tensor(self.mean_sizes) results[f'{prefix}size_res'] = ( size_res_norm * mean_sizes.unsqueeze(0).unsqueeze(0)) # decode objectness score # Group-Free-3D objectness output shape (batch, proposal, 1) results[f'{prefix}obj_scores'] = cls_preds_trans[..., :1].contiguous() # decode semantic score results[f'{prefix}sem_scores'] = cls_preds_trans[..., 1:].contiguous() return results