Commit dc23d605 authored by Shucai Xiao's avatar Shucai Xiao Committed by mvermeulen
Browse files

Accelerate calculating conflict table (#382)

* accelerate conflict table computation

* removed an unnecessary comma
parent a797f890
...@@ -27,13 +27,25 @@ struct joinable_thread : std::thread ...@@ -27,13 +27,25 @@ struct joinable_thread : std::thread
} }
}; };
template <class F>
auto thread_invoke(std::size_t i, std::size_t tid, F f) -> decltype(f(i, tid))
{
f(i, tid);
}
template <class F>
auto thread_invoke(std::size_t i, std::size_t, F f) -> decltype(f(i))
{
f(i);
}
template <class F> template <class F>
void par_for_impl(std::size_t n, std::size_t threadsize, F f) void par_for_impl(std::size_t n, std::size_t threadsize, F f)
{ {
if(threadsize <= 1) if(threadsize <= 1)
{ {
for(std::size_t i = 0; i < n; i++) for(std::size_t i = 0; i < n; i++)
f(i); thread_invoke(i, 0, f);
} }
else else
{ {
...@@ -45,16 +57,18 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f) ...@@ -45,16 +57,18 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f)
std::size_t grainsize = std::ceil(static_cast<double>(n) / threads.size()); std::size_t grainsize = std::ceil(static_cast<double>(n) / threads.size());
std::size_t work = 0; std::size_t work = 0;
std::generate(threads.begin(), threads.end(), [=, &work] { std::size_t tid = 0;
std::generate(threads.begin(), threads.end(), [=, &work, &tid] {
auto result = joinable_thread([=] { auto result = joinable_thread([=] {
std::size_t start = work; std::size_t start = work;
std::size_t last = std::min(n, work + grainsize); std::size_t last = std::min(n, work + grainsize);
for(std::size_t i = start; i < last; i++) for(std::size_t i = start; i < last; i++)
{ {
f(i); thread_invoke(i, tid, f);
} }
}); });
work += grainsize; work += grainsize;
++tid;
return result; return result;
}); });
assert(work >= n); assert(work >= n);
......
...@@ -4,12 +4,16 @@ ...@@ -4,12 +4,16 @@
#include <migraphx/op/identity.hpp> #include <migraphx/op/identity.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp> #include <migraphx/dfor.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <thread>
#include <mutex>
#include <set> #include <set>
#include <deque> #include <deque>
#include <chrono>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -303,31 +307,78 @@ struct stream_info ...@@ -303,31 +307,78 @@ struct stream_info
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>
get_conflicts(program& p) get_conflicts(program& p)
{ {
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> conflict_table; using conflict_table_type =
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>;
conflict_table_type conflict_table;
auto concur_ins = this->find_concurrent_instructions(p); auto concur_ins = this->find_concurrent_instructions(p);
for(auto&& merge : concur_ins)
std::vector<conflict_table_type> thread_conflict_tables(
std::thread::hardware_concurrency());
std::vector<instruction_ref> index_to_ins;
index_to_ins.reserve(concur_ins.size());
std::transform(concur_ins.begin(),
concur_ins.end(),
std::back_inserter(index_to_ins),
[](auto&& it) { return it.first; });
par_for(concur_ins.size(), [&](auto ins_index, auto tid) {
auto merge_first = index_to_ins[ins_index];
assert(concur_ins.count(merge_first) > 0);
auto& merge_second = concur_ins.at(merge_first);
// ensure there are enough elements for different threads
assert(tid < thread_conflict_tables.size());
auto& thrd_table = thread_conflict_tables.at(tid);
std::unordered_set<instruction_ref> checked_ins_set;
auto range_i = range(merge_second.begin(), std::prev(merge_second.end()));
for(auto it_i : iterator_for(range_i))
{ {
dfor(merge.second.size(), merge.second.size())([&](auto i, auto j) { std::unordered_set<instruction_ref> ins1_set;
if(i == j) std::copy_if(it_i->begin(),
return; it_i->end(),
for(auto ins1 : merge.second[i]) std::inserter(ins1_set, ins1_set.end()),
[&](auto i) { return not contains(checked_ins_set, i); });
checked_ins_set.insert(ins1_set.begin(), ins1_set.end());
auto range_j = range(std::next(it_i), merge_second.end());
std::unordered_set<instruction_ref> ins2_set;
for(auto it_j : iterator_for(range_j))
{
std::copy_if(it_j->begin(),
it_j->end(),
std::inserter(ins2_set, ins2_set.end()),
[&](auto i) { return not contains(checked_ins_set, i); });
}
for(auto ins1 : ins1_set)
{ {
auto p1 = std::distance(ins1, merge.first); auto p1 = std::distance(ins1, merge_first);
for(auto ins2 : merge.second[j]) for(auto ins2 : ins2_set)
{ {
if(ins1 == ins2) if(ins1 == ins2)
continue; continue;
auto p2 = std::distance(ins2, merge.first); auto p2 = std::distance(ins2, merge_first);
// The smaller distance means the instruction occurs later // The smaller distance means the instruction occurs later
if(p1 > p2) if(p1 > p2)
conflict_table[ins2].insert(ins1); thrd_table[ins2].insert(ins1);
else else
conflict_table[ins1].insert(ins2); thrd_table[ins1].insert(ins2);
}
} }
} }
}); });
// merge thread_conflict_tables together
for(auto& tbl : thread_conflict_tables)
{
for(auto& it : tbl)
{
conflict_table[it.first].insert(it.second.begin(), it.second.end());
}
} }
// Remove duplicates
// Remove instructions from the conflict table of an ealier instruction
for(auto&& ip : conflict_table) for(auto&& ip : conflict_table)
{ {
auto ins1 = ip.first; auto ins1 = ip.first;
...@@ -335,6 +386,7 @@ struct stream_info ...@@ -335,6 +386,7 @@ struct stream_info
if(contains(conflict_table[ins2], ins1)) if(contains(conflict_table[ins2], ins1))
conflict_table[ins2].erase(ins1); conflict_table[ins2].erase(ins1);
} }
return conflict_table; return conflict_table;
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment