# Copyright (c) OpenMMLab. All rights reserved. from typing import List, Sequence, Union import mmcv import numpy as np import torch from mmcv import BaseTransform from mmengine import InstanceData from numpy import dtype from mmdet3d.registry import TRANSFORMS from mmdet3d.structures import BaseInstance3DBoxes, Det3DDataSample, PointData from mmdet3d.structures.points import BasePoints def to_tensor( data: Union[torch.Tensor, np.ndarray, Sequence, int, float]) -> torch.Tensor: """Convert objects of various python types to :obj:`torch.Tensor`. Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`Sequence`, :class:`int` and :class:`float`. Args: data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to be converted. Returns: torch.Tensor: the converted data. """ if isinstance(data, torch.Tensor): return data elif isinstance(data, np.ndarray): if data.dtype is dtype('float64'): data = data.astype(np.float32) return torch.from_numpy(data) elif isinstance(data, Sequence) and not mmcv.is_str(data): return torch.tensor(data) elif isinstance(data, int): return torch.LongTensor([data]) elif isinstance(data, float): return torch.FloatTensor([data]) else: raise TypeError(f'type {type(data)} cannot be converted to tensor.') @TRANSFORMS.register_module() class Pack3DDetInputs(BaseTransform): INPUTS_KEYS = ['points', 'img'] INSTANCEDATA_3D_KEYS = [ 'gt_bboxes_3d', 'gt_labels_3d', 'attr_labels', 'depths', 'centers_2d' ] INSTANCEDATA_2D_KEYS = [ 'gt_bboxes', 'gt_bboxes_labels', ] SEG_KEYS = [ 'gt_seg_map', 'pts_instance_mask', 'pts_semantic_mask', 'gt_semantic_seg' ] def __init__( self, keys: dict, meta_keys: dict = ('img_path', 'ori_shape', 'img_shape', 'lidar2img', 'depth2img', 'cam2img', 'pad_shape', 'scale_factor', 'flip', 'pcd_horizontal_flip', 'pcd_vertical_flip', 'box_mode_3d', 'box_type_3d', 'img_norm_cfg', 'pcd_trans', 'sample_idx', 'pcd_scale_factor', 'pcd_rotation', 'pcd_rotation_angle', 'lidar_path', 'transformation_3d_flow', 'trans_mat', 'affine_aug')): self.keys = keys self.meta_keys = meta_keys def _remove_prefix(self, key: str) -> str: if key.startswith('gt_'): key = key[3:] return key def transform(self, results: Union[dict, List[dict]]) -> Union[dict, List[dict]]: """Method to pack the input data. when the value in this dict is a list, it usually is in Augmentations Testing. Args: results (dict | list[dict]): Result dict from the data pipeline. Returns: dict | List[dict]: - 'inputs' (dict): The forward data of models. It usually contains following keys: - points - img - 'data_sample' (obj:`Det3DDataSample`): The annotation info of the sample. """ # augtest if isinstance(results, list): if len(results) == 1: # simple test return self.pack_single_results(results[0]) pack_results = [] for single_result in results: pack_results.append(self.pack_single_results(single_result)) return pack_results # norm training and simple testing elif isinstance(results, dict): return self.pack_single_results(results) else: raise NotImplementedError def pack_single_results(self, results): """Method to pack the single input data. when the value in this dict is a list, it usually is in Augmentations Testing. Args: results (dict): Result dict from the data pipeline. Returns: dict: A dict contains - 'inputs' (dict): The forward data of models. It usually contains following keys: - points - img - 'data_sample' (obj:`Det3DDataSample`): The annotation info of the sample. """ # Format 3D data if 'points' in results: if isinstance(results['points'], BasePoints): results['points'] = results['points'].tensor if 'img' in results: if isinstance(results['img'], list): # process multiple imgs in single frame imgs = [img.transpose(2, 0, 1) for img in results['img']] imgs = np.ascontiguousarray(np.stack(imgs, axis=0)) results['img'] = to_tensor(imgs) else: img = results['img'] if len(img.shape) < 3: img = np.expand_dims(img, -1) results['img'] = to_tensor( np.ascontiguousarray(img.transpose(2, 0, 1))) for key in [ 'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels', 'gt_bboxes_labels', 'attr_labels', 'pts_instance_mask', 'pts_semantic_mask', 'centers_2d', 'depths', 'gt_labels_3d' ]: if key not in results: continue if isinstance(results[key], list): results[key] = [to_tensor(res) for res in results[key]] else: results[key] = to_tensor(results[key]) if 'gt_bboxes_3d' in results: if not isinstance(results['gt_bboxes_3d'], BaseInstance3DBoxes): results['gt_bboxes_3d'] = to_tensor(results['gt_bboxes_3d']) if 'gt_semantic_seg' in results: results['gt_semantic_seg'] = to_tensor( results['gt_semantic_seg'][None]) if 'gt_seg_map' in results: results['gt_seg_map'] = results['gt_seg_map'][None, ...] data_sample = Det3DDataSample() gt_instances_3d = InstanceData() gt_instances = InstanceData() gt_pts_seg = PointData() img_metas = {} for key in self.meta_keys: if key in results: img_metas[key] = results[key] data_sample.set_metainfo(img_metas) inputs = {} for key in self.keys: if key in results: if key in self.INPUTS_KEYS: inputs[key] = results[key] elif key in self.INSTANCEDATA_3D_KEYS: gt_instances_3d[self._remove_prefix(key)] = results[key] elif key in self.INSTANCEDATA_2D_KEYS: if key == 'gt_bboxes_labels': gt_instances['labels'] = results[key] else: gt_instances[self._remove_prefix(key)] = results[key] elif key in self.SEG_KEYS: gt_pts_seg[self._remove_prefix(key)] = results[key] else: raise NotImplementedError(f'Please modified ' f'`Pack3DDetInputs` ' f'to put {key} to ' f'corresponding field') data_sample.gt_instances_3d = gt_instances_3d data_sample.gt_instances = gt_instances data_sample.gt_pts_seg = gt_pts_seg if 'eval_ann_info' in results: data_sample.eval_ann_info = results['eval_ann_info'] else: data_sample.eval_ann_info = None packed_results = dict() packed_results['data_sample'] = data_sample packed_results['inputs'] = inputs return packed_results def __repr__(self) -> str: repr_str = self.__class__.__name__ repr_str += f'(keys={self.keys})' repr_str += f'(meta_keys={self.meta_keys})' return repr_str