Unverified Commit ca15cd37 authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Parse if inline constant args (#1533)

Allows migraphx to inline the IF operator when we run into an IF that can be evaluated at compile time, thus avoiding us injecting IF and just inserting the instructions directly.
parent 2c93aa87
......@@ -113,7 +113,8 @@ struct onnx_parser
void parse_from(std::istream& is, std::string name = "");
void parse_from(const void* data, std::size_t size);
void parse_graph(module* mod, const onnx::GraphProto& graph);
std::vector<instruction_ref>
parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining = false);
literal parse_value(const onnx::AttributeProto& attr) const;
literal parse_tensor(const onnx::TensorProto& t) const;
shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const;
......
......@@ -220,7 +220,7 @@ void onnx_parser::parse_from(std::istream& is, std::string name)
if(model.has_graph())
{
this->parse_graph(mm, model.graph());
(void)this->parse_graph(mm, model.graph());
}
}
else
......@@ -240,7 +240,7 @@ void onnx_parser::parse_from(const void* data, std::size_t size)
if(model.has_graph())
{
this->parse_graph(mm, model.graph());
(void)this->parse_graph(mm, model.graph());
}
}
else
......@@ -264,7 +264,8 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
return version;
}
void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
std::vector<instruction_ref>
onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining)
{
std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer())
......@@ -372,11 +373,16 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
std::back_inserter(output_ins),
[&](const auto& name) { return instructions[name]; });
// add the return instuction
mod->add_return(output_ins);
if(not inlining)
{
// add the return instuction
mod->add_return(output_ins);
// Remove instructions added in module (this is turned off for subgraph inlining)
erase_if(instructions, [&](auto&& p) { return mod->has_instruction(p.second); });
}
// remove instructions added in this mod
erase_if(instructions, [&](auto&& p) { return mod->has_instruction(p.second); });
return output_ins;
}
literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
......
......@@ -51,6 +51,24 @@ struct parse_if : op_parser<parse_if>
" condition input can have only one element!");
}
// Fold instruction if condition is constant thus can be evaled
// prior to inference
if(args.front()->can_eval())
{
auto cond_arg = args.front()->eval();
auto* mod = info.mod;
// then branch
if(cond_arg.at<bool>())
{
return parser.parse_graph(mod, then_graph, true);
}
// else branch
else
{
return parser.parse_graph(mod, else_graph, true);
}
}
std::string then_name = info.name + "_if";
module_ref then_mdl = parser.prog.create_module(then_name);
......@@ -58,10 +76,10 @@ struct parse_if : op_parser<parse_if>
module_ref else_mdl = parser.prog.create_module(else_name);
// parse the then sub_graph
parser.parse_graph(then_mdl, then_graph);
(void)parser.parse_graph(then_mdl, then_graph);
// parse_the else sub_graph
parser.parse_graph(else_mdl, else_graph);
(void)parser.parse_graph(else_mdl, else_graph);
auto then_out_shapes = then_mdl->get_output_shapes();
auto else_out_shapes = else_mdl->get_output_shapes();
......
......@@ -71,7 +71,7 @@ struct parse_loop : op_parser<parse_loop>
module_ref sub_mod = parser.prog.create_module(mod_name);
// parse the sub_graph
parser.parse_graph(sub_mod, sub_graph);
(void)parser.parse_graph(sub_mod, sub_graph);
auto ret = info.add_instruction(
make_op("loop", {{"max_iterations", max_iterations}}), args, {sub_mod});
......
......@@ -2498,6 +2498,58 @@ def if_else_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT,
[2, 3])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT,
[2, 3])
xt = np.ones((2, 3)).astype(np.float)
xt_tensor = helper.make_tensor(name='xt',
data_type=TensorProto.FLOAT,
dims=xt.shape,
vals=xt.flatten().astype(np.float32))
yt = np.random.randn(2, 3).astype(np.float)
yt_tensor = helper.make_tensor(name='yt',
data_type=TensorProto.FLOAT,
dims=yt.shape,
vals=yt.flatten().astype(np.float32))
then_add_node = onnx.helper.make_node('Add',
inputs=['x', 'xt'],
outputs=['then_out'])
else_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'yt'],
outputs=['else_out'])
then_body = onnx.helper.make_graph([then_add_node], 'then_body', [],
[then_out])
else_body = onnx.helper.make_graph([else_mul_node], 'else_body', [],
[else_out])
cond_tensor = onnx.helper.make_tensor_value_info("cond",
onnx.TensorProto.BOOL,
[1])
res = onnx.helper.make_tensor_value_info('res', TensorProto.FLOAT, [])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res'],
then_branch=then_body,
else_branch=else_body)
return ([node], [x, y, cond_tensor], [res], [xt_tensor, yt_tensor])
@onnx_test()
def if_else_test_inlined():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT,
[2, 3])
......@@ -2547,6 +2599,149 @@ def if_else_test():
return ([node], [x, y], [res], [cond_tensor, xt_tensor, yt_tensor])
@onnx_test()
def if_then_else_multi_output_shapes_inlined_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT,
[2, 3, 1])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT,
[2, 3, 1])
then_out2 = onnx.helper.make_tensor_value_info('then_out2',
onnx.TensorProto.FLOAT,
[2, 3, 1])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT,
[2, 3])
else_out2 = onnx.helper.make_tensor_value_info('else_out2',
onnx.TensorProto.FLOAT,
[2, 3])
xt = np.ones((2, 3, 1)).astype(np.float)
xt_tensor = helper.make_tensor(name='xt',
data_type=TensorProto.FLOAT,
dims=xt.shape,
vals=xt.flatten().astype(np.float32))
yt = np.random.randn(2, 3).astype(np.float)
yt_tensor = helper.make_tensor(name='yt',
data_type=TensorProto.FLOAT,
dims=yt.shape,
vals=yt.flatten().astype(np.float32))
then_add_node = onnx.helper.make_node('Add',
inputs=['x', 'xt'],
outputs=['then_out'])
then_add_node2 = onnx.helper.make_node('Add',
inputs=['x', 'x'],
outputs=['then_out2'])
else_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'yt'],
outputs=['else_out'])
else_sub_node = onnx.helper.make_node('Sub',
inputs=['y', 'yt'],
outputs=['else_out2'])
then_body = onnx.helper.make_graph([then_add_node, then_add_node2],
'then_body', [], [then_out, then_out2])
else_body = onnx.helper.make_graph([else_mul_node, else_sub_node],
'else_body', [], [else_out, else_out2])
cond = np.array([1]).astype(np.bool)
cond_tensor = helper.make_tensor(name="cond",
data_type=TensorProto.BOOL,
dims=cond.shape,
vals=cond.astype(bool))
res1 = onnx.helper.make_tensor_value_info('res1', TensorProto.FLOAT, [])
res2 = onnx.helper.make_tensor_value_info('res2', TensorProto.FLOAT, [])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res1', 'res2'],
then_branch=then_body,
else_branch=else_body)
return ([node], [x, y], [res1, res2], [cond_tensor, xt_tensor, yt_tensor])
@onnx_test()
def if_then_else_multi_output_shapes_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT,
[2, 3, 1])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT,
[2, 3, 1])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT,
[2, 3, 1])
then_out2 = onnx.helper.make_tensor_value_info('then_out2',
onnx.TensorProto.FLOAT,
[2, 3, 1])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT,
[2, 3, 1])
else_out2 = onnx.helper.make_tensor_value_info('else_out2',
onnx.TensorProto.FLOAT,
[2, 3, 1])
xt = np.ones((2, 3, 1)).astype(np.float)
xt_tensor = helper.make_tensor(name='xt',
data_type=TensorProto.FLOAT,
dims=xt.shape,
vals=xt.flatten().astype(np.float32))
yt = np.random.randn(2, 3, 1).astype(np.float)
yt_tensor = helper.make_tensor(name='yt',
data_type=TensorProto.FLOAT,
dims=yt.shape,
vals=yt.flatten().astype(np.float32))
then_add_node = onnx.helper.make_node('Add',
inputs=['x', 'xt'],
outputs=['then_out'])
then_add_node2 = onnx.helper.make_node('Add',
inputs=['x', 'x'],
outputs=['then_out2'])
else_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'yt'],
outputs=['else_out'])
else_sub_node = onnx.helper.make_node('Sub',
inputs=['y', 'yt'],
outputs=['else_out2'])
then_body = onnx.helper.make_graph([then_add_node, then_add_node2],
'then_body', [], [then_out, then_out2])
else_body = onnx.helper.make_graph([else_mul_node, else_sub_node],
'else_body', [], [else_out, else_out2])
cond_tensor = onnx.helper.make_tensor_value_info("cond",
onnx.TensorProto.BOOL,
[1])
res1 = onnx.helper.make_tensor_value_info('res1', TensorProto.FLOAT, [])
res2 = onnx.helper.make_tensor_value_info('res2', TensorProto.FLOAT, [])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res1', 'res2'],
then_branch=then_body,
else_branch=else_body)
return ([node], [x, y, cond_tensor], [res1, res2], [xt_tensor, yt_tensor])
@onnx_test()
def if_literal_test():
then_out = onnx.helper.make_tensor_value_info('then_out',
......@@ -2807,6 +3002,59 @@ def if_then_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT,
[2, 3])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT,
[2, 3])
xt = np.ones((2, 3)).astype(np.float)
xt_tensor = helper.make_tensor(name='xt',
data_type=TensorProto.FLOAT,
dims=xt.shape,
vals=xt.flatten().astype(np.float32))
yt = np.random.randn(2, 3).astype(np.float)
yt_tensor = helper.make_tensor(name='yt',
data_type=TensorProto.FLOAT,
dims=yt.shape,
vals=yt.flatten().astype(np.float32))
then_add_node = onnx.helper.make_node('Add',
inputs=['x', 'xt'],
outputs=['then_out'])
else_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'yt'],
outputs=['else_out'])
then_body = onnx.helper.make_graph([then_add_node], 'then_body', [],
[then_out])
else_body = onnx.helper.make_graph([else_mul_node], 'else_body', [],
[else_out])
cond_tensor = onnx.helper.make_tensor_value_info("cond",
onnx.TensorProto.BOOL,
[1])
res = onnx.helper.make_tensor_value_info('res', TensorProto.FLOAT, [])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res'],
then_branch=then_body,
else_branch=else_body)
return ([node], [x, y, cond_tensor], [res], [xt_tensor, yt_tensor])
@onnx_test()
def if_then_test_inlined():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT,
[2, 3])
......
......@@ -2672,14 +2672,16 @@ TEST_CASE(if_else_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {0}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-0.583375, 0.633757, 0.0668345, -0.479422, -0.604634, 0.0388589};
auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
std::vector<float> rand = {1.3865, -0.494756, -0.283504, 0.200491, -0.490031, 1.32388};
auto l1 = mm->add_literal(s, ones);
auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto cond = mm->add_parameter("cond", sc);
auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
......@@ -2693,15 +2695,32 @@ TEST_CASE(if_else_test)
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r});
std::ifstream ifs("if_else_test.onnx", std::ios::binary);
ifs.seekg(0, std::ios::end);
auto length = ifs.tellg();
ifs.seekg(0, std::ios::beg);
std::vector<char> onnx_buffer(length);
ifs.read(onnx_buffer.data(), length);
ifs.close();
auto prog = migraphx::parse_onnx("if_else_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(if_else_test_inlined)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
mm->add_literal(migraphx::literal(sc, {0}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
mm->add_literal(s, ones);
std::vector<float> rand = {0.811412, -0.949771, -0.169276, 0.36552, -0.14801, 2.07061};
auto l2 = mm->add_literal(s, rand);
mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto re = mm->add_instruction(migraphx::make_op("mul"), y, l2);
mm->add_return({re});
auto prog = migraphx::parse_onnx_buffer(onnx_buffer.data(), length, {});
auto prog = migraphx::parse_onnx("if_else_test_inlined.onnx");
EXPECT(p == prog);
}
......@@ -2774,6 +2793,70 @@ TEST_CASE(if_param_test)
EXPECT(p == prog);
}
TEST_CASE(if_then_else_multi_output_shapes_inlined_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::shape s_trail{migraphx::shape::float_type, {2, 3, 1}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s_trail, ones);
std::vector<float> rand = {-1.01837, -0.305541, -0.254105, 0.892955, 1.38714, -0.584205};
mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s_trail);
mm->add_parameter("y", s);
auto rt = mm->add_instruction(migraphx::make_op("add"), x, l1);
auto rt2 = mm->add_instruction(migraphx::make_op("add"), x, x);
mm->add_return({rt, rt2});
auto prog = migraphx::parse_onnx("if_then_else_multi_output_shapes_inlined_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(if_then_else_multi_output_shapes_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
migraphx::shape s{migraphx::shape::float_type, {2, 3, 1}};
migraphx::shape s_trail{migraphx::shape::float_type, {2, 3, 1}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s_trail, ones);
std::vector<float> rand = {-0.753997, 0.707831, -0.865795, 2.49574, 0.464937, -0.168745};
auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s_trail);
auto y = mm->add_parameter("y", s);
auto cond = mm->add_parameter("cond", sc);
auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
auto rt2 = then_mod->add_instruction(migraphx::make_op("add"), x, x);
then_mod->add_return({rt, rt2});
auto* else_mod = p.create_module("If_5_else");
auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
auto re2 = else_mod->add_instruction(migraphx::make_op("sub"), y, l2);
else_mod->add_return({re, re2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r1, r2});
auto prog = migraphx::parse_onnx("if_then_else_multi_output_shapes_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(if_pl_test)
{
migraphx::program p;
......@@ -2814,14 +2897,16 @@ TEST_CASE(if_then_test)
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
std::vector<float> rand = {-0.266913, -0.180328, -0.124268, -1.23768, 0.312334, 1.18475};
auto l1 = mm->add_literal(s, ones);
auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto cond = mm->add_parameter("cond", sc);
auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
......@@ -2839,6 +2924,32 @@ TEST_CASE(if_then_test)
EXPECT(p == prog);
}
TEST_CASE(if_then_test_inlined)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
mm->add_parameter("y", s);
auto rt = mm->add_instruction(migraphx::make_op("add"), x, l1);
mm->add_return({rt});
auto prog = migraphx::parse_onnx("if_then_test_inlined.onnx");
EXPECT(p == prog);
}
TEST_CASE(if_tuple_test)
{
migraphx::program p;
......
......@@ -590,6 +590,28 @@ TEST_CASE(if_else_test)
p.compile(migraphx::ref::target{});
migraphx::shape s_data{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::shape bool_data{migraphx::shape::bool_type, {1}};
bool b_data = false;
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data());
pp["y"] = migraphx::argument(s_data, data.data());
pp["cond"] = migraphx::argument(bool_data, &b_data);
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.0866565, -0.371067, 0.017719, 0.0250614, 0.0612539, -0.744683};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(if_else_test_inlined)
{
migraphx::program p = migraphx::parse_onnx("if_else_test_inlined.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s_data{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data());
......@@ -599,8 +621,49 @@ TEST_CASE(if_else_test)
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {
-0.0364609435, 0.475317657, -0.00417715637, -0.0599277429, 0.0755792186, -0.0218581557};
std::vector<float> gold = {0.0507132, -0.712328, 0.0105797, 0.04569, 0.0185013, -1.16472};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(if_then_test)
{
migraphx::program p = migraphx::parse_onnx("if_then_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s_data{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::shape bool_data{migraphx::shape::bool_type, {1}};
bool b_data = true;
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data());
pp["y"] = migraphx::argument(s_data, data.data());
pp["cond"] = migraphx::argument(bool_data, &b_data);
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
// onnx adds ones so result should be just + 1.0
std::vector<float> gold = {1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(if_then_test_inlined)
{
migraphx::program p = migraphx::parse_onnx("if_then_test_inlined.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s_data{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data());
pp["y"] = migraphx::argument(s_data, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375};
EXPECT(migraphx::verify_range(result_vector, gold));
}
......@@ -637,6 +700,67 @@ TEST_CASE(if_literal_test)
}
}
TEST_CASE(if_then_else_multi_output_shapes_inlined_test)
{
migraphx::program p =
migraphx::parse_onnx("if_then_else_multi_output_shapes_inlined_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape x_data{migraphx::shape::float_type, {2, 3, 1}};
migraphx::shape y_data{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(x_data, data.data());
pp["y"] = migraphx::argument(y_data, data.data());
auto result_args = p.eval(pp);
auto result = result_args.front();
auto result_b = result_args.back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> result_vector_back;
result_b.visit([&](auto output) { result_vector_back.assign(output.begin(), output.end()); });
result_vector.insert(result_vector.end(), result_vector_back.begin(), result_vector_back.end());
std::vector<float> gold = {
1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375, 0.125, 1.50, -0.125, 0.250, -0.250, -1.125};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(if_then_else_multi_output_shapes_test)
{
migraphx::program p = migraphx::parse_onnx("if_then_else_multi_output_shapes_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s_data{migraphx::shape::float_type, {2, 3, 1}};
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::shape bool_data{migraphx::shape::bool_type, {1}};
bool b_data = true;
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data());
pp["y"] = migraphx::argument(s_data, data.data());
pp["cond"] = migraphx::argument(bool_data, &b_data);
auto result_args = p.eval(pp);
auto result = result_args.front();
auto result_b = result_args.back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> result_vector_back;
result_b.visit([&](auto output) { result_vector_back.assign(output.begin(), output.end()); });
result_vector.insert(result_vector.end(), result_vector_back.begin(), result_vector_back.end());
std::vector<float> gold = {
1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375, 0.125, 1.50, -0.125, 0.250, -0.250, -1.125};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(if_pl_test)
{
auto run_prog = [](bool cond) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment