"docs/source/en/api/models/stable_cascade_unet.md" did not exist on "c3d78cd3067612175ac9f0f8b234abf5a2e1f510"
Commit 8a8b94c8 authored by Paul's avatar Paul
Browse files

Fix bug with stream splitting

parent 64f71de8
...@@ -18,6 +18,16 @@ bool stream_free(instruction_ref ins) ...@@ -18,6 +18,16 @@ bool stream_free(instruction_ref ins)
return is_context_free(ins->get_operator()) or ins->get_operator().name().front() == '@'; return is_context_free(ins->get_operator()) or ins->get_operator().name().front() == '@';
} }
auto get_inputs()
{
return [](auto i) { return i->inputs(); };
}
auto get_outputs()
{
return [](auto i) { return i->outputs(); };
}
struct stream_info struct stream_info
{ {
std::unordered_map<instruction_ref, std::size_t> ins2stream; std::unordered_map<instruction_ref, std::size_t> ins2stream;
...@@ -100,24 +110,32 @@ struct stream_info ...@@ -100,24 +110,32 @@ struct stream_info
} }
template <class F> template <class F>
bool different(F f) const bool different(F f, std::size_t stream) const
{ {
bool first = true; bool result = false;
std::size_t stream = 0;
bool result = false;
f([&](auto s) { f([&](auto s) {
if(not first and s != stream) if(s != stream)
{ {
result = true; result = true;
return false; return false;
} }
stream = s; stream = s;
first = false;
return true; return true;
}); });
return result; return result;
} }
template <class F>
bool different(F f) const
{
bool result = false;
f([&](auto s) {
result = different(f, s);
return false;
});
return result;
}
template <class Selector> template <class Selector>
auto get_streams(instruction_ref start, Selector select) const auto get_streams(instruction_ref start, Selector select) const
{ {
...@@ -141,24 +159,16 @@ struct stream_info ...@@ -141,24 +159,16 @@ struct stream_info
}; };
} }
auto get_input_streams(instruction_ref ins) const template<class... Ts>
{ bool is_merge_point(instruction_ref ins, Ts... xs) const { return different(get_streams(ins, get_inputs()), xs...); }
return get_streams(ins, [](auto i) { return i->inputs(); });
}
auto get_output_streams(instruction_ref ins) const
{
return get_streams(ins, [](auto i) { return i->outputs(); });
}
bool is_merge_point(instruction_ref ins) const { return different(get_input_streams(ins)); }
bool is_split_point(instruction_ref ins) const { return different(get_output_streams(ins)); } template<class... Ts>
bool is_split_point(instruction_ref ins, Ts... xs) const { return different(get_streams(ins, get_outputs()), xs...); }
std::vector<std::size_t> wait_for(instruction_ref ins) const std::vector<std::size_t> wait_for(instruction_ref ins) const
{ {
std::vector<std::size_t> result; std::vector<std::size_t> result;
get_input_streams(ins)([&](auto s) { get_streams(ins, get_inputs())([&](auto s) {
result.push_back(s); result.push_back(s);
return true; return true;
}); });
...@@ -166,7 +176,9 @@ struct stream_info ...@@ -166,7 +176,9 @@ struct stream_info
std::sort(result.begin(), result.end()); std::sort(result.begin(), result.end());
result.erase(std::unique(result.begin(), result.end()), result.end()); result.erase(std::unique(result.begin(), result.end()), result.end());
// Remove the merged stream // Remove the merged stream
result.erase(std::find(result.begin(), result.end(), get_stream(ins))); auto it = std::find(result.begin(), result.end(), get_stream(ins));
if (it != result.end())
result.erase(it);
return result; return result;
} }
...@@ -186,6 +198,7 @@ struct stream_info ...@@ -186,6 +198,7 @@ struct stream_info
split_from[ins].insert(split_from[arg].begin(), split_from[arg].end()); split_from[ins].insert(split_from[arg].begin(), split_from[arg].end());
} }
auto stream = get_stream(ins);
// if (is_merge_point(ins)) // if (is_merge_point(ins))
// { // {
// // post-dominator kills split point. // // post-dominator kills split point.
...@@ -199,7 +212,6 @@ struct stream_info ...@@ -199,7 +212,6 @@ struct stream_info
// Collect concur instructions for each split point. // Collect concur instructions for each split point.
for(auto& split : split_from[ins]) for(auto& split : split_from[ins])
{ {
auto stream = get_stream(ins);
if(result[split].size() <= stream) if(result[split].size() <= stream)
result[split].resize(stream + 1); result[split].resize(stream + 1);
result[split][stream].push_back(ins); result[split][stream].push_back(ins);
...@@ -245,10 +257,11 @@ void schedule::apply(program& p) const ...@@ -245,10 +257,11 @@ void schedule::apply(program& p) const
// Only schedule instructions that have a stream // Only schedule instructions that have a stream
if(not si.has_stream(ins)) if(not si.has_stream(ins))
continue; continue;
if(si.is_merge_point(ins)) auto stream = si.get_stream(ins);
model.wait(p, ins, si.get_stream(ins), si.wait_for(ins)); if(si.is_merge_point(ins, stream))
model.wait(p, ins, stream, si.wait_for(ins));
else else
model.schedule_instruction(p, ins, si.get_stream(ins)); model.schedule_instruction(p, ins, stream);
} }
// Add memory conflicts // Add memory conflicts
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment