Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e4759983
Commit
e4759983
authored
Jun 21, 2022
by
charlie
Browse files
formatting
parent
fbea17d7
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
67 additions
and
63 deletions
+67
-63
src/include/migraphx/op/nonmaxsuppression.hpp
src/include/migraphx/op/nonmaxsuppression.hpp
+67
-63
No files found.
src/include/migraphx/op/nonmaxsuppression.hpp
View file @
e4759983
...
@@ -70,7 +70,8 @@ struct nonmaxsuppression
...
@@ -70,7 +70,8 @@ struct nonmaxsuppression
};
};
template
<
class
T
>
template
<
class
T
>
box
batch_box
(
const
migraphx
::
tensor_view
<
T
>&
boxes
,
std
::
size_t
box_ind
,
std
::
size_t
box_idx
)
const
box
batch_box
(
const
migraphx
::
tensor_view
<
T
>&
boxes
,
std
::
size_t
box_ind
,
std
::
size_t
box_idx
)
const
{
{
box
result
{};
box
result
{};
auto
start
=
box_ind
+
4
*
box_idx
;
auto
start
=
box_ind
+
4
*
box_idx
;
...
@@ -92,7 +93,6 @@ struct nonmaxsuppression
...
@@ -92,7 +93,6 @@ struct nonmaxsuppression
return
result
;
return
result
;
}
}
inline
bool
suppress_by_iou
(
box
b1
,
box
b2
,
float
iou_threshold
)
const
inline
bool
suppress_by_iou
(
box
b1
,
box
b2
,
float
iou_threshold
)
const
{
{
b1
.
sort
();
b1
.
sort
();
...
@@ -128,18 +128,21 @@ struct nonmaxsuppression
...
@@ -128,18 +128,21 @@ struct nonmaxsuppression
return
intersection_over_union
>
iou_threshold
;
return
intersection_over_union
>
iou_threshold
;
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
std
::
size_t
max_output_boxes_per_class
=
(
args
.
size
()
>
2
)
?
(
args
.
at
(
2
).
at
<
std
::
size_t
>
())
:
0
;
std
::
size_t
max_output_boxes_per_class
=
if
(
max_output_boxes_per_class
==
0
)
{
return
result
;
}
(
args
.
size
()
>
2
)
?
(
args
.
at
(
2
).
at
<
std
::
size_t
>
())
:
0
;
if
(
max_output_boxes_per_class
==
0
)
{
return
result
;
}
float
iou_threshold
=
(
args
.
size
()
>
3
)
?
(
args
.
at
(
3
).
at
<
float
>
())
:
0.0
f
;
float
iou_threshold
=
(
args
.
size
()
>
3
)
?
(
args
.
at
(
3
).
at
<
float
>
())
:
0.0
f
;
float
score_threshold
=
(
args
.
size
()
>
4
)
?
(
args
.
at
(
4
).
at
<
float
>
())
:
0.0
f
;
float
score_threshold
=
(
args
.
size
()
>
4
)
?
(
args
.
at
(
4
).
at
<
float
>
())
:
0.0
f
;
result
.
visit
([
&
](
auto
output
){
result
.
visit
([
&
](
auto
output
)
{
visit_all
(
args
[
0
],
args
[
1
])([
&
](
auto
boxes
,
auto
scores
){
visit_all
(
args
[
0
],
args
[
1
])([
&
](
auto
boxes
,
auto
scores
)
{
std
::
fill
(
output
.
begin
(),
output
.
end
(),
0
);
std
::
fill
(
output
.
begin
(),
output
.
end
(),
0
);
const
auto
&
lens
=
scores
.
get_shape
().
lens
();
const
auto
&
lens
=
scores
.
get_shape
().
lens
();
const
auto
num_batches
=
lens
[
0
];
const
auto
num_batches
=
lens
[
0
];
...
@@ -155,11 +158,13 @@ struct nonmaxsuppression
...
@@ -155,11 +158,13 @@ struct nonmaxsuppression
auto
batch_idx
=
idx
[
0
];
auto
batch_idx
=
idx
[
0
];
auto
class_idx
=
idx
[
1
];
auto
class_idx
=
idx
[
1
];
// index offset for this class
// index offset for this class
std
::
size_t
score_offset_ind
=
(
batch_idx
*
num_classes
+
class_idx
)
*
num_boxes
;
std
::
size_t
score_offset_ind
=
(
batch_idx
*
num_classes
+
class_idx
)
*
num_boxes
;
// index to first value of this batch
// index to first value of this batch
std
::
size_t
batch_boxes_ind
=
batch_idx
*
num_boxes
*
4
;
std
::
size_t
batch_boxes_ind
=
batch_idx
*
num_boxes
*
4
;
std
::
priority_queue
<
std
::
pair
<
float
,
int64_t
>>
boxes_heap
;
std
::
priority_queue
<
std
::
pair
<
float
,
int64_t
>>
boxes_heap
;
auto
insert_to_boxes_heap
=
make_function_output_iterator
([
&
](
const
auto
&
x
)
{
boxes_heap
.
push
(
x
);
});
auto
insert_to_boxes_heap
=
make_function_output_iterator
([
&
](
const
auto
&
x
)
{
boxes_heap
.
push
(
x
);
});
int64_t
box_idx
=
0
;
int64_t
box_idx
=
0
;
// filter boxes below score_threshold
// filter boxes below score_threshold
...
@@ -180,8 +185,8 @@ struct nonmaxsuppression
...
@@ -180,8 +185,8 @@ struct nonmaxsuppression
{
{
const
std
::
pair
<
float
,
int64_t
>&
next_top_score
=
boxes_heap
.
top
();
const
std
::
pair
<
float
,
int64_t
>&
next_top_score
=
boxes_heap
.
top
();
// Check with existing selected boxes for this class, remove box if it
exceeds the IOU
// Check with existing selected boxes for this class, remove box if it
//
(Intersection Over Union) threshold
// exceeds the IOU
(Intersection Over Union) threshold
bool
not_selected
=
std
::
any_of
(
bool
not_selected
=
std
::
any_of
(
selected_boxes_inside_class
.
begin
(),
selected_boxes_inside_class
.
begin
(),
selected_boxes_inside_class
.
end
(),
selected_boxes_inside_class
.
end
(),
...
@@ -203,8 +208,7 @@ struct nonmaxsuppression
...
@@ -203,8 +208,7 @@ struct nonmaxsuppression
}
}
});
});
std
::
copy
(
selected_indices
.
begin
(),
selected_indices
.
end
(),
output
.
begin
());
std
::
copy
(
selected_indices
.
begin
(),
selected_indices
.
end
(),
output
.
begin
());
}
});
);
});
});
return
result
;
return
result
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment