Commit 335a1fae authored by Umang Yadav's avatar Umang Yadav
Browse files

add identity ins for the inbetween case as well

parent d3dbcd62
...@@ -86,6 +86,8 @@ TEST_CASE(single_target_test) ...@@ -86,6 +86,8 @@ TEST_CASE(single_target_test)
TEST_CASE(two_targets_with_ref) TEST_CASE(two_targets_with_ref)
{ {
/* /*
Identity
|
Add (tid = 1) Add (tid = 1)
| |
Mul (tid = 0) Mul (tid = 0)
...@@ -102,10 +104,11 @@ TEST_CASE(two_targets_with_ref) ...@@ -102,10 +104,11 @@ TEST_CASE(two_targets_with_ref)
auto x_param = mm->add_parameter("x", s); auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s); auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s); auto z_param = mm->add_parameter("z", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param); auto identity_ins_0 = mm->add_instruction(migraphx::make_op("identity"), x_param);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), identity_ins_0, y_param);
auto mul_ins = mm->add_instruction(migraphx::make_op("mul"), add_ins, z_param); auto mul_ins = mm->add_instruction(migraphx::make_op("mul"), add_ins, z_param);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), mul_ins); auto identity_ins_1 = mm->add_instruction(migraphx::make_op("identity"), mul_ins);
mm->add_return({identity_ins}); mm->add_return({identity_ins_1});
tass.insert(tass.begin(), std::make_pair(add_ins, 1)); tass.insert(tass.begin(), std::make_pair(add_ins, 1));
tass.insert(tass.begin(), std::make_pair(mul_ins, 0)); tass.insert(tass.begin(), std::make_pair(mul_ins, 0));
} }
...@@ -116,7 +119,7 @@ TEST_CASE(two_targets_with_ref) ...@@ -116,7 +119,7 @@ TEST_CASE(two_targets_with_ref)
auto z = mm->add_parameter("z", s); auto z = mm->add_parameter("z", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto identity_ins_0 = mm->add_instruction(migraphx::make_op("identity"), x);
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0"); migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter( auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter(
"param:1", migraphx::shape{migraphx::shape::float_type, {8}}); "param:1", migraphx::shape{migraphx::shape::float_type, {8}});
...@@ -135,15 +138,16 @@ TEST_CASE(two_targets_with_ref) ...@@ -135,15 +138,16 @@ TEST_CASE(two_targets_with_ref)
migraphx::make_op("mul"), target_mod_0_0_param_1, target_mod_0_0_param_0); migraphx::make_op("mul"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_2}); target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_3 = mm->add_instruction( auto x_3 = mm->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 1}}),
migraphx::make_op("run_on_target", {{"target_id", 1}}), {y, x}, {target_mod_1_0}); {y, identity_ins_0},
{target_mod_1_0});
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3); auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3);
auto x_5 = mm->add_instruction( auto x_5 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {z, x_4}, {target_mod_0_0}); migraphx::make_op("run_on_target", {{"target_id", 0}}), {z, x_4}, {target_mod_0_0});
auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5); auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5);
auto x_7 = mm->add_instruction(migraphx::make_op("identity"), x_6); auto identity_ins_1 = mm->add_instruction(migraphx::make_op("identity"), x_6);
mm->add_return({x_7}); mm->add_return({identity_ins_1});
} }
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
} }
...@@ -151,6 +155,8 @@ TEST_CASE(two_targets_with_ref) ...@@ -151,6 +155,8 @@ TEST_CASE(two_targets_with_ref)
TEST_CASE(two_targets_ref_inbetween) TEST_CASE(two_targets_ref_inbetween)
{ {
/* /*
Identity
|
Add (tid = 1) Add (tid = 1)
| |
Identity Identity
...@@ -167,9 +173,10 @@ TEST_CASE(two_targets_ref_inbetween) ...@@ -167,9 +173,10 @@ TEST_CASE(two_targets_ref_inbetween)
auto x_param = mm->add_parameter("x", s); auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s); auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s); auto z_param = mm->add_parameter("z", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param); auto identity_ins_0 = mm->add_instruction(migraphx::make_op("identity"), x_param);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), add_ins); auto add_ins = mm->add_instruction(migraphx::make_op("add"), identity_ins_0, y_param);
auto mul_ins = mm->add_instruction(migraphx::make_op("mul"), identity_ins, z_param); auto identity_ins_1 = mm->add_instruction(migraphx::make_op("identity"), add_ins);
auto mul_ins = mm->add_instruction(migraphx::make_op("mul"), identity_ins_1, z_param);
mm->add_return({mul_ins}); mm->add_return({mul_ins});
tass.insert(tass.begin(), std::make_pair(add_ins, 1)); tass.insert(tass.begin(), std::make_pair(add_ins, 1));
tass.insert(tass.begin(), std::make_pair(mul_ins, 0)); tass.insert(tass.begin(), std::make_pair(mul_ins, 0));
...@@ -181,6 +188,7 @@ TEST_CASE(two_targets_ref_inbetween) ...@@ -181,6 +188,7 @@ TEST_CASE(two_targets_ref_inbetween)
auto z = mm->add_parameter("z", s); auto z = mm->add_parameter("z", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), x);
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0"); migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter( auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter(
...@@ -200,8 +208,9 @@ TEST_CASE(two_targets_ref_inbetween) ...@@ -200,8 +208,9 @@ TEST_CASE(two_targets_ref_inbetween)
migraphx::make_op("mul"), target_mod_0_0_param_1, target_mod_0_0_param_0); migraphx::make_op("mul"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_2}); target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_3 = mm->add_instruction( auto x_3 = mm->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 1}}),
migraphx::make_op("run_on_target", {{"target_id", 1}}), {y, x}, {target_mod_1_0}); {y, identity_ins},
{target_mod_1_0});
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3); auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3);
auto x_5 = mm->add_instruction(migraphx::make_op("identity"), x_4); auto x_5 = mm->add_instruction(migraphx::make_op("identity"), x_4);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment