Commit 0e5dabb4 authored by Paul's avatar Paul
Browse files

Create intermediate partitions

parent 05208b60
......@@ -137,6 +137,22 @@ auto fold(F f)
return [=](auto&&... xs) { return fold_impl(f, std::forward<decltype(xs)>(xs)...); };
}
template<class F, class Proj>
auto by(F f, Proj proj)
{
return [=](auto&&... xs) {
return f(proj(std::forward<decltype(xs)>(xs))...);
};
}
template<class T>
auto index_of(T& x)
{
return [&](auto&& y) {
return x[y];
};
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -9,15 +9,11 @@
#include <unordered_map>
#include <unordered_set>
#include <set>
#include <deque>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
bool stream_free(instruction_ref ins)
{
return is_context_free(ins->get_operator()) or ins->get_operator().name().front() == '@';
}
auto get_inputs()
{
return [](auto i) { return i->inputs(); };
......@@ -54,48 +50,66 @@ struct stream_info
})(last);
}
void assign_streams(program& p, std::size_t streams)
struct partition
{
const std::size_t min_partition_threshold = 2;
for(std::size_t stream = 0; stream < streams - 1; stream++)
std::size_t weight = 0;
std::vector<instruction_ref> instructions{};
void add(instruction_ref ins, std::size_t w)
{
fix([&](auto self, auto ins) {
weight += w;
instructions.push_back(ins);
}
};
void assign_streams(program& p, std::size_t n)
{
const std::size_t min_partition_threshold = 2;
partition critical;
std::unordered_map<instruction_ref, std::deque<partition>> partitions;
fix([&](auto self, auto ins, auto& part) {
// If weight is zero then stop
if(this->weights[ins] == 0)
return;
// Only assign streams if not already assigned
if(not this->has_stream(ins) and this->iweights[ins] > 0)
this->set_stream(ins, stream);
instruction_ref child = p.end();
std::size_t w = 0;
part.add(ins, this->iweights[ins]);
auto max_it = std::max_element(ins->inputs().begin(), ins->inputs().end(), by(std::less<>{}, index_of(this->weights)));
for(auto i : ins->inputs())
{
const auto weight = this->weights[i];
// Skip instruction that already have stream assignment or too low of weights
if(this->has_stream(i) or weight <= min_partition_threshold)
if (i == *max_it or weight <= min_partition_threshold)
{
self(i);
self(i, part);
}
// Accumulate the max weight
else if(weight > w)
else
{
child = i;
w = weight;
}
partitions[ins].emplace_back();
self(i, partitions[ins].back());
}
if(child != p.end())
self(child);
})(std::prev(p.end()));
}
// Assign remaining instructions
for(auto ins : iterator_for(p))
})(std::prev(p.end()), critical);
// Set the critical partition to stream 0
set_stream(critical, 0);
std::vector<std::size_t> streams(n-1);
// Assign streams for the other partitions
for(auto&& ins_part:partitions)
{
if(has_stream(ins))
continue;
if(iweights[ins] == 0)
continue;
set_stream(ins, streams - 1);
std::sort(ins_part.second.begin(), ins_part.second.end(), by(std::greater<>{}, [](auto&& x) { return std::make_tuple(x.weight, x.instructions.size()); }));
for(auto&& part:ins_part.second)
{
auto stream = std::min_element(streams.begin(), streams.end()) - streams.begin();
set_stream(part, stream+1);
streams[stream] += part.weight;
}
}
}
void set_stream(const partition& p, std::size_t n)
{
for(auto ins:p.instructions)
if (iweights[ins] > 0)
set_stream(ins, n);
}
void set_stream(instruction_ref ins, std::size_t n)
......@@ -243,10 +257,9 @@ void schedule::apply(program& p) const
// Topo sort
fix([&](auto self, auto ins) {
auto args = ins->inputs();
std::sort(args.begin(), args.end(), [&](auto x, auto y) {
return std::make_tuple(si.weights[x], x->inputs().size()) <
std::make_tuple(si.weights[y], y->inputs().size());
});
std::sort(args.begin(), args.end(), by(std::less<>{}, [&](auto x) {
return std::make_tuple(si.weights[x], x->inputs().size());
}));
for(auto i : args)
{
p.move_instruction(i, p.begin());
......
......@@ -369,14 +369,14 @@ TEST_CASE(seq_merge)
p.compile(schedule_target{&stream});
EXPECT(stream.count(one) == 0);
EXPECT(stream.at(i1) == 1);
EXPECT(stream.at(i1) == 2);
for(auto ins : c1)
EXPECT(stream.at(ins) == 0);
EXPECT(stream.at(binary1) == 0);
EXPECT(stream.at(ins) == 3);
EXPECT(stream.at(binary1) == 3);
EXPECT(get_wait_for(binary1) == get_wait_for(stream[binary1], {stream[c1.back()], stream[i1]}));
check_conflicts(p, {c1, {i1}});
EXPECT(stream.at(i2) == 1);
EXPECT(stream.at(i2) == 3);
for(auto ins : c2)
EXPECT(stream.at(ins) == 0);
EXPECT(stream.at(binary2) == 0);
......@@ -414,8 +414,8 @@ TEST_CASE(par_merge)
EXPECT(stream.at(i2) == 2);
for(auto ins : c2)
EXPECT(stream.at(ins) == 1);
EXPECT(stream.at(binary2) == 1);
EXPECT(stream.at(ins) == 3);
EXPECT(stream.at(binary2) == 3);
EXPECT(get_wait_for(binary2) == get_wait_for(stream[binary2], {stream[c2.back()], stream[i2]}));
check_conflicts(p, {c2, {i2}});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment