Commit 80173750 authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

bugfixed: remove data_dict.pop('calib') to support lidar+image in kitti

parent 16f9c032
...@@ -11,18 +11,18 @@ class DataAugmentor(object): ...@@ -11,18 +11,18 @@ class DataAugmentor(object):
self.root_path = root_path self.root_path = root_path
self.class_names = class_names self.class_names = class_names
self.logger = logger self.logger = logger
self.data_augmentor_queue = [] self.data_augmentor_queue = []
aug_config_list = augmentor_configs if isinstance(augmentor_configs, list) \ aug_config_list = augmentor_configs if isinstance(augmentor_configs, list) \
else augmentor_configs.AUG_CONFIG_LIST else augmentor_configs.AUG_CONFIG_LIST
for cur_cfg in aug_config_list: for cur_cfg in aug_config_list:
if not isinstance(augmentor_configs, list): if not isinstance(augmentor_configs, list):
if cur_cfg.NAME in augmentor_configs.DISABLE_AUG_LIST: if cur_cfg.NAME in augmentor_configs.DISABLE_AUG_LIST:
continue continue
cur_augmentor = getattr(self, cur_cfg.NAME)(config=cur_cfg) cur_augmentor = getattr(self, cur_cfg.NAME)(config=cur_cfg)
self.data_augmentor_queue.append(cur_augmentor) self.data_augmentor_queue.append(cur_augmentor)
def gt_sampling(self, config=None): def gt_sampling(self, config=None):
db_sampler = database_sampler.DataBaseSampler( db_sampler = database_sampler.DataBaseSampler(
root_path=self.root_path, root_path=self.root_path,
...@@ -31,15 +31,15 @@ class DataAugmentor(object): ...@@ -31,15 +31,15 @@ class DataAugmentor(object):
logger=self.logger logger=self.logger
) )
return db_sampler return db_sampler
def __getstate__(self): def __getstate__(self):
d = dict(self.__dict__) d = dict(self.__dict__)
del d['logger'] del d['logger']
return d return d
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
def random_world_flip(self, data_dict=None, config=None): def random_world_flip(self, data_dict=None, config=None):
if data_dict is None: if data_dict is None:
return partial(self.random_world_flip, config=config) return partial(self.random_world_flip, config=config)
...@@ -54,7 +54,7 @@ class DataAugmentor(object): ...@@ -54,7 +54,7 @@ class DataAugmentor(object):
data_dict['gt_boxes'] = gt_boxes data_dict['gt_boxes'] = gt_boxes
data_dict['points'] = points data_dict['points'] = points
return data_dict return data_dict
def random_world_rotation(self, data_dict=None, config=None): def random_world_rotation(self, data_dict=None, config=None):
if data_dict is None: if data_dict is None:
return partial(self.random_world_rotation, config=config) return partial(self.random_world_rotation, config=config)
...@@ -64,24 +64,24 @@ class DataAugmentor(object): ...@@ -64,24 +64,24 @@ class DataAugmentor(object):
gt_boxes, points, noise_rot = augmentor_utils.global_rotation( gt_boxes, points, noise_rot = augmentor_utils.global_rotation(
data_dict['gt_boxes'], data_dict['points'], rot_range=rot_range, return_rot=True data_dict['gt_boxes'], data_dict['points'], rot_range=rot_range, return_rot=True
) )
data_dict['gt_boxes'] = gt_boxes data_dict['gt_boxes'] = gt_boxes
data_dict['points'] = points data_dict['points'] = points
data_dict['noise_rot'] = noise_rot data_dict['noise_rot'] = noise_rot
return data_dict return data_dict
def random_world_scaling(self, data_dict=None, config=None): def random_world_scaling(self, data_dict=None, config=None):
if data_dict is None: if data_dict is None:
return partial(self.random_world_scaling, config=config) return partial(self.random_world_scaling, config=config)
gt_boxes, points, noise_scale = augmentor_utils.global_scaling( gt_boxes, points, noise_scale = augmentor_utils.global_scaling(
data_dict['gt_boxes'], data_dict['points'], config['WORLD_SCALE_RANGE'], return_scale=True data_dict['gt_boxes'], data_dict['points'], config['WORLD_SCALE_RANGE'], return_scale=True
) )
data_dict['gt_boxes'] = gt_boxes data_dict['gt_boxes'] = gt_boxes
data_dict['points'] = points data_dict['points'] = points
data_dict['noise_scale'] = noise_scale data_dict['noise_scale'] = noise_scale
return data_dict return data_dict
def random_image_flip(self, data_dict=None, config=None): def random_image_flip(self, data_dict=None, config=None):
if data_dict is None: if data_dict is None:
return partial(self.random_image_flip, config=config) return partial(self.random_image_flip, config=config)
...@@ -95,12 +95,12 @@ class DataAugmentor(object): ...@@ -95,12 +95,12 @@ class DataAugmentor(object):
images, depth_maps, gt_boxes = getattr(augmentor_utils, 'random_image_flip_%s' % cur_axis)( images, depth_maps, gt_boxes = getattr(augmentor_utils, 'random_image_flip_%s' % cur_axis)(
images, depth_maps, gt_boxes, calib, images, depth_maps, gt_boxes, calib,
) )
data_dict['images'] = images data_dict['images'] = images
data_dict['depth_maps'] = depth_maps data_dict['depth_maps'] = depth_maps
data_dict['gt_boxes'] = gt_boxes data_dict['gt_boxes'] = gt_boxes
return data_dict return data_dict
def random_world_translation(self, data_dict=None, config=None): def random_world_translation(self, data_dict=None, config=None):
if data_dict is None: if data_dict is None:
return partial(self.random_world_translation, config=config) return partial(self.random_world_translation, config=config)
...@@ -131,11 +131,11 @@ class DataAugmentor(object): ...@@ -131,11 +131,11 @@ class DataAugmentor(object):
gt_boxes, points = getattr(augmentor_utils, 'random_local_translation_along_%s' % cur_axis)( gt_boxes, points = getattr(augmentor_utils, 'random_local_translation_along_%s' % cur_axis)(
gt_boxes, points, offset_range, gt_boxes, points, offset_range,
) )
data_dict['gt_boxes'] = gt_boxes data_dict['gt_boxes'] = gt_boxes
data_dict['points'] = points data_dict['points'] = points
return data_dict return data_dict
def random_local_rotation(self, data_dict=None, config=None): def random_local_rotation(self, data_dict=None, config=None):
""" """
Please check the correctness of it before using. Please check the correctness of it before using.
...@@ -148,11 +148,11 @@ class DataAugmentor(object): ...@@ -148,11 +148,11 @@ class DataAugmentor(object):
gt_boxes, points = augmentor_utils.local_rotation( gt_boxes, points = augmentor_utils.local_rotation(
data_dict['gt_boxes'], data_dict['points'], rot_range=rot_range data_dict['gt_boxes'], data_dict['points'], rot_range=rot_range
) )
data_dict['gt_boxes'] = gt_boxes data_dict['gt_boxes'] = gt_boxes
data_dict['points'] = points data_dict['points'] = points
return data_dict return data_dict
def random_local_scaling(self, data_dict=None, config=None): def random_local_scaling(self, data_dict=None, config=None):
""" """
Please check the correctness of it before using. Please check the correctness of it before using.
...@@ -162,18 +162,18 @@ class DataAugmentor(object): ...@@ -162,18 +162,18 @@ class DataAugmentor(object):
gt_boxes, points = augmentor_utils.local_scaling( gt_boxes, points = augmentor_utils.local_scaling(
data_dict['gt_boxes'], data_dict['points'], config['LOCAL_SCALE_RANGE'] data_dict['gt_boxes'], data_dict['points'], config['LOCAL_SCALE_RANGE']
) )
data_dict['gt_boxes'] = gt_boxes data_dict['gt_boxes'] = gt_boxes
data_dict['points'] = points data_dict['points'] = points
return data_dict return data_dict
def random_world_frustum_dropout(self, data_dict=None, config=None): def random_world_frustum_dropout(self, data_dict=None, config=None):
""" """
Please check the correctness of it before using. Please check the correctness of it before using.
""" """
if data_dict is None: if data_dict is None:
return partial(self.random_world_frustum_dropout, config=config) return partial(self.random_world_frustum_dropout, config=config)
intensity_range = config['INTENSITY_RANGE'] intensity_range = config['INTENSITY_RANGE']
gt_boxes, points = data_dict['gt_boxes'], data_dict['points'] gt_boxes, points = data_dict['gt_boxes'], data_dict['points']
for direction in config['DIRECTION']: for direction in config['DIRECTION']:
...@@ -181,18 +181,18 @@ class DataAugmentor(object): ...@@ -181,18 +181,18 @@ class DataAugmentor(object):
gt_boxes, points = getattr(augmentor_utils, 'global_frustum_dropout_%s' % direction)( gt_boxes, points = getattr(augmentor_utils, 'global_frustum_dropout_%s' % direction)(
gt_boxes, points, intensity_range, gt_boxes, points, intensity_range,
) )
data_dict['gt_boxes'] = gt_boxes data_dict['gt_boxes'] = gt_boxes
data_dict['points'] = points data_dict['points'] = points
return data_dict return data_dict
def random_local_frustum_dropout(self, data_dict=None, config=None): def random_local_frustum_dropout(self, data_dict=None, config=None):
""" """
Please check the correctness of it before using. Please check the correctness of it before using.
""" """
if data_dict is None: if data_dict is None:
return partial(self.random_local_frustum_dropout, config=config) return partial(self.random_local_frustum_dropout, config=config)
intensity_range = config['INTENSITY_RANGE'] intensity_range = config['INTENSITY_RANGE']
gt_boxes, points = data_dict['gt_boxes'], data_dict['points'] gt_boxes, points = data_dict['gt_boxes'], data_dict['points']
for direction in config['DIRECTION']: for direction in config['DIRECTION']:
...@@ -200,21 +200,21 @@ class DataAugmentor(object): ...@@ -200,21 +200,21 @@ class DataAugmentor(object):
gt_boxes, points = getattr(augmentor_utils, 'local_frustum_dropout_%s' % direction)( gt_boxes, points = getattr(augmentor_utils, 'local_frustum_dropout_%s' % direction)(
gt_boxes, points, intensity_range, gt_boxes, points, intensity_range,
) )
data_dict['gt_boxes'] = gt_boxes data_dict['gt_boxes'] = gt_boxes
data_dict['points'] = points data_dict['points'] = points
return data_dict return data_dict
def random_local_pyramid_aug(self, data_dict=None, config=None): def random_local_pyramid_aug(self, data_dict=None, config=None):
""" """
Refer to the paper: Refer to the paper:
SE-SSD: Self-Ensembling Single-Stage Object Detector From Point Cloud SE-SSD: Self-Ensembling Single-Stage Object Detector From Point Cloud
""" """
if data_dict is None: if data_dict is None:
return partial(self.random_local_pyramid_aug, config=config) return partial(self.random_local_pyramid_aug, config=config)
gt_boxes, points = data_dict['gt_boxes'], data_dict['points'] gt_boxes, points = data_dict['gt_boxes'], data_dict['points']
gt_boxes, points, pyramids = augmentor_utils.local_pyramid_dropout(gt_boxes, points, config['DROP_PROB']) gt_boxes, points, pyramids = augmentor_utils.local_pyramid_dropout(gt_boxes, points, config['DROP_PROB'])
gt_boxes, points, pyramids = augmentor_utils.local_pyramid_sparsify(gt_boxes, points, gt_boxes, points, pyramids = augmentor_utils.local_pyramid_sparsify(gt_boxes, points,
config['SPARSIFY_PROB'], config['SPARSIFY_PROB'],
...@@ -227,7 +227,7 @@ class DataAugmentor(object): ...@@ -227,7 +227,7 @@ class DataAugmentor(object):
data_dict['gt_boxes'] = gt_boxes data_dict['gt_boxes'] = gt_boxes
data_dict['points'] = points data_dict['points'] = points
return data_dict return data_dict
def forward(self, data_dict): def forward(self, data_dict):
""" """
Args: Args:
...@@ -241,12 +241,12 @@ class DataAugmentor(object): ...@@ -241,12 +241,12 @@ class DataAugmentor(object):
""" """
for cur_augmentor in self.data_augmentor_queue: for cur_augmentor in self.data_augmentor_queue:
data_dict = cur_augmentor(data_dict=data_dict) data_dict = cur_augmentor(data_dict=data_dict)
data_dict['gt_boxes'][:, 6] = common_utils.limit_period( data_dict['gt_boxes'][:, 6] = common_utils.limit_period(
data_dict['gt_boxes'][:, 6], offset=0.5, period=2 * np.pi data_dict['gt_boxes'][:, 6], offset=0.5, period=2 * np.pi
) )
if 'calib' in data_dict: # if 'calib' in data_dict:
data_dict.pop('calib') # data_dict.pop('calib')
if 'road_plane' in data_dict: if 'road_plane' in data_dict:
data_dict.pop('road_plane') data_dict.pop('road_plane')
if 'gt_boxes_mask' in data_dict: if 'gt_boxes_mask' in data_dict:
...@@ -255,6 +255,6 @@ class DataAugmentor(object): ...@@ -255,6 +255,6 @@ class DataAugmentor(object):
data_dict['gt_names'] = data_dict['gt_names'][gt_boxes_mask] data_dict['gt_names'] = data_dict['gt_names'][gt_boxes_mask]
if 'gt_boxes2d' in data_dict: if 'gt_boxes2d' in data_dict:
data_dict['gt_boxes2d'] = data_dict['gt_boxes2d'][gt_boxes_mask] data_dict['gt_boxes2d'] = data_dict['gt_boxes2d'][gt_boxes_mask]
data_dict.pop('gt_boxes_mask') data_dict.pop('gt_boxes_mask')
return data_dict return data_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment