Commit e9029c0e authored by zhangwenwei's avatar zhangwenwei
Browse files

Merge branch 'merge_rot_scale' into 'master'

Merge rot scale

See merge request open-mmlab/mmdet.3d!82
parents 5a1575a0 92ae69fb
...@@ -8,7 +8,6 @@ db_sampler = dict( ...@@ -8,7 +8,6 @@ db_sampler = dict(
data_root=data_root, data_root=data_root,
info_path=data_root + 'kitti_dbinfos_train.pkl', info_path=data_root + 'kitti_dbinfos_train.pkl',
rate=1.0, rate=1.0,
object_rot_range=[0.0, 0.0],
prepare=dict( prepare=dict(
filter_by_difficulty=[-1], filter_by_difficulty=[-1],
filter_by_min_points=dict(Car=5, Pedestrian=10, Cyclist=10)), filter_by_min_points=dict(Car=5, Pedestrian=10, Cyclist=10)),
...@@ -40,7 +39,7 @@ train_pipeline = [ ...@@ -40,7 +39,7 @@ train_pipeline = [
translation_std=[1.0, 1.0, 0.5], translation_std=[1.0, 1.0, 0.5],
global_rot_range=[0.0, 0.0], global_rot_range=[0.0, 0.0],
rot_range=[-0.78539816, 0.78539816]), rot_range=[-0.78539816, 0.78539816]),
dict(type='RandomFlip3D', flip_ratio=0.5), dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict( dict(
type='GlobalRotScaleTrans', type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816], rot_range=[-0.78539816, 0.78539816],
......
...@@ -8,7 +8,6 @@ db_sampler = dict( ...@@ -8,7 +8,6 @@ db_sampler = dict(
data_root=data_root, data_root=data_root,
info_path=data_root + 'kitti_dbinfos_train.pkl', info_path=data_root + 'kitti_dbinfos_train.pkl',
rate=1.0, rate=1.0,
object_rot_range=[0.0, 0.0],
prepare=dict(filter_by_difficulty=[-1], filter_by_min_points=dict(Car=5)), prepare=dict(filter_by_difficulty=[-1], filter_by_min_points=dict(Car=5)),
classes=class_names, classes=class_names,
sample_groups=dict(Car=15)) sample_groups=dict(Car=15))
...@@ -38,7 +37,7 @@ train_pipeline = [ ...@@ -38,7 +37,7 @@ train_pipeline = [
translation_std=[1.0, 1.0, 0.5], translation_std=[1.0, 1.0, 0.5],
global_rot_range=[0.0, 0.0], global_rot_range=[0.0, 0.0],
rot_range=[-0.78539816, 0.78539816]), rot_range=[-0.78539816, 0.78539816]),
dict(type='RandomFlip3D', flip_ratio=0.5), dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict( dict(
type='GlobalRotScaleTrans', type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816], rot_range=[-0.78539816, 0.78539816],
......
...@@ -42,7 +42,7 @@ train_pipeline = [ ...@@ -42,7 +42,7 @@ train_pipeline = [
rot_range=[-0.3925, 0.3925], rot_range=[-0.3925, 0.3925],
scale_ratio_range=[0.95, 1.05], scale_ratio_range=[0.95, 1.05],
translation_std=[0, 0, 0]), translation_std=[0, 0, 0]),
dict(type='RandomFlip3D', flip_ratio=0.5), dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointShuffle'), dict(type='PointShuffle'),
......
...@@ -22,12 +22,16 @@ train_pipeline = [ ...@@ -22,12 +22,16 @@ train_pipeline = [
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34,
36, 39)), 36, 39)),
dict(type='IndoorPointSample', num_points=40000), dict(type='IndoorPointSample', num_points=40000),
dict(type='IndoorFlipData', flip_ratio_yz=0.5, flip_ratio_xz=0.5),
dict( dict(
type='IndoorGlobalRotScaleTrans', type='RandomFlip3D',
shift_height=True, sync_2d=False,
rot_range=[-1 / 36, 1 / 36], flip_ratio_bev_horizontal=0.5,
scale_range=None), flip_ratio_bev_vertical=0.5),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.087266, 0.087266],
scale_ratio_range=[1.0, 1.0],
shift_height=True),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='DefaultFormatBundle3D', class_names=class_names),
dict( dict(
type='Collect3D', type='Collect3D',
...@@ -53,7 +57,11 @@ test_pipeline = [ ...@@ -53,7 +57,11 @@ test_pipeline = [
rot_range=[0, 0], rot_range=[0, 0],
scale_ratio_range=[1., 1.], scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]), translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'), dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
flip_ratio_bev_vertical=0.5),
dict(type='IndoorPointSample', num_points=40000), dict(type='IndoorPointSample', num_points=40000),
dict( dict(
type='DefaultFormatBundle3D', type='DefaultFormatBundle3D',
......
...@@ -9,12 +9,16 @@ train_pipeline = [ ...@@ -9,12 +9,16 @@ train_pipeline = [
load_dim=6, load_dim=6,
use_dim=[0, 1, 2]), use_dim=[0, 1, 2]),
dict(type='LoadAnnotations3D'), dict(type='LoadAnnotations3D'),
dict(type='IndoorFlipData', flip_ratio_yz=0.5),
dict( dict(
type='IndoorGlobalRotScaleTrans', type='RandomFlip3D',
shift_height=True, sync_2d=False,
rot_range=[-1 / 6, 1 / 6], flip_ratio_bev_horizontal=0.5,
scale_range=[0.85, 1.15]), ),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.523599, 0.523599],
scale_ratio_range=[0.85, 1.15],
shift_height=True),
dict(type='IndoorPointSample', num_points=20000), dict(type='IndoorPointSample', num_points=20000),
dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
...@@ -36,7 +40,11 @@ test_pipeline = [ ...@@ -36,7 +40,11 @@ test_pipeline = [
rot_range=[0, 0], rot_range=[0, 0],
scale_ratio_range=[1., 1.], scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]), translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'), dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
),
dict(type='IndoorPointSample', num_points=20000), dict(type='IndoorPointSample', num_points=20000),
dict( dict(
type='DefaultFormatBundle3D', type='DefaultFormatBundle3D',
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# This schedule is mainly used by models on indoor dataset, # This schedule is mainly used by models on indoor dataset,
# e.g., VoteNet on SUNRGBD and ScanNet # e.g., VoteNet on SUNRGBD and ScanNet
lr = 0.008 # max learning rate lr = 0.008 # max learning rate
optimizer = dict(type='Adam', lr=lr) optimizer = dict(type='AdamW', lr=lr, weight_decay=0.01)
optimizer_config = dict(grad_clip=dict(max_norm=10, norm_type=2)) optimizer_config = dict(grad_clip=dict(max_norm=10, norm_type=2))
lr_config = dict(policy='step', warmup=None, step=[24, 32]) lr_config = dict(policy='step', warmup=None, step=[24, 32])
# runtime settings # runtime settings
......
...@@ -81,7 +81,6 @@ db_sampler = dict( ...@@ -81,7 +81,6 @@ db_sampler = dict(
data_root=data_root, data_root=data_root,
info_path=data_root + 'kitti_dbinfos_train.pkl', info_path=data_root + 'kitti_dbinfos_train.pkl',
rate=1.0, rate=1.0,
object_rot_range=[0.0, 0.0],
prepare=dict(filter_by_difficulty=[-1], filter_by_min_points=dict(Car=5)), prepare=dict(filter_by_difficulty=[-1], filter_by_min_points=dict(Car=5)),
sample_groups=dict(Car=15), sample_groups=dict(Car=15),
classes=class_names) classes=class_names)
...@@ -96,7 +95,7 @@ train_pipeline = [ ...@@ -96,7 +95,7 @@ train_pipeline = [
loc_noise_std=[0.25, 0.25, 0.25], loc_noise_std=[0.25, 0.25, 0.25],
global_rot_range=[0.0, 0.0], global_rot_range=[0.0, 0.0],
rot_uniform_noise=[-0.15707963267, 0.15707963267]), rot_uniform_noise=[-0.15707963267, 0.15707963267]),
dict(type='RandomFlip3D', flip_ratio=0.5), dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict( dict(
type='GlobalRotScale', type='GlobalRotScale',
rot_uniform_noise=[-0.78539816, 0.78539816], rot_uniform_noise=[-0.78539816, 0.78539816],
......
...@@ -110,7 +110,6 @@ db_sampler = dict( ...@@ -110,7 +110,6 @@ db_sampler = dict(
data_root=data_root, data_root=data_root,
info_path=data_root + 'kitti_dbinfos_train.pkl', info_path=data_root + 'kitti_dbinfos_train.pkl',
rate=1.0, rate=1.0,
object_rot_range=[0.0, 0.0],
prepare=dict( prepare=dict(
filter_by_difficulty=[-1], filter_by_difficulty=[-1],
filter_by_min_points=dict( filter_by_min_points=dict(
...@@ -135,7 +134,7 @@ train_pipeline = [ ...@@ -135,7 +134,7 @@ train_pipeline = [
translation_std=[1.0, 1.0, 0.1], translation_std=[1.0, 1.0, 0.1],
global_rot_range=[0.0, 0.0], global_rot_range=[0.0, 0.0],
rot_range=[-0.78539816, 0.78539816]), rot_range=[-0.78539816, 0.78539816]),
dict(type='RandomFlip3D', flip_ratio=0.5), dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict( dict(
type='GlobalRotScaleTrans', type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816], rot_range=[-0.78539816, 0.78539816],
......
...@@ -99,7 +99,6 @@ db_sampler = dict( ...@@ -99,7 +99,6 @@ db_sampler = dict(
data_root=data_root, data_root=data_root,
info_path=data_root + 'kitti_dbinfos_train.pkl', info_path=data_root + 'kitti_dbinfos_train.pkl',
rate=1.0, rate=1.0,
object_rot_range=[0.0, 0.0],
prepare=dict( prepare=dict(
filter_by_difficulty=[-1], filter_by_difficulty=[-1],
filter_by_min_points=dict( filter_by_min_points=dict(
...@@ -135,7 +134,7 @@ train_pipeline = [ ...@@ -135,7 +134,7 @@ train_pipeline = [
translation_std=[1.0, 1.0, 0.1], translation_std=[1.0, 1.0, 0.1],
global_rot_range=[0.0, 0.0], global_rot_range=[0.0, 0.0],
rot_range=[-0.78539816, 0.78539816]), rot_range=[-0.78539816, 0.78539816]),
dict(type='RandomFlip3D', flip_ratio=0.5), dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict( dict(
type='GlobalRotScaleTrans', type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816], rot_range=[-0.78539816, 0.78539816],
......
...@@ -148,7 +148,7 @@ train_pipeline = [ ...@@ -148,7 +148,7 @@ train_pipeline = [
rot_range=[-0.78539816, 0.78539816], rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.95, 1.05], scale_ratio_range=[0.95, 1.05],
translation_std=[0.2, 0.2, 0.2]), translation_std=[0.2, 0.2, 0.2]),
dict(type='RandomFlip3D', flip_ratio=0.5), dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointShuffle'), dict(type='PointShuffle'),
......
...@@ -204,7 +204,6 @@ db_sampler = dict( ...@@ -204,7 +204,6 @@ db_sampler = dict(
data_root=data_root, data_root=data_root,
info_path=data_root + 'kitti_dbinfos_train.pkl', info_path=data_root + 'kitti_dbinfos_train.pkl',
rate=1.0, rate=1.0,
object_rot_range=[0.0, 0.0],
prepare=dict( prepare=dict(
filter_by_difficulty=[-1], filter_by_difficulty=[-1],
filter_by_min_points=dict(Car=5, Pedestrian=10, Cyclist=10)), filter_by_min_points=dict(Car=5, Pedestrian=10, Cyclist=10)),
...@@ -220,7 +219,7 @@ train_pipeline = [ ...@@ -220,7 +219,7 @@ train_pipeline = [
translation_std=[1.0, 1.0, 0.5], translation_std=[1.0, 1.0, 0.5],
global_rot_range=[0.0, 0.0], global_rot_range=[0.0, 0.0],
rot_range=[-0.78539816, 0.78539816]), rot_range=[-0.78539816, 0.78539816]),
dict(type='RandomFlip3D', flip_ratio=0.5), dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict( dict(
type='GlobalRotScaleTrans', type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816], rot_range=[-0.78539816, 0.78539816],
......
...@@ -78,7 +78,6 @@ db_sampler = dict( ...@@ -78,7 +78,6 @@ db_sampler = dict(
data_root=data_root, data_root=data_root,
info_path=data_root + 'kitti_dbinfos_train.pkl', info_path=data_root + 'kitti_dbinfos_train.pkl',
rate=1.0, rate=1.0,
object_rot_range=[0.0, 0.0],
prepare=dict(filter_by_difficulty=[-1], filter_by_min_points=dict(Car=5)), prepare=dict(filter_by_difficulty=[-1], filter_by_min_points=dict(Car=5)),
classes=class_names, classes=class_names,
sample_groups=dict(Car=15)) sample_groups=dict(Car=15))
...@@ -92,7 +91,7 @@ train_pipeline = [ ...@@ -92,7 +91,7 @@ train_pipeline = [
translation_std=[1.0, 1.0, 0.5], translation_std=[1.0, 1.0, 0.5],
global_rot_range=[0.0, 0.0], global_rot_range=[0.0, 0.0],
rot_range=[-0.78539816, 0.78539816]), rot_range=[-0.78539816, 0.78539816]),
dict(type='RandomFlip3D', flip_ratio=0.5), dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict( dict(
type='GlobalRotScaleTrans', type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816], rot_range=[-0.78539816, 0.78539816],
......
...@@ -13,7 +13,6 @@ db_sampler = dict( ...@@ -13,7 +13,6 @@ db_sampler = dict(
data_root=data_root, data_root=data_root,
info_path=data_root + 'kitti_dbinfos_train.pkl', info_path=data_root + 'kitti_dbinfos_train.pkl',
rate=1.0, rate=1.0,
object_rot_range=[0.0, 0.0],
prepare=dict( prepare=dict(
filter_by_difficulty=[-1], filter_by_difficulty=[-1],
filter_by_min_points=dict(Car=5, Pedestrian=10, Cyclist=10)), filter_by_min_points=dict(Car=5, Pedestrian=10, Cyclist=10)),
...@@ -31,7 +30,7 @@ train_pipeline = [ ...@@ -31,7 +30,7 @@ train_pipeline = [
translation_std=[0.25, 0.25, 0.25], translation_std=[0.25, 0.25, 0.25],
global_rot_range=[0.0, 0.0], global_rot_range=[0.0, 0.0],
rot_range=[-0.15707963267, 0.15707963267]), rot_range=[-0.15707963267, 0.15707963267]),
dict(type='RandomFlip3D', flip_ratio=0.5), dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict( dict(
type='GlobalRotScaleTrans', type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816], rot_range=[-0.78539816, 0.78539816],
......
...@@ -35,7 +35,6 @@ db_sampler = dict( ...@@ -35,7 +35,6 @@ db_sampler = dict(
data_root=data_root, data_root=data_root,
info_path=data_root + 'kitti_dbinfos_train.pkl', info_path=data_root + 'kitti_dbinfos_train.pkl',
rate=1.0, rate=1.0,
object_rot_range=[0.0, 0.0],
prepare=dict(filter_by_difficulty=[-1], filter_by_min_points=dict(Car=5)), prepare=dict(filter_by_difficulty=[-1], filter_by_min_points=dict(Car=5)),
sample_groups=dict(Car=15), sample_groups=dict(Car=15),
classes=class_names) classes=class_names)
...@@ -50,7 +49,7 @@ train_pipeline = [ ...@@ -50,7 +49,7 @@ train_pipeline = [
translation_std=[0.25, 0.25, 0.25], translation_std=[0.25, 0.25, 0.25],
global_rot_range=[0.0, 0.0], global_rot_range=[0.0, 0.0],
rot_range=[-0.15707963267, 0.15707963267]), rot_range=[-0.15707963267, 0.15707963267]),
dict(type='RandomFlip3D', flip_ratio=0.5), dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict( dict(
type='GlobalRotScaleTrans', type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816], rot_range=[-0.78539816, 0.78539816],
......
...@@ -101,6 +101,9 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes): ...@@ -101,6 +101,9 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
Returns: Returns:
torch.Tensor: corners of each box with size (N, 8, 3) torch.Tensor: corners of each box with size (N, 8, 3)
""" """
# TODO: rotation_3d_in_axis function do not support
# empty tensor currently.
assert len(self.tensor) != 0
dims = self.dims dims = self.dims
corners_norm = torch.from_numpy( corners_norm = torch.from_numpy(
np.stack(np.unravel_index(np.arange(8), [2] * 3), axis=1)).to( np.stack(np.unravel_index(np.arange(8), [2] * 3), axis=1)).to(
...@@ -150,11 +153,18 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes): ...@@ -150,11 +153,18 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
bev_boxes = torch.cat([centers - dims / 2, centers + dims / 2], dim=-1) bev_boxes = torch.cat([centers - dims / 2, centers + dims / 2], dim=-1)
return bev_boxes return bev_boxes
def rotate(self, angle): def rotate(self, angle, points=None):
"""Calculate whether the points is in any of the boxes """Rotate boxes with points (optional) with the given angle.
Args: Args:
angle (float | torch.Tensor): rotation angle angle (float, torch.Tensor): Rotation angle.
points (torch.Tensor, numpy.ndarray, optional): Points to rotate.
Defaults to None.
Returns:
tuple or None: When ``points`` is None, the function returns None,
otherwise it returns the rotated points and the
rotation matrix ``rot_mat_T``.
""" """
if not isinstance(angle, torch.Tensor): if not isinstance(angle, torch.Tensor):
angle = self.tensor.new_tensor(angle) angle = self.tensor.new_tensor(angle)
...@@ -166,13 +176,28 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes): ...@@ -166,13 +176,28 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
self.tensor[:, :3] = self.tensor[:, :3] @ rot_mat_T self.tensor[:, :3] = self.tensor[:, :3] @ rot_mat_T
self.tensor[:, 6] += angle self.tensor[:, 6] += angle
def flip(self, bev_direction='horizontal'): if points is not None:
if isinstance(points, torch.Tensor):
points[:, :3] = points[:, :3] @ rot_mat_T
elif isinstance(points, np.ndarray):
rot_mat_T = rot_mat_T.numpy()
points[:, :3] = np.dot(points[:, :3], rot_mat_T)
else:
raise ValueError
return points, rot_mat_T
def flip(self, bev_direction='horizontal', points=None):
"""Flip the boxes in BEV along given BEV direction """Flip the boxes in BEV along given BEV direction
In CAM coordinates, it flips the x (horizontal) or z (vertical) axis. In CAM coordinates, it flips the x (horizontal) or z (vertical) axis.
Args: Args:
bev_direction (str): Flip direction (horizontal or vertical). bev_direction (str): Flip direction (horizontal or vertical).
points (torch.Tensor, numpy.ndarray, None): Points to flip.
Defaults to None.
Returns:
torch.Tensor, numpy.ndarray or None: Flipped points.
""" """
assert bev_direction in ('horizontal', 'vertical') assert bev_direction in ('horizontal', 'vertical')
if bev_direction == 'horizontal': if bev_direction == 'horizontal':
...@@ -184,6 +209,14 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes): ...@@ -184,6 +209,14 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
if self.with_yaw: if self.with_yaw:
self.tensor[:, 6] = -self.tensor[:, 6] self.tensor[:, 6] = -self.tensor[:, 6]
if points is not None:
assert isinstance(points, (torch.Tensor, np.ndarray))
if bev_direction == 'horizontal':
points[:, 0] = -points[:, 0]
elif bev_direction == 'vertical':
points[:, 2] = -points[:, 2]
return points
def in_range_bev(self, box_range): def in_range_bev(self, box_range):
"""Check whether the boxes are in the given range """Check whether the boxes are in the given range
......
...@@ -69,6 +69,9 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes): ...@@ -69,6 +69,9 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes):
Returns: Returns:
torch.Tensor: corners of each box with size (N, 8, 3) torch.Tensor: corners of each box with size (N, 8, 3)
""" """
# TODO: rotation_3d_in_axis function do not support
# empty tensor currently.
assert len(self.tensor) != 0
dims = self.dims dims = self.dims
corners_norm = torch.from_numpy( corners_norm = torch.from_numpy(
np.stack(np.unravel_index(np.arange(8), [2] * 3), axis=1)).to( np.stack(np.unravel_index(np.arange(8), [2] * 3), axis=1)).to(
...@@ -118,23 +121,31 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes): ...@@ -118,23 +121,31 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes):
bev_boxes = torch.cat([centers - dims / 2, centers + dims / 2], dim=-1) bev_boxes = torch.cat([centers - dims / 2, centers + dims / 2], dim=-1)
return bev_boxes return bev_boxes
def rotate(self, angle): def rotate(self, angle, points=None):
"""Calculate whether the points is in any of the boxes """Rotate boxes with points (optional) with the given angle.
Args: Args:
angle (float | torch.Tensor): rotation angle angle (float, torch.Tensor): Rotation angle.
points (torch.Tensor, numpy.ndarray, optional): Points to rotate.
Defaults to None.
Returns:
tuple or None: When ``points`` is None, the function returns None,
otherwise it returns the rotated points and the
rotation matrix ``rot_mat_T``.
""" """
if not isinstance(angle, torch.Tensor): if not isinstance(angle, torch.Tensor):
angle = self.tensor.new_tensor(angle) angle = self.tensor.new_tensor(angle)
rot_sin = torch.sin(angle) rot_sin = torch.sin(angle)
rot_cos = torch.cos(angle) rot_cos = torch.cos(angle)
rot_mat = self.tensor.new_tensor([[rot_cos, -rot_sin, 0], rot_mat_T = self.tensor.new_tensor([[rot_cos, -rot_sin, 0],
[rot_sin, rot_cos, 0], [0, 0, 1]]) [rot_sin, rot_cos, 0], [0, 0,
self.tensor[:, 0:3] = self.tensor[:, 0:3] @ rot_mat.T 1]]).T
self.tensor[:, 0:3] = self.tensor[:, 0:3] @ rot_mat_T
if self.with_yaw: if self.with_yaw:
self.tensor[:, 6] -= angle self.tensor[:, 6] -= angle
else: else:
corners_rot = self.corners @ rot_mat.T corners_rot = self.corners @ rot_mat_T
new_x_size = corners_rot[..., 0].max( new_x_size = corners_rot[..., 0].max(
dim=1, keepdim=True)[0] - corners_rot[..., 0].min( dim=1, keepdim=True)[0] - corners_rot[..., 0].min(
dim=1, keepdim=True)[0] dim=1, keepdim=True)[0]
...@@ -143,13 +154,28 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes): ...@@ -143,13 +154,28 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes):
dim=1, keepdim=True)[0] dim=1, keepdim=True)[0]
self.tensor[:, 3:5] = torch.cat((new_x_size, new_y_size), dim=-1) self.tensor[:, 3:5] = torch.cat((new_x_size, new_y_size), dim=-1)
def flip(self, bev_direction='horizontal'): if points is not None:
if isinstance(points, torch.Tensor):
points[:, :3] = points[:, :3] @ rot_mat_T
elif isinstance(points, np.ndarray):
rot_mat_T = rot_mat_T.numpy()
points[:, :3] = np.dot(points[:, :3], rot_mat_T)
else:
raise ValueError
return points, rot_mat_T
def flip(self, bev_direction='horizontal', points=None):
"""Flip the boxes in BEV along given BEV direction """Flip the boxes in BEV along given BEV direction
In Depth coordinates, it flips x (horizontal) or y (vertical) axis. In Depth coordinates, it flips x (horizontal) or y (vertical) axis.
Args: Args:
bev_direction (str): Flip direction (horizontal or vertical). bev_direction (str): Flip direction (horizontal or vertical).
points (torch.Tensor, numpy.ndarray, None): Points to flip.
Defaults to None.
Returns:
torch.Tensor, numpy.ndarray or None: Flipped points.
""" """
assert bev_direction in ('horizontal', 'vertical') assert bev_direction in ('horizontal', 'vertical')
if bev_direction == 'horizontal': if bev_direction == 'horizontal':
...@@ -161,6 +187,14 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes): ...@@ -161,6 +187,14 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes):
if self.with_yaw: if self.with_yaw:
self.tensor[:, 6] = -self.tensor[:, 6] self.tensor[:, 6] = -self.tensor[:, 6]
if points is not None:
assert isinstance(points, (torch.Tensor, np.ndarray))
if bev_direction == 'horizontal':
points[:, 0] = -points[:, 0]
elif bev_direction == 'vertical':
points[:, 1] = -points[:, 1]
return points
def in_range_bev(self, box_range): def in_range_bev(self, box_range):
"""Check whether the boxes are in the given range """Check whether the boxes are in the given range
......
...@@ -69,6 +69,9 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes): ...@@ -69,6 +69,9 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
Returns: Returns:
torch.Tensor: corners of each box with size (N, 8, 3) torch.Tensor: corners of each box with size (N, 8, 3)
""" """
# TODO: rotation_3d_in_axis function do not support
# empty tensor currently.
assert len(self.tensor) != 0
dims = self.dims dims = self.dims
corners_norm = torch.from_numpy( corners_norm = torch.from_numpy(
np.stack(np.unravel_index(np.arange(8), [2] * 3), axis=1)).to( np.stack(np.unravel_index(np.arange(8), [2] * 3), axis=1)).to(
...@@ -118,11 +121,18 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes): ...@@ -118,11 +121,18 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
bev_boxes = torch.cat([centers - dims / 2, centers + dims / 2], dim=-1) bev_boxes = torch.cat([centers - dims / 2, centers + dims / 2], dim=-1)
return bev_boxes return bev_boxes
def rotate(self, angle): def rotate(self, angle, points=None):
"""Calculate whether the points is in any of the boxes """Rotate boxes with points (optional) with the given angle.
Args: Args:
angle (float | torch.Tensor): rotation angle angle (float | torch.Tensor): Rotation angle.
points (torch.Tensor, numpy.ndarray, optional): Points to rotate.
Defaults to None.
Returns:
tuple or None: When ``points`` is None, the function returns None,
otherwise it returns the rotated points and the
rotation matrix ``rot_mat_T``.
""" """
if not isinstance(angle, torch.Tensor): if not isinstance(angle, torch.Tensor):
angle = self.tensor.new_tensor(angle) angle = self.tensor.new_tensor(angle)
...@@ -138,13 +148,28 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes): ...@@ -138,13 +148,28 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
# rotate velo vector # rotate velo vector
self.tensor[:, 7:9] = self.tensor[:, 7:9] @ rot_mat_T[:2, :2] self.tensor[:, 7:9] = self.tensor[:, 7:9] @ rot_mat_T[:2, :2]
def flip(self, bev_direction='horizontal'): if points is not None:
if isinstance(points, torch.Tensor):
points[:, :3] = points[:, :3] @ rot_mat_T
elif isinstance(points, np.ndarray):
rot_mat_T = rot_mat_T.numpy()
points[:, :3] = np.dot(points[:, :3], rot_mat_T)
else:
raise ValueError
return points, rot_mat_T
def flip(self, bev_direction='horizontal', points=None):
"""Flip the boxes in BEV along given BEV direction """Flip the boxes in BEV along given BEV direction
In LIDAR coordinates, it flips the y (horizontal) or x (vertical) axis. In LIDAR coordinates, it flips the y (horizontal) or x (vertical) axis.
Args: Args:
bev_direction (str): Flip direction (horizontal or vertical). bev_direction (str): Flip direction (horizontal or vertical).
points (torch.Tensor, numpy.ndarray, None): Points to flip.
Defaults to None.
Returns:
torch.Tensor, numpy.ndarray or None: Flipped points.
""" """
assert bev_direction in ('horizontal', 'vertical') assert bev_direction in ('horizontal', 'vertical')
if bev_direction == 'horizontal': if bev_direction == 'horizontal':
...@@ -156,6 +181,14 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes): ...@@ -156,6 +181,14 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
if self.with_yaw: if self.with_yaw:
self.tensor[:, 6] = -self.tensor[:, 6] self.tensor[:, 6] = -self.tensor[:, 6]
if points is not None:
assert isinstance(points, (torch.Tensor, np.ndarray))
if bev_direction == 'horizontal':
points[:, 1] = -points[:, 1]
elif bev_direction == 'vertical':
points[:, 0] = -points[:, 0]
return points
def in_range_bev(self, box_range): def in_range_bev(self, box_range):
"""Check whether the boxes are in the given range """Check whether the boxes are in the given range
......
import torch import torch
def bbox3d_mapping_back(bboxes, scale_factor, flip): def bbox3d_mapping_back(bboxes, scale_factor, flip_horizontal, flip_vertical):
"""Map bboxes from testing scale to original image scale""" """Map bboxes from testing scale to original image scale"""
new_bboxes = bboxes.clone() new_bboxes = bboxes.clone()
if flip: if flip_horizontal:
new_bboxes.flip() new_bboxes.flip('horizontal')
if flip_vertical:
new_bboxes.flip('vertical')
new_bboxes.scale(1 / scale_factor) new_bboxes.scale(1 / scale_factor)
return new_bboxes return new_bboxes
......
...@@ -33,10 +33,12 @@ def merge_aug_bboxes_3d(aug_results, img_metas, test_cfg): ...@@ -33,10 +33,12 @@ def merge_aug_bboxes_3d(aug_results, img_metas, test_cfg):
for bboxes, img_info in zip(aug_results, img_metas): for bboxes, img_info in zip(aug_results, img_metas):
scale_factor = img_info[0]['pcd_scale_factor'] scale_factor = img_info[0]['pcd_scale_factor']
flip = img_info[0]['pcd_flip'] pcd_horizontal_flip = img_info[0]['pcd_horizontal_flip']
pcd_vertical_flip = img_info[0]['pcd_vertical_flip']
recovered_scores.append(bboxes['scores_3d']) recovered_scores.append(bboxes['scores_3d'])
recovered_labels.append(bboxes['labels_3d']) recovered_labels.append(bboxes['labels_3d'])
bboxes = bbox3d_mapping_back(bboxes['boxes_3d'], scale_factor, flip) bboxes = bbox3d_mapping_back(bboxes['boxes_3d'], scale_factor,
pcd_horizontal_flip, pcd_vertical_flip)
recovered_bboxes.append(bboxes) recovered_bboxes.append(bboxes)
aug_bboxes = recovered_bboxes[0].cat(recovered_bboxes) aug_bboxes = recovered_bboxes[0].cat(recovered_bboxes)
......
...@@ -3,12 +3,11 @@ from .custom_3d import Custom3DDataset ...@@ -3,12 +3,11 @@ from .custom_3d import Custom3DDataset
from .kitti2d_dataset import Kitti2DDataset from .kitti2d_dataset import Kitti2DDataset
from .kitti_dataset import KittiDataset from .kitti_dataset import KittiDataset
from .nuscenes_dataset import NuScenesDataset from .nuscenes_dataset import NuScenesDataset
from .pipelines import (GlobalRotScaleTrans, IndoorFlipData, from .pipelines import (GlobalRotScaleTrans, IndoorPointSample,
IndoorGlobalRotScaleTrans, IndoorPointSample, LoadAnnotations3D, LoadPointsFromFile,
IndoorPointsColorJitter, LoadAnnotations3D, NormalizePointsColor, ObjectNoise, ObjectRangeFilter,
LoadPointsFromFile, NormalizePointsColor, ObjectNoise, ObjectSample, PointShuffle, PointsRangeFilter,
ObjectRangeFilter, ObjectSample, PointShuffle, RandomFlip3D)
PointsRangeFilter, RandomFlip3D)
from .scannet_dataset import ScanNetDataset from .scannet_dataset import ScanNetDataset
from .sunrgbd_dataset import SUNRGBDDataset from .sunrgbd_dataset import SUNRGBDDataset
...@@ -20,7 +19,5 @@ __all__ = [ ...@@ -20,7 +19,5 @@ __all__ = [
'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans', 'PointShuffle', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans', 'PointShuffle',
'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D', 'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D',
'LoadPointsFromFile', 'NormalizePointsColor', 'IndoorPointSample', 'LoadPointsFromFile', 'NormalizePointsColor', 'IndoorPointSample',
'LoadAnnotations3D', 'IndoorPointsColorJitter', 'LoadAnnotations3D', 'SUNRGBDDataset', 'ScanNetDataset', 'Custom3DDataset'
'IndoorGlobalRotScaleTrans', 'IndoorFlipData', 'SUNRGBDDataset',
'ScanNetDataset', 'Custom3DDataset'
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment