Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
0ac2901c
Commit
0ac2901c
authored
Jul 06, 2020
by
Shaoshuai Shi
Browse files
bugfixed: loss calculation for separted multi-head
parent
19c66f79
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
27 deletions
+48
-27
pcdet/models/dense_heads/anchor_head_multi.py
pcdet/models/dense_heads/anchor_head_multi.py
+47
-26
tools/cfgs/kitti_models/second_multihead.yaml
tools/cfgs/kitti_models/second_multihead.yaml
+1
-1
No files found.
pcdet/models/dense_heads/anchor_head_multi.py
View file @
0ac2901c
...
...
@@ -78,19 +78,26 @@ class SingleHead(BaseBEVBackbone):
class
AnchorHeadMulti
(
AnchorHeadTemplate
):
def
__init__
(
self
,
model_cfg
,
input_channels
,
num_class
,
class_names
,
grid_size
,
point_cloud_range
,
predict_boxes_when_training
=
True
):
def
__init__
(
self
,
model_cfg
,
input_channels
,
num_class
,
class_names
,
grid_size
,
point_cloud_range
,
predict_boxes_when_training
=
True
):
super
().
__init__
(
model_cfg
=
model_cfg
,
num_class
=
num_class
,
class_names
=
class_names
,
grid_size
=
grid_size
,
point_cloud_range
=
point_cloud_range
,
predict_boxes_when_training
=
predict_boxes_when_training
model_cfg
=
model_cfg
,
num_class
=
num_class
,
class_names
=
class_names
,
grid_size
=
grid_size
,
point_cloud_range
=
point_cloud_range
,
predict_boxes_when_training
=
predict_boxes_when_training
)
self
.
model_cfg
=
model_cfg
self
.
seperate_multihead
=
self
.
model_cfg
.
get
(
'SEPERATE_MULTIHEAD'
,
False
)
shared_conv_num_filter
=
self
.
model_cfg
.
SHARED_CONV_NUM_FILTER
self
.
shared_conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
input_channels
,
shared_conv_num_filter
,
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
shared_conv_num_filter
,
eps
=
1e-3
,
momentum
=
0.01
),
nn
.
ReLU
(),
)
self
.
separate_multihead
=
self
.
model_cfg
.
get
(
'SEPARATE_MULTIHEAD'
,
False
)
if
self
.
model_cfg
.
get
(
'SHARED_CONV_NUM_FILTER'
,
None
)
is
not
None
:
shared_conv_num_filter
=
self
.
model_cfg
.
SHARED_CONV_NUM_FILTER
self
.
shared_conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
input_channels
,
shared_conv_num_filter
,
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
shared_conv_num_filter
,
eps
=
1e-3
,
momentum
=
0.01
),
nn
.
ReLU
(),
)
else
:
self
.
shared_conv
=
None
shared_conv_num_filter
=
input_channels
self
.
rpn_heads
=
None
self
.
make_multihead
(
shared_conv_num_filter
)
def
make_multihead
(
self
,
input_channels
):
...
...
@@ -99,15 +106,22 @@ class AnchorHeadMulti(AnchorHeadTemplate):
class_names
=
[]
for
rpn_head_cfg
in
rpn_head_cfgs
:
class_names
.
extend
(
rpn_head_cfg
[
'HEAD_CLS_NAME'
])
for
rpn_head_cfg
in
rpn_head_cfgs
:
num_anchors_per_location
=
sum
([
self
.
num_anchors_per_location
[
class_names
.
index
(
head_cls
)]
for
head_cls
in
rpn_head_cfg
[
'HEAD_CLS_NAME'
]])
rpn_head
=
SingleHead
(
self
.
model_cfg
,
input_channels
,
len
(
rpn_head_cfg
[
'HEAD_CLS_NAME'
])
if
self
.
seperate_multihead
else
self
.
num_class
,
num_anchors_per_location
,
self
.
box_coder
.
code_size
,
rpn_head_cfg
)
num_anchors_per_location
=
sum
([
self
.
num_anchors_per_location
[
class_names
.
index
(
head_cls
)]
for
head_cls
in
rpn_head_cfg
[
'HEAD_CLS_NAME'
]])
rpn_head
=
SingleHead
(
self
.
model_cfg
,
input_channels
,
len
(
rpn_head_cfg
[
'HEAD_CLS_NAME'
])
if
self
.
separate_multihead
else
self
.
num_class
,
num_anchors_per_location
,
self
.
box_coder
.
code_size
,
rpn_head_cfg
)
rpn_heads
.
append
(
rpn_head
)
self
.
rpn_heads
=
nn
.
ModuleList
(
rpn_heads
)
def
forward
(
self
,
data_dict
):
spatial_features_2d
=
data_dict
[
'spatial_features_2d'
]
spatial_features_2d
=
self
.
shared_conv
(
spatial_features_2d
)
if
self
.
shared_conv
is
not
None
:
spatial_features_2d
=
self
.
shared_conv
(
spatial_features_2d
)
ret_dicts
=
[]
for
rpn_head
in
self
.
rpn_heads
:
...
...
@@ -115,15 +129,14 @@ class AnchorHeadMulti(AnchorHeadTemplate):
cls_preds
=
[
ret_dict
[
'cls_preds'
]
for
ret_dict
in
ret_dicts
]
box_preds
=
[
ret_dict
[
'box_preds'
]
for
ret_dict
in
ret_dicts
]
ret
=
{
'cls_preds'
:
cls_preds
if
self
.
sep
e
rate_multihead
else
torch
.
cat
(
cls_preds
,
dim
=
1
),
'box_preds'
:
box_preds
if
self
.
sep
e
rate_multihead
else
torch
.
cat
(
box_preds
,
dim
=
1
),
'cls_preds'
:
cls_preds
if
self
.
sep
a
rate_multihead
else
torch
.
cat
(
cls_preds
,
dim
=
1
),
'box_preds'
:
box_preds
if
self
.
sep
a
rate_multihead
else
torch
.
cat
(
box_preds
,
dim
=
1
),
}
if
self
.
model_cfg
.
get
(
'USE_DIRECTION_CLASSIFIER'
,
False
):
dir_cls_preds
=
[
ret_dict
[
'dir_cls_preds'
]
for
ret_dict
in
ret_dicts
]
ret
[
'dir_cls_preds'
]
=
dir_cls_preds
if
self
.
sep
e
rate_multihead
else
torch
.
cat
(
dir_cls_preds
,
dim
=
1
)
ret
[
'dir_cls_preds'
]
=
dir_cls_preds
if
self
.
sep
a
rate_multihead
else
torch
.
cat
(
dir_cls_preds
,
dim
=
1
)
else
:
dir_cls_preds
=
None
...
...
@@ -161,25 +174,33 @@ class AnchorHeadMulti(AnchorHeadTemplate):
# class agnostic
box_cls_labels
[
positives
]
=
1
pos_normalizer
=
positives
.
sum
(
1
,
keepdim
=
True
).
float
()
reg_weights
/=
torch
.
clamp
(
pos_normalizer
,
min
=
1.0
)
cls_weights
/=
torch
.
clamp
(
pos_normalizer
,
min
=
1.0
)
cls_targets
=
box_cls_labels
*
cared
.
type_as
(
box_cls_labels
)
one_hot_target
=
torch
.
zeros
(
*
list
(
cls_targets
.
shape
),
cls_preds
[
0
].
shape
[
-
1
]
+
1
if
self
.
seperate_multihead
else
self
.
num_class
+
1
,
dtype
=
cls_preds
[
0
].
dtype
,
device
=
cls_targets
.
device
)
one_hot_target
.
scatter_
(
-
1
,
cls_targets
.
unsqueeze
(
dim
=-
1
).
long
(),
1.0
)
one_hot_targets
=
one_hot_target
[...,
1
:]
start_idx
=
0
one_hot_target
s
=
torch
.
zeros
(
*
list
(
cls_targets
.
shape
),
self
.
num_class
+
1
,
dtype
=
cls_preds
[
0
].
dtype
,
device
=
cls_targets
.
device
)
one_hot_target
s
.
scatter_
(
-
1
,
cls_targets
.
unsqueeze
(
dim
=-
1
).
long
(),
1.0
)
one_hot_targets
=
one_hot_target
s
[...,
1
:]
start_idx
=
c_idx
=
0
cls_losses
=
0
for
cls_pred
in
cls_preds
:
cls_pred
=
cls_pred
.
view
(
batch_size
,
-
1
,
cls_pred
.
shape
[
-
1
])
one_hot_target
=
one_hot_targets
[:,
start_idx
:
start_idx
+
cls_pred
.
shape
[
1
]]
for
idx
,
cls_pred
in
enumerate
(
cls_preds
):
cur_num_class
=
self
.
rpn_heads
[
idx
].
num_class
cls_pred
=
cls_pred
.
view
(
batch_size
,
-
1
,
cur_num_class
)
if
self
.
separate_multihead
:
one_hot_target
=
one_hot_targets
[:,
start_idx
:
start_idx
+
cls_pred
.
shape
[
1
],
c_idx
:
c_idx
+
cur_num_class
]
c_idx
+=
cur_num_class
else
:
one_hot_target
=
one_hot_targets
[:,
start_idx
:
start_idx
+
cls_pred
.
shape
[
1
]]
cls_weight
=
cls_weights
[:,
start_idx
:
start_idx
+
cls_pred
.
shape
[
1
]]
cls_loss_src
=
self
.
cls_loss_func
(
cls_pred
,
one_hot_target
,
weights
=
cls_weight
)
# [N, M]
cls_loss
=
cls_loss_src
.
sum
()
/
batch_size
cls_loss
=
cls_loss
*
self
.
model_cfg
.
LOSS_CONFIG
.
LOSS_WEIGHTS
[
'cls_weight'
]
cls_losses
+=
cls_loss
start_idx
+=
cls_pred
.
shape
[
1
]
assert
start_idx
==
one_hot_targets
.
shape
[
1
]
tb_dict
=
{
'rpn_loss_cls'
:
cls_losses
.
item
()
}
...
...
tools/cfgs/kitti_models/second_multihead.yaml
View file @
0ac2901c
...
...
@@ -35,7 +35,7 @@ MODEL:
NUM_DIR_BINS
:
2
USE_MULTIHEAD
:
True
SEP
E
RATE_MULTIHEAD
:
True
SEP
A
RATE_MULTIHEAD
:
True
ANCHOR_GENERATOR_CONFIG
:
[
{
'
class_name'
:
'
Car'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment