Unverified Commit c5d71a67 authored by yukang's avatar yukang Committed by GitHub
Browse files

Update database_sampler.py

parent 216adfe5
...@@ -9,7 +9,7 @@ import SharedArray ...@@ -9,7 +9,7 @@ import SharedArray
import torch.distributed as dist import torch.distributed as dist
from ...ops.iou3d_nms import iou3d_nms_utils from ...ops.iou3d_nms import iou3d_nms_utils
from ...utils import box_utils, common_utils, box2d_utils, calibration_kitti from ...utils import box_utils, common_utils, calibration_kitti
from pcdet.datasets.kitti.kitti_object_eval_python import kitti_common from pcdet.datasets.kitti.kitti_object_eval_python import kitti_common
class DataBaseSampler(object): class DataBaseSampler(object):
...@@ -18,13 +18,11 @@ class DataBaseSampler(object): ...@@ -18,13 +18,11 @@ class DataBaseSampler(object):
self.class_names = class_names self.class_names = class_names
self.sampler_cfg = sampler_cfg self.sampler_cfg = sampler_cfg
self.aug_with_img = sampler_cfg.get('AUG_WITH_IMAGE', False) self.aug_with_img = sampler_cfg.get('AUG_WITH_IMAGE', False)
self.joint_sample = True self.joint_sample = sampler_cfg.get('JOINT_SAMPLE', False)
self.keep_raw = False self.keep_raw = sampler_cfg.get('KEEP_RAW', False)
self.box_iou_thres = 0.5 self.box_iou_thres = sampler_cfg.get('BOX_IOU_THRES', 1.0)
self.img_aug_type = 'by_depth' self.aug_use_type = sampler_cfg.get('AUG_USE_TYPE', 'annotation')
self.aug_use_type = 'annotation' self.point_refine = sampler_cfg.get('POINT_REFINE', False)
self.point_refine = True
self.img_root_path = 'training/image_2'
self.logger = logger self.logger = logger
self.db_infos = {} self.db_infos = {}
...@@ -169,13 +167,14 @@ class DataBaseSampler(object): ...@@ -169,13 +167,14 @@ class DataBaseSampler(object):
boxes3d = data_dict['gt_boxes'] boxes3d = data_dict['gt_boxes']
boxes2d = data_dict['gt_boxes2d'] boxes2d = data_dict['gt_boxes2d']
corners_lidar = box_utils.boxes_to_corners_3d(boxes3d) corners_lidar = box_utils.boxes_to_corners_3d(boxes3d)
if 'depth' in self.img_aug_type: img_aug_type = self.sampler_cfg.IMG_AUG_TYPE
if 'depth' in img_aug_type:
paste_order = boxes3d[:,0].argsort() paste_order = boxes3d[:,0].argsort()
paste_order = paste_order[::-1] paste_order = paste_order[::-1]
else: else:
paste_order = np.arange(len(boxes3d),dtype=np.int) paste_order = np.arange(len(boxes3d),dtype=np.int)
if 'reverse' in self.img_aug_type: if 'reverse' in img_aug_type:
paste_order = paste_order[::-1] paste_order = paste_order[::-1]
paste_mask = -255 * np.ones(image.shape[:2], dtype=np.int) paste_mask = -255 * np.ones(image.shape[:2], dtype=np.int)
...@@ -253,7 +252,7 @@ class DataBaseSampler(object): ...@@ -253,7 +252,7 @@ class DataBaseSampler(object):
obj_idx = idx * np.ones(len(obj_points), dtype=np.int) obj_idx = idx * np.ones(len(obj_points), dtype=np.int)
# copy crops from images # copy crops from images
img_path = self.root_path / self.img_root_path / (info['image_idx']+'.png') img_path = self.root_path / self.sampler_cfg.IMG_ROOT_PATH / (info['image_idx']+'.png')
raw_image = io.imread(img_path) raw_image = io.imread(img_path)
raw_image = raw_image.astype(np.float32) raw_image = raw_image.astype(np.float32)
raw_center = info['bbox'].reshape(2,2).mean(0) raw_center = info['bbox'].reshape(2,2).mean(0)
...@@ -284,14 +283,14 @@ class DataBaseSampler(object): ...@@ -284,14 +283,14 @@ class DataBaseSampler(object):
data_dict['images'].shape[:2]) data_dict['images'].shape[:2])
sampled_boxes2d = torch.Tensor(sampled_boxes2d) sampled_boxes2d = torch.Tensor(sampled_boxes2d)
existed_boxes2d = torch.Tensor(data_dict['gt_boxes2d']) existed_boxes2d = torch.Tensor(data_dict['gt_boxes2d'])
iou2d1 = box2d_utils.pairwise_iou(sampled_boxes2d, existed_boxes2d).cpu().numpy() iou2d1 = box_utils.pairwise_iou(sampled_boxes2d, existed_boxes2d).cpu().numpy()
iou2d2 = box2d_utils.pairwise_iou(sampled_boxes2d, sampled_boxes2d).cpu().numpy() iou2d2 = box_utils.pairwise_iou(sampled_boxes2d, sampled_boxes2d).cpu().numpy()
iou2d2[range(sampled_boxes2d.shape[0]), range(sampled_boxes2d.shape[0])] = 0 iou2d2[range(sampled_boxes2d.shape[0]), range(sampled_boxes2d.shape[0])] = 0
iou2d1 = iou2d1 if iou2d1.shape[1] > 0 else iou2d2 iou2d1 = iou2d1 if iou2d1.shape[1] > 0 else iou2d2
valid_mask = ((iou2d1.max(axis=1)<self.box_iou_thres) & valid_mask = ((iou2d1.max(axis=1)<self.box_iou_thres) &
(iou2d2.max(axis=1)<self.box_iou_thres) & (iou2d2.max(axis=1)<self.box_iou_thres) &
((iou1.max(axis=1) + iou2.max(axis=1)) == 0)).nonzero()[0] ((iou1.max(axis=1) + iou2.max(axis=1)) == 0)).nonzero()[0]
sampled_boxes2d = sampled_boxes2d[valid_mask].cpu().numpy() sampled_boxes2d = sampled_boxes2d[valid_mask].cpu().numpy()
return sampled_boxes2d, mv_height, valid_mask return sampled_boxes2d, mv_height, valid_mask
...@@ -395,8 +394,7 @@ class DataBaseSampler(object): ...@@ -395,8 +394,7 @@ class DataBaseSampler(object):
sampled_boxes = np.stack([x['box3d_lidar'] for x in sampled_dict], axis=0).astype(np.float32) sampled_boxes = np.stack([x['box3d_lidar'] for x in sampled_dict], axis=0).astype(np.float32)
if self.sampler_cfg.get('DATABASE_WITH_FAKELIDAR', False): assert not self.sampler_cfg.get('DATABASE_WITH_FAKELIDAR', False), 'Please use latest codes to generate GT_DATABASE'
sampled_boxes = box_utils.boxes3d_kitti_fakelidar_to_lidar(sampled_boxes)
iou1 = iou3d_nms_utils.boxes_bev_iou_cpu(sampled_boxes[:, 0:7], existed_boxes[:, 0:7]) iou1 = iou3d_nms_utils.boxes_bev_iou_cpu(sampled_boxes[:, 0:7], existed_boxes[:, 0:7])
iou2 = iou3d_nms_utils.boxes_bev_iou_cpu(sampled_boxes[:, 0:7], sampled_boxes[:, 0:7]) iou2 = iou3d_nms_utils.boxes_bev_iou_cpu(sampled_boxes[:, 0:7], sampled_boxes[:, 0:7])
...@@ -419,16 +417,13 @@ class DataBaseSampler(object): ...@@ -419,16 +417,13 @@ class DataBaseSampler(object):
sampled_gt_boxes = existed_boxes[gt_boxes.shape[0]:, :] sampled_gt_boxes = existed_boxes[gt_boxes.shape[0]:, :]
if self.aug_with_img: if len(sampled_gt_boxes2d) > 0:
if len(sampled_gt_boxes2d) > 0: sampled_gt_boxes2d = np.concatenate(sampled_gt_boxes2d, axis=0)
sampled_gt_boxes2d = np.concatenate(sampled_gt_boxes2d, axis=0)
if total_valid_sampled_dict.__len__() > 0: if total_valid_sampled_dict.__len__() > 0:
data_dict = self.add_sampled_boxes_to_scene(data_dict, data_dict = self.add_sampled_boxes_to_scene(
sampled_gt_boxes, data_dict, sampled_gt_boxes, total_valid_sampled_dict, sampled_mv_height, sampled_gt_boxes2d
total_valid_sampled_dict, )
sampled_mv_height,
sampled_gt_boxes2d)
data_dict.pop('gt_boxes_mask') data_dict.pop('gt_boxes_mask')
return data_dict return data_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment