Commit 20524b8e authored by Umang Yadav's avatar Umang Yadav
Browse files

Add fork and merge cases

parent 359bac6d
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/instruction_ref.hpp"
#include <cstddef> #include <cstddef>
#include <limits> #include <limits>
#include <iterator> #include <iterator>
...@@ -78,14 +79,34 @@ alloc instructions would be generated by compiler at later stage, so those shoul ...@@ -78,14 +79,34 @@ alloc instructions would be generated by compiler at later stage, so those shoul
(TODO): CustomOps may require special handling. (TODO): CustomOps may require special handling.
Step 1: Step 1:
Identify subgraph boundaries Identify subgraph boundaries:
(a) Boundaries can happen when any output of a node doesn't have same target
assignment as the node itself.
(b) Boundaries can happen when any output of any node doesn't have all its inputs with same target
assignment as the node itself.
Ref is used for instructions that do not have assignments. Ref is used for instructions that do not have assignments.
Boundaries can happen in following cases. For example graphs like following:
1. Ref --> Target X --> Ref 1. Ref --> Target X --> Ref
2. Ref --> Target X --> Target Y 2. Ref --> Target X --> Target Y
3. Target X --> Target Y --> Target Z , in this case target X and target Z can be same 3. Target X --> Target Y --> Target Z , in this case target X and target Z can be same
4. Target X --> "@return" 4. Target X --> "@return"
5. Target X --> Ref --> "@return" 5. Target X --> Ref --> "@return"
6. When there is a fork in graph :
Ref
|
-------------
| |
| |
Target X Ref
7. When there is merge in a graph :
Target X Ref
| |
---------------
|
Target X
Each of those identified regions could have futher nested sub modules which needs to be handled Each of those identified regions could have futher nested sub modules which needs to be handled
separately. separately.
...@@ -121,6 +142,43 @@ struct auto_gen_root_modules ...@@ -121,6 +142,43 @@ struct auto_gen_root_modules
} }
} }
bool is_different_subgraph(migraphx::instruction_ref ins, size_t tid)
{
if(tass.find(ins) == tass.end() or tass.at(ins) != tid)
{
return true;
}
return false;
}
bool is_merge_node(migraphx::instruction_ref ins, size_t tid)
{
const auto inputs = ins->inputs();
if(std::any_of(inputs.begin(), inputs.end(), [&](auto input_ins) {
if(is_different_subgraph(input_ins, tid))
{
return true;
}
return false;
}))
return true;
return false;
}
bool is_fork_node(migraphx::instruction_ref ins, size_t tid)
{
const auto outputs = ins->outputs();
if(std::any_of(outputs.begin(), outputs.end(), [&](auto output_ins) {
if(is_different_subgraph(output_ins, tid) and output_ins->name() != "@return")
{
return true;
}
return false;
}))
return true;
return false;
}
void find_subgraphs(migraphx::module_ref mm, migraphx::program& p) void find_subgraphs(migraphx::module_ref mm, migraphx::program& p)
{ {
// sort the graph in reverse post order DFS order // sort the graph in reverse post order DFS order
...@@ -138,7 +196,6 @@ struct auto_gen_root_modules ...@@ -138,7 +196,6 @@ struct auto_gen_root_modules
{ {
std::cout << "looking at instruction: \n"; std::cout << "looking at instruction: \n";
ins->debug_print(); ins->debug_print();
std::cout << "\n";
} }
// skip all params, literal and builtins other than return, skip "run_on_target_mod" // skip all params, literal and builtins other than return, skip "run_on_target_mod"
...@@ -148,11 +205,17 @@ struct auto_gen_root_modules ...@@ -148,11 +205,17 @@ struct auto_gen_root_modules
{ {
continue; continue;
} }
else if(ins->name() == "@return" or tass.find(ins) == tass.end() or else if(ins->name() == "@return" or is_different_subgraph(ins, current_tid) or
tass.at(ins) != current_tid) is_merge_node(ins, current_tid))
{ {
generate_run_on_target_modules(mm, p, ins, current_tid); generate_run_on_target_modules(mm, p, ins, current_tid);
} }
else if(is_fork_node(ins, current_tid))
{
same_tid_ins_vec.push_back(ins);
same_tid_ins_set.insert(ins);
generate_run_on_target_modules(mm, p, std::next(ins), current_tid);
}
else if(current_tid == std::numeric_limits<std::size_t>::max()) else if(current_tid == std::numeric_limits<std::size_t>::max())
{ {
current_tid = tass.at(ins); current_tid = tass.at(ins);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment