Commit 8a8b94c8 authored by Paul's avatar Paul
Browse files

Fix bug with stream splitting

parent 64f71de8
......@@ -18,6 +18,16 @@ bool stream_free(instruction_ref ins)
return is_context_free(ins->get_operator()) or ins->get_operator().name().front() == '@';
}
auto get_inputs()
{
return [](auto i) { return i->inputs(); };
}
auto get_outputs()
{
return [](auto i) { return i->outputs(); };
}
struct stream_info
{
std::unordered_map<instruction_ref, std::size_t> ins2stream;
......@@ -100,24 +110,32 @@ struct stream_info
}
template <class F>
bool different(F f) const
bool different(F f, std::size_t stream) const
{
bool first = true;
std::size_t stream = 0;
bool result = false;
bool result = false;
f([&](auto s) {
if(not first and s != stream)
if(s != stream)
{
result = true;
return false;
}
stream = s;
first = false;
return true;
});
return result;
}
template <class F>
bool different(F f) const
{
bool result = false;
f([&](auto s) {
result = different(f, s);
return false;
});
return result;
}
template <class Selector>
auto get_streams(instruction_ref start, Selector select) const
{
......@@ -141,24 +159,16 @@ struct stream_info
};
}
auto get_input_streams(instruction_ref ins) const
{
return get_streams(ins, [](auto i) { return i->inputs(); });
}
auto get_output_streams(instruction_ref ins) const
{
return get_streams(ins, [](auto i) { return i->outputs(); });
}
bool is_merge_point(instruction_ref ins) const { return different(get_input_streams(ins)); }
template<class... Ts>
bool is_merge_point(instruction_ref ins, Ts... xs) const { return different(get_streams(ins, get_inputs()), xs...); }
bool is_split_point(instruction_ref ins) const { return different(get_output_streams(ins)); }
template<class... Ts>
bool is_split_point(instruction_ref ins, Ts... xs) const { return different(get_streams(ins, get_outputs()), xs...); }
std::vector<std::size_t> wait_for(instruction_ref ins) const
{
std::vector<std::size_t> result;
get_input_streams(ins)([&](auto s) {
get_streams(ins, get_inputs())([&](auto s) {
result.push_back(s);
return true;
});
......@@ -166,7 +176,9 @@ struct stream_info
std::sort(result.begin(), result.end());
result.erase(std::unique(result.begin(), result.end()), result.end());
// Remove the merged stream
result.erase(std::find(result.begin(), result.end(), get_stream(ins)));
auto it = std::find(result.begin(), result.end(), get_stream(ins));
if (it != result.end())
result.erase(it);
return result;
}
......@@ -186,6 +198,7 @@ struct stream_info
split_from[ins].insert(split_from[arg].begin(), split_from[arg].end());
}
auto stream = get_stream(ins);
// if (is_merge_point(ins))
// {
// // post-dominator kills split point.
......@@ -199,7 +212,6 @@ struct stream_info
// Collect concur instructions for each split point.
for(auto& split : split_from[ins])
{
auto stream = get_stream(ins);
if(result[split].size() <= stream)
result[split].resize(stream + 1);
result[split][stream].push_back(ins);
......@@ -245,10 +257,11 @@ void schedule::apply(program& p) const
// Only schedule instructions that have a stream
if(not si.has_stream(ins))
continue;
if(si.is_merge_point(ins))
model.wait(p, ins, si.get_stream(ins), si.wait_for(ins));
auto stream = si.get_stream(ins);
if(si.is_merge_point(ins, stream))
model.wait(p, ins, stream, si.wait_for(ins));
else
model.schedule_instruction(p, ins, si.get_stream(ins));
model.schedule_instruction(p, ins, stream);
}
// Add memory conflicts
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment