Unverified Commit f4071498 authored by yukang's avatar yukang Committed by GitHub
Browse files

Update database_sampler.py

parent e80ebf6b
......@@ -162,7 +162,7 @@ class DataBaseSampler(object):
gt_boxes[:, 2] -= mv_height # lidar view
return gt_boxes, mv_height
def copy_paste_to_image(self, data_dict, crop_feat, gt_number, point_idxes=None):
def copy_paste_to_image_kitti(self, data_dict, crop_feat, gt_number, point_idxes=None):
image = data_dict['images']
boxes3d = data_dict['gt_boxes']
boxes2d = data_dict['gt_boxes2d']
......@@ -227,11 +227,78 @@ class DataBaseSampler(object):
return data_dict
def add_sampled_boxes_to_scene(self, data_dict, sampled_gt_boxes, mv_height, sampled_gt_boxes2d, total_valid_sampled_dict):
def collect_image_crops_kitti(self, info, data_dict, obj_points, sampled_gt_boxes, sampled_gt_boxes2d, idx):
calib_file = kitti_common.get_calib_path(int(info['image_idx']), self.root_path, relative_path=False)
sampled_calib = calibration_kitti.Calibration(calib_file)
points_2d, depth_2d = sampled_calib.lidar_to_img(obj_points[:,:3])
if self.point_refine:
# align calibration metrics for points
points_ract = data_dict['calib'].img_to_rect(points_2d[:,0], points_2d[:,1], depth_2d)
points_lidar = data_dict['calib'].rect_to_lidar(points_ract)
obj_points[:, :3] = points_lidar
# align calibration metrics for boxes
box3d_raw = sampled_gt_boxes[idx].reshape(1,-1)
box3d_coords = box_utils.boxes_to_corners_3d(box3d_raw)[0]
box3d_box, box3d_depth = sampled_calib.lidar_to_img(box3d_coords)
box3d_coord_rect = data_dict['calib'].img_to_rect(box3d_box[:,0], box3d_box[:,1], box3d_depth)
box3d_rect = box_utils.corners_rect_to_camera(box3d_coord_rect).reshape(1,-1)
box3d_lidar = box_utils.boxes3d_kitti_camera_to_lidar(box3d_rect, data_dict['calib'])
box2d = box_utils.boxes3d_kitti_camera_to_imageboxes(box3d_rect, data_dict['calib'],
data_dict['images'].shape[:2])
sampled_gt_boxes[idx] = box3d_lidar[0]
sampled_gt_boxes2d[idx] = box2d[0]
obj_idx = idx * np.ones(len(obj_points), dtype=np.int)
# copy crops from images
img_path = self.root_path / self.sampler_cfg.IMG_ROOT_PATH / (info['image_idx']+'.png')
raw_image = io.imread(img_path)
raw_image = raw_image.astype(np.float32)
raw_center = info['bbox'].reshape(2,2).mean(0)
new_box = sampled_gt_boxes2d[idx].astype(np.int)
new_shape = np.array([new_box[2]-new_box[0], new_box[3]-new_box[1]])
raw_box = np.concatenate([raw_center-new_shape/2, raw_center+new_shape/2]).astype(np.int)
raw_box[0::2] = np.clip(raw_box[0::2], a_min=0, a_max=raw_image.shape[1])
raw_box[1::2] = np.clip(raw_box[1::2], a_min=0, a_max=raw_image.shape[0])
if (raw_box[2]-raw_box[0])!=new_shape[0] or (raw_box[3]-raw_box[1])!=new_shape[1]:
new_center = new_box.reshape(2,2).mean(0)
new_shape = np.array([raw_box[2]-raw_box[0], raw_box[3]-raw_box[1]])
new_box = np.concatenate([new_center-new_shape/2, new_center+new_shape/2]).astype(np.int)
img_crop2d = raw_image[raw_box[1]:raw_box[3],raw_box[0]:raw_box[2]] / 255
return new_box, img_crop2d, obj_points, obj_idx
def sample_gt_boxes_2d_kitti(self, data_dict, sampled_boxes, iou1, iou2):
# filter out box2d iou > thres
if self.sampler_cfg.get('USE_ROAD_PLANE', False):
sampled_boxes, mv_height = self.put_boxes_on_road_planes(
sampled_boxes, data_dict['road_plane'], data_dict['calib']
)
# sampled_boxes2d = np.stack([x['bbox'] for x in sampled_dict], axis=0).astype(np.float32)
boxes3d_camera = box_utils.boxes3d_lidar_to_kitti_camera(sampled_boxes, data_dict['calib'])
sampled_boxes2d = box_utils.boxes3d_kitti_camera_to_imageboxes(boxes3d_camera, data_dict['calib'],
data_dict['images'].shape[:2])
sampled_boxes2d = torch.Tensor(sampled_boxes2d)
existed_boxes2d = torch.Tensor(data_dict['gt_boxes2d'])
iou2d1 = box2d_utils.pairwise_iou(sampled_boxes2d, existed_boxes2d).cpu().numpy()
iou2d2 = box2d_utils.pairwise_iou(sampled_boxes2d, sampled_boxes2d).cpu().numpy()
iou2d2[range(sampled_boxes2d.shape[0]), range(sampled_boxes2d.shape[0])] = 0
iou2d1 = iou2d1 if iou2d1.shape[1] > 0 else iou2d2
valid_mask = ((iou2d1.max(axis=1)<self.box_iou_thres) &
(iou2d2.max(axis=1)<self.box_iou_thres) &
((iou1.max(axis=1) + iou2.max(axis=1)) == 0)).nonzero()[0]
sampled_boxes2d = sampled_boxes2d[valid_mask].cpu().numpy()
return sampled_boxes2d, mv_height, valid_mask
def add_sampled_boxes_to_scene(self, data_dict, sampled_gt_boxes, total_valid_sampled_dict, mv_height=None, sampled_gt_boxes2d=None):
gt_boxes_mask = data_dict['gt_boxes_mask']
gt_boxes = data_dict['gt_boxes'][gt_boxes_mask]
gt_names = data_dict['gt_names'][gt_boxes_mask]
gt_number = gt_boxes_mask.sum().astype(np.int)
points = data_dict['points']
if self.sampler_cfg.get('USE_ROAD_PLANE', False) and not self.aug_with_img:
sampled_gt_boxes, mv_height = self.put_boxes_on_road_planes(
......@@ -240,11 +307,14 @@ class DataBaseSampler(object):
data_dict.pop('calib')
data_dict.pop('road_plane')
obj_points_list, obj_index_list, crop_boxes2d = [], [], []
obj_points_list = []
# convert sampled 3D boxes to image plane
if self.aug_with_img:
obj_index_list, crop_boxes2d = [], []
gt_number = gt_boxes_mask.sum().astype(np.int)
gt_boxes2d = data_dict['gt_boxes2d'][gt_boxes_mask].astype(np.int)
gt_crops2d = [data_dict['images'][_x[1]:_x[3],_x[0]:_x[2]] for _x in gt_boxes2d]
if self.use_shared_memory:
gt_database_data = SharedArray.attach(f"shm://{self.gt_database_data_key}")
gt_database_data.setflags(write=0)
......@@ -267,77 +337,37 @@ class DataBaseSampler(object):
obj_points[:, 2] -= mv_height[idx]
if self.aug_with_img:
calib_file = kitti_common.get_calib_path(int(info['image_idx']), self.root_path, relative_path=False)
sampled_calib = calibration_kitti.Calibration(calib_file)
points_2d, depth_2d = sampled_calib.lidar_to_img(obj_points[:,:3])
if self.point_refine:
# align calibration metrics for points
points_ract = data_dict['calib'].img_to_rect(points_2d[:,0], points_2d[:,1], depth_2d)
points_lidar = data_dict['calib'].rect_to_lidar(points_ract)
obj_points[:, :3] = points_lidar
# align calibration metrics for boxes
box3d_raw = sampled_gt_boxes[idx].reshape(1,-1)
box3d_coords = box_utils.boxes_to_corners_3d(box3d_raw)[0]
box3d_box, box3d_depth = sampled_calib.lidar_to_img(box3d_coords)
box3d_coord_rect = data_dict['calib'].img_to_rect(box3d_box[:,0], box3d_box[:,1], box3d_depth)
box3d_rect = box_utils.corners_rect_to_camera(box3d_coord_rect).reshape(1,-1)
box3d_lidar = box_utils.boxes3d_kitti_camera_to_lidar(box3d_rect, data_dict['calib'])
box2d = box_utils.boxes3d_kitti_camera_to_imageboxes(box3d_rect, data_dict['calib'],
data_dict['images'].shape[:2])
sampled_gt_boxes[idx] = box3d_lidar[0]
sampled_gt_boxes2d[idx] = box2d[0]
obj_idx = idx * np.ones(len(obj_points), dtype=np.int)
obj_points_list.append(obj_points)
obj_index_list.append(obj_idx)
# copy crops from images
if self.aug_with_img:
img_path = self.root_path / self.sampler_cfg.IMG_ROOT_PATH / (info['image_idx']+'.png')
raw_image = io.imread(img_path)
raw_image = raw_image.astype(np.float32)
raw_center = info['bbox'].reshape(2,2).mean(0)
new_box = sampled_gt_boxes2d[idx].astype(np.int)
new_shape = np.array([new_box[2]-new_box[0], new_box[3]-new_box[1]])
raw_box = np.concatenate([raw_center-new_shape/2, raw_center+new_shape/2]).astype(np.int)
raw_box[0::2] = np.clip(raw_box[0::2], a_min=0, a_max=raw_image.shape[1])
raw_box[1::2] = np.clip(raw_box[1::2], a_min=0, a_max=raw_image.shape[0])
if (raw_box[2]-raw_box[0])!=new_shape[0] or (raw_box[3]-raw_box[1])!=new_shape[1]:
new_center = new_box.reshape(2,2).mean(0)
new_shape = np.array([raw_box[2]-raw_box[0], raw_box[3]-raw_box[1]])
new_box = np.concatenate([new_center-new_shape/2, new_center+new_shape/2]).astype(np.int)
img_crop2d = raw_image[raw_box[1]:raw_box[3],raw_box[0]:raw_box[2]] / 255
new_box, img_crop2d, obj_points, obj_idx = self.collect_image_crops_kitti(info, data_dict,
obj_points, sampled_gt_boxes, sampled_gt_boxes2d, idx)
crop_boxes2d.append(new_box)
gt_crops2d.append(img_crop2d)
obj_index_list.append(obj_idx)
obj_points_list.append(obj_points)
obj_points = np.concatenate(obj_points_list, axis=0)
obj_points_idx = np.concatenate(obj_index_list, axis=0)
sampled_gt_names = np.array([x['name'] for x in total_valid_sampled_dict])
large_sampled_gt_boxes = box_utils.enlarge_box3d(
sampled_gt_boxes[:, 0:7], extra_width=self.sampler_cfg.REMOVE_EXTRA_WIDTH
)
points = box_utils.remove_points_in_boxes3d(points, large_sampled_gt_boxes)
point_idxes = -1 * np.ones(len(points), dtype=np.int)
points = np.concatenate([points, obj_points], axis=0)
point_idxes = np.concatenate([point_idxes, obj_points_idx], axis=0)
points = np.concatenate([obj_points, points], axis=0)
gt_names = np.concatenate([gt_names, sampled_gt_names], axis=0)
gt_boxes = np.concatenate([gt_boxes, sampled_gt_boxes], axis=0)
data_dict['gt_boxes'] = gt_boxes
data_dict['gt_names'] = gt_names
data_dict['points'] = points
if self.aug_with_img:
obj_points_idx = np.concatenate(obj_index_list, axis=0)
point_idxes = -1 * np.ones(len(points), dtype=np.int)
point_idxes = np.concatenate([obj_points_idx, point_idxes], axis=0)
data_dict['gt_boxes2d'] = np.concatenate([gt_boxes2d, np.array(crop_boxes2d)], axis=0)
data_dict = self.copy_paste_to_image(data_dict, gt_crops2d, gt_number, point_idxes)
data_dict = self.copy_paste_to_image_kitti(data_dict, gt_crops2d, gt_number, point_idxes)
if self.sampler_cfg.get('USE_ROAD_PLANE', False):
data_dict.pop('road_plane')
if self.sampler_cfg.get('USE_ROAD_PLANE', False) and self.aug_with_img:
# data_dict.pop('calib')
data_dict.pop('road_plane')
return data_dict
def __call__(self, data_dict):
......@@ -372,48 +402,32 @@ class DataBaseSampler(object):
iou2[range(sampled_boxes.shape[0]), range(sampled_boxes.shape[0])] = 0
iou1 = iou1 if iou1.shape[1] > 0 else iou2
valid_mask = ((iou1.max(axis=1) + iou2.max(axis=1)) == 0).nonzero()[0]
# filter out box2d iou > thres
if self.sampler_cfg.get('USE_ROAD_PLANE', False):
sampled_boxes, mv_height = self.put_boxes_on_road_planes(
sampled_boxes, data_dict['road_plane'], data_dict['calib']
)
if self.aug_with_img:
# sampled_boxes2d = np.stack([x['bbox'] for x in sampled_dict], axis=0).astype(np.float32)
boxes3d_camera = box_utils.boxes3d_lidar_to_kitti_camera(sampled_boxes, data_dict['calib'])
sampled_boxes2d = box_utils.boxes3d_kitti_camera_to_imageboxes(boxes3d_camera, data_dict['calib'],
data_dict['images'].shape[:2])
sampled_boxes2d = torch.Tensor(sampled_boxes2d)
existed_boxes2d = torch.Tensor(data_dict['gt_boxes2d'])
iou2d1 = box2d_utils.pairwise_iou(sampled_boxes2d, existed_boxes2d).cpu().numpy()
iou2d2 = box2d_utils.pairwise_iou(sampled_boxes2d, sampled_boxes2d).cpu().numpy()
iou2d2[range(sampled_boxes2d.shape[0]), range(sampled_boxes2d.shape[0])] = 0
iou2d1 = iou2d1 if iou2d1.shape[1] > 0 else iou2d2
valid_mask = ((iou2d1.max(axis=1)<self.box_iou_thres) &
(iou2d2.max(axis=1)<self.box_iou_thres) &
((iou1.max(axis=1) + iou2.max(axis=1)) == 0)).nonzero()[0]
sampled_boxes2d = sampled_boxes2d[valid_mask].cpu().numpy()
sampled_boxes2d, mv_height, valid_mask = self.sample_gt_boxes_2d_kitti(data_dict, sampled_boxes, iou1, iou2)
sampled_gt_boxes2d.append(sampled_boxes2d)
if self.sampler_cfg.get('USE_ROAD_PLANE', False):
mv_height = mv_height[valid_mask]
sampled_mv_height = np.concatenate((sampled_mv_height, mv_height), axis=0)
valid_sampled_dict = [sampled_dict[x] for x in valid_mask]
valid_sampled_boxes = sampled_boxes[valid_mask]
mv_height = mv_height[valid_mask]
existed_boxes = np.concatenate((existed_boxes, valid_sampled_boxes), axis=0)
sampled_mv_height = np.concatenate((sampled_mv_height, mv_height), axis=0)
total_valid_sampled_dict.extend(valid_sampled_dict)
sampled_gt_boxes = existed_boxes[gt_boxes.shape[0]:, :]
if len(sampled_gt_boxes2d) > 0:
sampled_gt_boxes2d = np.concatenate(sampled_gt_boxes2d, axis=0)
if self.aug_with_img:
if len(sampled_gt_boxes2d) > 0:
sampled_gt_boxes2d = np.concatenate(sampled_gt_boxes2d, axis=0)
if total_valid_sampled_dict.__len__() > 0:
data_dict = self.add_sampled_boxes_to_scene(data_dict,
sampled_gt_boxes,
total_valid_sampled_dict,
sampled_mv_height,
sampled_gt_boxes2d,
total_valid_sampled_dict)
sampled_gt_boxes2d)
data_dict.pop('gt_boxes_mask')
return data_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment