Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
b16e0891
Commit
b16e0891
authored
Jul 06, 2020
by
Gus-Guo
Committed by
Shaoshuai Shi
Jul 06, 2020
Browse files
support multiheads that predict seperate classes
parent
5de3373d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
197 additions
and
42 deletions
+197
-42
pcdet/models/dense_heads/anchor_head_multi.py
pcdet/models/dense_heads/anchor_head_multi.py
+129
-13
pcdet/models/dense_heads/anchor_head_template.py
pcdet/models/dense_heads/anchor_head_template.py
+9
-11
pcdet/models/dense_heads/target_assigner/axis_aligned_target_assigner.py
...nse_heads/target_assigner/axis_aligned_target_assigner.py
+30
-10
pcdet/models/detectors/detector3d_template.py
pcdet/models/detectors/detector3d_template.py
+29
-8
No files found.
pcdet/models/dense_heads/anchor_head_multi.py
View file @
b16e0891
...
...
@@ -30,7 +30,7 @@ class SingleHead(BaseBEVBackbone):
)
else
:
self
.
conv_dir_cls
=
None
self
.
use_multihead
=
self
.
model_cfg
.
get
(
'USE_MULTI
_
HEAD'
,
False
)
self
.
use_multihead
=
self
.
model_cfg
.
get
(
'USE_MULTIHEAD'
,
False
)
self
.
init_weights
()
def
init_weights
(
self
):
...
...
@@ -55,7 +55,7 @@ class SingleHead(BaseBEVBackbone):
cls_preds
=
cls_preds
.
view
(
-
1
,
self
.
num_anchors_per_location
,
self
.
num_class
,
H
,
W
).
permute
(
0
,
1
,
3
,
4
,
2
).
contiguous
()
box_preds
=
box_preds
.
view
(
batch_size
,
-
1
,
self
.
code_size
)
cls_preds
=
cls_preds
.
view
(
batch_size
,
-
1
,
self
.
num_class
)
.
unsqueeze
(
-
1
)
cls_preds
=
cls_preds
.
view
(
batch_size
,
-
1
,
self
.
num_class
)
if
self
.
conv_dir_cls
is
not
None
:
dir_cls_preds
=
self
.
conv_dir_cls
(
spatial_features_2d
)
...
...
@@ -81,8 +81,15 @@ class AnchorHeadMulti(AnchorHeadTemplate):
model_cfg
=
model_cfg
,
num_class
=
num_class
,
class_names
=
class_names
,
grid_size
=
grid_size
,
point_cloud_range
=
point_cloud_range
,
predict_boxes_when_training
=
predict_boxes_when_training
)
self
.
model_cfg
=
model_cfg
self
.
mak
e_multihead
(
input_channe
ls
)
self
.
seperat
e_multihead
=
self
.
model_cfg
.
get
(
'SEPERATE_MULTIHEAD'
,
Fa
ls
e
)
shared_conv_num_filter
=
self
.
model_cfg
.
SHARED_CONV_NUM_FILTER
self
.
shared_conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
input_channels
,
shared_conv_num_filter
,
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
shared_conv_num_filter
,
eps
=
1e-3
,
momentum
=
0.01
),
nn
.
ReLU
(),
)
self
.
make_multihead
(
shared_conv_num_filter
)
def
make_multihead
(
self
,
input_channels
):
rpn_head_cfgs
=
self
.
model_cfg
.
RPN_HEAD_CFGS
...
...
@@ -92,27 +99,29 @@ class AnchorHeadMulti(AnchorHeadTemplate):
class_names
.
extend
(
rpn_head_cfg
[
'HEAD_CLS_NAME'
])
for
rpn_head_cfg
in
rpn_head_cfgs
:
num_anchors_per_location
=
sum
([
self
.
num_anchors_per_location
[
class_names
.
index
(
head_cls
)]
for
head_cls
in
rpn_head_cfg
[
'HEAD_CLS_NAME'
]])
rpn_head
=
SingleHead
(
self
.
model_cfg
,
input_channels
,
self
.
num_class
,
num_anchors_per_location
,
self
.
box_coder
.
code_size
,
rpn_head_cfg
)
rpn_head
=
SingleHead
(
self
.
model_cfg
,
input_channels
,
len
(
rpn_head_cfg
[
'HEAD_CLS_NAME'
])
if
self
.
seperate_multihead
else
self
.
num_class
,
num_anchors_per_location
,
self
.
box_coder
.
code_size
,
rpn_head_cfg
)
rpn_heads
.
append
(
rpn_head
)
self
.
rpn_heads
=
nn
.
ModuleList
(
rpn_heads
)
def
forward
(
self
,
data_dict
):
spatial_features_2d
=
data_dict
[
'spatial_features_2d'
]
spatial_features_2d
=
self
.
shared_conv
(
spatial_features_2d
)
ret_dicts
=
[]
for
rpn_head
in
self
.
rpn_heads
:
ret_dicts
.
append
(
rpn_head
(
spatial_features_2d
))
cls_preds
=
torch
.
cat
([
ret_dict
[
'cls_preds'
]
for
ret_dict
in
ret_dicts
],
dim
=
1
)
box_preds
=
torch
.
cat
([
ret_dict
[
'box_preds'
]
for
ret_dict
in
ret_dicts
],
dim
=
1
)
ret
=
{
'cls_preds'
:
cls_preds
,
'box_preds'
:
box_preds
,
cls_preds
=
[
ret_dict
[
'cls_preds'
]
for
ret_dict
in
ret_dicts
]
box_preds
=
[
ret_dict
[
'box_preds'
]
for
ret_dict
in
ret_dicts
]
ret
=
{
'cls_preds'
:
cls_preds
if
self
.
seperate_multihead
else
torch
.
cat
(
cls_preds
,
dim
=
1
),
'box_preds'
:
box_preds
if
self
.
seperate_multihead
else
torch
.
cat
(
box_preds
,
dim
=
1
),
}
if
self
.
model_cfg
.
get
(
'USE_DIRECTION_CLASSIFIER'
,
False
):
dir_cls_preds
=
torch
.
cat
(
[
ret_dict
[
'dir_cls_preds'
]
for
ret_dict
in
ret_dicts
]
,
dim
=
1
)
ret
[
'dir_cls_preds'
]
=
dir_cls_preds
dir_cls_preds
=
[
ret_dict
[
'dir_cls_preds'
]
for
ret_dict
in
ret_dicts
]
ret
[
'dir_cls_preds'
]
=
dir_cls_preds
if
self
.
seperate_multihead
else
torch
.
cat
(
dir_cls_preds
,
dim
=
1
)
else
:
dir_cls_preds
=
None
...
...
@@ -133,3 +142,110 @@ class AnchorHeadMulti(AnchorHeadTemplate):
data_dict
[
'cls_preds_normalized'
]
=
False
return
data_dict
def
get_cls_layer_loss
(
self
):
cls_preds
=
self
.
forward_ret_dict
[
'cls_preds'
]
box_cls_labels
=
self
.
forward_ret_dict
[
'box_cls_labels'
]
if
not
isinstance
(
cls_preds
,
list
):
cls_preds
=
[
cls_preds
]
batch_size
=
int
(
cls_preds
[
0
].
shape
[
0
])
cared
=
box_cls_labels
>=
0
# [N, num_anchors]
positives
=
box_cls_labels
>
0
negatives
=
box_cls_labels
==
0
negative_cls_weights
=
negatives
*
1.0
cls_weights
=
(
negative_cls_weights
+
1.0
*
positives
).
float
()
reg_weights
=
positives
.
float
()
if
self
.
num_class
==
1
:
# class agnostic
box_cls_labels
[
positive
]
=
1
pos_normalizer
=
positives
.
sum
(
1
,
keepdim
=
True
).
float
()
reg_weights
/=
torch
.
clamp
(
pos_normalizer
,
min
=
1.0
)
cls_weights
/=
torch
.
clamp
(
pos_normalizer
,
min
=
1.0
)
cls_targets
=
box_cls_labels
*
cared
.
type_as
(
box_cls_labels
)
one_hot_target
=
torch
.
zeros
(
*
list
(
cls_targets
.
shape
),
cls_preds
[
0
].
shape
[
-
1
]
+
1
if
self
.
seperate_multihead
else
self
.
num_class
+
1
,
dtype
=
cls_preds
[
0
].
dtype
,
device
=
cls_targets
.
device
)
one_hot_target
.
scatter_
(
-
1
,
cls_targets
.
unsqueeze
(
dim
=-
1
).
long
(),
1.0
)
one_hot_targets
=
one_hot_target
[...,
1
:]
start_idx
=
0
cls_losses
=
0
for
cls_pred
in
cls_preds
:
cls_pred
=
cls_pred
.
view
(
batch_size
,
-
1
,
cls_pred
.
shape
[
-
1
])
one_hot_target
=
one_hot_targets
[:,
start_idx
:
start_idx
+
cls_pred
.
shape
[
1
]]
cls_weight
=
cls_weights
[:,
start_idx
:
start_idx
+
cls_pred
.
shape
[
1
]]
cls_loss_src
=
self
.
cls_loss_func
(
cls_pred
,
one_hot_target
,
weights
=
cls_weight
)
# [N, M]
cls_loss
=
cls_loss_src
.
sum
()
/
batch_size
cls_loss
=
cls_loss
*
self
.
model_cfg
.
LOSS_CONFIG
.
LOSS_WEIGHTS
[
'cls_weight'
]
cls_losses
+=
cls_loss
start_idx
+=
cls_pred
.
shape
[
1
]
tb_dict
=
{
'rpn_loss_cls'
:
cls_losses
.
item
()
}
return
cls_losses
,
tb_dict
def
get_box_reg_layer_loss
(
self
):
box_preds
=
self
.
forward_ret_dict
[
'box_preds'
]
box_dir_cls_preds
=
self
.
forward_ret_dict
.
get
(
'dir_cls_preds'
,
None
)
box_reg_targets
=
self
.
forward_ret_dict
[
'box_reg_targets'
]
box_cls_labels
=
self
.
forward_ret_dict
[
'box_cls_labels'
]
positives
=
box_cls_labels
>
0
reg_weights
=
positives
.
float
()
pos_normalizer
=
positives
.
sum
(
1
,
keepdim
=
True
).
float
()
reg_weights
/=
torch
.
clamp
(
pos_normalizer
,
min
=
1.0
)
if
not
isinstance
(
box_preds
,
list
):
box_preds
=
[
box_preds
]
batch_size
=
int
(
box_preds
[
0
].
shape
[
0
])
if
isinstance
(
self
.
anchors
,
list
):
if
self
.
use_multihead
:
anchors
=
torch
.
cat
(
[
anchor
.
permute
(
3
,
4
,
0
,
1
,
2
,
5
).
contiguous
().
view
(
-
1
,
anchor
.
shape
[
-
1
])
for
anchor
in
self
.
anchors
],
dim
=
0
)
else
:
anchors
=
torch
.
cat
(
self
.
anchors
,
dim
=-
3
)
else
:
anchors
=
self
.
anchors
anchors
=
anchors
.
view
(
1
,
-
1
,
anchors
.
shape
[
-
1
]).
repeat
(
batch_size
,
1
,
1
)
start_idx
=
0
box_losses
=
0
tb_dict
=
{}
for
idx
,
box_pred
in
enumerate
(
box_preds
):
box_pred
=
box_pred
.
view
(
batch_size
,
-
1
,
box_pred
.
shape
[
-
1
]
//
self
.
num_anchors_per_location
if
not
self
.
use_multihead
else
box_pred
.
shape
[
-
1
])
box_reg_target
=
box_reg_targets
[:,
start_idx
:
start_idx
+
box_pred
.
shape
[
1
]]
reg_weight
=
reg_weights
[:,
start_idx
:
start_idx
+
box_pred
.
shape
[
1
]]
# sin(a - b) = sinacosb-cosasinb
box_pred_sin
,
reg_target_sin
=
self
.
add_sin_difference
(
box_pred
,
box_reg_target
)
loc_loss_src
=
self
.
reg_loss_func
(
box_pred_sin
,
reg_target_sin
,
weights
=
reg_weight
)
# [N, M]
loc_loss
=
loc_loss_src
.
sum
()
/
batch_size
loc_loss
=
loc_loss
*
self
.
model_cfg
.
LOSS_CONFIG
.
LOSS_WEIGHTS
[
'loc_weight'
]
box_losses
+=
loc_loss
tb_dict
[
'rpn_loss_loc'
]
=
tb_dict
.
get
(
'rpn_loss_loc'
,
0
)
+
loc_loss
if
box_dir_cls_preds
is
not
None
:
if
not
isinstance
(
box_dir_cls_preds
,
list
):
box_dir_cls_preds
=
[
box_dir_cls_preds
]
dir_targets
=
self
.
get_direction_target
(
anchors
,
box_reg_targets
,
dir_offset
=
self
.
model_cfg
.
DIR_OFFSET
,
num_bins
=
self
.
model_cfg
.
NUM_DIR_BINS
)
box_dir_cls_pred
=
box_dir_cls_preds
[
idx
]
dir_logit
=
box_dir_cls_pred
.
view
(
batch_size
,
-
1
,
self
.
model_cfg
.
NUM_DIR_BINS
)
weights
=
positives
.
type_as
(
dir_logit
)
weights
/=
torch
.
clamp
(
weights
.
sum
(
-
1
,
keepdim
=
True
),
min
=
1.0
)
weight
=
weights
[:,
start_idx
:
start_idx
+
box_pred
.
shape
[
1
]]
dir_target
=
dir_targets
[:,
start_idx
:
start_idx
+
box_pred
.
shape
[
1
]]
dir_loss
=
self
.
dir_loss_func
(
dir_logit
,
dir_target
,
weights
=
weight
)
dir_loss
=
dir_loss
.
sum
()
/
batch_size
dir_loss
=
dir_loss
*
self
.
model_cfg
.
LOSS_CONFIG
.
LOSS_WEIGHTS
[
'dir_weight'
]
box_losses
+=
dir_loss
tb_dict
[
'rpn_loss_dir'
]
=
tb_dict
.
get
(
'rpn_loss_dir'
,
0
)
+
dir_loss
.
item
()
start_idx
+=
box_pred
.
shape
[
1
]
return
box_losses
,
tb_dict
pcdet/models/dense_heads/anchor_head_template.py
View file @
b16e0891
...
...
@@ -14,7 +14,7 @@ class AnchorHeadTemplate(nn.Module):
self
.
num_class
=
num_class
self
.
class_names
=
class_names
self
.
predict_boxes_when_training
=
predict_boxes_when_training
self
.
use_multihead
=
self
.
model_cfg
.
get
(
'USE_MULTI
_
HEAD'
,
False
)
self
.
use_multihead
=
self
.
model_cfg
.
get
(
'USE_MULTIHEAD'
,
False
)
anchor_target_cfg
=
self
.
model_cfg
.
TARGET_ASSIGNER_CONFIG
self
.
box_coder
=
getattr
(
box_coder_utils
,
anchor_target_cfg
.
BOX_CODER
)(
...
...
@@ -26,7 +26,7 @@ class AnchorHeadTemplate(nn.Module):
anchor_generator_cfg
,
grid_size
=
grid_size
,
point_cloud_range
=
point_cloud_range
)
self
.
anchors
=
[
x
.
cuda
()
for
x
in
anchors
]
self
.
target_assigner
=
self
.
get_target_assigner
(
anchor_target_cfg
,
anchor_generator_cfg
)
self
.
target_assigner
=
self
.
get_target_assigner
(
anchor_target_cfg
)
self
.
forward_ret_dict
=
{}
self
.
build_losses
(
self
.
model_cfg
.
LOSS_CONFIG
)
...
...
@@ -41,17 +41,17 @@ class AnchorHeadTemplate(nn.Module):
anchors_list
,
num_anchors_per_location_list
=
anchor_generator
.
generate_anchors
(
feature_map_size
)
return
anchors_list
,
num_anchors_per_location_list
def
get_target_assigner
(
self
,
anchor_target_cfg
,
anchor_generator_cfg
):
def
get_target_assigner
(
self
,
anchor_target_cfg
):
if
anchor_target_cfg
.
NAME
==
'ATSS'
:
target_assigner
=
ATSSTargetAssigner
(
topk
=
anchor_target_cfg
.
TOPK
,
box_coder
=
self
.
box_coder
,
use_multihead
=
self
.
use_multihead
,
match_height
=
anchor_target_cfg
.
MATCH_HEIGHT
)
elif
anchor_target_cfg
.
NAME
==
'AxisAlignedTargetAssigner'
:
target_assigner
=
AxisAlignedTargetAssigner
(
anchor_target_cfg
=
anchor_target_cfg
,
anchor_generator_cfg
=
anchor_generator_cfg
,
model_cfg
=
self
.
model_cfg
,
class_names
=
self
.
class_names
,
box_coder
=
self
.
box_coder
,
match_height
=
anchor_target_cfg
.
MATCH_HEIGHT
...
...
@@ -82,7 +82,7 @@ class AnchorHeadTemplate(nn.Module):
"""
targets_dict
=
self
.
target_assigner
.
assign_targets
(
self
.
anchors
,
gt_boxes
,
self
.
use_multihead
self
.
anchors
,
gt_boxes
)
return
targets_dict
...
...
@@ -113,8 +113,6 @@ class AnchorHeadTemplate(nn.Module):
one_hot_targets
.
scatter_
(
-
1
,
cls_targets
.
unsqueeze
(
dim
=-
1
).
long
(),
1.0
)
cls_preds
=
cls_preds
.
view
(
batch_size
,
-
1
,
self
.
num_class
)
one_hot_targets
=
one_hot_targets
[...,
1
:]
# import pdb
# pdb.set_trace()
cls_loss_src
=
self
.
cls_loss_func
(
cls_preds
,
one_hot_targets
,
weights
=
cls_weights
)
# [N, M]
cls_loss
=
cls_loss_src
.
sum
()
/
batch_size
...
...
@@ -235,14 +233,14 @@ class AnchorHeadTemplate(nn.Module):
anchors
=
self
.
anchors
num_anchors
=
anchors
.
view
(
-
1
,
anchors
.
shape
[
-
1
]).
shape
[
0
]
batch_anchors
=
anchors
.
view
(
1
,
-
1
,
anchors
.
shape
[
-
1
]).
repeat
(
batch_size
,
1
,
1
)
batch_cls_preds
=
cls_preds
.
view
(
batch_size
,
num_anchors
,
-
1
).
float
()
batch_box_preds
=
box_preds
.
view
(
batch_size
,
num_anchors
,
-
1
)
batch_cls_preds
=
cls_preds
.
view
(
batch_size
,
num_anchors
,
-
1
).
float
()
if
not
isinstance
(
cls_preds
,
list
)
else
cls_preds
batch_box_preds
=
box_preds
.
view
(
batch_size
,
num_anchors
,
-
1
)
if
not
isinstance
(
box_preds
,
list
)
else
torch
.
cat
(
box_preds
,
dim
=
1
).
view
(
batch_size
,
num_anchors
,
-
1
)
batch_box_preds
=
self
.
box_coder
.
decode_torch
(
batch_box_preds
,
batch_anchors
)
if
dir_cls_preds
is
not
None
:
dir_offset
=
self
.
model_cfg
.
DIR_OFFSET
dir_limit_offset
=
self
.
model_cfg
.
DIR_LIMIT_OFFSET
dir_cls_preds
=
dir_cls_preds
.
view
(
batch_size
,
num_anchors
,
-
1
)
dir_cls_preds
=
dir_cls_preds
.
view
(
batch_size
,
num_anchors
,
-
1
)
if
not
isinstance
(
dir_cls_preds
,
list
)
else
torch
.
cat
(
dir_cls_preds
,
dim
=
1
).
view
(
batch_size
,
num_anchors
,
-
1
)
dir_labels
=
torch
.
max
(
dir_cls_preds
,
dim
=-
1
)[
1
]
period
=
(
2
*
np
.
pi
/
self
.
model_cfg
.
NUM_DIR_BINS
)
...
...
pcdet/models/dense_heads/target_assigner/axis_aligned_target_assigner.py
View file @
b16e0891
...
...
@@ -4,8 +4,11 @@ from ....ops.iou3d_nms import iou3d_nms_utils
class
AxisAlignedTargetAssigner
(
object
):
def
__init__
(
self
,
anchor_target_cfg
,
anchor_generator
_cfg
,
class_names
,
box_coder
,
match_height
=
False
):
def
__init__
(
self
,
model
_cfg
,
class_names
,
box_coder
,
match_height
=
False
):
super
().
__init__
()
anchor_generator_cfg
=
model_cfg
.
ANCHOR_GENERATOR_CONFIG
anchor_target_cfg
=
model_cfg
.
TARGET_ASSIGNER_CONFIG
self
.
box_coder
=
box_coder
self
.
match_height
=
match_height
self
.
class_names
=
class_names
...
...
@@ -19,7 +22,16 @@ class AxisAlignedTargetAssigner(object):
self
.
matched_thresholds
[
config
[
'class_name'
]]
=
config
[
'matched_threshold'
]
self
.
unmatched_thresholds
[
config
[
'class_name'
]]
=
config
[
'unmatched_threshold'
]
def
assign_targets
(
self
,
all_anchors
,
gt_boxes_with_classes
,
use_multihead
=
False
):
self
.
use_multihead
=
model_cfg
.
get
(
'USE_MULTIHEAD'
,
False
)
self
.
seperate_multihead
=
model_cfg
.
get
(
'SEPERATE_MULTIHEAD'
,
False
)
if
self
.
seperate_multihead
:
rpn_head_cfgs
=
model_cfg
.
RPN_HEAD_CFGS
self
.
gt_remapping
=
{}
for
rpn_head_cfg
in
rpn_head_cfgs
:
for
idx
,
name
in
enumerate
(
rpn_head_cfg
[
'HEAD_CLS_NAME'
]):
self
.
gt_remapping
[
name
]
=
idx
+
1
def
assign_targets
(
self
,
all_anchors
,
gt_boxes_with_classes
):
"""
Args:
all_anchors: [(N, 7), ...]
...
...
@@ -48,21 +60,30 @@ class AxisAlignedTargetAssigner(object):
for
anchor_class_name
,
anchors
in
zip
(
self
.
anchor_class_names
,
all_anchors
):
mask
=
torch
.
tensor
([
self
.
class_names
[
c
-
1
]
==
anchor_class_name
for
c
in
cur_gt_classes
],
dtype
=
torch
.
bool
)
if
use_multihead
:
if
self
.
use_multihead
:
anchors
=
anchors
.
permute
(
3
,
4
,
0
,
1
,
2
,
5
).
contiguous
().
view
(
-
1
,
anchors
.
shape
[
-
1
])
if
self
.
seperate_multihead
:
selected_classes
=
cur_gt_classes
[
mask
].
clone
()
if
len
(
selected_classes
)
>
0
:
new_cls_id
=
self
.
gt_remapping
[
anchor_class_name
]
selected_classes
[:]
=
new_cls_id
else
:
selected_classes
=
cur_gt_classes
[
mask
]
else
:
feature_map_size
=
anchors
.
shape
[:
3
]
anchors
=
anchors
.
view
(
-
1
,
anchors
.
shape
[
-
1
])
selected_classes
=
cur_gt_classes
[
mask
]
single_target
=
self
.
assign_targets_single
(
anchors
,
cur_gt
[
mask
],
gt_classes
=
cur_gt
_classes
[
mask
]
,
gt_classes
=
selected
_classes
,
matched_threshold
=
self
.
matched_thresholds
[
anchor_class_name
],
unmatched_threshold
=
self
.
unmatched_thresholds
[
anchor_class_name
]
)
target_list
.
append
(
single_target
)
if
use_multihead
:
if
self
.
use_multihead
:
target_dict
=
{
'box_cls_labels'
:
[
t
[
'box_cls_labels'
].
view
(
-
1
)
for
t
in
target_list
],
'box_reg_targets'
:
[
t
[
'box_reg_targets'
].
view
(
-
1
,
self
.
box_coder
.
code_size
)
for
t
in
target_list
],
...
...
@@ -89,7 +110,6 @@ class AxisAlignedTargetAssigner(object):
reg_weights
.
append
(
target_dict
[
'reg_weights'
])
bbox_targets
=
torch
.
stack
(
bbox_targets
,
dim
=
0
)
cls_labels
=
torch
.
stack
(
cls_labels
,
dim
=
0
)
reg_weights
=
torch
.
stack
(
reg_weights
,
dim
=
0
)
all_targets_dict
=
{
...
...
pcdet/models/detectors/detector3d_template.py
View file @
b16e0891
...
...
@@ -6,7 +6,7 @@ from ..backbones_3d import vfe, pfe
from
..backbones_2d
import
map_to_bev
from
..model_utils.model_nms_utils
import
class_agnostic_nms
from
...ops.iou3d_nms
import
iou3d_nms_utils
import
numpy
as
np
class
Detector3DTemplate
(
nn
.
Module
):
def
__init__
(
self
,
model_cfg
,
num_class
,
dataset
):
...
...
@@ -182,22 +182,43 @@ class Detector3DTemplate(nn.Module):
if
batch_dict
.
get
(
'batch_index'
,
None
)
is
not
None
:
assert
batch_dict
[
'batch_cls_preds'
].
shape
.
__len__
()
==
2
batch_mask
=
(
batch_dict
[
'batch_index'
]
==
index
)
else
:
if
isinstance
(
batch_dict
[
'batch_cls_preds'
],
list
):
assert
batch_dict
[
'batch_cls_preds'
][
0
].
shape
.
__len__
()
==
3
else
:
assert
batch_dict
[
'batch_cls_preds'
].
shape
.
__len__
()
==
3
batch_mask
=
index
box_preds
=
batch_dict
[
'batch_box_preds'
][
batch_mask
]
cls_preds
=
batch_dict
[
'batch_cls_preds'
][
batch_mask
]
cls_preds
=
batch_dict
[
'batch_cls_preds'
][
batch_mask
]
if
not
isinstance
(
batch_dict
[
'batch_cls_preds'
],
list
)
else
[
batch_cls_pred
[
batch_mask
]
for
batch_cls_pred
in
batch_dict
[
'batch_cls_preds'
]]
src_cls_preds
=
cls_preds
src_box_preds
=
box_preds
if
isinstance
(
cls_preds
,
list
):
assert
cls_preds
[
0
].
shape
[
1
]
in
[
1
,
self
.
num_class
]
else
:
assert
cls_preds
.
shape
[
1
]
in
[
1
,
self
.
num_class
]
if
not
batch_dict
[
'cls_preds_normalized'
]:
cls_preds
=
torch
.
sigmoid
(
cls_preds
)
cls_preds
=
torch
.
sigmoid
(
cls_preds
)
if
not
isinstance
(
cls_preds
,
list
)
else
[
torch
.
sigmoid
(
cls_pred
)
for
cls_pred
in
cls_preds
]
if
post_process_cfg
.
NMS_CONFIG
.
MULTI_CLASSES_NMS
:
raise
NotImplementedError
else
:
if
isinstance
(
cls_preds
,
list
):
all_cls_preds
=
[]
label_preds
=
[]
rpn_head_cfgs
=
self
.
model_cfg
.
DENSE_HEAD
.
RPN_HEAD_CFGS
head_cls_names
=
[
np
.
array
(
rpn_head_cfg
[
'HEAD_CLS_NAME'
])
for
rpn_head_cfg
in
rpn_head_cfgs
]
for
idx
,
cls_pred
in
enumerate
(
cls_preds
):
pred_score
,
pred_head_label
=
torch
.
max
(
cls_pred
,
dim
=-
1
)
pred_class_names
=
head_cls_names
[
idx
][
pred_head_label
.
cpu
().
numpy
().
astype
(
int
)]
label_pred
=
[
self
.
class_names
.
index
(
cls_name
)
+
1
for
cls_name
in
pred_class_names
]
label_pred
=
torch
.
from_numpy
(
np
.
array
(
label_pred
)).
to
(
cls_pred
.
device
).
int
()
all_cls_preds
.
append
(
pred_score
)
label_preds
.
append
(
label_pred
)
cls_preds
=
torch
.
cat
(
all_cls_preds
,
dim
=
0
)
label_preds
=
torch
.
cat
(
label_preds
,
dim
=
0
)
else
:
cls_preds
,
label_preds
=
torch
.
max
(
cls_preds
,
dim
=-
1
)
label_preds
=
batch_dict
[
'roi_labels'
][
index
]
if
batch_dict
.
get
(
'has_class_labels'
,
False
)
else
label_preds
+
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment