Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
d3179c0f
Commit
d3179c0f
authored
Nov 26, 2021
by
Shaoshuai Shi
Browse files
support CenterHead for RCNN training/testing
parent
2d269538
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
2 deletions
+31
-2
pcdet/models/dense_heads/center_head.py
pcdet/models/dense_heads/center_head.py
+28
-2
pcdet/models/roi_heads/roi_head_template.py
pcdet/models/roi_heads/roi_head_template.py
+3
-0
No files found.
pcdet/models/dense_heads/center_head.py
View file @
d3179c0f
...
...
@@ -302,6 +302,24 @@ class CenterHead(nn.Module):
return
ret_dict
@
staticmethod
def
reorder_rois_for_refining
(
batch_size
,
pred_dicts
):
num_max_rois
=
max
([
len
(
cur_dict
[
'pred_boxes'
])
for
cur_dict
in
pred_dicts
])
num_max_rois
=
max
(
1
,
num_max_rois
)
# at least one faked rois to avoid error
pred_boxes
=
pred_dicts
[
0
][
'pred_boxes'
]
rois
=
pred_boxes
.
new_zeros
((
batch_size
,
num_max_rois
,
pred_boxes
.
shape
[
-
1
]))
roi_scores
=
pred_boxes
.
new_zeros
((
batch_size
,
num_max_rois
))
roi_labels
=
pred_boxes
.
new_zeros
((
batch_size
,
num_max_rois
)).
long
()
for
bs_idx
in
range
(
batch_size
):
num_boxes
=
len
(
pred_dicts
[
bs_idx
][
'pred_boxes'
])
rois
[
bs_idx
,
:
num_boxes
,
:]
=
pred_dicts
[
bs_idx
][
'pred_boxes'
]
roi_scores
[
bs_idx
,
:
num_boxes
]
=
pred_dicts
[
bs_idx
][
'pred_scores'
]
roi_labels
[
bs_idx
,
:
num_boxes
]
=
pred_dicts
[
bs_idx
][
'pred_labels'
]
return
rois
,
roi_scores
,
roi_labels
def
forward
(
self
,
data_dict
):
spatial_features_2d
=
data_dict
[
'spatial_features_2d'
]
x
=
self
.
shared_conv
(
spatial_features_2d
)
...
...
@@ -320,9 +338,17 @@ class CenterHead(nn.Module):
self
.
forward_ret_dict
[
'pred_dicts'
]
=
pred_dicts
if
not
self
.
training
or
self
.
predict_boxes_when_training
:
final_box
_dicts
=
self
.
generate_predicted_boxes
(
pred
_dicts
=
self
.
generate_predicted_boxes
(
data_dict
[
'batch_size'
],
pred_dicts
)
data_dict
[
'final_box_dicts'
]
=
final_box_dicts
if
self
.
predict_boxes_when_training
:
rois
,
roi_scores
,
roi_labels
=
self
.
reorder_rois_for_refining
(
data_dict
[
'batch_size'
],
pred_dicts
)
data_dict
[
'rois'
]
=
rois
data_dict
[
'roi_scores'
]
=
roi_scores
data_dict
[
'roi_labels'
]
=
roi_labels
data_dict
[
'has_class_labels'
]
=
True
else
:
data_dict
[
'final_box_dicts'
]
=
pred_dicts
return
data_dict
pcdet/models/roi_heads/roi_head_template.py
View file @
d3179c0f
...
...
@@ -61,6 +61,9 @@ class RoIHeadTemplate(nn.Module):
roi_labels: (B, num_rois)
"""
if
batch_dict
.
get
(
'rois'
,
None
)
is
not
None
:
return
batch_dict
batch_size
=
batch_dict
[
'batch_size'
]
batch_box_preds
=
batch_dict
[
'batch_box_preds'
]
batch_cls_preds
=
batch_dict
[
'batch_cls_preds'
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment