Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
mmdetection3d
Commits
eab8cfa2
Unverified
Commit
eab8cfa2
authored
Apr 25, 2022
by
Danila Rukhovich
Committed by
GitHub
Apr 25, 2022
Browse files
[Fix] NMS for Point RCNN (#1418)
* fix nms for point rcnn * add else case
parent
023c88b5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
6 deletions
+11
-6
mmdet3d/models/dense_heads/point_rpn_head.py
mmdet3d/models/dense_heads/point_rpn_head.py
+11
-6
No files found.
mmdet3d/models/dense_heads/point_rpn_head.py
View file @
eab8cfa2
...
@@ -3,6 +3,7 @@ import torch
...
@@ -3,6 +3,7 @@ import torch
from
mmcv.runner
import
BaseModule
,
force_fp32
from
mmcv.runner
import
BaseModule
,
force_fp32
from
torch
import
nn
as
nn
from
torch
import
nn
as
nn
from
mmdet3d.core
import
xywhr2xyxyr
from
mmdet3d.core.bbox.structures
import
(
DepthInstance3DBoxes
,
from
mmdet3d.core.bbox.structures
import
(
DepthInstance3DBoxes
,
LiDARInstance3DBoxes
)
LiDARInstance3DBoxes
)
from
mmdet3d.core.post_processing
import
nms_bev
,
nms_normal_bev
from
mmdet3d.core.post_processing
import
nms_bev
,
nms_normal_bev
...
@@ -320,29 +321,33 @@ class PointRPNHead(BaseModule):
...
@@ -320,29 +321,33 @@ class PointRPNHead(BaseModule):
else
:
else
:
raise
NotImplementedError
(
'Unsupported bbox type!'
)
raise
NotImplementedError
(
'Unsupported bbox type!'
)
bbox
=
bbox
.
tensor
[
nonempty_box_mask
]
bbox
=
bbox
[
nonempty_box_mask
]
if
self
.
test_cfg
.
score_thr
is
not
None
:
if
self
.
test_cfg
.
score_thr
is
not
None
:
score_thr
=
self
.
test_cfg
.
score_thr
score_thr
=
self
.
test_cfg
.
score_thr
keep
=
(
obj_scores
>=
score_thr
)
keep
=
(
obj_scores
>=
score_thr
)
obj_scores
=
obj_scores
[
keep
]
obj_scores
=
obj_scores
[
keep
]
sem_scores
=
sem_scores
[
keep
]
sem_scores
=
sem_scores
[
keep
]
bbox
=
bbox
[
keep
]
bbox
=
bbox
.
tensor
[
keep
]
if
obj_scores
.
shape
[
0
]
>
0
:
if
obj_scores
.
shape
[
0
]
>
0
:
topk
=
min
(
nms_cfg
.
nms_pre
,
obj_scores
.
shape
[
0
])
topk
=
min
(
nms_cfg
.
nms_pre
,
obj_scores
.
shape
[
0
])
obj_scores_nms
,
indices
=
torch
.
topk
(
obj_scores
,
k
=
topk
)
obj_scores_nms
,
indices
=
torch
.
topk
(
obj_scores
,
k
=
topk
)
bbox_for_nms
=
bbox
[
indices
]
bbox_for_nms
=
xywhr2xyxyr
(
bbox
[
indices
]
.
bev
)
sem_scores_nms
=
sem_scores
[
indices
]
sem_scores_nms
=
sem_scores
[
indices
]
keep
=
nms_func
(
bbox_for_nms
[:,
0
:
7
],
obj_scores_nms
,
keep
=
nms_func
(
bbox_for_nms
,
obj_scores_nms
,
nms_cfg
.
iou_thr
)
nms_cfg
.
iou_thr
)
keep
=
keep
[:
nms_cfg
.
nms_post
]
keep
=
keep
[:
nms_cfg
.
nms_post
]
bbox_selected
=
bbox
_for_nms
[
keep
]
bbox_selected
=
bbox
.
tensor
[
indices
]
[
keep
]
score_selected
=
obj_scores_nms
[
keep
]
score_selected
=
obj_scores_nms
[
keep
]
cls_preds
=
sem_scores_nms
[
keep
]
cls_preds
=
sem_scores_nms
[
keep
]
labels
=
torch
.
argmax
(
cls_preds
,
-
1
)
labels
=
torch
.
argmax
(
cls_preds
,
-
1
)
else
:
bbox_selected
=
bbox
.
tensor
score_selected
=
obj_scores
.
new_zeros
([
0
])
labels
=
obj_scores
.
new_zeros
([
0
])
cls_preds
=
obj_scores
.
new_zeros
([
0
,
sem_scores
.
shape
[
-
1
]])
return
bbox_selected
,
score_selected
,
labels
,
cls_preds
return
bbox_selected
,
score_selected
,
labels
,
cls_preds
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment