# Copyright (c) OpenMMLab. All rights reserved. from typing import Optional import torch from mmdet3d.models.test_time_augs import merge_aug_bboxes_3d from mmdet3d.registry import MODELS from .mvx_two_stage import MVXTwoStageDetector @MODELS.register_module() class CenterPoint(MVXTwoStageDetector): """Base class of Multi-modality VoxelNet. Args: pts_voxel_layer (dict, optional): Point cloud voxelization layer. Defaults to None. pts_voxel_encoder (dict, optional): Point voxelization encoder layer. Defaults to None. pts_middle_encoder (dict, optional): Middle encoder layer of points cloud modality. Defaults to None. pts_fusion_layer (dict, optional): Fusion layer. Defaults to None. img_backbone (dict, optional): Backbone of extracting images feature. Defaults to None. pts_backbone (dict, optional): Backbone of extracting points features. Defaults to None. img_neck (dict, optional): Neck of extracting image features. Defaults to None. pts_neck (dict, optional): Neck of extracting points features. Defaults to None. pts_bbox_head (dict, optional): Bboxes head of point cloud modality. Defaults to None. img_roi_head (dict, optional): RoI head of image modality. Defaults to None. img_rpn_head (dict, optional): RPN head of image modality. Defaults to None. train_cfg (dict, optional): Train config of model. Defaults to None. test_cfg (dict, optional): Train config of model. Defaults to None. init_cfg (dict, optional): Initialize config of model. Defaults to None. data_preprocessor (dict or ConfigDict, optional): The pre-process config of :class:`Det3DDataPreprocessor`. Defaults to None. """ def __init__(self, pts_voxel_layer: Optional[dict] = None, pts_voxel_encoder: Optional[dict] = None, pts_middle_encoder: Optional[dict] = None, pts_fusion_layer: Optional[dict] = None, img_backbone: Optional[dict] = None, pts_backbone: Optional[dict] = None, img_neck: Optional[dict] = None, pts_neck: Optional[dict] = None, pts_bbox_head: Optional[dict] = None, img_roi_head: Optional[dict] = None, img_rpn_head: Optional[dict] = None, train_cfg: Optional[dict] = None, test_cfg: Optional[dict] = None, init_cfg: Optional[dict] = None, data_preprocessor: Optional[dict] = None, **kwargs): super(CenterPoint, self).__init__(pts_voxel_layer, pts_voxel_encoder, pts_middle_encoder, pts_fusion_layer, img_backbone, pts_backbone, img_neck, pts_neck, pts_bbox_head, img_roi_head, img_rpn_head, train_cfg, test_cfg, init_cfg, data_preprocessor, **kwargs) # TODO support this def aug_test_pts(self, feats, img_metas, rescale=False): """Test function of point cloud branch with augmentaiton. The function implementation process is as follows: - step 1: map features back for double-flip augmentation. - step 2: merge all features and generate boxes. - step 3: map boxes back for scale augmentation. - step 4: merge results. Args: feats (list[torch.Tensor]): Feature of point cloud. img_metas (list[dict]): Meta information of samples. rescale (bool, optional): Whether to rescale bboxes. Default: False. Returns: dict: Returned bboxes consists of the following keys: - boxes_3d (:obj:`LiDARInstance3DBoxes`): Predicted bboxes. - scores_3d (torch.Tensor): Scores of predicted boxes. - labels_3d (torch.Tensor): Labels of predicted boxes. """ raise NotImplementedError # only support aug_test for one sample outs_list = [] for x, img_meta in zip(feats, img_metas): outs = self.pts_bbox_head(x) # merge augmented outputs before decoding bboxes for task_id, out in enumerate(outs): for key in out[0].keys(): if img_meta[0]['pcd_horizontal_flip']: outs[task_id][0][key] = torch.flip( outs[task_id][0][key], dims=[2]) if key == 'reg': outs[task_id][0][key][:, 1, ...] = 1 - outs[ task_id][0][key][:, 1, ...] elif key == 'rot': outs[task_id][0][ key][:, 0, ...] = -outs[task_id][0][key][:, 0, ...] elif key == 'vel': outs[task_id][0][ key][:, 1, ...] = -outs[task_id][0][key][:, 1, ...] if img_meta[0]['pcd_vertical_flip']: outs[task_id][0][key] = torch.flip( outs[task_id][0][key], dims=[3]) if key == 'reg': outs[task_id][0][key][:, 0, ...] = 1 - outs[ task_id][0][key][:, 0, ...] elif key == 'rot': outs[task_id][0][ key][:, 1, ...] = -outs[task_id][0][key][:, 1, ...] elif key == 'vel': outs[task_id][0][ key][:, 0, ...] = -outs[task_id][0][key][:, 0, ...] outs_list.append(outs) preds_dicts = dict() scale_img_metas = [] # concat outputs sharing the same pcd_scale_factor for i, (img_meta, outs) in enumerate(zip(img_metas, outs_list)): pcd_scale_factor = img_meta[0]['pcd_scale_factor'] if pcd_scale_factor not in preds_dicts.keys(): preds_dicts[pcd_scale_factor] = outs scale_img_metas.append(img_meta) else: for task_id, out in enumerate(outs): for key in out[0].keys(): preds_dicts[pcd_scale_factor][task_id][0][key] += out[ 0][key] aug_bboxes = [] for pcd_scale_factor, preds_dict in preds_dicts.items(): for task_id, pred_dict in enumerate(preds_dict): # merge outputs with different flips before decoding bboxes for key in pred_dict[0].keys(): preds_dict[task_id][0][key] /= len(outs_list) / len( preds_dicts.keys()) bbox_list = self.pts_bbox_head.get_bboxes( preds_dict, img_metas[0], rescale=rescale) bbox_list = [ dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels) for bboxes, scores, labels in bbox_list ] aug_bboxes.append(bbox_list[0]) if len(preds_dicts.keys()) > 1: # merge outputs with different scales after decoding bboxes merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, scale_img_metas, self.pts_bbox_head.test_cfg) return merged_bboxes else: for key in bbox_list[0].keys(): bbox_list[0][key] = bbox_list[0][key].to('cpu') return bbox_list[0] # TODO support this def aug_test(self, points, img_metas, imgs=None, rescale=False): raise NotImplementedError """Test function with augmentaiton.""" img_feats, pts_feats = self.extract_feats(points, img_metas, imgs) bbox_list = dict() if pts_feats and self.with_pts_bbox: pts_bbox = self.aug_test_pts(pts_feats, img_metas, rescale) bbox_list.update(pts_bbox=pts_bbox) return [bbox_list]