Commit 4008675f authored by Paul's avatar Paul
Browse files

Sort while creating partitions

parent 6e9142b5
...@@ -104,11 +104,9 @@ instruction_ref program::insert_instruction(instruction_ref ins, ...@@ -104,11 +104,9 @@ instruction_ref program::insert_instruction(instruction_ref ins,
args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) && args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction"); "Argument is not an exisiting instruction");
assert(not starts_with(op.name(), "@")); assert(not starts_with(op.name(), "@"));
// TODO: Use move
shape r = compute_shape(op, args); shape r = compute_shape(op, args);
auto result = impl->instructions.insert(ins, {op, r, std::move(args)}); auto result = impl->instructions.insert(ins, {op, r, std::move(args)});
instruction::backreference(result); instruction::backreference(result);
// assert(result->inputs() == args);
assert(result->valid(begin())); assert(result->valid(begin()));
return result; return result;
} }
......
...@@ -50,6 +50,39 @@ struct stream_info ...@@ -50,6 +50,39 @@ struct stream_info
})(last); })(last);
} }
std::vector<instruction_ref>::iterator sort_args(std::vector<instruction_ref>& args)
{
const std::size_t min_partition_threshold = 2;
auto compare = by(std::less<>{}, [&](auto x) {
return std::make_tuple(this->weights[x], x->inputs().size());
});
if (args.size() < 2)
{
return args.end();
}
else if (args.size() == 2)
{
auto w1 = this->weights[args[0]];
auto w2 = this->weights[args[1]];
if (std::make_tuple(w1, args[0]->inputs().size()) > std::make_tuple(w2, args[1]->inputs().size()))
{
std::swap(args[0], args[1]);
std::swap(w1, w2);
}
if (w1 > min_partition_threshold)
return args.begin();
if (w2 > min_partition_threshold)
return args.begin()+1;
return args.end();
}
std::sort(args.begin(), args.end(), compare);
return std::upper_bound(args.begin(), args.end(), min_partition_threshold, [&](std::size_t w, auto i) {
return w < this->weights[i];
});
}
struct partition struct partition
{ {
std::size_t weight = 0; std::size_t weight = 0;
...@@ -64,22 +97,24 @@ struct stream_info ...@@ -64,22 +97,24 @@ struct stream_info
void assign_streams(program& p, std::size_t n) void assign_streams(program& p, std::size_t n)
{ {
const std::size_t min_partition_threshold = 2;
partition critical; partition critical;
std::unordered_map<instruction_ref, std::deque<partition>> partitions; std::unordered_map<instruction_ref, std::deque<partition>> partitions;
partitions.reserve(weights.size());
fix([&](auto self, auto ins, auto& part) { fix([&](auto self, auto ins, auto& part) {
// If weight is zero then stop if (contains(partitions, ins))
if(this->weights[ins] == 0)
return; return;
partitions[ins];
part.add(ins, this->iweights[ins]); part.add(ins, this->iweights[ins]);
auto max_it = std::max_element(ins->inputs().begin(), auto args = ins->inputs();
ins->inputs().end(), auto threshold_it = sort_args(args);
by(std::less<>{}, index_of(this->weights))); for(auto i:range(args.begin(), threshold_it))
for(auto i : ins->inputs())
{ {
const auto weight = this->weights[i]; self(i, part);
if(i == *max_it or weight <= min_partition_threshold) }
for(auto i:range(threshold_it, args.end()))
{
if (i == args.back())
{ {
self(i, part); self(i, part);
} }
...@@ -89,6 +124,8 @@ struct stream_info ...@@ -89,6 +124,8 @@ struct stream_info
self(i, partitions[ins].back()); self(i, partitions[ins].back());
} }
} }
// Sort instructions
p.move_instruction(ins, p.end());
})(std::prev(p.end()), critical); })(std::prev(p.end()), critical);
// Set the critical partition to stream 0 // Set the critical partition to stream 0
...@@ -233,6 +270,8 @@ struct stream_info ...@@ -233,6 +270,8 @@ struct stream_info
{ {
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result; std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from; std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from;
result.reserve(p.size());
merge_from.reserve(p.size());
for(auto ins : reverse_iterator_for(p)) for(auto ins : reverse_iterator_for(p))
{ {
for(auto&& arg : ins->outputs()) for(auto&& arg : ins->outputs())
...@@ -254,7 +293,7 @@ struct stream_info ...@@ -254,7 +293,7 @@ struct stream_info
auto&& r = result[merge][stream]; auto&& r = result[merge][stream];
r.push_back(ins); r.push_back(ins);
// Copy inputs if they dont have a stream(and are not a builtin and context // Copy inputs if they dont have a stream(and are not a builtin and context
// free) Inputs without a stream can have a implicit dependency // free). Inputs without a stream can have a implicit dependency
std::copy_if(ins->inputs().begin(), std::copy_if(ins->inputs().begin(),
ins->inputs().end(), ins->inputs().end(),
std::back_inserter(r), std::back_inserter(r),
...@@ -277,19 +316,6 @@ void schedule::apply(program& p) const ...@@ -277,19 +316,6 @@ void schedule::apply(program& p) const
si.accumulate_weights(last, model); si.accumulate_weights(last, model);
si.assign_streams(p, model.concurrency()); si.assign_streams(p, model.concurrency());
// Topo sort
fix([&](auto self, auto ins) {
auto args = ins->inputs();
std::sort(args.begin(), args.end(), by(std::less<>{}, [&](auto x) {
return std::make_tuple(si.weights[x], x->inputs().size());
}));
for(auto i : args)
{
p.move_instruction(i, p.begin());
self(i);
}
})(last);
if(enabled(MIGRAPHX_TRACE_COMPILE{})) if(enabled(MIGRAPHX_TRACE_COMPILE{}))
{ {
p.annotate(std::cout, [&](auto ins) { p.annotate(std::cout, [&](auto ins) {
...@@ -308,10 +334,12 @@ void schedule::apply(program& p) const ...@@ -308,10 +334,12 @@ void schedule::apply(program& p) const
} }
// Schedule instructions // Schedule instructions
std::unordered_map<instruction_ref, std::size_t> ins2wait;
std::size_t wait_id = 0; std::size_t wait_id = 0;
std::unordered_map<instruction_ref, std::size_t> ins2wait;
std::unordered_map<std::size_t, std::unordered_set<std::size_t>> waited_for; std::unordered_map<std::size_t, std::unordered_set<std::size_t>> waited_for;
std::unordered_map<instruction_ref, std::unordered_set<std::size_t>> ins2waited; std::unordered_map<instruction_ref, std::unordered_set<std::size_t>> ins2waited;
ins2wait.reserve(p.size());
ins2waited.reserve(p.size());
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
{ {
// Only schedule instructions that have a stream // Only schedule instructions that have a stream
...@@ -364,6 +392,10 @@ void schedule::apply(program& p) const ...@@ -364,6 +392,10 @@ void schedule::apply(program& p) const
dfor(merge.second.size(), merge.second.size())([&](auto i, auto j) { dfor(merge.second.size(), merge.second.size())([&](auto i, auto j) {
if(i == j) if(i == j)
return; return;
if (merge.second[i].empty())
return;
if (merge.second[j].empty())
return;
for(auto ins1 : merge.second[i]) for(auto ins1 : merge.second[i])
{ {
auto args = merge.second[j]; auto args = merge.second[j];
......
...@@ -521,17 +521,19 @@ TEST_CASE(seq_merge) ...@@ -521,17 +521,19 @@ TEST_CASE(seq_merge)
p.compile(t); p.compile(t);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 2); EXPECT(t.get_stream(i1) != t.get_stream(i2));
EXPECT(t.get_stream(i1) != t.get_stream(c1.back()));
for(auto ins : c1) for(auto ins : c1)
EXPECT(t.get_stream(ins) == 3); EXPECT(t.get_stream(ins) == t.get_stream(c1.back()));
EXPECT(t.get_stream(binary1) == 3); EXPECT(t.get_stream(binary1) == t.get_stream(c1.back()));
EXPECT(get_wait_for(binary1) == EXPECT(get_wait_for(binary1) ==
get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)})); get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}}); check_conflicts(p, {c1, {i1}});
EXPECT(t.get_stream(i2) == 3); EXPECT(t.get_stream(i2) != t.get_stream(c2.back()));
for(auto ins : c2) for(auto ins : c2)
EXPECT(t.get_stream(ins) == 0); EXPECT(t.get_stream(ins) == t.get_stream(c2.back()));
EXPECT(t.get_stream(c1.back()) != t.get_stream(c2.back()));
EXPECT(t.get_stream(binary2) == 0); EXPECT(t.get_stream(binary2) == 0);
EXPECT(get_wait_for(binary2) == EXPECT(get_wait_for(binary2) ==
get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)})); get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
...@@ -559,7 +561,7 @@ TEST_CASE(par_merge) ...@@ -559,7 +561,7 @@ TEST_CASE(par_merge)
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(binary3) == 0); EXPECT(t.get_stream(binary3) == 0);
EXPECT(t.get_stream(i1) == 2); EXPECT(t.get_stream(i1) != t.get_stream(i2));
for(auto ins : c1) for(auto ins : c1)
EXPECT(t.get_stream(ins) == 0); EXPECT(t.get_stream(ins) == 0);
EXPECT(t.get_stream(binary1) == 0); EXPECT(t.get_stream(binary1) == 0);
...@@ -567,7 +569,6 @@ TEST_CASE(par_merge) ...@@ -567,7 +569,6 @@ TEST_CASE(par_merge)
get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)})); get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}}); check_conflicts(p, {c1, {i1}});
EXPECT(t.get_stream(i2) == 1);
for(auto ins : c2) for(auto ins : c2)
EXPECT(t.get_stream(ins) == 3); EXPECT(t.get_stream(ins) == 3);
EXPECT(t.get_stream(binary2) == 3); EXPECT(t.get_stream(binary2) == 3);
...@@ -684,7 +685,8 @@ TEST_CASE(inner_split1) ...@@ -684,7 +685,8 @@ TEST_CASE(inner_split1)
auto output = p.add_instruction(nary_op{}, i1, s1, s2); auto output = p.add_instruction(nary_op{}, i1, s1, s2);
p.compile(t); p.compile(t);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 3); EXPECT(t.get_stream(i1) != t.get_stream(s1));
EXPECT(t.get_stream(i1) != t.get_stream(s2));
for(auto ins : c1) for(auto ins : c1)
EXPECT(t.get_stream(ins) != t.get_stream(i1)); EXPECT(t.get_stream(ins) != t.get_stream(i1));
EXPECT(t.get_stream(s1) != t.get_stream(s2)); EXPECT(t.get_stream(s1) != t.get_stream(s2));
...@@ -709,7 +711,8 @@ TEST_CASE(inner_split2) ...@@ -709,7 +711,8 @@ TEST_CASE(inner_split2)
auto output = p.add_instruction(nary_op{}, i1, s1.back(), s2.back()); auto output = p.add_instruction(nary_op{}, i1, s1.back(), s2.back());
p.compile(t); p.compile(t);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 2); EXPECT(t.get_stream(i1) != t.get_stream(s1.back()));
EXPECT(t.get_stream(i1) != t.get_stream(s2.back()));
for(auto ins : c1) for(auto ins : c1)
EXPECT(t.get_stream(ins) != t.get_stream(i1)); EXPECT(t.get_stream(ins) != t.get_stream(i1));
EXPECT(t.get_stream(s1.back()) != t.get_stream(s2.back())); EXPECT(t.get_stream(s1.back()) != t.get_stream(s2.back()));
...@@ -735,7 +738,7 @@ TEST_CASE(inception_resnet) ...@@ -735,7 +738,7 @@ TEST_CASE(inception_resnet)
auto output = p.add_instruction(nary_op{}, binary, input); auto output = p.add_instruction(nary_op{}, binary, input);
p.compile(t); p.compile(t);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 2); EXPECT(t.get_stream(i1) != 0);
for(auto ins : c1) for(auto ins : c1)
EXPECT(t.get_stream(ins) == 0); EXPECT(t.get_stream(ins) == 0);
EXPECT(t.get_stream(binary) == 0); EXPECT(t.get_stream(binary) == 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment