Commit e9ef0711 authored by liyinhao's avatar liyinhao
Browse files

add indoor_load_data test unit

parent d9c7fb38
...@@ -123,8 +123,8 @@ class IndoorRotateData(object): ...@@ -123,8 +123,8 @@ class IndoorRotateData(object):
@PIPELINES.register_module() @PIPELINES.register_module()
class IndoorShuffleData(object): class PointShuffle(object):
"""Indoor Shuffle Data """Point Shuffle.
Shuffle points. Shuffle points.
""" """
......
...@@ -39,28 +39,30 @@ class IndoorLoadData(object): ...@@ -39,28 +39,30 @@ class IndoorLoadData(object):
gt_boxes_mask = np.zeros((1, 1)) gt_boxes_mask = np.zeros((1, 1))
if self.name == 'scannet': if self.name == 'scannet':
scan_name = info['points']['lidar_idx'] scan_name = info['point_cloud']['lidar_idx']
points = self._get_lidar(scan_name, data_path) point_cloud = self._get_lidar(scan_name, data_path)
instance_labels = self._get_instance_label(scan_name, data_path) instance_labels = self._get_instance_label(scan_name, data_path)
semantic_labels = self._get_semantic_label(scan_name, data_path) semantic_labels = self._get_semantic_label(scan_name, data_path)
else: else:
points = np.load( point_cloud = np.load(
osp.join(data_path, 'lidar', osp.join(data_path, 'lidar',
'%06d.npz' % info['points']['lidar_idx']))['pc'] '%06d.npz' % info['point_cloud']['lidar_idx']))['pc']
if not self.use_color: if not self.use_color:
points = points[:, 0:3] # do not use color for now point_cloud = point_cloud[:, 0:3] # do not use color for now
pcl_color = points[:, 3:6] pcl_color = point_cloud[:, 3:6]
else: else:
points = points[:, 0:6] point_cloud = point_cloud[:, 0:6]
pcl_color = points[:, 3:6] pcl_color = point_cloud[:, 3:6]
points[:, 3:] = (points[:, 3:] - np.array(self.mean_color)) / 256.0 point_cloud[:, 3:] = (point_cloud[:, 3:] -
np.array(self.mean_color)) / 256.0
if self.use_height: if self.use_height:
floor_height = np.percentile(points[:, 2], 0.99) floor_height = np.percentile(point_cloud[:, 2], 0.99)
height = points[:, 2] - floor_height height = point_cloud[:, 2] - floor_height
points = np.concatenate([points, np.expand_dims(height, 1)], 1) point_cloud = np.concatenate(
results['points'] = points [point_cloud, np.expand_dims(height, 1)], 1)
results['point_cloud'] = point_cloud
if self.name == 'scannet': if self.name == 'scannet':
results['pcl_color'] = pcl_color results['pcl_color'] = pcl_color
......
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
from mmdet3d.datasets.pipelines.indoor_augment import (IndoorFlipData, from mmdet3d.datasets.pipelines.indoor_augment import (IndoorFlipData,
IndoorRotateData, IndoorRotateData,
IndoorShuffleData) PointShuffle)
def test_indoor_flip_data(): def test_indoor_flip_data():
...@@ -80,12 +80,12 @@ def test_indoor_rotate_data(): ...@@ -80,12 +80,12 @@ def test_indoor_rotate_data():
assert scannet_gt_boxes.shape == (2, 6) assert scannet_gt_boxes.shape == (2, 6)
def test_indoor_shuffle_data(): def test_point_shuffle():
indoor_shuffle_data = IndoorShuffleData() point_shuffle = PointShuffle()
results = dict() results = dict()
results['points'] = np.array( results['points'] = np.array(
[[1.02828765e+00, 3.65790772e+00, 1.97294697e-01, 1.61959505e+00], [[1.02828765e+00, 3.65790772e+00, 1.97294697e-01, 1.61959505e+00],
[-3.95979017e-01, 1.05465031e+00, -7.49204338e-01, 6.73096001e-01]]) [-3.95979017e-01, 1.05465031e+00, -7.49204338e-01, 6.73096001e-01]])
results = indoor_shuffle_data(results) results = point_shuffle(results)
points = results.get('points') points = results.get('points')
assert points.shape == (2, 4) assert points.shape == (2, 4)
import mmcv
from mmdet3d.datasets.pipelines.indoor_loading import IndoorLoadData
def test_indoor_load_data():
train_info = mmcv.load('./tests/data/sunrgbd/sunrgbd_infos_train.pkl')
sunrgbd_load_data = IndoorLoadData('sunrgbd', False, True, [0.5, 0.5, 0.5])
sunrgbd_results = dict()
sunrgbd_results['data_path'] = './tests/data/sunrgbd/sunrgbd_trainval'
sunrgbd_results['info'] = train_info[0]
sunrgbd_results = sunrgbd_load_data(sunrgbd_results)
point_cloud = sunrgbd_results.get('point_cloud', None)
gt_boxes = sunrgbd_results.get('gt_boxes', None)
gt_classes = sunrgbd_results.get('gt_classes', None)
gt_boxes_mask = sunrgbd_results.get('gt_boxes_mask', None)
assert point_cloud.shape == (50000, 4)
assert gt_boxes.shape == (3, 7)
assert gt_classes.shape == (3, 1)
assert gt_boxes_mask.shape == (3, 1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment