Commit ccbdeaa9 authored by Umang Yadav's avatar Umang Yadav
Browse files

refactor

parent a88074c2
...@@ -73,9 +73,11 @@ Given target assignments (tass) for the instructions, generate run_on_target mod ...@@ -73,9 +73,11 @@ Given target assignments (tass) for the instructions, generate run_on_target mod
automatically. Input graph should be uncompiled migraphx program. target assignments (tass) map automatically. Input graph should be uncompiled migraphx program. target assignments (tass) map
should have a map of instruction to target_id. Instructions that are not inside tass map are should have a map of instruction to target_id. Instructions that are not inside tass map are
considered to be targeted for the "Ref" by default. params, literals and other builtins shouldn't be considered to be targeted for the "Ref" by default. params, literals and other builtins shouldn't be
part of the tass, only compute and reshaper instructions should be part of tass. Copy, sync and part of the tass, only compute and certain reshaper instructions should be part of tass. Copy, sync
alloc instructions would be generated by compiler at later stage, so those shouldn't be considered. and alloc instructions would be generated by compiler at later stage, so those shouldn't be
(TODO): CustomOps may require special handling. considered. (TODO): CustomOps may require special handling.
Ref is used as default target for instructions that do not have assignments.
Step 1: Step 1:
Identify subgraph boundaries: Identify subgraph boundaries:
...@@ -84,7 +86,6 @@ assignment as the node itself. ...@@ -84,7 +86,6 @@ assignment as the node itself.
(b) Boundaries can happen when any output of any node doesn't have all its inputs with same target (b) Boundaries can happen when any output of any node doesn't have all its inputs with same target
assignment as the node itself. assignment as the node itself.
Ref is used for instructions that do not have assignments.
For example graphs like following: For example graphs like following:
1. Ref --> Target X --> Ref 1. Ref --> Target X --> Ref
2. Ref --> Target X --> Target Y 2. Ref --> Target X --> Target Y
...@@ -124,6 +125,7 @@ struct auto_gen_root_modules ...@@ -124,6 +125,7 @@ struct auto_gen_root_modules
: tass(target_assignments) : tass(target_assignments)
{ {
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
// initialize tid_counter, it is used to create meaningful names for the target modules
for(const auto& i : tass) for(const auto& i : tass)
{ {
if(tid_counter.find(i.second) == tid_counter.end()) if(tid_counter.find(i.second) == tid_counter.end())
...@@ -163,7 +165,7 @@ struct auto_gen_root_modules ...@@ -163,7 +165,7 @@ struct auto_gen_root_modules
{ {
return false; return false;
} }
if(std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) { return std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) {
if((skip_ins.find(input_ins) != skip_ins.end()) or if((skip_ins.find(input_ins) != skip_ins.end()) or
(tass.find(input_ins) != tass.end() and (tass.find(input_ins) != tass.end() and
tass.at(ins) != tid.value_or(std::numeric_limits<std::size_t>::max()))) tass.at(ins) != tid.value_or(std::numeric_limits<std::size_t>::max())))
...@@ -171,9 +173,7 @@ struct auto_gen_root_modules ...@@ -171,9 +173,7 @@ struct auto_gen_root_modules
return true; return true;
} }
return false; return false;
})) });
return true;
return false;
} }
bool is_fork_node(migraphx::instruction_ref ins, std::optional<std::size_t> tid) bool is_fork_node(migraphx::instruction_ref ins, std::optional<std::size_t> tid)
...@@ -183,7 +183,7 @@ struct auto_gen_root_modules ...@@ -183,7 +183,7 @@ struct auto_gen_root_modules
{ {
return false; return false;
} }
if(std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) { return std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) {
if(tass.find(output_ins) != tass.end() and if(tass.find(output_ins) != tass.end() and
tass.at(output_ins) != tid.value_or(std::numeric_limits<std::size_t>::max()) and tass.at(output_ins) != tid.value_or(std::numeric_limits<std::size_t>::max()) and
output_ins->name() != "@return") output_ins->name() != "@return")
...@@ -191,9 +191,7 @@ struct auto_gen_root_modules ...@@ -191,9 +191,7 @@ struct auto_gen_root_modules
return true; return true;
} }
return false; return false;
})) });
return true;
return false;
} }
void find_subgraphs(migraphx::module_ref mm, migraphx::program& p) void find_subgraphs(migraphx::module_ref mm, migraphx::program& p)
......
...@@ -453,7 +453,6 @@ TEST_CASE(fork_and_merge_case) ...@@ -453,7 +453,6 @@ TEST_CASE(fork_and_merge_case)
auto x_10 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_9); auto x_10 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_9);
mm->add_return({x_10}); mm->add_return({x_10});
} }
p1.print_cpp(std::cout);
EXPECT(p1.sort() == p2.sort()); EXPECT(p1.sort() == p2.sort());
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment