Commit 23fd64f8 authored by Paul's avatar Paul
Browse files

Formatting

parent 0e5dabb4
...@@ -137,20 +137,16 @@ auto fold(F f) ...@@ -137,20 +137,16 @@ auto fold(F f)
return [=](auto&&... xs) { return fold_impl(f, std::forward<decltype(xs)>(xs)...); }; return [=](auto&&... xs) { return fold_impl(f, std::forward<decltype(xs)>(xs)...); };
} }
template<class F, class Proj> template <class F, class Proj>
auto by(F f, Proj proj) auto by(F f, Proj proj)
{ {
return [=](auto&&... xs) { return [=](auto&&... xs) { return f(proj(std::forward<decltype(xs)>(xs))...); };
return f(proj(std::forward<decltype(xs)>(xs))...);
};
} }
template<class T> template <class T>
auto index_of(T& x) auto index_of(T& x)
{ {
return [&](auto&& y) { return [&](auto&& y) { return x[y]; };
return x[y];
};
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -73,11 +73,13 @@ struct stream_info ...@@ -73,11 +73,13 @@ struct stream_info
return; return;
part.add(ins, this->iweights[ins]); part.add(ins, this->iweights[ins]);
auto max_it = std::max_element(ins->inputs().begin(), ins->inputs().end(), by(std::less<>{}, index_of(this->weights))); auto max_it = std::max_element(ins->inputs().begin(),
ins->inputs().end(),
by(std::less<>{}, index_of(this->weights)));
for(auto i : ins->inputs()) for(auto i : ins->inputs())
{ {
const auto weight = this->weights[i]; const auto weight = this->weights[i];
if (i == *max_it or weight <= min_partition_threshold) if(i == *max_it or weight <= min_partition_threshold)
{ {
self(i, part); self(i, part);
} }
...@@ -91,15 +93,18 @@ struct stream_info ...@@ -91,15 +93,18 @@ struct stream_info
// Set the critical partition to stream 0 // Set the critical partition to stream 0
set_stream(critical, 0); set_stream(critical, 0);
std::vector<std::size_t> streams(n-1); std::vector<std::size_t> streams(n - 1);
// Assign streams for the other partitions // Assign streams for the other partitions
for(auto&& ins_part:partitions) for(auto&& ins_part : partitions)
{ {
std::sort(ins_part.second.begin(), ins_part.second.end(), by(std::greater<>{}, [](auto&& x) { return std::make_tuple(x.weight, x.instructions.size()); })); std::sort(
for(auto&& part:ins_part.second) ins_part.second.begin(), ins_part.second.end(), by(std::greater<>{}, [](auto&& x) {
return std::make_tuple(x.weight, x.instructions.size());
}));
for(auto&& part : ins_part.second)
{ {
auto stream = std::min_element(streams.begin(), streams.end()) - streams.begin(); auto stream = std::min_element(streams.begin(), streams.end()) - streams.begin();
set_stream(part, stream+1); set_stream(part, stream + 1);
streams[stream] += part.weight; streams[stream] += part.weight;
} }
} }
...@@ -107,8 +112,8 @@ struct stream_info ...@@ -107,8 +112,8 @@ struct stream_info
void set_stream(const partition& p, std::size_t n) void set_stream(const partition& p, std::size_t n)
{ {
for(auto ins:p.instructions) for(auto ins : p.instructions)
if (iweights[ins] > 0) if(iweights[ins] > 0)
set_stream(ins, n); set_stream(ins, n);
} }
...@@ -258,8 +263,8 @@ void schedule::apply(program& p) const ...@@ -258,8 +263,8 @@ void schedule::apply(program& p) const
fix([&](auto self, auto ins) { fix([&](auto self, auto ins) {
auto args = ins->inputs(); auto args = ins->inputs();
std::sort(args.begin(), args.end(), by(std::less<>{}, [&](auto x) { std::sort(args.begin(), args.end(), by(std::less<>{}, [&](auto x) {
return std::make_tuple(si.weights[x], x->inputs().size()); return std::make_tuple(si.weights[x], x->inputs().size());
})); }));
for(auto i : args) for(auto i : args)
{ {
p.move_instruction(i, p.begin()); p.move_instruction(i, p.begin());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment