Unverified Commit b3e792bc authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Feature] Add `RandomDropPointsColor` transform (#585)

* add RandomDropPointsColor transform

* add unit test

* add link to PAConv
parent 926beda8
......@@ -13,7 +13,8 @@ from .pipelines import (BackgroundPointsFilter, GlobalAlignment,
LoadPointsFromFile, LoadPointsFromMultiSweeps,
NormalizePointsColor, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointShuffle, PointsRangeFilter,
RandomFlip3D, VoxelBasedPointSampler)
RandomDropPointsColor, RandomFlip3D,
VoxelBasedPointSampler)
from .s3dis_dataset import S3DISSegDataset
from .scannet_dataset import ScanNetDataset, ScanNetSegDataset
from .semantickitti_dataset import SemanticKITTIDataset
......@@ -32,5 +33,6 @@ __all__ = [
'LoadAnnotations3D', 'GlobalAlignment', 'SUNRGBDDataset', 'ScanNetDataset',
'ScanNetSegDataset', 'SemanticKITTIDataset', 'Custom3DDataset',
'Custom3DSegDataset', 'LoadPointsFromMultiSweeps', 'WaymoDataset',
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'get_loading_pipeline'
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'get_loading_pipeline',
'RandomDropPointsColor'
]
......@@ -10,7 +10,8 @@ from .transforms_3d import (BackgroundPointsFilter, GlobalAlignment,
GlobalRotScaleTrans, IndoorPatchPointSample,
IndoorPointSample, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointShuffle, PointsRangeFilter,
RandomFlip3D, VoxelBasedPointSampler)
RandomDropPointsColor, RandomFlip3D,
VoxelBasedPointSampler)
__all__ = [
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans',
......@@ -20,5 +21,6 @@ __all__ = [
'NormalizePointsColor', 'LoadAnnotations3D', 'IndoorPointSample',
'PointSegClassMapping', 'MultiScaleFlipAug3D', 'LoadPointsFromMultiSweeps',
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'GlobalAlignment',
'IndoorPatchPointSample', 'LoadImageFromFileMono3D'
'IndoorPatchPointSample', 'LoadImageFromFileMono3D',
'RandomDropPointsColor'
]
......@@ -10,6 +10,50 @@ from ..builder import OBJECTSAMPLERS
from .data_augment_utils import noise_per_object_v3_
@PIPELINES.register_module()
class RandomDropPointsColor(object):
r"""Randomly set the color of points to all zeros.
Once this transform is executed, all the points' color will be dropped.
Refer to `PAConv <https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/
util/transform.py#L223>`_ for more details.
Args:
drop_ratio (float): The probability of dropping point colors.
Defaults to 0.2.
"""
def __init__(self, drop_ratio=0.2):
assert isinstance(drop_ratio, (int, float)) and 0 <= drop_ratio <= 1, \
f'invalid drop_ratio value {drop_ratio}'
self.drop_ratio = drop_ratio
def __call__(self, input_dict):
"""Call function to drop point colors.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after color dropping, \
'points' key is updated in the result dict.
"""
points = input_dict['points']
assert points.attribute_dims is not None and \
'color' in points.attribute_dims, \
'Expect points have color attribute'
if np.random.rand() < self.drop_ratio:
points.color = points.color * 0.0
return input_dict
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(drop_ratio={self.drop_ratio})'
return repr_str
@PIPELINES.register_module()
class RandomFlip3D(RandomFlip):
"""Flip the points & bbox.
......
......@@ -8,7 +8,8 @@ from mmdet3d.core import (Box3DMode, CameraInstance3DBoxes,
from mmdet3d.core.points import DepthPoints, LiDARPoints
from mmdet3d.datasets import (BackgroundPointsFilter, GlobalAlignment,
GlobalRotScaleTrans, ObjectNoise, ObjectSample,
PointShuffle, PointsRangeFilter, RandomFlip3D,
PointShuffle, PointsRangeFilter,
RandomDropPointsColor, RandomFlip3D,
VoxelBasedPointSampler)
......@@ -364,6 +365,41 @@ def test_global_rot_scale_trans():
atol=1e-6)
def test_random_drop_points_color():
# drop_ratio should be in [0, 1]
with pytest.raises(AssertionError):
random_drop_points_color = RandomDropPointsColor(drop_ratio=1.1)
# 100% drop
random_drop_points_color = RandomDropPointsColor(drop_ratio=1)
points = np.fromfile('tests/data/scannet/points/scene0000_00.bin',
np.float32).reshape(-1, 6)
depth_points = DepthPoints(
points.copy(), points_dim=6, attribute_dims=dict(color=[3, 4, 5]))
input_dict = dict(points=depth_points.clone())
input_dict = random_drop_points_color(input_dict)
trans_depth_points = input_dict['points']
trans_color = trans_depth_points.color
assert torch.all(trans_color == trans_color.new_zeros(trans_color.shape))
# 0% drop
random_drop_points_color = RandomDropPointsColor(drop_ratio=0)
input_dict = dict(points=depth_points.clone())
input_dict = random_drop_points_color(input_dict)
trans_depth_points = input_dict['points']
trans_color = trans_depth_points.color
assert torch.allclose(trans_color, depth_points.tensor[:, 3:6])
random_drop_points_color = RandomDropPointsColor(drop_ratio=0.5)
repr_str = repr(random_drop_points_color)
expected_repr_str = 'RandomDropPointsColor(drop_ratio=0.5)'
assert repr_str == expected_repr_str
def test_random_flip_3d():
random_flip_3d = RandomFlip3D(
flip_ratio_bev_horizontal=1.0, flip_ratio_bev_vertical=1.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment