Commit 335a1fae authored by Umang Yadav's avatar Umang Yadav
Browse files

add identity ins for the inbetween case as well

parent d3dbcd62
......@@ -86,6 +86,8 @@ TEST_CASE(single_target_test)
TEST_CASE(two_targets_with_ref)
{
/*
Identity
|
Add (tid = 1)
|
Mul (tid = 0)
......@@ -102,10 +104,11 @@ TEST_CASE(two_targets_with_ref)
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto identity_ins_0 = mm->add_instruction(migraphx::make_op("identity"), x_param);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), identity_ins_0, y_param);
auto mul_ins = mm->add_instruction(migraphx::make_op("mul"), add_ins, z_param);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), mul_ins);
mm->add_return({identity_ins});
auto identity_ins_1 = mm->add_instruction(migraphx::make_op("identity"), mul_ins);
mm->add_return({identity_ins_1});
tass.insert(tass.begin(), std::make_pair(add_ins, 1));
tass.insert(tass.begin(), std::make_pair(mul_ins, 0));
}
......@@ -116,7 +119,7 @@ TEST_CASE(two_targets_with_ref)
auto z = mm->add_parameter("z", s);
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
auto identity_ins_0 = mm->add_instruction(migraphx::make_op("identity"), x);
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter(
"param:1", migraphx::shape{migraphx::shape::float_type, {8}});
......@@ -135,15 +138,16 @@ TEST_CASE(two_targets_with_ref)
migraphx::make_op("mul"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_3 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 1}}), {y, x}, {target_mod_1_0});
auto x_3 = mm->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 1}}),
{y, identity_ins_0},
{target_mod_1_0});
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3);
auto x_5 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {z, x_4}, {target_mod_0_0});
auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5);
auto x_7 = mm->add_instruction(migraphx::make_op("identity"), x_6);
mm->add_return({x_7});
auto identity_ins_1 = mm->add_instruction(migraphx::make_op("identity"), x_6);
mm->add_return({identity_ins_1});
}
EXPECT(p1.sort() == p2.sort());
}
......@@ -151,6 +155,8 @@ TEST_CASE(two_targets_with_ref)
TEST_CASE(two_targets_ref_inbetween)
{
/*
Identity
|
Add (tid = 1)
|
Identity
......@@ -167,9 +173,10 @@ TEST_CASE(two_targets_ref_inbetween)
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), add_ins);
auto mul_ins = mm->add_instruction(migraphx::make_op("mul"), identity_ins, z_param);
auto identity_ins_0 = mm->add_instruction(migraphx::make_op("identity"), x_param);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), identity_ins_0, y_param);
auto identity_ins_1 = mm->add_instruction(migraphx::make_op("identity"), add_ins);
auto mul_ins = mm->add_instruction(migraphx::make_op("mul"), identity_ins_1, z_param);
mm->add_return({mul_ins});
tass.insert(tass.begin(), std::make_pair(add_ins, 1));
tass.insert(tass.begin(), std::make_pair(mul_ins, 0));
......@@ -181,6 +188,7 @@ TEST_CASE(two_targets_ref_inbetween)
auto z = mm->add_parameter("z", s);
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), x);
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter(
......@@ -200,8 +208,9 @@ TEST_CASE(two_targets_ref_inbetween)
migraphx::make_op("mul"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_3 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 1}}), {y, x}, {target_mod_1_0});
auto x_3 = mm->add_instruction(migraphx::make_op("run_on_target", {{"target_id", 1}}),
{y, identity_ins},
{target_mod_1_0});
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3);
auto x_5 = mm->add_instruction(migraphx::make_op("identity"), x_4);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment