Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6efffa37
Commit
6efffa37
authored
Jun 27, 2022
by
charlie
Browse files
Merge branch 'nonstd_NMS' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_nms
parents
94f2aea8
390b87ae
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
205 additions
and
110 deletions
+205
-110
src/include/migraphx/iota_iterator.hpp
src/include/migraphx/iota_iterator.hpp
+2
-1
src/include/migraphx/op/nonmaxsuppression.hpp
src/include/migraphx/op/nonmaxsuppression.hpp
+129
-109
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+74
-0
No files found.
src/include/migraphx/iota_iterator.hpp
View file @
6efffa37
...
@@ -81,8 +81,9 @@ struct basic_iota_iterator
...
@@ -81,8 +81,9 @@ struct basic_iota_iterator
index
--
;
index
--
;
return
it
;
return
it
;
}
}
// TODO: operator->
reference
operator
*
()
const
{
return
f
(
index
);
}
reference
operator
*
()
const
{
return
f
(
index
);
}
pointer
operator
->
()
const
{
return
&
f
(
index
);
}
reference
operator
[](
int
n
)
const
{
return
f
(
index
+
n
);
}
};
};
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
...
src/include/migraphx/op/nonmaxsuppression.hpp
View file @
6efffa37
...
@@ -56,14 +56,21 @@ struct nonmaxsuppression
...
@@ -56,14 +56,21 @@ struct nonmaxsuppression
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
// requires at least 2 inputs
// requires at least 2 inputs
check_shapes
{
inputs
,
*
this
}.
standard
();
check_shapes
{{
inputs
.
at
(
0
),
inputs
.
at
(
1
)},
*
this
}.
only_dims
(
3
);
check_shapes
{{
inputs
.
at
(
0
),
inputs
.
at
(
1
)},
*
this
}.
only_dims
(
3
);
auto
lens
=
inputs
.
front
().
lens
();
auto
lens
=
inputs
.
front
().
lens
();
// check input shape
// check input shape
if
(
lens
[
1
]
!=
inputs
.
at
(
1
).
lens
()[
2
])
if
(
lens
[
1
]
!=
inputs
.
at
(
1
).
lens
()[
2
])
{
{
MIGRAPHX_THROW
(
"NonMaxSuppression: dimension mismatch between first and second input!"
);
MIGRAPHX_THROW
(
"NonMaxSuppression: spatial dimension mismatch between boxes and scores input"
);
}
// check batch sizes
if
(
lens
[
0
]
!=
inputs
.
at
(
1
).
lens
()[
0
])
{
MIGRAPHX_THROW
(
"NonMaxSuppression: number of batches mismatch between boxes and scores input"
);
}
}
std
::
vector
<
int64_t
>
out_lens
(
2
);
std
::
vector
<
int64_t
>
out_lens
(
2
);
...
@@ -74,8 +81,8 @@ struct nonmaxsuppression
...
@@ -74,8 +81,8 @@ struct nonmaxsuppression
struct
box
struct
box
{
{
std
::
array
<
float
,
2
>
x
;
std
::
array
<
double
,
2
>
x
;
std
::
array
<
float
,
2
>
y
;
std
::
array
<
double
,
2
>
y
;
void
sort
()
void
sort
()
{
{
...
@@ -83,9 +90,9 @@ struct nonmaxsuppression
...
@@ -83,9 +90,9 @@ struct nonmaxsuppression
std
::
sort
(
y
.
begin
(),
y
.
end
());
std
::
sort
(
y
.
begin
(),
y
.
end
());
}
}
std
::
array
<
float
,
2
>&
operator
[](
std
::
size_t
i
)
{
return
i
==
0
?
x
:
y
;
}
std
::
array
<
double
,
2
>&
operator
[](
std
::
size_t
i
)
{
return
i
==
0
?
x
:
y
;
}
float
area
()
const
double
area
()
const
{
{
assert
(
std
::
is_sorted
(
x
.
begin
(),
x
.
end
()));
assert
(
std
::
is_sorted
(
x
.
begin
(),
x
.
end
()));
assert
(
std
::
is_sorted
(
y
.
begin
(),
y
.
end
()));
assert
(
std
::
is_sorted
(
y
.
begin
(),
y
.
end
()));
...
@@ -94,29 +101,29 @@ struct nonmaxsuppression
...
@@ -94,29 +101,29 @@ struct nonmaxsuppression
};
};
template
<
class
T
>
template
<
class
T
>
box
batch_box
(
const
T
*
boxes
,
std
::
size_t
bidx
)
const
box
batch_box
(
T
boxes
,
std
::
size_t
b
ox_
idx
)
const
{
{
box
result
{};
box
result
{};
const
T
*
start
=
boxes
+
4
*
bidx
;
auto
start
=
boxes
+
4
*
b
ox_
idx
;
if
(
center_point_box
)
if
(
center_point_box
)
{
{
float
half_width
=
start
[
2
]
/
2.0
f
;
double
half_width
=
start
[
2
]
/
2.0
;
float
half_height
=
start
[
3
]
/
2.0
f
;
double
half_height
=
start
[
3
]
/
2.0
;
float
x_center
=
start
[
0
];
double
x_center
=
start
[
0
];
float
y_center
=
start
[
1
];
double
y_center
=
start
[
1
];
result
.
x
=
{
x_center
-
half_width
,
x_center
+
half_width
};
result
.
x
=
{
x_center
-
half_width
,
x_center
+
half_width
};
result
.
y
=
{
y_center
-
half_height
,
y_center
+
half_height
};
result
.
y
=
{
y_center
-
half_height
,
y_center
+
half_height
};
}
}
else
else
{
{
result
.
x
=
{
start
[
1
],
start
[
3
]};
result
.
x
=
{
static_cast
<
double
>
(
start
[
1
]
)
,
static_cast
<
double
>
(
start
[
3
]
)
};
result
.
y
=
{
start
[
0
],
start
[
2
]};
result
.
y
=
{
static_cast
<
double
>
(
start
[
0
]
)
,
static_cast
<
double
>
(
start
[
2
]
)
};
}
}
return
result
;
return
result
;
}
}
inline
bool
suppress_by_iou
(
box
b1
,
box
b2
,
float
iou_threshold
)
const
inline
bool
suppress_by_iou
(
box
b1
,
box
b2
,
double
iou_threshold
)
const
{
{
b1
.
sort
();
b1
.
sort
();
b2
.
sort
();
b2
.
sort
();
...
@@ -128,7 +135,7 @@ struct nonmaxsuppression
...
@@ -128,7 +135,7 @@ struct nonmaxsuppression
intersection
[
i
][
1
]
=
std
::
min
(
b1
[
i
][
1
],
b2
[
i
][
1
]);
intersection
[
i
][
1
]
=
std
::
min
(
b1
[
i
][
1
],
b2
[
i
][
1
]);
}
}
std
::
vector
<
std
::
array
<
float
,
2
>>
bbox
=
{
intersection
.
x
,
intersection
.
y
};
std
::
vector
<
std
::
array
<
double
,
2
>>
bbox
=
{
intersection
.
x
,
intersection
.
y
};
if
(
std
::
any_of
(
bbox
.
begin
(),
bbox
.
end
(),
[](
auto
bx
)
{
if
(
std
::
any_of
(
bbox
.
begin
(),
bbox
.
end
(),
[](
auto
bx
)
{
return
not
std
::
is_sorted
(
bx
.
begin
(),
bx
.
end
());
return
not
std
::
is_sorted
(
bx
.
begin
(),
bx
.
end
());
}))
}))
...
@@ -136,115 +143,128 @@ struct nonmaxsuppression
...
@@ -136,115 +143,128 @@ struct nonmaxsuppression
return
false
;
return
false
;
}
}
const
float
area1
=
b1
.
area
();
const
double
area1
=
b1
.
area
();
const
float
area2
=
b2
.
area
();
const
double
area2
=
b2
.
area
();
const
float
intersection_area
=
intersection
.
area
();
const
double
intersection_area
=
intersection
.
area
();
const
float
union_area
=
area1
+
area2
-
intersection_area
;
const
double
union_area
=
area1
+
area2
-
intersection_area
;
if
(
area1
<=
.0
f
or
area2
<=
.0
f
or
union_area
<=
.0
f
)
if
(
area1
<=
.0
f
or
area2
<=
.0
f
or
union_area
<=
.0
f
)
{
{
return
false
;
return
false
;
}
}
const
float
intersection_over_union
=
intersection_area
/
union_area
;
const
double
intersection_over_union
=
intersection_area
/
union_area
;
return
intersection_over_union
>
iou_threshold
;
return
intersection_over_union
>
iou_threshold
;
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
// filter boxes below score_threshold
template
<
class
T
>
std
::
priority_queue
<
std
::
pair
<
double
,
int64_t
>>
filter_boxes_by_score
(
T
scores_start
,
std
::
size_t
num_boxes
,
double
score_threshold
)
const
{
{
argument
result
{
output_shape
};
std
::
priority_queue
<
std
::
pair
<
double
,
int64_t
>>
boxes_heap
;
auto
insert_to_boxes_heap
=
result
.
visit
([
&
](
auto
out
)
{
std
::
fill
(
out
.
begin
(),
out
.
end
(),
0
);
});
make_function_output_iterator
([
&
](
const
auto
&
x
)
{
boxes_heap
.
push
(
x
);
});
int64_t
box_idx
=
0
;
std
::
size_t
max_output_boxes_per_class
=
0
;
transform_if
(
float
iou_threshold
=
0.0
f
;
scores_start
,
float
score_threshold
=
0.0
f
;
scores_start
+
num_boxes
,
insert_to_boxes_heap
,
[
&
](
auto
sc
)
{
box_idx
++
;
return
sc
>=
score_threshold
;
},
[
&
](
auto
sc
)
{
return
std
::
make_pair
(
sc
,
box_idx
-
1
);
});
return
boxes_heap
;
}
if
(
args
.
size
()
>
2
)
template
<
class
H
,
class
S
>
{
void
select_boxes
(
H
&
boxes_heap
,
max_output_boxes_per_class
=
args
.
at
(
2
).
at
<
std
::
size_t
>
();
std
::
vector
<
std
::
pair
<
double
,
int64_t
>>&
selected_boxes_inside_class
,
}
std
::
vector
<
int64_t
>&
selected_indices
,
// max_output_boxes_per_class is 0, no output
S
batch_boxes_start
,
if
(
max_output_boxes_per_class
==
0
)
std
::
size_t
max_output_boxes_per_class
,
double
iou_threshold
,
std
::
size_t
batch_idx
,
std
::
size_t
class_idx
)
const
{
selected_boxes_inside_class
.
clear
();
// Get the next box with top score, filter by iou_threshold
while
(
!
boxes_heap
.
empty
()
&&
selected_boxes_inside_class
.
size
()
<
max_output_boxes_per_class
)
{
{
return
result
;
// Check with existing selected boxes for this class, remove box if it
// exceeds the IOU (Intersection Over Union) threshold
const
auto
next_top_score
=
boxes_heap
.
top
();
bool
not_selected
=
std
::
any_of
(
selected_boxes_inside_class
.
begin
(),
selected_boxes_inside_class
.
end
(),
[
&
](
auto
selected_index
)
{
return
this
->
suppress_by_iou
(
batch_box
(
batch_boxes_start
,
next_top_score
.
second
),
batch_box
(
batch_boxes_start
,
selected_index
.
second
),
iou_threshold
);
});
if
(
not
not_selected
)
{
selected_boxes_inside_class
.
push_back
(
next_top_score
);
selected_indices
.
push_back
(
batch_idx
);
selected_indices
.
push_back
(
class_idx
);
selected_indices
.
push_back
(
next_top_score
.
second
);
}
boxes_heap
.
pop
();
}
}
}
if
(
args
.
size
()
>
3
)
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
iou_threshold
=
args
.
at
(
3
).
at
<
float
>
();
argument
result
{
output_shape
};
}
if
(
args
.
size
()
>
4
)
std
::
size_t
max_output_boxes_per_class
=
(
args
.
size
()
>
2
)
?
(
args
.
at
(
2
).
at
<
std
::
size_t
>
())
:
0
;
if
(
max_output_boxes_per_class
==
0
)
{
{
score_threshold
=
args
.
at
(
4
).
at
<
float
>
()
;
return
result
;
}
}
double
iou_threshold
=
(
args
.
size
()
>
3
)
?
(
args
.
at
(
3
).
at
<
double
>
())
:
0.0
f
;
const
auto
&
lens
=
args
.
at
(
1
).
get_shape
().
lens
();
double
score_threshold
=
(
args
.
size
()
>
4
)
?
(
args
.
at
(
4
).
at
<
double
>
())
:
0.0
f
;
auto
batch_num
=
lens
[
0
];
auto
class_num
=
lens
[
1
];
result
.
visit
([
&
](
auto
output
)
{
auto
box_num
=
args
.
at
(
0
).
get_shape
().
lens
()[
1
];
visit_all
(
args
[
0
],
args
[
1
])([
&
](
auto
boxes
,
auto
scores
)
{
std
::
fill
(
output
.
begin
(),
output
.
end
(),
0
);
std
::
vector
<
std
::
pair
<
float
,
int64_t
>>
selected_boxes_inside_class
;
const
auto
&
lens
=
scores
.
get_shape
().
lens
();
std
::
vector
<
int64_t
>
selected_indices
;
const
auto
num_batches
=
lens
[
0
];
selected_boxes_inside_class
.
reserve
(
output_shape
.
elements
());
const
auto
num_classes
=
lens
[
1
];
const
auto
num_boxes
=
lens
[
2
];
auto
scores
=
make_view
<
float
>
(
args
.
at
(
1
).
get_shape
(),
args
.
at
(
1
).
cast
<
float
>
());
// boxes of a class with NMS applied [score, index]
const
float
*
boxes
=
args
.
at
(
0
).
cast
<
float
>
();
std
::
vector
<
std
::
pair
<
double
,
int64_t
>>
selected_boxes_inside_class
;
shape
comp_s
{
shape
::
float_type
,
{
batch_num
,
class_num
}};
std
::
vector
<
int64_t
>
selected_indices
;
shape_for_each
(
comp_s
,
[
&
](
auto
idx
)
{
selected_boxes_inside_class
.
reserve
(
output_shape
.
elements
());
auto
bidx
=
idx
[
0
];
// iterate over batches and classes
auto
cidx
=
idx
[
1
];
shape
comp_s
{
shape
::
double_type
,
{
num_batches
,
num_classes
}};
shape_for_each
(
comp_s
,
[
&
](
auto
idx
)
{
std
::
size_t
score_offset
=
(
bidx
*
class_num
+
cidx
)
*
box_num
;
auto
batch_idx
=
idx
[
0
];
const
float
*
batch_boxes
=
boxes
+
bidx
*
box_num
*
4
;
auto
class_idx
=
idx
[
1
];
std
::
priority_queue
<
std
::
pair
<
float
,
int64_t
>>
sorted_boxes
;
// index offset for this class
auto
insert_to_sorted_boxes
=
auto
scores_start
=
make_function_output_iterator
([
&
](
const
auto
&
x
)
{
sorted_boxes
.
push
(
x
);
});
scores
.
begin
()
+
(
batch_idx
*
num_classes
+
class_idx
)
*
num_boxes
;
// iterator to first value of this batch
int64_t
box_idx
=
0
;
auto
batch_boxes_start
=
boxes
.
begin
()
+
batch_idx
*
num_boxes
*
4
;
transform_if
(
auto
boxes_heap
=
scores
.
begin
()
+
score_offset
,
filter_boxes_by_score
(
scores_start
,
num_boxes
,
score_threshold
);
scores
.
begin
()
+
score_offset
+
box_num
,
select_boxes
(
boxes_heap
,
insert_to_sorted_boxes
,
selected_boxes_inside_class
,
[
&
](
auto
sc
)
{
selected_indices
,
box_idx
++
;
batch_boxes_start
,
return
sc
>=
score_threshold
;
max_output_boxes_per_class
,
},
iou_threshold
,
[
&
](
auto
sc
)
{
return
std
::
make_pair
(
sc
,
box_idx
-
1
);
});
batch_idx
,
class_idx
);
selected_boxes_inside_class
.
clear
();
});
// Get the next box with top score, filter by iou_threshold
std
::
copy
(
selected_indices
.
begin
(),
selected_indices
.
end
(),
output
.
begin
());
while
(
!
sorted_boxes
.
empty
()
&&
});
selected_boxes_inside_class
.
size
()
<
max_output_boxes_per_class
)
{
const
std
::
pair
<
float
,
int64_t
>&
next_top_score
=
sorted_boxes
.
top
();
// Check with existing selected boxes for this class, suppress if exceed the IOU
// (Intersection Over Union) threshold
bool
not_selected
=
std
::
any_of
(
selected_boxes_inside_class
.
begin
(),
selected_boxes_inside_class
.
end
(),
[
&
](
auto
selected_index
)
{
return
this
->
suppress_by_iou
(
batch_box
(
batch_boxes
,
next_top_score
.
second
),
batch_box
(
batch_boxes
,
selected_index
.
second
),
iou_threshold
);
});
if
(
not
not_selected
)
{
selected_boxes_inside_class
.
push_back
(
next_top_score
);
selected_indices
.
push_back
(
bidx
);
selected_indices
.
push_back
(
cidx
);
selected_indices
.
push_back
(
next_top_score
.
second
);
}
sorted_boxes
.
pop
();
}
});
result
.
visit
([
&
](
auto
out
)
{
std
::
copy
(
selected_indices
.
begin
(),
selected_indices
.
end
(),
out
.
begin
());
});
});
return
result
;
return
result
;
...
...
test/ref_ops_test.cpp
View file @
6efffa37
...
@@ -3470,6 +3470,80 @@ TEST_CASE(nms_test)
...
@@ -3470,6 +3470,80 @@ TEST_CASE(nms_test)
EXPECT
(
migraphx
::
verify_range
(
result
,
gold
));
EXPECT
(
migraphx
::
verify_range
(
result
,
gold
));
}
}
TEST_CASE
(
nms_transpose1_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
boxes_s
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
6
}};
std
::
vector
<
float
>
boxes_vec
=
{
0.5
,
0.5
,
0.5
,
0.5
,
0.5
,
0.5
,
0.5
,
0.6
,
0.4
,
10.5
,
10.6
,
100.5
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
};
migraphx
::
shape
scores_s
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
6
}};
std
::
vector
<
float
>
scores_vec
=
{
0.9
,
0.75
,
0.6
,
0.95
,
0.5
,
0.3
};
auto
t_boxes_l
=
mm
->
add_literal
(
migraphx
::
literal
(
boxes_s
,
boxes_vec
));
auto
scores_l
=
mm
->
add_literal
(
migraphx
::
literal
(
scores_s
,
scores_vec
));
auto
max_out_l
=
mm
->
add_literal
(
int64_t
{
4
});
auto
iou_threshold
=
mm
->
add_literal
(
0.5
f
);
auto
score_threshold
=
mm
->
add_literal
(
0.0
f
);
auto
transpose_boxes
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
}}}),
t_boxes_l
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"nonmaxsuppression"
,
{{
"center_point_box"
,
1
}}),
transpose_boxes
,
scores_l
,
max_out_l
,
iou_threshold
,
score_threshold
);
mm
->
add_return
({
r
});
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
output
=
p
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result
;
output
.
visit
([
&
](
auto
out
)
{
result
.
assign
(
out
.
begin
(),
out
.
end
());
});
std
::
vector
<
int64_t
>
gold
=
{
0
,
0
,
3
,
0
,
0
,
0
,
0
,
0
,
5
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
EXPECT
(
migraphx
::
verify_range
(
result
,
gold
));
}
TEST_CASE
(
nms_transpose2_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
boxes_s
{
migraphx
::
shape
::
float_type
,
{
4
,
1
,
6
}};
std
::
vector
<
float
>
boxes_vec
=
{
0.5
,
0.5
,
0.5
,
0.5
,
0.5
,
0.5
,
0.5
,
0.6
,
0.4
,
10.5
,
10.6
,
100.5
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
};
migraphx
::
shape
scores_s
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
6
}};
std
::
vector
<
float
>
scores_vec
=
{
0.9
,
0.75
,
0.6
,
0.95
,
0.5
,
0.3
};
auto
t_boxes_l
=
mm
->
add_literal
(
migraphx
::
literal
(
boxes_s
,
boxes_vec
));
auto
scores_l
=
mm
->
add_literal
(
migraphx
::
literal
(
scores_s
,
scores_vec
));
auto
max_out_l
=
mm
->
add_literal
(
int64_t
{
4
});
auto
iou_threshold
=
mm
->
add_literal
(
0.5
f
);
auto
score_threshold
=
mm
->
add_literal
(
0.0
f
);
auto
transpose_boxes
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
2
,
0
}}}),
t_boxes_l
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"nonmaxsuppression"
,
{{
"center_point_box"
,
1
}}),
transpose_boxes
,
scores_l
,
max_out_l
,
iou_threshold
,
score_threshold
);
mm
->
add_return
({
r
});
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
output
=
p
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result
;
output
.
visit
([
&
](
auto
out
)
{
result
.
assign
(
out
.
begin
(),
out
.
end
());
});
std
::
vector
<
int64_t
>
gold
=
{
0
,
0
,
3
,
0
,
0
,
0
,
0
,
0
,
5
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
EXPECT
(
migraphx
::
verify_range
(
result
,
gold
));
}
TEST_CASE
(
nonzero_test
)
TEST_CASE
(
nonzero_test
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment