Commit 0e5dabb4 authored by Paul's avatar Paul
Browse files

Create intermediate partitions

parent 05208b60
...@@ -137,6 +137,22 @@ auto fold(F f) ...@@ -137,6 +137,22 @@ auto fold(F f)
return [=](auto&&... xs) { return fold_impl(f, std::forward<decltype(xs)>(xs)...); }; return [=](auto&&... xs) { return fold_impl(f, std::forward<decltype(xs)>(xs)...); };
} }
template<class F, class Proj>
auto by(F f, Proj proj)
{
return [=](auto&&... xs) {
return f(proj(std::forward<decltype(xs)>(xs))...);
};
}
template<class T>
auto index_of(T& x)
{
return [&](auto&& y) {
return x[y];
};
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -9,15 +9,11 @@ ...@@ -9,15 +9,11 @@
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <set> #include <set>
#include <deque>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
bool stream_free(instruction_ref ins)
{
return is_context_free(ins->get_operator()) or ins->get_operator().name().front() == '@';
}
auto get_inputs() auto get_inputs()
{ {
return [](auto i) { return i->inputs(); }; return [](auto i) { return i->inputs(); };
...@@ -54,50 +50,68 @@ struct stream_info ...@@ -54,50 +50,68 @@ struct stream_info
})(last); })(last);
} }
void assign_streams(program& p, std::size_t streams) struct partition
{ {
const std::size_t min_partition_threshold = 2; std::size_t weight = 0;
for(std::size_t stream = 0; stream < streams - 1; stream++) std::vector<instruction_ref> instructions{};
void add(instruction_ref ins, std::size_t w)
{ {
fix([&](auto self, auto ins) { weight += w;
// If weight is zero then stop instructions.push_back(ins);
if(this->weights[ins] == 0) }
return; };
// Only assign streams if not already assigned
if(not this->has_stream(ins) and this->iweights[ins] > 0) void assign_streams(program& p, std::size_t n)
this->set_stream(ins, stream); {
instruction_ref child = p.end(); const std::size_t min_partition_threshold = 2;
std::size_t w = 0; partition critical;
for(auto i : ins->inputs()) std::unordered_map<instruction_ref, std::deque<partition>> partitions;
fix([&](auto self, auto ins, auto& part) {
// If weight is zero then stop
if(this->weights[ins] == 0)
return;
part.add(ins, this->iweights[ins]);
auto max_it = std::max_element(ins->inputs().begin(), ins->inputs().end(), by(std::less<>{}, index_of(this->weights)));
for(auto i : ins->inputs())
{
const auto weight = this->weights[i];
if (i == *max_it or weight <= min_partition_threshold)
{ {
const auto weight = this->weights[i]; self(i, part);
// Skip instruction that already have stream assignment or too low of weights
if(this->has_stream(i) or weight <= min_partition_threshold)
{
self(i);
}
// Accumulate the max weight
else if(weight > w)
{
child = i;
w = weight;
}
} }
if(child != p.end()) else
self(child); {
})(std::prev(p.end())); partitions[ins].emplace_back();
} self(i, partitions[ins].back());
// Assign remaining instructions }
for(auto ins : iterator_for(p)) }
})(std::prev(p.end()), critical);
// Set the critical partition to stream 0
set_stream(critical, 0);
std::vector<std::size_t> streams(n-1);
// Assign streams for the other partitions
for(auto&& ins_part:partitions)
{ {
if(has_stream(ins)) std::sort(ins_part.second.begin(), ins_part.second.end(), by(std::greater<>{}, [](auto&& x) { return std::make_tuple(x.weight, x.instructions.size()); }));
continue; for(auto&& part:ins_part.second)
if(iweights[ins] == 0) {
continue; auto stream = std::min_element(streams.begin(), streams.end()) - streams.begin();
set_stream(ins, streams - 1); set_stream(part, stream+1);
streams[stream] += part.weight;
}
} }
} }
void set_stream(const partition& p, std::size_t n)
{
for(auto ins:p.instructions)
if (iweights[ins] > 0)
set_stream(ins, n);
}
void set_stream(instruction_ref ins, std::size_t n) void set_stream(instruction_ref ins, std::size_t n)
{ {
assert(iweights[ins] > 0); assert(iweights[ins] > 0);
...@@ -243,10 +257,9 @@ void schedule::apply(program& p) const ...@@ -243,10 +257,9 @@ void schedule::apply(program& p) const
// Topo sort // Topo sort
fix([&](auto self, auto ins) { fix([&](auto self, auto ins) {
auto args = ins->inputs(); auto args = ins->inputs();
std::sort(args.begin(), args.end(), [&](auto x, auto y) { std::sort(args.begin(), args.end(), by(std::less<>{}, [&](auto x) {
return std::make_tuple(si.weights[x], x->inputs().size()) < return std::make_tuple(si.weights[x], x->inputs().size());
std::make_tuple(si.weights[y], y->inputs().size()); }));
});
for(auto i : args) for(auto i : args)
{ {
p.move_instruction(i, p.begin()); p.move_instruction(i, p.begin());
......
...@@ -369,14 +369,14 @@ TEST_CASE(seq_merge) ...@@ -369,14 +369,14 @@ TEST_CASE(seq_merge)
p.compile(schedule_target{&stream}); p.compile(schedule_target{&stream});
EXPECT(stream.count(one) == 0); EXPECT(stream.count(one) == 0);
EXPECT(stream.at(i1) == 1); EXPECT(stream.at(i1) == 2);
for(auto ins : c1) for(auto ins : c1)
EXPECT(stream.at(ins) == 0); EXPECT(stream.at(ins) == 3);
EXPECT(stream.at(binary1) == 0); EXPECT(stream.at(binary1) == 3);
EXPECT(get_wait_for(binary1) == get_wait_for(stream[binary1], {stream[c1.back()], stream[i1]})); EXPECT(get_wait_for(binary1) == get_wait_for(stream[binary1], {stream[c1.back()], stream[i1]}));
check_conflicts(p, {c1, {i1}}); check_conflicts(p, {c1, {i1}});
EXPECT(stream.at(i2) == 1); EXPECT(stream.at(i2) == 3);
for(auto ins : c2) for(auto ins : c2)
EXPECT(stream.at(ins) == 0); EXPECT(stream.at(ins) == 0);
EXPECT(stream.at(binary2) == 0); EXPECT(stream.at(binary2) == 0);
...@@ -414,8 +414,8 @@ TEST_CASE(par_merge) ...@@ -414,8 +414,8 @@ TEST_CASE(par_merge)
EXPECT(stream.at(i2) == 2); EXPECT(stream.at(i2) == 2);
for(auto ins : c2) for(auto ins : c2)
EXPECT(stream.at(ins) == 1); EXPECT(stream.at(ins) == 3);
EXPECT(stream.at(binary2) == 1); EXPECT(stream.at(binary2) == 3);
EXPECT(get_wait_for(binary2) == get_wait_for(stream[binary2], {stream[c2.back()], stream[i2]})); EXPECT(get_wait_for(binary2) == get_wait_for(stream[binary2], {stream[c2.back()], stream[i2]}));
check_conflicts(p, {c2, {i2}}); check_conflicts(p, {c2, {i2}});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment