Unverified Commit da1f5b12 authored by yukang's avatar yukang Committed by GitHub
Browse files

Update database_sampler.py

parent c5d71a67
...@@ -18,11 +18,13 @@ class DataBaseSampler(object): ...@@ -18,11 +18,13 @@ class DataBaseSampler(object):
self.class_names = class_names self.class_names = class_names
self.sampler_cfg = sampler_cfg self.sampler_cfg = sampler_cfg
self.aug_with_img = sampler_cfg.get('AUG_WITH_IMAGE', False) self.aug_with_img = sampler_cfg.get('AUG_WITH_IMAGE', False)
self.joint_sample = sampler_cfg.get('JOINT_SAMPLE', False) self.joint_sample = True
self.keep_raw = sampler_cfg.get('KEEP_RAW', False) self.keep_raw = False
self.box_iou_thres = sampler_cfg.get('BOX_IOU_THRES', 1.0) self.box_iou_thres = 0.5
self.aug_use_type = sampler_cfg.get('AUG_USE_TYPE', 'annotation') self.img_aug_type = 'by_depth'
self.point_refine = sampler_cfg.get('POINT_REFINE', False) self.aug_use_type = 'annotation'
self.point_refine = True
self.img_root_path = 'training/image_2'
self.logger = logger self.logger = logger
self.db_infos = {} self.db_infos = {}
...@@ -167,14 +169,13 @@ class DataBaseSampler(object): ...@@ -167,14 +169,13 @@ class DataBaseSampler(object):
boxes3d = data_dict['gt_boxes'] boxes3d = data_dict['gt_boxes']
boxes2d = data_dict['gt_boxes2d'] boxes2d = data_dict['gt_boxes2d']
corners_lidar = box_utils.boxes_to_corners_3d(boxes3d) corners_lidar = box_utils.boxes_to_corners_3d(boxes3d)
img_aug_type = self.sampler_cfg.IMG_AUG_TYPE if 'depth' in self.img_aug_type:
if 'depth' in img_aug_type:
paste_order = boxes3d[:,0].argsort() paste_order = boxes3d[:,0].argsort()
paste_order = paste_order[::-1] paste_order = paste_order[::-1]
else: else:
paste_order = np.arange(len(boxes3d),dtype=np.int) paste_order = np.arange(len(boxes3d),dtype=np.int)
if 'reverse' in img_aug_type: if 'reverse' in self.img_aug_type:
paste_order = paste_order[::-1] paste_order = paste_order[::-1]
paste_mask = -255 * np.ones(image.shape[:2], dtype=np.int) paste_mask = -255 * np.ones(image.shape[:2], dtype=np.int)
...@@ -252,7 +253,7 @@ class DataBaseSampler(object): ...@@ -252,7 +253,7 @@ class DataBaseSampler(object):
obj_idx = idx * np.ones(len(obj_points), dtype=np.int) obj_idx = idx * np.ones(len(obj_points), dtype=np.int)
# copy crops from images # copy crops from images
img_path = self.root_path / self.sampler_cfg.IMG_ROOT_PATH / (info['image_idx']+'.png') img_path = self.root_path / self.img_root_path / (info['image_idx']+'.png')
raw_image = io.imread(img_path) raw_image = io.imread(img_path)
raw_image = raw_image.astype(np.float32) raw_image = raw_image.astype(np.float32)
raw_center = info['bbox'].reshape(2,2).mean(0) raw_center = info['bbox'].reshape(2,2).mean(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment