Commit 0a39a958 authored by Paul's avatar Paul
Browse files

Formatting

parent c6d92ecc
...@@ -14,49 +14,33 @@ struct stream_info ...@@ -14,49 +14,33 @@ struct stream_info
{ {
std::unordered_map<instruction_ref, std::size_t> ins2stream; std::unordered_map<instruction_ref, std::size_t> ins2stream;
void set_stream(instruction_ref ins, std::size_t n) void set_stream(instruction_ref ins, std::size_t n) { ins2stream[ins] = n; }
{
ins2stream[ins] = n;
}
std::size_t get_stream(instruction_ref ins) const std::size_t get_stream(instruction_ref ins) const { return ins2stream.at(ins); }
{
return ins2stream.at(ins);
}
bool has_stream(instruction_ref ins) const bool has_stream(instruction_ref ins) const { return ins2stream.count(ins) > 0; }
{
return ins2stream.count(ins) > 0;
}
bool different(const std::vector<instruction_ref>& v) const bool different(const std::vector<instruction_ref>& v) const
{ {
if (v.size() < 2) if(v.size() < 2)
return false; return false;
auto stream = get_stream(v.front()); auto stream = get_stream(v.front());
return not std::all_of(v.begin(), v.end(), [&](instruction_ref x) { return not std::all_of(
return get_stream(x) == stream; v.begin(), v.end(), [&](instruction_ref x) { return get_stream(x) == stream; });
});
} }
bool is_split_point(instruction_ref ins) const bool is_split_point(instruction_ref ins) const { return different(ins->outputs()); }
{
return different(ins->outputs());
}
bool is_merge_point(instruction_ref ins) const bool is_merge_point(instruction_ref ins) const { return different(ins->inputs()); }
{
return different(ins->inputs());
}
std::vector<std::size_t> wait_for(instruction_ref ins) const std::vector<std::size_t> wait_for(instruction_ref ins) const
{ {
std::set<std::size_t> result; std::set<std::size_t> result;
auto s = get_stream(ins); auto s = get_stream(ins);
for(auto i:ins->inputs()) for(auto i : ins->inputs())
{ {
auto stream = get_stream(i); auto stream = get_stream(i);
if (stream != s) if(stream != s)
result.insert(stream); result.insert(stream);
} }
return {result.begin(), result.end()}; return {result.begin(), result.end()};
...@@ -85,39 +69,39 @@ void schedule::apply(program& p) const ...@@ -85,39 +69,39 @@ void schedule::apply(program& p) const
// Assign streams // Assign streams
auto streams = model.concurrency(); auto streams = model.concurrency();
stream_info si; stream_info si;
for(std::size_t stream = 0;stream < streams;stream++) for(std::size_t stream = 0; stream < streams; stream++)
{ {
fix([&](auto self, auto ins) { fix([&](auto self, auto ins) {
// Only assign streams fi not already assigned // Only assign streams fi not already assigned
if (not si.has_stream(ins)) if(not si.has_stream(ins))
si.set_stream(ins, stream); si.set_stream(ins, stream);
instruction_ref child = p.end(); instruction_ref child = p.end();
std::size_t w = 0; std::size_t w = 0;
for(auto i:ins->inputs()) for(auto i : ins->inputs())
{ {
const auto weight = weights[i]; const auto weight = weights[i];
// Skip instruction that already have stream assignment or too low of weights // Skip instruction that already have stream assignment or too low of weights
if (si.has_stream(i) or weight <= min_partition_threshold) if(si.has_stream(i) or weight <= min_partition_threshold)
{ {
self(i); self(i);
} }
// Accumulate the max weight // Accumulate the max weight
else if (weight > w) else if(weight > w)
{ {
child = i; child = i;
w = weight; w = weight;
} }
} }
if (child != p.end()) if(child != p.end())
self(child); self(child);
})(last); })(last);
} }
// Assign remaining instructions // Assign remaining instructions
for(auto ins:iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if (si.has_stream(ins)) if(si.has_stream(ins))
continue; continue;
si.set_stream(ins, streams-1); si.set_stream(ins, streams - 1);
} }
// Topo sort // Topo sort
...@@ -127,18 +111,18 @@ void schedule::apply(program& p) const ...@@ -127,18 +111,18 @@ void schedule::apply(program& p) const
for(auto i : ins->inputs()) for(auto i : ins->inputs())
self(i); self(i);
})(last); })(last);
// Schedule instructions // Schedule instructions
for(auto ins:iterator_for(p)) for(auto ins : iterator_for(p))
{ {
if (si.is_merge_point(ins)) if(si.is_merge_point(ins))
{ {
assert(not si.wait_for(ins).empty()); assert(not si.wait_for(ins).empty());
model.wait(p, ins, si.get_stream(ins), si.wait_for(ins)); model.wait(p, ins, si.get_stream(ins), si.wait_for(ins));
continue; continue;
} }
// Skip scheduling instructions with no context // Skip scheduling instructions with no context
if (is_context_free(ins->get_operator()) or ins->get_operator().name().front() == '@') if(is_context_free(ins->get_operator()) or ins->get_operator().name().front() == '@')
continue; continue;
model.schedule_instruction(p, ins, si.get_stream(ins)); model.schedule_instruction(p, ins, si.get_stream(ins));
} }
......
...@@ -98,9 +98,9 @@ static const std::unordered_map<std::string, std::size_t>& weight_map() ...@@ -98,9 +98,9 @@ static const std::unordered_map<std::string, std::size_t>& weight_map()
std::size_t schedule_model::weight(const operation& op) const std::size_t schedule_model::weight(const operation& op) const
{ {
if(weight_map().count(op.name()) == 0) if(weight_map().count(op.name()) == 0)
{ {
if (is_context_free(op) or op.name()[0] == '@') if(is_context_free(op) or op.name()[0] == '@')
return 0; return 0;
return 1; return 1;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment