Commit 65cf0713 authored by Paul's avatar Paul
Browse files

Fix unit tests

parent 2f7db364
...@@ -49,11 +49,11 @@ struct nary_op ...@@ -49,11 +49,11 @@ struct nary_op
struct wait_event struct wait_event
{ {
std::vector<std::size_t> wait_for; std::shared_ptr<std::vector<std::size_t>> wait_for = std::make_shared<std::vector<std::size_t>>();
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return migraphx::pack(f(self.wait_for, "wait_for")); return migraphx::pack(f(*self.wait_for, "wait_for"));
} }
std::string name() const { return "wait_event"; } std::string name() const { return "wait_event"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; } migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; }
...@@ -62,28 +62,39 @@ struct wait_event ...@@ -62,28 +62,39 @@ struct wait_event
const migraphx::shape&, const migraphx::shape&,
const std::vector<migraphx::argument>&) const const std::vector<migraphx::argument>&) const
{ {
assert(not wait_for.empty()); assert(wait_for != nullptr);
assert(not wait_for->empty());
return {}; return {};
} }
}; };
using instruction_map = std::unordered_map<migraphx::instruction_ref, std::size_t>; using instruction_map = std::unordered_map<migraphx::instruction_ref, std::size_t>;
using wait_map = std::unordered_map<migraphx::instruction_ref, std::shared_ptr<std::vector<std::size_t>>>;
struct schedule_model_test struct schedule_model_test
{ {
instruction_map* ins2stream; std::shared_ptr<instruction_map> ins2stream = std::make_shared<instruction_map>();
std::shared_ptr<std::unordered_map<std::size_t, std::size_t>> wait2stream = std::make_shared<std::unordered_map<std::size_t, std::size_t>>();
std::shared_ptr<wait_map> ins2wait_for = std::make_shared<wait_map>();
std::size_t concurrency() const { return 4; } std::size_t concurrency() const { return 4; }
void void
schedule_instruction(migraphx::program&, migraphx::instruction_ref ins, std::size_t n) const sched(migraphx::program&, migraphx::instruction_ref ins, std::size_t n) const
{ {
(*ins2stream)[ins] = n; (*ins2stream)[ins] = n;
} }
void wait(migraphx::program& p, void wait(migraphx::program& p, migraphx::instruction_ref ins, std::size_t wait_id) const
migraphx::instruction_ref ins,
std::size_t,
const std::vector<std::size_t>& wait_for) const
{ {
p.insert_instruction(ins, wait_event{wait_for}); if (ins2wait_for->count(ins) == 0)
{
auto event = wait_event{};
p.insert_instruction(ins, event);
(*ins2wait_for)[ins] = event.wait_for;
}
(*ins2wait_for)[ins]->push_back(wait2stream->at(wait_id));
}
void record(migraphx::program& p, migraphx::instruction_ref ins, std::size_t wait_id) const
{
(*wait2stream)[wait_id] = ins2stream->at(ins);
} }
std::size_t weight(const migraphx::operation& op) const std::size_t weight(const migraphx::operation& op) const
{ {
...@@ -96,13 +107,23 @@ struct schedule_model_test ...@@ -96,13 +107,23 @@ struct schedule_model_test
struct schedule_target struct schedule_target
{ {
instruction_map* ins2stream; schedule_model_test model{};
std::string name() const { return "schedule"; } std::string name() const { return "schedule"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const std::vector<migraphx::pass> get_passes(migraphx::context&) const
{ {
return {migraphx::schedule{schedule_model_test{ins2stream}}}; return {migraphx::schedule{model}};
} }
migraphx::context get_context() const { return {}; } migraphx::context get_context() const { return {}; }
std::size_t get_stream(migraphx::instruction_ref ins)
{
return model.ins2stream->at(ins);
}
bool has_stream(migraphx::instruction_ref ins)
{
return model.ins2stream->count(ins) > 0;
}
}; };
bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx::instruction_ref y) bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx::instruction_ref y)
...@@ -166,7 +187,7 @@ std::vector<std::size_t> get_wait_for(migraphx::instruction_ref ins) ...@@ -166,7 +187,7 @@ std::vector<std::size_t> get_wait_for(migraphx::instruction_ref ins)
auto wait_ins = std::prev(ins); auto wait_ins = std::prev(ins);
if(wait_ins->name() != "wait_event") if(wait_ins->name() != "wait_event")
return {}; return {};
auto wf = migraphx::any_cast<wait_event>(wait_ins->get_operator()).wait_for; auto wf = *migraphx::any_cast<wait_event>(wait_ins->get_operator()).wait_for;
std::sort(wf.begin(), wf.end()); std::sort(wf.begin(), wf.end());
return wf; return wf;
} }
...@@ -183,36 +204,35 @@ chain(migraphx::program& p, std::size_t n, T x, migraphx::instruction_ref input) ...@@ -183,36 +204,35 @@ chain(migraphx::program& p, std::size_t n, T x, migraphx::instruction_ref input)
} }
return result; return result;
} }
TEST_CASE(single_entry) TEST_CASE(single_entry)
{ {
instruction_map stream; schedule_target t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one); auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one); auto onep2 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, onep1, onep2); auto binary = p.add_instruction(nary_op{}, onep1, onep2);
p.compile(schedule_target{&stream}); p.compile(t);
EXPECT(stream.count(one) == 0); EXPECT(not t.has_stream(one));
EXPECT(stream.at(onep1) != stream.at(onep2)); EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
EXPECT(stream.at(binary) == 0); EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) == get_wait_for(stream[binary], {stream[onep1], stream[onep2]})); EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary), {t.get_stream(onep1), t.get_stream(onep2)}));
EXPECT(check_conflicts(p, onep1, onep2)); EXPECT(check_conflicts(p, onep1, onep2));
} }
TEST_CASE(zero_merge1) TEST_CASE(zero_merge1)
{ {
instruction_map stream; schedule_target t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one); auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one); auto onep2 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(migraphx::op::identity{}, onep1, onep2); auto binary = p.add_instruction(migraphx::op::identity{}, onep1, onep2);
p.compile(schedule_target{&stream}); p.compile(t);
EXPECT(stream.count(one) == 0); EXPECT(not t.has_stream(one));
EXPECT(stream.at(onep1) != stream.at(onep2)); EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
// No stream assignment // No stream assignment
EXPECT(stream.count(binary) == 0); EXPECT(not t.has_stream(binary));
// There is no wait // There is no wait
EXPECT(get_wait_for(binary).empty()); EXPECT(get_wait_for(binary).empty());
EXPECT(check_conflicts(p, onep1, onep2)); EXPECT(check_conflicts(p, onep1, onep2));
...@@ -220,7 +240,7 @@ TEST_CASE(zero_merge1) ...@@ -220,7 +240,7 @@ TEST_CASE(zero_merge1)
TEST_CASE(zero_merge2) TEST_CASE(zero_merge2)
{ {
instruction_map stream; schedule_target t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one); auto onep1 = p.add_instruction(unary_op{}, one);
...@@ -228,11 +248,11 @@ TEST_CASE(zero_merge2) ...@@ -228,11 +248,11 @@ TEST_CASE(zero_merge2)
auto binary = p.add_instruction(migraphx::op::identity{}, auto binary = p.add_instruction(migraphx::op::identity{},
p.add_instruction(migraphx::op::identity{}, onep1), p.add_instruction(migraphx::op::identity{}, onep1),
p.add_instruction(migraphx::op::identity{}, onep2)); p.add_instruction(migraphx::op::identity{}, onep2));
p.compile(schedule_target{&stream}); p.compile(t);
EXPECT(stream.count(one) == 0); EXPECT(not t.has_stream(one));
EXPECT(stream.at(onep1) != stream.at(onep2)); EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
// No stream assignment // No stream assignment
EXPECT(stream.count(binary) == 0); EXPECT(not t.has_stream(binary));
// There is no wait // There is no wait
EXPECT(get_wait_for(binary).empty()); EXPECT(get_wait_for(binary).empty());
EXPECT(check_conflicts(p, onep1, onep2)); EXPECT(check_conflicts(p, onep1, onep2));
...@@ -240,43 +260,43 @@ TEST_CASE(zero_merge2) ...@@ -240,43 +260,43 @@ TEST_CASE(zero_merge2)
TEST_CASE(double_entry) TEST_CASE(double_entry)
{ {
instruction_map stream; schedule_target t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto onep = p.add_instruction(unary_op{}, one); auto onep = p.add_instruction(unary_op{}, one);
auto twop = p.add_instruction(unary_op{}, two); auto twop = p.add_instruction(unary_op{}, two);
auto binary = p.add_instruction(nary_op{}, onep, twop); auto binary = p.add_instruction(nary_op{}, onep, twop);
p.compile(schedule_target{&stream}); p.compile(t);
EXPECT(stream.count(one) == 0); EXPECT(not t.has_stream(one));
EXPECT(stream.count(two) == 0); EXPECT(not t.has_stream(two));
EXPECT(stream.at(onep) != stream.at(twop)); EXPECT(t.get_stream(onep) != t.get_stream(twop));
EXPECT(stream.at(binary) == 0); EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) == get_wait_for(stream[binary], {stream[onep], stream[twop]})); EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary), {t.get_stream(onep), t.get_stream(twop)}));
// EXPECT(check_conflicts(p, onep, twop)); // EXPECT(check_conflicts(p, onep, twop));
} }
TEST_CASE(two_branches) TEST_CASE(two_branches)
{ {
instruction_map stream; schedule_target t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one); auto c1 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one); auto i1 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, i1, c1.back()); auto binary = p.add_instruction(nary_op{}, i1, c1.back());
p.compile(schedule_target{&stream}); p.compile(t);
EXPECT(stream.count(one) == 0); EXPECT(not t.has_stream(one));
EXPECT(stream.at(i1) == 1); EXPECT(t.get_stream(i1) == 1);
for(auto ins : c1) for(auto ins : c1)
EXPECT(stream.at(ins) == 0); EXPECT(t.get_stream(ins) == 0);
EXPECT(stream.at(binary) == 0); EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) == get_wait_for(stream[binary], {stream[c1.back()], stream[i1]})); EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}}); check_conflicts(p, {c1, {i1}});
} }
TEST_CASE(four_branches) TEST_CASE(four_branches)
{ {
instruction_map stream; schedule_target t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto c1 = chain(p, 4, unary_op{}, one); auto c1 = chain(p, 4, unary_op{}, one);
...@@ -284,25 +304,25 @@ TEST_CASE(four_branches) ...@@ -284,25 +304,25 @@ TEST_CASE(four_branches)
auto c3 = chain(p, 2, unary_op{}, one); auto c3 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one); auto i1 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back()); auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back());
p.compile(schedule_target{&stream}); p.compile(t);
EXPECT(stream.count(one) == 0); EXPECT(not t.has_stream(one));
EXPECT(stream.at(i1) == 3); EXPECT(t.get_stream(i1) == 3);
for(auto ins : c1) for(auto ins : c1)
EXPECT(stream.at(ins) == 0); EXPECT(t.get_stream(ins) == 0);
for(auto ins : c2) for(auto ins : c2)
EXPECT(stream.at(ins) == 1); EXPECT(t.get_stream(ins) == 1);
for(auto ins : c3) for(auto ins : c3)
EXPECT(stream.at(ins) == 2); EXPECT(t.get_stream(ins) == 2);
EXPECT(stream.at(binary) == 0); EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) == EXPECT(get_wait_for(binary) ==
get_wait_for(stream[binary], get_wait_for(t.get_stream(binary),
{stream[c1.back()], stream[c2.back()], stream[c3.back()], stream[i1]})); {t.get_stream(c1.back()), t.get_stream(c2.back()), t.get_stream(c3.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, c2, c3, {i1}}); check_conflicts(p, {c1, c2, c3, {i1}});
} }
TEST_CASE(five_branches) TEST_CASE(five_branches)
{ {
instruction_map stream; schedule_target t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto c1 = chain(p, 5, unary_op{}, one); auto c1 = chain(p, 5, unary_op{}, one);
...@@ -311,28 +331,28 @@ TEST_CASE(five_branches) ...@@ -311,28 +331,28 @@ TEST_CASE(five_branches)
auto c4 = chain(p, 2, unary_op{}, one); auto c4 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one); auto i1 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back(), c4.back()); auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back(), c4.back());
p.compile(schedule_target{&stream}); p.compile(t);
EXPECT(stream.count(one) == 0); EXPECT(not t.has_stream(one));
EXPECT(stream.at(i1) == 3); EXPECT(t.get_stream(i1) == 3);
for(auto ins : c1) for(auto ins : c1)
EXPECT(stream.at(ins) == 0); EXPECT(t.get_stream(ins) == 0);
for(auto ins : c2) for(auto ins : c2)
EXPECT(stream.at(ins) == 1); EXPECT(t.get_stream(ins) == 1);
for(auto ins : c3) for(auto ins : c3)
EXPECT(stream.at(ins) == 2); EXPECT(t.get_stream(ins) == 2);
for(auto ins : c4) for(auto ins : c4)
EXPECT(stream.at(ins) == 3); EXPECT(t.get_stream(ins) == 3);
EXPECT(stream.at(binary) == 0); EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) == EXPECT(get_wait_for(binary) ==
get_wait_for(stream[binary], get_wait_for(t.get_stream(binary),
{stream[c1.back()], stream[c2.back()], stream[c3.back()], stream[i1]})); {t.get_stream(c1.back()), t.get_stream(c2.back()), t.get_stream(c3.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, c2, c3, c4}); check_conflicts(p, {c1, c2, c3, c4});
check_conflicts(p, {c1, c2, c3, {i1}}); check_conflicts(p, {c1, c2, c3, {i1}});
} }
TEST_CASE(four_branches_eq) TEST_CASE(four_branches_eq)
{ {
instruction_map stream; schedule_target t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one); auto onep1 = p.add_instruction(unary_op{}, one);
...@@ -340,22 +360,22 @@ TEST_CASE(four_branches_eq) ...@@ -340,22 +360,22 @@ TEST_CASE(four_branches_eq)
auto onep3 = p.add_instruction(unary_op{}, one); auto onep3 = p.add_instruction(unary_op{}, one);
auto onep4 = p.add_instruction(unary_op{}, one); auto onep4 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, onep1, onep2, onep3, onep4); auto binary = p.add_instruction(nary_op{}, onep1, onep2, onep3, onep4);
p.compile(schedule_target{&stream}); p.compile(t);
EXPECT(stream.count(one) == 0); EXPECT(not t.has_stream(one));
EXPECT(sorted<std::size_t>( EXPECT(sorted<std::size_t>(
{stream.at(onep1), stream.at(onep2), stream.at(onep3), stream.at(onep4)}) == {t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}) ==
unique<std::size_t>( unique<std::size_t>(
{stream.at(onep1), stream.at(onep2), stream.at(onep3), stream.at(onep4)})); {t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}));
EXPECT(stream.at(binary) == 0); EXPECT(t.get_stream(binary) == 0);
EXPECT( EXPECT(
get_wait_for(binary) == get_wait_for(binary) ==
get_wait_for(stream[binary], {stream[onep1], stream[onep2], stream[onep3], stream[onep4]})); get_wait_for(t.get_stream(binary), {t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}));
check_conflicts(p, {{onep1}, {onep2}, {onep3}, {onep4}}); check_conflicts(p, {{onep1}, {onep2}, {onep3}, {onep4}});
} }
TEST_CASE(seq_merge) TEST_CASE(seq_merge)
{ {
instruction_map stream; schedule_target t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one); auto c1 = chain(p, 2, unary_op{}, one);
...@@ -366,27 +386,27 @@ TEST_CASE(seq_merge) ...@@ -366,27 +386,27 @@ TEST_CASE(seq_merge)
auto i2 = p.add_instruction(unary_op{}, binary1); auto i2 = p.add_instruction(unary_op{}, binary1);
auto binary2 = p.add_instruction(nary_op{}, i2, c2.back()); auto binary2 = p.add_instruction(nary_op{}, i2, c2.back());
p.compile(schedule_target{&stream}); p.compile(t);
EXPECT(stream.count(one) == 0); EXPECT(not t.has_stream(one));
EXPECT(stream.at(i1) == 2); EXPECT(t.get_stream(i1) == 2);
for(auto ins : c1) for(auto ins : c1)
EXPECT(stream.at(ins) == 3); EXPECT(t.get_stream(ins) == 3);
EXPECT(stream.at(binary1) == 3); EXPECT(t.get_stream(binary1) == 3);
EXPECT(get_wait_for(binary1) == get_wait_for(stream[binary1], {stream[c1.back()], stream[i1]})); EXPECT(get_wait_for(binary1) == get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}}); check_conflicts(p, {c1, {i1}});
EXPECT(stream.at(i2) == 3); EXPECT(t.get_stream(i2) == 3);
for(auto ins : c2) for(auto ins : c2)
EXPECT(stream.at(ins) == 0); EXPECT(t.get_stream(ins) == 0);
EXPECT(stream.at(binary2) == 0); EXPECT(t.get_stream(binary2) == 0);
EXPECT(get_wait_for(binary2) == get_wait_for(stream[binary2], {stream[c2.back()], stream[i2]})); EXPECT(get_wait_for(binary2) == get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
check_conflicts(p, {c2, {i2}}); check_conflicts(p, {c2, {i2}});
} }
TEST_CASE(par_merge) TEST_CASE(par_merge)
{ {
instruction_map stream; schedule_target t{};
migraphx::program p; migraphx::program p;
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto start1 = p.add_instruction(unary_op{}, one); auto start1 = p.add_instruction(unary_op{}, one);
...@@ -401,25 +421,24 @@ TEST_CASE(par_merge) ...@@ -401,25 +421,24 @@ TEST_CASE(par_merge)
auto binary3 = p.add_instruction(nary_op{}, binary1, binary2); auto binary3 = p.add_instruction(nary_op{}, binary1, binary2);
p.compile(schedule_target{&stream}); p.compile(t);
EXPECT(stream.count(one) == 0); EXPECT(not t.has_stream(one));
EXPECT(stream.at(binary3) == 0); EXPECT(t.get_stream(binary3) == 0);
EXPECT(stream.at(i1) == 1); EXPECT(t.get_stream(i1) == 2);
for(auto ins : c1) for(auto ins : c1)
EXPECT(stream.at(ins) == 0); EXPECT(t.get_stream(ins) == 0);
EXPECT(stream.at(binary1) == 0); EXPECT(t.get_stream(binary1) == 0);
EXPECT(get_wait_for(binary1) == get_wait_for(stream[binary1], {stream[c1.back()], stream[i1]})); EXPECT(get_wait_for(binary1) == get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}}); check_conflicts(p, {c1, {i1}});
EXPECT(stream.at(i2) == 2); EXPECT(t.get_stream(i2) == 1);
for(auto ins : c2) for(auto ins : c2)
EXPECT(stream.at(ins) == 3); EXPECT(t.get_stream(ins) == 3);
EXPECT(stream.at(binary2) == 3); EXPECT(t.get_stream(binary2) == 3);
EXPECT(get_wait_for(binary2) == get_wait_for(stream[binary2], {stream[c2.back()], stream[i2]})); EXPECT(get_wait_for(binary2) == get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
check_conflicts(p, {c2, {i2}}); check_conflicts(p, {c2, {i2}});
EXPECT(check_conflicts(p, binary1, binary2)); EXPECT(check_conflicts(p, binary1, binary2));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment