Commit 473a045a authored by Paul's avatar Paul
Browse files

Fix memory conflicts

parent 84221940
...@@ -3,21 +3,64 @@ ...@@ -3,21 +3,64 @@
#include <cassert> #include <cassert>
#include <type_traits> #include <type_traits>
#include <iterator>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class T> struct iterator_for_select
{
template<class T>
static T deref(T x)
{
return x;
}
template<class T>
static auto begin(T* x)
{
return x->begin();
}
template<class T>
static auto end(T* x)
{
return x->end();
}
};
struct iterator_for_select_reverse
{
template<class T>
static auto deref(T x)
{
return std::prev(x.base());
}
template<class T>
static auto begin(T* x)
{
return std::make_reverse_iterator(x->end());
}
template<class T>
static auto end(T* x)
{
return std::make_reverse_iterator(x->begin());
}
};
template <class T, class Selector=iterator_for_select>
struct iterator_for_range struct iterator_for_range
{ {
T* base; T* base;
using base_iterator = std::remove_reference_t<decltype(base->begin())>; using base_iterator = std::remove_reference_t<decltype(Selector::begin(base))>;
struct iterator struct iterator
{ {
base_iterator i; base_iterator i;
base_iterator operator*() const { return i; } auto operator*() const { return Selector::deref(i); }
base_iterator operator++() { return ++i; } base_iterator operator++() { return ++i; }
bool operator!=(const iterator& rhs) const { return i != rhs.i; } bool operator!=(const iterator& rhs) const { return i != rhs.i; }
}; };
...@@ -25,12 +68,12 @@ struct iterator_for_range ...@@ -25,12 +68,12 @@ struct iterator_for_range
iterator begin() iterator begin()
{ {
assert(base != nullptr); assert(base != nullptr);
return {base->begin()}; return {Selector::begin(base)};
} }
iterator end() iterator end()
{ {
assert(base != nullptr); assert(base != nullptr);
return {base->end()}; return {Selector::end(base)};
} }
}; };
template <class T> template <class T>
...@@ -39,6 +82,12 @@ iterator_for_range<T> iterator_for(T& x) ...@@ -39,6 +82,12 @@ iterator_for_range<T> iterator_for(T& x)
return {&x}; return {&x};
} }
template <class T>
iterator_for_range<T, iterator_for_select_reverse> reverse_iterator_for(T& x)
{
return {&x};
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -162,7 +162,7 @@ struct stream_info ...@@ -162,7 +162,7 @@ struct stream_info
} }
template <class Selector> template <class Selector>
auto get_streams(instruction_ref start, Selector select) const auto get_streams_from(instruction_ref start, Selector select) const
{ {
return [=](auto f) { return [=](auto f) {
return fix<bool>([&](auto self, auto ins) { return fix<bool>([&](auto self, auto ins) {
...@@ -184,16 +184,28 @@ struct stream_info ...@@ -184,16 +184,28 @@ struct stream_info
}; };
} }
std::unordered_set<std::size_t> get_streams(instruction_ref ins)
{
if (has_stream(ins))
return {get_stream(ins)};
std::unordered_set<std::size_t> result;
get_streams_from(ins, get_inputs())([&](auto s) {
result.insert(s);
return true;
});
return result;
}
template <class... Ts> template <class... Ts>
bool is_merge_point(instruction_ref ins, Ts... xs) const bool is_merge_point(instruction_ref ins, Ts... xs) const
{ {
return different(get_streams(ins, get_inputs()), xs...); return different(get_streams_from(ins, get_inputs()), xs...);
} }
template <class... Ts> template <class... Ts>
bool is_split_point(instruction_ref ins, Ts... xs) const bool is_split_point(instruction_ref ins, Ts... xs) const
{ {
return different(get_streams(ins, get_outputs()), xs...); return different(get_streams_from(ins, get_outputs()), xs...);
} }
std::vector<instruction_ref> get_recorded_instructions(instruction_ref start) std::vector<instruction_ref> get_recorded_instructions(instruction_ref start)
...@@ -225,7 +237,7 @@ struct stream_info ...@@ -225,7 +237,7 @@ struct stream_info
std::vector<std::size_t> wait_for(instruction_ref ins) const std::vector<std::size_t> wait_for(instruction_ref ins) const
{ {
std::vector<std::size_t> result; std::vector<std::size_t> result;
get_streams(ins, get_inputs())([&](auto s) { get_streams_from(ins, get_inputs())([&](auto s) {
result.push_back(s); result.push_back(s);
return true; return true;
}); });
...@@ -243,35 +255,27 @@ struct stream_info ...@@ -243,35 +255,27 @@ struct stream_info
find_concurrent_instructions(program& p) find_concurrent_instructions(program& p)
{ {
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result; std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> split_from; std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from;
for(auto ins : iterator_for(p)) for(auto ins : reverse_iterator_for(p))
{ {
if(iweights[ins] == 0) for(auto&& arg : ins->outputs())
continue;
for(auto&& arg : ins->inputs())
{ {
if(is_split_point(arg)) if(is_merge_point(arg))
split_from[ins].insert(arg); merge_from[ins].insert(arg);
split_from[ins].insert(split_from[arg].begin(), split_from[arg].end()); merge_from[ins].insert(merge_from[arg].begin(), merge_from[arg].end());
} }
auto stream = get_stream(ins); auto streams = get_streams(ins);
// if (is_merge_point(ins))
// {
// // post-dominator kills split point.
// for(auto& split : split_from[ins])
// {
// if(strictly_post_dominates(ins, split))
// split_from[ins].erase(split);
// }
// }
// Collect concur instructions for each split point. // Collect concur instructions for each merge point.
for(auto& split : split_from[ins]) for(auto& merge : merge_from[ins])
{
for(auto stream:streams)
{ {
if(result[split].size() <= stream) if(result[merge].size() <= stream)
result[split].resize(stream + 1); result[merge].resize(stream + 1);
result[split][stream].push_back(ins); result[merge][stream].push_back(ins);
}
} }
} }
return result; return result;
...@@ -304,7 +308,7 @@ void schedule::apply(program& p) const ...@@ -304,7 +308,7 @@ void schedule::apply(program& p) const
std::cout << ":"; std::cout << ":";
std::cout << " weight=" << si.weights.at(ins); std::cout << " weight=" << si.weights.at(ins);
std::cout << " input={"; std::cout << " input={";
si.get_streams(ins, get_inputs())([&](auto s) { si.get_streams_from(ins, get_inputs())([&](auto s) {
std::cout << s << ","; std::cout << s << ",";
return true; return true;
}); });
...@@ -367,20 +371,16 @@ void schedule::apply(program& p) const ...@@ -367,20 +371,16 @@ void schedule::apply(program& p) const
// Add memory conflicts // Add memory conflicts
auto concur_ins = si.find_concurrent_instructions(p); auto concur_ins = si.find_concurrent_instructions(p);
for(auto&& split : concur_ins) for(auto&& merge : concur_ins)
{ {
dfor(split.second.size(), split.second.size())([&](auto i, auto j) { dfor(merge.second.size(), merge.second.size())([&](auto i, auto j) {
if(i == j) if(i == j)
return; return;
for(auto ins1 : split.second[i]) for(auto ins1 : merge.second[i])
{ {
auto args = split.second[j]; auto args = merge.second[j];
args.insert(args.begin(), ins1); args.insert(args.begin(), ins1);
p.insert_instruction(merge.first, op::identity{}, args);
auto point = std::max_element(args.begin(), args.end(), [&](auto x, auto y) {
return std::distance(split.first, x) < std::distance(split.first, y);
});
p.insert_instruction(std::next(*point), op::identity{}, args);
} }
}); });
} }
......
...@@ -94,7 +94,7 @@ struct schedule_model_test ...@@ -94,7 +94,7 @@ struct schedule_model_test
} }
(*ins2wait_for)[ins]->push_back(wait2stream->at(wait_id)); (*ins2wait_for)[ins]->push_back(wait2stream->at(wait_id));
} }
void record(migraphx::program& p, migraphx::instruction_ref ins, std::size_t wait_id) const void record(migraphx::program&, migraphx::instruction_ref ins, std::size_t wait_id) const
{ {
(*wait2stream)[wait_id] = ins2stream->at(ins); (*wait2stream)[wait_id] = ins2stream->at(ins);
} }
...@@ -181,6 +181,9 @@ std::vector<std::size_t> get_wait_for(std::size_t wait_on, std::vector<std::size ...@@ -181,6 +181,9 @@ std::vector<std::size_t> get_wait_for(std::size_t wait_on, std::vector<std::size
std::vector<std::size_t> get_wait_for(migraphx::instruction_ref ins) std::vector<std::size_t> get_wait_for(migraphx::instruction_ref ins)
{ {
auto wait_ins = std::prev(ins); auto wait_ins = std::prev(ins);
// Skip identity operators
while(wait_ins->name() == "identity")
wait_ins = std::prev(wait_ins);
if(wait_ins->name() != "wait_event") if(wait_ins->name() != "wait_event")
return {}; return {};
auto wf = *migraphx::any_cast<wait_event>(wait_ins->get_operator()).wait_for; auto wf = *migraphx::any_cast<wait_event>(wait_ins->get_operator()).wait_for;
...@@ -338,7 +341,7 @@ TEST_CASE(double_entry) ...@@ -338,7 +341,7 @@ TEST_CASE(double_entry)
EXPECT(t.get_stream(binary) == 0); EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) == EXPECT(get_wait_for(binary) ==
get_wait_for(t.get_stream(binary), {t.get_stream(onep), t.get_stream(twop)})); get_wait_for(t.get_stream(binary), {t.get_stream(onep), t.get_stream(twop)}));
// EXPECT(check_conflicts(p, onep, twop)); EXPECT(check_conflicts(p, onep, twop));
} }
TEST_CASE(two_branches) TEST_CASE(two_branches)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment