Unverified Commit a5627bfb authored by Xiang Xu's avatar Xiang Xu Committed by GitHub
Browse files

[Feature] Support `LaserMix` augmentation (#2302)

* add lasermix

* add prob

* update description

* update
parent 7beabbdd
...@@ -295,7 +295,7 @@ class Seg3DDataset(BaseDataset): ...@@ -295,7 +295,7 @@ class Seg3DDataset(BaseDataset):
if not self.test_mode: if not self.test_mode:
data_info = self.get_data_info(idx) data_info = self.get_data_info(idx)
# Pass the dataset to the pipeline during training to support mixed # Pass the dataset to the pipeline during training to support mixed
# data augmentation, such as polarmix. # data augmentation, such as polarmix and lasermix.
data_info['dataset'] = self data_info['dataset'] = self
return self.pipeline(data_info) return self.pipeline(data_info)
else: else:
......
...@@ -11,8 +11,8 @@ from .test_time_aug import MultiScaleFlipAug3D ...@@ -11,8 +11,8 @@ from .test_time_aug import MultiScaleFlipAug3D
from .transforms_3d import (AffineResize, BackgroundPointsFilter, from .transforms_3d import (AffineResize, BackgroundPointsFilter,
GlobalAlignment, GlobalRotScaleTrans, GlobalAlignment, GlobalRotScaleTrans,
IndoorPatchPointSample, IndoorPointSample, IndoorPatchPointSample, IndoorPointSample,
MultiViewWrapper, ObjectNameFilter, ObjectNoise, LaserMix, MultiViewWrapper, ObjectNameFilter,
ObjectRangeFilter, ObjectSample, ObjectNoise, ObjectRangeFilter, ObjectSample,
PhotoMetricDistortion3D, PointSample, PointShuffle, PhotoMetricDistortion3D, PointSample, PointShuffle,
PointsRangeFilter, PolarMix, RandomDropPointsColor, PointsRangeFilter, PolarMix, RandomDropPointsColor,
RandomFlip3D, RandomJitterPoints, RandomResize3D, RandomFlip3D, RandomJitterPoints, RandomResize3D,
...@@ -30,5 +30,5 @@ __all__ = [ ...@@ -30,5 +30,5 @@ __all__ = [
'RandomDropPointsColor', 'RandomJitterPoints', 'AffineResize', 'RandomDropPointsColor', 'RandomJitterPoints', 'AffineResize',
'RandomShiftScale', 'LoadPointsFromDict', 'Resize3D', 'RandomResize3D', 'RandomShiftScale', 'LoadPointsFromDict', 'Resize3D', 'RandomResize3D',
'MultiViewWrapper', 'PhotoMetricDistortion3D', 'MonoDet3DInferencerLoader', 'MultiViewWrapper', 'PhotoMetricDistortion3D', 'MonoDet3DInferencerLoader',
'LidarDet3DInferencerLoader', 'PolarMix' 'LidarDet3DInferencerLoader', 'PolarMix', 'LaserMix'
] ]
...@@ -2521,3 +2521,148 @@ class PolarMix(BaseTransform): ...@@ -2521,3 +2521,148 @@ class PolarMix(BaseTransform):
repr_str += f'pre_transform={self.pre_transform}, ' repr_str += f'pre_transform={self.pre_transform}, '
repr_str += f'prob={self.prob})' repr_str += f'prob={self.prob})'
return repr_str return repr_str
@TRANSFORMS.register_module()
class LaserMix(BaseTransform):
"""LaserMix data augmentation.
The lasermix transform steps are as follows:
1. Another random point cloud is picked by dataset.
2. Divide the point cloud into several regions according to pitch
angles and combine the areas crossly.
Required Keys:
- points (:obj:`BasePoints`)
- pts_semantic_mask (np.int64)
- dataset (:obj:`BaseDataset`)
Modified Keys:
- points (:obj:`BasePoints`)
- pts_semantic_mask (np.int64)
Args:
num_areas (List[int]): A list of area numbers will be divided into.
pitch_angles (Sequence[float]): Pitch angles used to divide areas.
pre_transform (Sequence[dict], optional): Sequence of transform object
or config dict to be composed. Defaults to None.
prob (float): The transformation probability. Defaults to 1.0.
"""
def __init__(self,
num_areas: List[int],
pitch_angles: Sequence[float],
pre_transform: Optional[Sequence[dict]] = None,
prob: float = 1.0) -> None:
assert is_list_of(num_areas, int), \
'num_areas should be a list of int.'
self.num_areas = num_areas
assert len(pitch_angles) == 2, \
'The length of pitch_angles should be 2, ' \
f'but got {len(pitch_angles)}.'
assert pitch_angles[1] > pitch_angles[0], \
'pitch_angles[1] should be larger than pitch_angles[0].'
self.pitch_angles = pitch_angles
self.prob = prob
if pre_transform is None:
self.pre_transform = None
else:
self.pre_transform = Compose(pre_transform)
def laser_mix_transform(self, input_dict: dict, mix_results: dict) -> dict:
"""LaserMix transform function.
Args:
input_dict (dict): Result dict from loading pipeline.
mix_results (dict): Mixed dict picked from dataset.
Returns:
dict: output dict after transformation.
"""
mix_points = mix_results['points']
mix_pts_semantic_mask = mix_results['pts_semantic_mask']
points = input_dict['points']
pts_semantic_mask = input_dict['pts_semantic_mask']
rho = torch.sqrt(points.coord[:, 0]**2 + points.coord[:, 1]**2)
pitch = torch.atan2(points.coord[:, 2], rho)
pitch = torch.clip(pitch, self.pitch_angles[0] + 1e-5,
self.pitch_angles[1] - 1e-5)
mix_rho = torch.sqrt(mix_points.coord[:, 0]**2 +
mix_points.coord[:, 1]**2)
mix_pitch = torch.atan2(mix_points.coord[:, 2], mix_rho)
mix_pitch = torch.clip(mix_pitch, self.pitch_angles[0] + 1e-5,
self.pitch_angles[1] - 1e-5)
num_areas = np.random.choice(self.num_areas, size=1)[0]
angle_list = np.linspace(self.pitch_angles[1], self.pitch_angles[0],
num_areas + 1)
out_points = []
out_pts_semantic_mask = []
for i in range(num_areas):
# convert angle to radian
start_angle = angle_list[i + 1] / 180 * np.pi
end_angle = angle_list[i] / 180 * np.pi
if i % 2 == 0: # pick from original point cloud
idx = (pitch > start_angle) & (pitch <= end_angle)
out_points.append(points[idx])
out_pts_semantic_mask.append(pts_semantic_mask[idx.numpy()])
else: # pickle from mixed point cloud
idx = (mix_pitch > start_angle) & (mix_pitch <= end_angle)
out_points.append(mix_points[idx])
out_pts_semantic_mask.append(
mix_pts_semantic_mask[idx.numpy()])
out_points = points.cat(out_points)
out_pts_semantic_mask = np.concatenate(out_pts_semantic_mask, axis=0)
input_dict['points'] = out_points
input_dict['pts_semantic_mask'] = out_pts_semantic_mask
return input_dict
def transform(self, input_dict: dict) -> dict:
"""LaserMix transform function.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: output dict after transformation.
"""
if np.random.rand() > self.prob:
return input_dict
assert 'dataset' in input_dict, \
'`dataset` is needed to pass through LaserMix, while not found.'
dataset = input_dict['dataset']
# get index of other point cloud
index = np.random.randint(0, len(dataset))
mix_results = dataset.get_data_info(index)
if self.pre_transform is not None:
# pre_transform may also require dataset
mix_results.update({'dataset': dataset})
# before lasermix need to go through
# the necessary pre_transform
mix_results = self.pre_transform(mix_results)
mix_results.pop('dataset')
input_dict = self.laser_mix_transform(input_dict, mix_results)
return input_dict
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(num_areas={self.num_areas}, '
repr_str += f'pitch_angles={self.pitch_angles}, '
repr_str += f'pre_transform={self.pre_transform}, '
repr_str += f'prob={self.prob})'
return repr_str
...@@ -8,7 +8,7 @@ from mmengine.testing import assert_allclose ...@@ -8,7 +8,7 @@ from mmengine.testing import assert_allclose
from mmdet3d.datasets import (GlobalAlignment, RandomFlip3D, from mmdet3d.datasets import (GlobalAlignment, RandomFlip3D,
SemanticKITTIDataset) SemanticKITTIDataset)
from mmdet3d.datasets.transforms import GlobalRotScaleTrans, PolarMix from mmdet3d.datasets.transforms import GlobalRotScaleTrans, LaserMix, PolarMix
from mmdet3d.structures import LiDARPoints from mmdet3d.structures import LiDARPoints
from mmdet3d.testing import create_data_info_after_loading from mmdet3d.testing import create_data_info_after_loading
from mmdet3d.utils import register_all_modules from mmdet3d.utils import register_all_modules
...@@ -222,3 +222,128 @@ class TestPolarMix(unittest.TestCase): ...@@ -222,3 +222,128 @@ class TestPolarMix(unittest.TestCase):
results = transform.transform(copy.deepcopy(self.results)) results = transform.transform(copy.deepcopy(self.results))
self.assertTrue(results['points'].shape[0] == self.assertTrue(results['points'].shape[0] ==
results['pts_semantic_mask'].shape[0]) results['pts_semantic_mask'].shape[0])
class TestLaserMix(unittest.TestCase):
def setUp(self):
self.pre_transform = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=4,
use_dim=4),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True,
seg_3d_dtype='np.int32'),
dict(type='PointSegClassMapping'),
]
classes = ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck', 'bus',
'person', 'bicyclist', 'motorcyclist', 'road', 'parking',
'sidewalk', 'other-ground', 'building', 'fence',
'vegetation', 'trunck', 'terrian', 'pole', 'traffic-sign')
palette = [
[174, 199, 232],
[152, 223, 138],
[31, 119, 180],
[255, 187, 120],
[188, 189, 34],
[140, 86, 75],
[255, 152, 150],
[214, 39, 40],
[197, 176, 213],
[148, 103, 189],
[196, 156, 148],
[23, 190, 207],
[247, 182, 210],
[219, 219, 141],
[255, 127, 14],
[158, 218, 229],
[44, 160, 44],
[112, 128, 144],
[227, 119, 194],
[82, 84, 163],
]
seg_label_mapping = {
0: 0, # "unlabeled"
1: 0, # "outlier" mapped to "unlabeled" --------------mapped
10: 1, # "car"
11: 2, # "bicycle"
13: 5, # "bus" mapped to "other-vehicle" --------------mapped
15: 3, # "motorcycle"
16: 5, # "on-rails" mapped to "other-vehicle" ---------mapped
18: 4, # "truck"
20: 5, # "other-vehicle"
30: 6, # "person"
31: 7, # "bicyclist"
32: 8, # "motorcyclist"
40: 9, # "road"
44: 10, # "parking"
48: 11, # "sidewalk"
49: 12, # "other-ground"
50: 13, # "building"
51: 14, # "fence"
52: 0, # "other-structure" mapped to "unlabeled" ------mapped
60: 9, # "lane-marking" to "road" ---------------------mapped
70: 15, # "vegetation"
71: 16, # "trunk"
72: 17, # "terrain"
80: 18, # "pole"
81: 19, # "traffic-sign"
99: 0, # "other-object" to "unlabeled" ----------------mapped
252: 1, # "moving-car" to "car" ------------------------mapped
253: 7, # "moving-bicyclist" to "bicyclist" ------------mapped
254: 6, # "moving-person" to "person" ------------------mapped
255: 8, # "moving-motorcyclist" to "motorcyclist" ------mapped
256: 5, # "moving-on-rails" mapped to "other-vehic------mapped
257: 5, # "moving-bus" mapped to "other-vehicle" -------mapped
258: 4, # "moving-truck" to "truck" --------------------mapped
259: 5 # "moving-other"-vehicle to "other-vehicle"-----mapped
}
max_label = 259
self.dataset = SemanticKITTIDataset(
'./tests/data/semantickitti/',
'semantickitti_infos.pkl',
metainfo=dict(
classes=classes,
palette=palette,
seg_label_mapping=seg_label_mapping,
max_label=max_label),
data_prefix=dict(
pts='sequences/00/velodyne',
pts_semantic_mask='sequences/00/labels'),
pipeline=[],
modality=dict(use_lidar=True, use_camera=False))
points = np.random.random((100, 4))
self.results = {
'points': LiDARPoints(points, points_dim=4),
'pts_semantic_mask': np.random.randint(0, 20, (100, )),
'dataset': self.dataset
}
def test_transform(self):
# test assertion for invalid num_areas
with self.assertRaises(AssertionError):
transform = LaserMix(num_areas=3, pitch_angles=[-20, 0])
with self.assertRaises(AssertionError):
transform = LaserMix(num_areas=[3.0, 4.0], pitch_angles=[-20, 0])
# test assertion for invalid pitch_angles
with self.assertRaises(AssertionError):
transform = LaserMix(num_areas=[3, 4], pitch_angles=[-20])
with self.assertRaises(AssertionError):
transform = LaserMix(num_areas=[3, 4], pitch_angles=[0, -20])
transform = LaserMix(
num_areas=[3, 4, 5, 6],
pitch_angles=[-20, 0],
pre_transform=self.pre_transform)
results = transform.transform(copy.deepcopy(self.results))
self.assertTrue(results['points'].shape[0] ==
results['pts_semantic_mask'].shape[0])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment