Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
9f5d201e
Unverified
Commit
9f5d201e
authored
Jul 16, 2020
by
Shaoshuai Shi
Committed by
GitHub
Jul 16, 2020
Browse files
add torch.no_grad context for proposal_layer (#160)
parent
2400fdf2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
2 deletions
+5
-2
pcdet/models/dense_heads/anchor_head_multi.py
pcdet/models/dense_heads/anchor_head_multi.py
+4
-2
pcdet/models/roi_heads/roi_head_template.py
pcdet/models/roi_heads/roi_head_template.py
+1
-0
No files found.
pcdet/models/dense_heads/anchor_head_multi.py
View file @
9f5d201e
...
@@ -4,6 +4,7 @@ from .anchor_head_template import AnchorHeadTemplate
...
@@ -4,6 +4,7 @@ from .anchor_head_template import AnchorHeadTemplate
from
..backbones_2d
import
BaseBEVBackbone
from
..backbones_2d
import
BaseBEVBackbone
import
torch
import
torch
class
SingleHead
(
BaseBEVBackbone
):
class
SingleHead
(
BaseBEVBackbone
):
def
__init__
(
self
,
model_cfg
,
input_channels
,
num_class
,
num_anchors_per_location
,
code_size
,
encode_conv_cfg
=
None
):
def
__init__
(
self
,
model_cfg
,
input_channels
,
num_class
,
num_anchors_per_location
,
code_size
,
encode_conv_cfg
=
None
):
super
().
__init__
(
encode_conv_cfg
,
input_channels
)
super
().
__init__
(
encode_conv_cfg
,
input_channels
)
...
@@ -75,6 +76,7 @@ class SingleHead(BaseBEVBackbone):
...
@@ -75,6 +76,7 @@ class SingleHead(BaseBEVBackbone):
return
ret_dict
return
ret_dict
class
AnchorHeadMulti
(
AnchorHeadTemplate
):
class
AnchorHeadMulti
(
AnchorHeadTemplate
):
def
__init__
(
self
,
model_cfg
,
input_channels
,
num_class
,
class_names
,
grid_size
,
point_cloud_range
,
predict_boxes_when_training
=
True
):
def
__init__
(
self
,
model_cfg
,
input_channels
,
num_class
,
class_names
,
grid_size
,
point_cloud_range
,
predict_boxes_when_training
=
True
):
super
().
__init__
(
super
().
__init__
(
...
@@ -82,7 +84,6 @@ class AnchorHeadMulti(AnchorHeadTemplate):
...
@@ -82,7 +84,6 @@ class AnchorHeadMulti(AnchorHeadTemplate):
)
)
self
.
model_cfg
=
model_cfg
self
.
model_cfg
=
model_cfg
self
.
make_multihead
(
input_channels
)
self
.
make_multihead
(
input_channels
)
def
make_multihead
(
self
,
input_channels
):
def
make_multihead
(
self
,
input_channels
):
rpn_head_cfgs
=
self
.
model_cfg
.
RPN_HEAD_CFGS
rpn_head_cfgs
=
self
.
model_cfg
.
RPN_HEAD_CFGS
...
@@ -123,7 +124,8 @@ class AnchorHeadMulti(AnchorHeadTemplate):
...
@@ -123,7 +124,8 @@ class AnchorHeadMulti(AnchorHeadTemplate):
gt_boxes
=
data_dict
[
'gt_boxes'
]
gt_boxes
=
data_dict
[
'gt_boxes'
]
)
)
self
.
forward_ret_dict
.
update
(
targets_dict
)
self
.
forward_ret_dict
.
update
(
targets_dict
)
else
:
if
not
self
.
training
or
self
.
predict_boxes_when_training
:
batch_cls_preds
,
batch_box_preds
=
self
.
generate_predicted_boxes
(
batch_cls_preds
,
batch_box_preds
=
self
.
generate_predicted_boxes
(
batch_size
=
data_dict
[
'batch_size'
],
batch_size
=
data_dict
[
'batch_size'
],
cls_preds
=
cls_preds
,
box_preds
=
box_preds
,
dir_cls_preds
=
dir_cls_preds
cls_preds
=
cls_preds
,
box_preds
=
box_preds
,
dir_cls_preds
=
dir_cls_preds
...
...
pcdet/models/roi_heads/roi_head_template.py
View file @
9f5d201e
...
@@ -39,6 +39,7 @@ class RoIHeadTemplate(nn.Module):
...
@@ -39,6 +39,7 @@ class RoIHeadTemplate(nn.Module):
fc_layers
=
nn
.
Sequential
(
*
fc_layers
)
fc_layers
=
nn
.
Sequential
(
*
fc_layers
)
return
fc_layers
return
fc_layers
@
torch
.
no_grad
()
def
proposal_layer
(
self
,
batch_dict
,
nms_config
):
def
proposal_layer
(
self
,
batch_dict
,
nms_config
):
"""
"""
Args:
Args:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment