Commit d3dbcd62 authored by Umang Yadav's avatar Umang Yadav
Browse files

Fix test case for ref_inbetween

parent ccbdeaa9
...@@ -215,7 +215,7 @@ struct auto_gen_root_modules ...@@ -215,7 +215,7 @@ struct auto_gen_root_modules
if(fork_node) if(fork_node)
{ {
assert(current_tid.has_value()); assert(current_tid.has_value());
generate_run_on_target_modules(mm, p, ins, current_tid.value()); generate_run_on_target_modules(mm, p, ins, current_tid);
if(not same_tid_ins_vec.empty()) if(not same_tid_ins_vec.empty())
{ {
current_tid = nullopt; current_tid = nullopt;
...@@ -246,7 +246,7 @@ struct auto_gen_root_modules ...@@ -246,7 +246,7 @@ struct auto_gen_root_modules
if(ins->name() == "@return" or is_different_subgraph(ins, current_tid) or if(ins->name() == "@return" or is_different_subgraph(ins, current_tid) or
is_merge_node(ins, current_tid)) is_merge_node(ins, current_tid))
{ {
generate_run_on_target_modules(mm, p, ins, current_tid.value()); generate_run_on_target_modules(mm, p, ins, current_tid);
} }
else if(tass.at(ins) == current_tid.value()) else if(tass.at(ins) == current_tid.value())
{ {
...@@ -282,12 +282,12 @@ struct auto_gen_root_modules ...@@ -282,12 +282,12 @@ struct auto_gen_root_modules
void generate_run_on_target_modules(migraphx::module_ref mm, void generate_run_on_target_modules(migraphx::module_ref mm,
migraphx::program& p, migraphx::program& p,
migraphx::instruction_ref ins, migraphx::instruction_ref ins,
std::size_t& current_tid) std::optional<std::size_t>& current_tid)
{ {
assert(same_tid_ins_vec.size() == same_tid_ins_set.size()); assert(same_tid_ins_vec.size() == same_tid_ins_set.size());
if(same_tid_ins_vec.empty()) if(same_tid_ins_vec.empty())
{ {
assert(current_tid == std::numeric_limits<std::size_t>::max()); assert(not current_tid.has_value());
return; return;
} }
// gather all parameters // gather all parameters
...@@ -327,9 +327,9 @@ struct auto_gen_root_modules ...@@ -327,9 +327,9 @@ struct auto_gen_root_modules
} }
} }
auto* tmod = p.create_module("target_mod_" + std::to_string(current_tid) + "_" + auto* tmod = p.create_module("target_mod_" + std::to_string(current_tid.value()) + "_" +
std::to_string(tid_counter[current_tid])); std::to_string(tid_counter[current_tid.value()]));
update_tid_counter(current_tid); update_tid_counter(current_tid.value());
std::unordered_map<instruction_ref, instruction_ref> params_map; std::unordered_map<instruction_ref, instruction_ref> params_map;
std::size_t param_counter = 0; std::size_t param_counter = 0;
std::vector<instruction_ref> rot_ins_params; std::vector<instruction_ref> rot_ins_params;
...@@ -380,8 +380,11 @@ struct auto_gen_root_modules ...@@ -380,8 +380,11 @@ struct auto_gen_root_modules
} }
// add run_on_target ins // add run_on_target ins
auto rot_ins = mm->insert_instruction( auto rot_ins =
ins, make_op("run_on_target", {{"target_id", current_tid}}), rot_ins_params, {tmod}); mm->insert_instruction(ins,
make_op("run_on_target", {{"target_id", current_tid.value()}}),
rot_ins_params,
{tmod});
skip_ins.insert(rot_ins); skip_ins.insert(rot_ins);
// fetch return instructions from tuple // fetch return instructions from tuple
...@@ -401,13 +404,13 @@ struct auto_gen_root_modules ...@@ -401,13 +404,13 @@ struct auto_gen_root_modules
same_tid_ins_vec.clear(); same_tid_ins_vec.clear();
if(tass.find(ins) != tass.end()) if(tass.find(ins) != tass.end())
{ {
current_tid = tass.at(ins); current_tid = std::make_optional<std::size_t>(tass.at(ins));
same_tid_ins_set.insert(ins); same_tid_ins_set.insert(ins);
same_tid_ins_vec.push_back(ins); same_tid_ins_vec.push_back(ins);
} }
else else
{ {
current_tid = std::numeric_limits<std::size_t>::max(); current_tid = nullopt;
} }
if(enabled(MIGRAPHX_DEBUG_ROOT_GENERATOR{})) if(enabled(MIGRAPHX_DEBUG_ROOT_GENERATOR{}))
{ {
......
...@@ -44,13 +44,54 @@ ...@@ -44,13 +44,54 @@
#include <migraphx/target_assignments.hpp> #include <migraphx/target_assignments.hpp>
#include <test.hpp> #include <test.hpp>
TEST_CASE(simple_no_branch_test) TEST_CASE(single_target_test)
{
/*
Add (tid = 1)
|
Return
*/
migraphx::target_assignments tass;
auto s = migraphx::shape{migraphx::shape::float_type, {8}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
mm->add_return({add_ins});
tass.insert(tass.begin(), std::make_pair(add_ins, 1));
}
migraphx::generate_root_modules(p1, tass);
migraphx::program p2;
{
migraphx::module_ref mm = p2.get_main_module();
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter("param:1", s);
auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter("param:0", s);
auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction(
migraphx::make_op("add"), target_mod_1_0_param_1, target_mod_1_0_param_0);
target_mod_1_0->add_return({x_target_mod_1_0_2});
auto x_2 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 1}}), {y, x}, {target_mod_1_0});
auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2);
mm->add_return({x_3});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(two_targets_with_ref)
{ {
/* /*
Add (tid = 1) Add (tid = 1)
| |
Mul (tid = 0) Mul (tid = 0)
| |
Identity
|
Return Return
*/ */
migraphx::target_assignments tass; migraphx::target_assignments tass;
...@@ -61,11 +102,12 @@ TEST_CASE(simple_no_branch_test) ...@@ -61,11 +102,12 @@ TEST_CASE(simple_no_branch_test)
auto x_param = mm->add_parameter("x", s); auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s); auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s); auto z_param = mm->add_parameter("z", s);
auto cpu_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param); auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto gpu_ins = mm->add_instruction(migraphx::make_op("mul"), cpu_ins, z_param); auto mul_ins = mm->add_instruction(migraphx::make_op("mul"), add_ins, z_param);
mm->add_return({gpu_ins}); auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), mul_ins);
tass.insert(tass.begin(), std::make_pair(cpu_ins, 1)); mm->add_return({identity_ins});
tass.insert(tass.begin(), std::make_pair(gpu_ins, 0)); tass.insert(tass.begin(), std::make_pair(add_ins, 1));
tass.insert(tass.begin(), std::make_pair(mul_ins, 0));
} }
migraphx::generate_root_modules(p1, tass); migraphx::generate_root_modules(p1, tass);
migraphx::program p2; migraphx::program p2;
...@@ -100,7 +142,73 @@ TEST_CASE(simple_no_branch_test) ...@@ -100,7 +142,73 @@ TEST_CASE(simple_no_branch_test)
auto x_5 = mm->add_instruction( auto x_5 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {z, x_4}, {target_mod_0_0}); migraphx::make_op("run_on_target", {{"target_id", 0}}), {z, x_4}, {target_mod_0_0});
auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5); auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5);
mm->add_return({x_6}); auto x_7 = mm->add_instruction(migraphx::make_op("identity"), x_6);
mm->add_return({x_7});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(two_targets_ref_inbetween)
{
/*
Add (tid = 1)
|
Identity
|
Mul (tid = 0)
|
Return
*/
migraphx::target_assignments tass;
auto s = migraphx::shape{migraphx::shape::float_type, {8}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), add_ins);
auto mul_ins = mm->add_instruction(migraphx::make_op("mul"), identity_ins, z_param);
mm->add_return({mul_ins});
tass.insert(tass.begin(), std::make_pair(add_ins, 1));
tass.insert(tass.begin(), std::make_pair(mul_ins, 0));
}
migraphx::generate_root_modules(p1, tass);
migraphx::program p2;
{
migraphx::module_ref mm = p2.get_main_module();
auto z = mm->add_parameter("z", s);
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter(
"param:1", migraphx::shape{migraphx::shape::float_type, {8}});
auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}});
auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction(
migraphx::make_op("add"), target_mod_1_0_param_1, target_mod_1_0_param_0);
target_mod_1_0->add_return({x_target_mod_1_0_2});
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter(
"param:1", migraphx::shape{migraphx::shape::float_type, {8}});
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}});
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("mul"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_3 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 1}}), {y, x}, {target_mod_1_0});
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3);
auto x_5 = mm->add_instruction(migraphx::make_op("identity"), x_4);
auto x_6 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {z, x_5}, {target_mod_0_0});
auto x_7 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_6);
mm->add_return({x_7});
} }
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment