Commit fa12da23 authored by Umang Yadav's avatar Umang Yadav
Browse files

Changes for the order fix

parent 1e80ceef
......@@ -174,18 +174,51 @@ struct auto_gen_root_modules
*/
bool is_merge_node(migraphx::instruction_ref ins, std::optional<std::size_t> ins_tid)
{
const auto inputs = ins->inputs();
if(inputs.size() == 1)
size_t in_degree = inputs.size();
if(in_degree == 1)
{
return false;
}
return std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) {
return (
(this->skip_ins.find(input_ins) != skip_ins.end()) or
(tass.find(input_ins) != tass.end() and
tass.at(input_ins) != ins_tid.value_or(std::numeric_limits<std::size_t>::max())));
});
size_t input_from_other_tid_module = 0;
size_t num_default_tids = 0;
size_t num_different_tids = 0;
size_t num_same_tid = 0;
// std::unordered_map<size_t, size_t> in_tid_freq_map;
for(const auto& input_ins : inputs)
{
if(skip_ins.find(input_ins) != skip_ins.end())
{
input_from_other_tid_module++;
}
else if(tass.find(input_ins) == tass.end())
{
num_default_tids++;
}
else if(tass.at(input_ins) != ins_tid)
{
num_different_tids++;
}
else
{
num_same_tid++;
}
}
assert(input_from_other_tid_module + num_default_tids + num_different_tids + num_same_tid ==
in_degree);
if(input_from_other_tid_module > 1)
{
return true;
}
else if(input_from_other_tid_module + num_default_tids == in_degree)
{
return false;
}
else if(num_same_tid + num_default_tids == in_degree)
{
return false;
}
return true;
}
/*
......@@ -200,21 +233,47 @@ struct auto_gen_root_modules
For the partitioner, if any of the fork node's output doesn't have same tid as the fork node
itself then, it is classified as boundary for subgraph.
*/
bool is_fork_node(migraphx::instruction_ref ins, std::optional<std::size_t> ins_tid)
bool is_fork_node(migraphx::instruction_ref ins, std::size_t ins_tid)
{
const auto outputs = ins->outputs();
if(outputs.size() == 1)
{
return false;
}
// if all the outputs are for the "default" or with same tid then it is not a fork but
// rather simply a boundary
std::unordered_map<std::size_t, std::size_t> output_tids;
for(const auto& output_ins : outputs)
{
if(tass.find(output_ins) != tass.end())
{
auto out_tid = tass.at(output_ins);
if(output_tids.find(out_tid) == output_tids.end())
{
output_tids[out_tid] = 1;
}
else
{
output_tids[out_tid]++;
}
}
}
if(output_tids.empty())
{
return false;
}
else if(output_tids.size() == 1 and output_tids.cbegin()->second == outputs.size())
{
return false;
}
return std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) {
if(output_ins->name() == "return")
{
return false;
}
return (tass.find(output_ins) != tass.end() and
tass.at(output_ins) !=
ins_tid.value_or(std::numeric_limits<std::size_t>::max()));
return (tass.find(output_ins) != tass.end() and tass.at(output_ins) != ins_tid);
});
}
......@@ -262,7 +321,7 @@ struct auto_gen_root_modules
current_tid = std::make_optional<std::size_t>(tass.at(ins));
same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins);
fork_node = is_fork_node(ins, current_tid);
fork_node = is_fork_node(ins, current_tid.value());
}
}
else
......@@ -281,7 +340,8 @@ struct auto_gen_root_modules
{
MIGRAPHX_THROW("GenerateRootModules: this case shouldn't occur");
}
fork_node = is_fork_node(ins, current_tid);
fork_node = is_fork_node(
ins, current_tid.value_or(std::numeric_limits<std::size_t>::max()));
}
if(not ins->module_inputs().empty())
......@@ -315,7 +375,8 @@ struct auto_gen_root_modules
return;
}
// gather all parameters
std::unordered_set<instruction_ref> params;
std::unordered_set<instruction_ref> params_set;
std::vector<instruction_ref> params_vec;
// gather all return values
std::vector<instruction_ref> return_ins;
for(auto tins : iterator_for(same_tid_ins_vec))
......@@ -325,11 +386,15 @@ struct auto_gen_root_modules
transform_if(
inputs.cbegin(),
inputs.cend(),
std::inserter(params, params.end()),
std::back_inserter(params_vec),
[&](auto in_param) {
return (params.count(in_param) == 0 and same_tid_ins_set.count(in_param) == 0);
return (params_set.count(in_param) == 0 and
same_tid_ins_set.count(in_param) == 0);
},
[&](auto in_param) { return in_param; });
[&](auto in_param) {
params_set.insert(in_param);
return in_param;
});
if(std::any_of(outputs.begin(), outputs.end(), [&](const auto out_ins) {
return same_tid_ins_set.count(out_ins) == 0;
}))
......@@ -340,7 +405,7 @@ struct auto_gen_root_modules
if(enabled(MIGRAPHX_DEBUG_ROOT_GENERATOR{}))
{
std::cout << "params ins: \n";
for(auto tmp : iterator_for(params))
for(auto tmp : iterator_for(params_vec))
{
(*tmp)->debug_print();
}
......@@ -357,7 +422,7 @@ struct auto_gen_root_modules
std::unordered_map<instruction_ref, instruction_ref> params_map;
std::size_t param_counter = 0;
std::vector<instruction_ref> rot_ins_params;
for(auto pins : iterator_for(params))
for(auto pins : iterator_for(params_vec))
{
auto scalar = get_scalar(*pins);
if(scalar.empty())
......
......@@ -69,14 +69,14 @@ TEST_CASE(single_target_test)
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter("param:1", s);
auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter("param:0", s);
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter("param:1", s);
auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction(
migraphx::make_op("add"), target_mod_1_0_param_1, target_mod_1_0_param_0);
migraphx::make_op("add"), target_mod_1_0_param_0, target_mod_1_0_param_1);
target_mod_1_0->add_return({x_target_mod_1_0_2});
auto x_2 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 1}}), {y, x}, {target_mod_1_0});
migraphx::make_op("run_on_target", {{"target_id", 1}}), {x, y}, {target_mod_1_0});
auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2);
mm->add_return({x_3});
}
......@@ -115,36 +115,33 @@ TEST_CASE(two_targets_with_ref)
migraphx::generate_root_modules(p1, tass);
migraphx::program p2;
{
migraphx::module_ref mm = p2.get_main_module();
auto z = mm->add_parameter("z", s);
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
auto identity_ins_0 = mm->add_instruction(migraphx::make_op("identity"), x);
migraphx::module_ref mm = p2.get_main_module();
auto z = mm->add_parameter("z", s);
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
auto identity_ins_0 = mm->add_instruction(migraphx::make_op("identity"), x);
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter(
"param:1", migraphx::shape{migraphx::shape::float_type, {8}});
auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}});
auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction(
migraphx::make_op("add"), target_mod_1_0_param_1, target_mod_1_0_param_0);
auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter("param:0", s);
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter("param:1", s);
auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction(
migraphx::make_op("add"), target_mod_1_0_param_0, target_mod_1_0_param_1);
target_mod_1_0->add_return({x_target_mod_1_0_2});
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter(
"param:1", migraphx::shape{migraphx::shape::float_type, {8}});
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}});
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("mul"), target_mod_0_0_param_1, target_mod_0_0_param_0);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("mul"), target_mod_0_0_param_0, target_mod_0_0_param_1);
target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_3 = mm->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 1}}),
{y, identity_ins_0},
{identity_ins_0, y},
{target_mod_1_0});
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3);
auto x_5 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {z, x_4}, {target_mod_0_0});
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x_4, z}, {target_mod_0_0});
auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5);
auto identity_ins_1 = mm->add_instruction(migraphx::make_op("identity"), x_6);
mm->add_return({identity_ins_1});
......@@ -191,37 +188,100 @@ TEST_CASE(two_targets_ref_inbetween)
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), x);
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter(
"param:1", migraphx::shape{migraphx::shape::float_type, {8}});
auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}});
auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction(
migraphx::make_op("add"), target_mod_1_0_param_1, target_mod_1_0_param_0);
auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter("param:0", s);
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter("param:1", s);
auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction(
migraphx::make_op("add"), target_mod_1_0_param_0, target_mod_1_0_param_1);
target_mod_1_0->add_return({x_target_mod_1_0_2});
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter(
"param:1", migraphx::shape{migraphx::shape::float_type, {8}});
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}});
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("mul"), target_mod_0_0_param_1, target_mod_0_0_param_0);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("mul"), target_mod_0_0_param_0, target_mod_0_0_param_1);
target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_3 = mm->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 1}}),
{y, identity_ins},
{identity_ins, y},
{target_mod_1_0});
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3);
auto x_5 = mm->add_instruction(migraphx::make_op("identity"), x_4);
auto x_6 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {z, x_5}, {target_mod_0_0});
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x_5, z}, {target_mod_0_0});
auto x_7 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_6);
mm->add_return({x_7});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(multiple_targets_multiple_returns)
{
/*
Add (tid = 1)
|
---------------
| |
Mul Identity
(tid = 0) (tid = 0)
| |
---------------
|
Return
*/
auto s = migraphx::shape{migraphx::shape::float_type, {8}};
migraphx::target_assignments tass;
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), add_ins);
auto mul_ins = mm->add_instruction(migraphx::make_op("mul"), add_ins, z_param);
mm->add_return({mul_ins, identity_ins});
tass.insert(tass.begin(), std::make_pair(add_ins, 1));
tass.insert(tass.begin(), std::make_pair(mul_ins, 0));
tass.insert(tass.begin(), std::make_pair(identity_ins, 0));
}
migraphx::generate_root_modules(p1, tass);
migraphx::program p2;
{
migraphx::module_ref mm = p2.get_main_module();
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter("param:0", s);
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter("param:1", s);
auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction(
migraphx::make_op("add"), target_mod_1_0_param_0, target_mod_1_0_param_1);
target_mod_1_0->add_return({x_target_mod_1_0_2});
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto x_target_mod_0_0_2 =
target_mod_0_0->add_instruction(migraphx::make_op("identity"), target_mod_0_0_param_0);
auto x_target_mod_0_0_3 = target_mod_0_0->add_instruction(
migraphx::make_op("mul"), target_mod_0_0_param_0, target_mod_0_0_param_1);
target_mod_0_0->add_return({x_target_mod_0_0_2, x_target_mod_0_0_3});
auto x_2 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 1}}), {x, y}, {target_mod_1_0});
auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2);
auto z = mm->add_parameter("z", s);
auto x_5 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x_3, z}, {target_mod_0_0});
auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5);
auto x_7 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), x_5);
mm->add_return({x_7, x_6});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(single_target_multiple_returns)
{
/*
......@@ -256,27 +316,28 @@ TEST_CASE(single_target_multiple_returns)
migraphx::generate_root_modules(p1, tass);
migraphx::program p2;
{
migraphx::module_ref mm = p2.get_main_module();
auto z = mm->add_parameter("z", s);
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
migraphx::module_ref mm = p2.get_main_module();
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
auto z = mm->add_parameter("z", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_2 = target_mod_0_0->add_parameter("param:2", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_2, target_mod_0_0_param_1);
auto x_target_mod_0_0_3 =
target_mod_0_0->add_instruction(migraphx::make_op("identity"), x_target_mod_0_0_2);
auto x_target_mod_0_0_4 = target_mod_0_0->add_instruction(
migraphx::make_op("mul"), x_target_mod_0_0_2, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_3, x_target_mod_0_0_4});
auto x_target_mod_0_0_3 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_0, target_mod_0_0_param_1);
auto x_target_mod_0_0_4 =
target_mod_0_0->add_instruction(migraphx::make_op("identity"), x_target_mod_0_0_3);
auto x_target_mod_0_0_5 = target_mod_0_0->add_instruction(
migraphx::make_op("mul"), x_target_mod_0_0_3, target_mod_0_0_param_2);
target_mod_0_0->add_return({x_target_mod_0_0_4, x_target_mod_0_0_5});
auto x_2 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {z, y, x}, {target_mod_0_0});
auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2);
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), x_2);
mm->add_return({x_4, x_3});
auto x_3 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x, y, z}, {target_mod_0_0});
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3);
auto x_5 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), x_3);
mm->add_return({x_5, x_4});
}
EXPECT(p1.sort() == p2.sort());
}
......@@ -332,38 +393,34 @@ TEST_CASE(if_then_else_program)
auto cond = mm->add_parameter("cond", cond_s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter(
"param:1", migraphx::shape{migraphx::shape::float_type, {2, 3}});
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {2, 3}});
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", ds);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", ds);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_0, target_mod_0_0_param_1);
target_mod_0_0->add_return({x_target_mod_0_0_2});
migraphx::module_ref if_gpu_mod = p2.create_module("if_gpu_mod");
auto x_if_gpu_mod_0 = if_gpu_mod->add_literal(migraphx::literal(ds, data1));
auto x_if_gpu_mod_1 =
if_gpu_mod->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 0}}),
{x_if_gpu_mod_0, x},
{x, x_if_gpu_mod_0},
{target_mod_0_0});
auto x_if_gpu_mod_2 = if_gpu_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_if_gpu_mod_1);
if_gpu_mod->add_return({x_if_gpu_mod_2});
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter(
"param:1", migraphx::shape{migraphx::shape::float_type, {2, 3}});
auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {2, 3}});
auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction(
migraphx::make_op("mul"), target_mod_1_0_param_1, target_mod_1_0_param_0);
auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter("param:0", ds);
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter("param:1", ds);
auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction(
migraphx::make_op("mul"), target_mod_1_0_param_0, target_mod_1_0_param_1);
target_mod_1_0->add_return({x_target_mod_1_0_2});
migraphx::module_ref else_cpu_mod = p2.create_module("else_cpu_mod");
auto x_else_cpu_mod_0 = else_cpu_mod->add_literal(migraphx::literal(ds, data2));
auto x_else_cpu_mod_1 =
else_cpu_mod->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 1}}),
{x_else_cpu_mod_0, y},
{y, x_else_cpu_mod_0},
{target_mod_1_0});
auto x_else_cpu_mod_2 = else_cpu_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_else_cpu_mod_1);
......@@ -424,25 +481,25 @@ TEST_CASE(merge_case_1)
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0);
migraphx::make_op("add"), target_mod_0_0_param_0, target_mod_0_0_param_1);
target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_5 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {y, x}, {target_mod_0_0});
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x, y}, {target_mod_0_0});
auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5);
migraphx::module_ref target_mod_0_1 = p2.create_module("target_mod_0_1");
auto target_mod_0_1_param_1 = target_mod_0_1->add_parameter("param:1", s);
auto target_mod_0_1_param_0 = target_mod_0_1->add_parameter("param:0", s);
auto target_mod_0_1_param_1 = target_mod_0_1->add_parameter("param:1", s);
auto x_target_mod_0_1_2 = target_mod_0_1->add_instruction(
migraphx::make_op("mul"), target_mod_0_1_param_1, target_mod_0_1_param_0);
migraphx::make_op("mul"), target_mod_0_1_param_0, target_mod_0_1_param_1);
target_mod_0_1->add_return({x_target_mod_0_1_2});
auto x_7 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x_4, x_6}, {target_mod_0_1});
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x_6, x_4}, {target_mod_0_1});
auto x_8 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_7);
mm->add_return({x_8});
}
......@@ -517,14 +574,14 @@ TEST_CASE(merge_case_3)
auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0);
migraphx::make_op("add"), target_mod_0_0_param_0, target_mod_0_0_param_1);
target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_4 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {y, x}, {target_mod_0_0});
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x, y}, {target_mod_0_0});
auto x_5 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_4);
auto x_6 = mm->add_instruction(migraphx::make_op("mul"), x_5, x_1);
mm->add_return({x_6});
......@@ -566,14 +623,14 @@ TEST_CASE(merge_case_4)
auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0);
migraphx::make_op("add"), target_mod_0_0_param_0, target_mod_0_0_param_1);
target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_4 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {y, x}, {target_mod_0_0});
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x, y}, {target_mod_0_0});
auto x_5 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_4);
mm->add_return({x_5, x_1});
}
......@@ -617,14 +674,14 @@ TEST_CASE(merge_case_5)
auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0);
migraphx::make_op("add"), target_mod_0_0_param_0, target_mod_0_0_param_1);
target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_4 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {y, x}, {target_mod_0_0});
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x, y}, {target_mod_0_0});
auto x_5 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_4);
auto x_6 = mm->add_instruction(migraphx::make_op("identity"), x_5);
mm->add_return({x_6, x_1});
......@@ -676,25 +733,25 @@ TEST_CASE(fork_and_merge_case_1)
auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0);
migraphx::make_op("add"), target_mod_0_0_param_0, target_mod_0_0_param_1);
target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_3 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {y, x}, {target_mod_0_0});
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x, y}, {target_mod_0_0});
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3);
migraphx::module_ref target_mod_0_1 = p2.create_module("target_mod_0_1");
auto target_mod_0_1_param_1 = target_mod_0_1->add_parameter("param:1", s);
auto target_mod_0_1_param_0 = target_mod_0_1->add_parameter("param:0", s);
auto target_mod_0_1_param_1 = target_mod_0_1->add_parameter("param:1", s);
auto x_target_mod_0_1_2 = target_mod_0_1->add_instruction(
migraphx::make_op("mul"), target_mod_0_1_param_1, target_mod_0_1_param_0);
migraphx::make_op("mul"), target_mod_0_1_param_0, target_mod_0_1_param_1);
target_mod_0_1->add_return({x_target_mod_0_1_2});
auto x_5 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {z, x_4}, {target_mod_0_1});
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x_4, z}, {target_mod_0_1});
auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5);
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
......@@ -708,14 +765,14 @@ TEST_CASE(fork_and_merge_case_1)
auto x_8 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_7);
migraphx::module_ref target_mod_0_2 = p2.create_module("target_mod_0_2");
auto target_mod_0_2_param_1 = target_mod_0_2->add_parameter("param:1", s);
auto target_mod_0_2_param_0 = target_mod_0_2->add_parameter("param:0", s);
auto target_mod_0_2_param_1 = target_mod_0_2->add_parameter("param:1", s);
auto x_target_mod_0_2_2 = target_mod_0_2->add_instruction(
migraphx::make_op("sub"), target_mod_0_2_param_1, target_mod_0_2_param_0);
migraphx::make_op("sub"), target_mod_0_2_param_0, target_mod_0_2_param_1);
target_mod_0_2->add_return({x_target_mod_0_2_2});
auto x_9 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x_6, x_8}, {target_mod_0_2});
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x_8, x_6}, {target_mod_0_2});
auto x_10 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_9);
mm->add_return({x_10});
}
......@@ -761,13 +818,13 @@ TEST_CASE(fork_and_return_as_merge_bypass_branch_and_tass_on_other)
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto x_target_mod_0_0_0 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0);
migraphx::make_op("add"), target_mod_0_0_param_0, target_mod_0_0_param_1);
auto x_target_mod_0_0_1 =
target_mod_0_0->add_instruction(migraphx::make_op("identity"), x_target_mod_0_0_0);
target_mod_0_0->add_return({x_target_mod_0_0_0, x_target_mod_0_0_1});
auto x_2 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {y, x}, {target_mod_0_0});
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x, y}, {target_mod_0_0});
auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2);
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), x_2);
mm->add_return({x_3, x_4});
......@@ -813,14 +870,14 @@ TEST_CASE(fork_and_return_as_merge_bypass_branch_and_no_tass_on_other)
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0);
migraphx::make_op("add"), target_mod_0_0_param_0, target_mod_0_0_param_1);
target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_2 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {y, x}, {target_mod_0_0});
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x, y}, {target_mod_0_0});
auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2);
auto x_4 = mm->add_instruction(migraphx::make_op("identity"), x_3);
auto x_5 = mm->add_instruction(migraphx::make_op("identity"), x_4);
......@@ -869,14 +926,14 @@ TEST_CASE(fork_and_return_as_merge_different_tass_on_both_branches)
auto x_param = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_add = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0);
migraphx::make_op("add"), target_mod_0_0_param_0, target_mod_0_0_param_1);
target_mod_0_0->add_return({target_mod_0_0_add});
auto x_2 = mm->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 0}}),
{y_param, x_param},
{x_param, y_param},
{target_mod_0_0});
auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2);
......@@ -892,14 +949,14 @@ TEST_CASE(fork_and_return_as_merge_different_tass_on_both_branches)
auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5);
migraphx::module_ref target_mod_0_1 = p2.create_module("target_mod_0_1");
auto target_mod_0_1_param_1 = target_mod_0_1->add_parameter("param:1", s);
auto target_mod_0_1_param_0 = target_mod_0_1->add_parameter("param:0", s);
auto target_mod_0_1_param_1 = target_mod_0_1->add_parameter("param:1", s);
auto target_mod_0_1_mul = target_mod_0_1->add_instruction(
migraphx::make_op("mul"), target_mod_0_1_param_1, target_mod_0_1_param_0);
migraphx::make_op("mul"), target_mod_0_1_param_0, target_mod_0_1_param_1);
target_mod_0_1->add_return({target_mod_0_1_mul});
auto x_7 = mm->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 0}}),
{z_param, x_3},
{x_3, z_param},
{target_mod_0_1});
auto x_8 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_7);
mm->add_return({x_8, x_6});
......@@ -980,14 +1037,14 @@ TEST_CASE(fork_and_return_as_merge_no_tass_on_one_branch)
auto z = mm->add_parameter("z", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("mul"), target_mod_0_0_param_1, target_mod_0_0_param_0);
migraphx::make_op("mul"), target_mod_0_0_param_0, target_mod_0_0_param_1);
target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_5 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {z, x_2}, {target_mod_0_0});
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x_2, z}, {target_mod_0_0});
auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5);
mm->add_return({x_6, x_3});
}
......@@ -1095,15 +1152,15 @@ TEST_CASE(nested_if_then_else_program)
auto test_mod_param_0 = test_mod->add_parameter("param:0", ds);
auto test_mod_param_1 = test_mod->add_parameter("param:1", ds);
auto ins1 = test_mod->add_instruction(
migraphx::make_op("add"), test_mod_param_1, test_mod_param_0);
migraphx::make_op("add"), test_mod_param_0, test_mod_param_1);
test_mod->add_return({ins1});
tass.insert(tass.begin(), std::make_pair(ins1, tid));
return test_mod;
};
migraphx::module_ref target_1_0 = p2.create_module("target_1_0");
auto target_1_0_1_param_0 = target_1_0->add_literal(ds, data);
auto target_1_0_1_param_1 = target_1_0->add_parameter("1__param_0", ds);
auto target_1_0_1_param_0 = target_1_0->add_parameter("1__param_0", ds);
auto target_1_0_1_param_1 = target_1_0->add_literal(ds, data);
auto x_target_1_0_2 =
target_1_0->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 1}}),
{target_1_0_1_param_0, target_1_0_1_param_1},
......@@ -1113,8 +1170,8 @@ TEST_CASE(nested_if_then_else_program)
target_1_0->add_return({x_target_1_0_3});
migraphx::module_ref target_0_0 = p2.create_module("target_0_0");
auto target_0_0_2_param_0 = target_0_0->add_literal(ds, data);
auto target_0_0_2_param_1 = target_0_0->add_parameter("2__param_0", ds);
auto target_0_0_2_param_0 = target_0_0->add_parameter("2__param_0", ds);
auto target_0_0_2_param_1 = target_0_0->add_literal(ds, data);
auto x_target_0_0_2 =
target_0_0->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 0}}),
{target_0_0_2_param_0, target_0_0_2_param_1},
......@@ -1124,10 +1181,10 @@ TEST_CASE(nested_if_then_else_program)
target_0_0->add_return({x_target_0_0_3});
migraphx::module_ref target_3_0 = p2.create_module("target_mod_3_0");
auto target_mod_3_0_param_1 = target_3_0->add_parameter("param:1", ds);
auto target_mod_3_0_param_0 = target_3_0->add_parameter("param:0", ds);
auto target_mod_3_0_param_1 = target_3_0->add_parameter("param:1", ds);
auto target_3_add_ins = target_3_0->add_instruction(
migraphx::make_op("add"), target_mod_3_0_param_1, target_mod_3_0_param_0);
migraphx::make_op("add"), target_mod_3_0_param_0, target_mod_3_0_param_1);
target_3_0->add_return({target_3_add_ins});
migraphx::module_ref then_mod = p2.create_module("then_mod");
......@@ -1136,7 +1193,7 @@ TEST_CASE(nested_if_then_else_program)
auto then_mod_then_mod_cond = then_mod->add_parameter("then_mod_cond", cond_s);
auto x_then_mod_3 =
then_mod->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 3}}),
{then_mod_then_mod_param_1, then_mod_then_mod_param_0},
{then_mod_then_mod_param_0, then_mod_then_mod_param_1},
{target_3_0});
auto x_then_mod_4 = then_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_then_mod_3);
......@@ -1149,8 +1206,8 @@ TEST_CASE(nested_if_then_else_program)
then_mod->add_return({x_then_mod_6});
migraphx::module_ref target_0_1 = p2.create_module("target_0_1");
auto target_0_1_1_param_0 = target_0_1->add_literal(ds, data);
auto target_0_1_1_param_1 = target_0_1->add_parameter("1__param_0", ds);
auto target_0_1_1_param_0 = target_0_1->add_parameter("1__param_0", ds);
auto target_0_1_1_param_1 = target_0_1->add_literal(ds, data);
auto x_target_0_1_2 =
target_0_1->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 0}}),
{target_0_1_1_param_0, target_0_1_1_param_1},
......@@ -1160,8 +1217,8 @@ TEST_CASE(nested_if_then_else_program)
target_0_1->add_return({x_target_0_1_3});
migraphx::module_ref target_1_1 = p2.create_module("target_1_1");
auto target_1_1_2_param_0 = target_1_1->add_literal(ds, data);
auto target_1_1_2_param_1 = target_1_1->add_parameter("2__param_0", ds);
auto target_1_1_2_param_0 = target_1_1->add_parameter("2__param_0", ds);
auto target_1_1_2_param_1 = target_1_1->add_literal(ds, data);
auto x_target_1_1_2 =
target_1_1->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 1}}),
{target_1_1_2_param_0, target_1_1_2_param_1},
......@@ -1171,10 +1228,10 @@ TEST_CASE(nested_if_then_else_program)
target_1_1->add_return({x_target_1_1_3});
migraphx::module_ref target_2_0 = p2.create_module("target_mod_2_0");
auto target_mod_2_0_param_1 = target_2_0->add_parameter("param:1", ds);
auto target_mod_2_0_param_0 = target_2_0->add_parameter("param:0", ds);
auto target_mod_2_0_param_1 = target_2_0->add_parameter("param:1", ds);
auto target_2_mul_ins = target_2_0->add_instruction(
migraphx::make_op("mul"), target_mod_2_0_param_1, target_mod_2_0_param_0);
migraphx::make_op("mul"), target_mod_2_0_param_0, target_mod_2_0_param_1);
target_2_0->add_return({target_2_mul_ins});
migraphx::module_ref else_mod = p2.create_module("else_mod");
......@@ -1183,7 +1240,7 @@ TEST_CASE(nested_if_then_else_program)
auto else_mod_else_mod_cond = else_mod->add_parameter("else_mod_cond", cond_s);
auto x_else_mod_3 =
else_mod->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 2}}),
{else_mod_else_mod_param_1, else_mod_else_mod_param_0},
{else_mod_else_mod_param_0, else_mod_else_mod_param_1},
{target_2_0});
auto x_else_mod_4 = else_mod->add_instruction(
migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_else_mod_3);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment