Unverified Commit 60d848b3 authored by Xiang Xu's avatar Xiang Xu Committed by GitHub
Browse files

[Feature] Support `PolarMix` augmentation (#2265)

* support polarmix

* Update __init__.py

* add UT

* use `BasePoints` instead of numpy

* Update transforms_3d.py

* Update transforms_3d.py

* Update test_transforms_3d.py

* update docs

* update polarmix without MultiImageMixDataset

* add comments

* fix UT

* update docstring

* fix yaw calculation

* fix UT

* refactor

* update

* update docs

* fix typo

* Update transforms_3d.py

* update ut

* fix typehint

* add prob argument
parent 21de1afe
...@@ -283,6 +283,24 @@ class Seg3DDataset(BaseDataset): ...@@ -283,6 +283,24 @@ class Seg3DDataset(BaseDataset):
return info return info
def prepare_data(self, idx: int) -> dict:
"""Get data processed by ``self.pipeline``.
Args:
idx (int): The index of ``data_info``.
Returns:
dict: Results passed through ``self.pipeline``.
"""
if not self.test_mode:
data_info = self.get_data_info(idx)
# Pass the dataset to the pipeline during training to support mixed
# data augmentation, such as polarmix.
data_info['dataset'] = self
return self.pipeline(data_info)
else:
return super().prepare_data(idx)
def get_scene_idxs(self, scene_idxs: Union[None, str, def get_scene_idxs(self, scene_idxs: Union[None, str,
np.ndarray]) -> np.ndarray: np.ndarray]) -> np.ndarray:
"""Compute scene_idxs for data sampling. """Compute scene_idxs for data sampling.
......
...@@ -14,7 +14,7 @@ from .transforms_3d import (AffineResize, BackgroundPointsFilter, ...@@ -14,7 +14,7 @@ from .transforms_3d import (AffineResize, BackgroundPointsFilter,
MultiViewWrapper, ObjectNameFilter, ObjectNoise, MultiViewWrapper, ObjectNameFilter, ObjectNoise,
ObjectRangeFilter, ObjectSample, ObjectRangeFilter, ObjectSample,
PhotoMetricDistortion3D, PointSample, PointShuffle, PhotoMetricDistortion3D, PointSample, PointShuffle,
PointsRangeFilter, RandomDropPointsColor, PointsRangeFilter, PolarMix, RandomDropPointsColor,
RandomFlip3D, RandomJitterPoints, RandomResize3D, RandomFlip3D, RandomJitterPoints, RandomResize3D,
RandomShiftScale, Resize3D, VoxelBasedPointSampler) RandomShiftScale, Resize3D, VoxelBasedPointSampler)
...@@ -30,5 +30,5 @@ __all__ = [ ...@@ -30,5 +30,5 @@ __all__ = [
'RandomDropPointsColor', 'RandomJitterPoints', 'AffineResize', 'RandomDropPointsColor', 'RandomJitterPoints', 'AffineResize',
'RandomShiftScale', 'LoadPointsFromDict', 'Resize3D', 'RandomResize3D', 'RandomShiftScale', 'LoadPointsFromDict', 'Resize3D', 'RandomResize3D',
'MultiViewWrapper', 'PhotoMetricDistortion3D', 'MonoDet3DInferencerLoader', 'MultiViewWrapper', 'PhotoMetricDistortion3D', 'MonoDet3DInferencerLoader',
'LidarDet3DInferencerLoader' 'LidarDet3DInferencerLoader', 'PolarMix'
] ]
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import random import random
import warnings import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Sequence, Tuple, Union
import cv2 import cv2
import mmcv import mmcv
import numpy as np import numpy as np
import torch
from mmcv.transforms import BaseTransform, Compose, RandomResize, Resize from mmcv.transforms import BaseTransform, Compose, RandomResize, Resize
from mmdet.datasets.transforms import (PhotoMetricDistortion, RandomCrop, from mmdet.datasets.transforms import (PhotoMetricDistortion, RandomCrop,
RandomFlip) RandomFlip)
from mmengine import is_tuple_of from mmengine import is_list_of, is_tuple_of
from mmdet3d.models.task_modules import VoxelGenerator from mmdet3d.models.task_modules import VoxelGenerator
from mmdet3d.registry import TRANSFORMS from mmdet3d.registry import TRANSFORMS
...@@ -2352,3 +2353,171 @@ class MultiViewWrapper(BaseTransform): ...@@ -2352,3 +2353,171 @@ class MultiViewWrapper(BaseTransform):
if len(input_dict[key]) == 0: if len(input_dict[key]) == 0:
input_dict.pop(key) input_dict.pop(key)
return input_dict return input_dict
@TRANSFORMS.register_module()
class PolarMix(BaseTransform):
"""PolarMix data augmentation.
The polarmix transform steps are as follows:
1. Another random point cloud is picked by dataset.
2. Exchange sectors of two point clouds that are cut with certain
azimuth angles.
3. Cut point instances from picked point cloud, rotate them by multiple
azimuth angles, and paste the cut and rotated instances.
Required Keys:
- points (:obj:`BasePoints`)
- pts_semantic_mask (np.int64)
- dataset (:obj:`BaseDataset`)
Modified Keys:
- points (:obj:`BasePoints`)
- pts_semantic_mask (np.int64)
Args:
instance_classes (List[int]): Semantic masks which represent the
instance.
swap_ratio (float): Swap ratio of two point cloud. Defaults to 0.5.
rotate_paste_ratio (float): Rotate paste ratio. Defaults to 1.0.
pre_transform (Sequence[dict], optional): Sequence of transform object
or config dict to be composed. Defaults to None.
prob (float): The transformation probability. Defaults to 1.0.
"""
def __init__(self,
instance_classes: List[int],
swap_ratio: float = 0.5,
rotate_paste_ratio: float = 1.0,
pre_transform: Optional[Sequence[dict]] = None,
prob: float = 1.0) -> None:
assert is_list_of(instance_classes, int), \
'instance_classes should be a list of int'
self.instance_classes = instance_classes
self.swap_ratio = swap_ratio
self.rotate_paste_ratio = rotate_paste_ratio
self.prob = prob
if pre_transform is None:
self.pre_transform = None
else:
self.pre_transform = Compose(pre_transform)
def polar_mix_transform(self, input_dict: dict, mix_results: dict) -> dict:
"""PolarMix transform function.
Args:
input_dict (dict): Result dict from loading pipeline.
mix_results (dict): Mixed dict picked from dataset.
Returns:
dict: output dict after transformation.
"""
mix_points = mix_results['points']
mix_pts_semantic_mask = mix_results['pts_semantic_mask']
points = input_dict['points']
pts_semantic_mask = input_dict['pts_semantic_mask']
# 1. swap point cloud
if np.random.random() < self.swap_ratio:
start_angle = (np.random.random() - 1) * np.pi # -pi~0
end_angle = start_angle + np.pi
# calculate horizontal angle for each point
yaw = -torch.atan2(points.coord[:, 1], points.coord[:, 0])
mix_yaw = -torch.atan2(mix_points.coord[:, 1], mix_points.coord[:,
0])
# select points in sector
idx = (yaw <= start_angle) | (yaw >= end_angle)
mix_idx = (mix_yaw > start_angle) & (mix_yaw < end_angle)
# swap
points = points.cat([points[idx], mix_points[mix_idx]])
pts_semantic_mask = np.concatenate(
(pts_semantic_mask[idx.numpy()],
mix_pts_semantic_mask[mix_idx.numpy()]),
axis=0)
# 2. rotate-pasting
if np.random.random() < self.rotate_paste_ratio:
# extract instance points
instance_points, instance_pts_semantic_mask = [], []
for instance_class in self.instance_classes:
mix_idx = mix_pts_semantic_mask == instance_class
instance_points.append(mix_points[mix_idx])
instance_pts_semantic_mask.append(
mix_pts_semantic_mask[mix_idx])
instance_points = mix_points.cat(instance_points)
instance_pts_semantic_mask = np.concatenate(
instance_pts_semantic_mask, axis=0)
# rotate-copy
copy_points = [instance_points]
copy_pts_semantic_mask = [instance_pts_semantic_mask]
angle_list = [
np.random.random() * np.pi * 2 / 3,
(np.random.random() + 1) * np.pi * 2 / 3
]
for angle in angle_list:
new_points = instance_points.clone()
new_points.rotate(angle)
copy_points.append(new_points)
copy_pts_semantic_mask.append(instance_pts_semantic_mask)
copy_points = instance_points.cat(copy_points)
copy_pts_semantic_mask = np.concatenate(
copy_pts_semantic_mask, axis=0)
points = points.cat([points, copy_points])
pts_semantic_mask = np.concatenate(
(pts_semantic_mask, copy_pts_semantic_mask), axis=0)
input_dict['points'] = points
input_dict['pts_semantic_mask'] = pts_semantic_mask
return input_dict
def transform(self, input_dict: dict) -> dict:
"""PolarMix transform function.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: output dict after transformation.
"""
if np.random.rand() > self.prob:
return input_dict
assert 'dataset' in input_dict, \
'`dataset` is needed to pass through PolarMix, while not found.'
dataset = input_dict['dataset']
# get index of other point cloud
index = np.random.randint(0, len(dataset))
mix_results = dataset.get_data_info(index)
if self.pre_transform is not None:
# pre_transform may also require dataset
mix_results.update({'dataset': dataset})
# before polarmix need to go through
# the necessary pre_transform
mix_results = self.pre_transform(mix_results)
mix_results.pop('dataset')
input_dict = self.polar_mix_transform(input_dict, mix_results)
return input_dict
def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(instance_classes={self.instance_classes}, '
repr_str += f'swap_ratio={self.swap_ratio}, '
repr_str += f'rotate_paste_ratio={self.rotate_paste_ratio}, '
repr_str += f'pre_transform={self.pre_transform}, '
repr_str += f'prob={self.prob})'
return repr_str
...@@ -6,9 +6,14 @@ import numpy as np ...@@ -6,9 +6,14 @@ import numpy as np
import torch import torch
from mmengine.testing import assert_allclose from mmengine.testing import assert_allclose
from mmdet3d.datasets import GlobalAlignment, RandomFlip3D from mmdet3d.datasets import (GlobalAlignment, RandomFlip3D,
from mmdet3d.datasets.transforms import GlobalRotScaleTrans SemanticKITTIDataset)
from mmdet3d.datasets.transforms import GlobalRotScaleTrans, PolarMix
from mmdet3d.structures import LiDARPoints
from mmdet3d.testing import create_data_info_after_loading from mmdet3d.testing import create_data_info_after_loading
from mmdet3d.utils import register_all_modules
register_all_modules()
class TestGlobalRotScaleTrans(unittest.TestCase): class TestGlobalRotScaleTrans(unittest.TestCase):
...@@ -99,3 +104,121 @@ class TestGlobalAlignment(unittest.TestCase): ...@@ -99,3 +104,121 @@ class TestGlobalAlignment(unittest.TestCase):
# assert the rot metric # assert the rot metric
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
global_align_transform(data_info) global_align_transform(data_info)
class TestPolarMix(unittest.TestCase):
def setUp(self):
self.pre_transform = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=4,
use_dim=4),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True,
seg_3d_dtype='np.int32'),
dict(type='PointSegClassMapping'),
]
classes = ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck', 'bus',
'person', 'bicyclist', 'motorcyclist', 'road', 'parking',
'sidewalk', 'other-ground', 'building', 'fence',
'vegetation', 'trunck', 'terrian', 'pole', 'traffic-sign')
palette = [
[174, 199, 232],
[152, 223, 138],
[31, 119, 180],
[255, 187, 120],
[188, 189, 34],
[140, 86, 75],
[255, 152, 150],
[214, 39, 40],
[197, 176, 213],
[148, 103, 189],
[196, 156, 148],
[23, 190, 207],
[247, 182, 210],
[219, 219, 141],
[255, 127, 14],
[158, 218, 229],
[44, 160, 44],
[112, 128, 144],
[227, 119, 194],
[82, 84, 163],
]
seg_label_mapping = {
0: 0, # "unlabeled"
1: 0, # "outlier" mapped to "unlabeled" --------------mapped
10: 1, # "car"
11: 2, # "bicycle"
13: 5, # "bus" mapped to "other-vehicle" --------------mapped
15: 3, # "motorcycle"
16: 5, # "on-rails" mapped to "other-vehicle" ---------mapped
18: 4, # "truck"
20: 5, # "other-vehicle"
30: 6, # "person"
31: 7, # "bicyclist"
32: 8, # "motorcyclist"
40: 9, # "road"
44: 10, # "parking"
48: 11, # "sidewalk"
49: 12, # "other-ground"
50: 13, # "building"
51: 14, # "fence"
52: 0, # "other-structure" mapped to "unlabeled" ------mapped
60: 9, # "lane-marking" to "road" ---------------------mapped
70: 15, # "vegetation"
71: 16, # "trunk"
72: 17, # "terrain"
80: 18, # "pole"
81: 19, # "traffic-sign"
99: 0, # "other-object" to "unlabeled" ----------------mapped
252: 1, # "moving-car" to "car" ------------------------mapped
253: 7, # "moving-bicyclist" to "bicyclist" ------------mapped
254: 6, # "moving-person" to "person" ------------------mapped
255: 8, # "moving-motorcyclist" to "motorcyclist" ------mapped
256: 5, # "moving-on-rails" mapped to "other-vehic------mapped
257: 5, # "moving-bus" mapped to "other-vehicle" -------mapped
258: 4, # "moving-truck" to "truck" --------------------mapped
259: 5 # "moving-other"-vehicle to "other-vehicle"-----mapped
}
max_label = 259
self.dataset = SemanticKITTIDataset(
'./tests/data/semantickitti/',
'semantickitti_infos.pkl',
metainfo=dict(
classes=classes,
palette=palette,
seg_label_mapping=seg_label_mapping,
max_label=max_label),
data_prefix=dict(
pts='sequences/00/velodyne',
pts_semantic_mask='sequences/00/labels'),
pipeline=[],
modality=dict(use_lidar=True, use_camera=False))
points = np.random.random((100, 4))
self.results = {
'points': LiDARPoints(points, points_dim=4),
'pts_semantic_mask': np.random.randint(0, 20, (100, )),
'dataset': self.dataset
}
def test_transform(self):
# test assertion for invalid instance_classes
with self.assertRaises(AssertionError):
transform = PolarMix(instance_classes=1)
with self.assertRaises(AssertionError):
transform = PolarMix(instance_classes=[1.0, 2.0])
transform = PolarMix(
instance_classes=[1, 2],
swap_ratio=1.0,
pre_transform=self.pre_transform)
results = transform.transform(copy.deepcopy(self.results))
self.assertTrue(results['points'].shape[0] ==
results['pts_semantic_mask'].shape[0])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment