Commit 0743656b authored by Shaoshuai Shi's avatar Shaoshuai Shi
Browse files

refactor database_sampler.py to support extending to other multimodal datasets...

refactor database_sampler.py to support extending to other multimodal datasets (in addition to kitti)
parent 80173750
...@@ -17,14 +17,9 @@ class DataBaseSampler(object): ...@@ -17,14 +17,9 @@ class DataBaseSampler(object):
self.root_path = root_path self.root_path = root_path
self.class_names = class_names self.class_names = class_names
self.sampler_cfg = sampler_cfg self.sampler_cfg = sampler_cfg
self.aug_with_img = sampler_cfg.get('AUG_WITH_IMAGE', False)
self.joint_sample = True self.img_aug_type = sampler_cfg.get('IMG_AUG_TYPE', None)
self.keep_raw = False self.img_aug_iou_thresh = sampler_cfg.get('IMG_AUG_IOU_THRESH', 0.5)
self.box_iou_thres = 0.5
self.img_aug_type = 'by_depth'
self.aug_use_type = 'annotation'
self.point_refine = True
self.img_root_path = 'training/image_2'
self.logger = logger self.logger = logger
self.db_infos = {} self.db_infos = {}
...@@ -165,17 +160,20 @@ class DataBaseSampler(object): ...@@ -165,17 +160,20 @@ class DataBaseSampler(object):
return gt_boxes, mv_height return gt_boxes, mv_height
def copy_paste_to_image_kitti(self, data_dict, crop_feat, gt_number, point_idxes=None): def copy_paste_to_image_kitti(self, data_dict, crop_feat, gt_number, point_idxes=None):
kitti_img_aug_type = 'by_depth'
kitti_img_aug_use_type = 'annotation'
image = data_dict['images'] image = data_dict['images']
boxes3d = data_dict['gt_boxes'] boxes3d = data_dict['gt_boxes']
boxes2d = data_dict['gt_boxes2d'] boxes2d = data_dict['gt_boxes2d']
corners_lidar = box_utils.boxes_to_corners_3d(boxes3d) corners_lidar = box_utils.boxes_to_corners_3d(boxes3d)
if 'depth' in self.img_aug_type: if 'depth' in kitti_img_aug_type:
paste_order = boxes3d[:,0].argsort() paste_order = boxes3d[:,0].argsort()
paste_order = paste_order[::-1] paste_order = paste_order[::-1]
else: else:
paste_order = np.arange(len(boxes3d),dtype=np.int) paste_order = np.arange(len(boxes3d),dtype=np.int)
if 'reverse' in self.img_aug_type: if 'reverse' in kitti_img_aug_type:
paste_order = paste_order[::-1] paste_order = paste_order[::-1]
paste_mask = -255 * np.ones(image.shape[:2], dtype=np.int) paste_mask = -255 * np.ones(image.shape[:2], dtype=np.int)
...@@ -193,7 +191,7 @@ class DataBaseSampler(object): ...@@ -193,7 +191,7 @@ class DataBaseSampler(object):
(paste_mask[_box2d[1]:_box2d[3],_box2d[0]:_box2d[2]] > 0).astype(np.int) (paste_mask[_box2d[1]:_box2d[3],_box2d[0]:_box2d[2]] > 0).astype(np.int)
paste_mask[_box2d[1]:_box2d[3],_box2d[0]:_box2d[2]] = _order paste_mask[_box2d[1]:_box2d[3],_box2d[0]:_box2d[2]] = _order
if 'cover' in self.aug_use_type: if 'cover' in kitti_img_aug_use_type:
# HxWx2 for min and max depth of each box region # HxWx2 for min and max depth of each box region
depth_mask[_box2d[1]:_box2d[3],_box2d[0]:_box2d[2],0] = corners_lidar[_order,:,0].min() depth_mask[_box2d[1]:_box2d[3],_box2d[0]:_box2d[2],0] = corners_lidar[_order,:,0].min()
depth_mask[_box2d[1]:_box2d[3],_box2d[0]:_box2d[2],1] = corners_lidar[_order,:,0].max() depth_mask[_box2d[1]:_box2d[3],_box2d[0]:_box2d[2],1] = corners_lidar[_order,:,0].max()
...@@ -204,12 +202,12 @@ class DataBaseSampler(object): ...@@ -204,12 +202,12 @@ class DataBaseSampler(object):
data_dict['images'] = image data_dict['images'] = image
if not self.joint_sample: # if not self.joint_sample:
return data_dict # return data_dict
new_mask = paste_mask[points_2d[:,1], points_2d[:,0]]==(point_idxes+gt_number) new_mask = paste_mask[points_2d[:,1], points_2d[:,0]]==(point_idxes+gt_number)
if self.keep_raw: if False: # self.keep_raw:
raw_mask = point_idxes==-1 raw_mask = (point_idxes == -1)
else: else:
raw_fg = (fg_mask == 1) & (paste_mask >= 0) & (paste_mask < gt_number) raw_fg = (fg_mask == 1) & (paste_mask >= 0) & (paste_mask < gt_number)
raw_bg = (fg_mask == 0) & (paste_mask < 0) raw_bg = (fg_mask == 0) & (paste_mask < 0)
...@@ -217,13 +215,13 @@ class DataBaseSampler(object): ...@@ -217,13 +215,13 @@ class DataBaseSampler(object):
keep_mask = new_mask | raw_mask keep_mask = new_mask | raw_mask
data_dict['points_2d'] = points_2d data_dict['points_2d'] = points_2d
if 'annotation' in self.aug_use_type: if 'annotation' in kitti_img_aug_use_type:
data_dict['points'] = data_dict['points'][keep_mask] data_dict['points'] = data_dict['points'][keep_mask]
data_dict['points_2d'] = data_dict['points_2d'][keep_mask] data_dict['points_2d'] = data_dict['points_2d'][keep_mask]
elif 'projection' in self.aug_use_type: elif 'projection' in kitti_img_aug_use_type:
overlap_mask[overlap_mask>=1] = 1 overlap_mask[overlap_mask>=1] = 1
data_dict['overlap_mask'] = overlap_mask data_dict['overlap_mask'] = overlap_mask
if 'cover' in self.aug_use_type: if 'cover' in kitti_img_aug_use_type:
data_dict['depth_mask'] = depth_mask data_dict['depth_mask'] = depth_mask
return data_dict return data_dict
...@@ -233,7 +231,7 @@ class DataBaseSampler(object): ...@@ -233,7 +231,7 @@ class DataBaseSampler(object):
sampled_calib = calibration_kitti.Calibration(calib_file) sampled_calib = calibration_kitti.Calibration(calib_file)
points_2d, depth_2d = sampled_calib.lidar_to_img(obj_points[:,:3]) points_2d, depth_2d = sampled_calib.lidar_to_img(obj_points[:,:3])
if self.point_refine: if True: # self.point_refine:
# align calibration metrics for points # align calibration metrics for points
points_ract = data_dict['calib'].img_to_rect(points_2d[:,0], points_2d[:,1], depth_2d) points_ract = data_dict['calib'].img_to_rect(points_2d[:,0], points_2d[:,1], depth_2d)
points_lidar = data_dict['calib'].rect_to_lidar(points_ract) points_lidar = data_dict['calib'].rect_to_lidar(points_ract)
...@@ -253,7 +251,7 @@ class DataBaseSampler(object): ...@@ -253,7 +251,7 @@ class DataBaseSampler(object):
obj_idx = idx * np.ones(len(obj_points), dtype=np.int) obj_idx = idx * np.ones(len(obj_points), dtype=np.int)
# copy crops from images # copy crops from images
img_path = self.root_path / self.img_root_path / (info['image_idx']+'.png') img_path = self.root_path / f'training/image_2/{info["image_idx"]}.png'
raw_image = io.imread(img_path) raw_image = io.imread(img_path)
raw_image = raw_image.astype(np.float32) raw_image = raw_image.astype(np.float32)
raw_center = info['bbox'].reshape(2,2).mean(0) raw_center = info['bbox'].reshape(2,2).mean(0)
...@@ -271,7 +269,8 @@ class DataBaseSampler(object): ...@@ -271,7 +269,8 @@ class DataBaseSampler(object):
return new_box, img_crop2d, obj_points, obj_idx return new_box, img_crop2d, obj_points, obj_idx
def sample_gt_boxes_2d_kitti(self, data_dict, sampled_boxes, iou1, iou2): def sample_gt_boxes_2d_kitti(self, data_dict, sampled_boxes, valid_mask):
mv_height = None
# filter out box2d iou > thres # filter out box2d iou > thres
if self.sampler_cfg.get('USE_ROAD_PLANE', False): if self.sampler_cfg.get('USE_ROAD_PLANE', False):
sampled_boxes, mv_height = self.put_boxes_on_road_planes( sampled_boxes, mv_height = self.put_boxes_on_road_planes(
...@@ -289,19 +288,79 @@ class DataBaseSampler(object): ...@@ -289,19 +288,79 @@ class DataBaseSampler(object):
iou2d2[range(sampled_boxes2d.shape[0]), range(sampled_boxes2d.shape[0])] = 0 iou2d2[range(sampled_boxes2d.shape[0]), range(sampled_boxes2d.shape[0])] = 0
iou2d1 = iou2d1 if iou2d1.shape[1] > 0 else iou2d2 iou2d1 = iou2d1 if iou2d1.shape[1] > 0 else iou2d2
valid_mask = ((iou2d1.max(axis=1)<self.box_iou_thres) & ret_valid_mask = ((iou2d1.max(axis=1)<self.img_aug_iou_thresh) &
(iou2d2.max(axis=1)<self.box_iou_thres) & (iou2d2.max(axis=1)<self.img_aug_iou_thresh) &
((iou1.max(axis=1) + iou2.max(axis=1)) == 0)).nonzero()[0] (valid_mask))
sampled_boxes2d = sampled_boxes2d[ret_valid_mask].cpu().numpy()
if mv_height is not None:
mv_height = mv_height[ret_valid_mask]
return sampled_boxes2d, mv_height, ret_valid_mask
def sample_gt_boxes_2d(self, data_dict, sampled_boxes, valid_mask):
mv_height = None
if self.img_aug_type == 'kitti':
sampled_boxes2d, mv_height, ret_valid_mask = self.sample_gt_boxes_2d_kitti(data_dict, sampled_boxes, valid_mask)
else:
raise NotImplementedError
return sampled_boxes2d, mv_height, ret_valid_mask
def initilize_image_aug_dict(self, data_dict, gt_boxes_mask):
img_aug_gt_dict = None
if self.img_aug_type is None:
pass
elif self.img_aug_type == 'kitti':
obj_index_list, crop_boxes2d = [], []
gt_number = gt_boxes_mask.sum().astype(np.int)
gt_boxes2d = data_dict['gt_boxes2d'][gt_boxes_mask].astype(np.int)
gt_crops2d = [data_dict['images'][_x[1]:_x[3],_x[0]:_x[2]] for _x in gt_boxes2d]
img_aug_gt_dict = {
'obj_index_list': obj_index_list,
'gt_crops2d': gt_crops2d,
'gt_boxes2d': gt_boxes2d,
'gt_number': gt_number,
'crop_boxes2d': crop_boxes2d
}
else:
raise NotImplementedError
return img_aug_gt_dict
def collect_image_crops(self, img_aug_gt_dict, info, data_dict, obj_points, sampled_gt_boxes, sampled_gt_boxes2d, idx):
if self.img_aug_type == 'kitti':
new_box, img_crop2d, obj_points, obj_idx = self.collect_image_crops_kitti(info, data_dict,
obj_points, sampled_gt_boxes, sampled_gt_boxes2d, idx)
img_aug_gt_dict['crop_boxes2d'].append(new_box)
img_aug_gt_dict['gt_crops2d'].append(img_crop2d)
img_aug_gt_dict['obj_index_list'].append(obj_idx)
else:
raise NotImplementedError
return img_aug_gt_dict, obj_points
sampled_boxes2d = sampled_boxes2d[valid_mask].cpu().numpy() def copy_paste_to_image(self, img_aug_gt_dict, data_dict, points):
return sampled_boxes2d, mv_height, valid_mask if self.img_aug_type == 'kitti':
obj_points_idx = np.concatenate(img_aug_gt_dict['obj_index_list'], axis=0)
point_idxes = -1 * np.ones(len(points), dtype=np.int)
point_idxes = np.concatenate([obj_points_idx, point_idxes], axis=0)
data_dict['gt_boxes2d'] = np.concatenate([img_aug_gt_dict['gt_boxes2d'], np.array(img_aug_gt_dict['crop_boxes2d'])], axis=0)
data_dict = self.copy_paste_to_image_kitti(data_dict, img_aug_gt_dict['gt_crops2d'], img_aug_gt_dict['gt_number'], point_idxes)
if 'road_plane' in data_dict:
data_dict.pop('road_plane')
else:
raise NotImplementedError
return data_dict
def add_sampled_boxes_to_scene(self, data_dict, sampled_gt_boxes, total_valid_sampled_dict, mv_height=None, sampled_gt_boxes2d=None): def add_sampled_boxes_to_scene(self, data_dict, sampled_gt_boxes, total_valid_sampled_dict, mv_height=None, sampled_gt_boxes2d=None):
gt_boxes_mask = data_dict['gt_boxes_mask'] gt_boxes_mask = data_dict['gt_boxes_mask']
gt_boxes = data_dict['gt_boxes'][gt_boxes_mask] gt_boxes = data_dict['gt_boxes'][gt_boxes_mask]
gt_names = data_dict['gt_names'][gt_boxes_mask] gt_names = data_dict['gt_names'][gt_boxes_mask]
points = data_dict['points'] points = data_dict['points']
if self.sampler_cfg.get('USE_ROAD_PLANE', False) and not self.aug_with_img: if self.sampler_cfg.get('USE_ROAD_PLANE', False) and mv_height is None:
sampled_gt_boxes, mv_height = self.put_boxes_on_road_planes( sampled_gt_boxes, mv_height = self.put_boxes_on_road_planes(
sampled_gt_boxes, data_dict['road_plane'], data_dict['calib'] sampled_gt_boxes, data_dict['road_plane'], data_dict['calib']
) )
...@@ -309,12 +368,9 @@ class DataBaseSampler(object): ...@@ -309,12 +368,9 @@ class DataBaseSampler(object):
data_dict.pop('road_plane') data_dict.pop('road_plane')
obj_points_list = [] obj_points_list = []
# convert sampled 3D boxes to image plane # convert sampled 3D boxes to image plane
if self.aug_with_img: img_aug_gt_dict = self.initilize_image_aug_dict(data_dict, gt_boxes_mask)
obj_index_list, crop_boxes2d = [], []
gt_number = gt_boxes_mask.sum().astype(np.int)
gt_boxes2d = data_dict['gt_boxes2d'][gt_boxes_mask].astype(np.int)
gt_crops2d = [data_dict['images'][_x[1]:_x[3],_x[0]:_x[2]] for _x in gt_boxes2d]
if self.use_shared_memory: if self.use_shared_memory:
gt_database_data = SharedArray.attach(f"shm://{self.gt_database_data_key}") gt_database_data = SharedArray.attach(f"shm://{self.gt_database_data_key}")
...@@ -337,12 +393,10 @@ class DataBaseSampler(object): ...@@ -337,12 +393,10 @@ class DataBaseSampler(object):
# mv height # mv height
obj_points[:, 2] -= mv_height[idx] obj_points[:, 2] -= mv_height[idx]
if self.aug_with_img: if self.img_aug_type is not None:
new_box, img_crop2d, obj_points, obj_idx = self.collect_image_crops_kitti(info, data_dict, img_aug_gt_dict, obj_points = self.collect_image_crops(
obj_points, sampled_gt_boxes, sampled_gt_boxes2d, idx) img_aug_gt_dict, info, data_dict, obj_points, sampled_gt_boxes, sampled_gt_boxes2d, idx
crop_boxes2d.append(new_box) )
gt_crops2d.append(img_crop2d)
obj_index_list.append(obj_idx)
obj_points_list.append(obj_points) obj_points_list.append(obj_points)
...@@ -359,15 +413,9 @@ class DataBaseSampler(object): ...@@ -359,15 +413,9 @@ class DataBaseSampler(object):
data_dict['gt_boxes'] = gt_boxes data_dict['gt_boxes'] = gt_boxes
data_dict['gt_names'] = gt_names data_dict['gt_names'] = gt_names
data_dict['points'] = points data_dict['points'] = points
if self.aug_with_img:
obj_points_idx = np.concatenate(obj_index_list, axis=0)
point_idxes = -1 * np.ones(len(points), dtype=np.int)
point_idxes = np.concatenate([obj_points_idx, point_idxes], axis=0)
data_dict['gt_boxes2d'] = np.concatenate([gt_boxes2d, np.array(crop_boxes2d)], axis=0) if self.img_aug_type is not None:
data_dict = self.copy_paste_to_image_kitti(data_dict, gt_crops2d, gt_number, point_idxes) data_dict = self.copy_paste_to_image(img_aug_gt_dict, data_dict, points)
if self.sampler_cfg.get('USE_ROAD_PLANE', False):
data_dict.pop('road_plane')
return data_dict return data_dict
...@@ -386,6 +434,7 @@ class DataBaseSampler(object): ...@@ -386,6 +434,7 @@ class DataBaseSampler(object):
total_valid_sampled_dict = [] total_valid_sampled_dict = []
sampled_mv_height = [] sampled_mv_height = []
sampled_gt_boxes2d = [] sampled_gt_boxes2d = []
for class_name, sample_group in self.sample_groups.items(): for class_name, sample_group in self.sample_groups.items():
if self.limit_whole_scene: if self.limit_whole_scene:
num_gt = np.sum(class_name == gt_names) num_gt = np.sum(class_name == gt_names)
...@@ -401,15 +450,15 @@ class DataBaseSampler(object): ...@@ -401,15 +450,15 @@ class DataBaseSampler(object):
iou2 = iou3d_nms_utils.boxes_bev_iou_cpu(sampled_boxes[:, 0:7], sampled_boxes[:, 0:7]) iou2 = iou3d_nms_utils.boxes_bev_iou_cpu(sampled_boxes[:, 0:7], sampled_boxes[:, 0:7])
iou2[range(sampled_boxes.shape[0]), range(sampled_boxes.shape[0])] = 0 iou2[range(sampled_boxes.shape[0]), range(sampled_boxes.shape[0])] = 0
iou1 = iou1 if iou1.shape[1] > 0 else iou2 iou1 = iou1 if iou1.shape[1] > 0 else iou2
valid_mask = ((iou1.max(axis=1) + iou2.max(axis=1)) == 0).nonzero()[0] valid_mask = ((iou1.max(axis=1) + iou2.max(axis=1)) == 0)
if self.aug_with_img: if self.img_aug_type is not None:
sampled_boxes2d, mv_height, valid_mask = self.sample_gt_boxes_2d_kitti(data_dict, sampled_boxes, iou1, iou2) sampled_boxes2d, mv_height, valid_mask = self.sample_gt_boxes_2d(data_dict, sampled_boxes, valid_mask)
sampled_gt_boxes2d.append(sampled_boxes2d) sampled_gt_boxes2d.append(sampled_boxes2d)
if self.sampler_cfg.get('USE_ROAD_PLANE', False): if mv_height is not None:
mv_height = mv_height[valid_mask] sampled_mv_height.append(mv_height)
sampled_mv_height = np.concatenate((sampled_mv_height, mv_height), axis=0)
valid_mask = valid_mask.nonzero()[0]
valid_sampled_dict = [sampled_dict[x] for x in valid_mask] valid_sampled_dict = [sampled_dict[x] for x in valid_mask]
valid_sampled_boxes = sampled_boxes[valid_mask] valid_sampled_boxes = sampled_boxes[valid_mask]
...@@ -418,10 +467,10 @@ class DataBaseSampler(object): ...@@ -418,10 +467,10 @@ class DataBaseSampler(object):
sampled_gt_boxes = existed_boxes[gt_boxes.shape[0]:, :] sampled_gt_boxes = existed_boxes[gt_boxes.shape[0]:, :]
if len(sampled_gt_boxes2d) > 0:
sampled_gt_boxes2d = np.concatenate(sampled_gt_boxes2d, axis=0)
if total_valid_sampled_dict.__len__() > 0: if total_valid_sampled_dict.__len__() > 0:
sampled_gt_boxes2d = np.concatenate(sampled_gt_boxes2d, axis=0) if len(sampled_gt_boxes2d) > 0 else None
sampled_mv_height = np.concatenate(sampled_mv_height, axis=0) if len(sampled_mv_height) > 0 else None
data_dict = self.add_sampled_boxes_to_scene( data_dict = self.add_sampled_boxes_to_scene(
data_dict, sampled_gt_boxes, total_valid_sampled_dict, sampled_mv_height, sampled_gt_boxes2d data_dict, sampled_gt_boxes, total_valid_sampled_dict, sampled_mv_height, sampled_gt_boxes2d
) )
......
...@@ -8,7 +8,8 @@ DATA_CONFIG: ...@@ -8,7 +8,8 @@ DATA_CONFIG:
AUG_CONFIG_LIST: AUG_CONFIG_LIST:
- NAME: gt_sampling - NAME: gt_sampling
AUG_WITH_IMAGE: True # use PC-Image Aug # AUG_WITH_IMAGE: True # use PC-Image Aug
IMG_AUG_TYPE: kitti
USE_ROAD_PLANE: True USE_ROAD_PLANE: True
DB_INFO_PATH: DB_INFO_PATH:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment