Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
16202272
Commit
16202272
authored
May 15, 2020
by
zhangwenwei
Browse files
Support iou calculation in box structures
parent
82cd4892
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
187 additions
and
166 deletions
+187
-166
mmdet3d/core/anchor/anchor_3d_generator.py
mmdet3d/core/anchor/anchor_3d_generator.py
+2
-0
mmdet3d/core/bbox/samplers/iou_neg_piecewise_sampler.py
mmdet3d/core/bbox/samplers/iou_neg_piecewise_sampler.py
+1
-1
mmdet3d/core/bbox/structures/base_box3d.py
mmdet3d/core/bbox/structures/base_box3d.py
+93
-6
mmdet3d/core/bbox/structures/cam_box3d.py
mmdet3d/core/bbox/structures/cam_box3d.py
+48
-64
mmdet3d/core/bbox/structures/lidar_box3d.py
mmdet3d/core/bbox/structures/lidar_box3d.py
+3
-72
mmdet3d/core/bbox/structures/utils.py
mmdet3d/core/bbox/structures/utils.py
+13
-0
mmdet3d/datasets/pipelines/indoor_loading.py
mmdet3d/datasets/pipelines/indoor_loading.py
+1
-1
mmdet3d/models/roi_heads/roi_extractors/single_roiaware_extractor.py
...els/roi_heads/roi_extractors/single_roiaware_extractor.py
+1
-1
tests/test_box3d.py
tests/test_box3d.py
+25
-21
No files found.
mmdet3d/core/anchor/anchor_3d_generator.py
View file @
16202272
...
...
@@ -137,6 +137,7 @@ class Anchor3DRangeGenerator(object):
rotations
=
[
0
,
1.5707963
],
device
=
'cuda'
):
"""Generate anchors in a single range
Args:
feature_size: list [D, H, W](zyx)
sizes: [N, 3] list of list or array, size of anchors, xyz
...
...
@@ -221,6 +222,7 @@ class AlignedAnchor3DRangeGenerator(Anchor3DRangeGenerator):
rotations
=
[
0
,
1.5707963
],
device
=
'cuda'
):
"""Generate anchors in a single range
Args:
feature_size: list [D, H, W](zyx)
sizes: [N, 3] list of list or array, size of anchors, xyz
...
...
mmdet3d/core/bbox/samplers/iou_neg_piecewise_sampler.py
View file @
16202272
...
...
@@ -4,7 +4,7 @@ from mmdet.core.bbox.builder import BBOX_SAMPLERS
from
.
import
RandomSampler
,
SamplingResult
@
BBOX_SAMPLERS
.
register_module
@
BBOX_SAMPLERS
.
register_module
()
class
IoUNegPiecewiseSampler
(
RandomSampler
):
"""IoU Piece-wise Sampling
...
...
mmdet3d/core/bbox/structures/base_box3d.py
View file @
16202272
...
...
@@ -3,7 +3,8 @@ from abc import abstractmethod
import
numpy
as
np
import
torch
from
.utils
import
limit_period
from
mmdet3d.ops.iou3d
import
iou3d_cuda
from
.utils
import
limit_period
,
xywhr2xyxyr
class
BaseInstance3DBoxes
(
object
):
...
...
@@ -59,6 +60,24 @@ class BaseInstance3DBoxes(object):
"""
return
self
.
tensor
[:,
5
]
@
property
def
top_height
(
self
):
"""Obtain the top height of all the boxes.
Returns:
torch.Tensor: a vector with the top height of each box.
"""
return
self
.
bottom_height
+
self
.
height
@
property
def
bottom_height
(
self
):
"""Obtain the bottom's height of all the boxes.
Returns:
torch.Tensor: a vector with bottom's height of each box.
"""
return
self
.
tensor
[:,
2
]
@
property
def
center
(
self
):
"""Calculate the center of all the boxes.
...
...
@@ -286,17 +305,85 @@ class BaseInstance3DBoxes(object):
yield
from
self
.
tensor
@
classmethod
def
overlaps
(
cls
,
boxes1
,
boxes2
,
mode
=
'iou'
,
aligned
=
False
):
"""Calculate overlaps of two boxes
def
height_overlaps
(
cls
,
boxes1
,
boxes2
,
mode
=
'iou'
):
"""Calculate height overlaps of two boxes
Note:
This function calculate the height overlaps between boxes1 and
boxes2, boxes1 and boxes2 should be in the same type.
Args:
boxes1 (:obj:BaseInstanceBoxes): boxes 1 contain N boxes
boxes2 (:obj:BaseInstanceBoxes): boxes 2 contain M boxes
mode (str, optional): mode of iou calculation. Defaults to 'iou'.
aligned (bool, optional): Whether the boxes are aligned.
Defaults to False.
Returns:
torch.Tensor: Calculated iou of boxes
"""
pass
assert
isinstance
(
boxes1
,
BaseInstance3DBoxes
)
assert
isinstance
(
boxes2
,
BaseInstance3DBoxes
)
assert
type
(
boxes1
)
==
type
(
boxes2
),
'"boxes1" and "boxes2" should'
\
f
'be in the same type, got
{
type
(
boxes1
)
}
and
{
type
(
boxes2
)
}
.'
boxes1_top_height
=
boxes1
.
top_height
.
view
(
-
1
,
1
)
boxes1_bottom_height
=
boxes1
.
bottom_height
.
view
(
-
1
,
1
)
boxes2_top_height
=
boxes2
.
top_height
.
view
(
1
,
-
1
)
boxes2_bottom_height
=
boxes2
.
bottom_height
.
view
(
1
,
-
1
)
heighest_of_bottom
=
torch
.
max
(
boxes1_bottom_height
,
boxes2_bottom_height
)
lowest_of_top
=
torch
.
min
(
boxes1_top_height
,
boxes2_top_height
)
overlaps_h
=
torch
.
clamp
(
lowest_of_top
-
heighest_of_bottom
,
min
=
0
)
return
overlaps_h
@
classmethod
def
overlaps
(
cls
,
boxes1
,
boxes2
,
mode
=
'iou'
):
"""Calculate 3D overlaps of two boxes
Note:
This function calculate the overlaps between boxes1 and boxes2,
boxes1 and boxes2 are not necessarily to be in the same type.
Args:
boxes1 (:obj:BaseInstanceBoxes): boxes 1 contain N boxes
boxes2 (:obj:BaseInstanceBoxes): boxes 2 contain M boxes
mode (str, optional): mode of iou calculation. Defaults to 'iou'.
Returns:
torch.Tensor: Calculated iou of boxes
"""
assert
isinstance
(
boxes1
,
BaseInstance3DBoxes
)
assert
isinstance
(
boxes2
,
BaseInstance3DBoxes
)
assert
type
(
boxes1
)
==
type
(
boxes2
),
'"boxes1" and "boxes2" should'
\
f
'be in the same type, got
{
type
(
boxes1
)
}
and
{
type
(
boxes2
)
}
.'
assert
mode
in
[
'iou'
,
'iof'
]
# height overlap
overlaps_h
=
cls
.
height_overlaps
(
boxes1
,
boxes2
)
# obtain BEV boxes in XYXYR format
boxes1_bev
=
xywhr2xyxyr
(
boxes1
.
bev
)
boxes2_bev
=
xywhr2xyxyr
(
boxes2
.
bev
)
# bev overlap
overlaps_bev
=
boxes1_bev
.
new_zeros
(
(
boxes1_bev
.
shape
[
0
],
boxes2_bev
.
shape
[
0
])).
cuda
()
# (N, M)
iou3d_cuda
.
boxes_overlap_bev_gpu
(
boxes1_bev
.
contiguous
().
cuda
(),
boxes2_bev
.
contiguous
().
cuda
(),
overlaps_bev
)
# 3d overlaps
overlaps_3d
=
overlaps_bev
.
to
(
boxes1
.
device
)
*
overlaps_h
volume1
=
boxes1
.
volume
.
view
(
-
1
,
1
)
volume2
=
boxes2
.
volume
.
view
(
1
,
-
1
)
if
mode
==
'iou'
:
# the clamp func is used to avoid division of 0
iou3d
=
overlaps_3d
/
torch
.
clamp
(
volume1
+
volume2
-
overlaps_3d
,
min
=
1e-8
)
else
:
iou3d
=
overlaps_3d
/
torch
.
clamp
(
volume1
,
min
=
1e-8
)
return
iou3d
mmdet3d/core/bbox/structures/cam_box3d.py
View file @
16202272
import
numpy
as
np
import
torch
from
mmdet3d.ops.iou3d
import
iou3d_cuda
from
.base_box3d
import
BaseInstance3DBoxes
from
.utils
import
limit_period
,
rotation_3d_in_axis
...
...
@@ -35,10 +34,29 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
"""Obtain the height of all the boxes.
Returns:
torch.Tensor: a vector with
volume
of each box.
torch.Tensor: a vector with
height
of each box.
"""
return
self
.
tensor
[:,
4
]
@
property
def
top_height
(
self
):
"""Obtain the top height of all the boxes.
Returns:
torch.Tensor: a vector with the top height of each box.
"""
# the positive direction is down rather than up
return
self
.
bottom_height
-
self
.
height
@
property
def
bottom_height
(
self
):
"""Obtain the bottom's height of all the boxes.
Returns:
torch.Tensor: a vector with bottom's height of each box.
"""
return
self
.
tensor
[:,
1
]
@
property
def
gravity_center
(
self
):
"""Calculate the gravity center of all the boxes.
...
...
@@ -95,30 +113,14 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
return
corners
@
property
def
bev
(
self
,
mode
=
'XYWHR'
):
def
bev
(
self
):
"""Calculate the 2D bounding boxes in BEV with rotation
Args:
mode (str): The mode of BEV boxes. Default to 'XYWHR'.
Returns:
torch.Tensor: a nx5 tensor of 2D BEV box of each box.
The box is in XYWHR format.
"""
boxes_xywhr
=
self
.
tensor
[:,
[
0
,
2
,
3
,
5
,
6
]]
if
mode
==
'XYWHR'
:
return
boxes_xywhr
elif
mode
==
'XYXYR'
:
boxes
=
torch
.
zeros_like
(
boxes_xywhr
)
boxes
[:,
0
]
=
boxes_xywhr
[:,
0
]
-
boxes_xywhr
[
2
]
boxes
[:,
1
]
=
boxes_xywhr
[:,
1
]
-
boxes_xywhr
[
3
]
boxes
[:,
2
]
=
boxes_xywhr
[:,
0
]
+
boxes_xywhr
[
2
]
boxes
[:,
3
]
=
boxes_xywhr
[:,
1
]
+
boxes_xywhr
[
3
]
boxes
[:,
4
]
=
boxes_xywhr
[:,
4
]
return
boxes
else
:
raise
ValueError
(
'Only support mode to be either "XYWHR" or "XYXYR",'
f
'got
{
mode
}
'
)
return
self
.
tensor
[:,
[
0
,
2
,
3
,
5
,
6
]]
@
property
def
nearset_bev
(
self
):
...
...
@@ -196,53 +198,35 @@ class CameraInstance3DBoxes(BaseInstance3DBoxes):
return
in_range_flags
@
classmethod
def
overlaps
(
cls
,
boxes1
,
boxes2
,
mode
=
'iou'
):
"""Calculate overlaps of two boxes
def
height_overlaps
(
cls
,
boxes1
,
boxes2
,
mode
=
'iou'
):
"""Calculate height overlaps of two boxes
Note:
This function calculate the height overlaps between boxes1 and
boxes2, boxes1 and boxes2 should be in the same type.
Args:
boxes1 (:obj:
Camera
Instance
3D
Boxes): boxes 1 contain N boxes
boxes2 (:obj:
Camera
Instance
3D
Boxes): boxes 2 contain M boxes
boxes1 (:obj:
Base
InstanceBoxes): boxes 1 contain N boxes
boxes2 (:obj:
Base
InstanceBoxes): boxes 2 contain M boxes
mode (str, optional): mode of iou calculation. Defaults to 'iou'.
Returns:
torch.Tensor: Calculated iou of boxes
"""
assert
isinstance
(
boxes1
,
CameraInstance3DBoxes
)
assert
isinstance
(
boxes2
,
CameraInstance3DBoxes
)
assert
mode
in
[
'iou'
,
'iof'
]
# height overlap
boxes1_height_max
=
(
boxes1
.
tensor
[:,
1
]
+
boxes1
.
height
).
view
(
-
1
,
1
)
boxes1_height_min
=
boxes1
.
tensor
[:,
1
].
view
(
-
1
,
1
)
boxes2_height_max
=
(
boxes2
.
tensor
[:,
1
]
+
boxes2
.
height
).
view
(
1
,
-
1
)
boxes2_height_min
=
boxes2
.
tensor
[:,
1
].
view
(
1
,
-
1
)
max_of_min
=
torch
.
max
(
boxes1_height_min
,
boxes2_height_min
)
min_of_max
=
torch
.
min
(
boxes1_height_max
,
boxes2_height_max
)
overlaps_h
=
torch
.
clamp
(
min_of_max
-
max_of_min
,
min
=
0
)
# obtain BEV boxes in XYXYR format
boxes1_bev
=
boxes1
.
bev
(
mode
=
'XYXYR'
)
boxes2_bev
=
boxes2
.
bev
(
mode
=
'XYXYR'
)
# bev overlap
overlaps_bev
=
boxes1_bev
.
new_zeros
(
(
boxes1_bev
.
shape
[
0
],
boxes2_bev
.
shape
[
0
])).
cuda
()
# (N, M)
iou3d_cuda
.
boxes_overlap_bev_gpu
(
boxes1_bev
.
contiguous
().
cuda
(),
boxes2_bev
.
contiguous
().
cuda
(),
overlaps_bev
)
# 3d iou
overlaps_3d
=
overlaps_bev
.
to
(
boxes1
.
device
)
*
overlaps_h
volume1
=
boxes1
.
volume
.
view
(
-
1
,
1
)
volume2
=
boxes2
.
volume
.
view
(
1
,
-
1
)
if
mode
==
'iou'
:
# the clamp func is used to avoid division of 0
iou3d
=
overlaps_3d
/
torch
.
clamp
(
volume1
+
volume2
-
overlaps_3d
,
min
=
1e-8
)
else
:
iou3d
=
overlaps_3d
/
torch
.
clamp
(
volume1
,
min
=
1e-8
)
return
iou3d
assert
isinstance
(
boxes1
,
BaseInstance3DBoxes
)
assert
isinstance
(
boxes2
,
BaseInstance3DBoxes
)
assert
type
(
boxes1
)
==
type
(
boxes2
),
'"boxes1" and "boxes2" should'
\
f
'be in the same type, got
{
type
(
boxes1
)
}
and
{
type
(
boxes2
)
}
.'
boxes1_top_height
=
boxes1
.
top_height
.
view
(
-
1
,
1
)
boxes1_bottom_height
=
boxes1
.
bottom_height
.
view
(
-
1
,
1
)
boxes2_top_height
=
boxes2
.
top_height
.
view
(
1
,
-
1
)
boxes2_bottom_height
=
boxes2
.
bottom_height
.
view
(
1
,
-
1
)
# In camera coordinate system
# from up to down is the positive direction
heighest_of_bottom
=
torch
.
min
(
boxes1_bottom_height
,
boxes2_bottom_height
)
lowest_of_top
=
torch
.
max
(
boxes1_top_height
,
boxes2_top_height
)
overlaps_h
=
torch
.
clamp
(
heighest_of_bottom
-
lowest_of_top
,
min
=
0
)
return
overlaps_h
mmdet3d/core/bbox/structures/lidar_box3d.py
View file @
16202272
import
numpy
as
np
import
torch
from
mmdet3d.ops.iou3d
import
iou3d_cuda
from
.base_box3d
import
BaseInstance3DBoxes
from
.utils
import
limit_period
,
rotation_3d_in_axis
...
...
@@ -81,30 +80,14 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
return
corners
@
property
def
bev
(
self
,
mode
=
'XYWHR'
):
def
bev
(
self
):
"""Calculate the 2D bounding boxes in BEV with rotation
Args:
mode (str): The mode of BEV boxes. Default to 'XYWHR'.
Returns:
torch.Tensor: a nx5 tensor of 2D BEV box of each box.
The box is in XYWHR format
"""
boxes_xywhr
=
self
.
tensor
[:,
[
0
,
1
,
3
,
4
,
6
]]
if
mode
==
'XYWHR'
:
return
boxes_xywhr
elif
mode
==
'XYXYR'
:
boxes
=
torch
.
zeros_like
(
boxes_xywhr
)
boxes
[:,
0
]
=
boxes_xywhr
[:,
0
]
-
boxes_xywhr
[
2
]
boxes
[:,
1
]
=
boxes_xywhr
[:,
1
]
-
boxes_xywhr
[
3
]
boxes
[:,
2
]
=
boxes_xywhr
[:,
0
]
+
boxes_xywhr
[
2
]
boxes
[:,
3
]
=
boxes_xywhr
[:,
1
]
+
boxes_xywhr
[
3
]
boxes
[:,
4
]
=
boxes_xywhr
[:,
4
]
return
boxes
else
:
raise
ValueError
(
'Only support mode to be either "XYWHR" or "XYXYR",'
f
'got
{
mode
}
'
)
return
self
.
tensor
[:,
[
0
,
1
,
3
,
4
,
6
]]
@
property
def
nearset_bev
(
self
):
...
...
@@ -180,55 +163,3 @@ class LiDARInstance3DBoxes(BaseInstance3DBoxes):
&
(
self
.
tensor
[:,
0
]
<
box_range
[
2
])
&
(
self
.
tensor
[:,
1
]
<
box_range
[
3
]))
return
in_range_flags
@
classmethod
def
overlaps
(
cls
,
boxes1
,
boxes2
,
mode
=
'iou'
):
"""Calculate overlaps of two boxes
Args:
boxes1 (:obj:LiDARInstanceBoxes): boxes 1 contain N boxes
boxes2 (:obj:LiDARInstanceBoxes): boxes 2 contain M boxes
mode (str, optional): mode of iou calculation. Defaults to 'iou'.
Returns:
torch.Tensor: Calculated iou of boxes
"""
assert
isinstance
(
boxes1
,
LiDARInstance3DBoxes
)
assert
isinstance
(
boxes2
,
LiDARInstance3DBoxes
)
assert
mode
in
[
'iou'
,
'iof'
]
# height overlap
boxes1_height_max
=
(
boxes1
.
tensor
[:,
2
]
+
boxes1
.
height
).
view
(
-
1
,
1
)
boxes1_height_min
=
boxes1
.
tensor
[:,
2
].
view
(
-
1
,
1
)
boxes2_height_max
=
(
boxes2
.
tensor
[:,
2
]
+
boxes2
.
height
).
view
(
1
,
-
1
)
boxes2_height_min
=
boxes2
.
tensor
[:,
2
].
view
(
1
,
-
1
)
max_of_min
=
torch
.
max
(
boxes1_height_min
,
boxes2_height_min
)
min_of_max
=
torch
.
min
(
boxes1_height_max
,
boxes2_height_max
)
overlaps_h
=
torch
.
clamp
(
min_of_max
-
max_of_min
,
min
=
0
)
# obtain BEV boxes in XYXYR format
boxes1_bev
=
boxes1
.
bev
(
mode
=
'XYXYR'
)
boxes2_bev
=
boxes2
.
bev
(
mode
=
'XYXYR'
)
# bev overlap
overlaps_bev
=
boxes1_bev
.
new_zeros
(
(
boxes1_bev
.
shape
[
0
],
boxes2_bev
.
shape
[
0
])).
cuda
()
# (N, M)
iou3d_cuda
.
boxes_overlap_bev_gpu
(
boxes1_bev
.
contiguous
().
cuda
(),
boxes2_bev
.
contiguous
().
cuda
(),
overlaps_bev
)
# 3d iou
overlaps_3d
=
overlaps_bev
.
to
(
boxes1
.
device
)
*
overlaps_h
volume1
=
boxes1
.
volume
.
view
(
-
1
,
1
)
volume2
=
boxes2
.
volume
.
view
(
1
,
-
1
)
if
mode
==
'iou'
:
# the clamp func is used to avoid division of 0
iou3d
=
overlaps_3d
/
torch
.
clamp
(
volume1
+
volume2
-
overlaps_3d
,
min
=
1e-8
)
else
:
iou3d
=
overlaps_3d
/
torch
.
clamp
(
volume1
,
min
=
1e-8
)
return
iou3d
mmdet3d/core/bbox/structures/utils.py
View file @
16202272
...
...
@@ -59,3 +59,16 @@ def rotation_3d_in_axis(points, angles, axis=0):
raise
ValueError
(
f
'axis should in range [0, 1, 2], got
{
axis
}
'
)
return
torch
.
einsum
(
'aij,jka->aik'
,
(
points
,
rot_mat_T
))
def
xywhr2xyxyr
(
boxes_xywhr
):
boxes
=
torch
.
zeros_like
(
boxes_xywhr
)
half_w
=
boxes_xywhr
[:,
2
]
/
2
half_h
=
boxes_xywhr
[:,
3
]
/
2
boxes
[:,
0
]
=
boxes_xywhr
[:,
0
]
-
half_w
boxes
[:,
1
]
=
boxes_xywhr
[:,
1
]
-
half_h
boxes
[:,
2
]
=
boxes_xywhr
[:,
0
]
+
half_w
boxes
[:,
3
]
=
boxes_xywhr
[:,
1
]
+
half_h
boxes
[:,
4
]
=
boxes_xywhr
[:,
4
]
return
boxes
mmdet3d/datasets/pipelines/indoor_loading.py
View file @
16202272
...
...
@@ -76,7 +76,7 @@ class IndoorLoadPointsFromFile(object):
return
repr_str
@
PIPELINES
.
register_module
@
PIPELINES
.
register_module
()
class
IndoorLoadAnnotations3D
(
object
):
"""Indoor Load Annotations3D.
...
...
mmdet3d/models/roi_heads/roi_extractors/single_roiaware_extractor.py
View file @
16202272
...
...
@@ -5,7 +5,7 @@ from mmdet3d import ops
from
mmdet.models.builder
import
ROI_EXTRACTORS
@
ROI_EXTRACTORS
.
register_module
@
ROI_EXTRACTORS
.
register_module
()
class
Single3DRoIAwareExtractor
(
nn
.
Module
):
"""Point-wise roi-aware Extractor
...
...
tests/test_box3d.py
View file @
16202272
...
...
@@ -284,6 +284,14 @@ def test_boxes_conversion():
[
31.31978
,
8.162144
,
-
1.6217787
,
1.74
,
3.77
,
1.48
,
2.79
]])
cam_box_tensor
=
Box3DMode
.
convert
(
lidar_boxes
.
tensor
,
Box3DMode
.
LIDAR
,
Box3DMode
.
CAM
)
# Some properties should be the same
cam_boxes
=
CameraInstance3DBoxes
(
cam_box_tensor
)
assert
torch
.
equal
(
cam_boxes
.
height
,
lidar_boxes
.
height
)
assert
torch
.
equal
(
cam_boxes
.
top_height
,
lidar_boxes
.
top_height
)
assert
torch
.
equal
(
cam_boxes
.
bottom_height
,
lidar_boxes
.
bottom_height
)
assert
torch
.
equal
(
cam_boxes
.
volume
,
lidar_boxes
.
volume
)
lidar_box_tensor
=
Box3DMode
.
convert
(
cam_box_tensor
,
Box3DMode
.
CAM
,
Box3DMode
.
LIDAR
)
expected_tensor
=
torch
.
tensor
(
...
...
@@ -601,54 +609,50 @@ def test_camera_boxes3d():
def
test_boxes3d_overlaps
():
"""Test the iou calculation of boxes in different modes.
ComandLine:
xdoctest tests/test_box3d.py::test_boxes3d_overlaps zero
"""
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
'test requires GPU and torch+cuda'
)
# Test LiDAR boxes 3D overlaps
boxes1_tensor
=
torch
.
tensor
(
[[
1.8
,
-
2.5
-
1.8
,
1.75
,
3.39
,
1.65
,
1.6615927
],
[[
1.8
,
-
2.5
,
-
1.8
,
1.75
,
3.39
,
1.65
,
1.6615927
],
[
8.9
,
-
2.5
,
-
1.6
,
1.54
,
4.01
,
1.57
,
1.5215927
],
[
28.3
,
0.5
,
-
1.3
,
1.47
,
2.23
,
1.48
,
4.7115927
],
[
31.3
,
-
8.2
,
-
1.6
,
1.74
,
3.77
,
1.48
,
0.35
159278
]],
[
31.3
,
-
8.2
,
-
1.6
,
1.74
,
3.77
,
1.48
,
0.35
]],
device
=
'cuda'
)
boxes1
=
LiDARInstance3DBoxes
(
boxes1_tensor
)
boxes2_tensor
=
torch
.
tensor
([[
1.2
,
-
3.0
,
-
1.9
,
1.8
,
3.4
,
1.7
,
1.9
],
[
8.1
,
-
2.9
,
-
1.8
,
1.5
,
4.1
,
1.6
,
1.8
],
[
20.1
,
-
2
8.
5
,
-
1.
9
,
1.
6
,
3.
5
,
1.4
,
5.1
],
[
2
8.2
,
-
16
.5
,
-
1.
8
,
1.
7
,
3.
8
,
1.
5
,
0.6
]],
[
31.3
,
-
8.
2
,
-
1.
6
,
1.
74
,
3.
77
,
1.4
8
,
0.35
],
[
2
0.1
,
-
28
.5
,
-
1.
9
,
1.
6
,
3.
5
,
1.
4
,
5.1
]],
device
=
'cuda'
)
boxes2
=
LiDARInstance3DBoxes
(
boxes2_tensor
)
from
mmdet3d.ops.iou3d
import
boxes3d_to_bev_torch_lidar
expected_tensor
=
boxes3d_to_bev_torch_lidar
(
boxes1_tensor
,
boxes2_tensor
)
expected_tensor
=
torch
.
tensor
(
[[
0.3710
,
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.3322
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
0.0000
,
0.0000
],
[
0.0000
,
0.0000
,
1.0000
,
0.0000
]],
device
=
'cuda'
)
overlaps_3d
=
boxes1
.
overlaps
(
boxes1
,
boxes2
)
assert
torch
.
allclose
(
expected_tensor
,
overlaps_3d
)
assert
torch
.
allclose
(
expected_tensor
,
overlaps_3d
,
rtol
=
1e-4
,
atol
=
1e-7
)
# Test camera boxes 3D overlaps
boxes1_tensor
=
torch
.
tensor
(
[[
1.8
,
-
2.5
-
1.8
,
1.75
,
3.39
,
1.65
,
1.6615927
],
[
8.9
,
-
2.5
,
-
1.6
,
1.54
,
4.01
,
1.57
,
1.5215927
],
[
28.3
,
0.5
,
-
1.3
,
1.47
,
2.23
,
1.48
,
4.7115927
],
[
31.3
,
-
8.2
,
-
1.6
,
1.74
,
3.77
,
1.48
,
0.35159278
]],
device
=
'cuda'
)
cam_boxes1_tensor
=
Box3DMode
.
convert
(
boxes1_tensor
,
Box3DMode
.
LIDAR
,
Box3DMode
.
CAM
)
cam_boxes1
=
CameraInstance3DBoxes
(
cam_boxes1_tensor
)
boxes2_tensor
=
torch
.
tensor
([[
1.2
,
-
3.0
,
-
1.9
,
1.8
,
3.4
,
1.7
,
1.9
],
[
8.1
,
-
2.9
,
-
1.8
,
1.5
,
4.1
,
1.6
,
1.8
],
[
20.1
,
-
28.5
,
-
1.9
,
1.6
,
3.5
,
1.4
,
5.1
],
[
28.2
,
-
16.5
,
-
1.8
,
1.7
,
3.8
,
1.5
,
0.6
]],
device
=
'cuda'
)
cam_boxes2_tensor
=
Box3DMode
.
convert
(
boxes2_tensor
,
Box3DMode
.
LIDAR
,
Box3DMode
.
CAM
)
cam_boxes2
=
CameraInstance3DBoxes
(
cam_boxes2_tensor
)
cam_overlaps_3d
=
cam_boxes1
.
overlaps
(
cam_boxes1
,
cam_boxes2
)
from
mmdet3d.ops.iou3d
import
boxes3d_to_bev_torch_camera
expected_tensor
=
boxes3d_to_bev_torch_camera
(
boxes1_tensor
,
boxes2_tensor
)
assert
torch
.
allclose
(
expected_tensor
,
cam_overlaps_3d
)
# same boxes under different coordinates should have the same iou
assert
torch
.
allclose
(
expected_tensor
,
cam_overlaps_3d
,
rtol
=
1e-4
,
atol
=
1e-7
)
assert
torch
.
allclose
(
cam_overlaps_3d
,
overlaps_3d
)
with
pytest
.
raises
(
AssertionError
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment