Commit 0db30f45 authored by chenshi3's avatar chenshi3
Browse files

Adjust the hook function and nuscenes_dataset.py

parent ed2bb815
...@@ -24,7 +24,7 @@ class DataAugmentor(object): ...@@ -24,7 +24,7 @@ class DataAugmentor(object):
cur_augmentor = getattr(self, cur_cfg.NAME)(config=cur_cfg) cur_augmentor = getattr(self, cur_cfg.NAME)(config=cur_cfg)
self.data_augmentor_queue.append(cur_augmentor) self.data_augmentor_queue.append(cur_augmentor)
def disableAugmentation(self, augmentor_configs): def disable_augmentation(self, augmentor_configs):
self.data_augmentor_queue = [] self.data_augmentor_queue = []
aug_config_list = augmentor_configs if isinstance(augmentor_configs, list) \ aug_config_list = augmentor_configs if isinstance(augmentor_configs, list) \
else augmentor_configs.AUG_CONFIG_LIST else augmentor_configs.AUG_CONFIG_LIST
......
...@@ -152,36 +152,7 @@ class NuScenesDataset(DatasetTemplate): ...@@ -152,36 +152,7 @@ class NuScenesDataset(DatasetTemplate):
input_dict['camera_imgs'] = crop_images input_dict['camera_imgs'] = crop_images
return input_dict return input_dict
def __len__(self): def load_camera_info(self, input_dict, info):
if self._merge_all_iters_to_one_epoch:
return len(self.infos) * self.total_epochs
return len(self.infos)
def __getitem__(self, index):
if self._merge_all_iters_to_one_epoch:
index = index % len(self.infos)
info = copy.deepcopy(self.infos[index])
points = self.get_lidar_with_sweeps(index, max_sweeps=self.dataset_cfg.MAX_SWEEPS)
input_dict = {
'points': points,
'frame_id': Path(info['lidar_path']).stem,
'metadata': {'token': info['token']}
}
if 'gt_boxes' in info:
if self.dataset_cfg.get('FILTER_MIN_POINTS_IN_GT', False):
mask = (info['num_lidar_pts'] > self.dataset_cfg.FILTER_MIN_POINTS_IN_GT - 1)
else:
mask = None
input_dict.update({
'gt_names': info['gt_names'] if mask is None else info['gt_names'][mask],
'gt_boxes': info['gt_boxes'] if mask is None else info['gt_boxes'][mask]
})
if self.use_camera:
input_dict["image_paths"] = [] input_dict["image_paths"] = []
input_dict["lidar2camera"] = [] input_dict["lidar2camera"] = []
input_dict["lidar2image"] = [] input_dict["lidar2image"] = []
...@@ -236,6 +207,40 @@ class NuScenesDataset(DatasetTemplate): ...@@ -236,6 +207,40 @@ class NuScenesDataset(DatasetTemplate):
# resize and crop image # resize and crop image
input_dict = self.crop_image(input_dict) input_dict = self.crop_image(input_dict)
return input_dict
def __len__(self):
if self._merge_all_iters_to_one_epoch:
return len(self.infos) * self.total_epochs
return len(self.infos)
def __getitem__(self, index):
if self._merge_all_iters_to_one_epoch:
index = index % len(self.infos)
info = copy.deepcopy(self.infos[index])
points = self.get_lidar_with_sweeps(index, max_sweeps=self.dataset_cfg.MAX_SWEEPS)
input_dict = {
'points': points,
'frame_id': Path(info['lidar_path']).stem,
'metadata': {'token': info['token']}
}
if 'gt_boxes' in info:
if self.dataset_cfg.get('FILTER_MIN_POINTS_IN_GT', False):
mask = (info['num_lidar_pts'] > self.dataset_cfg.FILTER_MIN_POINTS_IN_GT - 1)
else:
mask = None
input_dict.update({
'gt_names': info['gt_names'] if mask is None else info['gt_names'][mask],
'gt_boxes': info['gt_boxes'] if mask is None else info['gt_boxes'][mask]
})
if self.use_camera:
input_dict = self.load_camera_info(input_dict, info)
data_dict = self.prepare_data(data_dict=input_dict) data_dict = self.prepare_data(data_dict=input_dict)
if self.dataset_cfg.get('SET_NAN_VELOCITY_TO_ZEROS', False) and 'gt_boxes' in info: if self.dataset_cfg.get('SET_NAN_VELOCITY_TO_ZEROS', False) and 'gt_boxes' in info:
......
...@@ -267,6 +267,6 @@ def disable_augmentation_hook(hook_config, dataloader, total_epochs, cur_epoch, ...@@ -267,6 +267,6 @@ def disable_augmentation_hook(hook_config, dataloader, total_epochs, cur_epoch,
dataset_cfg=cfg.DATA_CONFIG dataset_cfg=cfg.DATA_CONFIG
logger.info(f'Disable augmentations: {DISABLE_AUG_LIST}') logger.info(f'Disable augmentations: {DISABLE_AUG_LIST}')
dataset_cfg.DATA_AUGMENTOR.DISABLE_AUG_LIST = DISABLE_AUG_LIST dataset_cfg.DATA_AUGMENTOR.DISABLE_AUG_LIST = DISABLE_AUG_LIST
dataloader._dataset.data_augmentor.disableAugmentation(dataset_cfg.DATA_AUGMENTOR) dataloader._dataset.data_augmentor.disable_augmentation(dataset_cfg.DATA_AUGMENTOR)
flag = True flag = True
return flag return flag
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment