Commit eaae7f47 authored by Gus-Guo's avatar Gus-Guo
Browse files

delete usage of register function

parent b6f114e5
......@@ -4,15 +4,6 @@ from . import augmentor_utils, database_sampler
from ...utils import common_utils
def register_function_augmentor(src_func):
def kernel_func(self, data_dict=None, config=None):
if data_dict is None:
return partial(src_func, self=self, config=config)
return src_func
return kernel_func
class DataAugmentor(object):
def __init__(self, root_path, augmentor_configs, class_names, logger=None):
self.root_path = root_path
......@@ -23,7 +14,7 @@ class DataAugmentor(object):
for cur_cfg in augmentor_configs:
cur_augmentor = getattr(self, cur_cfg.NAME)(config=cur_cfg)
self.data_augmentor_queue.append(cur_augmentor)
def gt_sampling(self, config=None):
db_sampler = database_sampler.DataBaseSampler(
root_path=self.root_path,
......@@ -33,8 +24,10 @@ class DataAugmentor(object):
)
return db_sampler
@register_function_augmentor
def random_world_flip(self, data_dict=None, config=None):
if data_dict is None:
return partial(self.random_world_flip, config=config)
gt_boxes, points = data_dict['gt_boxes'], data_dict['points']
for cur_axis in config['ALONG_AXIS_LIST']:
assert cur_axis in ['x', 'y']
......@@ -46,8 +39,10 @@ class DataAugmentor(object):
data_dict['points'] = points
return data_dict
@register_function_augmentor
def random_world_rotation(self, data_dict=None, config=None):
if data_dict is None:
return partial(self.random_world_rotation, config=config)
rot_range = config['WORLD_ROT_ANGLE']
if not isinstance(rot_range, list):
rot_range = [-rot_range, rot_range]
......@@ -59,8 +54,10 @@ class DataAugmentor(object):
data_dict['points'] = points
return data_dict
@register_function_augmentor
def random_world_scaling(self, data_dict=None, config=None):
if data_dict is None:
return partial(self.random_world_scaling, config=config)
gt_boxes, points = augmentor_utils.global_scaling(
data_dict['gt_boxes'], data_dict['points'], config['WORLD_SCALE_RANGE']
)
......
......@@ -3,15 +3,6 @@ import numpy as np
from ...utils import box_utils, common_utils
def register_function_processor(src_func):
def kernel_func(self, data_dict=None, config=None):
if data_dict is None:
return partial(src_func, self=self, config=config)
return src_func
return kernel_func
class DataProcessor(object):
def __init__(self, processor_configs, point_cloud_range, training):
self.point_cloud_range = point_cloud_range
......@@ -22,9 +13,11 @@ class DataProcessor(object):
for cur_cfg in processor_configs:
cur_processor = getattr(self, cur_cfg.NAME)(config=cur_cfg)
self.data_processor_queue.append(cur_processor)
@register_function_processor
def mask_points_and_boxes_outside_range(self, data_dict=None, config=None):
if data_dict is None:
return partial(self.mask_points_and_boxes_outside_range, config=config)
mask = common_utils.mask_points_by_range(data_dict['points'], self.point_cloud_range)
data_dict['points'] = data_dict['points'][mask]
if data_dict.get('gt_boxes', None) is not None and config.REMOVE_OUTSIDE_BOXES and self.training:
......@@ -34,7 +27,6 @@ class DataProcessor(object):
data_dict['gt_boxes'] = data_dict['gt_boxes'][mask]
return data_dict
@register_function_processor
def shuffle_points(self, data_dict=None, config=None):
if data_dict is None:
return partial(self.shuffle_points, config=config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment