from pathlib import Path from collections import defaultdict import numpy as np import torch.utils.data as torch_data from .augmentor.data_augmentor import DataAugmentor from .processor.data_processor import DataProcessor from .processor.point_feature_encoder import PointFeatureEncoder from ..utils import common_utils class DatasetTemplate(torch_data.Dataset): def __init__(self, dataset_cfg=None, class_names=None, training=True, root_path=None, logger=None): super().__init__() self.dataset_cfg = dataset_cfg self.training = training self.class_names = class_names self.logger = logger self.root_path = root_path if root_path is not None else Path(self.dataset_cfg.DATA_PATH) self.logger = logger if self.dataset_cfg is None: return self.point_cloud_range = np.array(self.dataset_cfg.POINT_CLOUD_RANGE, dtype=np.float32) self.point_feature_encoder = PointFeatureEncoder( self.dataset_cfg.POINT_FEATURE_ENCODING, point_cloud_range=self.point_cloud_range ) self.data_augmentor = DataAugmentor( self.root_path, self.dataset_cfg.DATA_AUGMENTOR, self.class_names, logger=self.logger ) if self.training else None self.data_processor = DataProcessor( self.dataset_cfg.DATA_PROCESSOR, point_cloud_range=self.point_cloud_range, training=self.training ) self.grid_size = self.data_processor.grid_size self.voxel_size = self.data_processor.voxel_size @property def mode(self): return 'train' if self.training else 'test' def __len__(self): raise NotImplementedError def forward(self, index): raise NotImplementedError def prepare_data(self, data_dict): """ Args: data_dict: points: (N, 3 + C_in) gt_boxes: optional, (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...] gt_names: optional, (N), string ... Returns: data_dict: frame_id: string points: (N, 3 + C_in) gt_boxes: optional, (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...] gt_names: optional, (N), string use_lead_xyz: bool voxels: optional (num_voxels, max_points_per_voxel, 3 + C) voxel_coords: optional (num_voxels, 3) voxel_num_points: optional (num_voxels) ... """ if self.training: assert 'gt_boxes' in data_dict, 'gt_boxes should be provided for training' gt_boxes_mask = np.array([n in self.class_names for n in data_dict['gt_names']], dtype=np.bool_) data_dict = self.data_augmentor.forward( data_dict={ **data_dict, 'gt_boxes_mask': gt_boxes_mask } ) if len(data_dict['gt_boxes']) == 0: new_index = np.random.randint(self.__len__()) return self.__getitem__(new_index) if data_dict.get('gt_boxes', None) is not None: selected = common_utils.keep_arrays_by_name(data_dict['gt_names'], self.class_names) data_dict['gt_boxes'] = data_dict['gt_boxes'][selected] data_dict['gt_names'] = data_dict['gt_names'][selected] gt_classes = np.array([self.class_names.index(n) + 1 for n in data_dict['gt_names']], dtype=np.int32) gt_boxes = np.concatenate((data_dict['gt_boxes'], gt_classes.reshape(-1, 1).astype(np.float32)), axis=1) data_dict['gt_boxes'] = gt_boxes data_dict = self.point_feature_encoder.forward(data_dict) data_dict = self.data_processor.forward( data_dict=data_dict ) data_dict.pop('gt_names') return data_dict @staticmethod def collate_batch(batch_list, _unused=False): data_dict = defaultdict(list) for cur_sample in batch_list: for key, val in cur_sample.items(): data_dict[key].append(val) batch_size = len(batch_list) ret = {} for key, val in data_dict.items(): if key in ['voxels', 'voxel_num_points']: ret[key] = np.concatenate(val, axis=0) elif key in ['points', 'voxel_coords']: coors = [] for i, coor in enumerate(val): coor_pad = np.pad(coor, ((0, 0), (1, 0)), mode='constant', constant_values=i) coors.append(coor_pad) ret[key] = np.concatenate(coors, axis=0) elif key in ['gt_boxes']: max_gt = max([len(x) for x in val]) batch_gt_boxes3d = np.zeros((batch_size, max_gt, val[0].shape[-1]), dtype=np.float32) for k in range(batch_size): batch_gt_boxes3d[k, :val[k].__len__(), :] = val[k] ret[key] = batch_gt_boxes3d else: ret[key] = np.stack(val, axis=0) ret['batch_size'] = batch_size return ret