Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
fbea17d7
Commit
fbea17d7
authored
Jun 21, 2022
by
charlie
Browse files
NMS refactor and nonstd shape
parent
c650e2a4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
81 additions
and
98 deletions
+81
-98
src/include/migraphx/op/nonmaxsuppression.hpp
src/include/migraphx/op/nonmaxsuppression.hpp
+81
-98
No files found.
src/include/migraphx/op/nonmaxsuppression.hpp
View file @
fbea17d7
...
@@ -33,7 +33,6 @@ struct nonmaxsuppression
...
@@ -33,7 +33,6 @@ struct nonmaxsuppression
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
// requires at least 2 inputs
// requires at least 2 inputs
check_shapes
{
inputs
,
*
this
};
check_shapes
{{
inputs
.
at
(
0
),
inputs
.
at
(
1
)},
*
this
}.
only_dims
(
3
);
check_shapes
{{
inputs
.
at
(
0
),
inputs
.
at
(
1
)},
*
this
}.
only_dims
(
3
);
auto
lens
=
inputs
.
front
().
lens
();
auto
lens
=
inputs
.
front
().
lens
();
...
@@ -71,28 +70,29 @@ struct nonmaxsuppression
...
@@ -71,28 +70,29 @@ struct nonmaxsuppression
};
};
template
<
class
T
>
template
<
class
T
>
box
batch_box
(
const
T
*
boxes
,
std
::
size_t
b
idx
)
const
box
batch_box
(
const
migraphx
::
tensor_view
<
T
>&
boxes
,
std
::
size_t
box_ind
,
std
::
size_t
box_
idx
)
const
{
{
box
result
{};
box
result
{};
const
T
*
start
=
box
es
+
4
*
bidx
;
auto
start
=
box
_ind
+
4
*
b
ox_
idx
;
if
(
center_point_box
)
if
(
center_point_box
)
{
{
float
half_width
=
start
[
2
]
/
2.0
f
;
float
half_width
=
boxes
[
start
+
2
]
/
2.0
;
float
half_height
=
start
[
3
]
/
2.0
f
;
float
half_height
=
boxes
[
start
+
3
]
/
2.0
;
float
x_center
=
start
[
0
];
float
x_center
=
boxes
[
start
+
0
];
float
y_center
=
start
[
1
];
float
y_center
=
boxes
[
start
+
1
];
result
.
x
=
{
x_center
-
half_width
,
x_center
+
half_width
};
result
.
x
=
{
x_center
-
half_width
,
x_center
+
half_width
};
result
.
y
=
{
y_center
-
half_height
,
y_center
+
half_height
};
result
.
y
=
{
y_center
-
half_height
,
y_center
+
half_height
};
}
}
else
else
{
{
result
.
x
=
{
sta
rt
[
1
],
sta
rt
[
3
]};
result
.
x
=
{
sta
tic_cast
<
float
>
(
boxes
[
start
+
1
]
)
,
sta
tic_cast
<
float
>
(
boxes
[
start
+
3
]
)
};
result
.
y
=
{
sta
rt
[
0
],
sta
rt
[
2
]};
result
.
y
=
{
sta
tic_cast
<
float
>
(
boxes
[
start
+
0
]
)
,
sta
tic_cast
<
float
>
(
boxes
[
start
+
2
]
)
};
}
}
return
result
;
return
result
;
}
}
inline
bool
suppress_by_iou
(
box
b1
,
box
b2
,
float
iou_threshold
)
const
inline
bool
suppress_by_iou
(
box
b1
,
box
b2
,
float
iou_threshold
)
const
{
{
b1
.
sort
();
b1
.
sort
();
...
@@ -128,63 +128,45 @@ struct nonmaxsuppression
...
@@ -128,63 +128,45 @@ struct nonmaxsuppression
return
intersection_over_union
>
iou_threshold
;
return
intersection_over_union
>
iou_threshold
;
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
result
.
visit
([
&
](
auto
out
)
{
std
::
fill
(
out
.
begin
(),
out
.
end
(),
0
);
});
std
::
size_t
max_output_boxes_per_class
=
(
args
.
size
()
>
2
)
?
(
args
.
at
(
2
).
at
<
std
::
size_t
>
())
:
0
;
if
(
max_output_boxes_per_class
==
0
)
{
return
result
;
}
std
::
size_t
max_output_boxes_per_class
=
0
;
float
iou_threshold
=
(
args
.
size
()
>
3
)
?
(
args
.
at
(
3
).
at
<
float
>
())
:
0.0
f
;
float
iou_threshold
=
0.0
f
;
float
score_threshold
=
(
args
.
size
()
>
4
)
?
(
args
.
at
(
4
).
at
<
float
>
())
:
0.0
f
;
float
score_threshold
=
0.0
f
;
result
.
visit
([
&
](
auto
output
){
if
(
args
.
size
()
>
2
)
visit_all
(
args
[
0
],
args
[
1
])([
&
](
auto
boxes
,
auto
scores
){
{
std
::
fill
(
output
.
begin
(),
output
.
end
(),
0
);
max_output_boxes_per_class
=
args
.
at
(
2
).
at
<
std
::
size_t
>
();
const
auto
&
lens
=
scores
.
get_shape
().
lens
();
}
const
auto
num_batches
=
lens
[
0
];
// max_output_boxes_per_class is 0, no output
const
auto
num_classes
=
lens
[
1
];
if
(
max_output_boxes_per_class
==
0
)
const
auto
num_boxes
=
boxes
.
get_shape
().
lens
()[
1
];
{
// boxes of a class with NMS applied [score, index]
return
result
;
}
if
(
args
.
size
()
>
3
)
{
iou_threshold
=
args
.
at
(
3
).
at
<
float
>
();
}
if
(
args
.
size
()
>
4
)
{
score_threshold
=
args
.
at
(
4
).
at
<
float
>
();
}
const
auto
&
lens
=
args
.
at
(
1
).
get_shape
().
lens
();
auto
batch_num
=
lens
[
0
];
auto
class_num
=
lens
[
1
];
auto
box_num
=
args
.
at
(
0
).
get_shape
().
lens
()[
1
];
std
::
vector
<
std
::
pair
<
float
,
int64_t
>>
selected_boxes_inside_class
;
std
::
vector
<
std
::
pair
<
float
,
int64_t
>>
selected_boxes_inside_class
;
std
::
vector
<
int64_t
>
selected_indices
;
std
::
vector
<
int64_t
>
selected_indices
;
selected_boxes_inside_class
.
reserve
(
output_shape
.
elements
());
selected_boxes_inside_class
.
reserve
(
output_shape
.
elements
());
// iterate over batches and classes
auto
scores
=
make_view
<
float
>
(
args
.
at
(
1
).
get_shape
(),
args
.
at
(
1
).
cast
<
float
>
());
shape
comp_s
{
shape
::
float_type
,
{
num_batches
,
num_classes
}};
const
float
*
boxes
=
args
.
at
(
0
).
cast
<
float
>
();
shape
comp_s
{
shape
::
float_type
,
{
batch_num
,
class_num
}};
shape_for_each
(
comp_s
,
[
&
](
auto
idx
)
{
shape_for_each
(
comp_s
,
[
&
](
auto
idx
)
{
auto
bidx
=
idx
[
0
];
auto
batch_idx
=
idx
[
0
];
auto
cidx
=
idx
[
1
];
auto
class_idx
=
idx
[
1
];
// index offset for this class
std
::
size_t
score_offset
=
(
bidx
*
class_num
+
cidx
)
*
box_num
;
std
::
size_t
score_offset_ind
=
(
batch_idx
*
num_classes
+
class_idx
)
*
num_boxes
;
const
float
*
batch_boxes
=
boxes
+
bidx
*
box_num
*
4
;
// index to first value of this batch
std
::
priority_queue
<
std
::
pair
<
float
,
int64_t
>>
sorted_boxes
;
std
::
size_t
batch_boxes_ind
=
batch_idx
*
num_boxes
*
4
;
auto
insert_to_sorted_boxes
=
std
::
priority_queue
<
std
::
pair
<
float
,
int64_t
>>
boxes_heap
;
make_function_output_iterator
([
&
](
const
auto
&
x
)
{
sorted_boxes
.
push
(
x
);
});
auto
insert_to_boxes_heap
=
make_function_output_iterator
([
&
](
const
auto
&
x
)
{
boxes_heap
.
push
(
x
);
});
int64_t
box_idx
=
0
;
int64_t
box_idx
=
0
;
// filter boxes below score_threshold
transform_if
(
transform_if
(
scores
.
begin
()
+
score_offset
,
scores
.
begin
()
+
score_offset
_ind
,
scores
.
begin
()
+
score_offset
+
box_num
,
scores
.
begin
()
+
score_offset
_ind
+
num_boxes
,
insert_to_
sorted_
boxes
,
insert_to_boxes
_heap
,
[
&
](
auto
sc
)
{
[
&
](
auto
sc
)
{
box_idx
++
;
box_idx
++
;
return
sc
>=
score_threshold
;
return
sc
>=
score_threshold
;
...
@@ -193,35 +175,36 @@ struct nonmaxsuppression
...
@@ -193,35 +175,36 @@ struct nonmaxsuppression
selected_boxes_inside_class
.
clear
();
selected_boxes_inside_class
.
clear
();
// Get the next box with top score, filter by iou_threshold
// Get the next box with top score, filter by iou_threshold
while
(
!
sorted_
boxes
.
empty
()
&&
while
(
!
boxes
_heap
.
empty
()
&&
selected_boxes_inside_class
.
size
()
<
max_output_boxes_per_class
)
selected_boxes_inside_class
.
size
()
<
max_output_boxes_per_class
)
{
{
const
std
::
pair
<
float
,
int64_t
>&
next_top_score
=
sorted_
boxes
.
top
();
const
std
::
pair
<
float
,
int64_t
>&
next_top_score
=
boxes
_heap
.
top
();
// Check with existing selected boxes for this class,
suppress
if exceed the IOU
// Check with existing selected boxes for this class,
remove box
if
it
exceed
s
the IOU
// (Intersection Over Union) threshold
// (Intersection Over Union) threshold
bool
not_selected
=
std
::
any_of
(
bool
not_selected
=
std
::
any_of
(
selected_boxes_inside_class
.
begin
(),
selected_boxes_inside_class
.
begin
(),
selected_boxes_inside_class
.
end
(),
selected_boxes_inside_class
.
end
(),
[
&
](
auto
selected_index
)
{
[
&
](
auto
selected_index
)
{
return
this
->
suppress_by_iou
(
batch_box
(
batch_boxes
,
next_top_score
.
second
),
return
this
->
suppress_by_iou
(
batch_box
(
batch_boxes
,
selected_index
.
second
),
batch_box
(
boxes
,
batch_boxes_ind
,
next_top_score
.
second
),
batch_box
(
boxes
,
batch_boxes_ind
,
selected_index
.
second
),
iou_threshold
);
iou_threshold
);
});
});
if
(
not
not_selected
)
if
(
not
not_selected
)
{
{
selected_boxes_inside_class
.
push_back
(
next_top_score
);
selected_boxes_inside_class
.
push_back
(
next_top_score
);
selected_indices
.
push_back
(
bidx
);
selected_indices
.
push_back
(
b
atch_
idx
);
selected_indices
.
push_back
(
cidx
);
selected_indices
.
push_back
(
c
lass_
idx
);
selected_indices
.
push_back
(
next_top_score
.
second
);
selected_indices
.
push_back
(
next_top_score
.
second
);
}
}
sorted_boxes
.
pop
();
boxes_heap
.
pop
();
}
}
});
});
std
::
copy
(
selected_indices
.
begin
(),
selected_indices
.
end
(),
output
.
begin
());
result
.
visit
([
&
](
auto
out
)
{
}
std
::
copy
(
selected_indices
.
begin
(),
selected_indices
.
end
(),
out
.
begin
()
);
);
});
});
return
result
;
return
result
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment