Commit 7b9bb85b authored by liyinhao's avatar liyinhao
Browse files

change some keys according API doc

parent d86dffa4
......@@ -28,32 +28,30 @@ class LoadPointsFromFile(object):
info = results.get('info', None)
name = 'scannet' if info.get('image', None) is None else 'sunrgbd'
if name == 'scannet':
scan_name = info['point_cloud']['lidar_idx']
point_cloud = self._get_lidar(scan_name, data_path)
pts_filename = info['point_cloud']['lidar_idx']
points = self._get_lidar(pts_filename, data_path)
else:
point_cloud = np.load(
points = np.load(
osp.join(data_path, 'lidar',
'%06d.npz' % info['point_cloud']['lidar_idx']))['pc']
if not self.use_color:
if name == 'scannet':
pcl_color = point_cloud[:, 3:6]
point_cloud = point_cloud[:, 0:3]
pts_color = points[:, 3:6]
points = points[:, 0:3]
else:
if name == 'scannet':
pcl_color = point_cloud[:, 3:6]
point_cloud = point_cloud[:, 0:6]
point_cloud[:, 3:] = (point_cloud[:, 3:] -
np.array(self.color_mean)) / 256.0
pts_color = points[:, 3:6]
points = points[:, 0:6]
points[:, 3:] = (points[:, 3:] - np.array(self.color_mean)) / 256.0
if self.use_height:
floor_height = np.percentile(point_cloud[:, 2], 0.99)
height = point_cloud[:, 2] - floor_height
point_cloud = np.concatenate(
[point_cloud, np.expand_dims(height, 1)], 1)
results['point_cloud'] = point_cloud
floor_height = np.percentile(points[:, 2], 0.99)
height = points[:, 2] - floor_height
points = np.concatenate([points, np.expand_dims(height, 1)], 1)
results['points'] = points
if name == 'scannet':
results['pcl_color'] = pcl_color
results['pts_color'] = pts_color
return results
def __repr__(self):
......@@ -86,13 +84,13 @@ class LoadAnnotations3D(object):
data_path = results.get('data_path', None)
info = results.get('info', None)
if info['annos']['gt_num'] != 0:
gt_boxes = info['annos']['gt_boxes_upright_depth']
gt_bboxes_3d = info['annos']['gt_boxes_upright_depth']
gt_classes = info['annos']['class'].reshape(-1, 1)
gt_boxes_mask = np.ones_like(gt_classes)
gt_bboxes_3d_mask = np.ones_like(gt_classes)
else:
gt_boxes = np.zeros((1, 6), dtype=np.float32)
gt_bboxes_3d = np.zeros((1, 6), dtype=np.float32)
gt_classes = np.zeros((1, 1))
gt_boxes_mask = np.zeros((1, 1))
gt_bboxes_3d_mask = np.zeros((1, 1))
name = 'scannet' if info.get('image', None) is None else 'sunrgbd'
if name == 'scannet':
......@@ -102,9 +100,9 @@ class LoadAnnotations3D(object):
results['instance_labels'] = instance_labels
results['semantic_labels'] = semantic_labels
results['gt_boxes'] = gt_boxes
results['gt_bboxes_3d'] = gt_bboxes_3d
results['gt_classes'] = gt_classes
results['gt_boxes_mask'] = gt_boxes_mask
results['gt_bboxes_3d_mask'] = gt_bboxes_3d_mask
return results
def __repr__(self):
......
......@@ -12,7 +12,7 @@ def test_load_points_from_file():
sunrgbd_results['data_path'] = './tests/data/sunrgbd/sunrgbd_trainval'
sunrgbd_results['info'] = sunrgbd_info[0]
sunrgbd_results = sunrgbd_load_points_from_file(sunrgbd_results)
sunrgbd_point_cloud = sunrgbd_results.get('point_cloud', None)
sunrgbd_point_cloud = sunrgbd_results.get('points', None)
assert sunrgbd_point_cloud.shape == (1000, 4)
scannet_info = mmcv.load('./tests/data/scannet/scannet_infos.pkl')
......@@ -22,8 +22,8 @@ def test_load_points_from_file():
'data_path'] = './tests/data/scannet/scannet_train_instance_data'
scannet_results['info'] = scannet_info[0]
scannet_results = scannet_load_data(scannet_results)
scannet_point_cloud = scannet_results.get('point_cloud', None)
scannet_pcl_color = scannet_results.get('pcl_color', None)
scannet_point_cloud = scannet_results.get('points', None)
scannet_pcl_color = scannet_results.get('pts_color', None)
assert scannet_point_cloud.shape == (1000, 4)
assert scannet_pcl_color.shape == (1000, 3)
......@@ -35,9 +35,9 @@ def test_load_annotations3D():
sunrgbd_results['data_path'] = './tests/data/sunrgbd/sunrgbd_trainval'
sunrgbd_results['info'] = sunrgbd_info[0]
sunrgbd_results = sunrgbd_load_annotations3D(sunrgbd_results)
sunrgbd_gt_boxes = sunrgbd_results.get('gt_boxes', None)
sunrgbd_gt_boxes = sunrgbd_results.get('gt_bboxes_3d', None)
sunrgbd_gt_classes = sunrgbd_results.get('gt_classes', None)
sunrgbd_gt_boxes_mask = sunrgbd_results.get('gt_boxes_mask', None)
sunrgbd_gt_boxes_mask = sunrgbd_results.get('gt_bboxes_3d_mask', None)
assert sunrgbd_gt_boxes.shape == (3, 7)
assert sunrgbd_gt_classes.shape == (3, 1)
assert sunrgbd_gt_boxes_mask.shape == (3, 1)
......@@ -49,9 +49,9 @@ def test_load_annotations3D():
'data_path'] = './tests/data/scannet/scannet_train_instance_data'
scannet_results['info'] = scannet_info[0]
scannet_results = scannet_load_annotations3D(scannet_results)
scannet_gt_boxes = scannet_results.get('gt_boxes', None)
scannet_gt_boxes = scannet_results.get('gt_bboxes_3d', None)
scannet_gt_classes = scannet_results.get('gt_classes', None)
scannet_gt_boxes_mask = scannet_results.get('gt_boxes_mask', None)
scannet_gt_boxes_mask = scannet_results.get('gt_bboxes_3d_mask', None)
scannet_instance_labels = scannet_results.get('instance_labels', None)
scannet_semantic_labels = scannet_results.get('semantic_labels', None)
assert scannet_gt_boxes.shape == (27, 6)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment