Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
MMCV
Commits
9b8dd083
Unverified
Commit
9b8dd083
authored
May 13, 2021
by
Shilong Zhang
Committed by
GitHub
May 13, 2021
Browse files
[fix]: fix softnms (#1019)
* fix basemodule * fix typo * fix unitest
parent
da071146
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
3 deletions
+25
-3
mmcv/ops/nms.py
mmcv/ops/nms.py
+6
-3
tests/test_ops/test_nms.py
tests/test_ops/test_nms.py
+19
-0
No files found.
mmcv/ops/nms.py
View file @
9b8dd083
...
@@ -282,18 +282,21 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
...
@@ -282,18 +282,21 @@ def batched_nms(boxes, scores, idxs, nms_cfg, class_agnostic=False):
# This assumes `dets` has 5 dimensions where
# This assumes `dets` has 5 dimensions where
# the last dimension is score.
# the last dimension is score.
# TODO: more elegant way to handle the dimension issue.
# TODO: more elegant way to handle the dimension issue.
# Some type of nms would reweight the score, such as SoftNMS
scores
=
dets
[:,
4
]
scores
=
dets
[:,
4
]
else
:
else
:
total_mask
=
scores
.
new_zeros
(
scores
.
size
(),
dtype
=
torch
.
bool
)
total_mask
=
scores
.
new_zeros
(
scores
.
size
(),
dtype
=
torch
.
bool
)
# Some type of nms would reweight the score, such as SoftNMS
scores_after_nms
=
scores
.
new_zeros
(
scores
.
size
())
for
id
in
torch
.
unique
(
idxs
):
for
id
in
torch
.
unique
(
idxs
):
mask
=
(
idxs
==
id
).
nonzero
(
as_tuple
=
False
).
view
(
-
1
)
mask
=
(
idxs
==
id
).
nonzero
(
as_tuple
=
False
).
view
(
-
1
)
dets
,
keep
=
nms_op
(
boxes_for_nms
[
mask
],
scores
[
mask
],
**
nms_cfg_
)
dets
,
keep
=
nms_op
(
boxes_for_nms
[
mask
],
scores
[
mask
],
**
nms_cfg_
)
total_mask
[
mask
[
keep
]]
=
True
total_mask
[
mask
[
keep
]]
=
True
scores_after_nms
[
mask
[
keep
]]
=
dets
[:,
-
1
]
keep
=
total_mask
.
nonzero
(
as_tuple
=
False
).
view
(
-
1
)
keep
=
total_mask
.
nonzero
(
as_tuple
=
False
).
view
(
-
1
)
keep
=
keep
[
scores
[
keep
].
argsort
(
descending
=
True
)]
scores
,
inds
=
scores_after_nms
[
keep
].
sort
(
descending
=
True
)
keep
=
keep
[
inds
]
boxes
=
boxes
[
keep
]
boxes
=
boxes
[
keep
]
scores
=
scores
[
keep
]
return
torch
.
cat
([
boxes
,
scores
[:,
None
]],
-
1
),
keep
return
torch
.
cat
([
boxes
,
scores
[:,
None
]],
-
1
),
keep
...
...
tests/test_ops/test_nms.py
View file @
9b8dd083
...
@@ -157,3 +157,22 @@ class Testnms(object):
...
@@ -157,3 +157,22 @@ class Testnms(object):
assert
torch
.
equal
(
keep
,
seq_keep
)
assert
torch
.
equal
(
keep
,
seq_keep
)
assert
torch
.
equal
(
boxes
,
seq_boxes
)
assert
torch
.
equal
(
boxes
,
seq_boxes
)
assert
torch
.
equal
(
keep
,
torch
.
from_numpy
(
results
[
'keep'
]))
assert
torch
.
equal
(
keep
,
torch
.
from_numpy
(
results
[
'keep'
]))
nms_cfg
=
dict
(
type
=
'soft_nms'
,
iou_threshold
=
0.7
)
boxes
,
keep
=
batched_nms
(
torch
.
from_numpy
(
results
[
'boxes'
]),
torch
.
from_numpy
(
results
[
'scores'
]),
torch
.
from_numpy
(
results
[
'idxs'
]),
nms_cfg
,
class_agnostic
=
False
)
nms_cfg
.
update
(
split_thr
=
100
)
seq_boxes
,
seq_keep
=
batched_nms
(
torch
.
from_numpy
(
results
[
'boxes'
]),
torch
.
from_numpy
(
results
[
'scores'
]),
torch
.
from_numpy
(
results
[
'idxs'
]),
nms_cfg
,
class_agnostic
=
False
)
assert
torch
.
equal
(
keep
,
seq_keep
)
assert
torch
.
equal
(
boxes
,
seq_boxes
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment