Commit 7b9bb85b authored by liyinhao's avatar liyinhao
Browse files

change some keys according API doc

parent d86dffa4
...@@ -28,32 +28,30 @@ class LoadPointsFromFile(object): ...@@ -28,32 +28,30 @@ class LoadPointsFromFile(object):
info = results.get('info', None) info = results.get('info', None)
name = 'scannet' if info.get('image', None) is None else 'sunrgbd' name = 'scannet' if info.get('image', None) is None else 'sunrgbd'
if name == 'scannet': if name == 'scannet':
scan_name = info['point_cloud']['lidar_idx'] pts_filename = info['point_cloud']['lidar_idx']
point_cloud = self._get_lidar(scan_name, data_path) points = self._get_lidar(pts_filename, data_path)
else: else:
point_cloud = np.load( points = np.load(
osp.join(data_path, 'lidar', osp.join(data_path, 'lidar',
'%06d.npz' % info['point_cloud']['lidar_idx']))['pc'] '%06d.npz' % info['point_cloud']['lidar_idx']))['pc']
if not self.use_color: if not self.use_color:
if name == 'scannet': if name == 'scannet':
pcl_color = point_cloud[:, 3:6] pts_color = points[:, 3:6]
point_cloud = point_cloud[:, 0:3] points = points[:, 0:3]
else: else:
if name == 'scannet': if name == 'scannet':
pcl_color = point_cloud[:, 3:6] pts_color = points[:, 3:6]
point_cloud = point_cloud[:, 0:6] points = points[:, 0:6]
point_cloud[:, 3:] = (point_cloud[:, 3:] - points[:, 3:] = (points[:, 3:] - np.array(self.color_mean)) / 256.0
np.array(self.color_mean)) / 256.0
if self.use_height: if self.use_height:
floor_height = np.percentile(point_cloud[:, 2], 0.99) floor_height = np.percentile(points[:, 2], 0.99)
height = point_cloud[:, 2] - floor_height height = points[:, 2] - floor_height
point_cloud = np.concatenate( points = np.concatenate([points, np.expand_dims(height, 1)], 1)
[point_cloud, np.expand_dims(height, 1)], 1) results['points'] = points
results['point_cloud'] = point_cloud
if name == 'scannet': if name == 'scannet':
results['pcl_color'] = pcl_color results['pts_color'] = pts_color
return results return results
def __repr__(self): def __repr__(self):
...@@ -86,13 +84,13 @@ class LoadAnnotations3D(object): ...@@ -86,13 +84,13 @@ class LoadAnnotations3D(object):
data_path = results.get('data_path', None) data_path = results.get('data_path', None)
info = results.get('info', None) info = results.get('info', None)
if info['annos']['gt_num'] != 0: if info['annos']['gt_num'] != 0:
gt_boxes = info['annos']['gt_boxes_upright_depth'] gt_bboxes_3d = info['annos']['gt_boxes_upright_depth']
gt_classes = info['annos']['class'].reshape(-1, 1) gt_classes = info['annos']['class'].reshape(-1, 1)
gt_boxes_mask = np.ones_like(gt_classes) gt_bboxes_3d_mask = np.ones_like(gt_classes)
else: else:
gt_boxes = np.zeros((1, 6), dtype=np.float32) gt_bboxes_3d = np.zeros((1, 6), dtype=np.float32)
gt_classes = np.zeros((1, 1)) gt_classes = np.zeros((1, 1))
gt_boxes_mask = np.zeros((1, 1)) gt_bboxes_3d_mask = np.zeros((1, 1))
name = 'scannet' if info.get('image', None) is None else 'sunrgbd' name = 'scannet' if info.get('image', None) is None else 'sunrgbd'
if name == 'scannet': if name == 'scannet':
...@@ -102,9 +100,9 @@ class LoadAnnotations3D(object): ...@@ -102,9 +100,9 @@ class LoadAnnotations3D(object):
results['instance_labels'] = instance_labels results['instance_labels'] = instance_labels
results['semantic_labels'] = semantic_labels results['semantic_labels'] = semantic_labels
results['gt_boxes'] = gt_boxes results['gt_bboxes_3d'] = gt_bboxes_3d
results['gt_classes'] = gt_classes results['gt_classes'] = gt_classes
results['gt_boxes_mask'] = gt_boxes_mask results['gt_bboxes_3d_mask'] = gt_bboxes_3d_mask
return results return results
def __repr__(self): def __repr__(self):
......
...@@ -12,7 +12,7 @@ def test_load_points_from_file(): ...@@ -12,7 +12,7 @@ def test_load_points_from_file():
sunrgbd_results['data_path'] = './tests/data/sunrgbd/sunrgbd_trainval' sunrgbd_results['data_path'] = './tests/data/sunrgbd/sunrgbd_trainval'
sunrgbd_results['info'] = sunrgbd_info[0] sunrgbd_results['info'] = sunrgbd_info[0]
sunrgbd_results = sunrgbd_load_points_from_file(sunrgbd_results) sunrgbd_results = sunrgbd_load_points_from_file(sunrgbd_results)
sunrgbd_point_cloud = sunrgbd_results.get('point_cloud', None) sunrgbd_point_cloud = sunrgbd_results.get('points', None)
assert sunrgbd_point_cloud.shape == (1000, 4) assert sunrgbd_point_cloud.shape == (1000, 4)
scannet_info = mmcv.load('./tests/data/scannet/scannet_infos.pkl') scannet_info = mmcv.load('./tests/data/scannet/scannet_infos.pkl')
...@@ -22,8 +22,8 @@ def test_load_points_from_file(): ...@@ -22,8 +22,8 @@ def test_load_points_from_file():
'data_path'] = './tests/data/scannet/scannet_train_instance_data' 'data_path'] = './tests/data/scannet/scannet_train_instance_data'
scannet_results['info'] = scannet_info[0] scannet_results['info'] = scannet_info[0]
scannet_results = scannet_load_data(scannet_results) scannet_results = scannet_load_data(scannet_results)
scannet_point_cloud = scannet_results.get('point_cloud', None) scannet_point_cloud = scannet_results.get('points', None)
scannet_pcl_color = scannet_results.get('pcl_color', None) scannet_pcl_color = scannet_results.get('pts_color', None)
assert scannet_point_cloud.shape == (1000, 4) assert scannet_point_cloud.shape == (1000, 4)
assert scannet_pcl_color.shape == (1000, 3) assert scannet_pcl_color.shape == (1000, 3)
...@@ -35,9 +35,9 @@ def test_load_annotations3D(): ...@@ -35,9 +35,9 @@ def test_load_annotations3D():
sunrgbd_results['data_path'] = './tests/data/sunrgbd/sunrgbd_trainval' sunrgbd_results['data_path'] = './tests/data/sunrgbd/sunrgbd_trainval'
sunrgbd_results['info'] = sunrgbd_info[0] sunrgbd_results['info'] = sunrgbd_info[0]
sunrgbd_results = sunrgbd_load_annotations3D(sunrgbd_results) sunrgbd_results = sunrgbd_load_annotations3D(sunrgbd_results)
sunrgbd_gt_boxes = sunrgbd_results.get('gt_boxes', None) sunrgbd_gt_boxes = sunrgbd_results.get('gt_bboxes_3d', None)
sunrgbd_gt_classes = sunrgbd_results.get('gt_classes', None) sunrgbd_gt_classes = sunrgbd_results.get('gt_classes', None)
sunrgbd_gt_boxes_mask = sunrgbd_results.get('gt_boxes_mask', None) sunrgbd_gt_boxes_mask = sunrgbd_results.get('gt_bboxes_3d_mask', None)
assert sunrgbd_gt_boxes.shape == (3, 7) assert sunrgbd_gt_boxes.shape == (3, 7)
assert sunrgbd_gt_classes.shape == (3, 1) assert sunrgbd_gt_classes.shape == (3, 1)
assert sunrgbd_gt_boxes_mask.shape == (3, 1) assert sunrgbd_gt_boxes_mask.shape == (3, 1)
...@@ -49,9 +49,9 @@ def test_load_annotations3D(): ...@@ -49,9 +49,9 @@ def test_load_annotations3D():
'data_path'] = './tests/data/scannet/scannet_train_instance_data' 'data_path'] = './tests/data/scannet/scannet_train_instance_data'
scannet_results['info'] = scannet_info[0] scannet_results['info'] = scannet_info[0]
scannet_results = scannet_load_annotations3D(scannet_results) scannet_results = scannet_load_annotations3D(scannet_results)
scannet_gt_boxes = scannet_results.get('gt_boxes', None) scannet_gt_boxes = scannet_results.get('gt_bboxes_3d', None)
scannet_gt_classes = scannet_results.get('gt_classes', None) scannet_gt_classes = scannet_results.get('gt_classes', None)
scannet_gt_boxes_mask = scannet_results.get('gt_boxes_mask', None) scannet_gt_boxes_mask = scannet_results.get('gt_bboxes_3d_mask', None)
scannet_instance_labels = scannet_results.get('instance_labels', None) scannet_instance_labels = scannet_results.get('instance_labels', None)
scannet_semantic_labels = scannet_results.get('semantic_labels', None) scannet_semantic_labels = scannet_results.get('semantic_labels', None)
assert scannet_gt_boxes.shape == (27, 6) assert scannet_gt_boxes.shape == (27, 6)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment