Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ad73abbc
Unverified
Commit
ad73abbc
authored
Jun 29, 2022
by
Charlie Lin
Committed by
GitHub
Jun 29, 2022
Browse files
NMS refactor, enable nonstandard shape (#1257)
Allows PyTorch converted version of SSD-resnet34 to work
parent
ad27d0d6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
191 additions
and
100 deletions
+191
-100
src/include/migraphx/iota_iterator.hpp
src/include/migraphx/iota_iterator.hpp
+2
-1
src/include/migraphx/op/nonmaxsuppression.hpp
src/include/migraphx/op/nonmaxsuppression.hpp
+115
-99
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+74
-0
No files found.
src/include/migraphx/iota_iterator.hpp
View file @
ad73abbc
...
@@ -81,8 +81,9 @@ struct basic_iota_iterator
...
@@ -81,8 +81,9 @@ struct basic_iota_iterator
index
--
;
index
--
;
return
it
;
return
it
;
}
}
// TODO: operator->
reference
operator
*
()
const
{
return
f
(
index
);
}
reference
operator
*
()
const
{
return
f
(
index
);
}
pointer
operator
->
()
const
{
return
&
f
(
index
);
}
reference
operator
[](
int
n
)
const
{
return
f
(
index
+
n
);
}
};
};
template
<
class
T
,
class
F
>
template
<
class
T
,
class
F
>
...
...
src/include/migraphx/op/nonmaxsuppression.hpp
View file @
ad73abbc
...
@@ -56,14 +56,21 @@ struct nonmaxsuppression
...
@@ -56,14 +56,21 @@ struct nonmaxsuppression
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
// requires at least 2 inputs
// requires at least 2 inputs
check_shapes
{
inputs
,
*
this
}.
standard
();
check_shapes
{{
inputs
.
at
(
0
),
inputs
.
at
(
1
)},
*
this
}.
only_dims
(
3
);
check_shapes
{{
inputs
.
at
(
0
),
inputs
.
at
(
1
)},
*
this
}.
only_dims
(
3
);
auto
lens
=
inputs
.
front
().
lens
();
auto
lens
=
inputs
.
front
().
lens
();
// check input shape
// check input shape
if
(
lens
[
1
]
!=
inputs
.
at
(
1
).
lens
()[
2
])
if
(
lens
[
1
]
!=
inputs
.
at
(
1
).
lens
()[
2
])
{
{
MIGRAPHX_THROW
(
"NonMaxSuppression: dimension mismatch between first and second input!"
);
MIGRAPHX_THROW
(
"NonMaxSuppression: spatial dimension mismatch between boxes and scores input"
);
}
// check batch sizes
if
(
lens
[
0
]
!=
inputs
.
at
(
1
).
lens
()[
0
])
{
MIGRAPHX_THROW
(
"NonMaxSuppression: number of batches mismatch between boxes and scores input"
);
}
}
std
::
vector
<
int64_t
>
out_lens
(
2
);
std
::
vector
<
int64_t
>
out_lens
(
2
);
...
@@ -74,8 +81,8 @@ struct nonmaxsuppression
...
@@ -74,8 +81,8 @@ struct nonmaxsuppression
struct
box
struct
box
{
{
std
::
array
<
float
,
2
>
x
;
std
::
array
<
double
,
2
>
x
;
std
::
array
<
float
,
2
>
y
;
std
::
array
<
double
,
2
>
y
;
void
sort
()
void
sort
()
{
{
...
@@ -83,9 +90,9 @@ struct nonmaxsuppression
...
@@ -83,9 +90,9 @@ struct nonmaxsuppression
std
::
sort
(
y
.
begin
(),
y
.
end
());
std
::
sort
(
y
.
begin
(),
y
.
end
());
}
}
std
::
array
<
float
,
2
>&
operator
[](
std
::
size_t
i
)
{
return
i
==
0
?
x
:
y
;
}
std
::
array
<
double
,
2
>&
operator
[](
std
::
size_t
i
)
{
return
i
==
0
?
x
:
y
;
}
float
area
()
const
double
area
()
const
{
{
assert
(
std
::
is_sorted
(
x
.
begin
(),
x
.
end
()));
assert
(
std
::
is_sorted
(
x
.
begin
(),
x
.
end
()));
assert
(
std
::
is_sorted
(
y
.
begin
(),
y
.
end
()));
assert
(
std
::
is_sorted
(
y
.
begin
(),
y
.
end
()));
...
@@ -94,29 +101,29 @@ struct nonmaxsuppression
...
@@ -94,29 +101,29 @@ struct nonmaxsuppression
};
};
template
<
class
T
>
template
<
class
T
>
box
batch_box
(
const
T
*
boxes
,
std
::
size_t
bidx
)
const
box
batch_box
(
T
boxes
,
std
::
size_t
b
ox_
idx
)
const
{
{
box
result
{};
box
result
{};
const
T
*
start
=
boxes
+
4
*
bidx
;
auto
start
=
boxes
+
4
*
b
ox_
idx
;
if
(
center_point_box
)
if
(
center_point_box
)
{
{
float
half_width
=
start
[
2
]
/
2.0
f
;
double
half_width
=
start
[
2
]
/
2.0
;
float
half_height
=
start
[
3
]
/
2.0
f
;
double
half_height
=
start
[
3
]
/
2.0
;
float
x_center
=
start
[
0
];
double
x_center
=
start
[
0
];
float
y_center
=
start
[
1
];
double
y_center
=
start
[
1
];
result
.
x
=
{
x_center
-
half_width
,
x_center
+
half_width
};
result
.
x
=
{
x_center
-
half_width
,
x_center
+
half_width
};
result
.
y
=
{
y_center
-
half_height
,
y_center
+
half_height
};
result
.
y
=
{
y_center
-
half_height
,
y_center
+
half_height
};
}
}
else
else
{
{
result
.
x
=
{
start
[
1
],
start
[
3
]};
result
.
x
=
{
static_cast
<
double
>
(
start
[
1
]
)
,
static_cast
<
double
>
(
start
[
3
]
)
};
result
.
y
=
{
start
[
0
],
start
[
2
]};
result
.
y
=
{
static_cast
<
double
>
(
start
[
0
]
)
,
static_cast
<
double
>
(
start
[
2
]
)
};
}
}
return
result
;
return
result
;
}
}
inline
bool
suppress_by_iou
(
box
b1
,
box
b2
,
float
iou_threshold
)
const
inline
bool
suppress_by_iou
(
box
b1
,
box
b2
,
double
iou_threshold
)
const
{
{
b1
.
sort
();
b1
.
sort
();
b2
.
sort
();
b2
.
sort
();
...
@@ -128,7 +135,7 @@ struct nonmaxsuppression
...
@@ -128,7 +135,7 @@ struct nonmaxsuppression
intersection
[
i
][
1
]
=
std
::
min
(
b1
[
i
][
1
],
b2
[
i
][
1
]);
intersection
[
i
][
1
]
=
std
::
min
(
b1
[
i
][
1
],
b2
[
i
][
1
]);
}
}
std
::
vector
<
std
::
array
<
float
,
2
>>
bbox
=
{
intersection
.
x
,
intersection
.
y
};
std
::
vector
<
std
::
array
<
double
,
2
>>
bbox
=
{
intersection
.
x
,
intersection
.
y
};
if
(
std
::
any_of
(
bbox
.
begin
(),
bbox
.
end
(),
[](
auto
bx
)
{
if
(
std
::
any_of
(
bbox
.
begin
(),
bbox
.
end
(),
[](
auto
bx
)
{
return
not
std
::
is_sorted
(
bx
.
begin
(),
bx
.
end
());
return
not
std
::
is_sorted
(
bx
.
begin
(),
bx
.
end
());
}))
}))
...
@@ -136,115 +143,124 @@ struct nonmaxsuppression
...
@@ -136,115 +143,124 @@ struct nonmaxsuppression
return
false
;
return
false
;
}
}
const
float
area1
=
b1
.
area
();
const
double
area1
=
b1
.
area
();
const
float
area2
=
b2
.
area
();
const
double
area2
=
b2
.
area
();
const
float
intersection_area
=
intersection
.
area
();
const
double
intersection_area
=
intersection
.
area
();
const
float
union_area
=
area1
+
area2
-
intersection_area
;
const
double
union_area
=
area1
+
area2
-
intersection_area
;
if
(
area1
<=
.0
f
or
area2
<=
.0
f
or
union_area
<=
.0
f
)
if
(
area1
<=
.0
f
or
area2
<=
.0
f
or
union_area
<=
.0
f
)
{
{
return
false
;
return
false
;
}
}
const
float
intersection_over_union
=
intersection_area
/
union_area
;
const
double
intersection_over_union
=
intersection_area
/
union_area
;
return
intersection_over_union
>
iou_threshold
;
return
intersection_over_union
>
iou_threshold
;
}
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
// filter boxes below score_threshold
template
<
class
T
>
std
::
priority_queue
<
std
::
pair
<
double
,
int64_t
>>
filter_boxes_by_score
(
T
scores_start
,
std
::
size_t
num_boxes
,
double
score_threshold
)
const
{
{
argument
result
{
output_shape
};
std
::
priority_queue
<
std
::
pair
<
double
,
int64_t
>>
boxes_heap
;
auto
insert_to_boxes_heap
=
result
.
visit
([
&
](
auto
out
)
{
std
::
fill
(
out
.
begin
(),
out
.
end
(),
0
);
});
make_function_output_iterator
([
&
](
const
auto
&
x
)
{
boxes_heap
.
push
(
x
);
});
int64_t
box_idx
=
0
;
std
::
size_t
max_output_boxes_per_class
=
0
;
transform_if
(
float
iou_threshold
=
0.0
f
;
scores_start
,
float
score_threshold
=
0.0
f
;
scores_start
+
num_boxes
,
insert_to_boxes_heap
,
if
(
args
.
size
()
>
2
)
[
&
](
auto
sc
)
{
{
box_idx
++
;
max_output_boxes_per_class
=
args
.
at
(
2
).
at
<
std
::
size_t
>
();
return
sc
>=
score_threshold
;
}
},
// max_output_boxes_per_class is 0, no output
[
&
](
auto
sc
)
{
return
std
::
make_pair
(
sc
,
box_idx
-
1
);
});
if
(
max_output_boxes_per_class
==
0
)
return
boxes_heap
;
{
}
return
result
;
}
if
(
args
.
size
()
>
3
)
{
iou_threshold
=
args
.
at
(
3
).
at
<
float
>
();
}
if
(
args
.
size
()
>
4
)
{
score_threshold
=
args
.
at
(
4
).
at
<
float
>
();
}
const
auto
&
lens
=
args
.
at
(
1
).
get_shape
().
lens
();
auto
batch_num
=
lens
[
0
];
auto
class_num
=
lens
[
1
];
auto
box_num
=
args
.
at
(
0
).
get_shape
().
lens
()[
1
];
std
::
vector
<
std
::
pair
<
float
,
int64_t
>>
selected_boxes_inside_class
;
template
<
class
Output
,
class
Boxes
,
class
Scores
>
void
compute_nms
(
Output
output
,
Boxes
boxes
,
Scores
scores
,
const
shape
&
output_shape
,
std
::
size_t
max_output_boxes_per_class
,
double
iou_threshold
,
double
score_threshold
)
const
{
std
::
fill
(
output
.
begin
(),
output
.
end
(),
0
);
const
auto
&
lens
=
scores
.
get_shape
().
lens
();
const
auto
num_batches
=
lens
[
0
];
const
auto
num_classes
=
lens
[
1
];
const
auto
num_boxes
=
lens
[
2
];
// boxes of a class with NMS applied [score, index]
std
::
vector
<
std
::
pair
<
double
,
int64_t
>>
selected_boxes_inside_class
;
std
::
vector
<
int64_t
>
selected_indices
;
std
::
vector
<
int64_t
>
selected_indices
;
selected_boxes_inside_class
.
reserve
(
output_shape
.
elements
());
selected_boxes_inside_class
.
reserve
(
output_shape
.
elements
());
// iterate over batches and classes
auto
scores
=
make_view
<
float
>
(
args
.
at
(
1
).
get_shape
(),
args
.
at
(
1
).
cast
<
float
>
());
shape
comp_s
{
shape
::
double_type
,
{
num_batches
,
num_classes
}};
const
float
*
boxes
=
args
.
at
(
0
).
cast
<
float
>
();
shape
comp_s
{
shape
::
float_type
,
{
batch_num
,
class_num
}};
shape_for_each
(
comp_s
,
[
&
](
auto
idx
)
{
shape_for_each
(
comp_s
,
[
&
](
auto
idx
)
{
auto
bidx
=
idx
[
0
];
auto
batch_idx
=
idx
[
0
];
auto
cidx
=
idx
[
1
];
auto
class_idx
=
idx
[
1
];
// index offset for this class
std
::
size_t
score_offset
=
(
bidx
*
class_num
+
cidx
)
*
box_num
;
auto
scores_start
=
scores
.
begin
()
+
(
batch_idx
*
num_classes
+
class_idx
)
*
num_boxes
;
const
float
*
batch_boxes
=
boxes
+
bidx
*
box_num
*
4
;
// iterator to first value of this batch
std
::
priority_queue
<
std
::
pair
<
float
,
int64_t
>>
sorted_boxes
;
auto
batch_boxes_start
=
boxes
.
begin
()
+
batch_idx
*
num_boxes
*
4
;
auto
insert_to_sorted_boxes
=
auto
boxes_heap
=
filter_boxes_by_score
(
scores_start
,
num_boxes
,
score_threshold
);
make_function_output_iterator
([
&
](
const
auto
&
x
)
{
sorted_boxes
.
push
(
x
);
});
int64_t
box_idx
=
0
;
transform_if
(
scores
.
begin
()
+
score_offset
,
scores
.
begin
()
+
score_offset
+
box_num
,
insert_to_sorted_boxes
,
[
&
](
auto
sc
)
{
box_idx
++
;
return
sc
>=
score_threshold
;
},
[
&
](
auto
sc
)
{
return
std
::
make_pair
(
sc
,
box_idx
-
1
);
});
selected_boxes_inside_class
.
clear
();
selected_boxes_inside_class
.
clear
();
// Get the next box with top score, filter by iou_threshold
// Get the next box with top score, filter by iou_threshold
while
(
!
sorted_
boxes
.
empty
()
&&
while
(
!
boxes
_heap
.
empty
()
&&
selected_boxes_inside_class
.
size
()
<
max_output_boxes_per_class
)
selected_boxes_inside_class
.
size
()
<
max_output_boxes_per_class
)
{
{
const
std
::
pair
<
float
,
int64_t
>&
next_top_score
=
sorted_boxes
.
top
();
// Check with existing selected boxes for this class, remove box if it
// exceeds the IOU (Intersection Over Union) threshold
// Check with existing selected boxes for this class, suppress if exceed the IOU
const
auto
next_top_score
=
boxes_heap
.
top
();
// (Intersection Over Union) threshold
bool
not_selected
=
bool
not_selected
=
std
::
any_of
(
std
::
any_of
(
selected_boxes_inside_class
.
begin
(),
selected_boxes_inside_class
.
begin
(),
selected_boxes_inside_class
.
end
(),
selected_boxes_inside_class
.
end
(),
[
&
](
auto
selected_index
)
{
[
&
](
auto
selected_index
)
{
return
this
->
suppress_by_iou
(
return
this
->
suppress_by_iou
(
batch_box
(
batch_boxes
,
next_top_score
.
second
),
batch_box
(
batch_boxes
_start
,
next_top_score
.
second
),
batch_box
(
batch_boxes
,
selected_index
.
second
),
batch_box
(
batch_boxes
_start
,
selected_index
.
second
),
iou_threshold
);
iou_threshold
);
});
});
if
(
not
not_selected
)
if
(
not
not_selected
)
{
{
selected_boxes_inside_class
.
push_back
(
next_top_score
);
selected_boxes_inside_class
.
push_back
(
next_top_score
);
selected_indices
.
push_back
(
bidx
);
selected_indices
.
push_back
(
b
atch_
idx
);
selected_indices
.
push_back
(
cidx
);
selected_indices
.
push_back
(
c
lass_
idx
);
selected_indices
.
push_back
(
next_top_score
.
second
);
selected_indices
.
push_back
(
next_top_score
.
second
);
}
}
sorted_
boxes
.
pop
();
boxes
_heap
.
pop
();
}
}
});
});
std
::
copy
(
selected_indices
.
begin
(),
selected_indices
.
end
(),
output
.
begin
());
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
result
.
visit
([
&
](
auto
out
)
{
std
::
size_t
max_output_boxes_per_class
=
std
::
copy
(
selected_indices
.
begin
(),
selected_indices
.
end
(),
out
.
begin
());
(
args
.
size
()
>
2
)
?
(
args
.
at
(
2
).
at
<
std
::
size_t
>
())
:
0
;
if
(
max_output_boxes_per_class
==
0
)
{
return
result
;
}
double
iou_threshold
=
(
args
.
size
()
>
3
)
?
(
args
.
at
(
3
).
at
<
double
>
())
:
0.0
f
;
double
score_threshold
=
(
args
.
size
()
>
4
)
?
(
args
.
at
(
4
).
at
<
double
>
())
:
0.0
f
;
result
.
visit
([
&
](
auto
output
)
{
visit_all
(
args
[
0
],
args
[
1
])([
&
](
auto
boxes
,
auto
scores
)
{
compute_nms
(
output
,
boxes
,
scores
,
output_shape
,
max_output_boxes_per_class
,
iou_threshold
,
score_threshold
);
});
});
});
return
result
;
return
result
;
...
...
test/ref_ops_test.cpp
View file @
ad73abbc
...
@@ -3187,6 +3187,80 @@ TEST_CASE(nms_test)
...
@@ -3187,6 +3187,80 @@ TEST_CASE(nms_test)
EXPECT
(
migraphx
::
verify_range
(
result
,
gold
));
EXPECT
(
migraphx
::
verify_range
(
result
,
gold
));
}
}
TEST_CASE
(
nms_transpose1_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
boxes_s
{
migraphx
::
shape
::
float_type
,
{
1
,
4
,
6
}};
std
::
vector
<
float
>
boxes_vec
=
{
0.5
,
0.5
,
0.5
,
0.5
,
0.5
,
0.5
,
0.5
,
0.6
,
0.4
,
10.5
,
10.6
,
100.5
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
};
migraphx
::
shape
scores_s
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
6
}};
std
::
vector
<
float
>
scores_vec
=
{
0.9
,
0.75
,
0.6
,
0.95
,
0.5
,
0.3
};
auto
t_boxes_l
=
mm
->
add_literal
(
migraphx
::
literal
(
boxes_s
,
boxes_vec
));
auto
scores_l
=
mm
->
add_literal
(
migraphx
::
literal
(
scores_s
,
scores_vec
));
auto
max_out_l
=
mm
->
add_literal
(
int64_t
{
4
});
auto
iou_threshold
=
mm
->
add_literal
(
0.5
f
);
auto
score_threshold
=
mm
->
add_literal
(
0.0
f
);
auto
transpose_boxes
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
}}}),
t_boxes_l
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"nonmaxsuppression"
,
{{
"center_point_box"
,
1
}}),
transpose_boxes
,
scores_l
,
max_out_l
,
iou_threshold
,
score_threshold
);
mm
->
add_return
({
r
});
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
output
=
p
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result
;
output
.
visit
([
&
](
auto
out
)
{
result
.
assign
(
out
.
begin
(),
out
.
end
());
});
std
::
vector
<
int64_t
>
gold
=
{
0
,
0
,
3
,
0
,
0
,
0
,
0
,
0
,
5
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
EXPECT
(
migraphx
::
verify_range
(
result
,
gold
));
}
TEST_CASE
(
nms_transpose2_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
boxes_s
{
migraphx
::
shape
::
float_type
,
{
4
,
1
,
6
}};
std
::
vector
<
float
>
boxes_vec
=
{
0.5
,
0.5
,
0.5
,
0.5
,
0.5
,
0.5
,
0.5
,
0.6
,
0.4
,
10.5
,
10.6
,
100.5
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
1.0
,
};
migraphx
::
shape
scores_s
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
6
}};
std
::
vector
<
float
>
scores_vec
=
{
0.9
,
0.75
,
0.6
,
0.95
,
0.5
,
0.3
};
auto
t_boxes_l
=
mm
->
add_literal
(
migraphx
::
literal
(
boxes_s
,
boxes_vec
));
auto
scores_l
=
mm
->
add_literal
(
migraphx
::
literal
(
scores_s
,
scores_vec
));
auto
max_out_l
=
mm
->
add_literal
(
int64_t
{
4
});
auto
iou_threshold
=
mm
->
add_literal
(
0.5
f
);
auto
score_threshold
=
mm
->
add_literal
(
0.0
f
);
auto
transpose_boxes
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
2
,
0
}}}),
t_boxes_l
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"nonmaxsuppression"
,
{{
"center_point_box"
,
1
}}),
transpose_boxes
,
scores_l
,
max_out_l
,
iou_threshold
,
score_threshold
);
mm
->
add_return
({
r
});
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
output
=
p
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result
;
output
.
visit
([
&
](
auto
out
)
{
result
.
assign
(
out
.
begin
(),
out
.
end
());
});
std
::
vector
<
int64_t
>
gold
=
{
0
,
0
,
3
,
0
,
0
,
0
,
0
,
0
,
5
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
EXPECT
(
migraphx
::
verify_range
(
result
,
gold
));
}
TEST_CASE
(
nonzero_test
)
TEST_CASE
(
nonzero_test
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment