Commit 51597ed7 authored by Khalique's avatar Khalique
Browse files

fix tests and tf parser

parents 7bacd3ba bc80dee8
......@@ -148,6 +148,56 @@ TEST_CASE(match_arg7)
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_arg8)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
match::arg(1)(match::name("@literal"))),
match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_nargs1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::nargs(2));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_nargs2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::nargs(2), match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_nargs3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::nargs(2)));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_args1)
{
migraphx::program p;
......@@ -307,6 +357,19 @@ TEST_CASE(match_all_of2)
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_all_of3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::all_of(
match::arg(0)(match::name("@literal")), match::arg(1)(match::name("@literal")))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_any_of1)
{
migraphx::program p;
......@@ -359,6 +422,132 @@ TEST_CASE(match_none_of2)
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_output1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum = p.add_instruction(sum_op{}, minus, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::output(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_output2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum = p.add_instruction(sum_op{}, minus, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("@literal")(match::output(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_skip_output1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum = p.add_instruction(sum_op{}, minus, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto minus_pass = p.add_instruction(pass_op{}, minus);
auto sum = p.add_instruction(sum_op{}, minus_pass, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto minus_pass1 = p.add_instruction(pass_op{}, minus);
auto minus_pass2 = p.add_instruction(pass_op{}, minus_pass1);
auto minus_pass3 = p.add_instruction(pass_op{}, minus_pass2);
auto sum = p.add_instruction(sum_op{}, minus_pass3, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output4)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto pass = p.add_instruction(pass_op{}, one);
auto sum = p.add_instruction(sum_op{}, pass, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == two});
}
TEST_CASE(match_skip_output5)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto pass = p.add_instruction(pass_op{}, one);
auto sum1 = p.add_instruction(sum_op{}, pass, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, one);
auto sum3 = p.add_instruction(sum_op{}, sum2, two);
p.add_instruction(pass_op{}, sum3);
auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_skip_output6)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum1 = p.add_instruction(sum_op{}, minus, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, one);
auto sum3 = p.add_instruction(sum_op{}, sum2, two);
p.add_instruction(pass_op{}, sum3);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output7)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus1 = p.add_instruction(minus_op{}, two, one);
auto minus2 = p.add_instruction(minus_op{}, two, minus1);
auto sum = p.add_instruction(sum_op{}, one, minus2);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("minus")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus1});
}
TEST_CASE(match_bind1)
{
migraphx::program p;
......
implicit_bcast-example:q

add2:u

0
12"Addtest-multi_bcastZ
1out"Add subtraction2Z
0




Z
1

Z
1


b
2

b
out




B
\ No newline at end of file
B
 subtraction2:q
add2:q

0
1out"Sub subtraction2Z
......@@ -10,11 +10,11 @@
Z
1


b

b
out




B
\ No newline at end of file
B
......@@ -350,7 +350,7 @@ TEST_CASE(implicit_add_bcast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, l2, l3);
......@@ -377,7 +377,7 @@ TEST_CASE(implicit_sub_bcast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::sub{}, l2, l3);
......@@ -784,6 +784,28 @@ TEST_CASE(logsoftmax)
EXPECT(p == prog);
}
TEST_CASE(argmax)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto ins = p.add_instruction(migraphx::op::argmax{2}, l0);
p.add_instruction(migraphx::op::squeeze{{2}}, ins);
auto prog = migraphx::parse_onnx("argmax_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(argmin)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto ins = p.add_instruction(migraphx::op::argmin{3}, l0);
p.add_instruction(migraphx::op::squeeze{{3}}, ins);
auto prog = migraphx::parse_onnx("argmin_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(no_pad_test)
{
migraphx::program p;
......@@ -794,6 +816,38 @@ TEST_CASE(no_pad_test)
EXPECT(p == prog);
}
TEST_CASE(reducesum_test1)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_instruction(migraphx::op::reduce_sum{{2}}, l0);
p.add_instruction(migraphx::op::squeeze{{2}}, l1);
auto prog = migraphx::parse_onnx("reducesum_test1.onnx");
EXPECT(p == prog);
}
TEST_CASE(reducesum_test2)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_instruction(migraphx::op::reduce_sum{{2, 3}}, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1);
auto prog = migraphx::parse_onnx("reducesum_test2.onnx");
EXPECT(p == prog);
}
TEST_CASE(reducesum_test3)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_instruction(migraphx::op::reduce_sum{{2, 3}}, l0);
auto prog = migraphx::parse_onnx("reducesum_test3.onnx");
EXPECT(p == prog);
}
TEST_CASE(clip_test)
{
migraphx::program p;
......
reducesum-example:}
1
xy" ReduceSum*
axes@@*
keepdimstest_reducesumZ
x




b
y




B
......@@ -227,6 +227,16 @@ TEST_CASE(multibroadcast)
migraphx::shape input{migraphx::shape::float_type, {}};
throws_shape(migraphx::op::multibroadcast{lens}, input);
}
{
std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {3, 4}};
throws_shape(migraphx::op::multibroadcast{lens}, input);
}
{
std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}};
throws_shape(migraphx::op::multibroadcast{lens}, input);
}
}
TEST_CASE(broadcast)
......@@ -380,6 +390,38 @@ TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); }
template <class T>
void test_argop_var()
{
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}}, T{0}, input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}}, T{1}, input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}}, T{2}, input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}}, T{3}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(T{4}, input);
}
}
TEST_CASE(argmax) { test_argop_var<migraphx::op::argmax>(); }
TEST_CASE(argmin) { test_argop_var<migraphx::op::argmin>(); }
// 2 inputs arguments
TEST_CASE(matmul)
{
......
......@@ -38,7 +38,7 @@ TEST_CASE(test_shape_packed)
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_transposed)
TEST_CASE(test_shape_transposed1)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 2}};
EXPECT(not s.standard());
......@@ -47,6 +47,15 @@ TEST_CASE(test_shape_transposed)
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_transposed2)
{
migraphx::shape s{migraphx::shape::float_type, {1, 1, 1, 1, 2}, {2, 2, 2, 2, 1}};
EXPECT(s.standard());
EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_broadcasted)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 0}};
......
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
......@@ -165,4 +166,144 @@ TEST_CASE(transpose_double_contiguous)
EXPECT(p.has_instruction(t));
}
TEST_CASE(transpose_partial1)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
p.add_instruction(pass_op{}, t2);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
TEST_CASE(transpose_partial2)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
p.add_instruction(pass_op{}, t3);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
}
TEST_CASE(transpose_partial3)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
auto t4 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t3);
p.add_instruction(pass_op{}, t4);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 3);
}
TEST_CASE(nop_transpose1)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
TEST_CASE(nop_transpose2)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t2);
auto t4 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t3);
p.add_instruction(pass_op{}, t4);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 4);
}
TEST_CASE(nop_transpose3)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto concat = p.add_instruction(migraphx::op::concat{3}, x, y);
auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2, 3}}, concat);
auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, t1);
p.add_instruction(pass_op{}, t2);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
TEST_CASE(concat_transpose1)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto xt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto yt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y);
auto concat = p.add_instruction(migraphx::op::concat{2}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, concat);
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 3);
auto new_concat =
std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != p.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3);
}
TEST_CASE(concat_transpose2)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto xt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
auto yt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
auto new_concat =
std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != p.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/tf.hpp>
#include "test.hpp"
migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
{
auto prog = migraphx::parse_tf(name, is_nhwc);
if(is_nhwc)
migraphx::run_passes(prog,
{migraphx::simplify_reshapes{},
migraphx::dead_code_elimination{},
migraphx::eliminate_identity{}});
return prog;
}
TEST_CASE(add_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
p.add_instruction(migraphx::op::add{}, l0, l1);
auto prog = migraphx::parse_tf("add_test.pb", false);
auto prog = optimize_tf("add_test.pb", false);
EXPECT(p == prog);
}
......@@ -28,7 +43,7 @@ TEST_CASE(add_bcast_test)
auto l2 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1);
p.add_instruction(migraphx::op::add{}, l2, l3);
auto prog = migraphx::parse_tf("add_bcast_test.pb", false);
auto prog = optimize_tf("add_bcast_test.pb", false);
EXPECT(p == prog);
}
......@@ -51,7 +66,7 @@ TEST_CASE(batchnorm_test)
auto l4 = p.add_parameter("4", s0);
auto l1 = p.add_literal(migraphx::literal{s0, const_vals});
p.add_instruction(op, l0, l1, l2, l3, l4);
auto prog = migraphx::parse_tf("batchnorm_test.pb", true);
auto prog = optimize_tf("batchnorm_test.pb", true);
EXPECT(p == prog);
}
......@@ -65,7 +80,7 @@ TEST_CASE(biasadd_test)
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
auto l2 = p.add_instruction(migraphx::op::broadcast{axis, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::add{}, l0, l2);
auto prog = migraphx::parse_tf("biasadd_test.pb", true);
auto prog = optimize_tf("biasadd_test.pb", true);
EXPECT(p == prog);
}
......@@ -83,7 +98,7 @@ TEST_CASE(concat_test)
p.add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis});
p.add_instruction(migraphx::op::concat{static_cast<std::size_t>(axis)}, l0, l1);
auto prog = migraphx::parse_tf("concat_test.pb", false);
auto prog = optimize_tf("concat_test.pb", false);
EXPECT(p == prog);
}
......@@ -92,7 +107,7 @@ TEST_CASE(const_test)
{
migraphx::program p;
p.add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector<float>{1.0f});
auto prog = migraphx::parse_tf("constant_test.pb", false);
auto prog = optimize_tf("constant_test.pb", false);
EXPECT(p == prog);
}
......@@ -112,10 +127,9 @@ TEST_CASE(conv_test)
op.padding = {1, 1};
op.stride = {1, 1};
op.dilation = {1, 1};
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1);
auto l3 = p.add_instruction(migraphx::op::transpose{{1, 3, 0, 2}}, l2);
p.add_instruction(op, l0, l3);
auto prog = migraphx::parse_tf("conv_test.pb", true);
auto l2 = p.add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
p.add_instruction(op, l0, l2);
auto prog = optimize_tf("conv_test.pb", true);
EXPECT(p == prog);
}
......@@ -136,12 +150,11 @@ TEST_CASE(depthwiseconv_test)
op.stride = {1, 1};
op.dilation = {1, 1};
op.group = 3;
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1);
auto l3 = p.add_instruction(migraphx::op::transpose{{1, 3, 0, 2}}, l2);
auto l3 = p.add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
auto l4 = p.add_instruction(migraphx::op::contiguous{}, l3);
auto l5 = p.add_instruction(migraphx::op::reshape{{3, 1, 3, 3}}, l4);
p.add_instruction(op, l0, l5);
auto prog = migraphx::parse_tf("depthwise_conv_test.pb", true);
auto prog = optimize_tf("depthwise_conv_test.pb", true);
EXPECT(p == prog);
}
......@@ -151,11 +164,20 @@ TEST_CASE(expanddims_test)
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
p.add_literal(-1);
p.add_literal(0);
p.add_instruction(migraphx::op::reshape{{2, 3, 4, 1}}, l0);
p.add_instruction(migraphx::op::reshape{{1, 2, 3, 4}}, l0);
auto prog = migraphx::parse_tf("expanddims_test.pb", true);
auto prog = optimize_tf("expanddims_test.pb", true);
EXPECT(p == prog);
}
TEST_CASE(expanddims_test_neg_dims)
{
// this check makes sure the pb parses negative dim value correctly
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
p.add_instruction(migraphx::op::reshape{{2, 3, 4, 1}}, l0);
auto prog = optimize_tf("expanddims_neg_test.pb", true);
EXPECT(p == prog);
}
......@@ -165,7 +187,7 @@ TEST_CASE(identity_test)
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::identity{}, l0);
auto prog = migraphx::parse_tf("identity_test.pb", false);
auto prog = optimize_tf("identity_test.pb", false);
EXPECT(p == prog);
}
......@@ -180,7 +202,7 @@ TEST_CASE(matmul_test)
auto trans_l1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
p.add_instruction(migraphx::op::dot{}, trans_l0, trans_l1);
auto prog = migraphx::parse_tf("matmul_test.pb", false);
auto prog = optimize_tf("matmul_test.pb", false);
EXPECT(p == prog);
}
......@@ -197,7 +219,7 @@ TEST_CASE(mean_test)
p.add_instruction(op, l0);
auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
auto prog = migraphx::parse_tf("mean_test.pb", false);
auto prog = optimize_tf("mean_test.pb", false);
EXPECT(p == prog);
}
......@@ -207,14 +229,11 @@ TEST_CASE(mean_test_nhwc)
migraphx::program p;
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_literal(l);
p.add_literal(l);
migraphx::op::pooling op;
op.lengths = {16, 16};
p.add_instruction(op, l0);
auto l3 = p.add_instruction(op, l0);
auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
auto prog = migraphx::parse_tf("mean_test_nhwc.pb", true);
auto prog = optimize_tf("mean_test_nhwc.pb", true);
EXPECT(p == prog);
}
......@@ -226,7 +245,7 @@ TEST_CASE(mul_test)
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
p.add_instruction(migraphx::op::mul{}, l0, l1);
auto prog = migraphx::parse_tf("mul_test.pb", false);
auto prog = optimize_tf("mul_test.pb", false);
EXPECT(p == prog);
}
......@@ -248,7 +267,7 @@ TEST_CASE(pack_test)
return p.add_instruction(migraphx::op::unsqueeze{{axis}}, arg);
});
p.add_instruction(migraphx::op::concat{static_cast<size_t>(axis)}, unsqueezed_args);
auto prog = migraphx::parse_tf("pack_test.pb", false);
auto prog = optimize_tf("pack_test.pb", false);
EXPECT(p == prog);
}
......@@ -256,12 +275,15 @@ TEST_CASE(pack_test)
TEST_CASE(pack_test_nhwc)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
std::vector<migraphx::instruction_ref> args{l0, l1, l2};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt0 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l1);
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt2 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l2);
std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t nchw_axis = 1;
int64_t nchw_axis = 3;
std::transform(args.begin(),
args.end(),
......@@ -270,7 +292,7 @@ TEST_CASE(pack_test_nhwc)
return p.add_instruction(migraphx::op::unsqueeze{{nchw_axis}}, arg);
});
p.add_instruction(migraphx::op::concat{static_cast<size_t>(nchw_axis)}, unsqueezed_args);
auto prog = migraphx::parse_tf("pack_test_nhwc.pb", true);
auto prog = optimize_tf("pack_test_nhwc.pb", true);
EXPECT(p == prog);
}
......@@ -287,9 +309,9 @@ TEST_CASE(pooling_test)
max_pool_op.stride = {2, 2};
avg_pool_op.lengths = {2, 2};
max_pool_op.lengths = {2, 2};
p.add_instruction(avg_pool_op, l0);
p.add_instruction(max_pool_op, l0);
auto prog = migraphx::parse_tf("pooling_test.pb", true);
// p.add_instruction(avg_pool_op, l0);
auto prog = optimize_tf("pooling_test.pb", true);
EXPECT(p == prog);
}
......@@ -299,7 +321,7 @@ TEST_CASE(relu_test)
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::relu{}, l0);
auto prog = migraphx::parse_tf("relu_test.pb", false);
auto prog = optimize_tf("relu_test.pb", false);
EXPECT(p == prog);
}
......@@ -309,7 +331,7 @@ TEST_CASE(relu6_test)
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::clip{6.0, 0.0}, l0);
auto prog = migraphx::parse_tf("relu6_test.pb", false);
auto prog = optimize_tf("relu6_test.pb", false);
EXPECT(p == prog);
}
......@@ -322,7 +344,7 @@ TEST_CASE(reshape_test)
// in tf, the second arg is a literal that contains new dimensions
p.add_literal(migraphx::literal{s0, {1, 1, 1, 16}});
p.add_instruction(migraphx::op::reshape{{1, 1, 1, 16}}, l0);
auto prog = migraphx::parse_tf("reshape_test.pb", false);
auto prog = optimize_tf("reshape_test.pb", false);
EXPECT(p == prog);
}
......@@ -335,7 +357,7 @@ TEST_CASE(softmax_test)
auto r = p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, l0);
auto s = p.add_instruction(migraphx::op::softmax{}, r);
p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1])}}, s);
auto prog = migraphx::parse_tf("softmax_test.pb", false);
auto prog = optimize_tf("softmax_test.pb", false);
EXPECT(p == prog);
}
......@@ -345,7 +367,7 @@ TEST_CASE(squeeze_test)
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}});
p.add_instruction(migraphx::op::squeeze{{0, 3}}, l0);
auto prog = migraphx::parse_tf("squeeze_test.pb", false);
auto prog = optimize_tf("squeeze_test.pb", false);
EXPECT(p == prog);
}
......@@ -357,18 +379,13 @@ TEST_CASE(stridedslice_test)
std::size_t num_axes = 4;
migraphx::op::slice op;
op.starts = {0, 0, 0, 0};
op.ends = {1, 5, 1, 1};
op.ends = {1, 1, 1, 5};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
// add literals for starts, ends, and strides in tf (NHWC format)
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 0, 0, 0});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 5});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 1});
auto l1 = p.add_instruction(op, l0);
auto shrink_axis = 2;
auto shrink_axis = 1;
p.add_instruction(migraphx::op::squeeze{{shrink_axis}}, l1);
auto prog = migraphx::parse_tf("stridedslice_test.pb", true);
auto prog = optimize_tf("stridedslice_test.pb", true);
EXPECT(p == prog);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment