Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
764ce52e
Commit
764ce52e
authored
Jul 07, 2020
by
Shaoshuai Shi
Browse files
Merge branch 'dev_multihead' into dev_nuscene
parents
c45634a3
eb074cf2
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
79 additions
and
77 deletions
+79
-77
pcdet/models/backbones_2d/base_bev_backbone.py
pcdet/models/backbones_2d/base_bev_backbone.py
+14
-7
pcdet/models/dense_heads/anchor_head_multi.py
pcdet/models/dense_heads/anchor_head_multi.py
+47
-24
pcdet/models/dense_heads/anchor_head_template.py
pcdet/models/dense_heads/anchor_head_template.py
+6
-3
pcdet/models/detectors/detector3d_template.py
pcdet/models/detectors/detector3d_template.py
+12
-28
tools/cfgs/kitti_models/second_multihead.yaml
tools/cfgs/kitti_models/second_multihead.yaml
+0
-15
No files found.
pcdet/models/backbones_2d/base_bev_backbone.py
View file @
764ce52e
...
@@ -7,13 +7,20 @@ class BaseBEVBackbone(nn.Module):
...
@@ -7,13 +7,20 @@ class BaseBEVBackbone(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
model_cfg
=
model_cfg
self
.
model_cfg
=
model_cfg
assert
len
(
self
.
model_cfg
.
LAYER_NUMS
)
==
len
(
self
.
model_cfg
.
LAYER_STRIDES
)
==
len
(
self
.
model_cfg
.
NUM_FILTERS
)
if
self
.
model_cfg
.
get
(
'LAYER_NUMS'
,
None
)
is
not
None
:
assert
len
(
self
.
model_cfg
.
UPSAMPLE_STRIDES
)
==
len
(
self
.
model_cfg
.
NUM_UPSAMPLE_FILTERS
)
assert
len
(
self
.
model_cfg
.
LAYER_NUMS
)
==
len
(
self
.
model_cfg
.
LAYER_STRIDES
)
==
len
(
self
.
model_cfg
.
NUM_FILTERS
)
layer_nums
=
self
.
model_cfg
.
LAYER_NUMS
layer_nums
=
self
.
model_cfg
.
LAYER_NUMS
layer_strides
=
self
.
model_cfg
.
LAYER_STRIDES
layer_strides
=
self
.
model_cfg
.
LAYER_STRIDES
num_filters
=
self
.
model_cfg
.
NUM_FILTERS
num_filters
=
self
.
model_cfg
.
NUM_FILTERS
num_upsample_filters
=
self
.
model_cfg
.
NUM_UPSAMPLE_FILTERS
else
:
upsample_strides
=
self
.
model_cfg
.
UPSAMPLE_STRIDES
layer_nums
=
layer_strides
=
num_filters
=
[]
if
self
.
model_cfg
.
get
(
'UPSAMPLE_STRIDES'
,
None
)
is
not
None
:
assert
len
(
self
.
model_cfg
.
UPSAMPLE_STRIDES
)
==
len
(
self
.
model_cfg
.
NUM_UPSAMPLE_FILTERS
)
num_upsample_filters
=
self
.
model_cfg
.
NUM_UPSAMPLE_FILTERS
upsample_strides
=
self
.
model_cfg
.
UPSAMPLE_STRIDES
else
:
upsample_strides
=
num_upsample_filters
=
[]
num_levels
=
len
(
layer_nums
)
num_levels
=
len
(
layer_nums
)
c_in_list
=
[
input_channels
,
*
num_filters
[:
-
1
]]
c_in_list
=
[
input_channels
,
*
num_filters
[:
-
1
]]
...
...
pcdet/models/dense_heads/anchor_head_multi.py
View file @
764ce52e
...
@@ -6,13 +6,15 @@ import torch
...
@@ -6,13 +6,15 @@ import torch
class
SingleHead
(
BaseBEVBackbone
):
class
SingleHead
(
BaseBEVBackbone
):
def
__init__
(
self
,
model_cfg
,
input_channels
,
num_class
,
num_anchors_per_location
,
code_size
,
encode_conv_cfg
=
None
):
def
__init__
(
self
,
model_cfg
,
input_channels
,
num_class
,
num_anchors_per_location
,
code_size
,
encode_conv_cfg
=
None
,
head_label_indices
=
None
):
super
().
__init__
(
encode_conv_cfg
,
input_channels
)
super
().
__init__
(
encode_conv_cfg
,
input_channels
)
self
.
num_anchors_per_location
=
num_anchors_per_location
self
.
num_anchors_per_location
=
num_anchors_per_location
self
.
num_class
=
num_class
self
.
num_class
=
num_class
self
.
code_size
=
code_size
self
.
code_size
=
code_size
self
.
model_cfg
=
model_cfg
self
.
model_cfg
=
model_cfg
self
.
register_buffer
(
'head_label_indices'
,
head_label_indices
)
self
.
conv_cls
=
nn
.
Conv2d
(
self
.
conv_cls
=
nn
.
Conv2d
(
input_channels
,
self
.
num_anchors_per_location
*
self
.
num_class
,
input_channels
,
self
.
num_anchors_per_location
*
self
.
num_class
,
...
@@ -57,12 +59,13 @@ class SingleHead(BaseBEVBackbone):
...
@@ -57,12 +59,13 @@ class SingleHead(BaseBEVBackbone):
self
.
num_class
,
H
,
W
).
permute
(
0
,
1
,
3
,
4
,
2
).
contiguous
()
self
.
num_class
,
H
,
W
).
permute
(
0
,
1
,
3
,
4
,
2
).
contiguous
()
box_preds
=
box_preds
.
view
(
batch_size
,
-
1
,
self
.
code_size
)
box_preds
=
box_preds
.
view
(
batch_size
,
-
1
,
self
.
code_size
)
cls_preds
=
cls_preds
.
view
(
batch_size
,
-
1
,
self
.
num_class
)
cls_preds
=
cls_preds
.
view
(
batch_size
,
-
1
,
self
.
num_class
)
if
self
.
conv_dir_cls
is
not
None
:
if
self
.
conv_dir_cls
is
not
None
:
dir_cls_preds
=
self
.
conv_dir_cls
(
spatial_features_2d
)
dir_cls_preds
=
self
.
conv_dir_cls
(
spatial_features_2d
)
if
self
.
use_multihead
:
if
self
.
use_multihead
:
dir_cls_preds
=
dir_cls_preds
.
view
(
dir_cls_preds
=
dir_cls_preds
.
view
(
-
1
,
self
.
num_anchors_per_location
,
self
.
model_cfg
.
NUM_DIR_BINS
,
H
,
W
).
permute
(
0
,
1
,
3
,
4
,
2
).
contiguous
()
-
1
,
self
.
num_anchors_per_location
,
self
.
model_cfg
.
NUM_DIR_BINS
,
H
,
W
).
permute
(
0
,
1
,
3
,
4
,
2
).
contiguous
()
dir_cls_preds
=
dir_cls_preds
.
view
(
batch_size
,
-
1
,
self
.
model_cfg
.
NUM_DIR_BINS
)
dir_cls_preds
=
dir_cls_preds
.
view
(
batch_size
,
-
1
,
self
.
model_cfg
.
NUM_DIR_BINS
)
else
:
else
:
dir_cls_preds
=
dir_cls_preds
.
permute
(
0
,
2
,
3
,
1
).
contiguous
()
dir_cls_preds
=
dir_cls_preds
.
permute
(
0
,
2
,
3
,
1
).
contiguous
()
...
@@ -90,10 +93,10 @@ class AnchorHeadMulti(AnchorHeadTemplate):
...
@@ -90,10 +93,10 @@ class AnchorHeadMulti(AnchorHeadTemplate):
if
self
.
model_cfg
.
get
(
'SHARED_CONV_NUM_FILTER'
,
None
)
is
not
None
:
if
self
.
model_cfg
.
get
(
'SHARED_CONV_NUM_FILTER'
,
None
)
is
not
None
:
shared_conv_num_filter
=
self
.
model_cfg
.
SHARED_CONV_NUM_FILTER
shared_conv_num_filter
=
self
.
model_cfg
.
SHARED_CONV_NUM_FILTER
self
.
shared_conv
=
nn
.
Sequential
(
self
.
shared_conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
input_channels
,
shared_conv_num_filter
,
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
),
nn
.
Conv2d
(
input_channels
,
shared_conv_num_filter
,
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
shared_conv_num_filter
,
eps
=
1e-3
,
momentum
=
0.01
),
nn
.
BatchNorm2d
(
shared_conv_num_filter
,
eps
=
1e-3
,
momentum
=
0.01
),
nn
.
ReLU
(),
nn
.
ReLU
(),
)
)
else
:
else
:
self
.
shared_conv
=
None
self
.
shared_conv
=
None
shared_conv_num_filter
=
input_channels
shared_conv_num_filter
=
input_channels
...
@@ -110,10 +113,15 @@ class AnchorHeadMulti(AnchorHeadTemplate):
...
@@ -110,10 +113,15 @@ class AnchorHeadMulti(AnchorHeadTemplate):
for
rpn_head_cfg
in
rpn_head_cfgs
:
for
rpn_head_cfg
in
rpn_head_cfgs
:
num_anchors_per_location
=
sum
([
self
.
num_anchors_per_location
[
class_names
.
index
(
head_cls
)]
num_anchors_per_location
=
sum
([
self
.
num_anchors_per_location
[
class_names
.
index
(
head_cls
)]
for
head_cls
in
rpn_head_cfg
[
'HEAD_CLS_NAME'
]])
for
head_cls
in
rpn_head_cfg
[
'HEAD_CLS_NAME'
]])
head_label_indices
=
torch
.
from_numpy
(
np
.
array
([
self
.
class_names
.
index
(
cur_name
)
+
1
for
cur_name
in
rpn_head_cfg
[
'HEAD_CLS_NAME'
]
]))
rpn_head
=
SingleHead
(
rpn_head
=
SingleHead
(
self
.
model_cfg
,
input_channels
,
self
.
model_cfg
,
input_channels
,
len
(
rpn_head_cfg
[
'HEAD_CLS_NAME'
])
if
self
.
separate_multihead
else
self
.
num_class
,
len
(
rpn_head_cfg
[
'HEAD_CLS_NAME'
])
if
self
.
separate_multihead
else
self
.
num_class
,
num_anchors_per_location
,
self
.
box_coder
.
code_size
,
rpn_head_cfg
num_anchors_per_location
,
self
.
box_coder
.
code_size
,
rpn_head_cfg
,
head_label_indices
=
head_label_indices
)
)
rpn_heads
.
append
(
rpn_head
)
rpn_heads
.
append
(
rpn_head
)
self
.
rpn_heads
=
nn
.
ModuleList
(
rpn_heads
)
self
.
rpn_heads
=
nn
.
ModuleList
(
rpn_heads
)
...
@@ -126,7 +134,7 @@ class AnchorHeadMulti(AnchorHeadTemplate):
...
@@ -126,7 +134,7 @@ class AnchorHeadMulti(AnchorHeadTemplate):
ret_dicts
=
[]
ret_dicts
=
[]
for
rpn_head
in
self
.
rpn_heads
:
for
rpn_head
in
self
.
rpn_heads
:
ret_dicts
.
append
(
rpn_head
(
spatial_features_2d
))
ret_dicts
.
append
(
rpn_head
(
spatial_features_2d
))
cls_preds
=
[
ret_dict
[
'cls_preds'
]
for
ret_dict
in
ret_dicts
]
cls_preds
=
[
ret_dict
[
'cls_preds'
]
for
ret_dict
in
ret_dicts
]
box_preds
=
[
ret_dict
[
'box_preds'
]
for
ret_dict
in
ret_dicts
]
box_preds
=
[
ret_dict
[
'box_preds'
]
for
ret_dict
in
ret_dicts
]
ret
=
{
ret
=
{
...
@@ -137,11 +145,9 @@ class AnchorHeadMulti(AnchorHeadTemplate):
...
@@ -137,11 +145,9 @@ class AnchorHeadMulti(AnchorHeadTemplate):
if
self
.
model_cfg
.
get
(
'USE_DIRECTION_CLASSIFIER'
,
False
):
if
self
.
model_cfg
.
get
(
'USE_DIRECTION_CLASSIFIER'
,
False
):
dir_cls_preds
=
[
ret_dict
[
'dir_cls_preds'
]
for
ret_dict
in
ret_dicts
]
dir_cls_preds
=
[
ret_dict
[
'dir_cls_preds'
]
for
ret_dict
in
ret_dicts
]
ret
[
'dir_cls_preds'
]
=
dir_cls_preds
if
self
.
separate_multihead
else
torch
.
cat
(
dir_cls_preds
,
dim
=
1
)
ret
[
'dir_cls_preds'
]
=
dir_cls_preds
if
self
.
separate_multihead
else
torch
.
cat
(
dir_cls_preds
,
dim
=
1
)
else
:
dir_cls_preds
=
None
self
.
forward_ret_dict
.
update
(
ret
)
self
.
forward_ret_dict
.
update
(
ret
)
if
self
.
training
:
if
self
.
training
:
targets_dict
=
self
.
assign_targets
(
targets_dict
=
self
.
assign_targets
(
gt_boxes
=
data_dict
[
'gt_boxes'
]
gt_boxes
=
data_dict
[
'gt_boxes'
]
...
@@ -150,8 +156,24 @@ class AnchorHeadMulti(AnchorHeadTemplate):
...
@@ -150,8 +156,24 @@ class AnchorHeadMulti(AnchorHeadTemplate):
else
:
else
:
batch_cls_preds
,
batch_box_preds
=
self
.
generate_predicted_boxes
(
batch_cls_preds
,
batch_box_preds
=
self
.
generate_predicted_boxes
(
batch_size
=
data_dict
[
'batch_size'
],
batch_size
=
data_dict
[
'batch_size'
],
cls_preds
=
cls_preds
,
box_preds
=
box_preds
,
dir_cls_preds
=
dir_cls_preds
cls_preds
=
ret
[
'
cls_preds
'
]
,
box_preds
=
ret
[
'
box_preds
'
]
,
dir_cls_preds
=
ret
[
'
dir_cls_preds
'
]
)
)
if
isinstance
(
batch_cls_preds
,
list
):
all_pred_labels
=
[]
all_cls_preds
=
[]
for
idx
,
cls_pred
in
enumerate
(
batch_cls_preds
):
pred_score
,
pred_head_label
=
torch
.
max
(
cls_pred
,
dim
=-
1
)
pred_label
=
self
.
rpn_heads
[
idx
].
head_label_indices
[
pred_head_label
]
all_pred_labels
.
append
(
pred_label
)
all_cls_preds
.
append
(
pred_score
[:,
:,
None
])
batch_cls_preds
=
torch
.
cat
(
all_cls_preds
,
dim
=
1
)
batch_pred_labels
=
torch
.
cat
(
all_pred_labels
,
dim
=
1
)
data_dict
[
'batch_pred_labels'
]
=
batch_pred_labels
data_dict
[
'has_class_labels'
]
=
True
data_dict
[
'batch_cls_preds'
]
=
batch_cls_preds
data_dict
[
'batch_cls_preds'
]
=
batch_cls_preds
data_dict
[
'batch_box_preds'
]
=
batch_box_preds
data_dict
[
'batch_box_preds'
]
=
batch_box_preds
data_dict
[
'cls_preds_normalized'
]
=
False
data_dict
[
'cls_preds_normalized'
]
=
False
...
@@ -190,11 +212,12 @@ class AnchorHeadMulti(AnchorHeadTemplate):
...
@@ -190,11 +212,12 @@ class AnchorHeadMulti(AnchorHeadTemplate):
cur_num_class
=
self
.
rpn_heads
[
idx
].
num_class
cur_num_class
=
self
.
rpn_heads
[
idx
].
num_class
cls_pred
=
cls_pred
.
view
(
batch_size
,
-
1
,
cur_num_class
)
cls_pred
=
cls_pred
.
view
(
batch_size
,
-
1
,
cur_num_class
)
if
self
.
separate_multihead
:
if
self
.
separate_multihead
:
one_hot_target
=
one_hot_targets
[:,
start_idx
:
start_idx
+
cls_pred
.
shape
[
1
],
c_idx
:
c_idx
+
cur_num_class
]
one_hot_target
=
one_hot_targets
[:,
start_idx
:
start_idx
+
cls_pred
.
shape
[
1
],
c_idx
:
c_idx
+
cur_num_class
]
c_idx
+=
cur_num_class
c_idx
+=
cur_num_class
else
:
else
:
one_hot_target
=
one_hot_targets
[:,
start_idx
:
start_idx
+
cls_pred
.
shape
[
1
]]
one_hot_target
=
one_hot_targets
[:,
start_idx
:
start_idx
+
cls_pred
.
shape
[
1
]]
cls_weight
=
cls_weights
[:,
start_idx
:
start_idx
+
cls_pred
.
shape
[
1
]]
cls_weight
=
cls_weights
[:,
start_idx
:
start_idx
+
cls_pred
.
shape
[
1
]]
cls_loss_src
=
self
.
cls_loss_func
(
cls_pred
,
one_hot_target
,
weights
=
cls_weight
)
# [N, M]
cls_loss_src
=
self
.
cls_loss_func
(
cls_pred
,
one_hot_target
,
weights
=
cls_weight
)
# [N, M]
cls_loss
=
cls_loss_src
.
sum
()
/
batch_size
cls_loss
=
cls_loss_src
.
sum
()
/
batch_size
cls_loss
=
cls_loss
*
self
.
model_cfg
.
LOSS_CONFIG
.
LOSS_WEIGHTS
[
'cls_weight'
]
cls_loss
=
cls_loss
*
self
.
model_cfg
.
LOSS_CONFIG
.
LOSS_WEIGHTS
[
'cls_weight'
]
...
@@ -237,10 +260,10 @@ class AnchorHeadMulti(AnchorHeadTemplate):
...
@@ -237,10 +260,10 @@ class AnchorHeadMulti(AnchorHeadTemplate):
tb_dict
=
{}
tb_dict
=
{}
for
idx
,
box_pred
in
enumerate
(
box_preds
):
for
idx
,
box_pred
in
enumerate
(
box_preds
):
box_pred
=
box_pred
.
view
(
batch_size
,
-
1
,
box_pred
=
box_pred
.
view
(
batch_size
,
-
1
,
box_pred
.
shape
[
-
1
]
//
self
.
num_anchors_per_location
if
not
self
.
use_multihead
else
box_pred
.
shape
[
-
1
]
//
self
.
num_anchors_per_location
if
not
self
.
use_multihead
else
box_pred
.
shape
[
-
1
])
box_pred
.
shape
[
-
1
])
box_reg_target
=
box_reg_targets
[:,
start_idx
:
start_idx
+
box_pred
.
shape
[
1
]]
box_reg_target
=
box_reg_targets
[:,
start_idx
:
start_idx
+
box_pred
.
shape
[
1
]]
reg_weight
=
reg_weights
[:,
start_idx
:
start_idx
+
box_pred
.
shape
[
1
]]
reg_weight
=
reg_weights
[:,
start_idx
:
start_idx
+
box_pred
.
shape
[
1
]]
# sin(a - b) = sinacosb-cosasinb
# sin(a - b) = sinacosb-cosasinb
box_pred_sin
,
reg_target_sin
=
self
.
add_sin_difference
(
box_pred
,
box_reg_target
)
box_pred_sin
,
reg_target_sin
=
self
.
add_sin_difference
(
box_pred
,
box_reg_target
)
loc_loss_src
=
self
.
reg_loss_func
(
box_pred_sin
,
reg_target_sin
,
weights
=
reg_weight
)
# [N, M]
loc_loss_src
=
self
.
reg_loss_func
(
box_pred_sin
,
reg_target_sin
,
weights
=
reg_weight
)
# [N, M]
...
@@ -262,9 +285,9 @@ class AnchorHeadMulti(AnchorHeadTemplate):
...
@@ -262,9 +285,9 @@ class AnchorHeadMulti(AnchorHeadTemplate):
dir_logit
=
box_dir_cls_pred
.
view
(
batch_size
,
-
1
,
self
.
model_cfg
.
NUM_DIR_BINS
)
dir_logit
=
box_dir_cls_pred
.
view
(
batch_size
,
-
1
,
self
.
model_cfg
.
NUM_DIR_BINS
)
weights
=
positives
.
type_as
(
dir_logit
)
weights
=
positives
.
type_as
(
dir_logit
)
weights
/=
torch
.
clamp
(
weights
.
sum
(
-
1
,
keepdim
=
True
),
min
=
1.0
)
weights
/=
torch
.
clamp
(
weights
.
sum
(
-
1
,
keepdim
=
True
),
min
=
1.0
)
weight
=
weights
[:,
start_idx
:
start_idx
+
box_pred
.
shape
[
1
]]
weight
=
weights
[:,
start_idx
:
start_idx
+
box_pred
.
shape
[
1
]]
dir_target
=
dir_targets
[:,
start_idx
:
start_idx
+
box_pred
.
shape
[
1
]]
dir_target
=
dir_targets
[:,
start_idx
:
start_idx
+
box_pred
.
shape
[
1
]]
dir_loss
=
self
.
dir_loss_func
(
dir_logit
,
dir_target
,
weights
=
weight
)
dir_loss
=
self
.
dir_loss_func
(
dir_logit
,
dir_target
,
weights
=
weight
)
dir_loss
=
dir_loss
.
sum
()
/
batch_size
dir_loss
=
dir_loss
.
sum
()
/
batch_size
dir_loss
=
dir_loss
*
self
.
model_cfg
.
LOSS_CONFIG
.
LOSS_WEIGHTS
[
'dir_weight'
]
dir_loss
=
dir_loss
*
self
.
model_cfg
.
LOSS_CONFIG
.
LOSS_WEIGHTS
[
'dir_weight'
]
...
...
pcdet/models/dense_heads/anchor_head_template.py
View file @
764ce52e
...
@@ -244,14 +244,17 @@ class AnchorHeadTemplate(nn.Module):
...
@@ -244,14 +244,17 @@ class AnchorHeadTemplate(nn.Module):
anchors
=
self
.
anchors
anchors
=
self
.
anchors
num_anchors
=
anchors
.
view
(
-
1
,
anchors
.
shape
[
-
1
]).
shape
[
0
]
num_anchors
=
anchors
.
view
(
-
1
,
anchors
.
shape
[
-
1
]).
shape
[
0
]
batch_anchors
=
anchors
.
view
(
1
,
-
1
,
anchors
.
shape
[
-
1
]).
repeat
(
batch_size
,
1
,
1
)
batch_anchors
=
anchors
.
view
(
1
,
-
1
,
anchors
.
shape
[
-
1
]).
repeat
(
batch_size
,
1
,
1
)
batch_cls_preds
=
cls_preds
.
view
(
batch_size
,
num_anchors
,
-
1
).
float
()
if
not
isinstance
(
cls_preds
,
list
)
else
cls_preds
batch_cls_preds
=
cls_preds
.
view
(
batch_size
,
num_anchors
,
-
1
).
float
()
\
batch_box_preds
=
box_preds
.
view
(
batch_size
,
num_anchors
,
-
1
)
if
not
isinstance
(
box_preds
,
list
)
else
torch
.
cat
(
box_preds
,
dim
=
1
).
view
(
batch_size
,
num_anchors
,
-
1
)
if
not
isinstance
(
cls_preds
,
list
)
else
cls_preds
batch_box_preds
=
box_preds
.
view
(
batch_size
,
num_anchors
,
-
1
)
if
not
isinstance
(
box_preds
,
list
)
\
else
torch
.
cat
(
box_preds
,
dim
=
1
).
view
(
batch_size
,
num_anchors
,
-
1
)
batch_box_preds
=
self
.
box_coder
.
decode_torch
(
batch_box_preds
,
batch_anchors
)
batch_box_preds
=
self
.
box_coder
.
decode_torch
(
batch_box_preds
,
batch_anchors
)
if
dir_cls_preds
is
not
None
:
if
dir_cls_preds
is
not
None
:
dir_offset
=
self
.
model_cfg
.
DIR_OFFSET
dir_offset
=
self
.
model_cfg
.
DIR_OFFSET
dir_limit_offset
=
self
.
model_cfg
.
DIR_LIMIT_OFFSET
dir_limit_offset
=
self
.
model_cfg
.
DIR_LIMIT_OFFSET
dir_cls_preds
=
dir_cls_preds
.
view
(
batch_size
,
num_anchors
,
-
1
)
if
not
isinstance
(
dir_cls_preds
,
list
)
else
torch
.
cat
(
dir_cls_preds
,
dim
=
1
).
view
(
batch_size
,
num_anchors
,
-
1
)
dir_cls_preds
=
dir_cls_preds
.
view
(
batch_size
,
num_anchors
,
-
1
)
if
not
isinstance
(
dir_cls_preds
,
list
)
\
else
torch
.
cat
(
dir_cls_preds
,
dim
=
1
).
view
(
batch_size
,
num_anchors
,
-
1
)
dir_labels
=
torch
.
max
(
dir_cls_preds
,
dim
=-
1
)[
1
]
dir_labels
=
torch
.
max
(
dir_cls_preds
,
dim
=-
1
)[
1
]
period
=
(
2
*
np
.
pi
/
self
.
model_cfg
.
NUM_DIR_BINS
)
period
=
(
2
*
np
.
pi
/
self
.
model_cfg
.
NUM_DIR_BINS
)
...
...
pcdet/models/detectors/detector3d_template.py
View file @
764ce52e
...
@@ -6,7 +6,7 @@ from ..backbones_3d import vfe, pfe
...
@@ -6,7 +6,7 @@ from ..backbones_3d import vfe, pfe
from
..backbones_2d
import
map_to_bev
from
..backbones_2d
import
map_to_bev
from
..model_utils.model_nms_utils
import
class_agnostic_nms
from
..model_utils.model_nms_utils
import
class_agnostic_nms
from
...ops.iou3d_nms
import
iou3d_nms_utils
from
...ops.iou3d_nms
import
iou3d_nms_utils
import
numpy
as
np
class
Detector3DTemplate
(
nn
.
Module
):
class
Detector3DTemplate
(
nn
.
Module
):
def
__init__
(
self
,
model_cfg
,
num_class
,
dataset
):
def
__init__
(
self
,
model_cfg
,
num_class
,
dataset
):
...
@@ -170,7 +170,9 @@ class Detector3DTemplate(nn.Module):
...
@@ -170,7 +170,9 @@ class Detector3DTemplate(nn.Module):
batch_box_preds: (B, num_boxes, 7+C) or (N1+N2+..., 7+C)
batch_box_preds: (B, num_boxes, 7+C) or (N1+N2+..., 7+C)
cls_preds_normalized: indicate whether batch_cls_preds is normalized
cls_preds_normalized: indicate whether batch_cls_preds is normalized
batch_index: optional (N1+N2+...)
batch_index: optional (N1+N2+...)
has_class_labels: True/False
roi_labels: (B, num_rois) 1 .. num_classes
roi_labels: (B, num_rois) 1 .. num_classes
batch_pred_labels: (B, num_boxes, 1)
Returns:
Returns:
"""
"""
...
@@ -183,45 +185,27 @@ class Detector3DTemplate(nn.Module):
...
@@ -183,45 +185,27 @@ class Detector3DTemplate(nn.Module):
assert
batch_dict
[
'batch_cls_preds'
].
shape
.
__len__
()
==
2
assert
batch_dict
[
'batch_cls_preds'
].
shape
.
__len__
()
==
2
batch_mask
=
(
batch_dict
[
'batch_index'
]
==
index
)
batch_mask
=
(
batch_dict
[
'batch_index'
]
==
index
)
else
:
else
:
if
isinstance
(
batch_dict
[
'batch_cls_preds'
],
list
):
assert
batch_dict
[
'batch_cls_preds'
].
shape
.
__len__
()
==
3
assert
batch_dict
[
'batch_cls_preds'
][
0
].
shape
.
__len__
()
==
3
else
:
assert
batch_dict
[
'batch_cls_preds'
].
shape
.
__len__
()
==
3
batch_mask
=
index
batch_mask
=
index
box_preds
=
batch_dict
[
'batch_box_preds'
][
batch_mask
]
box_preds
=
batch_dict
[
'batch_box_preds'
][
batch_mask
]
cls_preds
=
batch_dict
[
'batch_cls_preds'
][
batch_mask
]
if
not
isinstance
(
batch_dict
[
'batch_cls_preds'
],
list
)
else
[
batch_cls_pred
[
batch_mask
]
for
batch_cls_pred
in
batch_dict
[
'batch_cls_preds'
]]
cls_preds
=
batch_dict
[
'batch_cls_preds'
][
batch_mask
]
src_cls_preds
=
cls_preds
src_cls_preds
=
cls_preds
src_box_preds
=
box_preds
src_box_preds
=
box_preds
if
isinstance
(
cls_preds
,
list
):
assert
cls_preds
.
shape
[
1
]
in
[
1
,
self
.
num_class
]
assert
cls_preds
[
0
].
shape
[
1
]
in
[
1
,
self
.
num_class
]
else
:
assert
cls_preds
.
shape
[
1
]
in
[
1
,
self
.
num_class
]
if
not
batch_dict
[
'cls_preds_normalized'
]:
if
not
batch_dict
[
'cls_preds_normalized'
]:
cls_preds
=
torch
.
sigmoid
(
cls_preds
)
if
not
isinstance
(
cls_preds
,
list
)
else
[
torch
.
sigmoid
(
cls_pred
)
for
cls_pred
in
cls_preds
]
cls_preds
=
torch
.
sigmoid
(
cls_preds
)
if
post_process_cfg
.
NMS_CONFIG
.
MULTI_CLASSES_NMS
:
if
post_process_cfg
.
NMS_CONFIG
.
MULTI_CLASSES_NMS
:
raise
NotImplementedError
raise
NotImplementedError
else
:
else
:
if
isinstance
(
cls_preds
,
list
):
cls_preds
,
label_preds
=
torch
.
max
(
cls_preds
,
dim
=-
1
)
all_cls_preds
=
[]
if
batch_dict
.
get
(
'has_class_labels'
,
False
):
label_preds
=
[]
label_key
=
'roi_labels'
if
'roi_labels'
in
batch_dict
else
'batch_pred_labels'
rpn_head_cfgs
=
self
.
model_cfg
.
DENSE_HEAD
.
RPN_HEAD_CFGS
label_preds
=
batch_dict
[
label_key
][
index
]
head_cls_names
=
[
np
.
array
(
rpn_head_cfg
[
'HEAD_CLS_NAME'
])
for
rpn_head_cfg
in
rpn_head_cfgs
]
for
idx
,
cls_pred
in
enumerate
(
cls_preds
):
pred_score
,
pred_head_label
=
torch
.
max
(
cls_pred
,
dim
=-
1
)
pred_class_names
=
head_cls_names
[
idx
][
pred_head_label
.
cpu
().
numpy
().
astype
(
int
)]
label_pred
=
[
self
.
class_names
.
index
(
cls_name
)
+
1
for
cls_name
in
pred_class_names
]
label_pred
=
torch
.
from_numpy
(
np
.
array
(
label_pred
)).
to
(
cls_pred
.
device
).
int
()
all_cls_preds
.
append
(
pred_score
)
label_preds
.
append
(
label_pred
)
cls_preds
=
torch
.
cat
(
all_cls_preds
,
dim
=
0
)
label_preds
=
torch
.
cat
(
label_preds
,
dim
=
0
)
else
:
else
:
cls_preds
,
label_preds
=
torch
.
max
(
cls_preds
,
dim
=-
1
)
label_preds
+
1
label_preds
=
batch_dict
[
'roi_labels'
][
index
]
if
batch_dict
.
get
(
'has_class_labels'
,
False
)
else
label_preds
+
1
selected
,
selected_scores
=
class_agnostic_nms
(
selected
,
selected_scores
=
class_agnostic_nms
(
box_scores
=
cls_preds
,
box_preds
=
box_preds
,
box_scores
=
cls_preds
,
box_preds
=
box_preds
,
...
...
tools/cfgs/kitti_models/second_multihead.yaml
View file @
764ce52e
...
@@ -74,27 +74,12 @@ MODEL:
...
@@ -74,27 +74,12 @@ MODEL:
RPN_HEAD_CFGS
:
[
RPN_HEAD_CFGS
:
[
{
{
'
HEAD_CLS_NAME'
:
[
'
Car'
],
'
HEAD_CLS_NAME'
:
[
'
Car'
],
'
LAYER_NUMS'
:
[],
'
LAYER_STRIDES'
:
[],
'
NUM_FILTERS'
:
[],
'
UPSAMPLE_STRIDES'
:
[],
'
NUM_UPSAMPLE_FILTERS'
:
[]
},
},
{
{
'
HEAD_CLS_NAME'
:
[
'
Pedestrian'
],
'
HEAD_CLS_NAME'
:
[
'
Pedestrian'
],
'
LAYER_NUMS'
:
[],
'
LAYER_STRIDES'
:
[],
'
NUM_FILTERS'
:
[],
'
UPSAMPLE_STRIDES'
:
[],
'
NUM_UPSAMPLE_FILTERS'
:
[]
},
},
{
{
'
HEAD_CLS_NAME'
:
[
'
Cyclist'
],
'
HEAD_CLS_NAME'
:
[
'
Cyclist'
],
'
LAYER_NUMS'
:
[],
'
LAYER_STRIDES'
:
[],
'
NUM_FILTERS'
:
[],
'
UPSAMPLE_STRIDES'
:
[],
'
NUM_UPSAMPLE_FILTERS'
:
[]
}
}
]
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment