Commit fe13db50 authored by Paul's avatar Paul
Browse files

Sort in reverse

parent e46a1cb2
...@@ -505,6 +505,16 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params) ...@@ -505,6 +505,16 @@ void program::perf_report(std::ostream& os, std::size_t n, parameter_map params)
void program::debug_print() const { std::cout << *this << std::endl; } void program::debug_print() const { std::cout << *this << std::endl; }
void program::debug_print(instruction_ref ins) const void program::debug_print(instruction_ref ins) const
{ {
if (ins == this->end())
{
std::cout << "End instruction" << std::endl;
return;
}
if (not has_instruction(ins))
{
std::cout << "Instruction not part of program" << std::endl;
return;
}
std::stringstream ss; std::stringstream ss;
print_program(ss, *this, [&](auto x, auto&& names) { print_program(ss, *this, [&](auto x, auto&& names) {
if(x == ins) if(x == ins)
......
...@@ -52,37 +52,24 @@ struct stream_info ...@@ -52,37 +52,24 @@ struct stream_info
std::vector<instruction_ref>::iterator sort_args(std::vector<instruction_ref>& args) std::vector<instruction_ref>::iterator sort_args(std::vector<instruction_ref>& args)
{ {
const std::size_t min_partition_threshold = 2;
auto compare = by(std::less<>{}, [&](auto x) {
return std::make_tuple(this->weights[x], x->inputs().size());
});
if(args.size() < 2) if(args.size() < 2)
{ {
return args.end(); return args.end();
} }
else if(args.size() == 2)
{
auto w1 = this->weights[args[0]];
auto w2 = this->weights[args[1]];
if(std::make_tuple(w1, args[0]->inputs().size()) >
std::make_tuple(w2, args[1]->inputs().size()))
{
std::swap(args[0], args[1]);
std::swap(w1, w2);
}
if(w1 > min_partition_threshold)
return args.begin();
if(w2 > min_partition_threshold)
return args.begin() + 1;
return args.end();
}
const std::size_t min_partition_threshold = 2;
auto compare = by(std::greater<>{}, [&](auto x) {
return std::make_tuple(this->weights[x], x->inputs().size());
});
std::sort(args.begin(), args.end(), compare); std::sort(args.begin(), args.end(), compare);
return std::upper_bound(args.begin(), auto it = std::lower_bound(std::next(args.begin()),
args.end(), args.end(),
min_partition_threshold, min_partition_threshold,
[&](std::size_t w, auto i) { return w < this->weights[i]; }); [&](auto i, std::size_t w) { return this->weights[i] > w; });
assert(it == args.end() or this->weights[*it] <= min_partition_threshold);
assert(it == args.end() or std::prev(it) == args.begin() or this->weights[*std::prev(it)] > min_partition_threshold);
return it;
} }
struct partition struct partition
...@@ -103,28 +90,31 @@ struct stream_info ...@@ -103,28 +90,31 @@ struct stream_info
std::unordered_map<instruction_ref, std::deque<partition>> partitions; std::unordered_map<instruction_ref, std::deque<partition>> partitions;
partitions.reserve(weights.size()); partitions.reserve(weights.size());
fix([&](auto self, auto ins, auto& part) { fix([&](auto self, auto ins, auto& part) {
assert(ins != p.end());
if(contains(partitions, ins)) if(contains(partitions, ins))
return; return;
assert(p.has_instruction(ins));
// Add an entry so we know the instruction was visited
partitions[ins]; partitions[ins];
part.add(ins, this->iweights[ins]); part.add(ins, this->iweights[ins]);
auto args = ins->inputs(); auto args = ins->inputs();
auto threshold_it = sort_args(args); auto threshold_it = sort_args(args);
for(auto i : range(args.begin(), threshold_it))
if (not args.empty())
{ {
self(i, part); assert(threshold_it != args.begin());
self(args.front(), part);
for(auto i : range(std::next(args.begin()), threshold_it))
{
partitions[ins].emplace_back();
self(i, partitions[ins].back());
} }
for(auto i : range(threshold_it, args.end())) for(auto i : range(threshold_it, args.end()))
{
if(i == args.back())
{ {
self(i, part); self(i, part);
} }
else
{
partitions[ins].emplace_back();
self(i, partitions[ins].back());
}
} }
// Sort instructions // Sort instructions
p.move_instruction(ins, p.end()); p.move_instruction(ins, p.end());
......
...@@ -521,7 +521,6 @@ TEST_CASE(seq_merge) ...@@ -521,7 +521,6 @@ TEST_CASE(seq_merge)
p.compile(t); p.compile(t);
EXPECT(not t.has_stream(one)); EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) != t.get_stream(i2));
EXPECT(t.get_stream(i1) != t.get_stream(c1.back())); EXPECT(t.get_stream(i1) != t.get_stream(c1.back()));
for(auto ins : c1) for(auto ins : c1)
EXPECT(t.get_stream(ins) == t.get_stream(c1.back())); EXPECT(t.get_stream(ins) == t.get_stream(c1.back()));
...@@ -533,7 +532,6 @@ TEST_CASE(seq_merge) ...@@ -533,7 +532,6 @@ TEST_CASE(seq_merge)
EXPECT(t.get_stream(i2) != t.get_stream(c2.back())); EXPECT(t.get_stream(i2) != t.get_stream(c2.back()));
for(auto ins : c2) for(auto ins : c2)
EXPECT(t.get_stream(ins) == t.get_stream(c2.back())); EXPECT(t.get_stream(ins) == t.get_stream(c2.back()));
EXPECT(t.get_stream(c1.back()) != t.get_stream(c2.back()));
EXPECT(t.get_stream(binary2) == 0); EXPECT(t.get_stream(binary2) == 0);
EXPECT(get_wait_for(binary2) == EXPECT(get_wait_for(binary2) ==
get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)})); get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
...@@ -695,7 +693,10 @@ TEST_CASE(inner_split1) ...@@ -695,7 +693,10 @@ TEST_CASE(inner_split1)
EXPECT(get_wait_for(output) == EXPECT(get_wait_for(output) ==
get_wait_for( get_wait_for(
t.get_stream(output), t.get_stream(output),
{t.get_stream(c1.back()), t.get_stream(i1), t.get_stream(s1), t.get_stream(s2)})); {t.get_stream(i1), t.get_stream(s1), t.get_stream(s2)}));
EXPECT(get_wait_for(s1).empty());
// TODO: Remove the extra wait here
// EXPECT(get_wait_for(s2).empty());
check_conflicts(p, {c1, {i1}, {s1}, {s2}}); check_conflicts(p, {c1, {i1}, {s1}, {s2}});
} }
...@@ -719,10 +720,10 @@ TEST_CASE(inner_split2) ...@@ -719,10 +720,10 @@ TEST_CASE(inner_split2)
EXPECT(t.get_stream(output) == 0); EXPECT(t.get_stream(output) == 0);
EXPECT(get_wait_for(output) == get_wait_for(t.get_stream(output), EXPECT(get_wait_for(output) == get_wait_for(t.get_stream(output),
{t.get_stream(c1.back()), {t.get_stream(i1),
t.get_stream(i1),
t.get_stream(s1.back()), t.get_stream(s1.back()),
t.get_stream(s2.back())})); t.get_stream(s2.back())}));
EXPECT(get_wait_for(s1.front()) == get_wait_for({t.get_stream(c1.back())}));
check_conflicts(p, {c1, {i1}, s1, s2}); check_conflicts(p, {c1, {i1}, s1, s2});
} }
...@@ -846,8 +847,8 @@ TEST_CASE(inception1) ...@@ -846,8 +847,8 @@ TEST_CASE(inception1)
p.compile(t); p.compile(t);
EXPECT(t.get_streams({i7, i11, i17, i23, i25, i31, i37, i39, i94}) == EXPECT(t.get_streams({i7, i11, i17, i23, i25, i31, i37, i39}) ==
t.get_streams({i7, i7, i7, i7, i7, i7, i7, i7, i7})); t.get_streams({i7, i7, i7, i7, i7, i7, i7, i7}));
EXPECT(t.get_streams({i48, i54, i61, output}) == EXPECT(t.get_streams({i48, i54, i61, output}) ==
t.get_streams({output, output, output, output})); t.get_streams({output, output, output, output}));
EXPECT(t.get_streams({i80, i86}) == t.get_streams({i80, i80})); EXPECT(t.get_streams({i80, i86}) == t.get_streams({i80, i80}));
...@@ -856,15 +857,12 @@ TEST_CASE(inception1) ...@@ -856,15 +857,12 @@ TEST_CASE(inception1)
EXPECT(t.get_stream(i7) != t.get_stream(i80)); EXPECT(t.get_stream(i7) != t.get_stream(i80));
EXPECT(t.get_stream(i69) != t.get_stream(i80)); EXPECT(t.get_stream(i69) != t.get_stream(i80));
EXPECT(t.get_stream(i69) != t.get_stream(i7)); EXPECT(t.get_stream(i69) != t.get_stream(i7));
EXPECT(t.get_stream(output) != t.get_stream(i7));
EXPECT(t.get_stream(output) != t.get_stream(i69)); EXPECT(t.get_stream(output) != t.get_stream(i69));
EXPECT(t.get_stream(output) != t.get_stream(i80)); EXPECT(t.get_stream(output) != t.get_stream(i80));
EXPECT(get_wait_for(i48) == get_wait_for({t.get_stream(i39)}));
EXPECT(get_wait_for(i80) == get_wait_for({t.get_stream(i39)})); EXPECT(get_wait_for(i80) == get_wait_for({t.get_stream(i39)}));
EXPECT(get_wait_for(i69) == get_wait_for({t.get_stream(i39)})); EXPECT(get_wait_for(i69) == get_wait_for({t.get_stream(i39)}));
// We dont wait twice EXPECT(get_wait_for(i94) == get_wait_for({t.get_stream(i39)}));
EXPECT(get_wait_for(i94).empty());
EXPECT( EXPECT(
get_wait_for(output) == get_wait_for(output) ==
get_wait_for(t.get_stream(output), get_wait_for(t.get_stream(output),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment