Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
cf29d887
Commit
cf29d887
authored
Jul 23, 2020
by
Shaoshuai Shi
Browse files
support dense_head PointHeadBox
parent
0b21d8f9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
167 additions
and
3 deletions
+167
-3
pcdet/models/dense_heads/__init__.py
pcdet/models/dense_heads/__init__.py
+2
-0
pcdet/models/dense_heads/point_head_box.py
pcdet/models/dense_heads/point_head_box.py
+114
-0
pcdet/models/dense_heads/point_head_template.py
pcdet/models/dense_heads/point_head_template.py
+51
-3
No files found.
pcdet/models/dense_heads/__init__.py
View file @
cf29d887
...
...
@@ -2,6 +2,7 @@ from .anchor_head_template import AnchorHeadTemplate
from
.anchor_head_single
import
AnchorHeadSingle
from
.point_intra_part_head
import
PointIntraPartOffsetHead
from
.point_head_simple
import
PointHeadSimple
from
.point_head_box
import
PointHeadBox
from
.anchor_head_multi
import
AnchorHeadMulti
__all__
=
{
...
...
@@ -9,5 +10,6 @@ __all__ = {
'AnchorHeadSingle'
:
AnchorHeadSingle
,
'PointIntraPartOffsetHead'
:
PointIntraPartOffsetHead
,
'PointHeadSimple'
:
PointHeadSimple
,
'PointHeadBox'
:
PointHeadBox
,
'AnchorHeadMulti'
:
AnchorHeadMulti
,
}
pcdet/models/dense_heads/point_head_box.py
0 → 100644
View file @
cf29d887
import
torch
from
.point_head_template
import
PointHeadTemplate
from
...utils
import
box_coder_utils
,
box_utils
class
PointHeadBox
(
PointHeadTemplate
):
"""
A simple point-based segmentation head, which are used for PointRCNN.
Reference Paper: https://arxiv.org/abs/1812.04244
PointRCNN: 3D Object Proposal Generation and Detection from Point Cloud
"""
def
__init__
(
self
,
num_class
,
input_channels
,
model_cfg
,
predict_boxes_when_training
=
False
,
**
kwargs
):
super
().
__init__
(
model_cfg
=
model_cfg
,
num_class
=
num_class
)
self
.
predict_boxes_when_training
=
predict_boxes_when_training
self
.
cls_layers
=
self
.
make_fc_layers
(
fc_cfg
=
self
.
model_cfg
.
CLS_FC
,
input_channels
=
input_channels
,
output_channels
=
num_class
)
target_cfg
=
self
.
model_cfg
.
TARGET_CONFIG
self
.
box_coder
=
getattr
(
box_coder_utils
,
target_cfg
.
BOX_CODER
)(
**
target_cfg
.
BOX_CODER_CONFIG
)
self
.
box_layers
=
self
.
make_fc_layers
(
fc_cfg
=
self
.
model_cfg
.
REG_FC
,
input_channels
=
input_channels
,
output_channels
=
self
.
box_coder
.
code_size
)
def
assign_targets
(
self
,
input_dict
):
"""
Args:
input_dict:
point_features: (N1 + N2 + N3 + ..., C)
batch_size:
point_coords: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z]
gt_boxes (optional): (B, M, 8)
Returns:
point_cls_labels: (N1 + N2 + N3 + ...), long type, 0:background, -1:ignored
point_part_labels: (N1 + N2 + N3 + ..., 3)
"""
point_coords
=
input_dict
[
'point_coords'
]
gt_boxes
=
input_dict
[
'gt_boxes'
]
assert
gt_boxes
.
shape
.
__len__
()
==
3
,
'gt_boxes.shape=%s'
%
str
(
gt_boxes
.
shape
)
assert
point_coords
.
shape
.
__len__
()
in
[
2
],
'points.shape=%s'
%
str
(
point_coords
.
shape
)
batch_size
=
gt_boxes
.
shape
[
0
]
extend_gt_boxes
=
box_utils
.
enlarge_box3d
(
gt_boxes
.
view
(
-
1
,
gt_boxes
.
shape
[
-
1
]),
extra_width
=
self
.
model_cfg
.
TARGET_CONFIG
.
GT_EXTRA_WIDTH
).
view
(
batch_size
,
-
1
,
gt_boxes
.
shape
[
-
1
])
targets_dict
=
self
.
assign_stack_targets
(
points
=
point_coords
,
gt_boxes
=
gt_boxes
,
extend_gt_boxes
=
extend_gt_boxes
,
set_ignore_flag
=
True
,
use_ball_constraint
=
False
,
ret_part_labels
=
False
,
ret_box_labels
=
True
)
return
targets_dict
def
get_loss
(
self
,
tb_dict
=
None
):
tb_dict
=
{}
if
tb_dict
is
None
else
tb_dict
point_loss_cls
,
tb_dict_1
=
self
.
get_cls_layer_loss
()
point_loss_box
,
tb_dict_2
=
self
.
get_box_layer_loss
()
point_loss
=
point_loss_cls
+
point_loss_box
tb_dict
.
update
(
tb_dict_1
)
tb_dict
.
update
(
tb_dict_2
)
return
point_loss
,
tb_dict
def
forward
(
self
,
batch_dict
):
"""
Args:
batch_dict:
batch_size:
point_features: (N1 + N2 + N3 + ..., C) or (B, N, C)
point_features_before_fusion: (N1 + N2 + N3 + ..., C)
point_coords: (N1 + N2 + N3 + ..., 4) [bs_idx, x, y, z]
point_labels (optional): (N1 + N2 + N3 + ...)
gt_boxes (optional): (B, M, 8)
Returns:
batch_dict:
point_cls_scores: (N1 + N2 + N3 + ..., 1)
point_part_offset: (N1 + N2 + N3 + ..., 3)
"""
if
self
.
model_cfg
.
get
(
'USE_POINT_FEATURES_BEFORE_FUSION'
,
False
):
point_features
=
batch_dict
[
'point_features_before_fusion'
]
else
:
point_features
=
batch_dict
[
'point_features'
]
point_cls_preds
=
self
.
cls_layers
(
point_features
)
# (total_points, num_class)
point_box_preds
=
self
.
box_layers
(
point_features
)
# (total_points, box_code_size)
point_cls_preds_max
,
_
=
point_cls_preds
.
max
(
dim
=-
1
)
batch_dict
[
'point_cls_scores'
]
=
torch
.
sigmoid
(
point_cls_preds_max
)
ret_dict
=
{
'point_cls_preds'
:
point_cls_preds
,
'point_box_preds'
:
point_box_preds
}
if
self
.
training
:
targets_dict
=
self
.
assign_targets
(
batch_dict
)
ret_dict
[
'point_cls_labels'
]
=
targets_dict
[
'point_cls_labels'
]
ret_dict
[
'point_box_labels'
]
=
targets_dict
[
'point_box_labels'
]
if
not
self
.
training
or
self
.
predict_boxes_when_training
:
point_cls_preds
,
point_box_preds
=
self
.
generate_predicted_boxes
(
points
=
batch_dict
[
'point_coords'
][:,
1
:
4
],
point_cls_preds
=
point_cls_preds
,
point_box_preds
=
point_box_preds
)
batch_dict
[
'batch_cls_preds'
]
=
point_cls_preds
batch_dict
[
'batch_box_preds'
]
=
point_box_preds
batch_dict
[
'batch_index'
]
=
batch_dict
[
'point_coords'
][:,
0
]
batch_dict
[
'cls_preds_normalized'
]
=
False
self
.
forward_ret_dict
=
ret_dict
return
batch_dict
pcdet/models/dense_heads/point_head_template.py
View file @
cf29d887
...
...
@@ -19,7 +19,17 @@ class PointHeadTemplate(nn.Module):
'cls_loss_func'
,
loss_utils
.
SigmoidFocalClassificationLoss
(
alpha
=
0.25
,
gamma
=
2.0
)
)
self
.
reg_loss_func
=
F
.
smooth_l1_loss
if
losses_cfg
.
get
(
'LOSS_REG'
,
None
)
==
'smooth-l1'
else
F
.
l1_loss
reg_loss_type
=
losses_cfg
.
get
(
'LOSS_REG'
,
None
)
if
reg_loss_type
==
'smooth-l1'
:
self
.
reg_loss_func
=
F
.
smooth_l1_loss
elif
reg_loss_type
==
'l1'
:
self
.
reg_loss_func
=
F
.
l1_loss
elif
reg_loss_type
==
'WeightedSmoothL1Loss'
:
self
.
reg_loss_func
=
loss_utils
.
WeightedSmoothL1Loss
(
code_weights
=
losses_cfg
.
LOSS_WEIGHTS
.
get
(
'code_weights'
,
None
)
)
else
:
self
.
reg_loss_func
=
F
.
smooth_l1_loss
@
staticmethod
def
make_fc_layers
(
fc_cfg
,
input_channels
,
output_channels
):
...
...
@@ -88,11 +98,15 @@ class PointHeadTemplate(nn.Module):
raise
NotImplementedError
gt_box_of_fg_points
=
gt_boxes
[
k
][
box_idxs_of_pts
[
fg_flag
]]
point_cls_labels_single
[
fg_flag
]
=
1
if
self
.
num_class
==
1
else
gt_box_of_fg_points
[:,
7
].
long
()
point_cls_labels_single
[
fg_flag
]
=
1
if
self
.
num_class
==
1
else
gt_box_of_fg_points
[:,
-
1
].
long
()
point_cls_labels
[
bs_mask
]
=
point_cls_labels_single
if
ret_box_labels
:
point_box_labels_single
=
point_box_labels
.
new_zeros
((
bs_mask
.
sum
(),
8
))
fg_point_box_labels
=
self
.
box_coder
.
encode_torch
(
points_single
[
fg_flag
],
gt_box_of_fg_points
)
fg_point_box_labels
=
self
.
box_coder
.
encode_torch
(
gt_boxes
=
gt_box_of_fg_points
[:,
:
-
1
],
points
=
points_single
[
fg_flag
],
gt_classes
=
gt_box_of_fg_points
[:,
-
1
].
long
()
)
point_box_labels_single
[
fg_flag
]
=
fg_point_box_labels
point_box_labels
[
bs_mask
]
=
point_box_labels_single
...
...
@@ -149,5 +163,39 @@ class PointHeadTemplate(nn.Module):
point_loss_part
=
point_loss_part
*
loss_weights_dict
[
'point_part_weight'
]
return
point_loss_part
,
{
'point_loss_part'
:
point_loss_part
.
item
()}
def
get_box_layer_loss
(
self
):
pos_mask
=
self
.
forward_ret_dict
[
'point_cls_labels'
]
>
0
point_box_labels
=
self
.
forward_ret_dict
[
'point_box_labels'
]
point_box_preds
=
self
.
forward_ret_dict
[
'point_box_preds'
]
reg_weights
=
pos_mask
.
float
()
pos_normalizer
=
pos_mask
.
sum
().
float
()
reg_weights
/=
torch
.
clamp
(
pos_normalizer
,
min
=
1.0
)
point_loss_box_src
=
self
.
reg_loss_func
(
point_box_preds
[
None
,
...],
point_box_labels
[
None
,
...],
weights
=
reg_weights
[
None
,
...]
)
point_loss_box
=
point_loss_box_src
.
sum
()
loss_weights_dict
=
self
.
model_cfg
.
LOSS_CONFIG
.
LOSS_WEIGHTS
point_loss_box
=
point_loss_box
*
loss_weights_dict
[
'point_box_weight'
]
return
point_loss_box
,
{
'point_loss_box'
:
point_loss_box
.
item
()}
def
generate_predicted_boxes
(
self
,
points
,
point_cls_preds
,
point_box_preds
):
"""
Args:
points: (N, 3)
point_cls_preds: (N, num_class)
point_box_preds: (N, box_code_size)
Returns:
point_cls_preds: (N, num_class)
point_box_preds: (N, box_code_size)
"""
_
,
pred_classes
=
point_cls_preds
.
max
(
dim
=-
1
)
point_box_preds
=
self
.
box_coder
.
decode_torch
(
point_box_preds
,
points
,
pred_classes
+
1
)
return
point_cls_preds
,
point_box_preds
def
forward
(
self
,
**
kwargs
):
raise
NotImplementedError
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment