Unverified Commit 9a30edfc authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Enhance] Support rotate points without bbox in GlobalRotScaleTrans (#540)

parent 9d852f17
......@@ -176,6 +176,8 @@ class BasePoints(object):
raise NotImplementedError
self.tensor[:, :3] = self.tensor[:, :3] @ rot_mat_T
return rot_mat_T
@abstractmethod
def flip(self, bev_direction='horizontal'):
"""Flip the points in BEV along given BEV direction."""
......
......@@ -403,8 +403,23 @@ class GlobalRotScaleTrans(object):
scale_ratio_range=[0.95, 1.05],
translation_std=[0, 0, 0],
shift_height=False):
seq_types = (list, tuple, np.ndarray)
if not isinstance(rot_range, seq_types):
assert isinstance(rot_range, (int, float)), \
f'unsupported rot_range type {type(rot_range)}'
rot_range = [-rot_range, rot_range]
self.rot_range = rot_range
assert isinstance(scale_ratio_range, seq_types), \
f'unsupported scale_ratio_range type {type(scale_ratio_range)}'
self.scale_ratio_range = scale_ratio_range
if not isinstance(translation_std, seq_types):
assert isinstance(translation_std, (int, float)), \
f'unsupported translation_std type {type(translation_std)}'
translation_std = [
translation_std, translation_std, translation_std
]
self.translation_std = translation_std
self.shift_height = shift_height
......@@ -419,14 +434,7 @@ class GlobalRotScaleTrans(object):
and keys in input_dict['bbox3d_fields'] are updated \
in the result dict.
"""
if not isinstance(self.translation_std, (list, tuple, np.ndarray)):
translation_std = [
self.translation_std, self.translation_std,
self.translation_std
]
else:
translation_std = self.translation_std
translation_std = np.array(translation_std, dtype=np.float32)
translation_std = np.array(self.translation_std, dtype=np.float32)
trans_factor = np.random.normal(scale=translation_std, size=3).T
input_dict['points'].translate(trans_factor)
......@@ -446,17 +454,21 @@ class GlobalRotScaleTrans(object):
in the result dict.
"""
rotation = self.rot_range
if not isinstance(rotation, list):
rotation = [-rotation, rotation]
noise_rotation = np.random.uniform(rotation[0], rotation[1])
# if no bbox in input_dict, only rotate points
if len(input_dict['bbox3d_fields']) == 0:
rot_mat_T = input_dict['points'].rotate(noise_rotation)
input_dict['pcd_rotation'] = rot_mat_T
return
# rotate points with bboxes
for key in input_dict['bbox3d_fields']:
if len(input_dict[key].tensor) != 0:
points, rot_mat_T = input_dict[key].rotate(
noise_rotation, input_dict['points'])
input_dict['points'] = points
input_dict['pcd_rotation'] = rot_mat_T
# input_dict['points_instance'].rotate(noise_rotation)
def _scale_bbox_points(self, input_dict):
"""Private function to scale bounding boxes and points.
......@@ -472,7 +484,8 @@ class GlobalRotScaleTrans(object):
points = input_dict['points']
points.scale(scale)
if self.shift_height:
assert 'height' in points.attribute_dims.keys()
assert 'height' in points.attribute_dims.keys(), \
'setting shift_height=True but points have no height attribute'
points.tensor[:, points.attribute_dims['height']] *= scale
input_dict['points'] = points
......
......@@ -52,9 +52,9 @@ def test_multi_scale_flip_aug_3D():
bbox3d_fields=bbox3d_fields)
results = multi_scale_flip_aug_3D(results)
expected_points = torch.tensor(
[[-2.2095, 3.3160, -0.7707, 0.4417], [-1.3739, 3.8711, 0.8524, 2.0648],
[-1.8140, 3.5389, -1.0057, 0.2067], [0.2040, 1.4268, -1.0504, 0.1620],
[1.5090, 3.2764, -1.1914, 0.0210]],
[[-2.2418, 3.2942, -0.7707, 0.4417], [-1.4116, 3.8575, 0.8524, 2.0648],
[-1.8484, 3.5210, -1.0057, 0.2067], [0.1900, 1.4287, -1.0504, 0.1620],
[1.4770, 3.2910, -1.1914, 0.0210]],
dtype=torch.float32)
assert torch.allclose(
......
......@@ -3,11 +3,12 @@ import numpy as np
import pytest
import torch
from mmdet3d.core import Box3DMode, CameraInstance3DBoxes, LiDARInstance3DBoxes
from mmdet3d.core import (Box3DMode, CameraInstance3DBoxes,
DepthInstance3DBoxes, LiDARInstance3DBoxes)
from mmdet3d.core.points import DepthPoints, LiDARPoints
from mmdet3d.datasets import (BackgroundPointsFilter, GlobalAlignment,
ObjectNoise, ObjectSample, PointShuffle,
PointsRangeFilter, RandomFlip3D,
GlobalRotScaleTrans, ObjectNoise, ObjectSample,
PointShuffle, PointsRangeFilter, RandomFlip3D,
VoxelBasedPointSampler)
......@@ -255,6 +256,114 @@ def test_global_alignment():
assert repr_str == expected_repr_str
def test_global_rot_scale_trans():
angle = 0.78539816
scale = [0.95, 1.05]
trans_std = 1.0
# rot_range should be a number or seq of numbers
with pytest.raises(AssertionError):
global_rot_scale_trans = GlobalRotScaleTrans(rot_range='0.0')
# scale_ratio_range should be seq of numbers
with pytest.raises(AssertionError):
global_rot_scale_trans = GlobalRotScaleTrans(scale_ratio_range=1.0)
# translation_std should be a number or seq of numbers
with pytest.raises(AssertionError):
global_rot_scale_trans = GlobalRotScaleTrans(translation_std='0.0')
global_rot_scale_trans = GlobalRotScaleTrans(
rot_range=angle,
scale_ratio_range=scale,
translation_std=trans_std,
shift_height=False)
np.random.seed(0)
points = np.fromfile('tests/data/scannet/points/scene0000_00.bin',
np.float32).reshape(-1, 6)
annos = mmcv.load('tests/data/scannet/scannet_infos.pkl')
info = annos[0]
gt_bboxes_3d = info['annos']['gt_boxes_upright_depth']
depth_points = DepthPoints(
points.copy(), points_dim=6, attribute_dims=dict(color=[3, 4, 5]))
gt_bboxes_3d = DepthInstance3DBoxes(
gt_bboxes_3d.copy(),
box_dim=gt_bboxes_3d.shape[-1],
with_yaw=False,
origin=(0.5, 0.5, 0.5))
input_dict = dict(
points=depth_points.clone(),
bbox3d_fields=['gt_bboxes_3d'],
gt_bboxes_3d=gt_bboxes_3d.clone())
input_dict = global_rot_scale_trans(input_dict)
trans_depth_points = input_dict['points']
trans_bboxes_3d = input_dict['gt_bboxes_3d']
noise_rot = 0.07667607233534723
scale_factor = 1.021518936637242
trans_factor = np.array([0.97873798, 2.2408932, 1.86755799])
true_depth_points = depth_points.clone()
true_bboxes_3d = gt_bboxes_3d.clone()
true_depth_points, noise_rot_mat_T = true_bboxes_3d.rotate(
noise_rot, true_depth_points)
true_bboxes_3d.scale(scale_factor)
true_bboxes_3d.translate(trans_factor)
true_depth_points.scale(scale_factor)
true_depth_points.translate(trans_factor)
assert torch.allclose(
trans_depth_points.tensor, true_depth_points.tensor, atol=1e-6)
assert torch.allclose(
trans_bboxes_3d.tensor, true_bboxes_3d.tensor, atol=1e-6)
assert input_dict['pcd_scale_factor'] == scale_factor
assert torch.allclose(
input_dict['pcd_rotation'], noise_rot_mat_T, atol=1e-6)
assert np.allclose(input_dict['pcd_trans'], trans_factor)
repr_str = repr(global_rot_scale_trans)
expected_repr_str = f'GlobalRotScaleTrans(rot_range={[-angle, angle]},' \
f' scale_ratio_range={scale},' \
f' translation_std={[trans_std for _ in range(3)]},' \
f' shift_height=False)'
assert repr_str == expected_repr_str
# points with shift_height but no bbox
global_rot_scale_trans = GlobalRotScaleTrans(
rot_range=angle,
scale_ratio_range=scale,
translation_std=trans_std,
shift_height=True)
# points should have height attribute when shift_height=True
with pytest.raises(AssertionError):
input_dict = global_rot_scale_trans(input_dict)
np.random.seed(0)
shift_height = points[:, 2:3] * 0.99
points = np.concatenate([points, shift_height], axis=1)
depth_points = DepthPoints(
points.copy(),
points_dim=7,
attribute_dims=dict(color=[3, 4, 5], height=6))
input_dict = dict(points=depth_points.clone(), bbox3d_fields=[])
input_dict = global_rot_scale_trans(input_dict)
trans_depth_points = input_dict['points']
true_shift_height = shift_height * scale_factor
assert np.allclose(
trans_depth_points.tensor.numpy(),
np.concatenate([true_depth_points.tensor.numpy(), true_shift_height],
axis=1),
atol=1e-6)
def test_random_flip_3d():
random_flip_3d = RandomFlip3D(
flip_ratio_bev_horizontal=1.0, flip_ratio_bev_vertical=1.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment