Commit 1e80ceef authored by Umang Yadav's avatar Umang Yadav
Browse files

add single target multiple returns

parent 1796d3e3
...@@ -114,7 +114,7 @@ struct module_impl ...@@ -114,7 +114,7 @@ struct module_impl
const operation& get_operation(instruction_ref ins) { return ins->get_operator(); } const operation& get_operation(instruction_ref ins) { return ins->get_operator(); }
module::module(const std::string& name) : impl(std::make_unique<module_impl>()) module::module(const std::string& name) :impl(std::make_unique<module_impl>())
{ {
impl->name = name; impl->name = name;
} }
...@@ -165,7 +165,7 @@ void module::assign(const module& m) ...@@ -165,7 +165,7 @@ void module::assign(const module& m)
auto order = any_cast<builtin::param>(ins->get_operator()).order; auto order = any_cast<builtin::param>(ins->get_operator()).order;
auto s = ins->get_shape(); auto s = ins->get_shape();
copy_ins = impl->insert(impl->instructions.end(), copy_ins = impl->insert(impl->instructions.end(),
{builtin::param{name, order}, std::move(s), {}}); {builtin::param{name, order}, std::move(s), {}});
impl->nparams++; impl->nparams++;
} }
else if(ins->name() == "@outline") else if(ins->name() == "@outline")
...@@ -800,8 +800,10 @@ static std::string cpp_var_name(const std::string& name) ...@@ -800,8 +800,10 @@ static std::string cpp_var_name(const std::string& name)
{ {
std::string prefix = "x_"; std::string prefix = "x_";
if(not contains(name, "@")) if(not contains(name, "@"))
prefix = "p_"; {
return to_c_id(prefix + replace_string(name, ":", "_module_")); return to_c_id(name);
}
return to_c_id(prefix + name);
} }
static void print_py_op(std::ostream& os, const operation& op) static void print_py_op(std::ostream& os, const operation& op)
...@@ -825,8 +827,11 @@ static void print_make_op(std::ostream& os, const operation& op) ...@@ -825,8 +827,11 @@ static void print_make_op(std::ostream& os, const operation& op)
auto v = op.to_value(); auto v = op.to_value();
if(not v.empty()) if(not v.empty())
{ {
os << "migraphx::make_json_op(" << enclose_name(op.name()); os << "migraphx::make_op(" << enclose_name(op.name());
os << ", " << enclose_name(to_json_string(v)); auto rname = "{" + replace_string(to_json_string(v), "\"", "\\\"") + "}";
rname = replace_string(rname, ":", ", ");
rname = replace_string(rname, "\\", "");
os << ", " << rname;
} }
else else
{ {
......
...@@ -222,6 +222,65 @@ TEST_CASE(two_targets_ref_inbetween) ...@@ -222,6 +222,65 @@ TEST_CASE(two_targets_ref_inbetween)
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
} }
TEST_CASE(single_target_multiple_returns)
{
/*
Add (tid = 0)
|
---------------
| |
Mul Identity
(tid = 0) (tid = 0)
| |
---------------
|
Return
*/
auto s = migraphx::shape{migraphx::shape::float_type, {8}};
migraphx::target_assignments tass;
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), add_ins);
auto mul_ins = mm->add_instruction(migraphx::make_op("mul"), add_ins, z_param);
mm->add_return({mul_ins, identity_ins});
tass.insert(tass.begin(), std::make_pair(add_ins, 0));
tass.insert(tass.begin(), std::make_pair(mul_ins, 0));
tass.insert(tass.begin(), std::make_pair(identity_ins, 0));
}
migraphx::generate_root_modules(p1, tass);
migraphx::program p2;
{
migraphx::module_ref mm = p2.get_main_module();
auto z = mm->add_parameter("z", s);
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_2 = target_mod_0_0->add_parameter("param:2", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_2, target_mod_0_0_param_1);
auto x_target_mod_0_0_3 =
target_mod_0_0->add_instruction(migraphx::make_op("identity"), x_target_mod_0_0_2);
auto x_target_mod_0_0_4 = target_mod_0_0->add_instruction(
migraphx::make_op("mul"), x_target_mod_0_0_2, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_3, x_target_mod_0_0_4});
auto x_2 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {z, y, x}, {target_mod_0_0});
auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2);
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), x_2);
mm->add_return({x_4, x_3});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(if_then_else_program) TEST_CASE(if_then_else_program)
{ {
/* /*
...@@ -663,7 +722,7 @@ TEST_CASE(fork_and_merge_case_1) ...@@ -663,7 +722,7 @@ TEST_CASE(fork_and_merge_case_1)
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
}; };
TEST_CASE(fork_and_merge_case_2) TEST_CASE(fork_and_return_as_merge_bypass_branch_and_tass_on_other)
{ {
/* /*
**** Fork node returning **** **** Fork node returning ****
...@@ -716,7 +775,7 @@ TEST_CASE(fork_and_merge_case_2) ...@@ -716,7 +775,7 @@ TEST_CASE(fork_and_merge_case_2)
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
} }
TEST_CASE(fork_and_merge_case_3) TEST_CASE(fork_and_return_as_merge_bypass_branch_and_no_tass_on_other)
{ {
/* /*
**** Fork node returning **** **** Fork node returning ****
...@@ -770,7 +829,7 @@ TEST_CASE(fork_and_merge_case_3) ...@@ -770,7 +829,7 @@ TEST_CASE(fork_and_merge_case_3)
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
} }
TEST_CASE(fork_and_merge_case_4) TEST_CASE(fork_and_return_as_merge_different_tass_on_both_branches)
{ {
/* /*
Add (tid = 0) Add (tid = 0)
...@@ -848,7 +907,7 @@ TEST_CASE(fork_and_merge_case_4) ...@@ -848,7 +907,7 @@ TEST_CASE(fork_and_merge_case_4)
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
}; };
TEST_CASE(fork_and_merge_case_5) TEST_CASE(fork_and_return_as_merge_no_tass_on_both_branch)
{ {
/* /*
Add (no assignment) Add (no assignment)
...@@ -881,7 +940,7 @@ TEST_CASE(fork_and_merge_case_5) ...@@ -881,7 +940,7 @@ TEST_CASE(fork_and_merge_case_5)
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
} }
TEST_CASE(fork_and_merge_case_6) TEST_CASE(fork_and_return_as_merge_no_tass_on_one_branch)
{ {
/* /*
Add (no assignment) Add (no assignment)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment