Unverified Commit 9a30edfc authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Enhance] Support rotate points without bbox in GlobalRotScaleTrans (#540)

parent 9d852f17
...@@ -176,6 +176,8 @@ class BasePoints(object): ...@@ -176,6 +176,8 @@ class BasePoints(object):
raise NotImplementedError raise NotImplementedError
self.tensor[:, :3] = self.tensor[:, :3] @ rot_mat_T self.tensor[:, :3] = self.tensor[:, :3] @ rot_mat_T
return rot_mat_T
@abstractmethod @abstractmethod
def flip(self, bev_direction='horizontal'): def flip(self, bev_direction='horizontal'):
"""Flip the points in BEV along given BEV direction.""" """Flip the points in BEV along given BEV direction."""
......
...@@ -403,8 +403,23 @@ class GlobalRotScaleTrans(object): ...@@ -403,8 +403,23 @@ class GlobalRotScaleTrans(object):
scale_ratio_range=[0.95, 1.05], scale_ratio_range=[0.95, 1.05],
translation_std=[0, 0, 0], translation_std=[0, 0, 0],
shift_height=False): shift_height=False):
seq_types = (list, tuple, np.ndarray)
if not isinstance(rot_range, seq_types):
assert isinstance(rot_range, (int, float)), \
f'unsupported rot_range type {type(rot_range)}'
rot_range = [-rot_range, rot_range]
self.rot_range = rot_range self.rot_range = rot_range
assert isinstance(scale_ratio_range, seq_types), \
f'unsupported scale_ratio_range type {type(scale_ratio_range)}'
self.scale_ratio_range = scale_ratio_range self.scale_ratio_range = scale_ratio_range
if not isinstance(translation_std, seq_types):
assert isinstance(translation_std, (int, float)), \
f'unsupported translation_std type {type(translation_std)}'
translation_std = [
translation_std, translation_std, translation_std
]
self.translation_std = translation_std self.translation_std = translation_std
self.shift_height = shift_height self.shift_height = shift_height
...@@ -419,14 +434,7 @@ class GlobalRotScaleTrans(object): ...@@ -419,14 +434,7 @@ class GlobalRotScaleTrans(object):
and keys in input_dict['bbox3d_fields'] are updated \ and keys in input_dict['bbox3d_fields'] are updated \
in the result dict. in the result dict.
""" """
if not isinstance(self.translation_std, (list, tuple, np.ndarray)): translation_std = np.array(self.translation_std, dtype=np.float32)
translation_std = [
self.translation_std, self.translation_std,
self.translation_std
]
else:
translation_std = self.translation_std
translation_std = np.array(translation_std, dtype=np.float32)
trans_factor = np.random.normal(scale=translation_std, size=3).T trans_factor = np.random.normal(scale=translation_std, size=3).T
input_dict['points'].translate(trans_factor) input_dict['points'].translate(trans_factor)
...@@ -446,17 +454,21 @@ class GlobalRotScaleTrans(object): ...@@ -446,17 +454,21 @@ class GlobalRotScaleTrans(object):
in the result dict. in the result dict.
""" """
rotation = self.rot_range rotation = self.rot_range
if not isinstance(rotation, list):
rotation = [-rotation, rotation]
noise_rotation = np.random.uniform(rotation[0], rotation[1]) noise_rotation = np.random.uniform(rotation[0], rotation[1])
# if no bbox in input_dict, only rotate points
if len(input_dict['bbox3d_fields']) == 0:
rot_mat_T = input_dict['points'].rotate(noise_rotation)
input_dict['pcd_rotation'] = rot_mat_T
return
# rotate points with bboxes
for key in input_dict['bbox3d_fields']: for key in input_dict['bbox3d_fields']:
if len(input_dict[key].tensor) != 0: if len(input_dict[key].tensor) != 0:
points, rot_mat_T = input_dict[key].rotate( points, rot_mat_T = input_dict[key].rotate(
noise_rotation, input_dict['points']) noise_rotation, input_dict['points'])
input_dict['points'] = points input_dict['points'] = points
input_dict['pcd_rotation'] = rot_mat_T input_dict['pcd_rotation'] = rot_mat_T
# input_dict['points_instance'].rotate(noise_rotation)
def _scale_bbox_points(self, input_dict): def _scale_bbox_points(self, input_dict):
"""Private function to scale bounding boxes and points. """Private function to scale bounding boxes and points.
...@@ -472,7 +484,8 @@ class GlobalRotScaleTrans(object): ...@@ -472,7 +484,8 @@ class GlobalRotScaleTrans(object):
points = input_dict['points'] points = input_dict['points']
points.scale(scale) points.scale(scale)
if self.shift_height: if self.shift_height:
assert 'height' in points.attribute_dims.keys() assert 'height' in points.attribute_dims.keys(), \
'setting shift_height=True but points have no height attribute'
points.tensor[:, points.attribute_dims['height']] *= scale points.tensor[:, points.attribute_dims['height']] *= scale
input_dict['points'] = points input_dict['points'] = points
......
...@@ -52,9 +52,9 @@ def test_multi_scale_flip_aug_3D(): ...@@ -52,9 +52,9 @@ def test_multi_scale_flip_aug_3D():
bbox3d_fields=bbox3d_fields) bbox3d_fields=bbox3d_fields)
results = multi_scale_flip_aug_3D(results) results = multi_scale_flip_aug_3D(results)
expected_points = torch.tensor( expected_points = torch.tensor(
[[-2.2095, 3.3160, -0.7707, 0.4417], [-1.3739, 3.8711, 0.8524, 2.0648], [[-2.2418, 3.2942, -0.7707, 0.4417], [-1.4116, 3.8575, 0.8524, 2.0648],
[-1.8140, 3.5389, -1.0057, 0.2067], [0.2040, 1.4268, -1.0504, 0.1620], [-1.8484, 3.5210, -1.0057, 0.2067], [0.1900, 1.4287, -1.0504, 0.1620],
[1.5090, 3.2764, -1.1914, 0.0210]], [1.4770, 3.2910, -1.1914, 0.0210]],
dtype=torch.float32) dtype=torch.float32)
assert torch.allclose( assert torch.allclose(
......
...@@ -3,11 +3,12 @@ import numpy as np ...@@ -3,11 +3,12 @@ import numpy as np
import pytest import pytest
import torch import torch
from mmdet3d.core import Box3DMode, CameraInstance3DBoxes, LiDARInstance3DBoxes from mmdet3d.core import (Box3DMode, CameraInstance3DBoxes,
DepthInstance3DBoxes, LiDARInstance3DBoxes)
from mmdet3d.core.points import DepthPoints, LiDARPoints from mmdet3d.core.points import DepthPoints, LiDARPoints
from mmdet3d.datasets import (BackgroundPointsFilter, GlobalAlignment, from mmdet3d.datasets import (BackgroundPointsFilter, GlobalAlignment,
ObjectNoise, ObjectSample, PointShuffle, GlobalRotScaleTrans, ObjectNoise, ObjectSample,
PointsRangeFilter, RandomFlip3D, PointShuffle, PointsRangeFilter, RandomFlip3D,
VoxelBasedPointSampler) VoxelBasedPointSampler)
...@@ -255,6 +256,114 @@ def test_global_alignment(): ...@@ -255,6 +256,114 @@ def test_global_alignment():
assert repr_str == expected_repr_str assert repr_str == expected_repr_str
def test_global_rot_scale_trans():
angle = 0.78539816
scale = [0.95, 1.05]
trans_std = 1.0
# rot_range should be a number or seq of numbers
with pytest.raises(AssertionError):
global_rot_scale_trans = GlobalRotScaleTrans(rot_range='0.0')
# scale_ratio_range should be seq of numbers
with pytest.raises(AssertionError):
global_rot_scale_trans = GlobalRotScaleTrans(scale_ratio_range=1.0)
# translation_std should be a number or seq of numbers
with pytest.raises(AssertionError):
global_rot_scale_trans = GlobalRotScaleTrans(translation_std='0.0')
global_rot_scale_trans = GlobalRotScaleTrans(
rot_range=angle,
scale_ratio_range=scale,
translation_std=trans_std,
shift_height=False)
np.random.seed(0)
points = np.fromfile('tests/data/scannet/points/scene0000_00.bin',
np.float32).reshape(-1, 6)
annos = mmcv.load('tests/data/scannet/scannet_infos.pkl')
info = annos[0]
gt_bboxes_3d = info['annos']['gt_boxes_upright_depth']
depth_points = DepthPoints(
points.copy(), points_dim=6, attribute_dims=dict(color=[3, 4, 5]))
gt_bboxes_3d = DepthInstance3DBoxes(
gt_bboxes_3d.copy(),
box_dim=gt_bboxes_3d.shape[-1],
with_yaw=False,
origin=(0.5, 0.5, 0.5))
input_dict = dict(
points=depth_points.clone(),
bbox3d_fields=['gt_bboxes_3d'],
gt_bboxes_3d=gt_bboxes_3d.clone())
input_dict = global_rot_scale_trans(input_dict)
trans_depth_points = input_dict['points']
trans_bboxes_3d = input_dict['gt_bboxes_3d']
noise_rot = 0.07667607233534723
scale_factor = 1.021518936637242
trans_factor = np.array([0.97873798, 2.2408932, 1.86755799])
true_depth_points = depth_points.clone()
true_bboxes_3d = gt_bboxes_3d.clone()
true_depth_points, noise_rot_mat_T = true_bboxes_3d.rotate(
noise_rot, true_depth_points)
true_bboxes_3d.scale(scale_factor)
true_bboxes_3d.translate(trans_factor)
true_depth_points.scale(scale_factor)
true_depth_points.translate(trans_factor)
assert torch.allclose(
trans_depth_points.tensor, true_depth_points.tensor, atol=1e-6)
assert torch.allclose(
trans_bboxes_3d.tensor, true_bboxes_3d.tensor, atol=1e-6)
assert input_dict['pcd_scale_factor'] == scale_factor
assert torch.allclose(
input_dict['pcd_rotation'], noise_rot_mat_T, atol=1e-6)
assert np.allclose(input_dict['pcd_trans'], trans_factor)
repr_str = repr(global_rot_scale_trans)
expected_repr_str = f'GlobalRotScaleTrans(rot_range={[-angle, angle]},' \
f' scale_ratio_range={scale},' \
f' translation_std={[trans_std for _ in range(3)]},' \
f' shift_height=False)'
assert repr_str == expected_repr_str
# points with shift_height but no bbox
global_rot_scale_trans = GlobalRotScaleTrans(
rot_range=angle,
scale_ratio_range=scale,
translation_std=trans_std,
shift_height=True)
# points should have height attribute when shift_height=True
with pytest.raises(AssertionError):
input_dict = global_rot_scale_trans(input_dict)
np.random.seed(0)
shift_height = points[:, 2:3] * 0.99
points = np.concatenate([points, shift_height], axis=1)
depth_points = DepthPoints(
points.copy(),
points_dim=7,
attribute_dims=dict(color=[3, 4, 5], height=6))
input_dict = dict(points=depth_points.clone(), bbox3d_fields=[])
input_dict = global_rot_scale_trans(input_dict)
trans_depth_points = input_dict['points']
true_shift_height = shift_height * scale_factor
assert np.allclose(
trans_depth_points.tensor.numpy(),
np.concatenate([true_depth_points.tensor.numpy(), true_shift_height],
axis=1),
atol=1e-6)
def test_random_flip_3d(): def test_random_flip_3d():
random_flip_3d = RandomFlip3D( random_flip_3d = RandomFlip3D(
flip_ratio_bev_horizontal=1.0, flip_ratio_bev_vertical=1.0) flip_ratio_bev_horizontal=1.0, flip_ratio_bev_vertical=1.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment