Commit 2c9f9c48 authored by Paul's avatar Paul
Browse files

Add more tests

parent fba751eb
......@@ -182,15 +182,13 @@ std::vector<T> unique(std::vector<T> x)
std::vector<std::size_t> get_wait_for(std::vector<std::size_t> wait_for)
{
std::sort(wait_for.begin(), wait_for.end());
return wait_for;
return unique(wait_for);
}
std::vector<std::size_t> get_wait_for(std::size_t wait_on, std::vector<std::size_t> wait_for)
{
wait_for.erase(std::find(wait_for.begin(), wait_for.end(), wait_on));
std::sort(wait_for.begin(), wait_for.end());
return wait_for;
return unique(wait_for);
}
std::vector<std::size_t> get_wait_for(migraphx::instruction_ref ins)
......@@ -631,6 +629,52 @@ TEST_CASE(par_merge_multi_entry)
check_conflicts(p, {c1, {i1}, c2, {i2}});
}
TEST_CASE(inner_split1)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one);
auto s1 = p.add_instruction(unary_op{}, c1);
auto s2 = p.add_instruction(unary_op{}, c1);
auto output = p.add_instruction(nary_op{}, i1, s1, s2);
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 3);
for(auto ins : c1)
EXPECT(t.get_stream(ins) != t.get_stream(i1));
EXPECT(t.get_stream(s1) != t.get_stream(s2));
EXPECT(t.get_stream(output) == 0);
EXPECT(get_wait_for(output) ==
get_wait_for(t.get_stream(output), {t.get_stream(c1.back()), t.get_stream(i1), t.get_stream(s1), t.get_stream(s2)}));
check_conflicts(p, {c1, {i1}, {s1}, {s2}});
}
TEST_CASE(inner_split2)
{
schedule_target t{};
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one);
auto s1 = chain(p, 3, unary_op{}, c1.back());
auto s2 = chain(p, 4, unary_op{}, c1.back());
auto output = p.add_instruction(nary_op{}, i1, s1.back(), s2.back());
p.compile(t);
EXPECT(not t.has_stream(one));
EXPECT(t.get_stream(i1) == 2);
for(auto ins : c1)
EXPECT(t.get_stream(ins) != t.get_stream(i1));
EXPECT(t.get_stream(s1.back()) != t.get_stream(s2.back()));
EXPECT(t.get_stream(output) == 0);
EXPECT(get_wait_for(output) ==
get_wait_for(t.get_stream(output), {t.get_stream(c1.back()), t.get_stream(i1), t.get_stream(s1.back()), t.get_stream(s2.back())}));
check_conflicts(p, {c1, {i1}, s1, s2});
}
TEST_CASE(inception_resnet)
{
schedule_target t{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment