Commit ffc51fe7 authored by Paul's avatar Paul
Browse files

Lazy compute input streams

parent a1dcece1
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <unordered_map> #include <unordered_map>
...@@ -27,10 +28,14 @@ struct stream_info ...@@ -27,10 +28,14 @@ struct stream_info
fix<std::size_t>([&](auto self, auto ins) -> std::size_t { fix<std::size_t>([&](auto self, auto ins) -> std::size_t {
if(weights.count(ins) == 0) if(weights.count(ins) == 0)
{ {
std::size_t weight = 0;
auto&& op = ins->get_operator();
if(not is_context_free(op) and op.name()[0] != '@')
weight = model.weight(op);
weights[ins] = weights[ins] =
std::accumulate(ins->inputs().begin(), std::accumulate(ins->inputs().begin(),
ins->inputs().end(), ins->inputs().end(),
model.weight(ins->get_operator()), weight,
[&](std::size_t w, instruction_ref i) { return w + self(i); }); [&](std::size_t w, instruction_ref i) { return w + self(i); });
} }
return weights[ins]; return weights[ins];
...@@ -94,31 +99,54 @@ struct stream_info ...@@ -94,31 +99,54 @@ struct stream_info
return not std::all_of(v.begin(), v.end(), [&](std::size_t x) { return x == v.front(); }); return not std::all_of(v.begin(), v.end(), [&](std::size_t x) { return x == v.front(); });
} }
template<class F>
bool different(F f) const
{
bool first = true;
std::size_t stream = 0;
bool result = false;
f([&](auto s) {
if (not first and s != stream)
{
result = true;
return false;
}
stream = s;
first = false;
return true;
});
return result;
}
template <class Selector> template <class Selector>
std::vector<std::size_t> get_streams(instruction_ref ins, Selector select) const auto get_streams(instruction_ref start, Selector select) const
{ {
std::vector<std::size_t> result; return [=](auto f) {
return fix<bool>([&](auto self, auto ins) {
for(auto i : select(ins)) for(auto i : select(ins))
{ {
if(weights.at(i) == 0) if(weights.at(i) == 0)
{ {
auto vv = get_input_streams(i); if (not self(i))
result.insert(result.end(), vv.begin(), vv.end()); return false;
} }
else else
{ {
result.emplace_back(get_stream(i)); if (not f(get_stream(i)))
return false;
} }
} }
return result; return true;
})(start);
};
} }
std::vector<std::size_t> get_input_streams(instruction_ref ins) const auto get_input_streams(instruction_ref ins) const
{ {
return get_streams(ins, [](auto i) { return i->inputs(); }); return get_streams(ins, [](auto i) { return i->inputs(); });
} }
std::vector<std::size_t> get_output_streams(instruction_ref ins) const auto get_output_streams(instruction_ref ins) const
{ {
return get_streams(ins, [](auto i) { return i->outputs(); }); return get_streams(ins, [](auto i) { return i->outputs(); });
} }
...@@ -129,15 +157,22 @@ struct stream_info ...@@ -129,15 +157,22 @@ struct stream_info
std::vector<std::size_t> wait_for(instruction_ref ins) const std::vector<std::size_t> wait_for(instruction_ref ins) const
{ {
std::vector<std::size_t> result = get_input_streams(ins); std::vector<std::size_t> result;
get_input_streams(ins)([&](auto s) {
result.push_back(s);
return true;
});
// Remove duplicates
std::sort(result.begin(), result.end()); std::sort(result.begin(), result.end());
result.erase(std::unique(result.begin(), result.end()), result.end()); result.erase(std::unique(result.begin(), result.end()), result.end());
// Remove the merged stream
result.erase(std::find(result.begin(), result.end(), get_stream(ins)));
return result; return result;
} }
template <class F> std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> find_concurrent_instructions(program& p)
void find_concurrent_instructions(program& p, F f)
{ {
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> split_from; std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> split_from;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
...@@ -150,12 +185,26 @@ struct stream_info ...@@ -150,12 +185,26 @@ struct stream_info
split_from[ins].insert(split_from[arg].begin(), split_from[arg].end()); split_from[ins].insert(split_from[arg].begin(), split_from[arg].end());
} }
// if (is_merge_point(ins))
// {
// // post-dominator kills split point.
// for(auto& split : split_from[ins])
// {
// if(strictly_post_dominates(ins, split))
// split_from[ins].erase(split);
// }
// }
// Collect concur instructions for each split point. // Collect concur instructions for each split point.
for(auto& split : split_from[ins]) for(auto& split : split_from[ins])
{ {
f(ins, split); auto stream = get_stream(ins);
if (result[split].size() <= stream)
result[split].resize(stream+1);
result[split][stream].push_back(ins);
} }
} }
return result;
} }
}; };
...@@ -186,8 +235,27 @@ void schedule::apply(program& p) const ...@@ -186,8 +235,27 @@ void schedule::apply(program& p) const
model.schedule_instruction(p, ins, si.get_stream(ins)); model.schedule_instruction(p, ins, si.get_stream(ins));
} }
si.find_concurrent_instructions( // Add memory conflicts
p, [&](auto x, auto y) { p.insert_instruction(std::next(x), op::identity{}, x, y); }); auto concur_ins = si.find_concurrent_instructions(p);
for(auto&& split:concur_ins)
{
dfor(split.second.size(), split.second.size())([&](auto i, auto j) {
if (i == j)
return;
for(auto ins1:split.second[i])
{
auto idx1 = std::distance(split.first, ins1);
for(auto ins2:split.second[j])
{
if (ins1 == ins2)
continue;
auto idx2 = std::distance(split.first, ins2);
auto point = idx1 > idx2 ? ins1 : ins2;
p.insert_instruction(std::next(point), op::identity{}, ins1, ins2);
}
}
});
}
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -114,6 +114,8 @@ struct hip_device ...@@ -114,6 +114,8 @@ struct hip_device
std::size_t nstreams() const { return streams.size(); } std::size_t nstreams() const { return streams.size(); }
std::size_t stream_id() const { return current_stream; }
private: private:
std::size_t device_id = 0; std::size_t device_id = 0;
std::size_t current_stream = 0; std::size_t current_stream = 0;
......
...@@ -21,7 +21,7 @@ hip_event_ptr create_event() ...@@ -21,7 +21,7 @@ hip_event_ptr create_event()
struct wait_event struct wait_event
{ {
std::vector<std::size_t> wait_for; std::vector<std::size_t> wait_for;
shared<hip_event_ptr> event; shared<hip_event_ptr> event = nullptr;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
...@@ -33,13 +33,18 @@ struct wait_event ...@@ -33,13 +33,18 @@ struct wait_event
argument compute(context& ctx, const shape&, const std::vector<argument>&) const argument compute(context& ctx, const shape&, const std::vector<argument>&) const
{ {
assert(event != nullptr); assert(event != nullptr);
assert(std::none_of(wait_for.begin(), wait_for.end(), [&](auto i) { return i == ctx.get_current_device().stream_id(); }));
for(auto n : wait_for) for(auto n : wait_for)
ctx.get_stream(n).record(event.get()); ctx.get_stream(n).record(event.get());
ctx.get_stream().wait(event.get()); ctx.get_stream().wait(event.get());
return {}; return {};
} }
void finalize(context& ctx, const shape&, std::vector<shape>) { event = create_event(); } void finalize(context& ctx, const shape&, std::vector<shape>)
{
assert(std::none_of(wait_for.begin(), wait_for.end(), [&](auto i) { return i == ctx.get_current_device().stream_id(); }));
event = create_event();
}
}; };
struct set_stream struct set_stream
...@@ -100,8 +105,6 @@ std::size_t schedule_model::weight(const operation& op) const ...@@ -100,8 +105,6 @@ std::size_t schedule_model::weight(const operation& op) const
{ {
if(weight_map().count(op.name()) == 0) if(weight_map().count(op.name()) == 0)
{ {
if(is_context_free(op) or op.name()[0] == '@')
return 0;
return 1; return 1;
} }
return weight_map().at(op.name()); return weight_map().at(op.name());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment