Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
30101f73
"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "2cc854f8cb5b9670fc53134f8104569c60d535be"
Commit
30101f73
authored
Jun 08, 2020
by
wuyuefeng
Committed by
zhangwenwei
Jun 08, 2020
Browse files
Depth box3d
parent
7a129d97
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
515 additions
and
71 deletions
+515
-71
mmdet3d/core/bbox/__init__.py
mmdet3d/core/bbox/__init__.py
+3
-2
mmdet3d/core/bbox/structures/__init__.py
mmdet3d/core/bbox/structures/__init__.py
+5
-1
mmdet3d/core/bbox/structures/base_box3d.py
mmdet3d/core/bbox/structures/base_box3d.py
+56
-19
mmdet3d/core/bbox/structures/box_3d_mode.py
mmdet3d/core/bbox/structures/box_3d_mode.py
+15
-2
mmdet3d/core/bbox/structures/cam_box3d.py
mmdet3d/core/bbox/structures/cam_box3d.py
+27
-17
mmdet3d/core/bbox/structures/depth_box3d.py
mmdet3d/core/bbox/structures/depth_box3d.py
+184
-0
mmdet3d/core/bbox/structures/lidar_box3d.py
mmdet3d/core/bbox/structures/lidar_box3d.py
+29
-19
tests/test_box3d.py
tests/test_box3d.py
+196
-11
No files found.
mmdet3d/core/bbox/__init__.py
View file @
30101f73
...
...
@@ -7,7 +7,8 @@ from .iou_calculators import (BboxOverlaps3D, BboxOverlapsNearest3D,
from
.samplers
import
(
BaseSampler
,
CombinedSampler
,
InstanceBalancedPosSampler
,
IoUBalancedNegSampler
,
PseudoSampler
,
RandomSampler
,
SamplingResult
)
from
.structures
import
Box3DMode
,
CameraInstance3DBoxes
,
LiDARInstance3DBoxes
from
.structures
import
(
Box3DMode
,
CameraInstance3DBoxes
,
DepthInstance3DBoxes
,
LiDARInstance3DBoxes
)
from
.transforms
import
(
bbox3d2result
,
bbox3d2roi
,
box3d_to_corner3d_upright_depth
,
boxes3d_to_bev_torch_lidar
)
...
...
@@ -25,5 +26,5 @@ __all__ = [
'BboxOverlapsNearest3D'
,
'BboxOverlaps3D'
,
'bbox_overlaps_nearest_3d'
,
'bbox_overlaps_3d'
,
'Box3DMode'
,
'LiDARInstance3DBoxes'
,
'CameraInstance3DBoxes'
,
'bbox3d2roi'
,
'bbox3d2result'
,
'box3d_to_corner3d_upright_depth'
'box3d_to_corner3d_upright_depth'
,
'DepthInstance3DBoxes'
]
mmdet3d/core/bbox/structures/__init__.py
View file @
30101f73
from
.box_3d_mode
import
Box3DMode
from
.cam_box3d
import
CameraInstance3DBoxes
from
.depth_box3d
import
DepthInstance3DBoxes
from
.lidar_box3d
import
LiDARInstance3DBoxes
__all__
=
[
'Box3DMode'
,
'LiDARInstance3DBoxes'
,
'CameraInstance3DBoxes'
]
__all__
=
[
'Box3DMode'
,
'LiDARInstance3DBoxes'
,
'CameraInstance3DBoxes'
,
'DepthInstance3DBoxes'
]
mmdet3d/core/bbox/structures/base_box3d.py
View file @
30101f73
...
...
@@ -14,9 +14,11 @@ class BaseInstance3DBoxes(object):
tensor (torch.Tensor | np.ndarray): a Nxbox_dim matrix.
box_dim (int): number of the dimension of a box
Each row is (x, y, z, x_size, y_size, z_size, yaw).
with_yaw (bool): if True, the value of yaw will be
set to 0 as minmax boxes.
"""
def
__init__
(
self
,
tensor
,
box_dim
=
7
):
def
__init__
(
self
,
tensor
,
box_dim
=
7
,
with_yaw
=
True
):
if
isinstance
(
tensor
,
torch
.
Tensor
):
device
=
tensor
.
device
else
:
...
...
@@ -28,7 +30,15 @@ class BaseInstance3DBoxes(object):
tensor
=
tensor
.
reshape
((
0
,
box_dim
)).
to
(
dtype
=
torch
.
float32
,
device
=
device
)
assert
tensor
.
dim
()
==
2
and
tensor
.
size
(
-
1
)
==
box_dim
,
tensor
.
size
()
self
.
box_dim
=
box_dim
if
not
with_yaw
and
tensor
.
shape
[
-
1
]
==
6
:
assert
box_dim
==
6
fake_rot
=
tensor
.
new_zeros
(
tensor
.
shape
[
0
],
1
)
tensor
=
torch
.
cat
((
tensor
,
fake_rot
),
dim
=-
1
)
self
.
box_dim
=
box_dim
+
1
else
:
self
.
box_dim
=
box_dim
self
.
with_yaw
=
with_yaw
self
.
tensor
=
tensor
@
property
...
...
@@ -135,8 +145,8 @@ class BaseInstance3DBoxes(object):
pass
@
abstractmethod
def
flip
(
self
):
"""Flip the boxes in
horizontal
direction
def
flip
(
self
,
bev_direction
=
'horizontal'
):
"""Flip the boxes in
BEV along given BEV
direction
"""
pass
...
...
@@ -184,8 +194,8 @@ class BaseInstance3DBoxes(object):
(x_min, y_min, x_max, y_max)
Returns:
a binary vector, i
ndicating whether each box is inside
the reference range.
torch.Tensor: I
ndicating whether each box is inside
the reference range.
"""
pass
...
...
@@ -193,8 +203,7 @@ class BaseInstance3DBoxes(object):
"""Scale the box with horizontal and vertical scaling factors
Args:
scale_factors (float):
scale factors to scale the boxes.
scale_factors (float): scale factors to scale the boxes.
"""
self
.
tensor
[:,
:
6
]
*=
scale_factor
self
.
tensor
[:,
7
:]
*=
scale_factor
...
...
@@ -218,9 +227,8 @@ class BaseInstance3DBoxes(object):
threshold (float): the threshold of minimal sizes
Returns:
Tensor:
a binary vector which represents whether each box is empty
(False) or non-empty (True).
torch.Tensor: a binary vector which represents whether each
box is empty (False) or non-empty (True).
"""
box
=
self
.
tensor
size_x
=
box
[...,
3
]
...
...
@@ -245,15 +253,19 @@ class BaseInstance3DBoxes(object):
subject to Pytorch's indexing semantics.
Returns:
Boxes: Create a new :class:`Boxes` by indexing.
BaseInstance3DBoxes: Create a new :class:`BaseInstance3DBoxes`
by indexing.
"""
original_type
=
type
(
self
)
if
isinstance
(
item
,
int
):
return
original_type
(
self
.
tensor
[
item
].
view
(
1
,
-
1
))
return
original_type
(
self
.
tensor
[
item
].
view
(
1
,
-
1
),
box_dim
=
self
.
box_dim
,
with_yaw
=
self
.
with_yaw
)
b
=
self
.
tensor
[
item
]
assert
b
.
dim
()
==
2
,
\
f
'Indexing on Boxes with
{
item
}
failed to return a matrix!'
return
original_type
(
b
)
return
original_type
(
b
,
box_dim
=
self
.
box_dim
,
with_yaw
=
self
.
with_yaw
)
def
__len__
(
self
):
return
self
.
tensor
.
shape
[
0
]
...
...
@@ -283,24 +295,30 @@ class BaseInstance3DBoxes(object):
def
to
(
self
,
device
):
original_type
=
type
(
self
)
return
original_type
(
self
.
tensor
.
to
(
device
))
return
original_type
(
self
.
tensor
.
to
(
device
),
box_dim
=
self
.
box_dim
,
with_yaw
=
self
.
with_yaw
)
def
clone
(
self
):
"""Clone the Boxes.
Returns:
B
oxes
B
aseInstance3DBoxes: Box object with the same properties as self.
"""
original_type
=
type
(
self
)
return
original_type
(
self
.
tensor
.
clone
())
return
original_type
(
self
.
tensor
.
clone
(),
box_dim
=
self
.
box_dim
,
with_yaw
=
self
.
with_yaw
)
@
property
def
device
(
self
):
return
self
.
tensor
.
device
def
__iter__
(
self
):
"""
Yield a box as a Tensor of shape (4,) at a time.
"""Yield a box as a Tensor of shape (4,) at a time.
Returns:
torch.Tensor: a box of shape (4,).
"""
yield
from
self
.
tensor
...
...
@@ -387,3 +405,22 @@ class BaseInstance3DBoxes(object):
iou3d
=
overlaps_3d
/
torch
.
clamp
(
volume1
,
min
=
1e-8
)
return
iou3d
def
new_box
(
self
,
data
):
"""Create a new box object with data.
The new box and its tensor has the similar properties
as self and self.tensor, respectively.
Args:
data (torch.Tensor | numpy.array | list): Data which the
returned Tensor copies.
Returns:
BaseInstance3DBoxes: A new bbox with data and other
properties are similar to self.
"""
new_tensor
=
self
.
tensor
.
new_tensor
(
data
)
original_type
=
type
(
self
)
return
original_type
(
new_tensor
,
box_dim
=
self
.
box_dim
,
with_yaw
=
self
.
with_yaw
)
mmdet3d/core/bbox/structures/box_3d_mode.py
View file @
30101f73
...
...
@@ -5,6 +5,7 @@ import torch
from
.base_box3d
import
BaseInstance3DBoxes
from
.cam_box3d
import
CameraInstance3DBoxes
from
.depth_box3d
import
DepthInstance3DBoxes
from
.lidar_box3d
import
LiDARInstance3DBoxes
...
...
@@ -61,7 +62,8 @@ class Box3DMode(IntEnum):
"""Convert boxes from `src` mode to `dst` mode.
Args:
box (tuple | list | np.ndarray | torch.Tensor):
box (tuple | list | np.ndarray |
torch.Tensor | BaseInstance3DBoxes):
can be a k-tuple, k-list or an Nxk array/tensor, where k = 7
src (BoxMode): the src Box mode
dst (BoxMode): the target Box mode
...
...
@@ -113,6 +115,14 @@ class Box3DMode(IntEnum):
if
rt_mat
is
None
:
rt_mat
=
arr
.
new_tensor
([[
1
,
0
,
0
],
[
0
,
0
,
-
1
],
[
0
,
1
,
0
]])
xyz_size
=
torch
.
cat
([
x_size
,
z_size
,
y_size
],
dim
=-
1
)
elif
src
==
Box3DMode
.
LIDAR
and
dst
==
Box3DMode
.
DEPTH
:
if
rt_mat
is
None
:
rt_mat
=
arr
.
new_tensor
([[
0
,
-
1
,
0
],
[
1
,
0
,
0
],
[
0
,
0
,
1
]])
xyz_size
=
torch
.
cat
([
y_size
,
x_size
,
z_size
],
dim
=-
1
)
elif
src
==
Box3DMode
.
DEPTH
and
dst
==
Box3DMode
.
LIDAR
:
if
rt_mat
is
None
:
rt_mat
=
arr
.
new_tensor
([[
0
,
1
,
0
],
[
-
1
,
0
,
0
],
[
0
,
0
,
1
]])
xyz_size
=
torch
.
cat
([
y_size
,
x_size
,
z_size
],
dim
=-
1
)
else
:
raise
NotImplementedError
(
f
'Conversion from Box3DMode
{
src
}
to
{
dst
}
'
...
...
@@ -141,10 +151,13 @@ class Box3DMode(IntEnum):
target_type
=
CameraInstance3DBoxes
elif
dst
==
Box3DMode
.
LIDAR
:
target_type
=
LiDARInstance3DBoxes
elif
dst
==
Box3DMode
.
DEPTH
:
target_type
=
DepthInstance3DBoxes
else
:
raise
NotImplementedError
(
f
'Conversion to
{
dst
}
through
{
original_type
}
'
' is not supported yet'
)
return
target_type
(
arr
,
box_dim
=
arr
.
size
(
-
1
))
return
target_type
(
arr
,
box_dim
=
arr
.
size
(
-
1
),
with_yaw
=
box
.
with_yaw
)
else
:
return
arr
mmdet3d/core/bbox/structures/cam_box3d.py
View file @
30101f73
...
...
@@ -11,10 +11,10 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
Coordinates in camera:
.. code-block:: none
z front
z front
(yaw=0.5*pi)
/
/
0 ------> x right
0 ------> x right
(yaw=0)
|
|
v
...
...
@@ -22,11 +22,15 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
The relative coordinate of bottom center in a CAM box is [0.5, 1.0, 0.5],
and the yaw is around the y axis, thus the rotation axis=1.
The yaw is 0 at the positive direction of x axis, and increases from
the positive direction of x to the positive direction of z.
Attributes:
tensor (torch.Tensor): float matrix of N x box_dim.
box_dim (int): integer indicates the dimension of a box
Each row is (x, y, z, x_size, y_size, z_size, yaw, ...).
with_yaw (bool): if True, the value of yaw will be
set to 0 as minmax boxes.
"""
@
property
...
...
@@ -75,7 +79,7 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
"""Calculate the coordinates of corners of all the boxes.
Convert the boxes to in clockwise order, in the form of
(x0y0z0, x0y0z1, x0y1z1, x0y1z0, x1y0z0, x1y0z1, x1y1z
0
, x1y1z
1
)
(x0y0z0, x0y0z1, x0y1z1, x0y1z0, x1y0z0, x1y0z1, x1y1z
1
, x1y1z
0
)
.. code-block:: none
...
...
@@ -85,7 +89,7 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
(x0, y0, z1) + ----------- + (x1, y0, z1)
/| / |
/ | / |
(x0, y0, z0) + ----------- + + (x1, y1, z
0
)
(x0, y0, z0) + ----------- + + (x1, y1, z
1
)
| / . | /
| / oriign | /
(x0, y1, z0) + ----------- + -------> x right
...
...
@@ -123,7 +127,7 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
return
self
.
tensor
[:,
[
0
,
2
,
3
,
5
,
6
]]
@
property
def
near
s
et_bev
(
self
):
def
neare
s
t_bev
(
self
):
"""Calculate the 2D bounding boxes in BEV without rotation
Returns:
...
...
@@ -150,11 +154,7 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
"""Calculate whether the points is in any of the boxes
Args:
angles (float | torch.Tensor): rotation angle
Returns:
None if `return_rot_mat=False`,
torch.Tensor if `return_rot_mat=True`
angle (float | torch.Tensor): rotation angle
"""
if
not
isinstance
(
angle
,
torch
.
Tensor
):
angle
=
self
.
tensor
.
new_tensor
(
angle
)
...
...
@@ -166,13 +166,23 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
self
.
tensor
[:,
:
3
]
=
self
.
tensor
[:,
:
3
]
@
rot_mat_T
self
.
tensor
[:,
6
]
+=
angle
def
flip
(
self
):
"""Flip the boxes in
horizontal
direction
def
flip
(
self
,
bev_direction
=
'horizontal'
):
"""Flip the boxes in
BEV along given BEV
direction
In CAM coordinates, it flips the x axis.
In CAM coordinates, it flips the x (horizontal) or z (vertical) axis.
Args:
bev_direction (str): Flip direction (horizontal or vertical).
"""
self
.
tensor
[:,
0
::
7
]
=
-
self
.
tensor
[:,
0
::
7
]
self
.
tensor
[:,
6
]
=
-
self
.
tensor
[:,
6
]
+
np
.
pi
assert
bev_direction
in
(
'horizontal'
,
'vertical'
)
if
bev_direction
==
'horizontal'
:
self
.
tensor
[:,
0
::
7
]
=
-
self
.
tensor
[:,
0
::
7
]
if
self
.
with_yaw
:
self
.
tensor
[:,
6
]
=
-
self
.
tensor
[:,
6
]
+
np
.
pi
elif
bev_direction
==
'vertical'
:
self
.
tensor
[:,
2
::
7
]
=
-
self
.
tensor
[:,
2
::
7
]
if
self
.
with_yaw
:
self
.
tensor
[:,
6
]
=
-
self
.
tensor
[:,
6
]
def
in_range_bev
(
self
,
box_range
):
"""Check whether the boxes are in the given range
...
...
@@ -188,8 +198,8 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
TODO: check whether this will effect the performance
Returns:
a binary vector, i
ndicating whether each box is inside
the reference range.
torch.Tensor: I
ndicating whether each box is inside
the reference range.
"""
in_range_flags
=
((
self
.
tensor
[:,
0
]
>
box_range
[
0
])
&
(
self
.
tensor
[:,
2
]
>
box_range
[
1
])
...
...
mmdet3d/core/bbox/structures/depth_box3d.py
0 → 100644
View file @
30101f73
import
numpy
as
np
import
torch
from
.base_box3d
import
BaseInstance3DBoxes
from
.utils
import
limit_period
,
rotation_3d_in_axis
class
DepthInstance3DBoxes
(
BaseInstance3DBoxes
):
"""3D boxes of instances in Depth coordinates
Coordinates in Depth:
.. code-block:: none
up z y front (yaw=0.5*pi)
^ ^
| /
| /
0 ------> x right (yaw=0)
The relative coordinate of bottom center in a Depth box is [0.5, 0.5, 0],
and the yaw is around the z axis, thus the rotation axis=2.
The yaw is 0 at the positive direction of x axis, and increases from
the positive direction of x to the positive direction of y.
Attributes:
tensor (torch.Tensor): float matrix of N x box_dim.
box_dim (int): integer indicates the dimension of a box
Each row is (x, y, z, x_size, y_size, z_size, yaw, ...).
with_yaw (bool): if True, the value of yaw will be
set to 0 as minmax boxes.
"""
@
property
def
gravity_center
(
self
):
"""Calculate the gravity center of all the boxes.
Returns:
torch.Tensor: a tensor with center of each box.
"""
bottom_center
=
self
.
bottom_center
gravity_center
=
torch
.
zeros_like
(
bottom_center
)
gravity_center
[:,
:
2
]
=
bottom_center
[:,
:
2
]
gravity_center
[:,
2
]
=
bottom_center
[:,
2
]
+
self
.
tensor
[:,
5
]
*
0.5
return
gravity_center
@
property
def
corners
(
self
):
"""Calculate the coordinates of corners of all the boxes.
Convert the boxes to corners in clockwise order, in form of
(x0y0z0, x0y0z1, x0y1z1, x0y1z0, x1y0z0, x1y0z1, x1y1z1, x1y1z0)
.. code-block:: none
up z
front y ^
/ |
/ |
(x0, y1, z1) + ----------- + (x1, y1, z1)
/| / |
/ | / |
(x0, y0, z1) + ----------- + + (x1, y1, z0)
| / . | /
| / oriign | /
(x0, y0, z0) + ----------- + --------> right x
(x1, y0, z0)
Returns:
torch.Tensor: corners of each box with size (N, 8, 3)
"""
dims
=
self
.
dims
corners_norm
=
torch
.
from_numpy
(
np
.
stack
(
np
.
unravel_index
(
np
.
arange
(
8
),
[
2
]
*
3
),
axis
=
1
)).
to
(
device
=
dims
.
device
,
dtype
=
dims
.
dtype
)
corners_norm
=
corners_norm
[[
0
,
1
,
3
,
2
,
4
,
5
,
7
,
6
]]
# use relative origin [0.5, 0.5, 0]
corners_norm
=
corners_norm
-
dims
.
new_tensor
([
0.5
,
0.5
,
0
])
corners
=
dims
.
view
([
-
1
,
1
,
3
])
*
corners_norm
.
reshape
([
1
,
8
,
3
])
# rotate around z axis
corners
=
rotation_3d_in_axis
(
corners
,
self
.
tensor
[:,
6
],
axis
=
2
)
corners
+=
self
.
tensor
[:,
:
3
].
view
(
-
1
,
1
,
3
)
return
corners
@
property
def
bev
(
self
):
"""Calculate the 2D bounding boxes in BEV with rotation
Returns:
torch.Tensor: a nx5 tensor of 2D BEV box of each box.
The box is in XYWHR format
"""
return
self
.
tensor
[:,
[
0
,
1
,
3
,
4
,
6
]]
@
property
def
nearest_bev
(
self
):
"""Calculate the 2D bounding boxes in BEV without rotation
Returns:
torch.Tensor: a tensor of 2D BEV box of each box.
"""
# Obtain BEV boxes with rotation in XYWHR format
bev_rotated_boxes
=
self
.
bev
# convert the rotation to a valid range
rotations
=
bev_rotated_boxes
[:,
-
1
]
normed_rotations
=
torch
.
abs
(
limit_period
(
rotations
,
0.5
,
np
.
pi
))
# find the center of boxes
conditions
=
(
normed_rotations
>
np
.
pi
/
4
)[...,
None
]
bboxes_xywh
=
torch
.
where
(
conditions
,
bev_rotated_boxes
[:,
[
0
,
1
,
3
,
2
]],
bev_rotated_boxes
[:,
:
4
])
centers
=
bboxes_xywh
[:,
:
2
]
dims
=
bboxes_xywh
[:,
2
:]
bev_boxes
=
torch
.
cat
([
centers
-
dims
/
2
,
centers
+
dims
/
2
],
dim
=-
1
)
return
bev_boxes
def
rotate
(
self
,
angle
):
"""Calculate whether the points is in any of the boxes
Args:
angle (float | torch.Tensor): rotation angle
"""
if
not
isinstance
(
angle
,
torch
.
Tensor
):
angle
=
self
.
tensor
.
new_tensor
(
angle
)
rot_sin
=
torch
.
sin
(
angle
)
rot_cos
=
torch
.
cos
(
angle
)
rot_mat
=
self
.
tensor
.
new_tensor
([[
rot_cos
,
-
rot_sin
,
0
],
[
rot_sin
,
rot_cos
,
0
],
[
0
,
0
,
1
]])
self
.
tensor
[:,
0
:
3
]
=
self
.
tensor
[:,
0
:
3
]
@
rot_mat
.
T
if
self
.
with_yaw
:
self
.
tensor
[:,
6
]
-=
angle
else
:
corners_rot
=
self
.
corners
@
rot_mat
.
T
new_x_size
=
corners_rot
[...,
0
].
max
(
dim
=
1
,
keepdim
=
True
)[
0
]
-
corners_rot
[...,
0
].
min
(
dim
=
1
,
keepdim
=
True
)[
0
]
new_y_size
=
corners_rot
[...,
1
].
max
(
dim
=
1
,
keepdim
=
True
)[
0
]
-
corners_rot
[...,
1
].
min
(
dim
=
1
,
keepdim
=
True
)[
0
]
self
.
tensor
[:,
3
:
5
]
=
torch
.
cat
((
new_x_size
,
new_y_size
),
dim
=-
1
)
def
flip
(
self
,
bev_direction
=
'horizontal'
):
"""Flip the boxes in BEV along given BEV direction
In Depth coordinates, it flips x (horizontal) or y (vertical) axis.
Args:
bev_direction (str): Flip direction (horizontal or vertical).
"""
assert
bev_direction
in
(
'horizontal'
,
'vertical'
)
if
bev_direction
==
'horizontal'
:
self
.
tensor
[:,
0
::
7
]
=
-
self
.
tensor
[:,
0
::
7
]
if
self
.
with_yaw
:
self
.
tensor
[:,
6
]
=
-
self
.
tensor
[:,
6
]
+
np
.
pi
elif
bev_direction
==
'vertical'
:
self
.
tensor
[:,
1
::
7
]
=
-
self
.
tensor
[:,
1
::
7
]
if
self
.
with_yaw
:
self
.
tensor
[:,
6
]
=
-
self
.
tensor
[:,
6
]
def
in_range_bev
(
self
,
box_range
):
"""Check whether the boxes are in the given range
Args:
box_range (list | torch.Tensor): the range of box
(x_min, y_min, x_max, y_max)
Note:
In the original implementation of SECOND, checking whether
a box in the range checks whether the points are in a convex
polygon, we try to reduce the burdun for simpler cases.
TODO: check whether this will effect the performance
Returns:
torch.Tensor: Indicating whether each box is inside
the reference range.
"""
in_range_flags
=
((
self
.
tensor
[:,
0
]
>
box_range
[
0
])
&
(
self
.
tensor
[:,
1
]
>
box_range
[
1
])
&
(
self
.
tensor
[:,
0
]
<
box_range
[
2
])
&
(
self
.
tensor
[:,
1
]
<
box_range
[
3
]))
return
in_range_flags
mmdet3d/core/bbox/structures/lidar_box3d.py
View file @
30101f73
...
...
@@ -11,19 +11,23 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
Coordinates in LiDAR:
.. code-block:: none
up z x front
^ ^
| /
| /
left y <------ 0
up z x front
(yaw=0.5*pi)
^ ^
| /
| /
(yaw=pi)
left y <------ 0
The relative coordinate of bottom center in a LiDAR box is [0.5, 0.5, 0],
and the yaw is around the z axis, thus the rotation axis=2.
The yaw is 0 at the negative direction of y axis, and increases from
the negative direction of y to the positive direction of x.
Attributes:
tensor (torch.Tensor): float matrix of N x box_dim.
box_dim (int): integer indicates the dimension of a box
Each row is (x, y, z, x_size, y_size, z_size, yaw, ...).
with_yaw (bool): if True, the value of yaw will be
set to 0 as minmax boxes.
"""
@
property
...
...
@@ -44,7 +48,7 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
"""Calculate the coordinates of corners of all the boxes.
Convert the boxes to corners in clockwise order, in form of
(x0y0z0, x0y0z1, x0y1z1, x0y1z0, x1y0z0, x1y0z1, x1y1z
0
, x1y1z
1
)
(x0y0z0, x0y0z1, x0y1z1, x0y1z0, x1y0z0, x1y0z1, x1y1z
1
, x1y1z
0
)
.. code-block:: none
...
...
@@ -90,7 +94,7 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
return
self
.
tensor
[:,
[
0
,
1
,
3
,
4
,
6
]]
@
property
def
near
s
et_bev
(
self
):
def
neare
s
t_bev
(
self
):
"""Calculate the 2D bounding boxes in BEV without rotation
Returns:
...
...
@@ -117,11 +121,7 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
"""Calculate whether the points is in any of the boxes
Args:
angles (float | torch.Tensor): rotation angle
Returns:
None if `return_rot_mat=False`,
torch.Tensor if `return_rot_mat=True`
angle (float | torch.Tensor): rotation angle
"""
if
not
isinstance
(
angle
,
torch
.
Tensor
):
angle
=
self
.
tensor
.
new_tensor
(
angle
)
...
...
@@ -133,13 +133,23 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
self
.
tensor
[:,
:
3
]
=
self
.
tensor
[:,
:
3
]
@
rot_mat_T
self
.
tensor
[:,
6
]
+=
angle
def
flip
(
self
):
"""Flip the boxes in
horizontal
direction
def
flip
(
self
,
bev_direction
=
'horizontal'
):
"""Flip the boxes in
BEV along given BEV
direction
In LIDAR coordinates, it flips the y axis.
In LIDAR coordinates, it flips the y (horizontal) or x (vertical) axis.
Args:
bev_direction (str): Flip direction (horizontal or vertical).
"""
self
.
tensor
[:,
1
::
7
]
=
-
self
.
tensor
[:,
1
::
7
]
self
.
tensor
[:,
6
]
=
-
self
.
tensor
[:,
6
]
+
np
.
pi
assert
bev_direction
in
(
'horizontal'
,
'vertical'
)
if
bev_direction
==
'horizontal'
:
self
.
tensor
[:,
1
::
7
]
=
-
self
.
tensor
[:,
1
::
7
]
if
self
.
with_yaw
:
self
.
tensor
[:,
6
]
=
-
self
.
tensor
[:,
6
]
+
np
.
pi
elif
bev_direction
==
'vertical'
:
self
.
tensor
[:,
0
::
7
]
=
-
self
.
tensor
[:,
0
::
7
]
if
self
.
with_yaw
:
self
.
tensor
[:,
6
]
=
-
self
.
tensor
[:,
6
]
def
in_range_bev
(
self
,
box_range
):
"""Check whether the boxes are in the given range
...
...
@@ -155,8 +165,8 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
TODO: check whether this will effect the performance
Returns:
a binary vector, i
ndicating whether each box is inside
the reference range.
torch.Tensor: I
ndicating whether each box is inside
the reference range.
"""
in_range_flags
=
((
self
.
tensor
[:,
0
]
>
box_range
[
0
])
&
(
self
.
tensor
[:,
1
]
>
box_range
[
1
])
...
...
tests/test_box3d.py
View file @
30101f73
...
...
@@ -3,7 +3,7 @@ import pytest
import
torch
from
mmdet3d.core.bbox
import
(
Box3DMode
,
CameraInstance3DBoxes
,
LiDARInstance3DBoxes
)
DepthInstance3DBoxes
,
LiDARInstance3DBoxes
)
def
test_lidar_boxes3d
():
...
...
@@ -70,9 +70,19 @@ def test_lidar_boxes3d():
[
28.2967
,
0.5557558
,
-
1.303325
,
1.47
,
2.23
,
1.48
,
4.7115927
],
[
26.66902
,
-
21.82302
,
-
1.736057
,
1.56
,
3.48
,
1.4
,
4.8315926
],
[
31.31978
,
-
8.162144
,
-
1.6217787
,
1.74
,
3.77
,
1.48
,
0.35159278
]])
boxes
.
flip
()
boxes
.
flip
(
'horizontal'
)
assert
torch
.
allclose
(
boxes
.
tensor
,
expected_tensor
)
expected_tensor
=
torch
.
tensor
(
[[
-
1.7802
,
-
2.5162
,
-
1.7501
,
1.7500
,
3.3900
,
1.6500
,
-
1.6616
],
[
-
8.9594
,
-
2.4567
,
-
1.6357
,
1.5400
,
4.0100
,
1.5700
,
-
1.5216
],
[
-
28.2967
,
0.5558
,
-
1.3033
,
1.4700
,
2.2300
,
1.4800
,
-
4.7116
],
[
-
26.6690
,
-
21.8230
,
-
1.7361
,
1.5600
,
3.4800
,
1.4000
,
-
4.8316
],
[
-
31.3198
,
-
8.1621
,
-
1.6218
,
1.7400
,
3.7700
,
1.4800
,
-
0.3516
]])
boxes_flip_vert
=
boxes
.
clone
()
boxes_flip_vert
.
flip
(
'vertical'
)
assert
torch
.
allclose
(
boxes_flip_vert
.
tensor
,
expected_tensor
,
1e-4
)
# test box rotation
expected_tensor
=
torch
.
tensor
(
[[
1.0385344
,
-
2.9020846
,
-
1.7501148
,
1.75
,
3.39
,
1.65
,
1.9336663
],
...
...
@@ -223,7 +233,7 @@ def test_lidar_boxes3d():
[
27.3398
,
-
18.3976
,
29.0896
,
-
14.6065
]])
# the pytorch print loses some precision
assert
torch
.
allclose
(
boxes
.
near
s
et_bev
,
expected_tensor
,
rtol
=
1e-4
,
atol
=
1e-7
)
boxes
.
neare
s
t_bev
,
expected_tensor
,
rtol
=
1e-4
,
atol
=
1e-7
)
# obtained by the print of the original implementation
expected_tensor
=
torch
.
tensor
([[[
2.4093e+00
,
-
4.4784e+00
,
-
1.9169e+00
],
...
...
@@ -269,6 +279,25 @@ def test_lidar_boxes3d():
# the pytorch print loses some precision
assert
torch
.
allclose
(
boxes
.
corners
,
expected_tensor
,
rtol
=
1e-4
,
atol
=
1e-7
)
# test new_box
new_box1
=
boxes
.
new_box
([[
1
,
2
,
3
,
4
,
5
,
6
,
7
]])
assert
torch
.
allclose
(
new_box1
.
tensor
,
torch
.
tensor
([[
1
,
2
,
3
,
4
,
5
,
6
,
7
]],
dtype
=
boxes
.
tensor
.
dtype
))
assert
new_box1
.
device
==
boxes
.
device
assert
new_box1
.
with_yaw
==
boxes
.
with_yaw
assert
new_box1
.
box_dim
==
boxes
.
box_dim
new_box2
=
boxes
.
new_box
(
np
.
array
([[
1
,
2
,
3
,
4
,
5
,
6
,
7
]]))
assert
torch
.
allclose
(
new_box2
.
tensor
,
torch
.
tensor
([[
1
,
2
,
3
,
4
,
5
,
6
,
7
]],
dtype
=
boxes
.
tensor
.
dtype
))
new_box3
=
boxes
.
new_box
(
torch
.
tensor
([[
1
,
2
,
3
,
4
,
5
,
6
,
7
]]))
assert
torch
.
allclose
(
new_box3
.
tensor
,
torch
.
tensor
([[
1
,
2
,
3
,
4
,
5
,
6
,
7
]],
dtype
=
boxes
.
tensor
.
dtype
))
def
test_boxes_conversion
():
"""Test the conversion of boxes between different modes.
...
...
@@ -310,12 +339,6 @@ def test_boxes_conversion():
Box3DMode
.
DEPTH
,
Box3DMode
.
CAM
)
assert
torch
.
allclose
(
cam_box_tensor
,
depth_to_cam_box_tensor
)
# test error raise with not supported conversion
with
pytest
.
raises
(
NotImplementedError
):
Box3DMode
.
convert
(
lidar_box_tensor
,
Box3DMode
.
LIDAR
,
Box3DMode
.
DEPTH
)
with
pytest
.
raises
(
NotImplementedError
):
Box3DMode
.
convert
(
depth_box_tensor
,
Box3DMode
.
DEPTH
,
Box3DMode
.
LIDAR
)
# test similar mode conversion
same_results
=
Box3DMode
.
convert
(
depth_box_tensor
,
Box3DMode
.
DEPTH
,
Box3DMode
.
DEPTH
)
...
...
@@ -389,6 +412,31 @@ def test_boxes_conversion():
rt_mat
.
inverse
().
numpy
())
assert
np
.
allclose
(
np
.
array
(
cam_to_lidar_box
),
expected_tensor
[
0
].
numpy
())
# test convert from depth to lidar
depth_boxes
=
torch
.
tensor
(
[[
2.4593
,
2.5870
,
-
0.4321
,
0.8597
,
0.6193
,
1.0204
,
3.0693
],
[
1.4856
,
2.5299
,
-
0.5570
,
0.9385
,
2.1404
,
0.8954
,
3.0601
]],
dtype
=
torch
.
float32
)
depth_boxes
=
DepthInstance3DBoxes
(
depth_boxes
)
depth_to_lidar_box
=
Box3DMode
.
convert
(
depth_boxes
,
Box3DMode
.
DEPTH
,
Box3DMode
.
LIDAR
)
lidar_to_depth_box
=
Box3DMode
.
convert
(
depth_to_lidar_box
,
Box3DMode
.
LIDAR
,
Box3DMode
.
DEPTH
)
assert
torch
.
allclose
(
depth_boxes
.
tensor
,
lidar_to_depth_box
.
tensor
)
assert
torch
.
allclose
(
depth_boxes
.
volume
,
lidar_to_depth_box
.
volume
)
# test convert from depth to camera
depth_to_cam_box
=
Box3DMode
.
convert
(
depth_boxes
,
Box3DMode
.
DEPTH
,
Box3DMode
.
CAM
)
cam_to_depth_box
=
Box3DMode
.
convert
(
depth_to_cam_box
,
Box3DMode
.
CAM
,
Box3DMode
.
DEPTH
)
assert
torch
.
allclose
(
depth_boxes
.
tensor
,
cam_to_depth_box
.
tensor
)
assert
torch
.
allclose
(
depth_boxes
.
volume
,
cam_to_depth_box
.
volume
)
with
pytest
.
raises
(
NotImplementedError
):
# assert invalid convert mode
Box3DMode
.
convert
(
depth_boxes
,
Box3DMode
.
DEPTH
,
3
)
def
test_camera_boxes3d
():
# Test init with numpy array
...
...
@@ -449,9 +497,19 @@ def test_camera_boxes3d():
[
26.66902
,
-
21.82302
,
-
1.736057
,
1.56
,
3.48
,
1.4
,
4.8315926
],
[
31.31978
,
-
8.162144
,
-
1.6217787
,
1.74
,
3.77
,
1.48
,
0.35159278
]]),
Box3DMode
.
LIDAR
,
Box3DMode
.
CAM
)
boxes
.
flip
()
boxes
.
flip
(
'horizontal'
)
assert
torch
.
allclose
(
boxes
.
tensor
,
expected_tensor
)
expected_tensor
=
torch
.
tensor
(
[[
2.5162
,
1.7501
,
-
1.7802
,
3.3900
,
1.6500
,
1.7500
,
-
1.6616
],
[
2.4567
,
1.6357
,
-
8.9594
,
4.0100
,
1.5700
,
1.5400
,
-
1.5216
],
[
-
0.5558
,
1.3033
,
-
28.2967
,
2.2300
,
1.4800
,
1.4700
,
-
4.7116
],
[
21.8230
,
1.7361
,
-
26.6690
,
3.4800
,
1.4000
,
1.5600
,
-
4.8316
],
[
8.1621
,
1.6218
,
-
31.3198
,
3.7700
,
1.4800
,
1.7400
,
-
0.3516
]])
boxes_flip_vert
=
boxes
.
clone
()
boxes_flip_vert
.
flip
(
'vertical'
)
assert
torch
.
allclose
(
boxes_flip_vert
.
tensor
,
expected_tensor
,
1e-4
)
# test box rotation
expected_tensor
=
Box3DMode
.
convert
(
torch
.
tensor
(
...
...
@@ -560,7 +618,7 @@ def test_camera_boxes3d():
expected_tensor
[:,
1
::
2
]
=
lidar_expected_tensor
[:,
0
::
2
]
# the pytorch print loses some precision
assert
torch
.
allclose
(
boxes
.
near
s
et_bev
,
expected_tensor
,
rtol
=
1e-4
,
atol
=
1e-7
)
boxes
.
neare
s
t_bev
,
expected_tensor
,
rtol
=
1e-4
,
atol
=
1e-7
)
# obtained by the print of the original implementation
expected_tensor
=
torch
.
tensor
([[[
3.2684e+00
,
2.5769e-01
,
-
7.7767e-01
],
...
...
@@ -659,3 +717,130 @@ def test_boxes3d_overlaps():
cam_boxes1
.
overlaps
(
cam_boxes1
,
boxes1
)
with
pytest
.
raises
(
AssertionError
):
boxes1
.
overlaps
(
cam_boxes1
,
boxes1
)
def
test_depth_boxes3d
():
# test empty initialization
empty_boxes
=
[]
boxes
=
DepthInstance3DBoxes
(
empty_boxes
)
assert
boxes
.
tensor
.
shape
[
0
]
==
0
assert
boxes
.
tensor
.
shape
[
1
]
==
7
# Test init with numpy array
np_boxes
=
np
.
array
(
[[
1.4856
,
2.5299
,
-
0.5570
,
0.9385
,
2.1404
,
0.8954
,
3.0601
],
[
2.3262
,
3.3065
,
--
0.44255
,
0.8234
,
0.5325
,
1.0099
,
2.9971
]],
dtype
=
np
.
float32
)
boxes_1
=
DepthInstance3DBoxes
(
np_boxes
)
assert
torch
.
allclose
(
boxes_1
.
tensor
,
torch
.
from_numpy
(
np_boxes
))
# test properties
assert
boxes_1
.
volume
.
size
(
0
)
==
2
assert
(
boxes_1
.
center
==
boxes_1
.
bottom_center
).
all
()
expected_tensor
=
torch
.
tensor
([[
1.4856
,
2.5299
,
-
0.1093
],
[
2.3262
,
3.3065
,
0.9475
]])
assert
torch
.
allclose
(
boxes_1
.
gravity_center
,
expected_tensor
)
expected_tensor
=
torch
.
tensor
([[
1.4856
,
2.5299
,
0.9385
,
2.1404
,
3.0601
],
[
2.3262
,
3.3065
,
0.8234
,
0.5325
,
2.9971
]])
assert
torch
.
allclose
(
boxes_1
.
bev
,
expected_tensor
)
expected_tensor
=
torch
.
tensor
([[
1.0164
,
1.4597
,
1.9548
,
3.6001
],
[
1.9145
,
3.0402
,
2.7379
,
3.5728
]])
assert
torch
.
allclose
(
boxes_1
.
nearest_bev
,
expected_tensor
,
1e-4
)
assert
repr
(
boxes
)
==
(
'DepthInstance3DBoxes(
\n
tensor([], size=(0, 7)))'
)
# test init with torch.Tensor
th_boxes
=
torch
.
tensor
(
[[
2.4593
,
2.5870
,
-
0.4321
,
0.8597
,
0.6193
,
1.0204
,
3.0693
],
[
1.4856
,
2.5299
,
-
0.5570
,
0.9385
,
2.1404
,
0.8954
,
3.0601
]],
dtype
=
torch
.
float32
)
boxes_2
=
DepthInstance3DBoxes
(
th_boxes
)
assert
torch
.
allclose
(
boxes_2
.
tensor
,
th_boxes
)
# test clone/to/device
boxes_2
=
boxes_2
.
clone
()
boxes_1
=
boxes_1
.
to
(
boxes_2
.
device
)
# test box concatenation
expected_tensor
=
torch
.
tensor
(
[[
1.4856
,
2.5299
,
-
0.5570
,
0.9385
,
2.1404
,
0.8954
,
3.0601
],
[
2.3262
,
3.3065
,
--
0.44255
,
0.8234
,
0.5325
,
1.0099
,
2.9971
],
[
2.4593
,
2.5870
,
-
0.4321
,
0.8597
,
0.6193
,
1.0204
,
3.0693
],
[
1.4856
,
2.5299
,
-
0.5570
,
0.9385
,
2.1404
,
0.8954
,
3.0601
]])
boxes
=
DepthInstance3DBoxes
.
cat
([
boxes_1
,
boxes_2
])
assert
torch
.
allclose
(
boxes
.
tensor
,
expected_tensor
)
# concatenate empty list
empty_boxes
=
DepthInstance3DBoxes
.
cat
([])
assert
empty_boxes
.
tensor
.
shape
[
0
]
==
0
assert
empty_boxes
.
tensor
.
shape
[
-
1
]
==
7
# test box flip
expected_tensor
=
torch
.
tensor
(
[[
-
1.4856
,
2.5299
,
-
0.5570
,
0.9385
,
2.1404
,
0.8954
,
0.0815
],
[
-
2.3262
,
3.3065
,
0.4426
,
0.8234
,
0.5325
,
1.0099
,
0.1445
],
[
-
2.4593
,
2.5870
,
-
0.4321
,
0.8597
,
0.6193
,
1.0204
,
0.0723
],
[
-
1.4856
,
2.5299
,
-
0.5570
,
0.9385
,
2.1404
,
0.8954
,
0.0815
]])
boxes
.
flip
(
bev_direction
=
'horizontal'
)
assert
torch
.
allclose
(
boxes
.
tensor
,
expected_tensor
,
1e-3
)
expected_tensor
=
torch
.
tensor
(
[[
-
1.4856
,
-
2.5299
,
-
0.5570
,
0.9385
,
2.1404
,
0.8954
,
-
0.0815
],
[
-
2.3262
,
-
3.3065
,
0.4426
,
0.8234
,
0.5325
,
1.0099
,
-
0.1445
],
[
-
2.4593
,
-
2.5870
,
-
0.4321
,
0.8597
,
0.6193
,
1.0204
,
-
0.0723
],
[
-
1.4856
,
-
2.5299
,
-
0.5570
,
0.9385
,
2.1404
,
0.8954
,
-
0.0815
]])
boxes
.
flip
(
bev_direction
=
'vertical'
)
assert
torch
.
allclose
(
boxes
.
tensor
,
expected_tensor
,
1e-3
)
# test box rotation
boxes_rot
=
boxes
.
clone
()
expected_tensor
=
torch
.
tensor
(
[[
-
1.6004
,
-
2.4589
,
-
0.5570
,
0.9385
,
2.1404
,
0.8954
,
-
0.0355
],
[
-
2.4758
,
-
3.1960
,
0.4426
,
0.8234
,
0.5325
,
1.0099
,
-
0.0985
],
[
-
2.5757
,
-
2.4712
,
-
0.4321
,
0.8597
,
0.6193
,
1.0204
,
-
0.0263
],
[
-
1.6004
,
-
2.4589
,
-
0.5570
,
0.9385
,
2.1404
,
0.8954
,
-
0.0355
]])
boxes_rot
.
rotate
(
-
0.04599790655000615
)
assert
torch
.
allclose
(
boxes_rot
.
tensor
,
expected_tensor
,
1e-3
)
th_boxes
=
torch
.
tensor
(
[[
0.61211395
,
0.8129094
,
0.10563634
,
1.497534
,
0.16927195
,
0.27956772
],
[
1.430009
,
0.49797538
,
0.9382923
,
0.07694054
,
0.9312509
,
1.8919173
]],
dtype
=
torch
.
float32
)
boxes
=
DepthInstance3DBoxes
(
th_boxes
,
box_dim
=
6
,
with_yaw
=
False
)
expected_tensor
=
torch
.
tensor
([[
0.64884546
,
0.78390356
,
0.10563634
,
1.50373348
,
0.23795205
,
0.27956772
,
0
],
[
1.45139421
,
0.43169443
,
0.93829232
,
0.11967964
,
0.93380373
,
1.89191735
,
0
]])
boxes_3
=
boxes
.
clone
()
boxes_3
.
rotate
(
-
0.04599790655000615
)
assert
torch
.
allclose
(
boxes_3
.
tensor
,
expected_tensor
)
boxes
.
rotate
(
torch
.
tensor
(
-
0.04599790655000615
))
assert
torch
.
allclose
(
boxes
.
tensor
,
expected_tensor
)
# test bbox in_range_bev
expected_tensor
=
torch
.
tensor
([
1
,
1
],
dtype
=
torch
.
bool
)
mask
=
boxes
.
in_range_bev
([
0.
,
-
40.
,
70.4
,
40.
])
assert
(
mask
==
expected_tensor
).
all
()
mask
=
boxes
.
nonempty
()
assert
(
mask
==
expected_tensor
).
all
()
expected_tensor
=
torch
.
tensor
([[[
-
0.1030
,
0.6649
,
0.1056
],
[
-
0.1030
,
0.6649
,
0.3852
],
[
-
0.1030
,
0.9029
,
0.3852
],
[
-
0.1030
,
0.9029
,
0.1056
],
[
1.4007
,
0.6649
,
0.1056
],
[
1.4007
,
0.6649
,
0.3852
],
[
1.4007
,
0.9029
,
0.3852
],
[
1.4007
,
0.9029
,
0.1056
]],
[[
1.3916
,
-
0.0352
,
0.9383
],
[
1.3916
,
-
0.0352
,
2.8302
],
[
1.3916
,
0.8986
,
2.8302
],
[
1.3916
,
0.8986
,
0.9383
],
[
1.5112
,
-
0.0352
,
0.9383
],
[
1.5112
,
-
0.0352
,
2.8302
],
[
1.5112
,
0.8986
,
2.8302
],
[
1.5112
,
0.8986
,
0.9383
]]])
torch
.
allclose
(
boxes
.
corners
,
expected_tensor
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment