Unverified Commit a0ae2f79 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Use `add_common_op` for handling types and broadcast in Clip Onnx parsing (#1121)

add_common_op for parse_clip
Should fix #1119
parent 8f184d4a
...@@ -32,9 +32,20 @@ struct onnx_parser ...@@ -32,9 +32,20 @@ struct onnx_parser
instruction_ref add_bias(const std::vector<instruction_ref>& args, instruction_ref add_bias(const std::vector<instruction_ref>& args,
instruction_ref curr_ins, instruction_ref curr_ins,
uint64_t axis) const; uint64_t axis) const;
instruction_ref add_broadcastable_binary_op(const std::string& op_name, instruction_ref add_broadcastable_binary_op(const std::string& op_name,
instruction_ref arg0, instruction_ref arg0,
instruction_ref arg1) const; instruction_ref arg1) const;
instruction_ref add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const;
template <class... Ts>
instruction_ref add_common_op(const std::string& op_name, Ts... xs) const
{
return add_common_op(op_name, {xs...});
}
instruction_ref add_instruction(const operation& op, instruction_ref add_instruction(const operation& op,
const std::vector<instruction_ref>& args) const; const std::vector<instruction_ref>& args) const;
......
...@@ -98,7 +98,13 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s ...@@ -98,7 +98,13 @@ instruction_ref onnx_parser::node_info::add_broadcastable_binary_op(const std::s
instruction_ref arg0, instruction_ref arg0,
instruction_ref arg1) const instruction_ref arg1) const
{ {
return add_common_op(*mod, make_op(op_name), {arg0, arg1}); return this->add_common_op(op_name, arg0, arg1);
}
instruction_ref onnx_parser::node_info::add_common_op(const std::string& op_name,
std::vector<instruction_ref> inputs) const
{
return migraphx::add_common_op(*mod, make_op(op_name), std::move(inputs));
} }
instruction_ref instruction_ref
......
...@@ -16,7 +16,6 @@ struct parse_clip : op_parser<parse_clip> ...@@ -16,7 +16,6 @@ struct parse_clip : op_parser<parse_clip>
onnx_parser::node_info info, onnx_parser::node_info info,
std::vector<instruction_ref> args) const std::vector<instruction_ref> args) const
{ {
auto input_lens = args[0]->get_shape().lens();
instruction_ref min_arg; instruction_ref min_arg;
instruction_ref max_arg; instruction_ref max_arg;
bool min_used = false; bool min_used = false;
...@@ -45,29 +44,17 @@ struct parse_clip : op_parser<parse_clip> ...@@ -45,29 +44,17 @@ struct parse_clip : op_parser<parse_clip>
max_used = true; max_used = true;
} }
if(min_used)
{
min_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
min_arg);
}
if(max_used)
{
max_arg = info.add_instruction(make_op("multibroadcast", {{"out_lens", input_lens}}),
max_arg);
}
if(min_used and max_used) if(min_used and max_used)
{ {
return info.add_instruction(make_op("clip"), args[0], min_arg, max_arg); return info.add_common_op("clip", args[0], min_arg, max_arg);
} }
else if(max_used) else if(max_used)
{ {
return info.add_instruction(make_op("min"), args[0], max_arg); return info.add_broadcastable_binary_op("min", args[0], max_arg);
} }
else if(min_used) else if(min_used)
{ {
return info.add_instruction(make_op("max"), args[0], min_arg); return info.add_broadcastable_binary_op("max", args[0], min_arg);
} }
else else
{ {
......
...@@ -426,6 +426,22 @@ def clip_test_op11_no_args1(): ...@@ -426,6 +426,22 @@ def clip_test_op11_no_args1():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def clip_test_args_type_mismatch():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3, 3])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 3])
min_val = helper.make_tensor('min', TensorProto.FLOAT, [1, 3],
[1.5, 2.5, 3.5])
max_val = helper.make_tensor('max', TensorProto.INT64, [3, 1], [2, 3, 4])
node = onnx.helper.make_node('Clip',
inputs=['0', 'min', 'max'],
outputs=['1'])
return ([node], [x], [y], [min_val, max_val])
@onnx_test @onnx_test
def concat_test(): def concat_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 4, 3]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 4, 3])
......
...@@ -470,6 +470,28 @@ TEST_CASE(clip_test_op11_no_args1) ...@@ -470,6 +470,28 @@ TEST_CASE(clip_test_op11_no_args1)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(clip_test_args_type_mismatch)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto min_val = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {1, 3}}, {1.5, 2.5, 3.5}});
auto max_val = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {3, 1}}, {2, 3, 4}});
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 3}});
min_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), min_val);
max_val =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), max_val);
max_val = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), max_val);
auto r = mm->add_instruction(migraphx::make_op("clip"), l0, min_val, max_val);
mm->add_return({r});
auto prog = migraphx::parse_onnx("clip_test_args_type_mismatch.onnx");
EXPECT(p == prog);
}
TEST_CASE(concat_test) TEST_CASE(concat_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -45,6 +45,22 @@ TEST_CASE(averagepool_nt_cip_test) ...@@ -45,6 +45,22 @@ TEST_CASE(averagepool_nt_cip_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(clip_args_type_mismatch)
{
auto p = migraphx::parse_onnx("clip_test_args_type_mismatch.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s_0{migraphx::shape::float_type, {3, 3}};
migraphx::parameter_map pp;
std::vector<float> data_0 = {0.9, 1.2, 1.7, 1.9, 2.2, 2.7, 2.9, 3.2, 3.7};
pp["0"] = migraphx::argument(s_0, data_0.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.5, 2, 2, 1.9, 2.5, 3, 2.9, 3.2, 3.7};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(depthtospace_simple_test) TEST_CASE(depthtospace_simple_test)
{ {
auto p = migraphx::parse_onnx("depthtospace_simple_test.onnx"); auto p = migraphx::parse_onnx("depthtospace_simple_test.onnx");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment