Commit 4008675f authored by Paul's avatar Paul
Browse files

Sort while creating partitions

parent 6e9142b5
......@@ -104,11 +104,9 @@ instruction_ref program::insert_instruction(instruction_ref ins,
args.begin(), args.end(), [&](instruction_ref x) { return has_instruction(x); }) &&
"Argument is not an exisiting instruction");
assert(not starts_with(op.name(), "@"));
// TODO: Use move
shape r = compute_shape(op, args);
auto result = impl->instructions.insert(ins, {op, r, std::move(args)});
instruction::backreference(result);
// assert(result->inputs() == args);
assert(result->valid(begin()));
return result;
}
......
......@@ -50,6 +50,39 @@ struct stream_info
})(last);
}
std::vector<instruction_ref>::iterator sort_args(std::vector<instruction_ref>& args)
{
const std::size_t min_partition_threshold = 2;
auto compare = by(std::less<>{}, [&](auto x) {
return std::make_tuple(this->weights[x], x->inputs().size());
});
if (args.size() < 2)
{
return args.end();
}
else if (args.size() == 2)
{
auto w1 = this->weights[args[0]];
auto w2 = this->weights[args[1]];
if (std::make_tuple(w1, args[0]->inputs().size()) > std::make_tuple(w2, args[1]->inputs().size()))
{
std::swap(args[0], args[1]);
std::swap(w1, w2);
}
if (w1 > min_partition_threshold)
return args.begin();
if (w2 > min_partition_threshold)
return args.begin()+1;
return args.end();
}
std::sort(args.begin(), args.end(), compare);
return std::upper_bound(args.begin(), args.end(), min_partition_threshold, [&](std::size_t w, auto i) {
return w < this->weights[i];
});
}
struct partition
{
std::size_t weight = 0;
......@@ -64,22 +97,24 @@ struct stream_info
void assign_streams(program& p, std::size_t n)
{
const std::size_t min_partition_threshold = 2;
partition critical;
std::unordered_map<instruction_ref, std::deque<partition>> partitions;
partitions.reserve(weights.size());
fix([&](auto self, auto ins, auto& part) {
// If weight is zero then stop
if(this->weights[ins] == 0)
if (contains(partitions, ins))
return;
partitions[ins];
part.add(ins, this->iweights[ins]);
auto max_it = std::max_element(ins->inputs().begin(),
ins->inputs().end(),
by(std::less<>{}, index_of(this->weights)));
for(auto i : ins->inputs())
auto args = ins->inputs();
auto threshold_it = sort_args(args);
for(auto i:range(args.begin(), threshold_it))
{
const auto weight = this->weights[i];
if(i == *max_it or weight <= min_partition_threshold)
self(i, part);
}
for(auto i:range(threshold_it, args.end()))
{
if (i == args.back())
{
self(i, part);
}
......@@ -89,6 +124,8 @@ struct stream_info
self(i, partitions[ins].back());
}
}
// Sort instructions
p.move_instruction(ins, p.end());
})(std::prev(p.end()), critical);
// Set the critical partition to stream 0
......@@ -233,6 +270,8 @@ struct stream_info
{
std::unordered_map<instruction_ref, std::vector<std::vector<instruction_ref>>> result;
std::unordered_map<instruction_ref, std::unordered_set<instruction_ref>> merge_from;
result.reserve(p.size());
merge_from.reserve(p.size());
for(auto ins : reverse_iterator_for(p))
{
for(auto&& arg : ins->outputs())
......@@ -254,7 +293,7 @@ struct stream_info
auto&& r = result[merge][stream];
r.push_back(ins);
// Copy inputs if they dont have a stream(and are not a builtin and context
// free) Inputs without a stream can have a implicit dependency
// free). Inputs without a stream can have a implicit dependency
std::copy_if(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(r),
......@@ -277,19 +316,6 @@ void schedule::apply(program& p) const
si.accumulate_weights(last, model);
si.assign_streams(p, model.concurrency());
// Topo sort
fix([&](auto self, auto ins) {
auto args = ins->inputs();
std::sort(args.begin(), args.end(), by(std::less<>{}, [&](auto x) {
return std::make_tuple(si.weights[x], x->inputs().size());
}));
for(auto i : args)
{
p.move_instruction(i, p.begin());
self(i);
}
})(last);
if(enabled(MIGRAPHX_TRACE_COMPILE{}))
{
p.annotate(std::cout, [&](auto ins) {
......@@ -308,10 +334,12 @@ void schedule::apply(program& p) const
}
// Schedule instructions
std::unordered_map<instruction_ref, std::size_t> ins2wait;
std::size_t wait_id = 0;
std::unordered_map<instruction_ref, std::size_t> ins2wait;
std::unordered_map<std::size_t, std::unordered_set<std::size_t>> waited_for;
std::unordered_map<instruction_ref, std::unordered_set<std::size_t>> ins2waited;
ins2wait.reserve(p.size());
ins2waited.reserve(p.size());
for(auto ins : iterator_for(p))
{
// Only schedule instructions that have a stream
......@@ -364,6 +392,10 @@ void schedule::apply(program& p) const
dfor(merge.second.size(), merge.second.size())([&](auto i, auto j) {
if(i == j)
return;
if (merge.second[i].empty())
return;
if (merge.second[j].empty())
return;
for(auto ins1 : merge.second[i])
{
auto args = merge.second[j];
......
......@@ -521,17 +521,19 @@ TEST_CASE(seq_merge)
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 2);
EXPECT(t.get_stream(i1) != t.get_stream(i2));
EXPECT(t.get_stream(i1) != t.get_stream(c1.back()));
for(auto ins : c1)
EXPECT(t.get_stream(ins) == 3);
EXPECT(t.get_stream(binary1) == 3);
EXPECT(t.get_stream(ins) == t.get_stream(c1.back()));
EXPECT(t.get_stream(binary1) == t.get_stream(c1.back()));
EXPECT(get_wait_for(binary1) ==
get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}});
EXPECT(t.get_stream(i2) == 3);
EXPECT(t.get_stream(i2) != t.get_stream(c2.back()));
for(auto ins : c2)
EXPECT(t.get_stream(ins) == 0);
EXPECT(t.get_stream(ins) == t.get_stream(c2.back()));
EXPECT(t.get_stream(c1.back()) != t.get_stream(c2.back()));
EXPECT(t.get_stream(binary2) == 0);
EXPECT(get_wait_for(binary2) ==
get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
......@@ -559,7 +561,7 @@ TEST_CASE(par_merge)
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(binary3) == 0);
EXPECT(t.get_stream(i1) == 2);
EXPECT(t.get_stream(i1) != t.get_stream(i2));
for(auto ins : c1)
EXPECT(t.get_stream(ins) == 0);
EXPECT(t.get_stream(binary1) == 0);
......@@ -567,7 +569,6 @@ TEST_CASE(par_merge)
get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}});
EXPECT(t.get_stream(i2) == 1);
for(auto ins : c2)
EXPECT(t.get_stream(ins) == 3);
EXPECT(t.get_stream(binary2) == 3);
......@@ -684,7 +685,8 @@ TEST_CASE(inner_split1)
auto output = p.add_instruction(nary_op{}, i1, s1, s2);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 3);
EXPECT(t.get_stream(i1) != t.get_stream(s1));
EXPECT(t.get_stream(i1) != t.get_stream(s2));
for(auto ins : c1)
EXPECT(t.get_stream(ins) != t.get_stream(i1));
EXPECT(t.get_stream(s1) != t.get_stream(s2));
......@@ -709,7 +711,8 @@ TEST_CASE(inner_split2)
auto output = p.add_instruction(nary_op{}, i1, s1.back(), s2.back());
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 2);
EXPECT(t.get_stream(i1) != t.get_stream(s1.back()));
EXPECT(t.get_stream(i1) != t.get_stream(s2.back()));
for(auto ins : c1)
EXPECT(t.get_stream(ins) != t.get_stream(i1));
EXPECT(t.get_stream(s1.back()) != t.get_stream(s2.back()));
......@@ -735,7 +738,7 @@ TEST_CASE(inception_resnet)
auto output = p.add_instruction(nary_op{}, binary, input);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 2);
EXPECT(t.get_stream(i1) != 0);
for(auto ins : c1)
EXPECT(t.get_stream(ins) == 0);
EXPECT(t.get_stream(binary) == 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment