Commit eaae7f47 authored by Gus-Guo's avatar Gus-Guo
Browse files

delete usage of register function

parent b6f114e5
...@@ -4,15 +4,6 @@ from . import augmentor_utils, database_sampler ...@@ -4,15 +4,6 @@ from . import augmentor_utils, database_sampler
from ...utils import common_utils from ...utils import common_utils
def register_function_augmentor(src_func):
def kernel_func(self, data_dict=None, config=None):
if data_dict is None:
return partial(src_func, self=self, config=config)
return src_func
return kernel_func
class DataAugmentor(object): class DataAugmentor(object):
def __init__(self, root_path, augmentor_configs, class_names, logger=None): def __init__(self, root_path, augmentor_configs, class_names, logger=None):
self.root_path = root_path self.root_path = root_path
...@@ -33,8 +24,10 @@ class DataAugmentor(object): ...@@ -33,8 +24,10 @@ class DataAugmentor(object):
) )
return db_sampler return db_sampler
@register_function_augmentor
def random_world_flip(self, data_dict=None, config=None): def random_world_flip(self, data_dict=None, config=None):
if data_dict is None:
return partial(self.random_world_flip, config=config)
gt_boxes, points = data_dict['gt_boxes'], data_dict['points'] gt_boxes, points = data_dict['gt_boxes'], data_dict['points']
for cur_axis in config['ALONG_AXIS_LIST']: for cur_axis in config['ALONG_AXIS_LIST']:
assert cur_axis in ['x', 'y'] assert cur_axis in ['x', 'y']
...@@ -46,8 +39,10 @@ class DataAugmentor(object): ...@@ -46,8 +39,10 @@ class DataAugmentor(object):
data_dict['points'] = points data_dict['points'] = points
return data_dict return data_dict
@register_function_augmentor
def random_world_rotation(self, data_dict=None, config=None): def random_world_rotation(self, data_dict=None, config=None):
if data_dict is None:
return partial(self.random_world_rotation, config=config)
rot_range = config['WORLD_ROT_ANGLE'] rot_range = config['WORLD_ROT_ANGLE']
if not isinstance(rot_range, list): if not isinstance(rot_range, list):
rot_range = [-rot_range, rot_range] rot_range = [-rot_range, rot_range]
...@@ -59,8 +54,10 @@ class DataAugmentor(object): ...@@ -59,8 +54,10 @@ class DataAugmentor(object):
data_dict['points'] = points data_dict['points'] = points
return data_dict return data_dict
@register_function_augmentor
def random_world_scaling(self, data_dict=None, config=None): def random_world_scaling(self, data_dict=None, config=None):
if data_dict is None:
return partial(self.random_world_scaling, config=config)
gt_boxes, points = augmentor_utils.global_scaling( gt_boxes, points = augmentor_utils.global_scaling(
data_dict['gt_boxes'], data_dict['points'], config['WORLD_SCALE_RANGE'] data_dict['gt_boxes'], data_dict['points'], config['WORLD_SCALE_RANGE']
) )
......
...@@ -3,15 +3,6 @@ import numpy as np ...@@ -3,15 +3,6 @@ import numpy as np
from ...utils import box_utils, common_utils from ...utils import box_utils, common_utils
def register_function_processor(src_func):
def kernel_func(self, data_dict=None, config=None):
if data_dict is None:
return partial(src_func, self=self, config=config)
return src_func
return kernel_func
class DataProcessor(object): class DataProcessor(object):
def __init__(self, processor_configs, point_cloud_range, training): def __init__(self, processor_configs, point_cloud_range, training):
self.point_cloud_range = point_cloud_range self.point_cloud_range = point_cloud_range
...@@ -23,8 +14,10 @@ class DataProcessor(object): ...@@ -23,8 +14,10 @@ class DataProcessor(object):
cur_processor = getattr(self, cur_cfg.NAME)(config=cur_cfg) cur_processor = getattr(self, cur_cfg.NAME)(config=cur_cfg)
self.data_processor_queue.append(cur_processor) self.data_processor_queue.append(cur_processor)
@register_function_processor
def mask_points_and_boxes_outside_range(self, data_dict=None, config=None): def mask_points_and_boxes_outside_range(self, data_dict=None, config=None):
if data_dict is None:
return partial(self.mask_points_and_boxes_outside_range, config=config)
mask = common_utils.mask_points_by_range(data_dict['points'], self.point_cloud_range) mask = common_utils.mask_points_by_range(data_dict['points'], self.point_cloud_range)
data_dict['points'] = data_dict['points'][mask] data_dict['points'] = data_dict['points'][mask]
if data_dict.get('gt_boxes', None) is not None and config.REMOVE_OUTSIDE_BOXES and self.training: if data_dict.get('gt_boxes', None) is not None and config.REMOVE_OUTSIDE_BOXES and self.training:
...@@ -34,7 +27,6 @@ class DataProcessor(object): ...@@ -34,7 +27,6 @@ class DataProcessor(object):
data_dict['gt_boxes'] = data_dict['gt_boxes'][mask] data_dict['gt_boxes'] = data_dict['gt_boxes'][mask]
return data_dict return data_dict
@register_function_processor
def shuffle_points(self, data_dict=None, config=None): def shuffle_points(self, data_dict=None, config=None):
if data_dict is None: if data_dict is None:
return partial(self.shuffle_points, config=config) return partial(self.shuffle_points, config=config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment