Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e4975990
"script/profile_reduce_with_index.sh" did not exist on "12dfba3d03f402c051e2129fa21f33264f4d26e5"
Unverified
Commit
e4975990
authored
Mar 07, 2023
by
Umang Yadav
Committed by
GitHub
Mar 07, 2023
Browse files
Prune candidates in NMS (#1601)
* NMS improvements
parent
d83b8397
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
31 deletions
+34
-31
src/eliminate_data_type.cpp
src/eliminate_data_type.cpp
+1
-0
src/include/migraphx/op/nonmaxsuppression.hpp
src/include/migraphx/op/nonmaxsuppression.hpp
+33
-31
No files found.
src/eliminate_data_type.cpp
View file @
e4975990
...
...
@@ -38,6 +38,7 @@ void eliminate_data_type::apply(module& m) const
"if"
,
"loop"
,
"roialign"
,
"nonmaxsuppression"
,
"scatternd_add"
,
"scatternd_mul"
,
"scatternd_none"
};
...
...
src/include/migraphx/op/nonmaxsuppression.hpp
View file @
e4975990
...
...
@@ -143,16 +143,22 @@ struct nonmaxsuppression
void
sort
()
{
std
::
sort
(
x
.
begin
(),
x
.
end
());
std
::
sort
(
y
.
begin
(),
y
.
end
());
if
(
x
[
0
]
>
x
[
1
])
{
std
::
swap
(
x
[
0
],
x
[
1
]);
}
if
(
y
[
0
]
>
y
[
1
])
{
std
::
swap
(
y
[
0
],
y
[
1
]);
}
}
std
::
array
<
double
,
2
>&
operator
[](
std
::
size_t
i
)
{
return
i
==
0
?
x
:
y
;
}
double
area
()
const
{
assert
(
std
::
is_sorted
(
x
.
begin
(),
x
.
end
())
);
assert
(
std
::
is_sorted
(
y
.
begin
(),
y
.
end
())
);
assert
(
x
[
0
]
<=
x
[
1
]
);
assert
(
y
[
0
]
<=
y
[
1
]
);
return
(
x
[
1
]
-
x
[
0
])
*
(
y
[
1
]
-
y
[
0
]);
}
};
...
...
@@ -190,14 +196,10 @@ struct nonmaxsuppression
{
intersection
[
i
][
0
]
=
std
::
max
(
b1
[
i
][
0
],
b2
[
i
][
0
]);
intersection
[
i
][
1
]
=
std
::
min
(
b1
[
i
][
1
],
b2
[
i
][
1
]);
}
std
::
vector
<
std
::
array
<
double
,
2
>>
bbox
=
{
intersection
.
x
,
intersection
.
y
};
if
(
std
::
any_of
(
bbox
.
begin
(),
bbox
.
end
(),
[](
auto
bx
)
{
return
not
std
::
is_sorted
(
bx
.
begin
(),
bx
.
end
());
}))
{
return
false
;
if
(
intersection
[
i
][
0
]
>
intersection
[
i
][
1
])
{
return
false
;
}
}
const
double
area1
=
b1
.
area
();
...
...
@@ -265,31 +267,31 @@ struct nonmaxsuppression
auto
batch_boxes_start
=
boxes
.
begin
()
+
batch_idx
*
num_boxes
*
4
;
auto
boxes_heap
=
filter_boxes_by_score
(
scores_start
,
num_boxes
,
score_threshold
);
selected_boxes_inside_class
.
clear
();
// Get the next box with top score, filter by iou_threshold
while
(
not
boxes_heap
.
empty
()
&&
selected_boxes_inside_class
.
size
()
<
max_output_boxes_per_class
)
{
//
Check with existing selected boxes for this class, remove box if it
//
exceeds the IOU (Intersection Over Union) threshold
//
select next top scorer box and remove any boxes from boxes_heap that exceeds IOU
//
threshold with the selected box
const
auto
next_top_score
=
boxes_heap
.
top
();
bool
not_selected
=
std
::
any_of
(
selected_boxes_inside_class
.
begin
(),
selected_boxes_inside_class
.
end
(),
[
&
](
auto
selected_index
)
{
return
this
->
suppress_by_iou
(
batch_box
(
batch_boxes_start
,
next_top_score
.
second
),
batch_box
(
batch_boxes_start
,
selected_index
.
second
),
iou_threshold
);
});
if
(
not
not_selected
)
boxes_heap
.
pop
();
selected_boxes_inside_class
.
push_back
(
next_top_score
);
selected_indices
.
push_back
(
batch_idx
);
selected_indices
.
push_back
(
class_idx
);
selected_indices
.
push_back
(
next_top_score
.
second
);
std
::
priority_queue
<
std
::
pair
<
double
,
int64_t
>>
remainder_boxes
;
while
(
not
boxes_heap
.
empty
())
{
selected_boxes_inside_class
.
push_back
(
next_top_score
);
selected_indices
.
push_back
(
batch_idx
);
selected_indices
.
push_back
(
class_idx
);
selected_indices
.
push_back
(
next_top_score
.
second
);
auto
iou_candidate_box
=
boxes_heap
.
top
();
if
(
not
this
->
suppress_by_iou
(
batch_box
(
batch_boxes_start
,
iou_candidate_box
.
second
),
batch_box
(
batch_boxes_start
,
next_top_score
.
second
),
iou_threshold
))
{
remainder_boxes
.
push
(
iou_candidate_box
);
}
boxes_heap
.
pop
();
}
boxes_heap
.
pop
()
;
boxes_heap
=
remainder_boxes
;
}
});
std
::
copy
(
selected_indices
.
begin
(),
selected_indices
.
end
(),
output
.
begin
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment