Commit 3cc4be9e authored by Umang Yadav's avatar Umang Yadav
Browse files

Fix conditions for fork-merge on return case

parent 44369c8e
......@@ -149,8 +149,8 @@ struct auto_gen_root_modules
}
}
bool is_different_subgraph(migraphx::instruction_ref current_ins,
std::optional<std::size_t> previous_tid)
bool has_different_tass(migraphx::instruction_ref current_ins,
std::optional<std::size_t> previous_tid)
{
if(tass.find(current_ins) == tass.end())
{
......@@ -162,11 +162,13 @@ struct auto_gen_root_modules
/*
Merge node is defined as node where two or more branches converge.
NodeX NodeY
| |
---------
|
NodeZ
For the partitioner, if any of the merge node's input doesn't have same tid as the merge node
itself then, it is classified as boundary for subgraph.
*/
......@@ -179,17 +181,22 @@ struct auto_gen_root_modules
return false;
}
return std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) {
return is_different_subgraph(input_ins, ins_tid);
return (
(this->skip_ins.find(input_ins) != skip_ins.end()) or
(tass.find(input_ins) != tass.end() and
tass.at(input_ins) != ins_tid.value_or(std::numeric_limits<std::size_t>::max())));
});
}
/*
Fork node is defined as node where graph forks into two or more branches
NodeX
|
------------
| |
NodeY NodeZ
For the partitioner, if any of the fork node's output doesn't have same tid as the fork node
itself then, it is classified as boundary for subgraph.
*/
......@@ -201,7 +208,13 @@ struct auto_gen_root_modules
return false;
}
return std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) {
return is_different_subgraph(output_ins, ins_tid);
if(output_ins->name() == "return")
{
return false;
}
return (tass.find(output_ins) != tass.end() and
tass.at(output_ins) !=
ins_tid.value_or(std::numeric_limits<std::size_t>::max()));
});
}
......@@ -254,7 +267,7 @@ struct auto_gen_root_modules
}
else
{
if(ins->name() == "@return" or is_different_subgraph(ins, current_tid) or
if(ins->name() == "@return" or has_different_tass(ins, current_tid) or
is_merge_node(ins, current_tid))
{
generate_run_on_target_modules(mm, p, ins, current_tid);
......@@ -304,7 +317,7 @@ struct auto_gen_root_modules
// gather all parameters
std::unordered_set<instruction_ref> params;
// gather all return values
std::unordered_set<instruction_ref> return_ins;
std::vector<instruction_ref> return_ins;
for(auto tins : iterator_for(same_tid_ins_vec))
{
auto inputs = (*tins)->inputs();
......@@ -321,7 +334,7 @@ struct auto_gen_root_modules
return same_tid_ins_set.count(out_ins) == 0;
}))
{
return_ins.insert(*tins);
return_ins.push_back(*tins);
}
}
if(enabled(MIGRAPHX_DEBUG_ROOT_GENERATOR{}))
......@@ -380,7 +393,7 @@ struct auto_gen_root_modules
for(auto ritr : iterator_for(return_ins))
{
rins.push_back(params_map.at(*ritr));
return_ins_idx_map[*ritr] = std::distance(ritr, return_ins.begin());
return_ins_idx_map[*ritr] = std::distance(return_ins.begin(), ritr);
}
tmod->add_return(rins);
......
......@@ -485,9 +485,11 @@ TEST_CASE(fork_case_4)
|
---------------------------
| |
Identity (tid = 0) Return
|
Return
Identity (tid = 0) |
| |
--------------------------
|
Return
*/
auto s = migraphx::shape{migraphx::shape::float_type, {8}};
migraphx::target_assignments tass;
......@@ -505,29 +507,24 @@ TEST_CASE(fork_case_4)
migraphx::generate_root_modules(p1, tass);
migraphx::program p2;
{
migraphx::module_ref mm = p2.get_main_module();
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
migraphx::module_ref mm = p2.get_main_module();
auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction(
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto x_target_mod_0_0_0 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_2});
migraphx::module_ref target_mod_0_1 = p2.create_module("target_mod_0_1");
auto target_mod_0_1_param_0 = target_mod_0_1->add_parameter("param:0", s);
auto x_target_mod_0_1_1 =
target_mod_0_1->add_instruction(migraphx::make_op("identity"), target_mod_0_1_param_0);
target_mod_0_1->add_return({x_target_mod_0_1_1});
auto x_target_mod_0_0_1 =
target_mod_0_0->add_instruction(migraphx::make_op("identity"), x_target_mod_0_0_0);
target_mod_0_0->add_return({x_target_mod_0_0_0, x_target_mod_0_0_1});
auto x_2 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {y, x}, {target_mod_0_0});
auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2);
auto x_4 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x_3}, {target_mod_0_1});
auto x_5 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_4);
mm->add_return({x_3, x_5});
auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), x_2);
mm->add_return({x_3, x_4});
}
EXPECT(p1.sort() == p2.sort());
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment