Commit 929ebfe8 authored by liyinhao's avatar liyinhao
Browse files

change dict.get, enrich assertion

parent 51fa14c2
...@@ -19,8 +19,8 @@ class PointsColorNormalize(object): ...@@ -19,8 +19,8 @@ class PointsColorNormalize(object):
self.color_mean = color_mean self.color_mean = color_mean
def __call__(self, results): def __call__(self, results):
points = results.get('results', None) points = results['points']
assert points.shape[1] >= 6 assert points.shape[1] >= 6, 'Incomplete color channel.'
points[:, 3:6] = points[:, 3:6] - np.array(self.color_mean) / 256.0 points[:, 3:6] = points[:, 3:6] - np.array(self.color_mean) / 256.0
results['points'] = points results['points'] = points
return results return results
...@@ -47,13 +47,13 @@ class LoadPointsFromFile(object): ...@@ -47,13 +47,13 @@ class LoadPointsFromFile(object):
def __init__(self, use_height, load_dim=6, use_dim=[0, 1, 2]): def __init__(self, use_height, load_dim=6, use_dim=[0, 1, 2]):
self.use_height = use_height self.use_height = use_height
assert max(use_dim) < load_dim assert max(use_dim) < load_dim, 'Wrong dimension is used.'
self.load_dim = load_dim self.load_dim = load_dim
self.use_dim = use_dim self.use_dim = use_dim
def __call__(self, results): def __call__(self, results):
pts_filename = results.get('pts_filename', None) pts_filename = results['pts_filename']
assert osp.exists(pts_filename) assert osp.exists(pts_filename), f'{pts_filename} does not exist.'
points = np.load(pts_filename) points = np.load(pts_filename)
points = points.reshape(-1, self.load_dim) points = points.reshape(-1, self.load_dim)
points = points[:, self.use_dim] points = points[:, self.use_dim]
...@@ -85,14 +85,15 @@ class LoadAnnotations3D(object): ...@@ -85,14 +85,15 @@ class LoadAnnotations3D(object):
pass pass
def __call__(self, results): def __call__(self, results):
ins_labelname = results.get('ins_labelname', None) pts_instance_mask_path = results['pts_instance_mask_path']
sem_labelname = results.get('sem_labelname', None) pts_semantic_mask_path = results['pts_semantic_mask_path']
if ins_labelname is not None and sem_labelname is not None: assert osp.exists(pts_instance_mask_path
assert osp.exists(ins_labelname) ), f'{pts_instance_mask_path} does not exist.'
assert osp.exists(sem_labelname) assert osp.exists(pts_semantic_mask_path
pts_instance_mask = np.load(ins_labelname) ), f'{pts_semantic_mask_path} does not exist.'
pts_semantic_mask = np.load(sem_labelname) pts_instance_mask = np.load(pts_instance_mask_path)
pts_semantic_mask = np.load(pts_semantic_mask_path)
results['pts_instance_mask'] = pts_instance_mask results['pts_instance_mask'] = pts_instance_mask
results['pts_semantic_mask'] = pts_semantic_mask results['pts_semantic_mask'] = pts_semantic_mask
......
...@@ -37,8 +37,6 @@ def test_load_points_from_file(): ...@@ -37,8 +37,6 @@ def test_load_points_from_file():
def test_load_annotations3D(): def test_load_annotations3D():
sunrgbd_info = mmcv.load('./tests/data/sunrgbd/sunrgbd_infos.pkl')[0] sunrgbd_info = mmcv.load('./tests/data/sunrgbd/sunrgbd_infos.pkl')[0]
sunrgbd_load_annotations3D = LoadAnnotations3D()
sunrgbd_results = dict()
if sunrgbd_info['annos']['gt_num'] != 0: if sunrgbd_info['annos']['gt_num'] != 0:
sunrgbd_gt_bboxes_3d = sunrgbd_info['annos']['gt_boxes_upright_depth'] sunrgbd_gt_bboxes_3d = sunrgbd_info['annos']['gt_boxes_upright_depth']
sunrgbd_gt_labels = sunrgbd_info['annos']['class'].reshape(-1, 1) sunrgbd_gt_labels = sunrgbd_info['annos']['class'].reshape(-1, 1)
...@@ -47,16 +45,9 @@ def test_load_annotations3D(): ...@@ -47,16 +45,9 @@ def test_load_annotations3D():
sunrgbd_gt_bboxes_3d = np.zeros((1, 6), dtype=np.float32) sunrgbd_gt_bboxes_3d = np.zeros((1, 6), dtype=np.float32)
sunrgbd_gt_labels = np.zeros((1, 1)) sunrgbd_gt_labels = np.zeros((1, 1))
sunrgbd_gt_bboxes_3d_mask = np.zeros((1, 1)) sunrgbd_gt_bboxes_3d_mask = np.zeros((1, 1))
sunrgbd_results['gt_bboxes_3d'] = sunrgbd_gt_bboxes_3d assert sunrgbd_gt_bboxes_3d.shape == (3, 7)
sunrgbd_results['gt_labels'] = sunrgbd_gt_labels assert sunrgbd_gt_labels.shape == (3, 1)
sunrgbd_results['gt_bboxes_3d_mask'] = sunrgbd_gt_bboxes_3d_mask assert sunrgbd_gt_bboxes_3d_mask.shape == (3, 1)
sunrgbd_results = sunrgbd_load_annotations3D(sunrgbd_results)
sunrgbd_gt_boxes = sunrgbd_results.get('gt_bboxes_3d', None)
sunrgbd_gt_lbaels = sunrgbd_results.get('gt_labels', None)
sunrgbd_gt_boxes_mask = sunrgbd_results.get('gt_bboxes_3d_mask', None)
assert sunrgbd_gt_boxes.shape == (3, 7)
assert sunrgbd_gt_lbaels.shape == (3, 1)
assert sunrgbd_gt_boxes_mask.shape == (3, 1)
scannet_info = mmcv.load('./tests/data/scannet/scannet_infos.pkl')[0] scannet_info = mmcv.load('./tests/data/scannet/scannet_infos.pkl')[0]
scannet_load_annotations3D = LoadAnnotations3D() scannet_load_annotations3D = LoadAnnotations3D()
...@@ -71,10 +62,10 @@ def test_load_annotations3D(): ...@@ -71,10 +62,10 @@ def test_load_annotations3D():
scannet_gt_labels = np.zeros((1, 1)) scannet_gt_labels = np.zeros((1, 1))
scannet_gt_bboxes_3d_mask = np.zeros((1, 1)) scannet_gt_bboxes_3d_mask = np.zeros((1, 1))
scan_name = scannet_info['point_cloud']['lidar_idx'] scan_name = scannet_info['point_cloud']['lidar_idx']
scannet_results['ins_labelname'] = osp.join(data_path, scannet_results['pts_instance_mask_path'] = osp.join(
scan_name + '_ins_label.npy') data_path, scan_name + '_ins_label.npy')
scannet_results['sem_labelname'] = osp.join(data_path, scannet_results['pts_semantic_mask_path'] = osp.join(
scan_name + '_sem_label.npy') data_path, scan_name + '_sem_label.npy')
scannet_results['info'] = scannet_info scannet_results['info'] = scannet_info
scannet_results['gt_bboxes_3d'] = scannet_gt_bboxes_3d scannet_results['gt_bboxes_3d'] = scannet_gt_bboxes_3d
scannet_results['gt_labels'] = scannet_gt_labels scannet_results['gt_labels'] = scannet_gt_labels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment