"docs/zh_cn/vscode:/vscode.git/clone" did not exist on "f45977008a52baaf97640a0e9b2bbe5ea1c4be34"
Commit 3cc4be9e authored by Umang Yadav's avatar Umang Yadav
Browse files

Fix conditions for fork-merge on return case

parent 44369c8e
...@@ -149,7 +149,7 @@ struct auto_gen_root_modules ...@@ -149,7 +149,7 @@ struct auto_gen_root_modules
} }
} }
bool is_different_subgraph(migraphx::instruction_ref current_ins, bool has_different_tass(migraphx::instruction_ref current_ins,
std::optional<std::size_t> previous_tid) std::optional<std::size_t> previous_tid)
{ {
if(tass.find(current_ins) == tass.end()) if(tass.find(current_ins) == tass.end())
...@@ -162,11 +162,13 @@ struct auto_gen_root_modules ...@@ -162,11 +162,13 @@ struct auto_gen_root_modules
/* /*
Merge node is defined as node where two or more branches converge. Merge node is defined as node where two or more branches converge.
NodeX NodeY NodeX NodeY
| | | |
--------- ---------
| |
NodeZ NodeZ
For the partitioner, if any of the merge node's input doesn't have same tid as the merge node For the partitioner, if any of the merge node's input doesn't have same tid as the merge node
itself then, it is classified as boundary for subgraph. itself then, it is classified as boundary for subgraph.
*/ */
...@@ -179,17 +181,22 @@ struct auto_gen_root_modules ...@@ -179,17 +181,22 @@ struct auto_gen_root_modules
return false; return false;
} }
return std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) { return std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) {
return is_different_subgraph(input_ins, ins_tid); return (
(this->skip_ins.find(input_ins) != skip_ins.end()) or
(tass.find(input_ins) != tass.end() and
tass.at(input_ins) != ins_tid.value_or(std::numeric_limits<std::size_t>::max())));
}); });
} }
/* /*
Fork node is defined as node where graph forks into two or more branches Fork node is defined as node where graph forks into two or more branches
NodeX NodeX
| |
------------ ------------
| | | |
NodeY NodeZ NodeY NodeZ
For the partitioner, if any of the fork node's output doesn't have same tid as the fork node For the partitioner, if any of the fork node's output doesn't have same tid as the fork node
itself then, it is classified as boundary for subgraph. itself then, it is classified as boundary for subgraph.
*/ */
...@@ -201,7 +208,13 @@ struct auto_gen_root_modules ...@@ -201,7 +208,13 @@ struct auto_gen_root_modules
return false; return false;
} }
return std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) { return std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) {
return is_different_subgraph(output_ins, ins_tid); if(output_ins->name() == "return")
{
return false;
}
return (tass.find(output_ins) != tass.end() and
tass.at(output_ins) !=
ins_tid.value_or(std::numeric_limits<std::size_t>::max()));
}); });
} }
...@@ -254,7 +267,7 @@ struct auto_gen_root_modules ...@@ -254,7 +267,7 @@ struct auto_gen_root_modules
} }
else else
{ {
if(ins->name() == "@return" or is_different_subgraph(ins, current_tid) or if(ins->name() == "@return" or has_different_tass(ins, current_tid) or
is_merge_node(ins, current_tid)) is_merge_node(ins, current_tid))
{ {
generate_run_on_target_modules(mm, p, ins, current_tid); generate_run_on_target_modules(mm, p, ins, current_tid);
...@@ -304,7 +317,7 @@ struct auto_gen_root_modules ...@@ -304,7 +317,7 @@ struct auto_gen_root_modules
// gather all parameters // gather all parameters
std::unordered_set<instruction_ref> params; std::unordered_set<instruction_ref> params;
// gather all return values // gather all return values
std::unordered_set<instruction_ref> return_ins; std::vector<instruction_ref> return_ins;
for(auto tins : iterator_for(same_tid_ins_vec)) for(auto tins : iterator_for(same_tid_ins_vec))
{ {
auto inputs = (*tins)->inputs(); auto inputs = (*tins)->inputs();
...@@ -321,7 +334,7 @@ struct auto_gen_root_modules ...@@ -321,7 +334,7 @@ struct auto_gen_root_modules
return same_tid_ins_set.count(out_ins) == 0; return same_tid_ins_set.count(out_ins) == 0;
})) }))
{ {
return_ins.insert(*tins); return_ins.push_back(*tins);
} }
} }
if(enabled(MIGRAPHX_DEBUG_ROOT_GENERATOR{})) if(enabled(MIGRAPHX_DEBUG_ROOT_GENERATOR{}))
...@@ -380,7 +393,7 @@ struct auto_gen_root_modules ...@@ -380,7 +393,7 @@ struct auto_gen_root_modules
for(auto ritr : iterator_for(return_ins)) for(auto ritr : iterator_for(return_ins))
{ {
rins.push_back(params_map.at(*ritr)); rins.push_back(params_map.at(*ritr));
return_ins_idx_map[*ritr] = std::distance(ritr, return_ins.begin()); return_ins_idx_map[*ritr] = std::distance(return_ins.begin(), ritr);
} }
tmod->add_return(rins); tmod->add_return(rins);
......
...@@ -485,7 +485,9 @@ TEST_CASE(fork_case_4) ...@@ -485,7 +485,9 @@ TEST_CASE(fork_case_4)
| |
--------------------------- ---------------------------
| | | |
Identity (tid = 0) Return Identity (tid = 0) |
| |
--------------------------
| |
Return Return
*/ */
...@@ -508,26 +510,21 @@ TEST_CASE(fork_case_4) ...@@ -508,26 +510,21 @@ TEST_CASE(fork_case_4)
migraphx::module_ref mm = p2.get_main_module(); migraphx::module_ref mm = p2.get_main_module();
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0"); migraphx::module_ref target_mod_0_0 = p2.create_module("target_mod_0_0");
auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s); auto target_mod_0_0_param_0 = target_mod_0_0->add_parameter("param:0", s);
auto x_target_mod_0_0_2 = target_mod_0_0->add_instruction( auto target_mod_0_0_param_1 = target_mod_0_0->add_parameter("param:1", s);
auto x_target_mod_0_0_0 = target_mod_0_0->add_instruction(
migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0); migraphx::make_op("add"), target_mod_0_0_param_1, target_mod_0_0_param_0);
target_mod_0_0->add_return({x_target_mod_0_0_2}); auto x_target_mod_0_0_1 =
target_mod_0_0->add_instruction(migraphx::make_op("identity"), x_target_mod_0_0_0);
migraphx::module_ref target_mod_0_1 = p2.create_module("target_mod_0_1"); target_mod_0_0->add_return({x_target_mod_0_0_0, x_target_mod_0_0_1});
auto target_mod_0_1_param_0 = target_mod_0_1->add_parameter("param:0", s);
auto x_target_mod_0_1_1 =
target_mod_0_1->add_instruction(migraphx::make_op("identity"), target_mod_0_1_param_0);
target_mod_0_1->add_return({x_target_mod_0_1_1});
auto x_2 = mm->add_instruction( auto x_2 = mm->add_instruction(
migraphx::make_op("run_on_target", {{"target_id", 0}}), {y, x}, {target_mod_0_0}); migraphx::make_op("run_on_target", {{"target_id", 0}}), {y, x}, {target_mod_0_0});
auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2); auto x_3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_2);
auto x_4 = mm->add_instruction( auto x_4 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), x_2);
migraphx::make_op("run_on_target", {{"target_id", 0}}), {x_3}, {target_mod_0_1}); mm->add_return({x_3, x_4});
auto x_5 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_4);
mm->add_return({x_3, x_5});
} }
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment