# Copyright (c) OpenMMLab. All rights reserved. from typing import List, Union import numpy as np from mmcv import BaseTransform from mmcv.transforms import to_tensor from mmengine import InstanceData from mmdet3d.core import Det3DDataSample, PointData from mmdet3d.core.bbox import BaseInstance3DBoxes from mmdet3d.core.points import BasePoints from mmdet3d.registry import TRANSFORMS @TRANSFORMS.register_module() class Pack3DDetInputs(BaseTransform): INPUTS_KEYS = ['points', 'img'] INSTANCEDATA_3D_KEYS = [ 'gt_bboxes_3d', 'gt_labels_3d', 'attr_labels', 'depths', 'centers_2d' ] INSTANCEDATA_2D_KEYS = [ 'gt_bboxes', 'gt_labels', ] SEG_KEYS = [ 'gt_seg_map', 'pts_instance_mask', 'pts_semantic_mask', 'gt_semantic_seg' ] def __init__( self, keys: dict, meta_keys: dict = ('filename', 'ori_shape', 'img_shape', 'lidar2img', 'depth2img', 'cam2img', 'pad_shape', 'scale_factor', 'flip', 'pcd_horizontal_flip', 'pcd_vertical_flip', 'box_mode_3d', 'box_type_3d', 'img_norm_cfg', 'pcd_trans', 'sample_idx', 'pcd_scale_factor', 'pcd_rotation', 'pcd_rotation_angle', 'pts_filename', 'transformation_3d_flow', 'trans_mat', 'affine_aug')): self.keys = keys self.meta_keys = meta_keys def _remove_prefix(self, key: str) -> str: if key.startswith('gt_'): key = key[3:] return key def transform(self, results: Union[dict, List[dict]]) -> Union[dict, List[dict]]: """Method to pack the input data. when the value in this dict is a list, it usually is in Augmentations Testing. Args: results (dict | list[dict]): Result dict from the data pipeline. Returns: dict | List[dict]: - 'inputs' (dict): The forward data of models. It usually contains following keys: - points - img - 'data_sample' (obj:`Det3DDataSample`): The annotation info of the sample. """ # augtest if isinstance(results, list): if len(results) == 1: # simple test return self.pack_single_results(results[0]) pack_results = [] for single_result in results: pack_results.append(self.pack_single_results(single_result)) return pack_results # norm training and simple testing elif isinstance(results, dict): return self.pack_single_results(results) else: raise NotImplementedError def pack_single_results(self, results): """Method to pack the single input data. when the value in this dict is a list, it usually is in Augmentations Testing. Args: results (dict): Result dict from the data pipeline. Returns: dict: A dict contains - 'inputs' (dict): The forward data of models. It usually contains following keys: - points - img - 'data_sample' (obj:`Det3DDataSample`): The annotation info of the sample. """ # Format 3D data if 'points' in results: if isinstance(results['points'], BasePoints): results['points'] = results['points'].tensor if 'img' in results: if isinstance(results['img'], list): # process multiple imgs in single frame imgs = [img.transpose(2, 0, 1) for img in results['img']] imgs = np.ascontiguousarray(np.stack(imgs, axis=0)) results['img'] = to_tensor(imgs) else: img = results['img'] if len(img.shape) < 3: img = np.expand_dims(img, -1) results['img'] = to_tensor( np.ascontiguousarray(img.transpose(2, 0, 1))) for key in [ 'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels', 'gt_labels_3d', 'attr_labels', 'pts_instance_mask', 'pts_semantic_mask', 'centers2d', 'depths' ]: if key not in results: continue if isinstance(results[key], list): results[key] = [to_tensor(res) for res in results[key]] else: results[key] = to_tensor(results[key]) if 'gt_bboxes_3d' in results: if not isinstance(results['gt_bboxes_3d'], BaseInstance3DBoxes): results['gt_bboxes_3d'] = to_tensor(results['gt_bboxes_3d']) if 'gt_semantic_seg' in results: results['gt_semantic_seg'] = to_tensor( results['gt_semantic_seg'][None]) if 'gt_seg_map' in results: results['gt_seg_map'] = results['gt_seg_map'][None, ...] data_sample = Det3DDataSample() gt_instances_3d = InstanceData() gt_instances = InstanceData() gt_pts_seg = PointData() img_metas = {} for key in self.meta_keys: if key in results: img_metas[key] = results[key] data_sample.set_metainfo(img_metas) inputs = {} for key in self.keys: if key in results: if key in self.INPUTS_KEYS: inputs[key] = results[key] elif key in self.INSTANCEDATA_3D_KEYS: gt_instances_3d[self._remove_prefix(key)] = results[key] elif key in self.INSTANCEDATA_2D_KEYS: gt_instances[self._remove_prefix(key)] = results[key] elif key in self.SEG_KEYS: gt_pts_seg[self._remove_prefix(key)] = results[key] else: raise NotImplementedError(f'Please modified ' f'`Pack3DDetInputs` ' f'to put {key} to ' f'corresponding field') data_sample.gt_instances_3d = gt_instances_3d data_sample.gt_instances = gt_instances data_sample.gt_pts_seg = gt_pts_seg if 'eval_ann_info' in results: data_sample.eval_ann_info = results['eval_ann_info'] else: data_sample.eval_ann_info = None packed_results = dict() packed_results['data_sample'] = data_sample packed_results['inputs'] = inputs return packed_results def __repr__(self) -> str: repr_str = self.__class__.__name__ repr_str += f'(keys={self.keys})' repr_str += f'(meta_keys={self.meta_keys})' return repr_str