Commit 6e392848 authored by Umang Yadav's avatar Umang Yadav
Browse files

Additional Tests

parent b6ed1d6e
...@@ -149,30 +149,58 @@ struct auto_gen_root_modules ...@@ -149,30 +149,58 @@ struct auto_gen_root_modules
} }
} }
bool is_different_subgraph(migraphx::instruction_ref ins, std::optional<std::size_t> tid) bool is_different_subgraph(migraphx::instruction_ref ins,
std::optional<std::size_t> previous_tid)
{ {
if(tass.find(ins) == tass.end()) if(tass.find(ins) == tass.end())
{ {
return tid.has_value(); return previous_tid.has_value();
} }
return tass.at(ins) != tid.value_or(std::numeric_limits<std::size_t>::max()); return tass.at(ins) != previous_tid.value_or(std::numeric_limits<std::size_t>::max());
} }
bool is_merge_node(migraphx::instruction_ref ins, std::optional<std::size_t> tid) /*
Merge node is defined as node where two or more branches converge.
NodeX NodeY
| |
---------
|
NodeZ
For the partitioner, if any of the merge node's input doesn't have same tid as the merge node
itself then, it is classified as boundary for subgraph.
*/
bool is_merge_node(migraphx::instruction_ref ins, std::optional<std::size_t> ins_tid)
{ {
const auto inputs = ins->inputs(); const auto inputs = ins->inputs();
if(inputs.size() == 1) if(inputs.size() == 1)
{ {
return false; return false;
} }
return std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) { return std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) {
return ((skip_ins.find(input_ins) != skip_ins.end()) or if(tass.find(input_ins) == tass.end())
(tass.find(input_ins) != tass.end() and {
tass.at(ins) != tid.value_or(std::numeric_limits<std::size_t>::max()))); return ins_tid.has_value();
}
else
{
return tass.at(input_ins) !=
ins_tid.value_or(std::numeric_limits<std::size_t>::max());
}
}); });
} }
bool is_fork_node(migraphx::instruction_ref ins, std::optional<std::size_t> tid) /*
Fork node is defined as node where graph forks into two or more branches
NodeX
|
------------
| |
NodeY NodeZ
For the partitioner, if any of the fork node's output doesn't have same tid as the fork node
itself then, it is classified as boundary for subgraph.
*/
bool is_fork_node(migraphx::instruction_ref ins, std::optional<std::size_t> ins_tid)
{ {
const auto outputs = ins->outputs(); const auto outputs = ins->outputs();
if(outputs.size() == 1) if(outputs.size() == 1)
...@@ -180,9 +208,15 @@ struct auto_gen_root_modules ...@@ -180,9 +208,15 @@ struct auto_gen_root_modules
return false; return false;
} }
return std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) { return std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) {
return (tass.find(output_ins) != tass.end() and if(tass.find(output_ins) == tass.end())
tass.at(output_ins) != tid.value_or(std::numeric_limits<std::size_t>::max()) and {
output_ins->name() != "@return"); return ins_tid.has_value();
}
else
{
return tass.at(output_ins) !=
ins_tid.value_or(std::numeric_limits<std::size_t>::max());
};
}); });
} }
......
...@@ -317,7 +317,7 @@ TEST_CASE(if_then_else_program) ...@@ -317,7 +317,7 @@ TEST_CASE(if_then_else_program)
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
} }
TEST_CASE(fork_case) TEST_CASE(fork_case_1)
{ {
/* /*
Add (tid = 0) Add (tid = 0)
...@@ -351,14 +351,12 @@ TEST_CASE(fork_case) ...@@ -351,14 +351,12 @@ TEST_CASE(fork_case)
migraphx::program p2; migraphx::program p2;
{ {
migraphx::module_ref mm = p2.get_main_module(); migraphx::module_ref mm = p2.get_main_module();
auto y_param = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {8}}); auto y_param = mm->add_parameter("y", s);
auto x_param = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {8}}); auto x_param = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0"); migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter( auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
"param:1", migraphx::shape{migraphx::shape::float_type, {8}}); auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}});
auto target_mod_0_0_add = target_mod_0_0->add_instruction( auto target_mod_0_0_add = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0); migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({target_mod_0_0_add}); target_mod_0_0->add_return({target_mod_0_0_add});
...@@ -368,10 +366,9 @@ TEST_CASE(fork_case) ...@@ -368,10 +366,9 @@ TEST_CASE(fork_case)
{target_mod_0_0}); {target_mod_0_0});
auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2); auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2);
auto z_param = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {8}}); auto z_param = mm->add_parameter("z", s);
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0"); migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter( auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter("param:0", s);
"param:0", migraphx::shape{migraphx::shape::float_type, {8}});
auto target_mod_1_0_identity = auto target_mod_1_0_identity =
target_mod_1_0->add_instruction(migraphx::make_op("identity"), target_mod_1_0_param_0); target_mod_1_0->add_instruction(migraphx::make_op("identity"), target_mod_1_0_param_0);
target_mod_1_0->add_return({target_mod_1_0_identity}); target_mod_1_0->add_return({target_mod_1_0_identity});
...@@ -381,10 +378,8 @@ TEST_CASE(fork_case) ...@@ -381,10 +378,8 @@ TEST_CASE(fork_case)
auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5); auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5);
migraphx::module_ref target_mod_0_1 = p2.create_module("target_mod_0_1"); migraphx::module_ref target_mod_0_1 = p2.create_module("target_mod_0_1");
auto target_mod_0_1_param_1 = target_mod_0_1->add_parameter( auto target_mod_0_1_param_1 = target_mod_0_1->add_parameter("param:1", s);
"param:1", migraphx::shape{migraphx::shape::float_type, {8}}); auto target_mod_0_1_param_0 = target_mod_0_1->add_parameter("param:0", s);
auto target_mod_0_1_param_0 = target_mod_0_1->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}});
auto target_mod_0_1_mul = target_mod_0_1->add_instruction( auto target_mod_0_1_mul = target_mod_0_1->add_instruction(
migraphx::make_op("mul"), target_mod_0_1_param_1, target_mod_0_1_param_0); migraphx::make_op("mul"), target_mod_0_1_param_1, target_mod_0_1_param_0);
target_mod_0_1->add_return({target_mod_0_1_mul}); target_mod_0_1->add_return({target_mod_0_1_mul});
...@@ -398,7 +393,176 @@ TEST_CASE(fork_case) ...@@ -398,7 +393,176 @@ TEST_CASE(fork_case)
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
}; };
TEST_CASE(merge_case) TEST_CASE(fork_case_2)
{
/*
Add (no assignment)
|
---------------
| |
Mul Identity
(no assignment) (no assignment)
| |
Return Return
*/
auto s = migraphx::shape{migraphx::shape::float_type, {8}};
migraphx::target_assignments tass;
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto mul_ins = mm->add_instruction(migraphx::make_op("mul"), add_ins, z_param);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), add_ins);
mm->add_return({mul_ins, identity_ins});
}
migraphx::program p2 = p1;
migraphx::generate_root_modules(p1, tass);
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(fork_case_3)
{
/*
Add (no assignment)
|
---------------
| |
Mul Identity
(tid = 0) (no assignment)
| |
Return Return
*/
auto s = migraphx::shape{migraphx::shape::float_type, {8}};
migraphx::target_assignments tass;
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto mul_ins = mm->add_instruction(migraphx::make_op("mul"), add_ins, z_param);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), add_ins);
mm->add_return({mul_ins, identity_ins});
tass.insert(tass.begin(), std::make_pair(mul_ins, 0));
}
migraphx::generate_root_modules(p1, tass);
migraphx::program p2;
{
migraphx::module_ref mm = p2.get_main_module();
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
auto x_2 = mm->add_instruction(migraphx::make_op("add"), x, y);
auto x_3 = mm->add_instruction(migraphx::make_op("identity"), x_2);
auto z = mm->add_parameter("z", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("mul"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_5 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {z, x_2}, {target_mod_0_0});
auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5);
mm->add_return({x_6, x_3});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(fork_case_4)
{
/*
**** Fork node returning ****
Add (tid = 0)
|
---------------------------
| |
Identity (tid = 0) Return
|
Return
*/
auto s = migraphx::shape{migraphx::shape::float_type, {8}};
migraphx::target_assignments tass;
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), add_ins);
mm->add_return({add_ins, identity_ins});
tass.insert(tass.begin(), std::make_pair(add_ins, 0));
tass.insert(tass.begin(), std::make_pair(identity_ins, 0));
}
migraphx::generate_root_modules(p1, tass);
migraphx::program p2;
{
migraphx::module_ref mm = p2.get_main_module();
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_2});
migraphx::module_ref target_mod_0_1 = p2.create_module("target_mod_0_1");
auto target_mod_0_1_param_0 = target_mod_0_1->add_parameter("param:0", s);
auto x_target_mod_0_1_1 =
target_mod_0_1->add_instruction(migraphx::make_op("identity"), target_mod_0_1_param_0);
target_mod_0_1->add_return({x_target_mod_0_1_1});
auto x_2 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {y, x}, {target_mod_0_0});
auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2);
auto x_4 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x_3}, {target_mod_0_1});
auto x_5 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_4);
mm->add_return({x_3, x_5});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(fork_case_5)
{
/*
**** Fork node returning ****
Add (no target assignment)
|
---------------------------
| |
Identity Return
(no target assignment)
|
Return
*/
auto s = migraphx::shape{migraphx::shape::float_type, {8}};
migraphx::target_assignments tass;
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), add_ins);
mm->add_return({add_ins, identity_ins});
}
migraphx::program p2 = p1;
migraphx::generate_root_modules(p1, tass);
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(merge_case_1)
{ {
/* /*
Add Identity Add Identity
...@@ -476,6 +640,95 @@ TEST_CASE(merge_case) ...@@ -476,6 +640,95 @@ TEST_CASE(merge_case)
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
}; };
TEST_CASE(merge_case_2)
{
/*
Add Identity
(no assignment) (no assignment)
| |
-----------------
|
Mul (no assignment)
|
Return
*/
migraphx::target_assignments tass;
auto s = migraphx::shape{migraphx::shape::float_type, {8}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), z_param);
auto mul_ins = mm->add_instruction(migraphx::make_op("mul"), add_ins, identity_ins);
mm->add_return({mul_ins});
}
migraphx::program p2 = p1;
migraphx::generate_root_modules(p1, tass);
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(merge_case_3)
{
/*
Add Identity
(tid=0) (no assignment)
| |
-----------------
|
Mul (no assignment)
|
Return
*/
migraphx::target_assignments tass;
auto s = migraphx::shape{migraphx::shape::float_type, {8}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), z_param);
auto mul_ins = mm->add_instruction(migraphx::make_op("mul"), add_ins, identity_ins);
mm->add_return({mul_ins});
tass.insert(tass.begin(), std::make_pair(add_ins, 0));
}
migraphx::generate_root_modules(p1, tass);
migraphx::program p2;
{
migraphx::module_ref mm = p2.get_main_module();
auto z = mm->add_parameter("z", s);
auto x_1 = mm->add_instruction(migraphx::make_op("identity"), z);
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter(
"param:1", migraphx::shape{migraphx::shape::float_type, {8}});
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}});
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_4 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {y, x}, {target_mod_0_0});
auto x_5 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_4);
auto x_6 = mm->add_instruction(migraphx::make_op("mul"), x_5, x_1);
mm->add_return({x_6});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(merge_case_4)
{
// return as the merge node
}
TEST_CASE(fork_and_merge_case) TEST_CASE(fork_and_merge_case)
{ {
/* /*
......
...@@ -301,7 +301,7 @@ TEST_CASE(multitarget_compile_nested_if_then_else) ...@@ -301,7 +301,7 @@ TEST_CASE(multitarget_compile_nested_if_then_else)
auto y = mm->add_parameter("y", ds); auto y = mm->add_parameter("y", ds);
auto z = mm->add_parameter("z", ds); auto z = mm->add_parameter("z", ds);
auto create_test_module = auto create_test_module =
[&](migraphx::program& prog, std::size_t tid, std::string param_prefix) { [&](migraphx::program& prog, std::size_t tid, const std::string& param_prefix) {
std::string mod_name = std::string mod_name =
"target_" + std::to_string(tid) + "_" + std::to_string(counter_map[tid]++); "target_" + std::to_string(tid) + "_" + std::to_string(counter_map[tid]++);
auto* test_mod = prog.create_module(mod_name); auto* test_mod = prog.create_module(mod_name);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment