Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
12c3b19d
Unverified
Commit
12c3b19d
authored
Jan 19, 2023
by
SekiroRong
Committed by
GitHub
Jan 19, 2023
Browse files
[fix] fix bug in projects/PETR (#2212)
parent
8bf2f5a4
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
48 deletions
+27
-48
projects/PETR/petr/petr.py
projects/PETR/petr/petr.py
+4
-21
projects/PETR/petr/transforms_3d.py
projects/PETR/petr/transforms_3d.py
+5
-3
projects/PETR/petr/utils.py
projects/PETR/petr/utils.py
+18
-24
No files found.
projects/PETR/petr/petr.py
View file @
12c3b19d
...
@@ -8,14 +8,11 @@
...
@@ -8,14 +8,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------
# ------------------------------------------------------------------------
import
numpy
as
np
import
torch
import
torch
from
mmengine.structures
import
InstanceData
from
mmengine.structures
import
InstanceData
import
mmdet3d
from
mmdet3d.models.detectors.mvx_two_stage
import
MVXTwoStageDetector
from
mmdet3d.models.detectors.mvx_two_stage
import
MVXTwoStageDetector
from
mmdet3d.registry
import
MODELS
from
mmdet3d.registry
import
MODELS
from
mmdet3d.structures.bbox_3d
import
LiDARInstance3DBoxes
,
limit_period
from
mmdet3d.structures.ops
import
bbox3d2result
from
mmdet3d.structures.ops
import
bbox3d2result
from
.grid_mask
import
GridMask
from
.grid_mask
import
GridMask
...
@@ -174,8 +171,6 @@ class PETR(MVXTwoStageDetector):
...
@@ -174,8 +171,6 @@ class PETR(MVXTwoStageDetector):
gt_labels_3d
=
[
gt
.
labels_3d
for
gt
in
batch_gt_instances_3d
]
gt_labels_3d
=
[
gt
.
labels_3d
for
gt
in
batch_gt_instances_3d
]
gt_bboxes_ignore
=
None
gt_bboxes_ignore
=
None
gt_bboxes_3d
=
self
.
LidarBox3dVersionTransfrom
(
gt_bboxes_3d
)
batch_img_metas
=
self
.
add_lidar2img
(
img
,
batch_img_metas
)
batch_img_metas
=
self
.
add_lidar2img
(
img
,
batch_img_metas
)
img_feats
=
self
.
extract_feat
(
img
=
img
,
img_metas
=
batch_img_metas
)
img_feats
=
self
.
extract_feat
(
img
=
img
,
img_metas
=
batch_img_metas
)
...
@@ -265,8 +260,8 @@ class PETR(MVXTwoStageDetector):
...
@@ -265,8 +260,8 @@ class PETR(MVXTwoStageDetector):
Returns:
Returns:
batch_input_metas (list[dict]): Meta info with lidar2img added
batch_input_metas (list[dict]): Meta info with lidar2img added
"""
"""
lidar2img_rts
=
[]
for
meta
in
batch_input_metas
:
for
meta
in
batch_input_metas
:
lidar2img_rts
=
[]
# obtain lidar to image transformation matrix
# obtain lidar to image transformation matrix
for
i
in
range
(
len
(
meta
[
'cam2img'
])):
for
i
in
range
(
len
(
meta
[
'cam2img'
])):
lidar2cam_rt
=
torch
.
tensor
(
meta
[
'lidar2cam'
][
i
]).
double
()
lidar2cam_rt
=
torch
.
tensor
(
meta
[
'lidar2cam'
][
i
]).
double
()
...
@@ -281,19 +276,7 @@ class PETR(MVXTwoStageDetector):
...
@@ -281,19 +276,7 @@ class PETR(MVXTwoStageDetector):
# and LoadMultiViewImageFromMultiSweepsFiles.
# and LoadMultiViewImageFromMultiSweepsFiles.
lidar2img_rts
.
append
(
lidar2img_rt
)
lidar2img_rts
.
append
(
lidar2img_rt
)
meta
[
'lidar2img'
]
=
lidar2img_rts
meta
[
'lidar2img'
]
=
lidar2img_rts
meta
[
'img_shape'
]
=
[
i
.
shape
for
i
in
img
[
0
]
]
img_shape
=
meta
[
'img_shape'
]
[:
3
]
return
batch_input_metas
meta
[
'img_shape'
]
=
[
img_shape
]
*
len
(
img
[
0
])
def
LidarBox3dVersionTransfrom
(
self
,
gt_bboxes_3d
):
return
batch_input_metas
if
int
(
mmdet3d
.
__version__
[
0
])
>=
1
:
# Begin hack adaptation to mmdet3d v1.0 ####
gt_bboxes_3d
=
gt_bboxes_3d
[
0
].
tensor
gt_bboxes_3d
[:,
[
3
,
4
]]
=
gt_bboxes_3d
[:,
[
4
,
3
]]
gt_bboxes_3d
[:,
6
]
=
-
gt_bboxes_3d
[:,
6
]
-
np
.
pi
/
2
gt_bboxes_3d
[:,
6
]
=
limit_period
(
gt_bboxes_3d
[:,
6
],
period
=
np
.
pi
*
2
)
gt_bboxes_3d
=
LiDARInstance3DBoxes
(
gt_bboxes_3d
,
box_dim
=
9
)
gt_bboxes_3d
=
[
gt_bboxes_3d
]
return
gt_bboxes_3d
projects/PETR/petr/transforms_3d.py
View file @
12c3b19d
...
@@ -185,8 +185,8 @@ class GlobalRotScaleTransImage(BaseTransform):
...
@@ -185,8 +185,8 @@ class GlobalRotScaleTransImage(BaseTransform):
num_view
=
len
(
results
[
'lidar2cam'
])
num_view
=
len
(
results
[
'lidar2cam'
])
for
view
in
range
(
num_view
):
for
view
in
range
(
num_view
):
results
[
'lidar2cam'
][
view
]
=
(
results
[
'lidar2cam'
][
view
]
=
(
torch
.
tensor
(
np
.
array
(
results
[
'lidar2cam'
][
view
])).
float
()
torch
.
tensor
(
np
.
array
(
results
[
'lidar2cam'
][
view
])
.
T
).
float
()
@
rot_mat_inv
).
numpy
()
@
rot_mat_inv
).
T
.
numpy
()
return
return
...
@@ -203,5 +203,7 @@ class GlobalRotScaleTransImage(BaseTransform):
...
@@ -203,5 +203,7 @@ class GlobalRotScaleTransImage(BaseTransform):
num_view
=
len
(
results
[
'lidar2cam'
])
num_view
=
len
(
results
[
'lidar2cam'
])
for
view
in
range
(
num_view
):
for
view
in
range
(
num_view
):
results
[
'lidar2cam'
][
view
]
=
(
torch
.
tensor
(
results
[
'lidar2cam'
][
view
]
=
(
torch
.
tensor
(
rot_mat_inv
.
T
@
results
[
'lidar2cam'
][
view
]).
float
()).
numpy
()
rot_mat_inv
.
T
@
results
[
'lidar2cam'
][
view
].
T
).
float
()).
T
.
numpy
()
return
return
projects/PETR/petr/utils.py
View file @
12c3b19d
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
mmdet3d
from
mmdet3d.structures.bbox_3d.utils
import
limit_period
from
mmdet3d.structures.bbox_3d.utils
import
limit_period
...
@@ -11,19 +10,21 @@ def normalize_bbox(bboxes, pc_range):
...
@@ -11,19 +10,21 @@ def normalize_bbox(bboxes, pc_range):
cx
=
bboxes
[...,
0
:
1
]
cx
=
bboxes
[...,
0
:
1
]
cy
=
bboxes
[...,
1
:
2
]
cy
=
bboxes
[...,
1
:
2
]
cz
=
bboxes
[...,
2
:
3
]
cz
=
bboxes
[...,
2
:
3
]
w
=
bboxes
[...,
3
:
4
].
log
()
length
=
bboxes
[...,
3
:
4
].
log
()
leng
th
=
bboxes
[...,
4
:
5
].
log
()
wid
th
=
bboxes
[...,
4
:
5
].
log
()
h
=
bboxes
[...,
5
:
6
].
log
()
h
eight
=
bboxes
[...,
5
:
6
].
log
()
rot
=
bboxes
[...,
6
:
7
]
rot
=
-
bboxes
[...,
6
:
7
]
-
np
.
pi
/
2
rot
=
limit_period
(
rot
,
period
=
np
.
pi
*
2
)
if
bboxes
.
size
(
-
1
)
>
7
:
if
bboxes
.
size
(
-
1
)
>
7
:
vx
=
bboxes
[...,
7
:
8
]
vx
=
bboxes
[...,
7
:
8
]
vy
=
bboxes
[...,
8
:
9
]
vy
=
bboxes
[...,
8
:
9
]
normalized_bboxes
=
torch
.
cat
(
normalized_bboxes
=
torch
.
cat
(
(
cx
,
cy
,
w
,
length
,
cz
,
h
,
rot
.
sin
(),
rot
.
cos
(),
vx
,
vy
),
dim
=-
1
)
(
cx
,
cy
,
length
,
width
,
cz
,
height
,
rot
.
sin
(),
rot
.
cos
(),
vx
,
vy
),
dim
=-
1
)
else
:
else
:
normalized_bboxes
=
torch
.
cat
(
normalized_bboxes
=
torch
.
cat
(
(
cx
,
cy
,
w
,
length
,
cz
,
h
,
rot
.
sin
(),
rot
.
cos
()),
dim
=-
1
)
(
cx
,
cy
,
length
,
width
,
cz
,
height
,
rot
.
sin
(),
rot
.
cos
()),
dim
=-
1
)
return
normalized_bboxes
return
normalized_bboxes
...
@@ -33,6 +34,8 @@ def denormalize_bbox(normalized_bboxes, pc_range):
...
@@ -33,6 +34,8 @@ def denormalize_bbox(normalized_bboxes, pc_range):
rot_cosine
=
normalized_bboxes
[...,
7
:
8
]
rot_cosine
=
normalized_bboxes
[...,
7
:
8
]
rot
=
torch
.
atan2
(
rot_sine
,
rot_cosine
)
rot
=
torch
.
atan2
(
rot_sine
,
rot_cosine
)
rot
=
-
rot
-
np
.
pi
/
2
rot
=
limit_period
(
rot
,
period
=
np
.
pi
*
2
)
# center in the bev
# center in the bev
cx
=
normalized_bboxes
[...,
0
:
1
]
cx
=
normalized_bboxes
[...,
0
:
1
]
...
@@ -40,30 +43,21 @@ def denormalize_bbox(normalized_bboxes, pc_range):
...
@@ -40,30 +43,21 @@ def denormalize_bbox(normalized_bboxes, pc_range):
cz
=
normalized_bboxes
[...,
4
:
5
]
cz
=
normalized_bboxes
[...,
4
:
5
]
# size
# size
w
=
normalized_bboxes
[...,
2
:
3
]
length
=
normalized_bboxes
[...,
2
:
3
]
leng
th
=
normalized_bboxes
[...,
3
:
4
]
wid
th
=
normalized_bboxes
[...,
3
:
4
]
h
=
normalized_bboxes
[...,
5
:
6
]
h
eight
=
normalized_bboxes
[...,
5
:
6
]
w
=
w
.
exp
()
w
idth
=
width
.
exp
()
length
=
length
.
exp
()
length
=
length
.
exp
()
h
=
h
.
exp
()
h
eight
=
height
.
exp
()
if
normalized_bboxes
.
size
(
-
1
)
>
8
:
if
normalized_bboxes
.
size
(
-
1
)
>
8
:
# velocity
# velocity
vx
=
normalized_bboxes
[:,
8
:
9
]
vx
=
normalized_bboxes
[:,
8
:
9
]
vy
=
normalized_bboxes
[:,
9
:
10
]
vy
=
normalized_bboxes
[:,
9
:
10
]
denormalized_bboxes
=
torch
.
cat
(
denormalized_bboxes
=
torch
.
cat
(
[
cx
,
cy
,
cz
,
w
,
length
,
h
,
rot
,
vx
,
vy
],
dim
=-
1
)
[
cx
,
cy
,
cz
,
length
,
width
,
height
,
rot
,
vx
,
vy
],
dim
=-
1
)
else
:
else
:
denormalized_bboxes
=
torch
.
cat
(
[
cx
,
cy
,
cz
,
w
,
length
,
h
,
rot
],
denormalized_bboxes
=
torch
.
cat
(
dim
=-
1
)
[
cx
,
cy
,
cz
,
length
,
width
,
height
,
rot
],
dim
=-
1
)
if
int
(
mmdet3d
.
__version__
[
0
])
>=
1
:
denormalized_bboxes_clone
=
denormalized_bboxes
.
clone
()
denormalized_bboxes
[:,
3
]
=
denormalized_bboxes_clone
[:,
4
]
denormalized_bboxes
[:,
4
]
=
denormalized_bboxes_clone
[:,
3
]
# change yaw
denormalized_bboxes
[:,
6
]
=
-
denormalized_bboxes_clone
[:,
6
]
-
np
.
pi
/
2
denormalized_bboxes
[:,
6
]
=
limit_period
(
denormalized_bboxes
[:,
6
],
period
=
np
.
pi
*
2
)
return
denormalized_bboxes
return
denormalized_bboxes
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment