"src/vscode:/vscode.git/clone" did not exist on "1793cc54463f86a55afe7bf6f8b15ecaedb65ef3"
Commit 8e812ae2 authored by Umang Yadav's avatar Umang Yadav
Browse files

Fork and merge tests

parent 6e392848
...@@ -537,7 +537,9 @@ TEST_CASE(fork_case_5) ...@@ -537,7 +537,9 @@ TEST_CASE(fork_case_5)
/* /*
**** Fork node returning **** **** Fork node returning ****
Add (no target assignment) Add (tid = 0)
|
Identity (no target_assignment)
| |
--------------------------- ---------------------------
| | | |
...@@ -554,11 +556,31 @@ TEST_CASE(fork_case_5) ...@@ -554,11 +556,31 @@ TEST_CASE(fork_case_5)
auto x_param = mm->add_parameter("x", s); auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s); auto y_param = mm->add_parameter("y", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param); auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), add_ins); auto identity_ins_0 = mm->add_instruction(migraphx::make_op("identity"), add_ins);
mm->add_return({add_ins, identity_ins}); auto identity_ins_1 = mm->add_instruction(migraphx::make_op("identity"), identity_ins_0);
mm->add_return({identity_ins_0, identity_ins_1});
tass.insert(tass.begin(), std::make_pair(add_ins, 0));
} }
migraphx::program p2 = p1;
migraphx::generate_root_modules(p1, tass); migraphx::generate_root_modules(p1, tass);
migraphx::program p2;
{
migraphx::module_ref mm = p2.get_main_module();
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_2 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {y, x}, {target_mod_0_0});
auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2);
auto x_4 = mm->add_instruction(migraphx::make_op("identity"), x_3);
auto x_5 = mm->add_instruction(migraphx::make_op("identity"), x_4);
mm->add_return({x_4, x_5});
}
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
} }
...@@ -595,13 +617,12 @@ TEST_CASE(merge_case_1) ...@@ -595,13 +617,12 @@ TEST_CASE(merge_case_1)
migraphx::program p2; migraphx::program p2;
{ {
migraphx::module_ref mm = p2.get_main_module(); migraphx::module_ref mm = p2.get_main_module();
auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {8}}); auto z = mm->add_parameter("z", s);
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {8}}); auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {8}}); auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0"); migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter( auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter("param:0", s);
"param:0", migraphx::shape{migraphx::shape::float_type, {8}});
auto x_target_mod_1_0_1 = auto x_target_mod_1_0_1 =
target_mod_1_0->add_instruction(migraphx::make_op("identity"), target_mod_1_0_param_0); target_mod_1_0->add_instruction(migraphx::make_op("identity"), target_mod_1_0_param_0);
target_mod_1_0->add_return({x_target_mod_1_0_1}); target_mod_1_0->add_return({x_target_mod_1_0_1});
...@@ -611,10 +632,8 @@ TEST_CASE(merge_case_1) ...@@ -611,10 +632,8 @@ TEST_CASE(merge_case_1)
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3); auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0"); migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter( auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
"param:1", migraphx::shape{migraphx::shape::float_type, {8}}); auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}});
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction( auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0); migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_2}); target_mod_0_0->add_return({x_target_mod_0_0_2});
...@@ -624,10 +643,8 @@ TEST_CASE(merge_case_1) ...@@ -624,10 +643,8 @@ TEST_CASE(merge_case_1)
auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5); auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5);
migraphx::module_ref target_mod_0_1 = p2.create_module("target_mod_0_1"); migraphx::module_ref target_mod_0_1 = p2.create_module("target_mod_0_1");
auto target_mod_0_1_param_1 = target_mod_0_1->add_parameter( auto target_mod_0_1_param_1 = target_mod_0_1->add_parameter("param:1", s);
"param:1", migraphx::shape{migraphx::shape::float_type, {8}}); auto target_mod_0_1_param_0 = target_mod_0_1->add_parameter("param:0", s);
auto target_mod_0_1_param_0 = target_mod_0_1->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}});
auto x_target_mod_0_1_2 = target_mod_0_1->add_instruction( auto x_target_mod_0_1_2 = target_mod_0_1->add_instruction(
migraphx::make_op("mul"), target_mod_0_1_param_1, target_mod_0_1_param_0); migraphx::make_op("mul"), target_mod_0_1_param_1, target_mod_0_1_param_0);
target_mod_0_1->add_return({x_target_mod_0_1_2}); target_mod_0_1->add_return({x_target_mod_0_1_2});
...@@ -708,10 +725,8 @@ TEST_CASE(merge_case_3) ...@@ -708,10 +725,8 @@ TEST_CASE(merge_case_3)
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0"); migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter( auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
"param:1", migraphx::shape{migraphx::shape::float_type, {8}}); auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}});
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction( auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0); migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_2}); target_mod_0_0->add_return({x_target_mod_0_0_2});
...@@ -727,8 +742,104 @@ TEST_CASE(merge_case_3) ...@@ -727,8 +742,104 @@ TEST_CASE(merge_case_3)
TEST_CASE(merge_case_4) TEST_CASE(merge_case_4)
{ {
// return as the merge node /*
**** "Return" as the Merge Node ****
Add Identity
(tid=0) (no assignment)
| |
-----------------
|
Return
*/
migraphx::target_assignments tass;
auto s = migraphx::shape{migraphx::shape::float_type, {8}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), z_param);
mm->add_return({add_ins, identity_ins});
tass.insert(tass.begin(), std::make_pair(add_ins, 0));
}
migraphx::generate_root_modules(p1, tass);
migraphx::program p2;
{
migraphx::module_ref mm = p2.get_main_module();
auto z = mm->add_parameter("z", s);
auto x_1 = mm->add_instruction(migraphx::make_op("identity"), z);
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_4 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {y, x}, {target_mod_0_0});
auto x_5 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_4);
mm->add_return({x_5, x_1});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(merge_case_5)
{
/*
**** "Return" as the Merge Node ****
Add (tid = 0)
|
Identity Identity
(no assignment) (no assignment)
| |
-----------------
|
Return
*/
migraphx::target_assignments tass;
auto s = migraphx::shape{migraphx::shape::float_type, {8}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto identity_ins_0 = mm->add_instruction(migraphx::make_op("identity"), add_ins);
auto identity_ins_1 = mm->add_instruction(migraphx::make_op("identity"), z_param);
mm->add_return({identity_ins_0, identity_ins_1});
tass.insert(tass.begin(), std::make_pair(add_ins, 0));
}
migraphx::generate_root_modules(p1, tass);
migraphx::program p2;
{
migraphx::module_ref mm = p2.get_main_module();
auto z = mm->add_parameter("z", s);
auto x_1 = mm->add_instruction(migraphx::make_op("identity"), z);
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_4 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {y, x}, {target_mod_0_0});
auto x_5 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_4);
auto x_6 = mm->add_instruction(migraphx::make_op("identity"), x_5);
mm->add_return({x_6, x_1});
}
EXPECT(p1.sort() == p2.sort());
} }
TEST_CASE(fork_and_merge_case) TEST_CASE(fork_and_merge_case)
{ {
/* /*
...@@ -768,15 +879,13 @@ TEST_CASE(fork_and_merge_case) ...@@ -768,15 +879,13 @@ TEST_CASE(fork_and_merge_case)
migraphx::program p2; migraphx::program p2;
{ {
migraphx::module_ref mm = p2.get_main_module(); migraphx::module_ref mm = p2.get_main_module();
auto z = mm->add_parameter("z", migraphx::shape{migraphx::shape::float_type, {8}}); auto z = mm->add_parameter("z", s);
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {8}}); auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {8}}); auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0"); migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter( auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
"param:1", migraphx::shape{migraphx::shape::float_type, {8}}); auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}});
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction( auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0); migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_2}); target_mod_0_0->add_return({x_target_mod_0_0_2});
...@@ -786,10 +895,8 @@ TEST_CASE(fork_and_merge_case) ...@@ -786,10 +895,8 @@ TEST_CASE(fork_and_merge_case)
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3); auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3);
migraphx::module_ref target_mod_0_1 = p2.create_module("target_mod_0_1"); migraphx::module_ref target_mod_0_1 = p2.create_module("target_mod_0_1");
auto target_mod_0_1_param_1 = target_mod_0_1->add_parameter( auto target_mod_0_1_param_1 = target_mod_0_1->add_parameter("param:1", s);
"param:1", migraphx::shape{migraphx::shape::float_type, {8}}); auto target_mod_0_1_param_0 = target_mod_0_1->add_parameter("param:0", s);
auto target_mod_0_1_param_0 = target_mod_0_1->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}});
auto x_target_mod_0_1_2 = target_mod_0_1->add_instruction( auto x_target_mod_0_1_2 = target_mod_0_1->add_instruction(
migraphx::make_op("mul"), target_mod_0_1_param_1, target_mod_0_1_param_0); migraphx::make_op("mul"), target_mod_0_1_param_1, target_mod_0_1_param_0);
target_mod_0_1->add_return({x_target_mod_0_1_2}); target_mod_0_1->add_return({x_target_mod_0_1_2});
...@@ -799,8 +906,7 @@ TEST_CASE(fork_and_merge_case) ...@@ -799,8 +906,7 @@ TEST_CASE(fork_and_merge_case)
auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5); auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5);
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0"); migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter( auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter("param:0", s);
"param:0", migraphx::shape{migraphx::shape::float_type, {8}});
auto x_target_mod_1_0_1 = auto x_target_mod_1_0_1 =
target_mod_1_0->add_instruction(migraphx::make_op("identity"), target_mod_1_0_param_0); target_mod_1_0->add_instruction(migraphx::make_op("identity"), target_mod_1_0_param_0);
target_mod_1_0->add_return({x_target_mod_1_0_1}); target_mod_1_0->add_return({x_target_mod_1_0_1});
...@@ -810,10 +916,8 @@ TEST_CASE(fork_and_merge_case) ...@@ -810,10 +916,8 @@ TEST_CASE(fork_and_merge_case)
auto x_8 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_7); auto x_8 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_7);
migraphx::module_ref target_mod_0_2 = p2.create_module("target_mod_0_2"); migraphx::module_ref target_mod_0_2 = p2.create_module("target_mod_0_2");
auto target_mod_0_2_param_1 = target_mod_0_2->add_parameter( auto target_mod_0_2_param_1 = target_mod_0_2->add_parameter("param:1", s);
"param:1", migraphx::shape{migraphx::shape::float_type, {8}}); auto target_mod_0_2_param_0 = target_mod_0_2->add_parameter("param:0", s);
auto target_mod_0_2_param_0 = target_mod_0_2->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}});
auto x_target_mod_0_2_2 = target_mod_0_2->add_instruction( auto x_target_mod_0_2_2 = target_mod_0_2->add_instruction(
migraphx::make_op("sub"), target_mod_0_2_param_1, target_mod_0_2_param_0); migraphx::make_op("sub"), target_mod_0_2_param_1, target_mod_0_2_param_0);
target_mod_0_2->add_return({x_target_mod_0_2_2}); target_mod_0_2->add_return({x_target_mod_0_2_2});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment