Unverified Commit fe719e8d authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Enhance] Add missing unit tests for pipeline functions (#615)

* unit test for LoadImageFromFileMono3D

* fix GlobalRotScaleTrans assertion

* add unit test for ObjectNameFilter

* add unit test for ObjectRangeFilter
parent 20c0fbdc
...@@ -11,9 +11,9 @@ from .pipelines import (BackgroundPointsFilter, GlobalAlignment, ...@@ -11,9 +11,9 @@ from .pipelines import (BackgroundPointsFilter, GlobalAlignment,
GlobalRotScaleTrans, IndoorPatchPointSample, GlobalRotScaleTrans, IndoorPatchPointSample,
IndoorPointSample, LoadAnnotations3D, IndoorPointSample, LoadAnnotations3D,
LoadPointsFromFile, LoadPointsFromMultiSweeps, LoadPointsFromFile, LoadPointsFromMultiSweeps,
NormalizePointsColor, ObjectNoise, ObjectRangeFilter, NormalizePointsColor, ObjectNameFilter, ObjectNoise,
ObjectSample, PointShuffle, PointsRangeFilter, ObjectRangeFilter, ObjectSample, PointShuffle,
RandomDropPointsColor, RandomFlip3D, PointsRangeFilter, RandomDropPointsColor, RandomFlip3D,
RandomJitterPoints, VoxelBasedPointSampler) RandomJitterPoints, VoxelBasedPointSampler)
from .s3dis_dataset import S3DISSegDataset from .s3dis_dataset import S3DISSegDataset
from .scannet_dataset import ScanNetDataset, ScanNetSegDataset from .scannet_dataset import ScanNetDataset, ScanNetSegDataset
...@@ -34,5 +34,5 @@ __all__ = [ ...@@ -34,5 +34,5 @@ __all__ = [
'ScanNetSegDataset', 'SemanticKITTIDataset', 'Custom3DDataset', 'ScanNetSegDataset', 'SemanticKITTIDataset', 'Custom3DDataset',
'Custom3DSegDataset', 'LoadPointsFromMultiSweeps', 'WaymoDataset', 'Custom3DSegDataset', 'LoadPointsFromMultiSweeps', 'WaymoDataset',
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'get_loading_pipeline', 'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'get_loading_pipeline',
'RandomDropPointsColor', 'RandomJitterPoints' 'RandomDropPointsColor', 'RandomJitterPoints', 'ObjectNameFilter'
] ]
...@@ -8,10 +8,11 @@ from .loading import (LoadAnnotations3D, LoadImageFromFileMono3D, ...@@ -8,10 +8,11 @@ from .loading import (LoadAnnotations3D, LoadImageFromFileMono3D,
from .test_time_aug import MultiScaleFlipAug3D from .test_time_aug import MultiScaleFlipAug3D
from .transforms_3d import (BackgroundPointsFilter, GlobalAlignment, from .transforms_3d import (BackgroundPointsFilter, GlobalAlignment,
GlobalRotScaleTrans, IndoorPatchPointSample, GlobalRotScaleTrans, IndoorPatchPointSample,
IndoorPointSample, ObjectNoise, ObjectRangeFilter, IndoorPointSample, ObjectNameFilter, ObjectNoise,
ObjectSample, PointShuffle, PointsRangeFilter, ObjectRangeFilter, ObjectSample, PointShuffle,
RandomDropPointsColor, RandomFlip3D, PointsRangeFilter, RandomDropPointsColor,
RandomJitterPoints, VoxelBasedPointSampler) RandomFlip3D, RandomJitterPoints,
VoxelBasedPointSampler)
__all__ = [ __all__ = [
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans', 'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans',
...@@ -21,6 +22,6 @@ __all__ = [ ...@@ -21,6 +22,6 @@ __all__ = [
'NormalizePointsColor', 'LoadAnnotations3D', 'IndoorPointSample', 'NormalizePointsColor', 'LoadAnnotations3D', 'IndoorPointSample',
'PointSegClassMapping', 'MultiScaleFlipAug3D', 'LoadPointsFromMultiSweeps', 'PointSegClassMapping', 'MultiScaleFlipAug3D', 'LoadPointsFromMultiSweeps',
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'GlobalAlignment', 'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'GlobalAlignment',
'IndoorPatchPointSample', 'LoadImageFromFileMono3D', 'IndoorPatchPointSample', 'LoadImageFromFileMono3D', 'ObjectNameFilter',
'RandomDropPointsColor', 'RandomJitterPoints' 'RandomDropPointsColor', 'RandomJitterPoints'
] ]
...@@ -532,6 +532,8 @@ class GlobalRotScaleTrans(object): ...@@ -532,6 +532,8 @@ class GlobalRotScaleTrans(object):
translation_std = [ translation_std = [
translation_std, translation_std, translation_std translation_std, translation_std, translation_std
] ]
assert all([std >= 0 for std in translation_std]), \
'translation_std should be positive'
self.translation_std = translation_std self.translation_std = translation_std
self.shift_height = shift_height self.shift_height = shift_height
......
...@@ -7,7 +7,8 @@ from mmdet3d.core import (Box3DMode, CameraInstance3DBoxes, ...@@ -7,7 +7,8 @@ from mmdet3d.core import (Box3DMode, CameraInstance3DBoxes,
DepthInstance3DBoxes, LiDARInstance3DBoxes) DepthInstance3DBoxes, LiDARInstance3DBoxes)
from mmdet3d.core.points import DepthPoints, LiDARPoints from mmdet3d.core.points import DepthPoints, LiDARPoints
from mmdet3d.datasets import (BackgroundPointsFilter, GlobalAlignment, from mmdet3d.datasets import (BackgroundPointsFilter, GlobalAlignment,
GlobalRotScaleTrans, ObjectNoise, ObjectSample, GlobalRotScaleTrans, ObjectNameFilter,
ObjectNoise, ObjectRangeFilter, ObjectSample,
PointShuffle, PointsRangeFilter, PointShuffle, PointsRangeFilter,
RandomDropPointsColor, RandomFlip3D, RandomDropPointsColor, RandomFlip3D,
RandomJitterPoints, VoxelBasedPointSampler) RandomJitterPoints, VoxelBasedPointSampler)
...@@ -142,6 +143,47 @@ def test_object_noise(): ...@@ -142,6 +143,47 @@ def test_object_noise():
assert torch.allclose(gt_bboxes_3d, expected_gt_bboxes_3d, 1e-3) assert torch.allclose(gt_bboxes_3d, expected_gt_bboxes_3d, 1e-3)
def test_object_name_filter():
class_names = ['Pedestrian']
object_name_filter = ObjectNameFilter(class_names)
annos = mmcv.load('./tests/data/kitti/kitti_infos_train.pkl')
info = annos[0]
rect = info['calib']['R0_rect'].astype(np.float32)
Trv2c = info['calib']['Tr_velo_to_cam'].astype(np.float32)
annos = info['annos']
loc = annos['location']
dims = annos['dimensions']
rots = annos['rotation_y']
gt_names = annos['name']
gt_bboxes_3d = np.concatenate([loc, dims, rots[..., np.newaxis]],
axis=1).astype(np.float32)
gt_bboxes_3d = CameraInstance3DBoxes(gt_bboxes_3d).convert_to(
Box3DMode.LIDAR, np.linalg.inv(rect @ Trv2c))
CLASSES = ('Pedestrian', 'Cyclist', 'Car')
gt_labels = []
for cat in gt_names:
if cat in CLASSES:
gt_labels.append(CLASSES.index(cat))
else:
gt_labels.append(-1)
gt_labels = np.array(gt_labels, dtype=np.long)
input_dict = dict(
gt_bboxes_3d=gt_bboxes_3d.clone(), gt_labels_3d=gt_labels.copy())
results = object_name_filter(input_dict)
bboxes_3d = results['gt_bboxes_3d']
labels_3d = results['gt_labels_3d']
keep_mask = np.array([name in class_names for name in gt_names])
assert torch.allclose(gt_bboxes_3d.tensor[keep_mask], bboxes_3d.tensor)
assert np.all(gt_labels[keep_mask] == labels_3d)
repr_str = repr(object_name_filter)
expected_repr_str = f'ObjectNameFilter(classes={class_names})'
assert repr_str == expected_repr_str
def test_point_shuffle(): def test_point_shuffle():
np.random.seed(0) np.random.seed(0)
torch.manual_seed(0) torch.manual_seed(0)
...@@ -224,6 +266,38 @@ def test_points_range_filter(): ...@@ -224,6 +266,38 @@ def test_points_range_filter():
assert repr_str == expected_repr_str assert repr_str == expected_repr_str
def test_object_range_filter():
point_cloud_range = [0, -40, -3, 70.4, 40, 1]
object_range_filter = ObjectRangeFilter(point_cloud_range)
bbox = np.array(
[[8.7314, -1.8559, -0.6547, 0.4800, 1.2000, 1.8900, 0.0100],
[28.7314, -18.559, 0.6547, 2.4800, 1.6000, 1.9200, 5.0100],
[-2.54, -1.8559, -0.6547, 0.4800, 1.2000, 1.8900, 0.0100],
[72.7314, -18.559, 0.6547, 6.4800, 11.6000, 4.9200, -0.0100],
[18.7314, -18.559, 20.6547, 6.4800, 8.6000, 3.9200, -1.0100],
[3.7314, 42.559, -0.6547, 6.4800, 8.6000, 2.9200, 3.0100]])
gt_bboxes_3d = LiDARInstance3DBoxes(bbox, origin=(0.5, 0.5, 0.5))
gt_labels_3d = np.array([0, 2, 1, 1, 2, 0], dtype=np.long)
input_dict = dict(
gt_bboxes_3d=gt_bboxes_3d.clone(), gt_labels_3d=gt_labels_3d.copy())
results = object_range_filter(input_dict)
bboxes_3d = results['gt_bboxes_3d']
labels_3d = results['gt_labels_3d']
keep_mask = np.array([True, True, False, False, True, False])
expected_bbox = gt_bboxes_3d.tensor[keep_mask]
expected_bbox[1, 6] -= 2 * np.pi # limit yaw
assert torch.allclose(expected_bbox, bboxes_3d.tensor)
assert np.all(gt_labels_3d[keep_mask] == labels_3d)
repr_str = repr(object_range_filter)
expected_repr_str = 'ObjectRangeFilter(point_cloud_range=' \
'[0.0, -40.0, -3.0, 70.4000015258789, 40.0, 1.0])'
assert repr_str == expected_repr_str
def test_global_alignment(): def test_global_alignment():
np.random.seed(0) np.random.seed(0)
global_alignment = GlobalAlignment(rotation_axis=2) global_alignment = GlobalAlignment(rotation_axis=2)
...@@ -270,9 +344,11 @@ def test_global_rot_scale_trans(): ...@@ -270,9 +344,11 @@ def test_global_rot_scale_trans():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
global_rot_scale_trans = GlobalRotScaleTrans(scale_ratio_range=1.0) global_rot_scale_trans = GlobalRotScaleTrans(scale_ratio_range=1.0)
# translation_std should be a number or seq of numbers # translation_std should be a positive number or seq of positive numbers
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
global_rot_scale_trans = GlobalRotScaleTrans(translation_std='0.0') global_rot_scale_trans = GlobalRotScaleTrans(translation_std='0.0')
with pytest.raises(AssertionError):
global_rot_scale_trans = GlobalRotScaleTrans(translation_std=-1.0)
global_rot_scale_trans = GlobalRotScaleTrans( global_rot_scale_trans = GlobalRotScaleTrans(
rot_range=angle, rot_range=angle,
......
...@@ -5,11 +5,16 @@ from os import path as osp ...@@ -5,11 +5,16 @@ from os import path as osp
from mmdet3d.core.bbox import DepthInstance3DBoxes from mmdet3d.core.bbox import DepthInstance3DBoxes
from mmdet3d.core.points import DepthPoints, LiDARPoints from mmdet3d.core.points import DepthPoints, LiDARPoints
from mmdet3d.datasets.pipelines import (LoadAnnotations3D, LoadPointsFromFile, # yapf: disable
from mmdet3d.datasets.pipelines import (LoadAnnotations3D,
LoadImageFromFileMono3D,
LoadPointsFromFile,
LoadPointsFromMultiSweeps, LoadPointsFromMultiSweeps,
NormalizePointsColor, NormalizePointsColor,
PointSegClassMapping) PointSegClassMapping)
# yapf: enable
def test_load_points_from_indoor_file(): def test_load_points_from_indoor_file():
# test on SUN RGB-D dataset with shifted height # test on SUN RGB-D dataset with shifted height
...@@ -288,6 +293,25 @@ def test_load_points_from_multi_sweeps(): ...@@ -288,6 +293,25 @@ def test_load_points_from_multi_sweeps():
assert points.shape == (403, 4) assert points.shape == (403, 4)
def test_load_image_from_file_mono_3d():
load_image_from_file_mono_3d = LoadImageFromFileMono3D()
filename = 'tests/data/nuscenes/samples/CAM_BACK_LEFT/' \
'n015-2018-07-18-11-07-57+0800__CAM_BACK_LEFT__1531883530447423.jpg'
cam_intrinsic = np.array([[1256.74, 0.0, 792.11], [0.0, 1256.74, 492.78],
[0.0, 0.0, 1.0]])
input_dict = dict(
img_prefix=None,
img_info=dict(filename=filename, cam_intrinsic=cam_intrinsic.copy()))
results = load_image_from_file_mono_3d(input_dict)
assert results['img'].shape == (900, 1600, 3)
assert np.all(results['cam_intrinsic'] == cam_intrinsic)
repr_str = repr(load_image_from_file_mono_3d)
expected_repr_str = 'LoadImageFromFileMono3D(to_float32=False, ' \
"color_type='color', file_client_args={'backend': 'disk'})"
assert repr_str == expected_repr_str
def test_point_seg_class_mapping(): def test_point_seg_class_mapping():
# max_cat_id should larger tham max id in valid_cat_ids # max_cat_id should larger tham max id in valid_cat_ids
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment