"README_ORIGIN.md" did not exist on "ecfa4fc6b7c27c8024e21c55a3e7bc47ffe524ea"
Commit ccbdeaa9 authored by Umang Yadav's avatar Umang Yadav
Browse files

refactor

parent a88074c2
......@@ -73,9 +73,11 @@ Given target assignments (tass) for the instructions, generate run_on_target mod
automatically. Input graph should be uncompiled migraphx program. target assignments (tass) map
should have a map of instruction to target_id. Instructions that are not inside tass map are
considered to be targeted for the "Ref" by default. params, literals and other builtins shouldn't be
part of the tass, only compute and reshaper instructions should be part of tass. Copy, sync and
alloc instructions would be generated by compiler at later stage, so those shouldn't be considered.
(TODO): CustomOps may require special handling.
part of the tass, only compute and certain reshaper instructions should be part of tass. Copy, sync
and alloc instructions would be generated by compiler at later stage, so those shouldn't be
considered. (TODO): CustomOps may require special handling.
Ref is used as default target for instructions that do not have assignments.
Step 1:
Identify subgraph boundaries:
......@@ -84,7 +86,6 @@ assignment as the node itself.
(b) Boundaries can happen when any output of any node doesn't have all its inputs with same target
assignment as the node itself.
Ref is used for instructions that do not have assignments.
For example graphs like following:
1. Ref --> Target X --> Ref
2. Ref --> Target X --> Target Y
......@@ -124,6 +125,7 @@ struct auto_gen_root_modules
: tass(target_assignments)
{
auto* mm = p.get_main_module();
// initialize tid_counter, it is used to create meaningful names for the target modules
for(const auto& i : tass)
{
if(tid_counter.find(i.second) == tid_counter.end())
......@@ -163,17 +165,15 @@ struct auto_gen_root_modules
{
return false;
}
if(std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) {
if((skip_ins.find(input_ins) != skip_ins.end()) or
(tass.find(input_ins) != tass.end() and
tass.at(ins) != tid.value_or(std::numeric_limits<std::size_t>::max())))
{
return true;
}
return false;
}))
return true;
return false;
return std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) {
if((skip_ins.find(input_ins) != skip_ins.end()) or
(tass.find(input_ins) != tass.end() and
tass.at(ins) != tid.value_or(std::numeric_limits<std::size_t>::max())))
{
return true;
}
return false;
});
}
bool is_fork_node(migraphx::instruction_ref ins, std::optional<std::size_t> tid)
......@@ -183,17 +183,15 @@ struct auto_gen_root_modules
{
return false;
}
if(std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) {
if(tass.find(output_ins) != tass.end() and
tass.at(output_ins) != tid.value_or(std::numeric_limits<std::size_t>::max()) and
output_ins->name() != "@return")
{
return true;
}
return false;
}))
return true;
return false;
return std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) {
if(tass.find(output_ins) != tass.end() and
tass.at(output_ins) != tid.value_or(std::numeric_limits<std::size_t>::max()) and
output_ins->name() != "@return")
{
return true;
}
return false;
});
}
void find_subgraphs(migraphx::module_ref mm, migraphx::program& p)
......
......@@ -453,7 +453,6 @@ TEST_CASE(fork_and_merge_case)
auto x_10 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), x_9);
mm->add_return({x_10});
}
p1.print_cpp(std::cout);
EXPECT(p1.sort() == p2.sort());
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment