Unverified Commit 23768cba authored by encore-zhou's avatar encore-zhou Committed by GitHub
Browse files

[Enhance] Use Points structure in augmentation and models (#204)

* add h3d backbone

* add h3d backbone

* add h3dnet

* modify scannet config

* fix bugs for proposal refine

* fix bugs for test backbone

* add primitive head test

* modify h3dhead

* modify h3d head

* update loss weight config

* fix bugs for h3d head loss

* modify h3d head get targets function

* update h3dnet base config

* modify weighted loss

* Revert "Merge branch 'h3d_u2' into 'master'"

This reverts merge request !5

* modify pipeline

* modify kitti pipeline

* fix bugs for points rotation

* modify multi sweeps

* modify multi sweep points

* fix bugs for points slicing

* modify BackgroundPointsFilter

* modify pipeline

* modify unittest

* modify unittest

* modify docstring

* modify config files

* update configs

* modify docstring
parent a97fc87b
...@@ -37,6 +37,7 @@ file_client_args = dict( ...@@ -37,6 +37,7 @@ file_client_args = dict(
train_pipeline = [ train_pipeline = [
dict( dict(
type='LoadPointsFromFile', type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=5, load_dim=5,
use_dim=5, use_dim=5,
file_client_args=file_client_args), file_client_args=file_client_args),
......
...@@ -37,6 +37,7 @@ file_client_args = dict( ...@@ -37,6 +37,7 @@ file_client_args = dict(
train_pipeline = [ train_pipeline = [
dict( dict(
type='LoadPointsFromFile', type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=5, load_dim=5,
use_dim=5, use_dim=5,
file_client_args=file_client_args), file_client_args=file_client_args),
......
...@@ -135,7 +135,7 @@ img_norm_cfg = dict( ...@@ -135,7 +135,7 @@ img_norm_cfg = dict(
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
input_modality = dict(use_lidar=True, use_camera=True) input_modality = dict(use_lidar=True, use_camera=True)
train_pipeline = [ train_pipeline = [
dict(type='LoadPointsFromFile', load_dim=4, use_dim=4), dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict( dict(
...@@ -160,7 +160,7 @@ train_pipeline = [ ...@@ -160,7 +160,7 @@ train_pipeline = [
keys=['points', 'img', 'gt_bboxes_3d', 'gt_labels_3d']), keys=['points', 'img', 'gt_bboxes_3d', 'gt_labels_3d']),
] ]
test_pipeline = [ test_pipeline = [
dict(type='LoadPointsFromFile', load_dim=4, use_dim=4), dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(type='LoadImageFromFile'), dict(type='LoadImageFromFile'),
dict( dict(
type='MultiScaleFlipAug3D', type='MultiScaleFlipAug3D',
......
...@@ -210,7 +210,7 @@ db_sampler = dict( ...@@ -210,7 +210,7 @@ db_sampler = dict(
classes=class_names, classes=class_names,
sample_groups=dict(Car=12, Pedestrian=6, Cyclist=6)) sample_groups=dict(Car=12, Pedestrian=6, Cyclist=6))
train_pipeline = [ train_pipeline = [
dict(type='LoadPointsFromFile', load_dim=4, use_dim=4), dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler), dict(type='ObjectSample', db_sampler=db_sampler),
dict( dict(
...@@ -232,7 +232,7 @@ train_pipeline = [ ...@@ -232,7 +232,7 @@ train_pipeline = [
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
] ]
test_pipeline = [ test_pipeline = [
dict(type='LoadPointsFromFile', load_dim=4, use_dim=4), dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict( dict(
type='MultiScaleFlipAug3D', type='MultiScaleFlipAug3D',
img_scale=(1333, 800), img_scale=(1333, 800),
......
...@@ -82,7 +82,7 @@ db_sampler = dict( ...@@ -82,7 +82,7 @@ db_sampler = dict(
classes=class_names, classes=class_names,
sample_groups=dict(Car=15)) sample_groups=dict(Car=15))
train_pipeline = [ train_pipeline = [
dict(type='LoadPointsFromFile', load_dim=4, use_dim=4), dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler), dict(type='ObjectSample', db_sampler=db_sampler),
dict( dict(
...@@ -104,7 +104,7 @@ train_pipeline = [ ...@@ -104,7 +104,7 @@ train_pipeline = [
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
] ]
test_pipeline = [ test_pipeline = [
dict(type='LoadPointsFromFile', load_dim=4, use_dim=4), dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict( dict(
type='MultiScaleFlipAug3D', type='MultiScaleFlipAug3D',
img_scale=(1333, 800), img_scale=(1333, 800),
......
...@@ -21,7 +21,7 @@ db_sampler = dict( ...@@ -21,7 +21,7 @@ db_sampler = dict(
# PointPillars uses different augmentation hyper parameters # PointPillars uses different augmentation hyper parameters
train_pipeline = [ train_pipeline = [
dict(type='LoadPointsFromFile', load_dim=4, use_dim=4), dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler), dict(type='ObjectSample', db_sampler=db_sampler),
dict( dict(
...@@ -42,7 +42,7 @@ train_pipeline = [ ...@@ -42,7 +42,7 @@ train_pipeline = [
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
] ]
test_pipeline = [ test_pipeline = [
dict(type='LoadPointsFromFile', load_dim=4, use_dim=4), dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict( dict(
type='MultiScaleFlipAug3D', type='MultiScaleFlipAug3D',
img_scale=(1333, 800), img_scale=(1333, 800),
......
...@@ -40,7 +40,7 @@ db_sampler = dict( ...@@ -40,7 +40,7 @@ db_sampler = dict(
classes=class_names) classes=class_names)
train_pipeline = [ train_pipeline = [
dict(type='LoadPointsFromFile', load_dim=4, use_dim=4), dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler), dict(type='ObjectSample', db_sampler=db_sampler),
dict( dict(
...@@ -61,7 +61,7 @@ train_pipeline = [ ...@@ -61,7 +61,7 @@ train_pipeline = [
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
] ]
test_pipeline = [ test_pipeline = [
dict(type='LoadPointsFromFile', load_dim=4, use_dim=4), dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict( dict(
type='MultiScaleFlipAug3D', type='MultiScaleFlipAug3D',
img_scale=(1333, 800), img_scale=(1333, 800),
......
...@@ -24,7 +24,7 @@ db_sampler = dict( ...@@ -24,7 +24,7 @@ db_sampler = dict(
type='LoadPointsFromFile', load_dim=5, use_dim=[0, 1, 2, 3, 4])) type='LoadPointsFromFile', load_dim=5, use_dim=[0, 1, 2, 3, 4]))
train_pipeline = [ train_pipeline = [
dict(type='LoadPointsFromFile', load_dim=6, use_dim=5), dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=6, use_dim=5),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler), dict(type='ObjectSample', db_sampler=db_sampler),
dict( dict(
...@@ -44,7 +44,7 @@ train_pipeline = [ ...@@ -44,7 +44,7 @@ train_pipeline = [
] ]
test_pipeline = [ test_pipeline = [
dict(type='LoadPointsFromFile', load_dim=6, use_dim=5), dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=6, use_dim=5),
dict( dict(
type='MultiScaleFlipAug3D', type='MultiScaleFlipAug3D',
img_scale=(1333, 800), img_scale=(1333, 800),
......
...@@ -13,7 +13,7 @@ class_names = [ ...@@ -13,7 +13,7 @@ class_names = [
] ]
train_pipeline = [ train_pipeline = [
dict(type='LoadPointsFromFile', load_dim=5, use_dim=5), dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=5, use_dim=5),
dict(type='LoadPointsFromMultiSweeps', sweeps_num=10), dict(type='LoadPointsFromMultiSweeps', sweeps_num=10),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict( dict(
...@@ -33,7 +33,7 @@ train_pipeline = [ ...@@ -33,7 +33,7 @@ train_pipeline = [
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
] ]
test_pipeline = [ test_pipeline = [
dict(type='LoadPointsFromFile', load_dim=5, use_dim=5), dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=5, use_dim=5),
dict(type='LoadPointsFromMultiSweeps', sweeps_num=10), dict(type='LoadPointsFromMultiSweeps', sweeps_num=10),
dict( dict(
type='MultiScaleFlipAug3D', type='MultiScaleFlipAug3D',
......
...@@ -13,7 +13,7 @@ class_names = [ ...@@ -13,7 +13,7 @@ class_names = [
] ]
train_pipeline = [ train_pipeline = [
dict(type='LoadPointsFromFile', load_dim=5, use_dim=5), dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=5, use_dim=5),
dict(type='LoadPointsFromMultiSweeps', sweeps_num=10), dict(type='LoadPointsFromMultiSweeps', sweeps_num=10),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict( dict(
...@@ -33,7 +33,7 @@ train_pipeline = [ ...@@ -33,7 +33,7 @@ train_pipeline = [
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
] ]
test_pipeline = [ test_pipeline = [
dict(type='LoadPointsFromFile', load_dim=5, use_dim=5), dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=5, use_dim=5),
dict(type='LoadPointsFromMultiSweeps', sweeps_num=10), dict(type='LoadPointsFromMultiSweeps', sweeps_num=10),
dict( dict(
type='MultiScaleFlipAug3D', type='MultiScaleFlipAug3D',
......
import numpy as np import numpy as np
import torch import torch
from mmdet3d.core.points import BasePoints
from .base_box3d import BaseInstance3DBoxes from .base_box3d import BaseInstance3DBoxes
from .utils import limit_period, rotation_3d_in_axis from .utils import limit_period, rotation_3d_in_axis
...@@ -96,7 +97,8 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes): ...@@ -96,7 +97,8 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
@property @property
def corners(self): def corners(self):
"""torch.Tensor: Coordinates of corners of all the boxes in shape (N, 8, 3). """torch.Tensor: Coordinates of corners of all the boxes in
shape (N, 8, 3).
Convert the boxes to in clockwise order, in the form of Convert the boxes to in clockwise order, in the form of
(x0y0z0, x0y0z1, x0y1z1, x0y1z0, x1y0z0, x1y0z1, x1y1z1, x1y1z0) (x0y0z0, x0y0z1, x0y1z1, x0y1z0, x1y0z0, x1y0z1, x1y1z1, x1y1z0)
...@@ -168,8 +170,8 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes): ...@@ -168,8 +170,8 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
Args: Args:
angle (float, torch.Tensor): Rotation angle. angle (float, torch.Tensor): Rotation angle.
points (torch.Tensor, numpy.ndarray, optional): Points to rotate. points (torch.Tensor, numpy.ndarray, :obj:`BasePoints`, optional):
Defaults to None. Points to rotate. Defaults to None.
Returns: Returns:
tuple or None: When ``points`` is None, the function returns \ tuple or None: When ``points`` is None, the function returns \
...@@ -192,6 +194,9 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes): ...@@ -192,6 +194,9 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
elif isinstance(points, np.ndarray): elif isinstance(points, np.ndarray):
rot_mat_T = rot_mat_T.numpy() rot_mat_T = rot_mat_T.numpy()
points[:, :3] = np.dot(points[:, :3], rot_mat_T) points[:, :3] = np.dot(points[:, :3], rot_mat_T)
elif isinstance(points, BasePoints):
# clockwise
points.rotate(-angle)
else: else:
raise ValueError raise ValueError
return points, rot_mat_T return points, rot_mat_T
...@@ -203,8 +208,8 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes): ...@@ -203,8 +208,8 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
Args: Args:
bev_direction (str): Flip direction (horizontal or vertical). bev_direction (str): Flip direction (horizontal or vertical).
points (torch.Tensor, numpy.ndarray, None): Points to flip. points (torch.Tensor, numpy.ndarray, :obj:`BasePoints`, None):
Defaults to None. Points to flip. Defaults to None.
Returns: Returns:
torch.Tensor, numpy.ndarray or None: Flipped points. torch.Tensor, numpy.ndarray or None: Flipped points.
...@@ -220,11 +225,14 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes): ...@@ -220,11 +225,14 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
self.tensor[:, 6] = -self.tensor[:, 6] self.tensor[:, 6] = -self.tensor[:, 6]
if points is not None: if points is not None:
assert isinstance(points, (torch.Tensor, np.ndarray)) assert isinstance(points, (torch.Tensor, np.ndarray, BasePoints))
if bev_direction == 'horizontal': if isinstance(points, (torch.Tensor, np.ndarray)):
points[:, 0] = -points[:, 0] if bev_direction == 'horizontal':
elif bev_direction == 'vertical': points[:, 0] = -points[:, 0]
points[:, 2] = -points[:, 2] elif bev_direction == 'vertical':
points[:, 2] = -points[:, 2]
elif isinstance(points, BasePoints):
points.flip(bev_direction)
return points return points
def in_range_bev(self, box_range): def in_range_bev(self, box_range):
......
import numpy as np import numpy as np
import torch import torch
from mmdet3d.core.points import BasePoints
from mmdet3d.ops import points_in_boxes_batch from mmdet3d.ops import points_in_boxes_batch
from .base_box3d import BaseInstance3DBoxes from .base_box3d import BaseInstance3DBoxes
from .utils import limit_period, rotation_3d_in_axis from .utils import limit_period, rotation_3d_in_axis
...@@ -114,8 +115,8 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes): ...@@ -114,8 +115,8 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes):
Args: Args:
angle (float, torch.Tensor): Rotation angle. angle (float, torch.Tensor): Rotation angle.
points (torch.Tensor, numpy.ndarray, optional): Points to rotate. points (torch.Tensor, numpy.ndarray, :obj:`BasePoints`, optional):
Defaults to None. Points to rotate. Defaults to None.
Returns: Returns:
tuple or None: When ``points`` is None, the function returns \ tuple or None: When ``points`` is None, the function returns \
...@@ -148,6 +149,9 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes): ...@@ -148,6 +149,9 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes):
elif isinstance(points, np.ndarray): elif isinstance(points, np.ndarray):
rot_mat_T = rot_mat_T.numpy() rot_mat_T = rot_mat_T.numpy()
points[:, :3] = np.dot(points[:, :3], rot_mat_T) points[:, :3] = np.dot(points[:, :3], rot_mat_T)
elif isinstance(points, BasePoints):
# anti-clockwise
points.rotate(angle)
else: else:
raise ValueError raise ValueError
return points, rot_mat_T return points, rot_mat_T
...@@ -159,8 +163,8 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes): ...@@ -159,8 +163,8 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes):
Args: Args:
bev_direction (str): Flip direction (horizontal or vertical). bev_direction (str): Flip direction (horizontal or vertical).
points (torch.Tensor, numpy.ndarray, None): Points to flip. points (torch.Tensor, numpy.ndarray, :obj:`BasePoints`, None):
Defaults to None. Points to flip. Defaults to None.
Returns: Returns:
torch.Tensor, numpy.ndarray or None: Flipped points. torch.Tensor, numpy.ndarray or None: Flipped points.
...@@ -176,11 +180,14 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes): ...@@ -176,11 +180,14 @@ class DepthInstance3DBoxes(BaseInstance3DBoxes):
self.tensor[:, 6] = -self.tensor[:, 6] self.tensor[:, 6] = -self.tensor[:, 6]
if points is not None: if points is not None:
assert isinstance(points, (torch.Tensor, np.ndarray)) assert isinstance(points, (torch.Tensor, np.ndarray, BasePoints))
if bev_direction == 'horizontal': if isinstance(points, (torch.Tensor, np.ndarray)):
points[:, 0] = -points[:, 0] if bev_direction == 'horizontal':
elif bev_direction == 'vertical': points[:, 0] = -points[:, 0]
points[:, 1] = -points[:, 1] elif bev_direction == 'vertical':
points[:, 1] = -points[:, 1]
elif isinstance(points, BasePoints):
points.flip(bev_direction)
return points return points
def in_range_bev(self, box_range): def in_range_bev(self, box_range):
......
import numpy as np import numpy as np
import torch import torch
from mmdet3d.core.points import BasePoints
from mmdet3d.ops.roiaware_pool3d import points_in_boxes_gpu from mmdet3d.ops.roiaware_pool3d import points_in_boxes_gpu
from .base_box3d import BaseInstance3DBoxes from .base_box3d import BaseInstance3DBoxes
from .utils import limit_period, rotation_3d_in_axis from .utils import limit_period, rotation_3d_in_axis
...@@ -114,8 +115,8 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes): ...@@ -114,8 +115,8 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
Args: Args:
angle (float | torch.Tensor): Rotation angle. angle (float | torch.Tensor): Rotation angle.
points (torch.Tensor, numpy.ndarray, optional): Points to rotate. points (torch.Tensor, numpy.ndarray, :obj:`BasePoints`, optional):
Defaults to None. Points to rotate. Defaults to None.
Returns: Returns:
tuple or None: When ``points`` is None, the function returns \ tuple or None: When ``points`` is None, the function returns \
...@@ -142,6 +143,9 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes): ...@@ -142,6 +143,9 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
elif isinstance(points, np.ndarray): elif isinstance(points, np.ndarray):
rot_mat_T = rot_mat_T.numpy() rot_mat_T = rot_mat_T.numpy()
points[:, :3] = np.dot(points[:, :3], rot_mat_T) points[:, :3] = np.dot(points[:, :3], rot_mat_T)
elif isinstance(points, BasePoints):
# clockwise
points.rotate(-angle)
else: else:
raise ValueError raise ValueError
return points, rot_mat_T return points, rot_mat_T
...@@ -153,8 +157,8 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes): ...@@ -153,8 +157,8 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
Args: Args:
bev_direction (str): Flip direction (horizontal or vertical). bev_direction (str): Flip direction (horizontal or vertical).
points (torch.Tensor, numpy.ndarray, None): Points to flip. points (torch.Tensor, numpy.ndarray, :obj:`BasePoints`, None):
Defaults to None. Points to flip. Defaults to None.
Returns: Returns:
torch.Tensor, numpy.ndarray or None: Flipped points. torch.Tensor, numpy.ndarray or None: Flipped points.
...@@ -170,11 +174,14 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes): ...@@ -170,11 +174,14 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
self.tensor[:, 6] = -self.tensor[:, 6] self.tensor[:, 6] = -self.tensor[:, 6]
if points is not None: if points is not None:
assert isinstance(points, (torch.Tensor, np.ndarray)) assert isinstance(points, (torch.Tensor, np.ndarray, BasePoints))
if bev_direction == 'horizontal': if isinstance(points, (torch.Tensor, np.ndarray)):
points[:, 1] = -points[:, 1] if bev_direction == 'horizontal':
elif bev_direction == 'vertical': points[:, 1] = -points[:, 1]
points[:, 0] = -points[:, 0] elif bev_direction == 'vertical':
points[:, 0] = -points[:, 0]
elif isinstance(points, BasePoints):
points.flip(bev_direction)
return points return points
def in_range_bev(self, box_range): def in_range_bev(self, box_range):
......
...@@ -4,3 +4,26 @@ from .depth_points import DepthPoints ...@@ -4,3 +4,26 @@ from .depth_points import DepthPoints
from .lidar_points import LiDARPoints from .lidar_points import LiDARPoints
__all__ = ['BasePoints', 'CameraPoints', 'DepthPoints', 'LiDARPoints'] __all__ = ['BasePoints', 'CameraPoints', 'DepthPoints', 'LiDARPoints']
def get_points_type(points_type):
"""Get the class of points according to coordinate type.
Args:
points_type (str): The type of points coordinate.
The valid value are "CAMERA", "LIDAR", or "DEPTH".
Returns:
class: Points type.
"""
if points_type == 'CAMERA':
points_cls = CameraPoints
elif points_type == 'LIDAR':
points_cls = LiDARPoints
elif points_type == 'DEPTH':
points_cls = DepthPoints
else:
raise ValueError('Only "points_type" of "CAMERA", "LIDAR", or "DEPTH"'
f' are supported, got {points_type}')
return points_cls
import numpy as np
import torch import torch
from abc import abstractmethod from abc import abstractmethod
...@@ -18,6 +19,7 @@ class BasePoints(object): ...@@ -18,6 +19,7 @@ class BasePoints(object):
Each row is (x, y, z, ...). Each row is (x, y, z, ...).
attribute_dims (bool): Dictinory to indicate the meaning of extra attribute_dims (bool): Dictinory to indicate the meaning of extra
dimension. Default to None. dimension. Default to None.
rotation_axis (int): Default rotation axis for points rotation.
""" """
def __init__(self, tensor, points_dim=3, attribute_dims=None): def __init__(self, tensor, points_dim=3, attribute_dims=None):
...@@ -37,6 +39,7 @@ class BasePoints(object): ...@@ -37,6 +39,7 @@ class BasePoints(object):
self.tensor = tensor self.tensor = tensor
self.points_dim = points_dim self.points_dim = points_dim
self.attribute_dims = attribute_dims self.attribute_dims = attribute_dims
self.rotation_axis = 0
@property @property
def coord(self): def coord(self):
...@@ -61,24 +64,32 @@ class BasePoints(object): ...@@ -61,24 +64,32 @@ class BasePoints(object):
else: else:
return None return None
@property
def shape(self):
"""torch.Shape: Shape of points."""
return self.tensor.shape
def shuffle(self): def shuffle(self):
"""Shuffle the points.""" """Shuffle the points."""
self.tensor = self.tensor[torch.randperm( self.tensor = self.tensor[torch.randperm(
self.__len__(), device=self.tensor.device)] self.__len__(), device=self.tensor.device)]
def rotate(self, rotation, axis=2): def rotate(self, rotation, axis=None):
"""Rotate points with the given rotation matrix or angle. """Rotate points with the given rotation matrix or angle.
Args: Args:
rotation (float, np.ndarray, torch.Tensor): Rotation matrix rotation (float, np.ndarray, torch.Tensor): Rotation matrix
or angle. or angle.
axis (int): Axis to rotate at. Defaults to 2. axis (int): Axis to rotate at. Defaults to None.
""" """
if not isinstance(rotation, torch.Tensor): if not isinstance(rotation, torch.Tensor):
rotation = self.tensor.new_tensor(rotation) rotation = self.tensor.new_tensor(rotation)
assert rotation.shape == torch.Size([3, 3]) or \ assert rotation.shape == torch.Size([3, 3]) or \
rotation.numel() == 1 rotation.numel() == 1
if axis is None:
axis = self.rotation_axis
if rotation.numel() == 1: if rotation.numel() == 1:
rot_sin = torch.sin(rotation) rot_sin = torch.sin(rotation)
rot_cos = torch.cos(rotation) rot_cos = torch.cos(rotation)
...@@ -204,12 +215,14 @@ class BasePoints(object): ...@@ -204,12 +215,14 @@ class BasePoints(object):
3. `new_points = points[vector]`: 3. `new_points = points[vector]`:
where vector is a torch.BoolTensor with `length = len(points)`. where vector is a torch.BoolTensor with `length = len(points)`.
Nonzero elements in the vector will be selected. Nonzero elements in the vector will be selected.
4. `new_points = points[3:11, vector]`:
return a slice of points and attribute dims.
Note that the returned Points might share storage with this Points, Note that the returned Points might share storage with this Points,
subject to Pytorch's indexing semantics. subject to Pytorch's indexing semantics.
Returns: Returns:
:obj:`BaseInstancesPints`: A new object of \ :obj:`BasePoints`: A new object of \
:class:`BaseInstancesPints` after indexing. :class:`BasePoints` after indexing.
""" """
original_type = type(self) original_type = type(self)
if isinstance(item, int): if isinstance(item, int):
...@@ -217,11 +230,43 @@ class BasePoints(object): ...@@ -217,11 +230,43 @@ class BasePoints(object):
self.tensor[item].view(1, -1), self.tensor[item].view(1, -1),
points_dim=self.points_dim, points_dim=self.points_dim,
attribute_dims=self.attribute_dims) attribute_dims=self.attribute_dims)
p = self.tensor[item] elif isinstance(item, tuple) and len(item) == 2:
if isinstance(item[1], slice):
start = 0 if item[1].start is None else item[1].start
stop = self.tensor.shape[1] + \
1 if item[1].stop is None else item[1].stop
step = 1 if item[1].step is None else item[1].step
item[1] = list(range(start, stop, step))
p = self.tensor[item[0], item[1]]
keep_dims = list(
set(item[1]).intersection(set(range(3, self.tensor.shape[1]))))
if self.attribute_dims is not None:
attribute_dims = self.attribute_dims.copy()
for key in self.attribute_dims.keys():
cur_attribute_dim = attribute_dims[key]
if isinstance(cur_attribute_dim, int):
cur_attribute_dims = [cur_attribute_dim]
intersect_attr = list(
set(cur_attribute_dims).intersection(set(keep_dims)))
if len(intersect_attr) == 1:
attribute_dims[key] = intersect_attr[0]
elif len(intersect_attr) > 1:
attribute_dims[key] = intersect_attr
else:
attribute_dims.pop(key)
else:
attribute_dims = None
elif isinstance(item, (slice, np.ndarray, torch.Tensor)):
p = self.tensor[item]
attribute_dims = self.attribute_dims
else:
raise NotImplementedError(f'Invalid slice {item}!')
assert p.dim() == 2, \ assert p.dim() == 2, \
f'Indexing on Points with {item} failed to return a matrix!' f'Indexing on Points with {item} failed to return a matrix!'
return original_type( return original_type(
p, points_dim=self.points_dim, attribute_dims=self.attribute_dims) p, points_dim=p.shape[1], attribute_dims=attribute_dims)
def __len__(self): def __len__(self):
"""int: Number of points in the current object.""" """int: Number of points in the current object."""
...@@ -236,10 +281,10 @@ class BasePoints(object): ...@@ -236,10 +281,10 @@ class BasePoints(object):
"""Concatenate a list of Points into a single Points. """Concatenate a list of Points into a single Points.
Args: Args:
points_list (list[:obj:`BaseInstancesPoints`]): List of points. points_list (list[:obj:`BasePoints`]): List of points.
Returns: Returns:
:obj:`BaseInstancesPoints`: The concatenated Points. :obj:`BasePoints`: The concatenated Points.
""" """
assert isinstance(points_list, (list, tuple)) assert isinstance(points_list, (list, tuple))
if len(points_list) == 0: if len(points_list) == 0:
......
...@@ -17,11 +17,13 @@ class CameraPoints(BasePoints): ...@@ -17,11 +17,13 @@ class CameraPoints(BasePoints):
Each row is (x, y, z, ...). Each row is (x, y, z, ...).
attribute_dims (bool): Dictinory to indicate the meaning of extra attribute_dims (bool): Dictinory to indicate the meaning of extra
dimension. Default to None. dimension. Default to None.
rotation_axis (int): Default rotation axis for points rotation.
""" """
def __init__(self, tensor, points_dim=3, attribute_dims=None): def __init__(self, tensor, points_dim=3, attribute_dims=None):
super(CameraPoints, self).__init__( super(CameraPoints, self).__init__(
tensor, points_dim=points_dim, attribute_dims=attribute_dims) tensor, points_dim=points_dim, attribute_dims=attribute_dims)
self.rotation_axis = 1
def flip(self, bev_direction='horizontal'): def flip(self, bev_direction='horizontal'):
"""Flip the boxes in BEV along given BEV direction.""" """Flip the boxes in BEV along given BEV direction."""
......
...@@ -17,11 +17,13 @@ class DepthPoints(BasePoints): ...@@ -17,11 +17,13 @@ class DepthPoints(BasePoints):
Each row is (x, y, z, ...). Each row is (x, y, z, ...).
attribute_dims (bool): Dictinory to indicate the meaning of extra attribute_dims (bool): Dictinory to indicate the meaning of extra
dimension. Default to None. dimension. Default to None.
rotation_axis (int): Default rotation axis for points rotation.
""" """
def __init__(self, tensor, points_dim=3, attribute_dims=None): def __init__(self, tensor, points_dim=3, attribute_dims=None):
super(DepthPoints, self).__init__( super(DepthPoints, self).__init__(
tensor, points_dim=points_dim, attribute_dims=attribute_dims) tensor, points_dim=points_dim, attribute_dims=attribute_dims)
self.rotation_axis = 2
def flip(self, bev_direction='horizontal'): def flip(self, bev_direction='horizontal'):
"""Flip the boxes in BEV along given BEV direction.""" """Flip the boxes in BEV along given BEV direction."""
......
...@@ -17,11 +17,13 @@ class LiDARPoints(BasePoints): ...@@ -17,11 +17,13 @@ class LiDARPoints(BasePoints):
Each row is (x, y, z, ...). Each row is (x, y, z, ...).
attribute_dims (bool): Dictinory to indicate the meaning of extra attribute_dims (bool): Dictinory to indicate the meaning of extra
dimension. Default to None. dimension. Default to None.
rotation_axis (int): Default rotation axis for points rotation.
""" """
def __init__(self, tensor, points_dim=3, attribute_dims=None): def __init__(self, tensor, points_dim=3, attribute_dims=None):
super(LiDARPoints, self).__init__( super(LiDARPoints, self).__init__(
tensor, points_dim=points_dim, attribute_dims=attribute_dims) tensor, points_dim=points_dim, attribute_dims=attribute_dims)
self.rotation_axis = 2
def flip(self, bev_direction='horizontal'): def flip(self, bev_direction='horizontal'):
"""Flip the boxes in BEV along given BEV direction.""" """Flip the boxes in BEV along given BEV direction."""
......
...@@ -100,6 +100,7 @@ class DataBaseSampler(object): ...@@ -100,6 +100,7 @@ class DataBaseSampler(object):
classes=None, classes=None,
points_loader=dict( points_loader=dict(
type='LoadPointsFromFile', type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=4, load_dim=4,
use_dim=[0, 1, 2, 3])): use_dim=[0, 1, 2, 3])):
super().__init__() super().__init__()
...@@ -253,7 +254,7 @@ class DataBaseSampler(object): ...@@ -253,7 +254,7 @@ class DataBaseSampler(object):
info['path']) if self.data_root else info['path'] info['path']) if self.data_root else info['path']
results = dict(pts_filename=file_path) results = dict(pts_filename=file_path)
s_points = self.points_loader(results)['points'] s_points = self.points_loader(results)['points']
s_points[:, :3] += info['box3d_lidar'][:3] s_points.translate(info['box3d_lidar'][:3])
count += 1 count += 1
...@@ -267,7 +268,7 @@ class DataBaseSampler(object): ...@@ -267,7 +268,7 @@ class DataBaseSampler(object):
'gt_bboxes_3d': 'gt_bboxes_3d':
sampled_gt_bboxes, sampled_gt_bboxes,
'points': 'points':
np.concatenate(s_points_list, axis=0), s_points_list[0].cat(s_points_list),
'group_ids': 'group_ids':
np.arange(gt_bboxes.shape[0], np.arange(gt_bboxes.shape[0],
gt_bboxes.shape[0] + len(sampled)) gt_bboxes.shape[0] + len(sampled))
......
...@@ -2,6 +2,7 @@ import numpy as np ...@@ -2,6 +2,7 @@ import numpy as np
from mmcv.parallel import DataContainer as DC from mmcv.parallel import DataContainer as DC
from mmdet3d.core.bbox import BaseInstance3DBoxes from mmdet3d.core.bbox import BaseInstance3DBoxes
from mmdet3d.core.points import BasePoints
from mmdet.datasets.builder import PIPELINES from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import to_tensor from mmdet.datasets.pipelines import to_tensor
...@@ -71,6 +72,7 @@ class DefaultFormatBundle(object): ...@@ -71,6 +72,7 @@ class DefaultFormatBundle(object):
if 'gt_semantic_seg' in results: if 'gt_semantic_seg' in results:
results['gt_semantic_seg'] = DC( results['gt_semantic_seg'] = DC(
to_tensor(results['gt_semantic_seg'][None, ...]), stack=True) to_tensor(results['gt_semantic_seg'][None, ...]), stack=True)
return results return results
def __repr__(self): def __repr__(self):
...@@ -202,9 +204,11 @@ class DefaultFormatBundle3D(DefaultFormatBundle): ...@@ -202,9 +204,11 @@ class DefaultFormatBundle3D(DefaultFormatBundle):
default bundle. default bundle.
""" """
# Format 3D data # Format 3D data
for key in [ if 'points' in results:
'voxels', 'coors', 'voxel_centers', 'num_points', 'points' assert isinstance(results['points'], BasePoints)
]: results['points'] = DC(results['points'].tensor)
for key in ['voxels', 'coors', 'voxel_centers', 'num_points']:
if key not in results: if key not in results:
continue continue
results[key] = DC(to_tensor(results[key]), stack=False) results[key] = DC(to_tensor(results[key]), stack=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment