Commit e90a8be6 authored by liyinhao's avatar liyinhao
Browse files

fix some problems

parent fe799e0a
......@@ -7,10 +7,10 @@ from .loader import DistributedGroupSampler, GroupSampler, build_dataloader
from .nuscenes_dataset import NuScenesDataset
from .pipelines import (GlobalRotScale, IndoorFlipData, IndoorGlobalRotScale,
IndoorLoadAnnotations3D, IndoorLoadPointsFromFile,
IndoorPointsColorJitter, IndoorPointsColorNormalize,
ObjectNoise, ObjectRangeFilter, ObjectSample,
PointSample, PointShuffle, PointsRangeFilter,
RandomFlip3D)
IndoorPointSample, IndoorPointsColorJitter,
IndoorPointsColorNormalize, ObjectNoise,
ObjectRangeFilter, ObjectSample, PointShuffle,
PointsRangeFilter, RandomFlip3D)
__all__ = [
'KittiDataset', 'GroupSampler', 'DistributedGroupSampler',
......@@ -18,7 +18,7 @@ __all__ = [
'CocoDataset', 'Kitti2DDataset', 'NuScenesDataset', 'ObjectSample',
'RandomFlip3D', 'ObjectNoise', 'GlobalRotScale', 'PointShuffle',
'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D',
'IndoorLoadPointsFromFile', 'IndoorPointsColorNormalize', 'PointSample',
'IndoorLoadAnnotations3D', 'IndoorPointsColorJitter',
'IndoorLoadPointsFromFile', 'IndoorPointsColorNormalize',
'IndoorPointSample', 'IndoorLoadAnnotations3D', 'IndoorPointsColorJitter',
'IndoorGlobalRotScale', 'IndoorFlipData'
]
......@@ -5,7 +5,7 @@ from .indoor_augment import (IndoorFlipData, IndoorGlobalRotScale,
IndoorPointsColorJitter)
from .indoor_loading import (IndoorLoadAnnotations3D, IndoorLoadPointsFromFile,
IndoorPointsColorNormalize)
from .indoor_sample import PointSample
from .indoor_sample import IndoorPointSample
from .loading import LoadMultiViewImageFromFiles, LoadPointsFromFile
from .train_aug import (GlobalRotScale, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointShuffle, PointsRangeFilter,
......@@ -18,5 +18,6 @@ __all__ = [
'DefaultFormatBundle', 'DefaultFormatBundle3D', 'DataBaseSampler',
'IndoorGlobalRotScale', 'IndoorPointsColorJitter', 'IndoorFlipData',
'MMDataBaseSampler', 'IndoorLoadPointsFromFile',
'IndoorPointsColorNormalize', 'IndoorLoadAnnotations3D', 'PointSample'
'IndoorPointsColorNormalize', 'IndoorLoadAnnotations3D',
'IndoorPointSample'
]
......@@ -4,7 +4,7 @@ from mmdet.datasets.builder import PIPELINES
@PIPELINES.register_module()
class PointSample(object):
class IndoorPointSample(object):
"""Point Sample.
Sampling data to a certain number.
......@@ -46,7 +46,7 @@ class PointSample(object):
return points[choices]
def __call__(self, results):
points = results.get('points', None)
points = results['points']
points, choices = self.points_random_sampling(
points, self.num_points, return_choices=True)
pts_instance_mask = results.get('pts_instance_mask', None)
......
......@@ -65,7 +65,7 @@ def test_load_annotations3D():
scannet_results['pts_instance_mask_path'] = osp.join(
data_path, f'{scan_name}_ins_label.npy')
scannet_results['pts_semantic_mask_path'] = osp.join(
data_path, scan_name + '_sem_label.npy')
data_path, f'{scan_name}_sem_label.npy')
scannet_results['info'] = scannet_info
scannet_results['gt_bboxes_3d'] = scannet_gt_bboxes_3d
scannet_results['gt_labels'] = scannet_gt_labels
......
import numpy as np
from mmdet3d.datasets.pipelines import PointSample
from mmdet3d.datasets.pipelines import IndoorPointSample
def test_indoor_sample():
np.random.seed(0)
scannet_sample_points = PointSample(5)
scannet_sample_points = IndoorPointSample(5)
scannet_results = dict()
scannet_points = np.array([[1.0719866, -0.7870435, 0.8408122, 0.9196809],
[1.103661, 0.81065744, 2.6616862, 2.7405548],
......@@ -24,11 +24,9 @@ def test_indoor_sample():
scannet_pts_semantic_mask = np.array([38, 1, 1, 40, 0, 40, 1, 1, 1, 0])
scannet_results['pts_semantic_mask'] = scannet_pts_semantic_mask
scannet_results = scannet_sample_points(scannet_results)
scannet_points_result = scannet_results.get('points', None)
scannet_instance_labels_result = scannet_results.get(
'pts_instance_mask', None)
scannet_semantic_labels_result = scannet_results.get(
'pts_semantic_mask', None)
scannet_points_result = scannet_results['points']
scannet_instance_labels_result = scannet_results['pts_instance_mask']
scannet_semantic_labels_result = scannet_results['pts_semantic_mask']
scannet_choices = np.array([2, 8, 4, 9, 1])
assert np.allclose(scannet_points[scannet_choices], scannet_points_result)
assert np.all(scannet_pts_instance_mask[scannet_choices] ==
......@@ -37,7 +35,7 @@ def test_indoor_sample():
scannet_semantic_labels_result)
np.random.seed(0)
sunrgbd_sample_points = PointSample(5)
sunrgbd_sample_points = IndoorPointSample(5)
sunrgbd_results = dict()
sunrgbd_point_cloud = np.array(
[[-1.8135729e-01, 1.4695230e+00, -1.2780589e+00, 7.8938007e-03],
......@@ -53,6 +51,6 @@ def test_indoor_sample():
sunrgbd_results['points'] = sunrgbd_point_cloud
sunrgbd_results = sunrgbd_sample_points(sunrgbd_results)
sunrgbd_choices = np.array([2, 8, 4, 9, 1])
sunrgbd_points_result = sunrgbd_results.get('points', None)
sunrgbd_points_result = sunrgbd_results['points']
assert np.allclose(sunrgbd_point_cloud[sunrgbd_choices],
sunrgbd_points_result)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment