"...csrc/io/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "c88423b0267a140c4a6aed75aca80ed398c967e4"
Commit 929ebfe8 authored by liyinhao's avatar liyinhao
Browse files

change dict.get, enrich assertion

parent 51fa14c2
......@@ -19,8 +19,8 @@ class PointsColorNormalize(object):
self.color_mean = color_mean
def __call__(self, results):
points = results.get('results', None)
assert points.shape[1] >= 6
points = results['points']
assert points.shape[1] >= 6, 'Incomplete color channel.'
points[:, 3:6] = points[:, 3:6] - np.array(self.color_mean) / 256.0
results['points'] = points
return results
......@@ -47,13 +47,13 @@ class LoadPointsFromFile(object):
def __init__(self, use_height, load_dim=6, use_dim=[0, 1, 2]):
self.use_height = use_height
assert max(use_dim) < load_dim
assert max(use_dim) < load_dim, 'Wrong dimension is used.'
self.load_dim = load_dim
self.use_dim = use_dim
def __call__(self, results):
pts_filename = results.get('pts_filename', None)
assert osp.exists(pts_filename)
pts_filename = results['pts_filename']
assert osp.exists(pts_filename), f'{pts_filename} does not exist.'
points = np.load(pts_filename)
points = points.reshape(-1, self.load_dim)
points = points[:, self.use_dim]
......@@ -85,16 +85,17 @@ class LoadAnnotations3D(object):
pass
def __call__(self, results):
ins_labelname = results.get('ins_labelname', None)
sem_labelname = results.get('sem_labelname', None)
if ins_labelname is not None and sem_labelname is not None:
assert osp.exists(ins_labelname)
assert osp.exists(sem_labelname)
pts_instance_mask = np.load(ins_labelname)
pts_semantic_mask = np.load(sem_labelname)
results['pts_instance_mask'] = pts_instance_mask
results['pts_semantic_mask'] = pts_semantic_mask
pts_instance_mask_path = results['pts_instance_mask_path']
pts_semantic_mask_path = results['pts_semantic_mask_path']
assert osp.exists(pts_instance_mask_path
), f'{pts_instance_mask_path} does not exist.'
assert osp.exists(pts_semantic_mask_path
), f'{pts_semantic_mask_path} does not exist.'
pts_instance_mask = np.load(pts_instance_mask_path)
pts_semantic_mask = np.load(pts_semantic_mask_path)
results['pts_instance_mask'] = pts_instance_mask
results['pts_semantic_mask'] = pts_semantic_mask
return results
......
......@@ -37,8 +37,6 @@ def test_load_points_from_file():
def test_load_annotations3D():
sunrgbd_info = mmcv.load('./tests/data/sunrgbd/sunrgbd_infos.pkl')[0]
sunrgbd_load_annotations3D = LoadAnnotations3D()
sunrgbd_results = dict()
if sunrgbd_info['annos']['gt_num'] != 0:
sunrgbd_gt_bboxes_3d = sunrgbd_info['annos']['gt_boxes_upright_depth']
sunrgbd_gt_labels = sunrgbd_info['annos']['class'].reshape(-1, 1)
......@@ -47,16 +45,9 @@ def test_load_annotations3D():
sunrgbd_gt_bboxes_3d = np.zeros((1, 6), dtype=np.float32)
sunrgbd_gt_labels = np.zeros((1, 1))
sunrgbd_gt_bboxes_3d_mask = np.zeros((1, 1))
sunrgbd_results['gt_bboxes_3d'] = sunrgbd_gt_bboxes_3d
sunrgbd_results['gt_labels'] = sunrgbd_gt_labels
sunrgbd_results['gt_bboxes_3d_mask'] = sunrgbd_gt_bboxes_3d_mask
sunrgbd_results = sunrgbd_load_annotations3D(sunrgbd_results)
sunrgbd_gt_boxes = sunrgbd_results.get('gt_bboxes_3d', None)
sunrgbd_gt_lbaels = sunrgbd_results.get('gt_labels', None)
sunrgbd_gt_boxes_mask = sunrgbd_results.get('gt_bboxes_3d_mask', None)
assert sunrgbd_gt_boxes.shape == (3, 7)
assert sunrgbd_gt_lbaels.shape == (3, 1)
assert sunrgbd_gt_boxes_mask.shape == (3, 1)
assert sunrgbd_gt_bboxes_3d.shape == (3, 7)
assert sunrgbd_gt_labels.shape == (3, 1)
assert sunrgbd_gt_bboxes_3d_mask.shape == (3, 1)
scannet_info = mmcv.load('./tests/data/scannet/scannet_infos.pkl')[0]
scannet_load_annotations3D = LoadAnnotations3D()
......@@ -71,10 +62,10 @@ def test_load_annotations3D():
scannet_gt_labels = np.zeros((1, 1))
scannet_gt_bboxes_3d_mask = np.zeros((1, 1))
scan_name = scannet_info['point_cloud']['lidar_idx']
scannet_results['ins_labelname'] = osp.join(data_path,
scan_name + '_ins_label.npy')
scannet_results['sem_labelname'] = osp.join(data_path,
scan_name + '_sem_label.npy')
scannet_results['pts_instance_mask_path'] = osp.join(
data_path, scan_name + '_ins_label.npy')
scannet_results['pts_semantic_mask_path'] = osp.join(
data_path, scan_name + '_sem_label.npy')
scannet_results['info'] = scannet_info
scannet_results['gt_bboxes_3d'] = scannet_gt_bboxes_3d
scannet_results['gt_labels'] = scannet_gt_labels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment