Commit fa12da23 authored by Umang Yadav's avatar Umang Yadav
Browse files

Changes for the order fix

parent 1e80ceef
...@@ -174,18 +174,51 @@ struct auto_gen_root_modules ...@@ -174,18 +174,51 @@ struct auto_gen_root_modules
*/ */
bool is_merge_node(migraphx::instruction_ref ins, std::optional<std::size_t> ins_tid) bool is_merge_node(migraphx::instruction_ref ins, std::optional<std::size_t> ins_tid)
{ {
const auto inputs = ins->inputs(); const auto inputs = ins->inputs();
if(inputs.size() == 1) size_t in_degree = inputs.size();
if(in_degree == 1)
{ {
return false; return false;
} }
return std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) { size_t input_from_other_tid_module = 0;
return ( size_t num_default_tids = 0;
(this->skip_ins.find(input_ins) != skip_ins.end()) or size_t num_different_tids = 0;
(tass.find(input_ins) != tass.end() and size_t num_same_tid = 0;
tass.at(input_ins) != ins_tid.value_or(std::numeric_limits<std::size_t>::max()))); // std::unordered_map<size_t, size_t> in_tid_freq_map;
}); for(const auto& input_ins : inputs)
{
if(skip_ins.find(input_ins) != skip_ins.end())
{
input_from_other_tid_module++;
}
else if(tass.find(input_ins) == tass.end())
{
num_default_tids++;
}
else if(tass.at(input_ins) != ins_tid)
{
num_different_tids++;
}
else
{
num_same_tid++;
}
}
assert(input_from_other_tid_module + num_default_tids + num_different_tids + num_same_tid ==
in_degree);
if(input_from_other_tid_module > 1)
{
return true;
}
else if(input_from_other_tid_module + num_default_tids == in_degree)
{
return false;
}
else if(num_same_tid + num_default_tids == in_degree)
{
return false;
}
return true;
} }
/* /*
...@@ -200,21 +233,47 @@ struct auto_gen_root_modules ...@@ -200,21 +233,47 @@ struct auto_gen_root_modules
For the partitioner, if any of the fork node's output doesn't have same tid as the fork node For the partitioner, if any of the fork node's output doesn't have same tid as the fork node
itself then, it is classified as boundary for subgraph. itself then, it is classified as boundary for subgraph.
*/ */
bool is_fork_node(migraphx::instruction_ref ins, std::optional<std::size_t> ins_tid) bool is_fork_node(migraphx::instruction_ref ins, std::size_t ins_tid)
{ {
const auto outputs = ins->outputs(); const auto outputs = ins->outputs();
if(outputs.size() == 1) if(outputs.size() == 1)
{ {
return false; return false;
} }
// if all the outputs are for the "default" or with same tid then it is not a fork but
// rather simply a boundary
std::unordered_map<std::size_t, std::size_t> output_tids;
for(const auto& output_ins : outputs)
{
if(tass.find(output_ins) != tass.end())
{
auto out_tid = tass.at(output_ins);
if(output_tids.find(out_tid) == output_tids.end())
{
output_tids[out_tid] = 1;
}
else
{
output_tids[out_tid]++;
}
}
}
if(output_tids.empty())
{
return false;
}
else if(output_tids.size() == 1 and output_tids.cbegin()->second == outputs.size())
{
return false;
}
return std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) { return std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) {
if(output_ins->name() == "return") if(output_ins->name() == "return")
{ {
return false; return false;
} }
return (tass.find(output_ins) != tass.end() and return (tass.find(output_ins) != tass.end() and tass.at(output_ins) != ins_tid);
tass.at(output_ins) !=
ins_tid.value_or(std::numeric_limits<std::size_t>::max()));
}); });
} }
...@@ -262,7 +321,7 @@ struct auto_gen_root_modules ...@@ -262,7 +321,7 @@ struct auto_gen_root_modules
current_tid = std::make_optional<std::size_t>(tass.at(ins)); current_tid = std::make_optional<std::size_t>(tass.at(ins));
same_tid_ins_vec.push_back(ins); same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins); same_tid_ins_set.insert(ins);
fork_node = is_fork_node(ins, current_tid); fork_node = is_fork_node(ins, current_tid.value());
} }
} }
else else
...@@ -281,7 +340,8 @@ struct auto_gen_root_modules ...@@ -281,7 +340,8 @@ struct auto_gen_root_modules
{ {
MIGRAPHX_THROW("GenerateRootModules: this case shouldn't occur"); MIGRAPHX_THROW("GenerateRootModules: this case shouldn't occur");
} }
fork_node = is_fork_node(ins, current_tid); fork_node = is_fork_node(
ins, current_tid.value_or(std::numeric_limits<std::size_t>::max()));
} }
if(not ins->module_inputs().empty()) if(not ins->module_inputs().empty())
...@@ -315,7 +375,8 @@ struct auto_gen_root_modules ...@@ -315,7 +375,8 @@ struct auto_gen_root_modules
return; return;
} }
// gather all parameters // gather all parameters
std::unordered_set<instruction_ref> params; std::unordered_set<instruction_ref> params_set;
std::vector<instruction_ref> params_vec;
// gather all return values // gather all return values
std::vector<instruction_ref> return_ins; std::vector<instruction_ref> return_ins;
for(auto tins : iterator_for(same_tid_ins_vec)) for(auto tins : iterator_for(same_tid_ins_vec))
...@@ -325,11 +386,15 @@ struct auto_gen_root_modules ...@@ -325,11 +386,15 @@ struct auto_gen_root_modules
transform_if( transform_if(
inputs.cbegin(), inputs.cbegin(),
inputs.cend(), inputs.cend(),
std::inserter(params, params.end()), std::back_inserter(params_vec),
[&](auto in_param) { [&](auto in_param) {
return (params.count(in_param) == 0 and same_tid_ins_set.count(in_param) == 0); return (params_set.count(in_param) == 0 and
same_tid_ins_set.count(in_param) == 0);
}, },
[&](auto in_param) { return in_param; }); [&](auto in_param) {
params_set.insert(in_param);
return in_param;
});
if(std::any_of(outputs.begin(), outputs.end(), [&](const auto out_ins) { if(std::any_of(outputs.begin(), outputs.end(), [&](const auto out_ins) {
return same_tid_ins_set.count(out_ins) == 0; return same_tid_ins_set.count(out_ins) == 0;
})) }))
...@@ -340,7 +405,7 @@ struct auto_gen_root_modules ...@@ -340,7 +405,7 @@ struct auto_gen_root_modules
if(enabled(MIGRAPHX_DEBUG_ROOT_GENERATOR{})) if(enabled(MIGRAPHX_DEBUG_ROOT_GENERATOR{}))
{ {
std::cout << "params ins: \n"; std::cout << "params ins: \n";
for(auto tmp : iterator_for(params)) for(auto tmp : iterator_for(params_vec))
{ {
(*tmp)->debug_print(); (*tmp)->debug_print();
} }
...@@ -357,7 +422,7 @@ struct auto_gen_root_modules ...@@ -357,7 +422,7 @@ struct auto_gen_root_modules
std::unordered_map<instruction_ref, instruction_ref> params_map; std::unordered_map<instruction_ref, instruction_ref> params_map;
std::size_t param_counter = 0; std::size_t param_counter = 0;
std::vector<instruction_ref> rot_ins_params; std::vector<instruction_ref> rot_ins_params;
for(auto pins : iterator_for(params)) for(auto pins : iterator_for(params_vec))
{ {
auto scalar = get_scalar(*pins); auto scalar = get_scalar(*pins);
if(scalar.empty()) if(scalar.empty())
......
...@@ -69,14 +69,14 @@ TEST_CASE(single_target_test) ...@@ -69,14 +69,14 @@ TEST_CASE(single_target_test)
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0"); migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter("param:1", s);
auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter("param:0", s); auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter("param:0", s);
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter("param:1", s);
auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction( auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction(
migraphx::make_op("add"), target_mod_1_0_param_1, target_mod_1_0_param_0); migraphx::make_op("add"), target_mod_1_0_param_0, target_mod_1_0_param_1);
target_mod_1_0->add_return({x_target_mod_1_0_2}); target_mod_1_0->add_return({x_target_mod_1_0_2});
auto x_2 = mm->add_instruction( auto x_2 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 1}}), {y, x}, {target_mod_1_0}); migraphx::make_op("run_on_target", {{"target_id", 1}}), {x, y}, {target_mod_1_0});
auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2); auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2);
mm->add_return({x_3}); mm->add_return({x_3});
} }
...@@ -115,36 +115,33 @@ TEST_CASE(two_targets_with_ref) ...@@ -115,36 +115,33 @@ TEST_CASE(two_targets_with_ref)
migraphx::generate_root_modules(p1, tass); migraphx::generate_root_modules(p1, tass);
migraphx::program p2; migraphx::program p2;
{ {
migraphx::module_ref mm = p2.get_main_module(); migraphx::module_ref mm = p2.get_main_module();
auto z = mm->add_parameter("z", s); auto z = mm->add_parameter("z", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto identity_ins_0 = mm->add_instruction(migraphx::make_op("identity"), x); auto identity_ins_0 = mm->add_instruction(migraphx::make_op("identity"), x);
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0"); migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter( auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter("param:0", s);
"param:1", migraphx::shape{migraphx::shape::float_type, {8}}); auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter("param:1", s);
auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter( auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}}); migraphx::make_op("add"), target_mod_1_0_param_0, target_mod_1_0_param_1);
auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction(
migraphx::make_op("add"), target_mod_1_0_param_1, target_mod_1_0_param_0);
target_mod_1_0->add_return({x_target_mod_1_0_2}); target_mod_1_0->add_return({x_target_mod_1_0_2});
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0"); migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter( auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
"param:1", migraphx::shape{migraphx::shape::float_type, {8}}); auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter( auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}}); migraphx::make_op("mul"), target_mod_0_0_param_0, target_mod_0_0_param_1);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("mul"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_2}); target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_3 = mm->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 1}}), auto x_3 = mm->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 1}}),
{y, identity_ins_0}, {identity_ins_0, y},
{target_mod_1_0}); {target_mod_1_0});
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3); auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3);
auto x_5 = mm->add_instruction( auto x_5 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {z, x_4}, {target_mod_0_0}); migraphx::make_op("run_on_target", {{"target_id", 0}}), {x_4, z}, {target_mod_0_0});
auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5); auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5);
auto identity_ins_1 = mm->add_instruction(migraphx::make_op("identity"), x_6); auto identity_ins_1 = mm->add_instruction(migraphx::make_op("identity"), x_6);
mm->add_return({identity_ins_1}); mm->add_return({identity_ins_1});
...@@ -191,37 +188,100 @@ TEST_CASE(two_targets_ref_inbetween) ...@@ -191,37 +188,100 @@ TEST_CASE(two_targets_ref_inbetween)
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), x); auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), x);
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0"); migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter( auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter("param:0", s);
"param:1", migraphx::shape{migraphx::shape::float_type, {8}}); auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter("param:1", s);
auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter( auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}}); migraphx::make_op("add"), target_mod_1_0_param_0, target_mod_1_0_param_1);
auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction(
migraphx::make_op("add"), target_mod_1_0_param_1, target_mod_1_0_param_0);
target_mod_1_0->add_return({x_target_mod_1_0_2}); target_mod_1_0->add_return({x_target_mod_1_0_2});
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0"); migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter( auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
"param:1", migraphx::shape{migraphx::shape::float_type, {8}}); auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter( auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}}); migraphx::make_op("mul"), target_mod_0_0_param_0, target_mod_0_0_param_1);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("mul"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_2}); target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_3 = mm->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 1}}), auto x_3 = mm->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 1}}),
{y, identity_ins}, {identity_ins, y},
{target_mod_1_0}); {target_mod_1_0});
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3); auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3);
auto x_5 = mm->add_instruction(migraphx::make_op("identity"), x_4); auto x_5 = mm->add_instruction(migraphx::make_op("identity"), x_4);
auto x_6 = mm->add_instruction( auto x_6 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {z, x_5}, {target_mod_0_0}); migraphx::make_op("run_on_target", {{"target_id", 0}}), {x_5, z}, {target_mod_0_0});
auto x_7 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_6); auto x_7 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_6);
mm->add_return({x_7}); mm->add_return({x_7});
} }
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
} }
TEST_CASE(multiple_targets_multiple_returns)
{
/*
Add (tid = 1)
|
---------------
| |
Mul Identity
(tid = 0) (tid = 0)
| |
---------------
|
Return
*/
auto s = migraphx::shape{migraphx::shape::float_type, {8}};
migraphx::target_assignments tass;
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), add_ins);
auto mul_ins = mm->add_instruction(migraphx::make_op("mul"), add_ins, z_param);
mm->add_return({mul_ins, identity_ins});
tass.insert(tass.begin(), std::make_pair(add_ins, 1));
tass.insert(tass.begin(), std::make_pair(mul_ins, 0));
tass.insert(tass.begin(), std::make_pair(identity_ins, 0));
}
migraphx::generate_root_modules(p1, tass);
migraphx::program p2;
{
migraphx::module_ref mm = p2.get_main_module();
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter("param:0", s);
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter("param:1", s);
auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction(
migraphx::make_op("add"), target_mod_1_0_param_0, target_mod_1_0_param_1);
target_mod_1_0->add_return({x_target_mod_1_0_2});
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto x_target_mod_0_0_2 =
target_mod_0_0->add_instruction(migraphx::make_op("identity"), target_mod_0_0_param_0);
auto x_target_mod_0_0_3 = target_mod_0_0->add_instruction(
migraphx::make_op("mul"), target_mod_0_0_param_0, target_mod_0_0_param_1);
target_mod_0_0->add_return({x_target_mod_0_0_2, x_target_mod_0_0_3});
auto x_2 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 1}}), {x, y}, {target_mod_1_0});
auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2);
auto z = mm->add_parameter("z", s);
auto x_5 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x_3, z}, {target_mod_0_0});
auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5);
auto x_7 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), x_5);
mm->add_return({x_7, x_6});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(single_target_multiple_returns) TEST_CASE(single_target_multiple_returns)
{ {
/* /*
...@@ -256,27 +316,28 @@ TEST_CASE(single_target_multiple_returns) ...@@ -256,27 +316,28 @@ TEST_CASE(single_target_multiple_returns)
migraphx::generate_root_modules(p1, tass); migraphx::generate_root_modules(p1, tass);
migraphx::program p2; migraphx::program p2;
{ {
migraphx::module_ref mm = p2.get_main_module(); migraphx::module_ref mm = p2.get_main_module();
auto z = mm->add_parameter("z", s); auto y = mm->add_parameter("y", s);
auto y = mm->add_parameter("y", s); auto x = mm->add_parameter("x", s);
auto x = mm->add_parameter("x", s); auto z = mm->add_parameter("z", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0"); migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_2 = target_mod_0_0->add_parameter("param:2", s); auto target_mod_0_0_param_2 = target_mod_0_0->add_parameter("param:2", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s); auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s); auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction( auto x_target_mod_0_0_3 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_2, target_mod_0_0_param_1); migraphx::make_op("add"), target_mod_0_0_param_0, target_mod_0_0_param_1);
auto x_target_mod_0_0_3 = auto x_target_mod_0_0_4 =
target_mod_0_0->add_instruction(migraphx::make_op("identity"), x_target_mod_0_0_2); target_mod_0_0->add_instruction(migraphx::make_op("identity"), x_target_mod_0_0_3);
auto x_target_mod_0_0_4 = target_mod_0_0->add_instruction( auto x_target_mod_0_0_5 = target_mod_0_0->add_instruction(
migraphx::make_op("mul"), x_target_mod_0_0_2, target_mod_0_0_param_0); migraphx::make_op("mul"), x_target_mod_0_0_3, target_mod_0_0_param_2);
target_mod_0_0->add_return({x_target_mod_0_0_3, x_target_mod_0_0_4}); target_mod_0_0->add_return({x_target_mod_0_0_4, x_target_mod_0_0_5});
auto x_2 = mm->add_instruction( auto x_3 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {z, y, x}, {target_mod_0_0}); migraphx::make_op("run_on_target", {{"target_id", 0}}), {x, y, z}, {target_mod_0_0});
auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2); auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3);
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), x_2); auto x_5 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), x_3);
mm->add_return({x_4, x_3}); mm->add_return({x_5, x_4});
} }
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
} }
...@@ -332,38 +393,34 @@ TEST_CASE(if_then_else_program) ...@@ -332,38 +393,34 @@ TEST_CASE(if_then_else_program)
auto cond = mm->add_parameter("cond", cond_s); auto cond = mm->add_parameter("cond", cond_s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0"); migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter( auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", ds);
"param:1", migraphx::shape{migraphx::shape::float_type, {2, 3}}); auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", ds);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter( auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
"param:0", migraphx::shape{migraphx::shape::float_type, {2, 3}}); migraphx::make_op("add"), target_mod_0_0_param_0, target_mod_0_0_param_1);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_2}); target_mod_0_0->add_return({x_target_mod_0_0_2});
migraphx::module_ref if_gpu_mod = p2.create_module("if_gpu_mod"); migraphx::module_ref if_gpu_mod = p2.create_module("if_gpu_mod");
auto x_if_gpu_mod_0 = if_gpu_mod->add_literal(migraphx::literal(ds, data1)); auto x_if_gpu_mod_0 = if_gpu_mod->add_literal(migraphx::literal(ds, data1));
auto x_if_gpu_mod_1 = auto x_if_gpu_mod_1 =
if_gpu_mod->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 0}}), if_gpu_mod->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 0}}),
{x_if_gpu_mod_0, x}, {x, x_if_gpu_mod_0},
{target_mod_0_0}); {target_mod_0_0});
auto x_if_gpu_mod_2 = if_gpu_mod->add_instruction( auto x_if_gpu_mod_2 = if_gpu_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_if_gpu_mod_1); migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_if_gpu_mod_1);
if_gpu_mod->add_return({x_if_gpu_mod_2}); if_gpu_mod->add_return({x_if_gpu_mod_2});
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0"); migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter( auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter("param:0", ds);
"param:1", migraphx::shape{migraphx::shape::float_type, {2, 3}}); auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter("param:1", ds);
auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter( auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction(
"param:0", migraphx::shape{migraphx::shape::float_type, {2, 3}}); migraphx::make_op("mul"), target_mod_1_0_param_0, target_mod_1_0_param_1);
auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction(
migraphx::make_op("mul"), target_mod_1_0_param_1, target_mod_1_0_param_0);
target_mod_1_0->add_return({x_target_mod_1_0_2}); target_mod_1_0->add_return({x_target_mod_1_0_2});
migraphx::module_ref else_cpu_mod = p2.create_module("else_cpu_mod"); migraphx::module_ref else_cpu_mod = p2.create_module("else_cpu_mod");
auto x_else_cpu_mod_0 = else_cpu_mod->add_literal(migraphx::literal(ds, data2)); auto x_else_cpu_mod_0 = else_cpu_mod->add_literal(migraphx::literal(ds, data2));
auto x_else_cpu_mod_1 = auto x_else_cpu_mod_1 =
else_cpu_mod->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 1}}), else_cpu_mod->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 1}}),
{x_else_cpu_mod_0, y}, {y, x_else_cpu_mod_0},
{target_mod_1_0}); {target_mod_1_0});
auto x_else_cpu_mod_2 = else_cpu_mod->add_instruction( auto x_else_cpu_mod_2 = else_cpu_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_else_cpu_mod_1); migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_else_cpu_mod_1);
...@@ -424,25 +481,25 @@ TEST_CASE(merge_case_1) ...@@ -424,25 +481,25 @@ TEST_CASE(merge_case_1)
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3); auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0"); migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s); auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction( auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0); migraphx::make_op("add"), target_mod_0_0_param_0, target_mod_0_0_param_1);
target_mod_0_0->add_return({x_target_mod_0_0_2}); target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_5 = mm->add_instruction( auto x_5 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {y, x}, {target_mod_0_0}); migraphx::make_op("run_on_target", {{"target_id", 0}}), {x, y}, {target_mod_0_0});
auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5); auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5);
migraphx::module_ref target_mod_0_1 = p2.create_module("target_mod_0_1"); migraphx::module_ref target_mod_0_1 = p2.create_module("target_mod_0_1");
auto target_mod_0_1_param_1 = target_mod_0_1->add_parameter("param:1", s);
auto target_mod_0_1_param_0 = target_mod_0_1->add_parameter("param:0", s); auto target_mod_0_1_param_0 = target_mod_0_1->add_parameter("param:0", s);
auto target_mod_0_1_param_1 = target_mod_0_1->add_parameter("param:1", s);
auto x_target_mod_0_1_2 = target_mod_0_1->add_instruction( auto x_target_mod_0_1_2 = target_mod_0_1->add_instruction(
migraphx::make_op("mul"), target_mod_0_1_param_1, target_mod_0_1_param_0); migraphx::make_op("mul"), target_mod_0_1_param_0, target_mod_0_1_param_1);
target_mod_0_1->add_return({x_target_mod_0_1_2}); target_mod_0_1->add_return({x_target_mod_0_1_2});
auto x_7 = mm->add_instruction( auto x_7 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x_4, x_6}, {target_mod_0_1}); migraphx::make_op("run_on_target", {{"target_id", 0}}), {x_6, x_4}, {target_mod_0_1});
auto x_8 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_7); auto x_8 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_7);
mm->add_return({x_8}); mm->add_return({x_8});
} }
...@@ -517,14 +574,14 @@ TEST_CASE(merge_case_3) ...@@ -517,14 +574,14 @@ TEST_CASE(merge_case_3)
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0"); migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s); auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction( auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0); migraphx::make_op("add"), target_mod_0_0_param_0, target_mod_0_0_param_1);
target_mod_0_0->add_return({x_target_mod_0_0_2}); target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_4 = mm->add_instruction( auto x_4 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {y, x}, {target_mod_0_0}); migraphx::make_op("run_on_target", {{"target_id", 0}}), {x, y}, {target_mod_0_0});
auto x_5 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_4); auto x_5 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_4);
auto x_6 = mm->add_instruction(migraphx::make_op("mul"), x_5, x_1); auto x_6 = mm->add_instruction(migraphx::make_op("mul"), x_5, x_1);
mm->add_return({x_6}); mm->add_return({x_6});
...@@ -566,14 +623,14 @@ TEST_CASE(merge_case_4) ...@@ -566,14 +623,14 @@ TEST_CASE(merge_case_4)
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0"); migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s); auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction( auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0); migraphx::make_op("add"), target_mod_0_0_param_0, target_mod_0_0_param_1);
target_mod_0_0->add_return({x_target_mod_0_0_2}); target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_4 = mm->add_instruction( auto x_4 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {y, x}, {target_mod_0_0}); migraphx::make_op("run_on_target", {{"target_id", 0}}), {x, y}, {target_mod_0_0});
auto x_5 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_4); auto x_5 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_4);
mm->add_return({x_5, x_1}); mm->add_return({x_5, x_1});
} }
...@@ -617,14 +674,14 @@ TEST_CASE(merge_case_5) ...@@ -617,14 +674,14 @@ TEST_CASE(merge_case_5)
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0"); migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s); auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction( auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0); migraphx::make_op("add"), target_mod_0_0_param_0, target_mod_0_0_param_1);
target_mod_0_0->add_return({x_target_mod_0_0_2}); target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_4 = mm->add_instruction( auto x_4 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {y, x}, {target_mod_0_0}); migraphx::make_op("run_on_target", {{"target_id", 0}}), {x, y}, {target_mod_0_0});
auto x_5 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_4); auto x_5 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_4);
auto x_6 = mm->add_instruction(migraphx::make_op("identity"), x_5); auto x_6 = mm->add_instruction(migraphx::make_op("identity"), x_5);
mm->add_return({x_6, x_1}); mm->add_return({x_6, x_1});
...@@ -676,25 +733,25 @@ TEST_CASE(fork_and_merge_case_1) ...@@ -676,25 +733,25 @@ TEST_CASE(fork_and_merge_case_1)
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0"); migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s); auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction( auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0); migraphx::make_op("add"), target_mod_0_0_param_0, target_mod_0_0_param_1);
target_mod_0_0->add_return({x_target_mod_0_0_2}); target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_3 = mm->add_instruction( auto x_3 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {y, x}, {target_mod_0_0}); migraphx::make_op("run_on_target", {{"target_id", 0}}), {x, y}, {target_mod_0_0});
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3); auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3);
migraphx::module_ref target_mod_0_1 = p2.create_module("target_mod_0_1"); migraphx::module_ref target_mod_0_1 = p2.create_module("target_mod_0_1");
auto target_mod_0_1_param_1 = target_mod_0_1->add_parameter("param:1", s);
auto target_mod_0_1_param_0 = target_mod_0_1->add_parameter("param:0", s); auto target_mod_0_1_param_0 = target_mod_0_1->add_parameter("param:0", s);
auto target_mod_0_1_param_1 = target_mod_0_1->add_parameter("param:1", s);
auto x_target_mod_0_1_2 = target_mod_0_1->add_instruction( auto x_target_mod_0_1_2 = target_mod_0_1->add_instruction(
migraphx::make_op("mul"), target_mod_0_1_param_1, target_mod_0_1_param_0); migraphx::make_op("mul"), target_mod_0_1_param_0, target_mod_0_1_param_1);
target_mod_0_1->add_return({x_target_mod_0_1_2}); target_mod_0_1->add_return({x_target_mod_0_1_2});
auto x_5 = mm->add_instruction( auto x_5 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {z, x_4}, {target_mod_0_1}); migraphx::make_op("run_on_target", {{"target_id", 0}}), {x_4, z}, {target_mod_0_1});
auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5); auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5);
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0"); migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
...@@ -708,14 +765,14 @@ TEST_CASE(fork_and_merge_case_1) ...@@ -708,14 +765,14 @@ TEST_CASE(fork_and_merge_case_1)
auto x_8 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_7); auto x_8 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_7);
migraphx::module_ref target_mod_0_2 = p2.create_module("target_mod_0_2"); migraphx::module_ref target_mod_0_2 = p2.create_module("target_mod_0_2");
auto target_mod_0_2_param_1 = target_mod_0_2->add_parameter("param:1", s);
auto target_mod_0_2_param_0 = target_mod_0_2->add_parameter("param:0", s); auto target_mod_0_2_param_0 = target_mod_0_2->add_parameter("param:0", s);
auto target_mod_0_2_param_1 = target_mod_0_2->add_parameter("param:1", s);
auto x_target_mod_0_2_2 = target_mod_0_2->add_instruction( auto x_target_mod_0_2_2 = target_mod_0_2->add_instruction(
migraphx::make_op("sub"), target_mod_0_2_param_1, target_mod_0_2_param_0); migraphx::make_op("sub"), target_mod_0_2_param_0, target_mod_0_2_param_1);
target_mod_0_2->add_return({x_target_mod_0_2_2}); target_mod_0_2->add_return({x_target_mod_0_2_2});
auto x_9 = mm->add_instruction( auto x_9 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x_6, x_8}, {target_mod_0_2}); migraphx::make_op("run_on_target", {{"target_id", 0}}), {x_8, x_6}, {target_mod_0_2});
auto x_10 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_9); auto x_10 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_9);
mm->add_return({x_10}); mm->add_return({x_10});
} }
...@@ -761,13 +818,13 @@ TEST_CASE(fork_and_return_as_merge_bypass_branch_and_tass_on_other) ...@@ -761,13 +818,13 @@ TEST_CASE(fork_and_return_as_merge_bypass_branch_and_tass_on_other)
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s); auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s); auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto x_target_mod_0_0_0 = target_mod_0_0->add_instruction( auto x_target_mod_0_0_0 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0); migraphx::make_op("add"), target_mod_0_0_param_0, target_mod_0_0_param_1);
auto x_target_mod_0_0_1 = auto x_target_mod_0_0_1 =
target_mod_0_0->add_instruction(migraphx::make_op("identity"), x_target_mod_0_0_0); target_mod_0_0->add_instruction(migraphx::make_op("identity"), x_target_mod_0_0_0);
target_mod_0_0->add_return({x_target_mod_0_0_0, x_target_mod_0_0_1}); target_mod_0_0->add_return({x_target_mod_0_0_0, x_target_mod_0_0_1});
auto x_2 = mm->add_instruction( auto x_2 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {y, x}, {target_mod_0_0}); migraphx::make_op("run_on_target", {{"target_id", 0}}), {x, y}, {target_mod_0_0});
auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2); auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2);
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), x_2); auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), x_2);
mm->add_return({x_3, x_4}); mm->add_return({x_3, x_4});
...@@ -813,14 +870,14 @@ TEST_CASE(fork_and_return_as_merge_bypass_branch_and_no_tass_on_other) ...@@ -813,14 +870,14 @@ TEST_CASE(fork_and_return_as_merge_bypass_branch_and_no_tass_on_other)
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0"); migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s); auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction( auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0); migraphx::make_op("add"), target_mod_0_0_param_0, target_mod_0_0_param_1);
target_mod_0_0->add_return({x_target_mod_0_0_2}); target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_2 = mm->add_instruction( auto x_2 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {y, x}, {target_mod_0_0}); migraphx::make_op("run_on_target", {{"target_id", 0}}), {x, y}, {target_mod_0_0});
auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2); auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2);
auto x_4 = mm->add_instruction(migraphx::make_op("identity"), x_3); auto x_4 = mm->add_instruction(migraphx::make_op("identity"), x_3);
auto x_5 = mm->add_instruction(migraphx::make_op("identity"), x_4); auto x_5 = mm->add_instruction(migraphx::make_op("identity"), x_4);
...@@ -869,14 +926,14 @@ TEST_CASE(fork_and_return_as_merge_different_tass_on_both_branches) ...@@ -869,14 +926,14 @@ TEST_CASE(fork_and_return_as_merge_different_tass_on_both_branches)
auto x_param = mm->add_parameter("x", s); auto x_param = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0"); migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s); auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_add = target_mod_0_0->add_instruction( auto target_mod_0_0_add = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0); migraphx::make_op("add"), target_mod_0_0_param_0, target_mod_0_0_param_1);
target_mod_0_0->add_return({target_mod_0_0_add}); target_mod_0_0->add_return({target_mod_0_0_add});
auto x_2 = mm->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 0}}), auto x_2 = mm->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 0}}),
{y_param, x_param}, {x_param, y_param},
{target_mod_0_0}); {target_mod_0_0});
auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2); auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2);
...@@ -892,14 +949,14 @@ TEST_CASE(fork_and_return_as_merge_different_tass_on_both_branches) ...@@ -892,14 +949,14 @@ TEST_CASE(fork_and_return_as_merge_different_tass_on_both_branches)
auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5); auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5);
migraphx::module_ref target_mod_0_1 = p2.create_module("target_mod_0_1"); migraphx::module_ref target_mod_0_1 = p2.create_module("target_mod_0_1");
auto target_mod_0_1_param_1 = target_mod_0_1->add_parameter("param:1", s);
auto target_mod_0_1_param_0 = target_mod_0_1->add_parameter("param:0", s); auto target_mod_0_1_param_0 = target_mod_0_1->add_parameter("param:0", s);
auto target_mod_0_1_param_1 = target_mod_0_1->add_parameter("param:1", s);
auto target_mod_0_1_mul = target_mod_0_1->add_instruction( auto target_mod_0_1_mul = target_mod_0_1->add_instruction(
migraphx::make_op("mul"), target_mod_0_1_param_1, target_mod_0_1_param_0); migraphx::make_op("mul"), target_mod_0_1_param_0, target_mod_0_1_param_1);
target_mod_0_1->add_return({target_mod_0_1_mul}); target_mod_0_1->add_return({target_mod_0_1_mul});
auto x_7 = mm->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 0}}), auto x_7 = mm->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 0}}),
{z_param, x_3}, {x_3, z_param},
{target_mod_0_1}); {target_mod_0_1});
auto x_8 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_7); auto x_8 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_7);
mm->add_return({x_8, x_6}); mm->add_return({x_8, x_6});
...@@ -980,14 +1037,14 @@ TEST_CASE(fork_and_return_as_merge_no_tass_on_one_branch) ...@@ -980,14 +1037,14 @@ TEST_CASE(fork_and_return_as_merge_no_tass_on_one_branch)
auto z = mm->add_parameter("z", s); auto z = mm->add_parameter("z", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0"); migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s); auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction( auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("mul"), target_mod_0_0_param_1, target_mod_0_0_param_0); migraphx::make_op("mul"), target_mod_0_0_param_0, target_mod_0_0_param_1);
target_mod_0_0->add_return({x_target_mod_0_0_2}); target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_5 = mm->add_instruction( auto x_5 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {z, x_2}, {target_mod_0_0}); migraphx::make_op("run_on_target", {{"target_id", 0}}), {x_2, z}, {target_mod_0_0});
auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5); auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5);
mm->add_return({x_6, x_3}); mm->add_return({x_6, x_3});
} }
...@@ -1095,15 +1152,15 @@ TEST_CASE(nested_if_then_else_program) ...@@ -1095,15 +1152,15 @@ TEST_CASE(nested_if_then_else_program)
auto test_mod_param_0 = test_mod->add_parameter("param:0", ds); auto test_mod_param_0 = test_mod->add_parameter("param:0", ds);
auto test_mod_param_1 = test_mod->add_parameter("param:1", ds); auto test_mod_param_1 = test_mod->add_parameter("param:1", ds);
auto ins1 = test_mod->add_instruction( auto ins1 = test_mod->add_instruction(
migraphx::make_op("add"), test_mod_param_1, test_mod_param_0); migraphx::make_op("add"), test_mod_param_0, test_mod_param_1);
test_mod->add_return({ins1}); test_mod->add_return({ins1});
tass.insert(tass.begin(), std::make_pair(ins1, tid)); tass.insert(tass.begin(), std::make_pair(ins1, tid));
return test_mod; return test_mod;
}; };
migraphx::module_ref target_1_0 = p2.create_module("target_1_0"); migraphx::module_ref target_1_0 = p2.create_module("target_1_0");
auto target_1_0_1_param_0 = target_1_0->add_literal(ds, data); auto target_1_0_1_param_0 = target_1_0->add_parameter("1__param_0", ds);
auto target_1_0_1_param_1 = target_1_0->add_parameter("1__param_0", ds); auto target_1_0_1_param_1 = target_1_0->add_literal(ds, data);
auto x_target_1_0_2 = auto x_target_1_0_2 =
target_1_0->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 1}}), target_1_0->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 1}}),
{target_1_0_1_param_0, target_1_0_1_param_1}, {target_1_0_1_param_0, target_1_0_1_param_1},
...@@ -1113,8 +1170,8 @@ TEST_CASE(nested_if_then_else_program) ...@@ -1113,8 +1170,8 @@ TEST_CASE(nested_if_then_else_program)
target_1_0->add_return({x_target_1_0_3}); target_1_0->add_return({x_target_1_0_3});
migraphx::module_ref target_0_0 = p2.create_module("target_0_0"); migraphx::module_ref target_0_0 = p2.create_module("target_0_0");
auto target_0_0_2_param_0 = target_0_0->add_literal(ds, data); auto target_0_0_2_param_0 = target_0_0->add_parameter("2__param_0", ds);
auto target_0_0_2_param_1 = target_0_0->add_parameter("2__param_0", ds); auto target_0_0_2_param_1 = target_0_0->add_literal(ds, data);
auto x_target_0_0_2 = auto x_target_0_0_2 =
target_0_0->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 0}}), target_0_0->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 0}}),
{target_0_0_2_param_0, target_0_0_2_param_1}, {target_0_0_2_param_0, target_0_0_2_param_1},
...@@ -1124,10 +1181,10 @@ TEST_CASE(nested_if_then_else_program) ...@@ -1124,10 +1181,10 @@ TEST_CASE(nested_if_then_else_program)
target_0_0->add_return({x_target_0_0_3}); target_0_0->add_return({x_target_0_0_3});
migraphx::module_ref target_3_0 = p2.create_module("target_mod_3_0"); migraphx::module_ref target_3_0 = p2.create_module("target_mod_3_0");
auto target_mod_3_0_param_1 = target_3_0->add_parameter("param:1", ds);
auto target_mod_3_0_param_0 = target_3_0->add_parameter("param:0", ds); auto target_mod_3_0_param_0 = target_3_0->add_parameter("param:0", ds);
auto target_mod_3_0_param_1 = target_3_0->add_parameter("param:1", ds);
auto target_3_add_ins = target_3_0->add_instruction( auto target_3_add_ins = target_3_0->add_instruction(
migraphx::make_op("add"), target_mod_3_0_param_1, target_mod_3_0_param_0); migraphx::make_op("add"), target_mod_3_0_param_0, target_mod_3_0_param_1);
target_3_0->add_return({target_3_add_ins}); target_3_0->add_return({target_3_add_ins});
migraphx::module_ref then_mod = p2.create_module("then_mod"); migraphx::module_ref then_mod = p2.create_module("then_mod");
...@@ -1136,7 +1193,7 @@ TEST_CASE(nested_if_then_else_program) ...@@ -1136,7 +1193,7 @@ TEST_CASE(nested_if_then_else_program)
auto then_mod_then_mod_cond = then_mod->add_parameter("then_mod_cond", cond_s); auto then_mod_then_mod_cond = then_mod->add_parameter("then_mod_cond", cond_s);
auto x_then_mod_3 = auto x_then_mod_3 =
then_mod->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 3}}), then_mod->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 3}}),
{then_mod_then_mod_param_1, then_mod_then_mod_param_0}, {then_mod_then_mod_param_0, then_mod_then_mod_param_1},
{target_3_0}); {target_3_0});
auto x_then_mod_4 = then_mod->add_instruction( auto x_then_mod_4 = then_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_then_mod_3); migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_then_mod_3);
...@@ -1149,8 +1206,8 @@ TEST_CASE(nested_if_then_else_program) ...@@ -1149,8 +1206,8 @@ TEST_CASE(nested_if_then_else_program)
then_mod->add_return({x_then_mod_6}); then_mod->add_return({x_then_mod_6});
migraphx::module_ref target_0_1 = p2.create_module("target_0_1"); migraphx::module_ref target_0_1 = p2.create_module("target_0_1");
auto target_0_1_1_param_0 = target_0_1->add_literal(ds, data); auto target_0_1_1_param_0 = target_0_1->add_parameter("1__param_0", ds);
auto target_0_1_1_param_1 = target_0_1->add_parameter("1__param_0", ds); auto target_0_1_1_param_1 = target_0_1->add_literal(ds, data);
auto x_target_0_1_2 = auto x_target_0_1_2 =
target_0_1->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 0}}), target_0_1->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 0}}),
{target_0_1_1_param_0, target_0_1_1_param_1}, {target_0_1_1_param_0, target_0_1_1_param_1},
...@@ -1160,8 +1217,8 @@ TEST_CASE(nested_if_then_else_program) ...@@ -1160,8 +1217,8 @@ TEST_CASE(nested_if_then_else_program)
target_0_1->add_return({x_target_0_1_3}); target_0_1->add_return({x_target_0_1_3});
migraphx::module_ref target_1_1 = p2.create_module("target_1_1"); migraphx::module_ref target_1_1 = p2.create_module("target_1_1");
auto target_1_1_2_param_0 = target_1_1->add_literal(ds, data); auto target_1_1_2_param_0 = target_1_1->add_parameter("2__param_0", ds);
auto target_1_1_2_param_1 = target_1_1->add_parameter("2__param_0", ds); auto target_1_1_2_param_1 = target_1_1->add_literal(ds, data);
auto x_target_1_1_2 = auto x_target_1_1_2 =
target_1_1->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 1}}), target_1_1->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 1}}),
{target_1_1_2_param_0, target_1_1_2_param_1}, {target_1_1_2_param_0, target_1_1_2_param_1},
...@@ -1171,10 +1228,10 @@ TEST_CASE(nested_if_then_else_program) ...@@ -1171,10 +1228,10 @@ TEST_CASE(nested_if_then_else_program)
target_1_1->add_return({x_target_1_1_3}); target_1_1->add_return({x_target_1_1_3});
migraphx::module_ref target_2_0 = p2.create_module("target_mod_2_0"); migraphx::module_ref target_2_0 = p2.create_module("target_mod_2_0");
auto target_mod_2_0_param_1 = target_2_0->add_parameter("param:1", ds);
auto target_mod_2_0_param_0 = target_2_0->add_parameter("param:0", ds); auto target_mod_2_0_param_0 = target_2_0->add_parameter("param:0", ds);
auto target_mod_2_0_param_1 = target_2_0->add_parameter("param:1", ds);
auto target_2_mul_ins = target_2_0->add_instruction( auto target_2_mul_ins = target_2_0->add_instruction(
migraphx::make_op("mul"), target_mod_2_0_param_1, target_mod_2_0_param_0); migraphx::make_op("mul"), target_mod_2_0_param_0, target_mod_2_0_param_1);
target_2_0->add_return({target_2_mul_ins}); target_2_0->add_return({target_2_mul_ins});
migraphx::module_ref else_mod = p2.create_module("else_mod"); migraphx::module_ref else_mod = p2.create_module("else_mod");
...@@ -1183,7 +1240,7 @@ TEST_CASE(nested_if_then_else_program) ...@@ -1183,7 +1240,7 @@ TEST_CASE(nested_if_then_else_program)
auto else_mod_else_mod_cond = else_mod->add_parameter("else_mod_cond", cond_s); auto else_mod_else_mod_cond = else_mod->add_parameter("else_mod_cond", cond_s);
auto x_else_mod_3 = auto x_else_mod_3 =
else_mod->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 2}}), else_mod->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 2}}),
{else_mod_else_mod_param_1, else_mod_else_mod_param_0}, {else_mod_else_mod_param_0, else_mod_else_mod_param_1},
{target_2_0}); {target_2_0});
auto x_else_mod_4 = else_mod->add_instruction( auto x_else_mod_4 = else_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_else_mod_3); migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_else_mod_3);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment