Unverified Commit 7ab06956 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Compute dominators (#525)



* rename merge_from to merge_to

* refine comments

* code backup

* clang format

* The first version that can reduce scratch memory usage

* code backup

* clang format

* code backup

* clang format

* fixed a bug related to removing gemm copy

* clang format

* code backup

* clang format

* fix review comments

* clang format

* fix unit test failure

* code backup

* clang format

* code base for further investigation

* code with both the forward and backward approach to compute the conflict table

* clang format

* clang format

* backup changes

* remove unnecessary file

* remove unnecessary code

* code backup

* clang format

* code backup

* clang format'

* fix a bug in the code

* clang format

* code backup

* clang format

* remove unused code

* remove unused code

* rename some functions

* remove print code

* code backup

* add dominator to scheduling

* add dominator algorithm to remove unnecessary conflicts

* Remove comment

* Use erase_if instead

* Formatting

* Code clean up:

* Formatting

* Add dominator info class

* Formatting

* Add dom_info

* Formatting

* Add test case and fix some bugs

* Formatting

* Add unit test for scheduler

* Formatting

* Use index map instead of distance

* Formatting

* Add memory coloring test

* Check for conflict in memory coloring

* Formatting

* Use 1 stream by default

* Update to use modules

* Formatting

* Skip live on entry check

* Formatting

* Formatting

* Fix tidy warning

* Fix tidy warning

* Formatting

* Add nolint

* Use C++17 to build everything when using clang

* Remove input names

* Formatting

* Remove input names

* Keep order of params

* Formatting
Co-authored-by: default avatarShucai Xiao <Shucai.Xiao@amd.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent e8738144
...@@ -95,4 +95,5 @@ ENV LD_LIBRARY_PATH=$PREFIX/lib ...@@ -95,4 +95,5 @@ ENV LD_LIBRARY_PATH=$PREFIX/lib
# Setup ubsan environment to printstacktrace # Setup ubsan environment to printstacktrace
ENV UBSAN_OPTIONS=print_stacktrace=1 ENV UBSAN_OPTIONS=print_stacktrace=1
ENV ASAN_OPTIONS=detect_stack_use_after_return=1:check_initialization_order=1:strict_init_order=1 ENV ASAN_OPTIONS=detect_stack_use_after_return=1:check_initialization_order=1:strict_init_order=1
RUN ln -s /opt/rocm/llvm/bin/llvm-symbolizer /usr/bin/llvm-symbolizer
...@@ -15,6 +15,7 @@ add_library(migraphx ...@@ -15,6 +15,7 @@ add_library(migraphx
compile_src.cpp compile_src.cpp
cpp_generator.cpp cpp_generator.cpp
dead_code_elimination.cpp dead_code_elimination.cpp
dom_info.cpp
dynamic_loader.cpp dynamic_loader.cpp
eliminate_allocation.cpp eliminate_allocation.cpp
eliminate_contiguous.cpp eliminate_contiguous.cpp
......
#include <migraphx/dom_info.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/erase.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool dominator_info::strictly_dominate(instruction_ref ins1, instruction_ref ins2)
{
if(ins1 == ins2)
return false;
auto iter = ins2idom.find(ins2);
while(iter != ins2idom.end())
{
if(ins1 == iter->second)
return true;
assert(iter != ins2idom.find(iter->second));
iter = ins2idom.find(iter->second);
}
return false;
}
struct module_visitor
{
module* mm;
module& get_nodes() const { return *mm; }
const std::vector<instruction_ref>& get_children(instruction_ref ins) { return ins->inputs(); }
};
template <class Visitor>
dominator_info compute_dominator_generic(Visitor v)
{
dominator_info info;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> instr2_doms;
for(instruction_ref ins : iterator_for(v.get_nodes()))
{
const std::vector<instruction_ref>& children = v.get_children(ins);
if(children.size() == 1)
{
info.ins2idom[ins] = children.front();
instr2_doms[ins].insert(children.front());
}
else if(children.size() > 1)
{
auto&& doms = instr2_doms[ins];
doms = instr2_doms[children.front()];
std::for_each(children.begin() + 1, children.end(), [&](instruction_ref child) {
auto&& child_doms = instr2_doms[child];
erase_if(doms, [&](auto x) { return not contains(child_doms, x); });
});
auto iter = std::find_if(doms.begin(), doms.end(), [&](auto dom1) {
return std::none_of(doms.begin(), doms.end(), [&](auto dom2) {
if(dom1 == dom2)
return false;
return info.strictly_dominate(dom1, dom2);
});
});
if(iter != doms.end())
info.ins2idom[ins] = *iter;
}
instr2_doms[ins].insert(ins);
}
return info;
}
dominator_info compute_dominator(module& m)
{
return compute_dominator_generic(module_visitor{&m});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -43,6 +43,7 @@ struct outline ...@@ -43,6 +43,7 @@ struct outline
struct param struct param
{ {
std::string parameter; std::string parameter;
uint32_t order = 0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DOM_INFO_HPP
#define MIGRAPHX_GUARD_RTGLIB_DOM_INFO_HPP
#include <migraphx/config.hpp>
#include <migraphx/instruction.hpp>
#include <unordered_map>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
struct dominator_info
{
bool strictly_dominate(instruction_ref ins1, instruction_ref ins2);
std::unordered_map<instruction_ref, instruction_ref> ins2idom;
};
dominator_info compute_dominator(module& m);
// dominator_info compute_dominator_naive(const module& m);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -25,12 +25,19 @@ auto erase(R&& r, const T& value) ...@@ -25,12 +25,19 @@ auto erase(R&& r, const T& value)
* *
* @param r The container to erase elements from * @param r The container to erase elements from
* @param pred Predicate function that selects which elements should be erased. * @param pred Predicate function that selects which elements should be erased.
* @return Returns iterator to erased element
*/ */
template <class R, class P> template <class R, class P>
auto erase_if(R&& r, P&& pred) void erase_if(R&& r, P&& pred)
{ {
return r.erase(std::remove_if(r.begin(), r.end(), pred), r.end()); auto first = r.begin();
auto last = r.end();
while(first != last)
{
if(pred(*first))
first = r.erase(first);
else
first++;
}
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -25,8 +25,8 @@ struct module_impl ...@@ -25,8 +25,8 @@ struct module_impl
// A list is used to keep references to an instruction stable // A list is used to keep references to an instruction stable
std::list<instruction> instructions; std::list<instruction> instructions;
std::unordered_set<instruction*> instruction_set; std::unordered_set<instruction*> instruction_set;
std::vector<std::string> input_names;
std::string name; std::string name;
uint32_t nparams = 0;
bool contains(instruction_ref ins) const bool contains(instruction_ref ins) const
{ {
...@@ -110,7 +110,6 @@ void module::assign(const module& m) ...@@ -110,7 +110,6 @@ void module::assign(const module& m)
{ {
impl->instructions.clear(); impl->instructions.clear();
} }
impl->input_names = m.impl->input_names;
impl->name = m.impl->name; impl->name = m.impl->name;
std::unordered_map<instruction_ref, instruction_ref> ins_map; std::unordered_map<instruction_ref, instruction_ref> ins_map;
...@@ -312,9 +311,8 @@ instruction_ref module::add_outline(const shape& s) ...@@ -312,9 +311,8 @@ instruction_ref module::add_outline(const shape& s)
instruction_ref module::add_parameter(std::string name, shape s) instruction_ref module::add_parameter(std::string name, shape s)
{ {
assert(get_parameter_shape(name) == shape{}); assert(get_parameter_shape(name) == shape{});
impl->input_names.push_back(name); impl->push_front({builtin::param{std::move(name), impl->nparams}, std::move(s), {}});
impl->nparams++;
impl->push_front({builtin::param{std::move(name)}, std::move(s), {}});
return impl->instructions.begin(); return impl->instructions.begin();
} }
...@@ -350,17 +348,21 @@ shape module::get_parameter_shape(std::string name) const ...@@ -350,17 +348,21 @@ shape module::get_parameter_shape(std::string name) const
std::vector<std::string> module::get_parameter_names() const std::vector<std::string> module::get_parameter_names() const
{ {
std::vector<std::string> result = impl->input_names; std::vector<std::string> result;
std::unordered_set<std::string> params; std::vector<builtin::param> params;
for(auto&& ins : impl->instructions) for(auto&& ins : impl->instructions)
{ {
if(ins.name() == "@param") if(ins.name() == "@param")
{ {
auto&& name = any_cast<builtin::param>(ins.get_operator()).parameter; auto&& param = any_cast<builtin::param>(ins.get_operator());
params.insert(name); params.push_back(param);
} }
} }
erase_if(result, [&](auto&& name) { return params.count(name) == 0; }); std::stable_sort(
params.begin(), params.end(), by(std::less<>{}, [](auto&& p) { return p.order; }));
std::transform(params.begin(), params.end(), std::back_inserter(result), [&](auto&& p) {
return p.parameter;
});
return result; return result;
} }
......
...@@ -249,8 +249,8 @@ void memory_coloring_impl::verify() ...@@ -249,8 +249,8 @@ void memory_coloring_impl::verify()
if(segment.begin == invalid_offset) if(segment.begin == invalid_offset)
{ {
if(!interval.is_live_on_entry) // if(!interval.is_live_on_entry)
MIGRAPHX_THROW("interval is not live on entry"); // MIGRAPHX_THROW("interval is not live on entry");
continue; continue;
} }
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/par_for.hpp> #include <migraphx/par_for.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/dom_info.hpp>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <queue> #include <queue>
...@@ -16,6 +17,7 @@ ...@@ -16,6 +17,7 @@
#include <set> #include <set>
#include <deque> #include <deque>
#include <chrono> #include <chrono>
#include <iomanip>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -88,7 +90,7 @@ struct stream_info ...@@ -88,7 +90,7 @@ struct stream_info
return args.end(); return args.end();
} }
const std::size_t min_partition_threshold = 1; const std::size_t min_partition_threshold = 2;
sort_args_by_weight(args, std::greater<>{}); sort_args_by_weight(args, std::greater<>{});
auto it = std::lower_bound(std::next(args.begin()), auto it = std::lower_bound(std::next(args.begin()),
...@@ -353,6 +355,7 @@ struct stream_info ...@@ -353,6 +355,7 @@ struct stream_info
{ {
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result; std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from; std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from;
dominator_info di = compute_dominator(p);
result.reserve(p.size()); result.reserve(p.size());
merge_from.reserve(p.size()); merge_from.reserve(p.size());
for(auto ins : reverse_iterator_for(p)) for(auto ins : reverse_iterator_for(p))
...@@ -366,8 +369,13 @@ struct stream_info ...@@ -366,8 +369,13 @@ struct stream_info
merge_from[ins].insert(merge_from[arg].begin(), merge_from[arg].end()); merge_from[ins].insert(merge_from[arg].begin(), merge_from[arg].end());
} }
auto streams = this->get_streams(ins); if(is_split_point(ins))
{
erase_if(merge_from[ins],
[&](auto merge) { return di.strictly_dominate(ins, merge); });
}
auto streams = this->get_streams(ins);
// Collect concur instructions for each merge point. // Collect concur instructions for each merge point.
for(const auto& merge : merge_from[ins]) for(const auto& merge : merge_from[ins])
{ {
...@@ -396,11 +404,18 @@ struct stream_info ...@@ -396,11 +404,18 @@ struct stream_info
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>
get_conflicts(module& p) get_conflicts(module& p)
{ {
using conflict_table_type = using conflict_table_type =
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>; std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>;
conflict_table_type conflict_table; conflict_table_type conflict_table;
auto concur_ins = this->find_concurrent_instructions(p); auto concur_ins = this->find_concurrent_instructions(p);
// Compute an index for each instruction
std::unordered_map<instruction_ref, std::size_t> ins2index;
std::size_t index_total = 0;
for(auto ins : iterator_for(p))
ins2index[ins] = index_total++;
std::vector<conflict_table_type> thread_conflict_tables( std::vector<conflict_table_type> thread_conflict_tables(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
std::vector<instruction_ref> index_to_ins; std::vector<instruction_ref> index_to_ins;
...@@ -442,14 +457,13 @@ struct stream_info ...@@ -442,14 +457,13 @@ struct stream_info
for(auto ins1 : ins1_set) for(auto ins1 : ins1_set)
{ {
auto p1 = std::distance(ins1, merge_first); auto p1 = ins2index.at(ins1);
for(auto ins2 : ins2_set) for(auto ins2 : ins2_set)
{ {
if(ins1 == ins2) if(ins1 == ins2)
continue; continue;
auto p2 = std::distance(ins2, merge_first); auto p2 = ins2index.at(ins2);
// The smaller distance means the instruction occurs later if(p2 > p1)
if(p1 > p2)
thrd_table[ins2].insert(ins1); thrd_table[ins2].insert(ins1);
else else
thrd_table[ins1].insert(ins2); thrd_table[ins1].insert(ins2);
......
#include <migraphx/dom_info.hpp>
#include <migraphx/program.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
TEST_CASE(dom1)
{
migraphx::module mm;
auto ins1 = mm.add_parameter("entry", {migraphx::shape::float_type});
auto ins2 = mm.add_instruction(pass_op{}, ins1);
auto ins3 = mm.add_instruction(pass_op{}, ins2);
auto ins4 = mm.add_instruction(pass_op{}, ins2);
auto ins5 = mm.add_instruction(pass_op{}, ins3, ins4);
auto ins6 = mm.add_instruction(pass_op{}, ins2);
auto dom = migraphx::compute_dominator(mm);
EXPECT(dom.strictly_dominate(ins1, ins2));
EXPECT(dom.strictly_dominate(ins2, ins3));
EXPECT(dom.strictly_dominate(ins2, ins4));
EXPECT(dom.strictly_dominate(ins2, ins5));
EXPECT(dom.strictly_dominate(ins2, ins6));
EXPECT(not dom.strictly_dominate(ins3, ins6));
EXPECT(not dom.strictly_dominate(ins4, ins6));
EXPECT(not dom.strictly_dominate(ins3, ins5));
EXPECT(not dom.strictly_dominate(ins4, ins5));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -79,8 +79,7 @@ struct pass_op ...@@ -79,8 +79,7 @@ struct pass_op
return {}; return {};
return inputs.front(); return inputs.front();
} }
int output_alias(const std::vector<migraphx::shape>& s) const { return s.empty() ? -1 : 0; }
int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
}; };
struct mod_pass_op struct mod_pass_op
......
This diff is collapsed.
...@@ -774,6 +774,31 @@ TEST_CASE(inception_resnet) ...@@ -774,6 +774,31 @@ TEST_CASE(inception_resnet)
t.check_conflicts(m, {c1, {i1}}); t.check_conflicts(m, {c1, {i1}});
} }
TEST_CASE(dominate_conflicts)
{
scheduler t{};
migraphx::module m;
auto one = m.add_literal(1);
auto onep1 = m.add_instruction(unary_op{}, one);
auto onep2 = m.add_instruction(unary_op{}, one);
auto binary1 = m.add_instruction(nary_op{}, onep1, onep2);
auto onep3 = m.add_instruction(unary_op{}, binary1);
auto onep4 = m.add_instruction(unary_op{}, binary1);
auto binary2 = m.add_instruction(nary_op{}, onep3, onep4);
t.run_pass(m);
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
EXPECT(t.get_stream(onep3) != t.get_stream(onep4));
EXPECT(get_wait_for(binary1) ==
get_wait_for(t.get_stream(binary1), {t.get_stream(onep1), t.get_stream(onep2)}));
t.check_conflicts(m, {{onep1}, {onep2}});
t.check_conflicts(m, {{onep3}, {onep4}});
t.check_conflicts(m, {{onep1, onep2}, {onep3, onep4}}, false);
t.check_conflicts(m, {{binary1}, {binary2}}, false);
}
TEST_CASE(inception1) TEST_CASE(inception1)
{ {
scheduler t{}; scheduler t{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment