Unverified Commit ca15cd37 authored by Ted Themistokleous's avatar Ted Themistokleous Committed by GitHub
Browse files

Parse if inline constant args (#1533)

Allows migraphx to inline the IF operator when we run into an IF that can be evaluated at compile time, thus avoiding us injecting IF and just inserting the instructions directly.
parent 2c93aa87
...@@ -113,7 +113,8 @@ struct onnx_parser ...@@ -113,7 +113,8 @@ struct onnx_parser
void parse_from(std::istream& is, std::string name = ""); void parse_from(std::istream& is, std::string name = "");
void parse_from(const void* data, std::size_t size); void parse_from(const void* data, std::size_t size);
void parse_graph(module* mod, const onnx::GraphProto& graph); std::vector<instruction_ref>
parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining = false);
literal parse_value(const onnx::AttributeProto& attr) const; literal parse_value(const onnx::AttributeProto& attr) const;
literal parse_tensor(const onnx::TensorProto& t) const; literal parse_tensor(const onnx::TensorProto& t) const;
shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const; shape parse_type(const onnx::TypeProto& t, const std::vector<std::size_t>& input_dims) const;
......
...@@ -220,7 +220,7 @@ void onnx_parser::parse_from(std::istream& is, std::string name) ...@@ -220,7 +220,7 @@ void onnx_parser::parse_from(std::istream& is, std::string name)
if(model.has_graph()) if(model.has_graph())
{ {
this->parse_graph(mm, model.graph()); (void)this->parse_graph(mm, model.graph());
} }
} }
else else
...@@ -240,7 +240,7 @@ void onnx_parser::parse_from(const void* data, std::size_t size) ...@@ -240,7 +240,7 @@ void onnx_parser::parse_from(const void* data, std::size_t size)
if(model.has_graph()) if(model.has_graph())
{ {
this->parse_graph(mm, model.graph()); (void)this->parse_graph(mm, model.graph());
} }
} }
else else
...@@ -264,7 +264,8 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model) ...@@ -264,7 +264,8 @@ int64_t onnx_parser::get_opset_version(const onnx::ModelProto& model)
return version; return version;
} }
void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) std::vector<instruction_ref>
onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlining)
{ {
std::unordered_map<std::string, instruction_ref> mod_insts; std::unordered_map<std::string, instruction_ref> mod_insts;
for(auto&& f : graph.initializer()) for(auto&& f : graph.initializer())
...@@ -372,11 +373,16 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph) ...@@ -372,11 +373,16 @@ void onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph)
std::back_inserter(output_ins), std::back_inserter(output_ins),
[&](const auto& name) { return instructions[name]; }); [&](const auto& name) { return instructions[name]; });
if(not inlining)
{
// add the return instuction // add the return instuction
mod->add_return(output_ins); mod->add_return(output_ins);
// remove instructions added in this mod // Remove instructions added in module (this is turned off for subgraph inlining)
erase_if(instructions, [&](auto&& p) { return mod->has_instruction(p.second); }); erase_if(instructions, [&](auto&& p) { return mod->has_instruction(p.second); });
}
return output_ins;
} }
literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const literal onnx_parser::parse_value(const onnx::AttributeProto& attr) const
......
...@@ -51,6 +51,24 @@ struct parse_if : op_parser<parse_if> ...@@ -51,6 +51,24 @@ struct parse_if : op_parser<parse_if>
" condition input can have only one element!"); " condition input can have only one element!");
} }
// Fold instruction if condition is constant thus can be evaled
// prior to inference
if(args.front()->can_eval())
{
auto cond_arg = args.front()->eval();
auto* mod = info.mod;
// then branch
if(cond_arg.at<bool>())
{
return parser.parse_graph(mod, then_graph, true);
}
// else branch
else
{
return parser.parse_graph(mod, else_graph, true);
}
}
std::string then_name = info.name + "_if"; std::string then_name = info.name + "_if";
module_ref then_mdl = parser.prog.create_module(then_name); module_ref then_mdl = parser.prog.create_module(then_name);
...@@ -58,10 +76,10 @@ struct parse_if : op_parser<parse_if> ...@@ -58,10 +76,10 @@ struct parse_if : op_parser<parse_if>
module_ref else_mdl = parser.prog.create_module(else_name); module_ref else_mdl = parser.prog.create_module(else_name);
// parse the then sub_graph // parse the then sub_graph
parser.parse_graph(then_mdl, then_graph); (void)parser.parse_graph(then_mdl, then_graph);
// parse_the else sub_graph // parse_the else sub_graph
parser.parse_graph(else_mdl, else_graph); (void)parser.parse_graph(else_mdl, else_graph);
auto then_out_shapes = then_mdl->get_output_shapes(); auto then_out_shapes = then_mdl->get_output_shapes();
auto else_out_shapes = else_mdl->get_output_shapes(); auto else_out_shapes = else_mdl->get_output_shapes();
......
...@@ -71,7 +71,7 @@ struct parse_loop : op_parser<parse_loop> ...@@ -71,7 +71,7 @@ struct parse_loop : op_parser<parse_loop>
module_ref sub_mod = parser.prog.create_module(mod_name); module_ref sub_mod = parser.prog.create_module(mod_name);
// parse the sub_graph // parse the sub_graph
parser.parse_graph(sub_mod, sub_graph); (void)parser.parse_graph(sub_mod, sub_graph);
auto ret = info.add_instruction( auto ret = info.add_instruction(
make_op("loop", {{"max_iterations", max_iterations}}), args, {sub_mod}); make_op("loop", {{"max_iterations", max_iterations}}), args, {sub_mod});
......
...@@ -2498,6 +2498,58 @@ def if_else_test(): ...@@ -2498,6 +2498,58 @@ def if_else_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3]) x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3]) y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT,
[2, 3])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT,
[2, 3])
xt = np.ones((2, 3)).astype(np.float)
xt_tensor = helper.make_tensor(name='xt',
data_type=TensorProto.FLOAT,
dims=xt.shape,
vals=xt.flatten().astype(np.float32))
yt = np.random.randn(2, 3).astype(np.float)
yt_tensor = helper.make_tensor(name='yt',
data_type=TensorProto.FLOAT,
dims=yt.shape,
vals=yt.flatten().astype(np.float32))
then_add_node = onnx.helper.make_node('Add',
inputs=['x', 'xt'],
outputs=['then_out'])
else_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'yt'],
outputs=['else_out'])
then_body = onnx.helper.make_graph([then_add_node], 'then_body', [],
[then_out])
else_body = onnx.helper.make_graph([else_mul_node], 'else_body', [],
[else_out])
cond_tensor = onnx.helper.make_tensor_value_info("cond",
onnx.TensorProto.BOOL,
[1])
res = onnx.helper.make_tensor_value_info('res', TensorProto.FLOAT, [])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res'],
then_branch=then_body,
else_branch=else_body)
return ([node], [x, y, cond_tensor], [res], [xt_tensor, yt_tensor])
@onnx_test()
def if_else_test_inlined():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3])
then_out = onnx.helper.make_tensor_value_info('then_out', then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT,
[2, 3]) [2, 3])
...@@ -2547,6 +2599,149 @@ def if_else_test(): ...@@ -2547,6 +2599,149 @@ def if_else_test():
return ([node], [x, y], [res], [cond_tensor, xt_tensor, yt_tensor]) return ([node], [x, y], [res], [cond_tensor, xt_tensor, yt_tensor])
@onnx_test()
def if_then_else_multi_output_shapes_inlined_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT,
[2, 3, 1])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT,
[2, 3, 1])
then_out2 = onnx.helper.make_tensor_value_info('then_out2',
onnx.TensorProto.FLOAT,
[2, 3, 1])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT,
[2, 3])
else_out2 = onnx.helper.make_tensor_value_info('else_out2',
onnx.TensorProto.FLOAT,
[2, 3])
xt = np.ones((2, 3, 1)).astype(np.float)
xt_tensor = helper.make_tensor(name='xt',
data_type=TensorProto.FLOAT,
dims=xt.shape,
vals=xt.flatten().astype(np.float32))
yt = np.random.randn(2, 3).astype(np.float)
yt_tensor = helper.make_tensor(name='yt',
data_type=TensorProto.FLOAT,
dims=yt.shape,
vals=yt.flatten().astype(np.float32))
then_add_node = onnx.helper.make_node('Add',
inputs=['x', 'xt'],
outputs=['then_out'])
then_add_node2 = onnx.helper.make_node('Add',
inputs=['x', 'x'],
outputs=['then_out2'])
else_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'yt'],
outputs=['else_out'])
else_sub_node = onnx.helper.make_node('Sub',
inputs=['y', 'yt'],
outputs=['else_out2'])
then_body = onnx.helper.make_graph([then_add_node, then_add_node2],
'then_body', [], [then_out, then_out2])
else_body = onnx.helper.make_graph([else_mul_node, else_sub_node],
'else_body', [], [else_out, else_out2])
cond = np.array([1]).astype(np.bool)
cond_tensor = helper.make_tensor(name="cond",
data_type=TensorProto.BOOL,
dims=cond.shape,
vals=cond.astype(bool))
res1 = onnx.helper.make_tensor_value_info('res1', TensorProto.FLOAT, [])
res2 = onnx.helper.make_tensor_value_info('res2', TensorProto.FLOAT, [])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res1', 'res2'],
then_branch=then_body,
else_branch=else_body)
return ([node], [x, y], [res1, res2], [cond_tensor, xt_tensor, yt_tensor])
@onnx_test()
def if_then_else_multi_output_shapes_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT,
[2, 3, 1])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT,
[2, 3, 1])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT,
[2, 3, 1])
then_out2 = onnx.helper.make_tensor_value_info('then_out2',
onnx.TensorProto.FLOAT,
[2, 3, 1])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT,
[2, 3, 1])
else_out2 = onnx.helper.make_tensor_value_info('else_out2',
onnx.TensorProto.FLOAT,
[2, 3, 1])
xt = np.ones((2, 3, 1)).astype(np.float)
xt_tensor = helper.make_tensor(name='xt',
data_type=TensorProto.FLOAT,
dims=xt.shape,
vals=xt.flatten().astype(np.float32))
yt = np.random.randn(2, 3, 1).astype(np.float)
yt_tensor = helper.make_tensor(name='yt',
data_type=TensorProto.FLOAT,
dims=yt.shape,
vals=yt.flatten().astype(np.float32))
then_add_node = onnx.helper.make_node('Add',
inputs=['x', 'xt'],
outputs=['then_out'])
then_add_node2 = onnx.helper.make_node('Add',
inputs=['x', 'x'],
outputs=['then_out2'])
else_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'yt'],
outputs=['else_out'])
else_sub_node = onnx.helper.make_node('Sub',
inputs=['y', 'yt'],
outputs=['else_out2'])
then_body = onnx.helper.make_graph([then_add_node, then_add_node2],
'then_body', [], [then_out, then_out2])
else_body = onnx.helper.make_graph([else_mul_node, else_sub_node],
'else_body', [], [else_out, else_out2])
cond_tensor = onnx.helper.make_tensor_value_info("cond",
onnx.TensorProto.BOOL,
[1])
res1 = onnx.helper.make_tensor_value_info('res1', TensorProto.FLOAT, [])
res2 = onnx.helper.make_tensor_value_info('res2', TensorProto.FLOAT, [])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res1', 'res2'],
then_branch=then_body,
else_branch=else_body)
return ([node], [x, y, cond_tensor], [res1, res2], [xt_tensor, yt_tensor])
@onnx_test() @onnx_test()
def if_literal_test(): def if_literal_test():
then_out = onnx.helper.make_tensor_value_info('then_out', then_out = onnx.helper.make_tensor_value_info('then_out',
...@@ -2807,6 +3002,59 @@ def if_then_test(): ...@@ -2807,6 +3002,59 @@ def if_then_test():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3]) x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3]) y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3])
then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT,
[2, 3])
else_out = onnx.helper.make_tensor_value_info('else_out',
onnx.TensorProto.FLOAT,
[2, 3])
xt = np.ones((2, 3)).astype(np.float)
xt_tensor = helper.make_tensor(name='xt',
data_type=TensorProto.FLOAT,
dims=xt.shape,
vals=xt.flatten().astype(np.float32))
yt = np.random.randn(2, 3).astype(np.float)
yt_tensor = helper.make_tensor(name='yt',
data_type=TensorProto.FLOAT,
dims=yt.shape,
vals=yt.flatten().astype(np.float32))
then_add_node = onnx.helper.make_node('Add',
inputs=['x', 'xt'],
outputs=['then_out'])
else_mul_node = onnx.helper.make_node('Mul',
inputs=['y', 'yt'],
outputs=['else_out'])
then_body = onnx.helper.make_graph([then_add_node], 'then_body', [],
[then_out])
else_body = onnx.helper.make_graph([else_mul_node], 'else_body', [],
[else_out])
cond_tensor = onnx.helper.make_tensor_value_info("cond",
onnx.TensorProto.BOOL,
[1])
res = onnx.helper.make_tensor_value_info('res', TensorProto.FLOAT, [])
node = onnx.helper.make_node('If',
inputs=['cond'],
outputs=['res'],
then_branch=then_body,
else_branch=else_body)
return ([node], [x, y, cond_tensor], [res], [xt_tensor, yt_tensor])
@onnx_test()
def if_then_test_inlined():
x = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [2, 3])
y = onnx.helper.make_tensor_value_info('y', onnx.TensorProto.FLOAT, [2, 3])
then_out = onnx.helper.make_tensor_value_info('then_out', then_out = onnx.helper.make_tensor_value_info('then_out',
onnx.TensorProto.FLOAT, onnx.TensorProto.FLOAT,
[2, 3]) [2, 3])
......
...@@ -2672,14 +2672,16 @@ TEST_CASE(if_else_test) ...@@ -2672,14 +2672,16 @@ TEST_CASE(if_else_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}}; migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {0}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f); std::vector<float> ones(s.elements(), 1.0f);
std::vector<float> rand = {1.3865, -0.494756, -0.283504, 0.200491, -0.490031, 1.32388};
auto l1 = mm->add_literal(s, ones); auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-0.583375, 0.633757, 0.0668345, -0.479422, -0.604634, 0.0388589};
auto l2 = mm->add_literal(s, rand); auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto cond = mm->add_parameter("cond", sc);
auto* then_mod = p.create_module("If_5_if"); auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1); auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
...@@ -2693,15 +2695,32 @@ TEST_CASE(if_else_test) ...@@ -2693,15 +2695,32 @@ TEST_CASE(if_else_test)
auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret); auto r = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
mm->add_return({r}); mm->add_return({r});
std::ifstream ifs("if_else_test.onnx", std::ios::binary); auto prog = migraphx::parse_onnx("if_else_test.onnx");
ifs.seekg(0, std::ios::end); EXPECT(p == prog);
auto length = ifs.tellg(); }
ifs.seekg(0, std::ios::beg);
std::vector<char> onnx_buffer(length); TEST_CASE(if_else_test_inlined)
ifs.read(onnx_buffer.data(), length); {
ifs.close(); migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
mm->add_literal(migraphx::literal(sc, {0}));
auto prog = migraphx::parse_onnx_buffer(onnx_buffer.data(), length, {}); migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
mm->add_literal(s, ones);
std::vector<float> rand = {0.811412, -0.949771, -0.169276, 0.36552, -0.14801, 2.07061};
auto l2 = mm->add_literal(s, rand);
mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto re = mm->add_instruction(migraphx::make_op("mul"), y, l2);
mm->add_return({re});
auto prog = migraphx::parse_onnx("if_else_test_inlined.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -2774,6 +2793,70 @@ TEST_CASE(if_param_test) ...@@ -2774,6 +2793,70 @@ TEST_CASE(if_param_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(if_then_else_multi_output_shapes_inlined_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::shape s_trail{migraphx::shape::float_type, {2, 3, 1}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s_trail, ones);
std::vector<float> rand = {-1.01837, -0.305541, -0.254105, 0.892955, 1.38714, -0.584205};
mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s_trail);
mm->add_parameter("y", s);
auto rt = mm->add_instruction(migraphx::make_op("add"), x, l1);
auto rt2 = mm->add_instruction(migraphx::make_op("add"), x, x);
mm->add_return({rt, rt2});
auto prog = migraphx::parse_onnx("if_then_else_multi_output_shapes_inlined_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(if_then_else_multi_output_shapes_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
migraphx::shape s{migraphx::shape::float_type, {2, 3, 1}};
migraphx::shape s_trail{migraphx::shape::float_type, {2, 3, 1}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s_trail, ones);
std::vector<float> rand = {-0.753997, 0.707831, -0.865795, 2.49574, 0.464937, -0.168745};
auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s_trail);
auto y = mm->add_parameter("y", s);
auto cond = mm->add_parameter("cond", sc);
auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
auto rt2 = then_mod->add_instruction(migraphx::make_op("add"), x, x);
then_mod->add_return({rt, rt2});
auto* else_mod = p.create_module("If_5_else");
auto re = else_mod->add_instruction(migraphx::make_op("mul"), y, l2);
auto re2 = else_mod->add_instruction(migraphx::make_op("sub"), y, l2);
else_mod->add_return({re, re2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
auto r1 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), ret);
auto r2 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), ret);
mm->add_return({r1, r2});
auto prog = migraphx::parse_onnx("if_then_else_multi_output_shapes_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(if_pl_test) TEST_CASE(if_pl_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -2814,14 +2897,16 @@ TEST_CASE(if_then_test) ...@@ -2814,14 +2897,16 @@ TEST_CASE(if_then_test)
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}}; migraphx::shape sc{migraphx::shape::bool_type, {1}};
auto cond = mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f); std::vector<float> ones(s.elements(), 1.0f);
std::vector<float> rand = {-0.266913, -0.180328, -0.124268, -1.23768, 0.312334, 1.18475};
auto l1 = mm->add_literal(s, ones); auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
auto l2 = mm->add_literal(s, rand); auto l2 = mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto cond = mm->add_parameter("cond", sc);
auto* then_mod = p.create_module("If_5_if"); auto* then_mod = p.create_module("If_5_if");
auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1); auto rt = then_mod->add_instruction(migraphx::make_op("add"), x, l1);
...@@ -2839,6 +2924,32 @@ TEST_CASE(if_then_test) ...@@ -2839,6 +2924,32 @@ TEST_CASE(if_then_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(if_then_test_inlined)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sc{migraphx::shape::bool_type, {1}};
mm->add_literal(migraphx::literal(sc, {1}));
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
std::vector<float> ones(s.elements(), 1.0f);
auto l1 = mm->add_literal(s, ones);
std::vector<float> rand = {-1.26487, -2.42279, 0.990835, 1.63072, 0.812238, -0.174946};
mm->add_literal(s, rand);
auto x = mm->add_parameter("x", s);
mm->add_parameter("y", s);
auto rt = mm->add_instruction(migraphx::make_op("add"), x, l1);
mm->add_return({rt});
auto prog = migraphx::parse_onnx("if_then_test_inlined.onnx");
EXPECT(p == prog);
}
TEST_CASE(if_tuple_test) TEST_CASE(if_tuple_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -590,17 +590,80 @@ TEST_CASE(if_else_test) ...@@ -590,17 +590,80 @@ TEST_CASE(if_else_test)
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
migraphx::shape s_data{migraphx::shape::float_type, {2, 3}}; migraphx::shape s_data{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625}; std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::shape bool_data{migraphx::shape::bool_type, {1}};
bool b_data = false;
migraphx::parameter_map pp; migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data()); pp["x"] = migraphx::argument(s_data, data.data());
pp["y"] = migraphx::argument(s_data, data.data()); pp["y"] = migraphx::argument(s_data, data.data());
pp["cond"] = migraphx::argument(bool_data, &b_data);
auto result = p.eval(pp).back(); auto result = p.eval(pp).back();
std::vector<float> result_vector; std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = { std::vector<float> gold = {0.0866565, -0.371067, 0.017719, 0.0250614, 0.0612539, -0.744683};
-0.0364609435, 0.475317657, -0.00417715637, -0.0599277429, 0.0755792186, -0.0218581557}; EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(if_else_test_inlined)
{
migraphx::program p = migraphx::parse_onnx("if_else_test_inlined.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s_data{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data());
pp["y"] = migraphx::argument(s_data, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.0507132, -0.712328, 0.0105797, 0.04569, 0.0185013, -1.16472};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(if_then_test)
{
migraphx::program p = migraphx::parse_onnx("if_then_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s_data{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::shape bool_data{migraphx::shape::bool_type, {1}};
bool b_data = true;
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data());
pp["y"] = migraphx::argument(s_data, data.data());
pp["cond"] = migraphx::argument(bool_data, &b_data);
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
// onnx adds ones so result should be just + 1.0
std::vector<float> gold = {1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(if_then_test_inlined)
{
migraphx::program p = migraphx::parse_onnx("if_then_test_inlined.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s_data{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data());
pp["y"] = migraphx::argument(s_data, data.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375};
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
...@@ -637,6 +700,67 @@ TEST_CASE(if_literal_test) ...@@ -637,6 +700,67 @@ TEST_CASE(if_literal_test)
} }
} }
TEST_CASE(if_then_else_multi_output_shapes_inlined_test)
{
migraphx::program p =
migraphx::parse_onnx("if_then_else_multi_output_shapes_inlined_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape x_data{migraphx::shape::float_type, {2, 3, 1}};
migraphx::shape y_data{migraphx::shape::float_type, {2, 3}};
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(x_data, data.data());
pp["y"] = migraphx::argument(y_data, data.data());
auto result_args = p.eval(pp);
auto result = result_args.front();
auto result_b = result_args.back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> result_vector_back;
result_b.visit([&](auto output) { result_vector_back.assign(output.begin(), output.end()); });
result_vector.insert(result_vector.end(), result_vector_back.begin(), result_vector_back.end());
std::vector<float> gold = {
1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375, 0.125, 1.50, -0.125, 0.250, -0.250, -1.125};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(if_then_else_multi_output_shapes_test)
{
migraphx::program p = migraphx::parse_onnx("if_then_else_multi_output_shapes_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s_data{migraphx::shape::float_type, {2, 3, 1}};
std::vector<float> data = {0.0625, 0.75, -0.0625, 0.125, -0.125, -0.5625};
migraphx::shape bool_data{migraphx::shape::bool_type, {1}};
bool b_data = true;
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(s_data, data.data());
pp["y"] = migraphx::argument(s_data, data.data());
pp["cond"] = migraphx::argument(bool_data, &b_data);
auto result_args = p.eval(pp);
auto result = result_args.front();
auto result_b = result_args.back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> result_vector_back;
result_b.visit([&](auto output) { result_vector_back.assign(output.begin(), output.end()); });
result_vector.insert(result_vector.end(), result_vector_back.begin(), result_vector_back.end());
std::vector<float> gold = {
1.0625, 1.75, 0.9375, 1.125, 0.875, 0.4375, 0.125, 1.50, -0.125, 0.250, -0.250, -1.125};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(if_pl_test) TEST_CASE(if_pl_test)
{ {
auto run_prog = [](bool cond) { auto run_prog = [](bool cond) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment