Unverified Commit 25a736f7 authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Feature] Add `RandomJitterPoints` transform (#584)

* add RandomJitterPoints transform

* add unit test

* add comments

* minor fix
parent 8f208a8a
...@@ -14,7 +14,7 @@ from .pipelines import (BackgroundPointsFilter, GlobalAlignment, ...@@ -14,7 +14,7 @@ from .pipelines import (BackgroundPointsFilter, GlobalAlignment,
NormalizePointsColor, ObjectNoise, ObjectRangeFilter, NormalizePointsColor, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointShuffle, PointsRangeFilter, ObjectSample, PointShuffle, PointsRangeFilter,
RandomDropPointsColor, RandomFlip3D, RandomDropPointsColor, RandomFlip3D,
VoxelBasedPointSampler) RandomJitterPoints, VoxelBasedPointSampler)
from .s3dis_dataset import S3DISSegDataset from .s3dis_dataset import S3DISSegDataset
from .scannet_dataset import ScanNetDataset, ScanNetSegDataset from .scannet_dataset import ScanNetDataset, ScanNetSegDataset
from .semantickitti_dataset import SemanticKITTIDataset from .semantickitti_dataset import SemanticKITTIDataset
...@@ -34,5 +34,5 @@ __all__ = [ ...@@ -34,5 +34,5 @@ __all__ = [
'ScanNetSegDataset', 'SemanticKITTIDataset', 'Custom3DDataset', 'ScanNetSegDataset', 'SemanticKITTIDataset', 'Custom3DDataset',
'Custom3DSegDataset', 'LoadPointsFromMultiSweeps', 'WaymoDataset', 'Custom3DSegDataset', 'LoadPointsFromMultiSweeps', 'WaymoDataset',
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'get_loading_pipeline', 'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'get_loading_pipeline',
'RandomDropPointsColor' 'RandomDropPointsColor', 'RandomJitterPoints'
] ]
...@@ -11,7 +11,7 @@ from .transforms_3d import (BackgroundPointsFilter, GlobalAlignment, ...@@ -11,7 +11,7 @@ from .transforms_3d import (BackgroundPointsFilter, GlobalAlignment,
IndoorPointSample, ObjectNoise, ObjectRangeFilter, IndoorPointSample, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointShuffle, PointsRangeFilter, ObjectSample, PointShuffle, PointsRangeFilter,
RandomDropPointsColor, RandomFlip3D, RandomDropPointsColor, RandomFlip3D,
VoxelBasedPointSampler) RandomJitterPoints, VoxelBasedPointSampler)
__all__ = [ __all__ = [
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans', 'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans',
...@@ -22,5 +22,5 @@ __all__ = [ ...@@ -22,5 +22,5 @@ __all__ = [
'PointSegClassMapping', 'MultiScaleFlipAug3D', 'LoadPointsFromMultiSweeps', 'PointSegClassMapping', 'MultiScaleFlipAug3D', 'LoadPointsFromMultiSweeps',
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'GlobalAlignment', 'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'GlobalAlignment',
'IndoorPatchPointSample', 'LoadImageFromFileMono3D', 'IndoorPatchPointSample', 'LoadImageFromFileMono3D',
'RandomDropPointsColor' 'RandomDropPointsColor', 'RandomJitterPoints'
] ]
...@@ -168,6 +168,74 @@ class RandomFlip3D(RandomFlip): ...@@ -168,6 +168,74 @@ class RandomFlip3D(RandomFlip):
return repr_str return repr_str
@PIPELINES.register_module()
class RandomJitterPoints(object):
"""Randomly jitter point coordinates.
Different from the global translation in ``GlobalRotScaleTrans``, here we \
apply different noises to each point in a scene.
Args:
jitter_std (list[float]): The standard deviation of jittering noise.
This applies random noise to all points in a 3D scene, which is \
sampled from a gaussian distribution whose standard deviation is \
set by ``jitter_std``. Defaults to [0.01, 0.01, 0.01]
clip_range (list[float] | None): Clip the randomly generated jitter \
noise into this range. If None is given, don't perform clipping.
Defaults to [-0.05, 0.05]
Note:
This transform should only be used in point cloud segmentation tasks \
because we don't transform ground-truth bboxes accordingly.
For similar transform in detection task, please refer to `ObjectNoise`.
"""
def __init__(self,
jitter_std=[0.01, 0.01, 0.01],
clip_range=[-0.05, 0.05]):
seq_types = (list, tuple, np.ndarray)
if not isinstance(jitter_std, seq_types):
assert isinstance(jitter_std, (int, float)), \
f'unsupported jitter_std type {type(jitter_std)}'
jitter_std = [jitter_std, jitter_std, jitter_std]
self.jitter_std = jitter_std
if clip_range is not None:
if not isinstance(clip_range, seq_types):
assert isinstance(clip_range, (int, float)), \
f'unsupported clip_range type {type(clip_range)}'
clip_range = [-clip_range, clip_range]
self.clip_range = clip_range
def __call__(self, input_dict):
"""Call function to jitter all the points in the scene.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after adding noise to each point, \
'points' key is updated in the result dict.
"""
points = input_dict['points']
jitter_std = np.array(self.jitter_std, dtype=np.float32)
jitter_noise = \
np.random.randn(points.shape[0], 3) * jitter_std[None, :]
if self.clip_range is not None:
jitter_noise = np.clip(jitter_noise, self.clip_range[0],
self.clip_range[1])
points.translate(jitter_noise)
return input_dict
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(jitter_std={self.jitter_std},'
repr_str += f' clip_range={self.clip_range})'
return repr_str
@PIPELINES.register_module() @PIPELINES.register_module()
class ObjectSample(object): class ObjectSample(object):
"""Sample GT objects to the data. """Sample GT objects to the data.
...@@ -433,8 +501,8 @@ class GlobalRotScaleTrans(object): ...@@ -433,8 +501,8 @@ class GlobalRotScaleTrans(object):
Defaults to [-0.78539816, 0.78539816] (close to [-pi/4, pi/4]). Defaults to [-0.78539816, 0.78539816] (close to [-pi/4, pi/4]).
scale_ratio_range (list[float]): Range of scale ratio. scale_ratio_range (list[float]): Range of scale ratio.
Defaults to [0.95, 1.05]. Defaults to [0.95, 1.05].
translation_std (list[float]): The standard deviation of ranslation translation_std (list[float]): The standard deviation of translation
noise. This apply random translation to a scene by a noise, which noise. This applies random translation to a scene by a noise, which
is sampled from a gaussian distribution whose standard deviation is sampled from a gaussian distribution whose standard deviation
is set by ``translation_std``. Defaults to [0, 0, 0] is set by ``translation_std``. Defaults to [0, 0, 0]
shift_height (bool): Whether to shift height. shift_height (bool): Whether to shift height.
......
...@@ -10,7 +10,7 @@ from mmdet3d.datasets import (BackgroundPointsFilter, GlobalAlignment, ...@@ -10,7 +10,7 @@ from mmdet3d.datasets import (BackgroundPointsFilter, GlobalAlignment,
GlobalRotScaleTrans, ObjectNoise, ObjectSample, GlobalRotScaleTrans, ObjectNoise, ObjectSample,
PointShuffle, PointsRangeFilter, PointShuffle, PointsRangeFilter,
RandomDropPointsColor, RandomFlip3D, RandomDropPointsColor, RandomFlip3D,
VoxelBasedPointSampler) RandomJitterPoints, VoxelBasedPointSampler)
def test_remove_points_in_boxes(): def test_remove_points_in_boxes():
...@@ -457,6 +457,62 @@ def test_random_flip_3d(): ...@@ -457,6 +457,62 @@ def test_random_flip_3d():
assert repr_str == expected_repr_str assert repr_str == expected_repr_str
def test_random_jitter_points():
# jitter_std should be a number or seq of numbers
with pytest.raises(AssertionError):
random_jitter_points = RandomJitterPoints(jitter_std='0.0')
# clip_range should be a number or seq of numbers
with pytest.raises(AssertionError):
random_jitter_points = RandomJitterPoints(clip_range='0.0')
random_jitter_points = RandomJitterPoints(jitter_std=0.01, clip_range=0.05)
np.random.seed(0)
points = np.fromfile('tests/data/scannet/points/scene0000_00.bin',
np.float32).reshape(-1, 6)[:10]
depth_points = DepthPoints(
points.copy(), points_dim=6, attribute_dims=dict(color=[3, 4, 5]))
input_dict = dict(points=depth_points.clone())
input_dict = random_jitter_points(input_dict)
trans_depth_points = input_dict['points']
jitter_noise = np.array([[0.01764052, 0.00400157, 0.00978738],
[0.02240893, 0.01867558, -0.00977278],
[0.00950088, -0.00151357, -0.00103219],
[0.00410598, 0.00144044, 0.01454273],
[0.00761038, 0.00121675, 0.00443863],
[0.00333674, 0.01494079, -0.00205158],
[0.00313068, -0.00854096, -0.0255299],
[0.00653619, 0.00864436, -0.00742165],
[0.02269755, -0.01454366, 0.00045759],
[-0.00187184, 0.01532779, 0.01469359]])
trans_depth_points = trans_depth_points.tensor.numpy()
expected_depth_points = points
expected_depth_points[:, :3] += jitter_noise
assert np.allclose(trans_depth_points, expected_depth_points)
repr_str = repr(random_jitter_points)
jitter_std = [0.01, 0.01, 0.01]
clip_range = [-0.05, 0.05]
expected_repr_str = f'RandomJitterPoints(jitter_std={jitter_std},' \
f' clip_range={clip_range})'
assert repr_str == expected_repr_str
# test clipping very large noise
random_jitter_points = RandomJitterPoints(jitter_std=1.0, clip_range=0.05)
input_dict = dict(points=depth_points.clone())
input_dict = random_jitter_points(input_dict)
trans_depth_points = input_dict['points']
assert (trans_depth_points.tensor - depth_points.tensor).max().item() <= \
0.05 + 1e-6
assert (trans_depth_points.tensor - depth_points.tensor).min().item() >= \
-0.05 - 1e-6
def test_background_points_filter(): def test_background_points_filter():
np.random.seed(0) np.random.seed(0)
background_points_filter = BackgroundPointsFilter((0.5, 2.0, 0.5)) background_points_filter = BackgroundPointsFilter((0.5, 2.0, 0.5))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment