import os.path as osp import tempfile import mmcv import numpy as np from torch.utils.data import Dataset from mmdet.datasets import DATASETS from ..core.bbox import (Box3DMode, CameraInstance3DBoxes, DepthInstance3DBoxes, LiDARInstance3DBoxes) from .pipelines import Compose @DATASETS.register_module() class Custom3DDataset(Dataset): """Customized 3D dataset This is the base dataset of SUNRGB-D, ScanNet, nuScenes, and KITTI dataset. Args: data_root (str): Path of dataset root. ann_file (str): Path of annotation file. pipeline (list[dict], optional): Pipeline used for data processing. Defaults to None. classes (tuple[str], optional): Classes used in the dataset. Defaults to None. modality ([dict], optional): Modality to specify the sensor data used as input. Defaults to None. box_type_3d (str, optional): Type of 3D box of this dataset. Based on the `box_type_3d`, the dataset will encapsulate the box to its original format then converted them to `box_type_3d`. Defaults to 'LiDAR'. Available options includes - 'LiDAR': box in LiDAR coordinates - 'Depth': box in depth coordinates, usually for indoor dataset - 'Camera': box in camera coordinates filter_empty_gt (bool, optional): Whether to filter empty GT. Defaults to True. test_mode (bool, optional): Whether the dataset is in test mode. Defaults to False. """ def __init__(self, data_root, ann_file, pipeline=None, classes=None, modality=None, box_type_3d='LiDAR', filter_empty_gt=True, test_mode=False): super().__init__() self.data_root = data_root self.ann_file = ann_file self.test_mode = test_mode self.modality = modality self.filter_empty_gt = filter_empty_gt self.get_box_type(box_type_3d) self.CLASSES = self.get_classes(classes) self.data_infos = self.load_annotations(self.ann_file) if pipeline is not None: self.pipeline = Compose(pipeline) # set group flag for the sampler if not self.test_mode: self._set_group_flag() def load_annotations(self, ann_file): return mmcv.load(ann_file) def get_box_type(self, box_type): box_type_lower = box_type.lower() if box_type_lower == 'lidar': self.box_type_3d = LiDARInstance3DBoxes self.box_mode_3d = Box3DMode.LIDAR elif box_type_lower == 'camera': self.box_type_3d = CameraInstance3DBoxes self.box_mode_3d = Box3DMode.CAM elif box_type_lower == 'depth': self.box_type_3d = DepthInstance3DBoxes self.box_mode_3d = Box3DMode.DEPTH else: raise ValueError('Only "box_type" of "camera", "lidar", "depth"' f' are supported, got {box_type}') def get_data_info(self, index): info = self.data_infos[index] sample_idx = info['point_cloud']['lidar_idx'] pts_filename = osp.join(self.data_root, info['pts_path']) input_dict = dict( pts_filename=pts_filename, sample_idx=sample_idx, file_name=pts_filename) if not self.test_mode: annos = self.get_ann_info(index) input_dict['ann_info'] = annos if self.filter_empty_gt and len(annos['gt_bboxes_3d']) == 0: return None return input_dict def pre_pipeline(self, results): results['img_fields'] = [] results['bbox3d_fields'] = [] results['pts_mask_fields'] = [] results['pts_seg_fields'] = [] results['bbox_fields'] = [] results['mask_fields'] = [] results['seg_fields'] = [] results['box_type_3d'] = self.box_type_3d results['box_mode_3d'] = self.box_mode_3d def prepare_train_data(self, index): input_dict = self.get_data_info(index) if input_dict is None: return None self.pre_pipeline(input_dict) example = self.pipeline(input_dict) if self.filter_empty_gt and (example is None or len( example['gt_bboxes_3d']._data) == 0): return None return example def prepare_test_data(self, index): input_dict = self.get_data_info(index) self.pre_pipeline(input_dict) example = self.pipeline(input_dict) return example @classmethod def get_classes(cls, classes=None): """Get class names of current dataset. Args: classes (Sequence[str] | str | None): If classes is None, use default CLASSES defined by builtin dataset. If classes is a string, take it as a file name. The file contains the name of classes where each line contains one class name. If classes is a tuple or list, override the CLASSES defined by the dataset. Return: list[str]: return the list of class names """ if classes is None: return cls.CLASSES if isinstance(classes, str): # take it as a file path class_names = mmcv.list_from_file(classes) elif isinstance(classes, (tuple, list)): class_names = classes else: raise ValueError(f'Unsupported type {type(classes)} of classes.') return class_names def format_results(self, outputs, pklfile_prefix=None, submission_prefix=None): if pklfile_prefix is None: tmp_dir = tempfile.TemporaryDirectory() pklfile_prefix = osp.join(tmp_dir.name, 'results') out = f'{pklfile_prefix}.pkl' mmcv.dump(outputs, out) return outputs, tmp_dir def evaluate(self, results, metric=None, iou_thr=(0.25, 0.5), logger=None, show=False, out_dir=None): """Evaluate. Evaluation in indoor protocol. Args: results (list[dict]): List of results. metric (str | list[str]): Metrics to be evaluated. iou_thr (list[float]): AP IoU thresholds. show (bool): Whether to visualize. Default: False. out_dir (str): Path to save the visualization results. Default: None. Returns: dict: Evaluation results. """ from mmdet3d.core.evaluation import indoor_eval assert isinstance( results, list), f'Expect results to be list, got {type(results)}.' assert len(results) > 0, f'Expect length of results > 0.' assert len(results) == len(self.data_infos) assert isinstance( results[0], dict ), f'Expect elements in results to be dict, got {type(results[0])}.' gt_annos = [info['annos'] for info in self.data_infos] label2cat = {i: cat_id for i, cat_id in enumerate(self.CLASSES)} ret_dict = indoor_eval( gt_annos, results, iou_thr, label2cat, logger=logger, box_type_3d=self.box_type_3d, box_mode_3d=self.box_mode_3d) if show: self.show(results, out_dir) return ret_dict def __len__(self): return len(self.data_infos) def _rand_another(self, idx): pool = np.where(self.flag == self.flag[idx])[0] return np.random.choice(pool) def __getitem__(self, idx): if self.test_mode: return self.prepare_test_data(idx) while True: data = self.prepare_train_data(idx) if data is None: idx = self._rand_another(idx) continue return data def _set_group_flag(self): """Set flag according to image aspect ratio. Images with aspect ratio greater than 1 will be set as group 1, otherwise group 0. In 3D datasets, they are all the same, thus are all zeros """ self.flag = np.zeros(len(self), dtype=np.uint8)