Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
8447a475
Commit
8447a475
authored
Jul 23, 2020
by
Shaoshuai Shi
Browse files
support PointResidualCoder
parent
a2e7d474
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
82 additions
and
0 deletions
+82
-0
pcdet/utils/box_coder_utils.py
pcdet/utils/box_coder_utils.py
+82
-0
No files found.
pcdet/utils/box_coder_utils.py
View file @
8447a475
import
torch
import
torch
import
numpy
as
np
class
ResidualCoder
(
object
):
class
ResidualCoder
(
object
):
...
@@ -123,3 +124,84 @@ class PreviousResidualRoIDecoder(object):
...
@@ -123,3 +124,84 @@ class PreviousResidualRoIDecoder(object):
cgs
=
[
t
+
a
for
t
,
a
in
zip
(
cts
,
cas
)]
cgs
=
[
t
+
a
for
t
,
a
in
zip
(
cts
,
cas
)]
return
torch
.
cat
([
xg
,
yg
,
zg
,
dxg
,
dyg
,
dzg
,
rg
,
*
cgs
],
dim
=-
1
)
return
torch
.
cat
([
xg
,
yg
,
zg
,
dxg
,
dyg
,
dzg
,
rg
,
*
cgs
],
dim
=-
1
)
class
PointResidualCoder
(
object
):
def
__init__
(
self
,
code_size
=
8
,
use_mean_size
=
True
,
**
kwargs
):
super
().
__init__
()
self
.
code_size
=
code_size
self
.
use_mean_size
=
use_mean_size
if
self
.
use_mean_size
:
self
.
mean_size
=
torch
.
from_numpy
(
np
.
array
(
kwargs
[
'mean_size'
])).
cuda
().
float
()
assert
self
.
mean_size
.
min
()
>
0
def
encode_torch
(
self
,
gt_boxes
,
points
,
gt_classes
=
None
):
"""
Args:
gt_boxes: (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...]
points: (N, 3) [x, y, z]
gt_classes: (N) [1, num_classes]
Returns:
box_coding: (N, 8 + C)
"""
gt_boxes
[:,
3
:
6
]
=
torch
.
clamp_min
(
gt_boxes
[:,
3
:
6
],
min
=
1e-5
)
xg
,
yg
,
zg
,
dxg
,
dyg
,
dzg
,
rg
,
*
cgs
=
torch
.
split
(
gt_boxes
,
1
,
dim
=-
1
)
xa
,
ya
,
za
=
torch
.
split
(
points
,
1
,
dim
=-
1
)
if
self
.
use_mean_size
:
assert
gt_classes
.
max
()
<=
self
.
mean_size
.
shape
[
0
]
point_anchor_size
=
self
.
mean_size
[
gt_classes
-
1
]
dxa
,
dya
,
dza
=
torch
.
split
(
point_anchor_size
,
1
,
dim
=-
1
)
diagonal
=
torch
.
sqrt
(
dxa
**
2
+
dya
**
2
)
xt
=
(
xg
-
xa
)
/
diagonal
yt
=
(
yg
-
ya
)
/
diagonal
zt
=
(
zg
-
za
)
/
dza
dxt
=
torch
.
log
(
dxg
/
dxa
)
dyt
=
torch
.
log
(
dyg
/
dya
)
dzt
=
torch
.
log
(
dzg
/
dza
)
else
:
xt
=
(
xg
-
xa
)
yt
=
(
yg
-
ya
)
zt
=
(
zg
-
za
)
dxt
=
torch
.
log
(
dxg
)
dyt
=
torch
.
log
(
dyg
)
dzt
=
torch
.
log
(
dzg
)
cts
=
[
g
for
g
in
cgs
]
return
torch
.
cat
([
xt
,
yt
,
zt
,
dxt
,
dyt
,
dzt
,
torch
.
cos
(
rg
),
torch
.
sin
(
rg
),
*
cts
],
dim
=-
1
)
def
decode_torch
(
self
,
box_encodings
,
points
,
pred_classes
=
None
):
"""
Args:
box_encodings: (N, 8 + C) [x, y, z, dx, dy, dz, cos, sin, ...]
points: [x, y, z]
pred_classes: (N) [1, num_classes]
Returns:
"""
xt
,
yt
,
zt
,
dxt
,
dyt
,
dzt
,
cost
,
sint
,
*
cts
=
torch
.
split
(
box_encodings
,
1
,
dim
=-
1
)
xa
,
ya
,
za
=
torch
.
split
(
points
,
1
,
dim
=-
1
)
if
self
.
use_mean_size
:
assert
pred_classes
.
max
()
<=
self
.
mean_size
.
shape
[
0
]
point_anchor_size
=
self
.
mean_size
[
pred_classes
-
1
]
dxa
,
dya
,
dza
=
torch
.
split
(
point_anchor_size
,
1
,
dim
=-
1
)
diagonal
=
torch
.
sqrt
(
dxa
**
2
+
dya
**
2
)
xg
=
xt
*
diagonal
+
xa
yg
=
yt
*
diagonal
+
ya
zg
=
zt
*
dza
+
za
dxg
=
torch
.
exp
(
dxt
)
*
dxa
dyg
=
torch
.
exp
(
dyt
)
*
dya
dzg
=
torch
.
exp
(
dzt
)
*
dza
else
:
xg
=
xt
+
xa
yg
=
yt
+
ya
zg
=
zt
+
za
dxg
,
dyg
,
dzg
=
torch
.
split
(
torch
.
exp
(
box_encodings
[...,
3
:
6
]),
1
,
dim
=-
1
)
rg
=
torch
.
atan2
(
sint
,
cost
)
cgs
=
[
t
for
t
in
cts
]
return
torch
.
cat
([
xg
,
yg
,
zg
,
dxg
,
dyg
,
dzg
,
rg
,
*
cgs
],
dim
=-
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment