Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenPCDet
Commits
a236428b
Commit
a236428b
authored
Jul 25, 2020
by
Shaoshuai Shi
Browse files
support multi-classes nms for multi-head, not checked
parent
6901df66
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
93 additions
and
26 deletions
+93
-26
pcdet/models/dense_heads/anchor_head_multi.py
pcdet/models/dense_heads/anchor_head_multi.py
+5
-13
pcdet/models/detectors/detector3d_template.py
pcdet/models/detectors/detector3d_template.py
+45
-13
pcdet/models/model_utils/model_nms_utils.py
pcdet/models/model_utils/model_nms_utils.py
+43
-0
No files found.
pcdet/models/dense_heads/anchor_head_multi.py
View file @
a236428b
...
@@ -229,19 +229,11 @@ class AnchorHeadMulti(AnchorHeadTemplate):
...
@@ -229,19 +229,11 @@ class AnchorHeadMulti(AnchorHeadTemplate):
)
)
if
isinstance
(
batch_cls_preds
,
list
):
if
isinstance
(
batch_cls_preds
,
list
):
all_pred_labels
=
[]
multihead_label_mapping
=
[]
all_cls_preds
=
[]
for
idx
in
range
(
len
(
batch_cls_preds
)):
for
idx
,
cls_pred
in
enumerate
(
batch_cls_preds
):
multihead_label_mapping
.
append
(
self
.
rpn_heads
[
idx
].
head_label_indices
)
pred_score
,
pred_head_label
=
torch
.
max
(
cls_pred
,
dim
=-
1
)
pred_label
=
self
.
rpn_heads
[
idx
].
head_label_indices
[
pred_head_label
]
data_dict
[
'multihead_label_mapping'
]
=
multihead_label_mapping
all_pred_labels
.
append
(
pred_label
)
all_cls_preds
.
append
(
pred_score
[:,
:,
None
])
batch_cls_preds
=
torch
.
cat
(
all_cls_preds
,
dim
=
1
)
batch_pred_labels
=
torch
.
cat
(
all_pred_labels
,
dim
=
1
)
data_dict
[
'batch_pred_labels'
]
=
batch_pred_labels
data_dict
[
'has_class_labels'
]
=
True
data_dict
[
'batch_cls_preds'
]
=
batch_cls_preds
data_dict
[
'batch_cls_preds'
]
=
batch_cls_preds
data_dict
[
'batch_box_preds'
]
=
batch_box_preds
data_dict
[
'batch_box_preds'
]
=
batch_box_preds
...
...
pcdet/models/detectors/detector3d_template.py
View file @
a236428b
...
@@ -4,7 +4,7 @@ import torch.nn as nn
...
@@ -4,7 +4,7 @@ import torch.nn as nn
from
..
import
backbones_3d
,
backbones_2d
,
dense_heads
,
roi_heads
from
..
import
backbones_3d
,
backbones_2d
,
dense_heads
,
roi_heads
from
..backbones_3d
import
vfe
,
pfe
from
..backbones_3d
import
vfe
,
pfe
from
..backbones_2d
import
map_to_bev
from
..backbones_2d
import
map_to_bev
from
..model_utils
.
model_nms_utils
import
class_agnostic_nms
from
..model_utils
import
model_nms_utils
from
...ops.iou3d_nms
import
iou3d_nms_utils
from
...ops.iou3d_nms
import
iou3d_nms_utils
...
@@ -169,6 +169,8 @@ class Detector3DTemplate(nn.Module):
...
@@ -169,6 +169,8 @@ class Detector3DTemplate(nn.Module):
batch_dict:
batch_dict:
batch_size:
batch_size:
batch_cls_preds: (B, num_boxes, num_classes | 1) or (N1+N2+..., num_classes | 1)
batch_cls_preds: (B, num_boxes, num_classes | 1) or (N1+N2+..., num_classes | 1)
or [(B, num_boxes, num_class1), (B, num_boxes, num_class2) ...]
multihead_label_mapping: [(num_class1), (num_class2), ...]
batch_box_preds: (B, num_boxes, 7+C) or (N1+N2+..., 7+C)
batch_box_preds: (B, num_boxes, 7+C) or (N1+N2+..., 7+C)
cls_preds_normalized: indicate whether batch_cls_preds is normalized
cls_preds_normalized: indicate whether batch_cls_preds is normalized
batch_index: optional (N1+N2+...)
batch_index: optional (N1+N2+...)
...
@@ -184,32 +186,62 @@ class Detector3DTemplate(nn.Module):
...
@@ -184,32 +186,62 @@ class Detector3DTemplate(nn.Module):
pred_dicts
=
[]
pred_dicts
=
[]
for
index
in
range
(
batch_size
):
for
index
in
range
(
batch_size
):
if
batch_dict
.
get
(
'batch_index'
,
None
)
is
not
None
:
if
batch_dict
.
get
(
'batch_index'
,
None
)
is
not
None
:
assert
batch_dict
[
'batch_
cls
_preds'
].
shape
.
__len__
()
==
2
assert
batch_dict
[
'batch_
box
_preds'
].
shape
.
__len__
()
==
2
batch_mask
=
(
batch_dict
[
'batch_index'
]
==
index
)
batch_mask
=
(
batch_dict
[
'batch_index'
]
==
index
)
else
:
else
:
assert
batch_dict
[
'batch_
cls
_preds'
].
shape
.
__len__
()
==
3
assert
batch_dict
[
'batch_
box
_preds'
].
shape
.
__len__
()
==
3
batch_mask
=
index
batch_mask
=
index
box_preds
=
batch_dict
[
'batch_box_preds'
][
batch_mask
]
box_preds
=
batch_dict
[
'batch_box_preds'
][
batch_mask
]
cls_preds
=
batch_dict
[
'batch_cls_preds'
][
batch_mask
]
src_cls_preds
=
cls_preds
src_box_preds
=
box_preds
src_box_preds
=
box_preds
assert
cls_preds
.
shape
[
1
]
in
[
1
,
self
.
num_class
]
if
not
batch_dict
[
'cls_preds_normalized'
]:
if
not
isinstance
(
batch_dict
[
'batch_cls_preds'
],
list
):
cls_preds
=
torch
.
sigmoid
(
cls_preds
)
cls_preds
=
batch_dict
[
'batch_cls_preds'
][
batch_mask
]
src_cls_preds
=
cls_preds
assert
cls_preds
.
shape
[
1
]
in
[
1
,
self
.
num_class
]
if
not
batch_dict
[
'cls_preds_normalized'
]:
cls_preds
=
torch
.
sigmoid
(
cls_preds
)
else
:
cls_preds
=
[
x
[
batch_mask
]
for
x
in
batch_dict
[
'batch_cls_preds'
]]
src_cls_preds
=
cls_preds
if
not
batch_dict
[
'cls_preds_normalized'
]:
cls_preds
=
[
torch
.
sigmoid
(
x
)
for
x
in
cls_preds
]
if
post_process_cfg
.
NMS_CONFIG
.
MULTI_CLASSES_NMS
:
if
post_process_cfg
.
NMS_CONFIG
.
MULTI_CLASSES_NMS
:
raise
NotImplementedError
if
not
isinstance
(
cls_preds
,
list
):
cls_preds
=
[
cls_preds
]
multihead_label_mapping
=
[
torch
.
arange
(
1
,
self
.
num_class
,
device
=
cls_preds
[
0
].
device
)]
else
:
multihead_label_mapping
=
batch_dict
[
'multihead_label_mapping'
]
cur_start_idx
=
0
pred_scores
,
pred_labels
,
pred_boxes
=
[],
[],
[]
for
cur_cls_preds
,
cur_label_mapping
in
zip
(
cls_preds
,
multihead_label_mapping
):
assert
cur_cls_preds
.
shape
[
1
]
==
len
(
cur_label_mapping
)
cur_box_preds
=
box_preds
[
cur_start_idx
:
cur_start_idx
+
cur_cls_preds
.
shape
[
0
]]
cur_pred_scores
,
cur_pred_labels
,
cur_pred_boxes
=
model_nms_utils
.
multi_classes_nms
(
cls_scores
=
cur_cls_preds
,
box_preds
=
cur_box_preds
,
nms_config
=
post_process_cfg
.
NMS_CONFIG
,
score_thresh
=
post_process_cfg
.
SCORE_THRESH
)
cur_pred_labels
=
cur_label_mapping
[
cur_pred_labels
]
pred_scores
.
append
(
cur_pred_scores
)
pred_labels
.
append
(
cur_pred_labels
)
pred_boxes
.
append
(
cur_pred_boxes
)
final_scores
=
torch
.
cat
(
pred_scores
,
dim
=
0
)
final_labels
=
torch
.
cat
(
pred_labels
,
dim
=
0
)
final_boxes
=
torch
.
cat
(
pred_boxes
,
dim
=
0
)
else
:
else
:
cls_preds
,
label_preds
=
torch
.
max
(
cls_preds
,
dim
=-
1
)
cls_preds
,
label_preds
=
torch
.
max
(
cls_preds
,
dim
=-
1
)
if
batch_dict
.
get
(
'has_class_labels'
,
False
):
if
batch_dict
.
get
(
'has_class_labels'
,
False
):
label_key
=
'roi_labels'
if
'roi_labels'
in
batch_dict
else
'batch_pred_labels'
label_key
=
'roi_labels'
if
'roi_labels'
in
batch_dict
else
'batch_pred_labels'
label_preds
=
batch_dict
[
label_key
][
index
]
label_preds
=
batch_dict
[
label_key
][
index
]
else
:
else
:
label_preds
+
1
label_preds
=
label_preds
+
1
selected
,
selected_scores
=
model_nms_utils
.
class_agnostic_nms
(
selected
,
selected_scores
=
class_agnostic_nms
(
box_scores
=
cls_preds
,
box_preds
=
box_preds
,
box_scores
=
cls_preds
,
box_preds
=
box_preds
,
nms_config
=
post_process_cfg
.
NMS_CONFIG
,
nms_config
=
post_process_cfg
.
NMS_CONFIG
,
score_thresh
=
post_process_cfg
.
SCORE_THRESH
score_thresh
=
post_process_cfg
.
SCORE_THRESH
...
...
pcdet/models/model_utils/model_nms_utils.py
View file @
a236428b
...
@@ -22,3 +22,46 @@ def class_agnostic_nms(box_scores, box_preds, nms_config, score_thresh=None):
...
@@ -22,3 +22,46 @@ def class_agnostic_nms(box_scores, box_preds, nms_config, score_thresh=None):
original_idxs
=
scores_mask
.
nonzero
().
view
(
-
1
)
original_idxs
=
scores_mask
.
nonzero
().
view
(
-
1
)
selected
=
original_idxs
[
selected
]
selected
=
original_idxs
[
selected
]
return
selected
,
src_box_scores
[
selected
]
return
selected
,
src_box_scores
[
selected
]
def
multi_classes_nms
(
cls_scores
,
box_preds
,
nms_config
,
score_thresh
=
None
):
"""
Args:
cls_scores: (N, num_class)
box_preds: (N, 7 + C)
nms_config:
score_thresh:
Returns:
"""
pred_scores
,
pred_labels
,
pred_boxes
=
[],
[],
[]
for
k
in
range
(
cls_scores
.
shape
[
0
]):
if
score_thresh
is
not
None
:
scores_mask
=
(
cls_scores
[:,
k
]
>=
score_thresh
)
box_scores
=
cls_scores
[
scores_mask
,
k
]
box_preds
=
box_preds
[
scores_mask
]
else
:
box_scores
=
cls_scores
[:,
k
]
selected
=
[]
if
box_scores
.
shape
[
0
]
>
0
:
box_scores_nms
,
indices
=
torch
.
topk
(
box_scores
,
k
=
min
(
nms_config
.
NMS_PRE_MAXSIZE
,
box_scores
.
shape
[
0
]))
boxes_for_nms
=
box_preds
[
indices
]
keep_idx
,
selected_scores
=
getattr
(
iou3d_nms_utils
,
nms_config
.
NMS_TYPE
)(
boxes_for_nms
[:,
0
:
7
],
box_scores_nms
,
nms_config
.
NMS_THRESH
,
**
nms_config
)
selected
=
indices
[
keep_idx
[:
nms_config
.
NMS_POST_MAXSIZE
]]
if
score_thresh
is
not
None
:
selected
=
scores_mask
.
nonzero
().
view
(
-
1
)
pred_scores
.
append
(
box_scores
[
selected
])
pred_labels
.
append
(
box_scores
.
new_ones
(
selected
.
shape
[
0
])
*
k
)
pred_boxes
.
append
(
box_preds
[
selected
])
pred_scores
=
torch
.
cat
(
pred_scores
,
dim
=
0
)
pred_labels
=
torch
.
cat
(
pred_labels
,
dim
=
0
)
pred_boxes
=
torch
.
cat
(
pred_boxes
,
dim
=
0
)
return
pred_scores
,
pred_labels
,
pred_boxes
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment