Commit 65cf0713 authored by Paul's avatar Paul
Browse files

Fix unit tests

parent 2f7db364
......@@ -49,11 +49,11 @@ struct nary_op
struct wait_event
{
std::vector<std::size_t> wait_for;
std::shared_ptr<std::vector<std::size_t>> wait_for = std::make_shared<std::vector<std::size_t>>();
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::pack(f(self.wait_for, "wait_for"));
return migraphx::pack(f(*self.wait_for, "wait_for"));
}
std::string name() const { return "wait_event"; }
migraphx::shape compute_shape(const std::vector<migraphx::shape>&) const { return {}; }
......@@ -62,28 +62,39 @@ struct wait_event
const migraphx::shape&,
const std::vector<migraphx::argument>&) const
{
assert(not wait_for.empty());
assert(wait_for != nullptr);
assert(not wait_for->empty());
return {};
}
};
using instruction_map = std::unordered_map<migraphx::instruction_ref, std::size_t>;
using wait_map = std::unordered_map<migraphx::instruction_ref, std::shared_ptr<std::vector<std::size_t>>>;
struct schedule_model_test
{
instruction_map* ins2stream;
std::shared_ptr<instruction_map> ins2stream = std::make_shared<instruction_map>();
std::shared_ptr<std::unordered_map<std::size_t, std::size_t>> wait2stream = std::make_shared<std::unordered_map<std::size_t, std::size_t>>();
std::shared_ptr<wait_map> ins2wait_for = std::make_shared<wait_map>();
std::size_t concurrency() const { return 4; }
void
schedule_instruction(migraphx::program&, migraphx::instruction_ref ins, std::size_t n) const
sched(migraphx::program&, migraphx::instruction_ref ins, std::size_t n) const
{
(*ins2stream)[ins] = n;
}
void wait(migraphx::program& p,
migraphx::instruction_ref ins,
std::size_t,
const std::vector<std::size_t>& wait_for) const
void wait(migraphx::program& p, migraphx::instruction_ref ins, std::size_t wait_id) const
{
p.insert_instruction(ins, wait_event{wait_for});
if (ins2wait_for->count(ins) == 0)
{
auto event = wait_event{};
p.insert_instruction(ins, event);
(*ins2wait_for)[ins] = event.wait_for;
}
(*ins2wait_for)[ins]->push_back(wait2stream->at(wait_id));
}
void record(migraphx::program& p, migraphx::instruction_ref ins, std::size_t wait_id) const
{
(*wait2stream)[wait_id] = ins2stream->at(ins);
}
std::size_t weight(const migraphx::operation& op) const
{
......@@ -96,13 +107,23 @@ struct schedule_model_test
struct schedule_target
{
instruction_map* ins2stream;
schedule_model_test model{};
std::string name() const { return "schedule"; }
std::vector<migraphx::pass> get_passes(migraphx::context&) const
{
return {migraphx::schedule{schedule_model_test{ins2stream}}};
return {migraphx::schedule{model}};
}
migraphx::context get_context() const { return {}; }
std::size_t get_stream(migraphx::instruction_ref ins)
{
return model.ins2stream->at(ins);
}
bool has_stream(migraphx::instruction_ref ins)
{
return model.ins2stream->count(ins) > 0;
}
};
bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx::instruction_ref y)
......@@ -166,7 +187,7 @@ std::vector<std::size_t> get_wait_for(migraphx::instruction_ref ins)
auto wait_ins = std::prev(ins);
if(wait_ins->name() != "wait_event")
return {};
auto wf = migraphx::any_cast<wait_event>(wait_ins->get_operator()).wait_for;
auto wf = *migraphx::any_cast<wait_event>(wait_ins->get_operator()).wait_for;
std::sort(wf.begin(), wf.end());
return wf;
}
......@@ -183,36 +204,35 @@ chain(migraphx::program& p, std::size_t n, T x, migraphx::instruction_ref input)
}
return result;
}
TEST_CASE(single_entry)
{
instruction_map stream;
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, onep1, onep2);
p.compile(schedule_target{&stream});
EXPECT(stream.count(one) == 0);
EXPECT(stream.at(onep1) != stream.at(onep2));
EXPECT(stream.at(binary) == 0);
EXPECT(get_wait_for(binary) == get_wait_for(stream[binary], {stream[onep1], stream[onep2]}));
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary), {t.get_stream(onep1), t.get_stream(onep2)}));
EXPECT(check_conflicts(p, onep1, onep2));
}
TEST_CASE(zero_merge1)
{
instruction_map stream;
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(migraphx::op::identity{}, onep1, onep2);
p.compile(schedule_target{&stream});
EXPECT(stream.count(one) == 0);
EXPECT(stream.at(onep1) != stream.at(onep2));
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
// No stream assignment
EXPECT(stream.count(binary) == 0);
EXPECT(not t.has_stream(binary));
// There is no wait
EXPECT(get_wait_for(binary).empty());
EXPECT(check_conflicts(p, onep1, onep2));
......@@ -220,7 +240,7 @@ TEST_CASE(zero_merge1)
TEST_CASE(zero_merge2)
{
instruction_map stream;
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
......@@ -228,11 +248,11 @@ TEST_CASE(zero_merge2)
auto binary = p.add_instruction(migraphx::op::identity{},
p.add_instruction(migraphx::op::identity{}, onep1),
p.add_instruction(migraphx::op::identity{}, onep2));
p.compile(schedule_target{&stream});
EXPECT(stream.count(one) == 0);
EXPECT(stream.at(onep1) != stream.at(onep2));
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(onep1) != t.get_stream(onep2));
// No stream assignment
EXPECT(stream.count(binary) == 0);
EXPECT(not t.has_stream(binary));
// There is no wait
EXPECT(get_wait_for(binary).empty());
EXPECT(check_conflicts(p, onep1, onep2));
......@@ -240,43 +260,43 @@ TEST_CASE(zero_merge2)
TEST_CASE(double_entry)
{
instruction_map stream;
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto onep = p.add_instruction(unary_op{}, one);
auto twop = p.add_instruction(unary_op{}, two);
auto binary = p.add_instruction(nary_op{}, onep, twop);
p.compile(schedule_target{&stream});
EXPECT(stream.count(one) == 0);
EXPECT(stream.count(two) == 0);
EXPECT(stream.at(onep) != stream.at(twop));
EXPECT(stream.at(binary) == 0);
EXPECT(get_wait_for(binary) == get_wait_for(stream[binary], {stream[onep], stream[twop]}));
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(not t.has_stream(two));
EXPECT(t.get_stream(onep) != t.get_stream(twop));
EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary), {t.get_stream(onep), t.get_stream(twop)}));
// EXPECT(check_conflicts(p, onep, twop));
}
TEST_CASE(two_branches)
{
instruction_map stream;
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, i1, c1.back());
p.compile(schedule_target{&stream});
EXPECT(stream.count(one) == 0);
EXPECT(stream.at(i1) == 1);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 1);
for(auto ins : c1)
EXPECT(stream.at(ins) == 0);
EXPECT(stream.at(binary) == 0);
EXPECT(get_wait_for(binary) == get_wait_for(stream[binary], {stream[c1.back()], stream[i1]}));
EXPECT(t.get_stream(ins) == 0);
EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) == get_wait_for(t.get_stream(binary), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}});
}
TEST_CASE(four_branches)
{
instruction_map stream;
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 4, unary_op{}, one);
......@@ -284,25 +304,25 @@ TEST_CASE(four_branches)
auto c3 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back());
p.compile(schedule_target{&stream});
EXPECT(stream.count(one) == 0);
EXPECT(stream.at(i1) == 3);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 3);
for(auto ins : c1)
EXPECT(stream.at(ins) == 0);
EXPECT(t.get_stream(ins) == 0);
for(auto ins : c2)
EXPECT(stream.at(ins) == 1);
EXPECT(t.get_stream(ins) == 1);
for(auto ins : c3)
EXPECT(stream.at(ins) == 2);
EXPECT(stream.at(binary) == 0);
EXPECT(t.get_stream(ins) == 2);
EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) ==
get_wait_for(stream[binary],
{stream[c1.back()], stream[c2.back()], stream[c3.back()], stream[i1]}));
get_wait_for(t.get_stream(binary),
{t.get_stream(c1.back()), t.get_stream(c2.back()), t.get_stream(c3.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, c2, c3, {i1}});
}
TEST_CASE(five_branches)
{
instruction_map stream;
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 5, unary_op{}, one);
......@@ -311,28 +331,28 @@ TEST_CASE(five_branches)
auto c4 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, i1, c1.back(), c2.back(), c3.back(), c4.back());
p.compile(schedule_target{&stream});
EXPECT(stream.count(one) == 0);
EXPECT(stream.at(i1) == 3);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 3);
for(auto ins : c1)
EXPECT(stream.at(ins) == 0);
EXPECT(t.get_stream(ins) == 0);
for(auto ins : c2)
EXPECT(stream.at(ins) == 1);
EXPECT(t.get_stream(ins) == 1);
for(auto ins : c3)
EXPECT(stream.at(ins) == 2);
EXPECT(t.get_stream(ins) == 2);
for(auto ins : c4)
EXPECT(stream.at(ins) == 3);
EXPECT(stream.at(binary) == 0);
EXPECT(t.get_stream(ins) == 3);
EXPECT(t.get_stream(binary) == 0);
EXPECT(get_wait_for(binary) ==
get_wait_for(stream[binary],
{stream[c1.back()], stream[c2.back()], stream[c3.back()], stream[i1]}));
get_wait_for(t.get_stream(binary),
{t.get_stream(c1.back()), t.get_stream(c2.back()), t.get_stream(c3.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, c2, c3, c4});
check_conflicts(p, {c1, c2, c3, {i1}});
}
TEST_CASE(four_branches_eq)
{
instruction_map stream;
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one);
......@@ -340,22 +360,22 @@ TEST_CASE(four_branches_eq)
auto onep3 = p.add_instruction(unary_op{}, one);
auto onep4 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, onep1, onep2, onep3, onep4);
p.compile(schedule_target{&stream});
EXPECT(stream.count(one) == 0);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(sorted<std::size_t>(
{stream.at(onep1), stream.at(onep2), stream.at(onep3), stream.at(onep4)}) ==
{t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}) ==
unique<std::size_t>(
{stream.at(onep1), stream.at(onep2), stream.at(onep3), stream.at(onep4)}));
EXPECT(stream.at(binary) == 0);
{t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}));
EXPECT(t.get_stream(binary) == 0);
EXPECT(
get_wait_for(binary) ==
get_wait_for(stream[binary], {stream[onep1], stream[onep2], stream[onep3], stream[onep4]}));
get_wait_for(t.get_stream(binary), {t.get_stream(onep1), t.get_stream(onep2), t.get_stream(onep3), t.get_stream(onep4)}));
check_conflicts(p, {{onep1}, {onep2}, {onep3}, {onep4}});
}
TEST_CASE(seq_merge)
{
instruction_map stream;
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one);
......@@ -366,27 +386,27 @@ TEST_CASE(seq_merge)
auto i2 = p.add_instruction(unary_op{}, binary1);
auto binary2 = p.add_instruction(nary_op{}, i2, c2.back());
p.compile(schedule_target{&stream});
EXPECT(stream.count(one) == 0);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(stream.at(i1) == 2);
EXPECT(t.get_stream(i1) == 2);
for(auto ins : c1)
EXPECT(stream.at(ins) == 3);
EXPECT(stream.at(binary1) == 3);
EXPECT(get_wait_for(binary1) == get_wait_for(stream[binary1], {stream[c1.back()], stream[i1]}));
EXPECT(t.get_stream(ins) == 3);
EXPECT(t.get_stream(binary1) == 3);
EXPECT(get_wait_for(binary1) == get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}});
EXPECT(stream.at(i2) == 3);
EXPECT(t.get_stream(i2) == 3);
for(auto ins : c2)
EXPECT(stream.at(ins) == 0);
EXPECT(stream.at(binary2) == 0);
EXPECT(get_wait_for(binary2) == get_wait_for(stream[binary2], {stream[c2.back()], stream[i2]}));
EXPECT(t.get_stream(ins) == 0);
EXPECT(t.get_stream(binary2) == 0);
EXPECT(get_wait_for(binary2) == get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
check_conflicts(p, {c2, {i2}});
}
TEST_CASE(par_merge)
{
instruction_map stream;
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto start1 = p.add_instruction(unary_op{}, one);
......@@ -401,25 +421,24 @@ TEST_CASE(par_merge)
auto binary3 = p.add_instruction(nary_op{}, binary1, binary2);
p.compile(schedule_target{&stream});
EXPECT(stream.count(one) == 0);
EXPECT(stream.at(binary3) == 0);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(binary3) == 0);
EXPECT(stream.at(i1) == 1);
EXPECT(t.get_stream(i1) == 2);
for(auto ins : c1)
EXPECT(stream.at(ins) == 0);
EXPECT(stream.at(binary1) == 0);
EXPECT(get_wait_for(binary1) == get_wait_for(stream[binary1], {stream[c1.back()], stream[i1]}));
EXPECT(t.get_stream(ins) == 0);
EXPECT(t.get_stream(binary1) == 0);
EXPECT(get_wait_for(binary1) == get_wait_for(t.get_stream(binary1), {t.get_stream(c1.back()), t.get_stream(i1)}));
check_conflicts(p, {c1, {i1}});
EXPECT(stream.at(i2) == 2);
EXPECT(t.get_stream(i2) == 1);
for(auto ins : c2)
EXPECT(stream.at(ins) == 3);
EXPECT(stream.at(binary2) == 3);
EXPECT(get_wait_for(binary2) == get_wait_for(stream[binary2], {stream[c2.back()], stream[i2]}));
EXPECT(t.get_stream(ins) == 3);
EXPECT(t.get_stream(binary2) == 3);
EXPECT(get_wait_for(binary2) == get_wait_for(t.get_stream(binary2), {t.get_stream(c2.back()), t.get_stream(i2)}));
check_conflicts(p, {c2, {i2}});
EXPECT(check_conflicts(p, binary1, binary2));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment