Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
55d12ff5
"git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "6f80c436769a40c80ba1e9d7bfe7397015bbbcf0"
Commit
55d12ff5
authored
Dec 26, 2021
by
Shaoshuai Shi
Browse files
add PVRCNNPlusPlus detector
parent
13789796
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
0 deletions
+53
-0
pcdet/models/detectors/pv_rcnn_plusplus.py
pcdet/models/detectors/pv_rcnn_plusplus.py
+53
-0
No files found.
pcdet/models/detectors/pv_rcnn_plusplus.py
0 → 100644
View file @
55d12ff5
from
.detector3d_template
import
Detector3DTemplate
class
PVRCNNPlusPlus
(
Detector3DTemplate
):
def
__init__
(
self
,
model_cfg
,
num_class
,
dataset
):
super
().
__init__
(
model_cfg
=
model_cfg
,
num_class
=
num_class
,
dataset
=
dataset
)
self
.
module_list
=
self
.
build_networks
()
def
forward
(
self
,
batch_dict
):
batch_dict
=
self
.
vfe
(
batch_dict
)
batch_dict
=
self
.
backbone_3d
(
batch_dict
)
batch_dict
=
self
.
map_to_bev_module
(
batch_dict
)
batch_dict
=
self
.
backbone_2d
(
batch_dict
)
batch_dict
=
self
.
dense_head
(
batch_dict
)
batch_dict
=
self
.
roi_head
.
proposal_layer
(
batch_dict
,
nms_config
=
self
.
roi_head
.
model_cfg
.
NMS_CONFIG
[
'TRAIN'
if
self
.
training
else
'TEST'
]
)
if
self
.
training
:
targets_dict
=
self
.
roi_head
.
assign_targets
(
batch_dict
)
batch_dict
[
'rois'
]
=
targets_dict
[
'rois'
]
batch_dict
[
'roi_labels'
]
=
targets_dict
[
'roi_labels'
]
batch_dict
[
'roi_targets_dict'
]
=
targets_dict
num_rois_per_scene
=
targets_dict
[
'rois'
].
shape
[
1
]
if
'roi_valid_num'
in
batch_dict
:
batch_dict
[
'roi_valid_num'
]
=
[
num_rois_per_scene
for
_
in
range
(
batch_dict
[
'batch_size'
])]
batch_dict
=
self
.
pfe
(
batch_dict
)
batch_dict
=
self
.
point_head
(
batch_dict
)
batch_dict
=
self
.
roi_head
(
batch_dict
)
if
self
.
training
:
loss
,
tb_dict
,
disp_dict
=
self
.
get_training_loss
()
ret_dict
=
{
'loss'
:
loss
}
return
ret_dict
,
tb_dict
,
disp_dict
else
:
pred_dicts
,
recall_dicts
=
self
.
post_processing
(
batch_dict
)
return
pred_dicts
,
recall_dicts
def
get_training_loss
(
self
):
disp_dict
=
{}
loss_rpn
,
tb_dict
=
self
.
dense_head
.
get_loss
()
if
self
.
point_head
is
not
None
:
loss_point
,
tb_dict
=
self
.
point_head
.
get_loss
(
tb_dict
)
else
:
loss_point
=
0
loss_rcnn
,
tb_dict
=
self
.
roi_head
.
get_loss
(
tb_dict
)
loss
=
loss_rpn
+
loss_point
+
loss_rcnn
return
loss
,
tb_dict
,
disp_dict
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment