"vscode:/vscode.git/clone" did not exist on "3eed20a33c97838ad15b85c28027922bc23255f8"
Unverified Commit 7ab06956 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Compute dominators (#525)



* rename merge_from to merge_to

* refine comments

* code backup

* clang format

* The first version that can reduce scratch memory usage

* code backup

* clang format

* code backup

* clang format

* fixed a bug related to removing gemm copy

* clang format

* code backup

* clang format

* fix review comments

* clang format

* fix unit test failure

* code backup

* clang format

* code base for further investigation

* code with both the forward and backward approach to compute the conflict table

* clang format

* clang format

* backup changes

* remove unnecessary file

* remove unnecessary code

* code backup

* clang format

* code backup

* clang format'

* fix a bug in the code

* clang format

* code backup

* clang format

* remove unused code

* remove unused code

* rename some functions

* remove print code

* code backup

* add dominator to scheduling

* add dominator algorithm to remove unnecessary conflicts

* Remove comment

* Use erase_if instead

* Formatting

* Code clean up:

* Formatting

* Add dominator info class

* Formatting

* Add dom_info

* Formatting

* Add test case and fix some bugs

* Formatting

* Add unit test for scheduler

* Formatting

* Use index map instead of distance

* Formatting

* Add memory coloring test

* Check for conflict in memory coloring

* Formatting

* Use 1 stream by default

* Update to use modules

* Formatting

* Skip live on entry check

* Formatting

* Formatting

* Fix tidy warning

* Fix tidy warning

* Formatting

* Add nolint

* Use C++17 to build everything when using clang

* Remove input names

* Formatting

* Remove input names

* Keep order of params

* Formatting
Co-authored-by: default avatarShucai Xiao <Shucai.Xiao@amd.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent e8738144
......@@ -95,4 +95,5 @@ ENV LD_LIBRARY_PATH=$PREFIX/lib
# Setup ubsan environment to printstacktrace
ENV UBSAN_OPTIONS=print_stacktrace=1
ENV ASAN_OPTIONS=detect_stack_use_after_return=1:check_initialization_order=1:strict_init_order=1
RUN ln -s /opt/rocm/llvm/bin/llvm-symbolizer /usr/bin/llvm-symbolizer
......@@ -15,6 +15,7 @@ add_library(migraphx
compile_src.cpp
cpp_generator.cpp
dead_code_elimination.cpp
dom_info.cpp
dynamic_loader.cpp
eliminate_allocation.cpp
eliminate_contiguous.cpp
......
#include <migraphx/dom_info.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/erase.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_set>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool dominator_info::strictly_dominate(instruction_ref ins1, instruction_ref ins2)
{
if(ins1 == ins2)
return false;
auto iter = ins2idom.find(ins2);
while(iter != ins2idom.end())
{
if(ins1 == iter->second)
return true;
assert(iter != ins2idom.find(iter->second));
iter = ins2idom.find(iter->second);
}
return false;
}
struct module_visitor
{
module* mm;
module& get_nodes() const { return *mm; }
const std::vector<instruction_ref>& get_children(instruction_ref ins) { return ins->inputs(); }
};
template <class Visitor>
dominator_info compute_dominator_generic(Visitor v)
{
dominator_info info;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> instr2_doms;
for(instruction_ref ins : iterator_for(v.get_nodes()))
{
const std::vector<instruction_ref>& children = v.get_children(ins);
if(children.size() == 1)
{
info.ins2idom[ins] = children.front();
instr2_doms[ins].insert(children.front());
}
else if(children.size() > 1)
{
auto&& doms = instr2_doms[ins];
doms = instr2_doms[children.front()];
std::for_each(children.begin() + 1, children.end(), [&](instruction_ref child) {
auto&& child_doms = instr2_doms[child];
erase_if(doms, [&](auto x) { return not contains(child_doms, x); });
});
auto iter = std::find_if(doms.begin(), doms.end(), [&](auto dom1) {
return std::none_of(doms.begin(), doms.end(), [&](auto dom2) {
if(dom1 == dom2)
return false;
return info.strictly_dominate(dom1, dom2);
});
});
if(iter != doms.end())
info.ins2idom[ins] = *iter;
}
instr2_doms[ins].insert(ins);
}
return info;
}
dominator_info compute_dominator(module& m)
{
return compute_dominator_generic(module_visitor{&m});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -43,6 +43,7 @@ struct outline
struct param
{
std::string parameter;
uint32_t order = 0;
template <class Self, class F>
static auto reflect(Self& self, F f)
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DOM_INFO_HPP
#define MIGRAPHX_GUARD_RTGLIB_DOM_INFO_HPP
#include <migraphx/config.hpp>
#include <migraphx/instruction.hpp>
#include <unordered_map>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
struct dominator_info
{
bool strictly_dominate(instruction_ref ins1, instruction_ref ins2);
std::unordered_map<instruction_ref, instruction_ref> ins2idom;
};
dominator_info compute_dominator(module& m);
// dominator_info compute_dominator_naive(const module& m);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -25,12 +25,19 @@ auto erase(R&& r, const T& value)
*
* @param r The container to erase elements from
* @param pred Predicate function that selects which elements should be erased.
* @return Returns iterator to erased element
*/
template <class R, class P>
auto erase_if(R&& r, P&& pred)
void erase_if(R&& r, P&& pred)
{
return r.erase(std::remove_if(r.begin(), r.end(), pred), r.end());
auto first = r.begin();
auto last = r.end();
while(first != last)
{
if(pred(*first))
first = r.erase(first);
else
first++;
}
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -25,8 +25,8 @@ struct module_impl
// A list is used to keep references to an instruction stable
std::list<instruction> instructions;
std::unordered_set<instruction*> instruction_set;
std::vector<std::string> input_names;
std::string name;
uint32_t nparams = 0;
bool contains(instruction_ref ins) const
{
......@@ -110,8 +110,7 @@ void module::assign(const module& m)
{
impl->instructions.clear();
}
impl->input_names = m.impl->input_names;
impl->name = m.impl->name;
impl->name = m.impl->name;
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(m))
......@@ -312,9 +311,8 @@ instruction_ref module::add_outline(const shape& s)
instruction_ref module::add_parameter(std::string name, shape s)
{
assert(get_parameter_shape(name) == shape{});
impl->input_names.push_back(name);
impl->push_front({builtin::param{std::move(name)}, std::move(s), {}});
impl->push_front({builtin::param{std::move(name), impl->nparams}, std::move(s), {}});
impl->nparams++;
return impl->instructions.begin();
}
......@@ -350,17 +348,21 @@ shape module::get_parameter_shape(std::string name) const
std::vector<std::string> module::get_parameter_names() const
{
std::vector<std::string> result = impl->input_names;
std::unordered_set<std::string> params;
std::vector<std::string> result;
std::vector<builtin::param> params;
for(auto&& ins : impl->instructions)
{
if(ins.name() == "@param")
{
auto&& name = any_cast<builtin::param>(ins.get_operator()).parameter;
params.insert(name);
auto&& param = any_cast<builtin::param>(ins.get_operator());
params.push_back(param);
}
}
erase_if(result, [&](auto&& name) { return params.count(name) == 0; });
std::stable_sort(
params.begin(), params.end(), by(std::less<>{}, [](auto&& p) { return p.order; }));
std::transform(params.begin(), params.end(), std::back_inserter(result), [&](auto&& p) {
return p.parameter;
});
return result;
}
......
......@@ -249,8 +249,8 @@ void memory_coloring_impl::verify()
if(segment.begin == invalid_offset)
{
if(!interval.is_live_on_entry)
MIGRAPHX_THROW("interval is not live on entry");
// if(!interval.is_live_on_entry)
// MIGRAPHX_THROW("interval is not live on entry");
continue;
}
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
......@@ -6,6 +6,7 @@
#include <migraphx/par_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dom_info.hpp>
#include <unordered_map>
#include <unordered_set>
#include <queue>
......@@ -16,6 +17,7 @@
#include <set>
#include <deque>
#include <chrono>
#include <iomanip>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -88,7 +90,7 @@ struct stream_info
return args.end();
}
const std::size_t min_partition_threshold = 1;
const std::size_t min_partition_threshold = 2;
sort_args_by_weight(args, std::greater<>{});
auto it = std::lower_bound(std::next(args.begin()),
......@@ -353,6 +355,7 @@ struct stream_info
{
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from;
dominator_info di = compute_dominator(p);
result.reserve(p.size());
merge_from.reserve(p.size());
for(auto ins : reverse_iterator_for(p))
......@@ -366,8 +369,13 @@ struct stream_info
merge_from[ins].insert(merge_from[arg].begin(), merge_from[arg].end());
}
auto streams = this->get_streams(ins);
if(is_split_point(ins))
{
erase_if(merge_from[ins],
[&](auto merge) { return di.strictly_dominate(ins, merge); });
}
auto streams = this->get_streams(ins);
// Collect concur instructions for each merge point.
for(const auto& merge : merge_from[ins])
{
......@@ -396,11 +404,18 @@ struct stream_info
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>
get_conflicts(module& p)
{
using conflict_table_type =
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>;
conflict_table_type conflict_table;
auto concur_ins = this->find_concurrent_instructions(p);
// Compute an index for each instruction
std::unordered_map<instruction_ref, std::size_t> ins2index;
std::size_t index_total = 0;
for(auto ins : iterator_for(p))
ins2index[ins] = index_total++;
std::vector<conflict_table_type> thread_conflict_tables(
std::thread::hardware_concurrency());
std::vector<instruction_ref> index_to_ins;
......@@ -442,14 +457,13 @@ struct stream_info
for(auto ins1 : ins1_set)
{
auto p1 = std::distance(ins1, merge_first);
auto p1 = ins2index.at(ins1);
for(auto ins2 : ins2_set)
{
if(ins1 == ins2)
continue;
auto p2 = std::distance(ins2, merge_first);
// The smaller distance means the instruction occurs later
if(p1 > p2)
auto p2 = ins2index.at(ins2);
if(p2 > p1)
thrd_table[ins2].insert(ins1);
else
thrd_table[ins1].insert(ins2);
......
#include <migraphx/dom_info.hpp>
#include <migraphx/program.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
TEST_CASE(dom1)
{
migraphx::module mm;
auto ins1 = mm.add_parameter("entry", {migraphx::shape::float_type});
auto ins2 = mm.add_instruction(pass_op{}, ins1);
auto ins3 = mm.add_instruction(pass_op{}, ins2);
auto ins4 = mm.add_instruction(pass_op{}, ins2);
auto ins5 = mm.add_instruction(pass_op{}, ins3, ins4);
auto ins6 = mm.add_instruction(pass_op{}, ins2);
auto dom = migraphx::compute_dominator(mm);
EXPECT(dom.strictly_dominate(ins1, ins2));
EXPECT(dom.strictly_dominate(ins2, ins3));
EXPECT(dom.strictly_dominate(ins2, ins4));
EXPECT(dom.strictly_dominate(ins2, ins5));
EXPECT(dom.strictly_dominate(ins2, ins6));
EXPECT(not dom.strictly_dominate(ins3, ins6));
EXPECT(not dom.strictly_dominate(ins4, ins6));
EXPECT(not dom.strictly_dominate(ins3, ins5));
EXPECT(not dom.strictly_dominate(ins4, ins5));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -79,8 +79,7 @@ struct pass_op
return {};
return inputs.front();
}
int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
int output_alias(const std::vector<migraphx::shape>& s) const { return s.empty() ? -1 : 0; }
};
struct mod_pass_op
......
This diff is collapsed.
......@@ -774,6 +774,31 @@ TEST_CASE(inception_resnet)
t.check_conflicts(m, {c1, {i1}});
}
TEST_CASE(dominate_conflicts)
{
scheduler t{};
migraphx::module m;
auto one = m.add_literal(1);
auto onep1 = m.add_instruction(unary_op{}, one);
auto onep2 = m.add_instruction(unary_op{}, one);
auto binary1 = m.add_instruction(nary_op{}, onep1, onep2);
auto onep3 = m.add_instruction(unary_op{}, binary1);
auto onep4 = m.add_instruction(unary_op{}, binary1);
auto binary2 = m.add_instruction(nary_op{}, onep3, onep4);
t.run_pass(m);
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
EXPECT(t.get_stream(onep3) != t.get_stream(onep4));
EXPECT(get_wait_for(binary1) ==
get_wait_for(t.get_stream(binary1), {t.get_stream(onep1), t.get_stream(onep2)}));
t.check_conflicts(m, {{onep1}, {onep2}});
t.check_conflicts(m, {{onep3}, {onep4}});
t.check_conflicts(m, {{onep1, onep2}, {onep3, onep4}}, false);
t.check_conflicts(m, {{binary1}, {binary2}}, false);
}
TEST_CASE(inception1)
{
scheduler t{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment