Commit 8e99dbc9 authored by charlie's avatar charlie
Browse files

test dynamic NMS, remove contiguous requirement

parent 6382ff10
...@@ -56,7 +56,7 @@ struct nonmaxsuppression ...@@ -56,7 +56,7 @@ struct nonmaxsuppression
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
// requires at least 2 inputs // requires at least 2 inputs
check_shapes{{inputs.at(0), inputs.at(1)}, *this}.only_dims(3); check_shapes{{inputs.at(0), inputs.at(1)}, *this}.only_dims(3).same_ndims();
// both boxes and scores will be dynamic if one of them is dynamic // both boxes and scores will be dynamic if one of them is dynamic
if(inputs.at(0).dynamic()) if(inputs.at(0).dynamic())
{ {
......
...@@ -75,7 +75,7 @@ struct parse_generic_op : op_parser<parse_generic_op> ...@@ -75,7 +75,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
bool needs_contiguous(const std::string& op_name) const bool needs_contiguous(const std::string& op_name) const
{ {
return contains({"flatten", "gather", "nonmaxsuppression", "scatter"}, op_name); return contains({"flatten", "gather", "scatter"}, op_name);
} }
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
......
...@@ -3343,7 +3343,33 @@ def nms_test(): ...@@ -3343,7 +3343,33 @@ def nms_test():
st = helper.make_tensor_value_info('score_threshold', TensorProto.FLOAT, st = helper.make_tensor_value_info('score_threshold', TensorProto.FLOAT,
[1]) [1])
out = helper.make_tensor_value_info('selected_indices', TensorProto.INT64, out = helper.make_tensor_value_info('selected_indices', TensorProto.INT64,
[6, 3]) [None, 3])
node = onnx.helper.make_node('NonMaxSuppression',
inputs=[
'boxes', 'scores',
'max_output_boxes_per_class',
'iou_threshold', 'score_threshold'
],
outputs=['selected_indices'],
center_point_box=1)
return ([node], [b, s, mo, iou, st], [out])
@onnx_test
def nms_dynamic_batch_test():
b = helper.make_tensor_value_info('boxes', TensorProto.FLOAT, [None, 6, 4])
s = helper.make_tensor_value_info('scores', TensorProto.FLOAT,
[None, 1, 6])
mo = helper.make_tensor_value_info('max_output_boxes_per_class',
TensorProto.INT64, [1])
iou = helper.make_tensor_value_info('iou_threshold', TensorProto.FLOAT,
[1])
st = helper.make_tensor_value_info('score_threshold', TensorProto.FLOAT,
[1])
out = helper.make_tensor_value_info('selected_indices', TensorProto.INT64,
[None, 3])
node = onnx.helper.make_node('NonMaxSuppression', node = onnx.helper.make_node('NonMaxSuppression',
inputs=[ inputs=[
......
No preview for this file type
...@@ -3118,6 +3118,31 @@ TEST_CASE(nms_test) ...@@ -3118,6 +3118,31 @@ TEST_CASE(nms_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(nms_dynamic_batch_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sb{migraphx::shape::float_type, {{1, 10, 0}, {6, 6, 0}, {4, 4, 0}}};
auto b = mm->add_parameter("boxes", sb);
migraphx::shape ss{migraphx::shape::float_type, {{1, 10, 0}, {1, 1, 0}, {6, 6, 0}}};
auto s = mm->add_parameter("scores", ss);
migraphx::shape smo{migraphx::shape::int64_type, {1}};
auto mo = mm->add_parameter("max_output_boxes_per_class", smo);
migraphx::shape siou{migraphx::shape::float_type, {1}};
auto iou = mm->add_parameter("iou_threshold", siou);
migraphx::shape sst{migraphx::shape::float_type, {1}};
auto st = mm->add_parameter("score_threshold", sst);
auto ret = mm->add_instruction(
migraphx::make_op("nonmaxsuppression", {{"center_point_box", 1}}), b, s, mo, iou, st);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 10, 0};
auto prog = migraphx::parse_onnx("nms_dynamic_batch_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(nonzero_dynamic_test) TEST_CASE(nonzero_dynamic_test)
{ {
migraphx::program p; migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment