"...include/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "88bdd75aa3fc907dae7112702d65748d97572114"
Commit d3dbcd62 authored by Umang Yadav's avatar Umang Yadav
Browse files

Fix test case for ref_inbetween

parent ccbdeaa9
......@@ -215,7 +215,7 @@ struct auto_gen_root_modules
if(fork_node)
{
assert(current_tid.has_value());
generate_run_on_target_modules(mm, p, ins, current_tid.value());
generate_run_on_target_modules(mm, p, ins, current_tid);
if(not same_tid_ins_vec.empty())
{
current_tid = nullopt;
......@@ -246,7 +246,7 @@ struct auto_gen_root_modules
if(ins->name() == "@return" or is_different_subgraph(ins, current_tid) or
is_merge_node(ins, current_tid))
{
generate_run_on_target_modules(mm, p, ins, current_tid.value());
generate_run_on_target_modules(mm, p, ins, current_tid);
}
else if(tass.at(ins) == current_tid.value())
{
......@@ -282,12 +282,12 @@ struct auto_gen_root_modules
void generate_run_on_target_modules(migraphx::module_ref mm,
migraphx::program& p,
migraphx::instruction_ref ins,
std::size_t& current_tid)
std::optional<std::size_t>& current_tid)
{
assert(same_tid_ins_vec.size() == same_tid_ins_set.size());
if(same_tid_ins_vec.empty())
{
assert(current_tid == std::numeric_limits<std::size_t>::max());
assert(not current_tid.has_value());
return;
}
// gather all parameters
......@@ -327,9 +327,9 @@ struct auto_gen_root_modules
}
}
auto* tmod = p.create_module("target_mod_" + std::to_string(current_tid) + "_" +
std::to_string(tid_counter[current_tid]));
update_tid_counter(current_tid);
auto* tmod = p.create_module("target_mod_" + std::to_string(current_tid.value()) + "_" +
std::to_string(tid_counter[current_tid.value()]));
update_tid_counter(current_tid.value());
std::unordered_map<instruction_ref, instruction_ref> params_map;
std::size_t param_counter = 0;
std::vector<instruction_ref> rot_ins_params;
......@@ -380,8 +380,11 @@ struct auto_gen_root_modules
}
// add run_on_target ins
auto rot_ins = mm->insert_instruction(
ins, make_op("run_on_target", {{"target_id", current_tid}}), rot_ins_params, {tmod});
auto rot_ins =
mm->insert_instruction(ins,
make_op("run_on_target", {{"target_id", current_tid.value()}}),
rot_ins_params,
{tmod});
skip_ins.insert(rot_ins);
// fetch return instructions from tuple
......@@ -401,13 +404,13 @@ struct auto_gen_root_modules
same_tid_ins_vec.clear();
if(tass.find(ins) != tass.end())
{
current_tid = tass.at(ins);
current_tid = std::make_optional<std::size_t>(tass.at(ins));
same_tid_ins_set.insert(ins);
same_tid_ins_vec.push_back(ins);
}
else
{
current_tid = std::numeric_limits<std::size_t>::max();
current_tid = nullopt;
}
if(enabled(MIGRAPHX_DEBUG_ROOT_GENERATOR{}))
{
......
......@@ -44,13 +44,11 @@
#include <migraphx/target_assignments.hpp>
#include <test.hpp>
TEST_CASE(simple_no_branch_test)
TEST_CASE(single_target_test)
{
/*
Add (tid = 1)
|
Mul (tid = 0)
|
Return
*/
migraphx::target_assignments tass;
......@@ -60,12 +58,56 @@ TEST_CASE(simple_no_branch_test)
auto* mm = p1.get_main_module();
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s);
auto cpu_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto gpu_ins = mm->add_instruction(migraphx::make_op("mul"), cpu_ins, z_param);
mm->add_return({gpu_ins});
tass.insert(tass.begin(), std::make_pair(cpu_ins, 1));
tass.insert(tass.begin(), std::make_pair(gpu_ins, 0));
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
mm->add_return({add_ins});
tass.insert(tass.begin(), std::make_pair(add_ins, 1));
}
migraphx::generate_root_modules(p1, tass);
migraphx::program p2;
{
migraphx::module_ref mm = p2.get_main_module();
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter("param:1", s);
auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter("param:0", s);
auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction(
migraphx::make_op("add"), target_mod_1_0_param_1, target_mod_1_0_param_0);
target_mod_1_0->add_return({x_target_mod_1_0_2});
auto x_2 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 1}}), {y, x}, {target_mod_1_0});
auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2);
mm->add_return({x_3});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(two_targets_with_ref)
{
/*
Add (tid = 1)
|
Mul (tid = 0)
|
Identity
|
Return
*/
migraphx::target_assignments tass;
auto s = migraphx::shape{migraphx::shape::float_type, {8}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto mul_ins = mm->add_instruction(migraphx::make_op("mul"), add_ins, z_param);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), mul_ins);
mm->add_return({identity_ins});
tass.insert(tass.begin(), std::make_pair(add_ins, 1));
tass.insert(tass.begin(), std::make_pair(mul_ins, 0));
}
migraphx::generate_root_modules(p1, tass);
migraphx::program p2;
......@@ -100,7 +142,73 @@ TEST_CASE(simple_no_branch_test)
auto x_5 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {z, x_4}, {target_mod_0_0});
auto x_6 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_5);
mm->add_return({x_6});
auto x_7 = mm->add_instruction(migraphx::make_op("identity"), x_6);
mm->add_return({x_7});
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(two_targets_ref_inbetween)
{
/*
Add (tid = 1)
|
Identity
|
Mul (tid = 0)
|
Return
*/
migraphx::target_assignments tass;
auto s = migraphx::shape{migraphx::shape::float_type, {8}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x_param = mm->add_parameter("x", s);
auto y_param = mm->add_parameter("y", s);
auto z_param = mm->add_parameter("z", s);
auto add_ins = mm->add_instruction(migraphx::make_op("add"), x_param, y_param);
auto identity_ins = mm->add_instruction(migraphx::make_op("identity"), add_ins);
auto mul_ins = mm->add_instruction(migraphx::make_op("mul"), identity_ins, z_param);
mm->add_return({mul_ins});
tass.insert(tass.begin(), std::make_pair(add_ins, 1));
tass.insert(tass.begin(), std::make_pair(mul_ins, 0));
}
migraphx::generate_root_modules(p1, tass);
migraphx::program p2;
{
migraphx::module_ref mm = p2.get_main_module();
auto z = mm->add_parameter("z", s);
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_1_0 = p2.create_module("target_mod_1_0");
auto target_mod_1_0_param_1 = target_mod_1_0->add_parameter(
"param:1", migraphx::shape{migraphx::shape::float_type, {8}});
auto target_mod_1_0_param_0 = target_mod_1_0->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}});
auto x_target_mod_1_0_2 = target_mod_1_0->add_instruction(
migraphx::make_op("add"), target_mod_1_0_param_1, target_mod_1_0_param_0);
target_mod_1_0->add_return({x_target_mod_1_0_2});
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter(
"param:1", migraphx::shape{migraphx::shape::float_type, {8}});
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter(
"param:0", migraphx::shape{migraphx::shape::float_type, {8}});
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
migraphx::make_op("mul"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_2});
auto x_3 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 1}}), {y, x}, {target_mod_1_0});
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_3);
auto x_5 = mm->add_instruction(migraphx::make_op("identity"), x_4);
auto x_6 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {z, x_5}, {target_mod_0_0});
auto x_7 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_6);
mm->add_return({x_7});
}
EXPECT(p1.sort() == p2.sort());
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment