"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "7b61571cc58a72afa8b430e7cd7d2ac468bd073d"
Unverified Commit b3e792bc authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Feature] Add `RandomDropPointsColor` transform (#585)

* add RandomDropPointsColor transform

* add unit test

* add link to PAConv
parent 926beda8
...@@ -13,7 +13,8 @@ from .pipelines import (BackgroundPointsFilter, GlobalAlignment, ...@@ -13,7 +13,8 @@ from .pipelines import (BackgroundPointsFilter, GlobalAlignment,
LoadPointsFromFile, LoadPointsFromMultiSweeps, LoadPointsFromFile, LoadPointsFromMultiSweeps,
NormalizePointsColor, ObjectNoise, ObjectRangeFilter, NormalizePointsColor, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointShuffle, PointsRangeFilter, ObjectSample, PointShuffle, PointsRangeFilter,
RandomFlip3D, VoxelBasedPointSampler) RandomDropPointsColor, RandomFlip3D,
VoxelBasedPointSampler)
from .s3dis_dataset import S3DISSegDataset from .s3dis_dataset import S3DISSegDataset
from .scannet_dataset import ScanNetDataset, ScanNetSegDataset from .scannet_dataset import ScanNetDataset, ScanNetSegDataset
from .semantickitti_dataset import SemanticKITTIDataset from .semantickitti_dataset import SemanticKITTIDataset
...@@ -32,5 +33,6 @@ __all__ = [ ...@@ -32,5 +33,6 @@ __all__ = [
'LoadAnnotations3D', 'GlobalAlignment', 'SUNRGBDDataset', 'ScanNetDataset', 'LoadAnnotations3D', 'GlobalAlignment', 'SUNRGBDDataset', 'ScanNetDataset',
'ScanNetSegDataset', 'SemanticKITTIDataset', 'Custom3DDataset', 'ScanNetSegDataset', 'SemanticKITTIDataset', 'Custom3DDataset',
'Custom3DSegDataset', 'LoadPointsFromMultiSweeps', 'WaymoDataset', 'Custom3DSegDataset', 'LoadPointsFromMultiSweeps', 'WaymoDataset',
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'get_loading_pipeline' 'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'get_loading_pipeline',
'RandomDropPointsColor'
] ]
...@@ -10,7 +10,8 @@ from .transforms_3d import (BackgroundPointsFilter, GlobalAlignment, ...@@ -10,7 +10,8 @@ from .transforms_3d import (BackgroundPointsFilter, GlobalAlignment,
GlobalRotScaleTrans, IndoorPatchPointSample, GlobalRotScaleTrans, IndoorPatchPointSample,
IndoorPointSample, ObjectNoise, ObjectRangeFilter, IndoorPointSample, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointShuffle, PointsRangeFilter, ObjectSample, PointShuffle, PointsRangeFilter,
RandomFlip3D, VoxelBasedPointSampler) RandomDropPointsColor, RandomFlip3D,
VoxelBasedPointSampler)
__all__ = [ __all__ = [
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans', 'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans',
...@@ -20,5 +21,6 @@ __all__ = [ ...@@ -20,5 +21,6 @@ __all__ = [
'NormalizePointsColor', 'LoadAnnotations3D', 'IndoorPointSample', 'NormalizePointsColor', 'LoadAnnotations3D', 'IndoorPointSample',
'PointSegClassMapping', 'MultiScaleFlipAug3D', 'LoadPointsFromMultiSweeps', 'PointSegClassMapping', 'MultiScaleFlipAug3D', 'LoadPointsFromMultiSweeps',
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'GlobalAlignment', 'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'GlobalAlignment',
'IndoorPatchPointSample', 'LoadImageFromFileMono3D' 'IndoorPatchPointSample', 'LoadImageFromFileMono3D',
'RandomDropPointsColor'
] ]
...@@ -10,6 +10,50 @@ from ..builder import OBJECTSAMPLERS ...@@ -10,6 +10,50 @@ from ..builder import OBJECTSAMPLERS
from .data_augment_utils import noise_per_object_v3_ from .data_augment_utils import noise_per_object_v3_
@PIPELINES.register_module()
class RandomDropPointsColor(object):
r"""Randomly set the color of points to all zeros.
Once this transform is executed, all the points' color will be dropped.
Refer to `PAConv <https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/
util/transform.py#L223>`_ for more details.
Args:
drop_ratio (float): The probability of dropping point colors.
Defaults to 0.2.
"""
def __init__(self, drop_ratio=0.2):
assert isinstance(drop_ratio, (int, float)) and 0 <= drop_ratio <= 1, \
f'invalid drop_ratio value {drop_ratio}'
self.drop_ratio = drop_ratio
def __call__(self, input_dict):
"""Call function to drop point colors.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after color dropping, \
'points' key is updated in the result dict.
"""
points = input_dict['points']
assert points.attribute_dims is not None and \
'color' in points.attribute_dims, \
'Expect points have color attribute'
if np.random.rand() < self.drop_ratio:
points.color = points.color * 0.0
return input_dict
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(drop_ratio={self.drop_ratio})'
return repr_str
@PIPELINES.register_module() @PIPELINES.register_module()
class RandomFlip3D(RandomFlip): class RandomFlip3D(RandomFlip):
"""Flip the points & bbox. """Flip the points & bbox.
......
...@@ -8,7 +8,8 @@ from mmdet3d.core import (Box3DMode, CameraInstance3DBoxes, ...@@ -8,7 +8,8 @@ from mmdet3d.core import (Box3DMode, CameraInstance3DBoxes,
from mmdet3d.core.points import DepthPoints, LiDARPoints from mmdet3d.core.points import DepthPoints, LiDARPoints
from mmdet3d.datasets import (BackgroundPointsFilter, GlobalAlignment, from mmdet3d.datasets import (BackgroundPointsFilter, GlobalAlignment,
GlobalRotScaleTrans, ObjectNoise, ObjectSample, GlobalRotScaleTrans, ObjectNoise, ObjectSample,
PointShuffle, PointsRangeFilter, RandomFlip3D, PointShuffle, PointsRangeFilter,
RandomDropPointsColor, RandomFlip3D,
VoxelBasedPointSampler) VoxelBasedPointSampler)
...@@ -364,6 +365,41 @@ def test_global_rot_scale_trans(): ...@@ -364,6 +365,41 @@ def test_global_rot_scale_trans():
atol=1e-6) atol=1e-6)
def test_random_drop_points_color():
# drop_ratio should be in [0, 1]
with pytest.raises(AssertionError):
random_drop_points_color = RandomDropPointsColor(drop_ratio=1.1)
# 100% drop
random_drop_points_color = RandomDropPointsColor(drop_ratio=1)
points = np.fromfile('tests/data/scannet/points/scene0000_00.bin',
np.float32).reshape(-1, 6)
depth_points = DepthPoints(
points.copy(), points_dim=6, attribute_dims=dict(color=[3, 4, 5]))
input_dict = dict(points=depth_points.clone())
input_dict = random_drop_points_color(input_dict)
trans_depth_points = input_dict['points']
trans_color = trans_depth_points.color
assert torch.all(trans_color == trans_color.new_zeros(trans_color.shape))
# 0% drop
random_drop_points_color = RandomDropPointsColor(drop_ratio=0)
input_dict = dict(points=depth_points.clone())
input_dict = random_drop_points_color(input_dict)
trans_depth_points = input_dict['points']
trans_color = trans_depth_points.color
assert torch.allclose(trans_color, depth_points.tensor[:, 3:6])
random_drop_points_color = RandomDropPointsColor(drop_ratio=0.5)
repr_str = repr(random_drop_points_color)
expected_repr_str = 'RandomDropPointsColor(drop_ratio=0.5)'
assert repr_str == expected_repr_str
def test_random_flip_3d(): def test_random_flip_3d():
random_flip_3d = RandomFlip3D( random_flip_3d = RandomFlip3D(
flip_ratio_bev_horizontal=1.0, flip_ratio_bev_vertical=1.0) flip_ratio_bev_horizontal=1.0, flip_ratio_bev_vertical=1.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment