Commit b3f1d9d5 authored by Paul's avatar Paul
Browse files

Add more tests

parent de1d1056
...@@ -64,7 +64,6 @@ struct set_stream ...@@ -64,7 +64,6 @@ struct set_stream
argument compute(context& ctx, const shape&, const std::vector<argument>&) const argument compute(context& ctx, const shape&, const std::vector<argument>&) const
{ {
assert(stream >= 0);
ctx.set_stream(stream); ctx.set_stream(stream);
return {}; return {};
} }
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
...@@ -26,9 +27,9 @@ struct unary_op ...@@ -26,9 +27,9 @@ struct unary_op
int output_alias(const std::vector<migraphx::shape>&) const { return 0; } int output_alias(const std::vector<migraphx::shape>&) const { return 0; }
}; };
struct binary_op struct nary_op
{ {
std::string name() const { return "binary"; } std::string name() const { return "nary"; }
migraphx::argument migraphx::argument
compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const compute(migraphx::context&, const migraphx::shape&, std::vector<migraphx::argument> args) const
{ {
...@@ -119,6 +120,17 @@ bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx ...@@ -119,6 +120,17 @@ bool check_conflicts(migraphx::program& p, migraphx::instruction_ref x, migraphx
return false; return false;
} }
void check_conflicts(migraphx::program& p, std::vector<std::vector<migraphx::instruction_ref>> conflicts)
{
migraphx::dfor(conflicts.size(), conflicts.size())([&](auto i, auto j) {
if (i == j)
return;
for(auto ins1:conflicts[i])
for(auto ins2:conflicts[j])
CHECK(check_conflicts(p, ins1, ins2));
});
}
std::vector<std::size_t> get_wait_for(std::size_t wait_on, std::vector<std::size_t> wait_for) std::vector<std::size_t> get_wait_for(std::size_t wait_on, std::vector<std::size_t> wait_for)
{ {
wait_for.erase(std::find(wait_for.begin(), wait_for.end(), wait_on)); wait_for.erase(std::find(wait_for.begin(), wait_for.end(), wait_on));
...@@ -136,6 +148,18 @@ std::vector<std::size_t> get_wait_for(migraphx::instruction_ref ins) ...@@ -136,6 +148,18 @@ std::vector<std::size_t> get_wait_for(migraphx::instruction_ref ins)
return wf; return wf;
} }
template<class T>
std::vector<migraphx::instruction_ref> chain(migraphx::program& p, std::size_t n, T x, migraphx::instruction_ref input)
{
std::vector<migraphx::instruction_ref> result;
for(std::size_t i = 0;i < n;i++)
{
result.push_back(p.add_instruction(x, input));
input = result.back();
}
return result;
}
TEST_CASE(single_entry) TEST_CASE(single_entry)
{ {
instruction_map stream; instruction_map stream;
...@@ -143,8 +167,9 @@ TEST_CASE(single_entry) ...@@ -143,8 +167,9 @@ TEST_CASE(single_entry)
auto one = p.add_literal(1); auto one = p.add_literal(1);
auto onep1 = p.add_instruction(unary_op{}, one); auto onep1 = p.add_instruction(unary_op{}, one);
auto onep2 = p.add_instruction(unary_op{}, one); auto onep2 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(binary_op{}, onep1, onep2); auto binary = p.add_instruction(nary_op{}, onep1, onep2);
p.compile(schedule_target{&stream}); p.compile(schedule_target{&stream});
EXPECT(stream.count(one) == 0);
EXPECT(stream.at(onep1) != stream.at(onep2)); EXPECT(stream.at(onep1) != stream.at(onep2));
EXPECT(stream.at(binary) == 0); EXPECT(stream.at(binary) == 0);
EXPECT(get_wait_for(binary) == get_wait_for(stream[binary], {stream[onep1], stream[onep2]})); EXPECT(get_wait_for(binary) == get_wait_for(stream[binary], {stream[onep1], stream[onep2]}));
...@@ -159,12 +184,53 @@ TEST_CASE(double_entry) ...@@ -159,12 +184,53 @@ TEST_CASE(double_entry)
auto two = p.add_literal(2); auto two = p.add_literal(2);
auto onep = p.add_instruction(unary_op{}, one); auto onep = p.add_instruction(unary_op{}, one);
auto twop = p.add_instruction(unary_op{}, two); auto twop = p.add_instruction(unary_op{}, two);
auto binary = p.add_instruction(binary_op{}, onep, twop); auto binary = p.add_instruction(nary_op{}, onep, twop);
p.compile(schedule_target{&stream}); p.compile(schedule_target{&stream});
EXPECT(stream.count(one) == 0);
EXPECT(stream.count(two) == 0);
EXPECT(stream.at(onep) != stream.at(twop)); EXPECT(stream.at(onep) != stream.at(twop));
EXPECT(stream.at(binary) == 0); EXPECT(stream.at(binary) == 0);
EXPECT(get_wait_for(binary) == get_wait_for(stream[binary], {stream[onep], stream[twop]})); EXPECT(get_wait_for(binary) == get_wait_for(stream[binary], {stream[onep], stream[twop]}));
EXPECT(check_conflicts(p, onep, twop)); // EXPECT(check_conflicts(p, onep, twop));
}
TEST_CASE(two_weights)
{
instruction_map stream;
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, i1, c1.back());
p.compile(schedule_target{&stream});
EXPECT(stream.count(one) == 0);
EXPECT(stream.at(i1) == 1);
for(auto ins:c1)
EXPECT(stream.at(ins) == 0);
EXPECT(stream.at(binary) == 0);
EXPECT(get_wait_for(binary) == get_wait_for(stream[binary], {stream[c1.back()], stream[i1]}));
check_conflicts(p, {c1, {i1}});
}
TEST_CASE(four_weights)
{
instruction_map stream;
migraphx::program p;
auto one = p.add_literal(1);
auto c1 = chain(p, 4, unary_op{}, one);
auto c2 = chain(p, 3, unary_op{}, one);
auto c3 = chain(p, 2, unary_op{}, one);
auto i1 = p.add_instruction(unary_op{}, one);
auto binary = p.add_instruction(nary_op{}, i1, c1.back());
p.compile(schedule_target{&stream});
EXPECT(stream.count(one) == 0);
EXPECT(stream.at(i1) == 3);
for(auto ins:c1) EXPECT(stream.at(ins) == 0);
for(auto ins:c2) EXPECT(stream.at(ins) == 1);
for(auto ins:c3) EXPECT(stream.at(ins) == 2);
EXPECT(stream.at(binary) == 0);
EXPECT(get_wait_for(binary) == get_wait_for(stream[binary], {stream[c1.back()], stream[c2.back()], stream[c3.back()], stream[i1]}));
check_conflicts(p, {c1, c2, c3, {i1}});
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment