Commit dc23d605 authored by Shucai Xiao's avatar Shucai Xiao Committed by mvermeulen
Browse files

Accelerate calculating conflict table (#382)

* accelerate conflict table computation

* removed an unnecessary comma
parent a797f890
......@@ -27,13 +27,25 @@ struct joinable_thread : std::thread
}
};
template <class F>
auto thread_invoke(std::size_t i, std::size_t tid, F f) -> decltype(f(i, tid))
{
f(i, tid);
}
template <class F>
auto thread_invoke(std::size_t i, std::size_t, F f) -> decltype(f(i))
{
f(i);
}
template <class F>
void par_for_impl(std::size_t n, std::size_t threadsize, F f)
{
if(threadsize <= 1)
{
for(std::size_t i = 0; i < n; i++)
f(i);
thread_invoke(i, 0, f);
}
else
{
......@@ -45,16 +57,18 @@ void par_for_impl(std::size_t n, std::size_t threadsize, F f)
std::size_t grainsize = std::ceil(static_cast<double>(n) / threads.size());
std::size_t work = 0;
std::generate(threads.begin(), threads.end(), [=, &work] {
std::size_t tid = 0;
std::generate(threads.begin(), threads.end(), [=, &work, &tid] {
auto result = joinable_thread([=] {
std::size_t start = work;
std::size_t last = std::min(n, work + grainsize);
for(std::size_t i = start; i < last; i++)
{
f(i);
thread_invoke(i, tid, f);
}
});
work += grainsize;
++tid;
return result;
});
assert(work >= n);
......
......@@ -4,12 +4,16 @@
#include <migraphx/op/identity.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_map>
#include <unordered_set>
#include <thread>
#include <mutex>
#include <set>
#include <deque>
#include <chrono>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -303,31 +307,78 @@ struct stream_info
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>
get_conflicts(program& p)
{
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> conflict_table;
using conflict_table_type =
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>>;
conflict_table_type conflict_table;
auto concur_ins = this->find_concurrent_instructions(p);
for(auto&& merge : concur_ins)
std::vector<conflict_table_type> thread_conflict_tables(
std::thread::hardware_concurrency());
std::vector<instruction_ref> index_to_ins;
index_to_ins.reserve(concur_ins.size());
std::transform(concur_ins.begin(),
concur_ins.end(),
std::back_inserter(index_to_ins),
[](auto&& it) { return it.first; });
par_for(concur_ins.size(), [&](auto ins_index, auto tid) {
auto merge_first = index_to_ins[ins_index];
assert(concur_ins.count(merge_first) > 0);
auto& merge_second = concur_ins.at(merge_first);
// ensure there are enough elements for different threads
assert(tid < thread_conflict_tables.size());
auto& thrd_table = thread_conflict_tables.at(tid);
std::unordered_set<instruction_ref> checked_ins_set;
auto range_i = range(merge_second.begin(), std::prev(merge_second.end()));
for(auto it_i : iterator_for(range_i))
{
dfor(merge.second.size(), merge.second.size())([&](auto i, auto j) {
if(i == j)
return;
for(auto ins1 : merge.second[i])
std::unordered_set<instruction_ref> ins1_set;
std::copy_if(it_i->begin(),
it_i->end(),
std::inserter(ins1_set, ins1_set.end()),
[&](auto i) { return not contains(checked_ins_set, i); });
checked_ins_set.insert(ins1_set.begin(), ins1_set.end());
auto range_j = range(std::next(it_i), merge_second.end());
std::unordered_set<instruction_ref> ins2_set;
for(auto it_j : iterator_for(range_j))
{
std::copy_if(it_j->begin(),
it_j->end(),
std::inserter(ins2_set, ins2_set.end()),
[&](auto i) { return not contains(checked_ins_set, i); });
}
for(auto ins1 : ins1_set)
{
auto p1 = std::distance(ins1, merge.first);
for(auto ins2 : merge.second[j])
auto p1 = std::distance(ins1, merge_first);
for(auto ins2 : ins2_set)
{
if(ins1 == ins2)
continue;
auto p2 = std::distance(ins2, merge.first);
auto p2 = std::distance(ins2, merge_first);
// The smaller distance means the instruction occurs later
if(p1 > p2)
conflict_table[ins2].insert(ins1);
thrd_table[ins2].insert(ins1);
else
conflict_table[ins1].insert(ins2);
thrd_table[ins1].insert(ins2);
}
}
}
});
// merge thread_conflict_tables together
for(auto& tbl : thread_conflict_tables)
{
for(auto& it : tbl)
{
conflict_table[it.first].insert(it.second.begin(), it.second.end());
}
}
// Remove duplicates
// Remove instructions from the conflict table of an ealier instruction
for(auto&& ip : conflict_table)
{
auto ins1 = ip.first;
......@@ -335,6 +386,7 @@ struct stream_info
if(contains(conflict_table[ins2], ins1))
conflict_table[ins2].erase(ins1);
}
return conflict_table;
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment