Unverified Commit e46a6a52 authored by Umang Yadav's avatar Umang Yadav Committed by GitHub
Browse files

Add use of implicit deps inside sorting function (#2005)

* use implicit deps for the sorting
* use BFS for sorting program
parent 01d4ae09
...@@ -222,7 +222,17 @@ struct MIGRAPHX_EXPORT module ...@@ -222,7 +222,17 @@ struct MIGRAPHX_EXPORT module
void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const; void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;
std::vector<module_ref> get_sub_modules(bool shallow = false) const; std::vector<module_ref> get_sub_modules(bool shallow = false) const;
/* sorts the module in topological order aka reverse-post order (RPO) DFS order
it takes last instruction or @return as the root and walks back the graph and moves inputs
of the each instruction such that it appears before the instruction itself.
*/
module& sort(); module& sort();
/* Any instruction "X" can have module arguments and those modules inside them can use any other
* instruction "Y" from predecessor modules of the instruction "X". Such instruction "Y" inside
* module args are not listed as input instructions to "X". But those instructions "Y" must be
* evaluted before the instruction "X" can. Therefore such "Y" instructions are considered
* implicit dependency to "X".
*/
ins_dep_map calc_implicit_deps() const; ins_dep_map calc_implicit_deps() const;
MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const module& m); MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const module& m);
......
...@@ -1011,9 +1011,17 @@ std::vector<module_ref> module::get_sub_modules(bool shallow) const ...@@ -1011,9 +1011,17 @@ std::vector<module_ref> module::get_sub_modules(bool shallow) const
module& module::sort() module& module::sort()
{ {
auto implicit_deps = calc_implicit_deps();
fix([&](auto self, auto ins) { fix([&](auto self, auto ins) {
this->move_instruction(ins, this->begin()); this->move_instruction(ins, this->begin());
for(auto child : ins->inputs()) auto ins_inputs = ins->inputs();
if(implicit_deps.find(ins) != implicit_deps.end())
{
auto ins_implict_inputs = implicit_deps.at(ins);
ins_inputs.insert(
ins_inputs.end(), ins_implict_inputs.begin(), ins_implict_inputs.end());
}
for(auto child : ins_inputs)
{ {
if(not contains(this->impl->instructions, child)) if(not contains(this->impl->instructions, child))
{ {
......
...@@ -40,13 +40,14 @@ ...@@ -40,13 +40,14 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/marker.hpp> #include <migraphx/marker.hpp>
#include <migraphx/supported_segments.hpp> #include <migraphx/supported_segments.hpp>
#include <iostream> #include <iostream>
#include <queue>
#include <sstream> #include <sstream>
#include <algorithm> #include <algorithm>
#include <set> #include <set>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <unordered_set> #include <unordered_set>
#include <map> #include <map>
#include <cassert> #include <cassert>
...@@ -1191,11 +1192,19 @@ void program::remove_unused_modules() ...@@ -1191,11 +1192,19 @@ void program::remove_unused_modules()
program& program::sort() program& program::sort()
{ {
for(auto& pp : this->impl->modules) std::queue<migraphx::module_ref> mqueue;
mqueue.push(get_main_module());
while(not mqueue.empty())
{ {
pp.second.sort(); module_ref current_mod = mqueue.front();
current_mod->sort();
mqueue.pop();
auto child_mods = current_mod->get_sub_modules(true);
for(auto& sub_mod : child_mods)
{
mqueue.push(sub_mod);
}
} }
return *this; return *this;
} }
......
...@@ -83,7 +83,7 @@ TEST_CASE(calc_implict_deps) ...@@ -83,7 +83,7 @@ TEST_CASE(calc_implict_deps)
auto* else_mod = p.create_module("If_5_else"); auto* else_mod = p.create_module("If_5_else");
auto l2 = else_mod->add_literal(migraphx::literal(ys, datay)); auto l2 = else_mod->add_literal(migraphx::literal(ys, datay));
auto a2 = else_mod->add_instruction(migraphx::make_op("if"), {cond}, {then_mod1, else_mod1}); auto a2 = else_mod->add_instruction(migraphx::make_op("if"), {cond}, {then_mod1, else_mod1});
auto a3 = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), a2); auto a3 = else_mod->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), a2);
else_mod->add_return({a3, l2}); else_mod->add_return({a3, l2});
auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod});
...@@ -95,6 +95,15 @@ TEST_CASE(calc_implict_deps) ...@@ -95,6 +95,15 @@ TEST_CASE(calc_implict_deps)
EXPECT(migraphx::contains(implicit_deps.at(ret), x1)); EXPECT(migraphx::contains(implicit_deps.at(ret), x1));
EXPECT(migraphx::contains(implicit_deps.at(ret), x2)); EXPECT(migraphx::contains(implicit_deps.at(ret), x2));
EXPECT(migraphx::contains(implicit_deps.at(ret), y2)); EXPECT(migraphx::contains(implicit_deps.at(ret), y2));
EXPECT(migraphx::contains(implicit_deps.at(ret), lx));
EXPECT(migraphx::contains(implicit_deps.at(ret), ly));
// test for sorting
p.sort();
auto ret_inputs = ret->inputs();
ret_inputs.insert(ret_inputs.end(), implicit_deps.at(ret).begin(), implicit_deps.at(ret).end());
EXPECT(std::all_of(ret_inputs.begin(), ret_inputs.end(), [&](const auto i) {
return std::distance(mm->begin(), i) < std::distance(mm->begin(), ret);
}));
} }
TEST_CASE(module_annotate) TEST_CASE(module_annotate)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment