Commit 473a045a authored by Paul's avatar Paul
Browse files

Fix memory conflicts

parent 84221940
......@@ -3,21 +3,64 @@
#include <cassert>
#include <type_traits>
#include <iterator>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
template <class T>
struct iterator_for_select
{
template<class T>
static T deref(T x)
{
return x;
}
template<class T>
static auto begin(T* x)
{
return x->begin();
}
template<class T>
static auto end(T* x)
{
return x->end();
}
};
struct iterator_for_select_reverse
{
template<class T>
static auto deref(T x)
{
return std::prev(x.base());
}
template<class T>
static auto begin(T* x)
{
return std::make_reverse_iterator(x->end());
}
template<class T>
static auto end(T* x)
{
return std::make_reverse_iterator(x->begin());
}
};
template <class T, class Selector=iterator_for_select>
struct iterator_for_range
{
T* base;
using base_iterator = std::remove_reference_t<decltype(base->begin())>;
using base_iterator = std::remove_reference_t<decltype(Selector::begin(base))>;
struct iterator
{
base_iterator i;
base_iterator operator*() const { return i; }
auto operator*() const { return Selector::deref(i); }
base_iterator operator++() { return ++i; }
bool operator!=(const iterator& rhs) const { return i != rhs.i; }
};
......@@ -25,12 +68,12 @@ struct iterator_for_range
iterator begin()
{
assert(base != nullptr);
return {base->begin()};
return {Selector::begin(base)};
}
iterator end()
{
assert(base != nullptr);
return {base->end()};
return {Selector::end(base)};
}
};
template <class T>
......@@ -39,6 +82,12 @@ iterator_for_range<T> iterator_for(T& x)
return {&x};
}
template <class T>
iterator_for_range<T, iterator_for_select_reverse> reverse_iterator_for(T& x)
{
return {&x};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -162,7 +162,7 @@ struct stream_info
}
template <class Selector>
auto get_streams(instruction_ref start, Selector select) const
auto get_streams_from(instruction_ref start, Selector select) const
{
return [=](auto f) {
return fix<bool>([&](auto self, auto ins) {
......@@ -184,16 +184,28 @@ struct stream_info
};
}
std::unordered_set<std::size_t> get_streams(instruction_ref ins)
{
if (has_stream(ins))
return {get_stream(ins)};
std::unordered_set<std::size_t> result;
get_streams_from(ins, get_inputs())([&](auto s) {
result.insert(s);
return true;
});
return result;
}
template <class... Ts>
bool is_merge_point(instruction_ref ins, Ts... xs) const
{
return different(get_streams(ins, get_inputs()), xs...);
return different(get_streams_from(ins, get_inputs()), xs...);
}
template <class... Ts>
bool is_split_point(instruction_ref ins, Ts... xs) const
{
return different(get_streams(ins, get_outputs()), xs...);
return different(get_streams_from(ins, get_outputs()), xs...);
}
std::vector<instruction_ref> get_recorded_instructions(instruction_ref start)
......@@ -225,7 +237,7 @@ struct stream_info
std::vector<std::size_t> wait_for(instruction_ref ins) const
{
std::vector<std::size_t> result;
get_streams(ins, get_inputs())([&](auto s) {
get_streams_from(ins, get_inputs())([&](auto s) {
result.push_back(s);
return true;
});
......@@ -243,35 +255,27 @@ struct stream_info
find_concurrent_instructions(program& p)
{
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> split_from;
for(auto ins : iterator_for(p))
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from;
for(auto ins : reverse_iterator_for(p))
{
if(iweights[ins] == 0)
continue;
for(auto&& arg : ins->inputs())
for(auto&& arg : ins->outputs())
{
if(is_split_point(arg))
split_from[ins].insert(arg);
split_from[ins].insert(split_from[arg].begin(), split_from[arg].end());
if(is_merge_point(arg))
merge_from[ins].insert(arg);
merge_from[ins].insert(merge_from[arg].begin(), merge_from[arg].end());
}
auto stream = get_stream(ins);
// if (is_merge_point(ins))
// {
// // post-dominator kills split point.
// for(auto& split : split_from[ins])
// {
// if(strictly_post_dominates(ins, split))
// split_from[ins].erase(split);
// }
// }
// Collect concur instructions for each split point.
for(auto& split : split_from[ins])
auto streams = get_streams(ins);
// Collect concur instructions for each merge point.
for(auto& merge : merge_from[ins])
{
if(result[split].size() <= stream)
result[split].resize(stream + 1);
result[split][stream].push_back(ins);
for(auto stream:streams)
{
if(result[merge].size() <= stream)
result[merge].resize(stream + 1);
result[merge][stream].push_back(ins);
}
}
}
return result;
......@@ -304,7 +308,7 @@ void schedule::apply(program& p) const
std::cout << ":";
std::cout << " weight=" << si.weights.at(ins);
std::cout << " input={";
si.get_streams(ins, get_inputs())([&](auto s) {
si.get_streams_from(ins, get_inputs())([&](auto s) {
std::cout << s << ",";
return true;
});
......@@ -367,20 +371,16 @@ void schedule::apply(program& p) const
// Add memory conflicts
auto concur_ins = si.find_concurrent_instructions(p);
for(auto&& split : concur_ins)
for(auto&& merge : concur_ins)
{
dfor(split.second.size(), split.second.size())([&](auto i, auto j) {
dfor(merge.second.size(), merge.second.size())([&](auto i, auto j) {
if(i == j)
return;
for(auto ins1 : split.second[i])
for(auto ins1 : merge.second[i])
{
auto args = split.second[j];
auto args = merge.second[j];
args.insert(args.begin(), ins1);
auto point = std::max_element(args.begin(), args.end(), [&](auto x, auto y) {
return std::distance(split.first, x) < std::distance(split.first, y);
});
p.insert_instruction(std::next(*point), op::identity{}, args);
p.insert_instruction(merge.first, op::identity{}, args);
}
});
}
......
......@@ -94,7 +94,7 @@ struct schedule_model_test
}
(*ins2wait_for)[ins]->push_back(wait2stream->at(wait_id));
}
void record(migraphx::program& p, migraphx::instruction_ref ins, std::size_t wait_id) const
void record(migraphx::program&, migraphx::instruction_ref ins, std::size_t wait_id) const
{
(*wait2stream)[wait_id] = ins2stream->at(ins);
}
......@@ -181,6 +181,9 @@ std::vector<std::size_t> get_wait_for(std::size_t wait_on, std::vector<std::size
std::vector<std::size_t> get_wait_for(migraphx::instruction_ref ins)
{
auto wait_ins = std::prev(ins);
// Skip identity operators
while(wait_ins->name() == "identity")
wait_ins = std::prev(wait_ins);
if(wait_ins->name() != "wait_event")
return {};
auto wf = *migraphx::any_cast<wait_event>(wait_ins->get_operator()).wait_for;
......@@ -338,7 +341,7 @@ TEST_CASE(double_entry)
EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) ==
get_wait_for(t.get_stream(binary), {t.get_stream(onep), t.get_stream(twop)}));
// EXPECT(check_conflicts(p, onep, twop));
EXPECT(check_conflicts(p, onep, twop));
}
TEST_CASE(two_branches)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment