Commit ffc51fe7 authored by Paul's avatar Paul
Browse files

Lazy compute input streams

parent a1dcece1
......@@ -3,6 +3,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <unordered_map>
......@@ -27,10 +28,14 @@ struct stream_info
fix<std::size_t>([&](auto self, auto ins) -> std::size_t {
if(weights.count(ins) == 0)
{
std::size_t weight = 0;
auto&& op = ins->get_operator();
if(not is_context_free(op) and op.name()[0] != '@')
weight = model.weight(op);
weights[ins] =
std::accumulate(ins->inputs().begin(),
ins->inputs().end(),
model.weight(ins->get_operator()),
weight,
[&](std::size_t w, instruction_ref i) { return w + self(i); });
}
return weights[ins];
......@@ -94,31 +99,54 @@ struct stream_info
return not std::all_of(v.begin(), v.end(), [&](std::size_t x) { return x == v.front(); });
}
template <class Selector>
std::vector<std::size_t> get_streams(instruction_ref ins, Selector select) const
template<class F>
bool different(F f) const
{
std::vector<std::size_t> result;
for(auto i : select(ins))
{
if(weights.at(i) == 0)
bool first = true;
std::size_t stream = 0;
bool result = false;
f([&](auto s) {
if (not first and s != stream)
{
auto vv = get_input_streams(i);
result.insert(result.end(), vv.begin(), vv.end());
result = true;
return false;
}
else
{
result.emplace_back(get_stream(i));
}
}
stream = s;
first = false;
return true;
});
return result;
}
std::vector<std::size_t> get_input_streams(instruction_ref ins) const
template <class Selector>
auto get_streams(instruction_ref start, Selector select) const
{
return [=](auto f) {
return fix<bool>([&](auto self, auto ins) {
for(auto i : select(ins))
{
if(weights.at(i) == 0)
{
if (not self(i))
return false;
}
else
{
if (not f(get_stream(i)))
return false;
}
}
return true;
})(start);
};
}
auto get_input_streams(instruction_ref ins) const
{
return get_streams(ins, [](auto i) { return i->inputs(); });
}
std::vector<std::size_t> get_output_streams(instruction_ref ins) const
auto get_output_streams(instruction_ref ins) const
{
return get_streams(ins, [](auto i) { return i->outputs(); });
}
......@@ -129,15 +157,22 @@ struct stream_info
std::vector<std::size_t> wait_for(instruction_ref ins) const
{
std::vector<std::size_t> result = get_input_streams(ins);
std::vector<std::size_t> result;
get_input_streams(ins)([&](auto s) {
result.push_back(s);
return true;
});
// Remove duplicates
std::sort(result.begin(), result.end());
result.erase(std::unique(result.begin(), result.end()), result.end());
// Remove the merged stream
result.erase(std::find(result.begin(), result.end(), get_stream(ins)));
return result;
}
template <class F>
void find_concurrent_instructions(program& p, F f)
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> find_concurrent_instructions(program& p)
{
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> split_from;
for(auto ins : iterator_for(p))
{
......@@ -150,12 +185,26 @@ struct stream_info
split_from[ins].insert(split_from[arg].begin(), split_from[arg].end());
}
// if (is_merge_point(ins))
// {
// // post-dominator kills split point.
// for(auto& split : split_from[ins])
// {
// if(strictly_post_dominates(ins, split))
// split_from[ins].erase(split);
// }
// }
// Collect concur instructions for each split point.
for(auto& split : split_from[ins])
{
f(ins, split);
auto stream = get_stream(ins);
if (result[split].size() <= stream)
result[split].resize(stream+1);
result[split][stream].push_back(ins);
}
}
return result;
}
};
......@@ -186,8 +235,27 @@ void schedule::apply(program& p) const
model.schedule_instruction(p, ins, si.get_stream(ins));
}
si.find_concurrent_instructions(
p, [&](auto x, auto y) { p.insert_instruction(std::next(x), op::identity{}, x, y); });
// Add memory conflicts
auto concur_ins = si.find_concurrent_instructions(p);
for(auto&& split:concur_ins)
{
dfor(split.second.size(), split.second.size())([&](auto i, auto j) {
if (i == j)
return;
for(auto ins1:split.second[i])
{
auto idx1 = std::distance(split.first, ins1);
for(auto ins2:split.second[j])
{
if (ins1 == ins2)
continue;
auto idx2 = std::distance(split.first, ins2);
auto point = idx1 > idx2 ? ins1 : ins2;
p.insert_instruction(std::next(point), op::identity{}, ins1, ins2);
}
}
});
}
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -114,6 +114,8 @@ struct hip_device
std::size_t nstreams() const { return streams.size(); }
std::size_t stream_id() const { return current_stream; }
private:
std::size_t device_id = 0;
std::size_t current_stream = 0;
......
......@@ -21,7 +21,7 @@ hip_event_ptr create_event()
struct wait_event
{
std::vector<std::size_t> wait_for;
shared<hip_event_ptr> event;
shared<hip_event_ptr> event = nullptr;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
......@@ -33,13 +33,18 @@ struct wait_event
argument compute(context& ctx, const shape&, const std::vector<argument>&) const
{
assert(event != nullptr);
assert(std::none_of(wait_for.begin(), wait_for.end(), [&](auto i) { return i == ctx.get_current_device().stream_id(); }));
for(auto n : wait_for)
ctx.get_stream(n).record(event.get());
ctx.get_stream().wait(event.get());
return {};
}
void finalize(context& ctx, const shape&, std::vector<shape>) { event = create_event(); }
void finalize(context& ctx, const shape&, std::vector<shape>)
{
assert(std::none_of(wait_for.begin(), wait_for.end(), [&](auto i) { return i == ctx.get_current_device().stream_id(); }));
event = create_event();
}
};
struct set_stream
......@@ -100,8 +105,6 @@ std::size_t schedule_model::weight(const operation& op) const
{
if(weight_map().count(op.name()) == 0)
{
if(is_context_free(op) or op.name()[0] == '@')
return 0;
return 1;
}
return weight_map().at(op.name());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment