"vscode:/vscode.git/clone" did not exist on "f40147ce58e746eb2825e6755cdf9a734a0b3cf4"
Commit 51597ed7 authored by Khalique's avatar Khalique
Browse files

fix tests and tf parser

parents 7bacd3ba bc80dee8
...@@ -148,6 +148,56 @@ TEST_CASE(match_arg7) ...@@ -148,6 +148,56 @@ TEST_CASE(match_arg7)
EXPECT(bool{r.result == sum}); EXPECT(bool{r.result == sum});
} }
TEST_CASE(match_arg8)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::arg(0)(match::name("@literal")),
match::arg(1)(match::name("@literal"))),
match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_nargs1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::nargs(2));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_nargs2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::nargs(2), match::standard_shape());
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_nargs3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::nargs(2)));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_args1) TEST_CASE(match_args1)
{ {
migraphx::program p; migraphx::program p;
...@@ -307,6 +357,19 @@ TEST_CASE(match_all_of2) ...@@ -307,6 +357,19 @@ TEST_CASE(match_all_of2)
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
TEST_CASE(match_all_of3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto sum = p.add_instruction(sum_op{}, one, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("sum")(match::all_of(match::all_of(
match::arg(0)(match::name("@literal")), match::arg(1)(match::name("@literal")))));
auto r = find_match(p, m);
EXPECT(bool{r.result == sum});
}
TEST_CASE(match_any_of1) TEST_CASE(match_any_of1)
{ {
migraphx::program p; migraphx::program p;
...@@ -359,6 +422,132 @@ TEST_CASE(match_none_of2) ...@@ -359,6 +422,132 @@ TEST_CASE(match_none_of2)
EXPECT(bool{r.result == p.end()}); EXPECT(bool{r.result == p.end()});
} }
TEST_CASE(match_output1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum = p.add_instruction(sum_op{}, minus, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::output(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_output2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum = p.add_instruction(sum_op{}, minus, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("@literal")(match::output(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_skip_output1)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum = p.add_instruction(sum_op{}, minus, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output2)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto minus_pass = p.add_instruction(pass_op{}, minus);
auto sum = p.add_instruction(sum_op{}, minus_pass, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output3)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto minus_pass1 = p.add_instruction(pass_op{}, minus);
auto minus_pass2 = p.add_instruction(pass_op{}, minus_pass1);
auto minus_pass3 = p.add_instruction(pass_op{}, minus_pass2);
auto sum = p.add_instruction(sum_op{}, minus_pass3, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output4)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto pass = p.add_instruction(pass_op{}, one);
auto sum = p.add_instruction(sum_op{}, pass, two);
p.add_instruction(pass_op{}, sum);
auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == two});
}
TEST_CASE(match_skip_output5)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto pass = p.add_instruction(pass_op{}, one);
auto sum1 = p.add_instruction(sum_op{}, pass, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, one);
auto sum3 = p.add_instruction(sum_op{}, sum2, two);
p.add_instruction(pass_op{}, sum3);
auto m = match::name("@literal")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == p.end()});
}
TEST_CASE(match_skip_output6)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus = p.add_instruction(minus_op{}, two, one);
auto sum1 = p.add_instruction(sum_op{}, minus, two);
auto sum2 = p.add_instruction(sum_op{}, sum1, one);
auto sum3 = p.add_instruction(sum_op{}, sum2, two);
p.add_instruction(pass_op{}, sum3);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("sum")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus});
}
TEST_CASE(match_skip_output7)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto minus1 = p.add_instruction(minus_op{}, two, one);
auto minus2 = p.add_instruction(minus_op{}, two, minus1);
auto sum = p.add_instruction(sum_op{}, one, minus2);
p.add_instruction(pass_op{}, sum);
auto m = match::name("minus")(match::skip_output(match::name("pass"))(match::name("minus")));
auto r = find_match(p, m);
EXPECT(bool{r.result == minus1});
}
TEST_CASE(match_bind1) TEST_CASE(match_bind1)
{ {
migraphx::program p; migraphx::program p;
......
implicit_bcast-example:q add2:u
 
0 0
12"Addtest-multi_bcastZ 1out"Add subtraction2Z
0 0
 
 
 
 
Z Z
1 1
 
 
b 
2 b
out
 
 
 
 
B B
\ No newline at end of file
 subtraction2:q add2:q
 
0 0
1out"Sub subtraction2Z 1out"Sub subtraction2Z
...@@ -10,11 +10,11 @@ ...@@ -10,11 +10,11 @@
Z Z
1 1
 
 
b b
out out
 
 
 
 
B B
\ No newline at end of file
...@@ -350,7 +350,7 @@ TEST_CASE(implicit_add_bcast_test) ...@@ -350,7 +350,7 @@ TEST_CASE(implicit_add_bcast_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4, 1}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0); auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::add{}, l2, l3); p.add_instruction(migraphx::op::add{}, l2, l3);
...@@ -377,7 +377,7 @@ TEST_CASE(implicit_sub_bcast_test) ...@@ -377,7 +377,7 @@ TEST_CASE(implicit_sub_bcast_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {3, 4}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 5}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0); auto l2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1); auto l3 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4, 5}}, l1);
p.add_instruction(migraphx::op::sub{}, l2, l3); p.add_instruction(migraphx::op::sub{}, l2, l3);
...@@ -784,6 +784,28 @@ TEST_CASE(logsoftmax) ...@@ -784,6 +784,28 @@ TEST_CASE(logsoftmax)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(argmax)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto ins = p.add_instruction(migraphx::op::argmax{2}, l0);
p.add_instruction(migraphx::op::squeeze{{2}}, ins);
auto prog = migraphx::parse_onnx("argmax_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(argmin)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto ins = p.add_instruction(migraphx::op::argmin{3}, l0);
p.add_instruction(migraphx::op::squeeze{{3}}, ins);
auto prog = migraphx::parse_onnx("argmin_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(no_pad_test) TEST_CASE(no_pad_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -794,6 +816,38 @@ TEST_CASE(no_pad_test) ...@@ -794,6 +816,38 @@ TEST_CASE(no_pad_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(reducesum_test1)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_instruction(migraphx::op::reduce_sum{{2}}, l0);
p.add_instruction(migraphx::op::squeeze{{2}}, l1);
auto prog = migraphx::parse_onnx("reducesum_test1.onnx");
EXPECT(p == prog);
}
TEST_CASE(reducesum_test2)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
auto l1 = p.add_instruction(migraphx::op::reduce_sum{{2, 3}}, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l1);
auto prog = migraphx::parse_onnx("reducesum_test2.onnx");
EXPECT(p == prog);
}
TEST_CASE(reducesum_test3)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_instruction(migraphx::op::reduce_sum{{2, 3}}, l0);
auto prog = migraphx::parse_onnx("reducesum_test3.onnx");
EXPECT(p == prog);
}
TEST_CASE(clip_test) TEST_CASE(clip_test)
{ {
migraphx::program p; migraphx::program p;
......
reducesum-example:}
1
xy" ReduceSum*
axes@@*
keepdimstest_reducesumZ
x




b
y




B
...@@ -227,6 +227,16 @@ TEST_CASE(multibroadcast) ...@@ -227,6 +227,16 @@ TEST_CASE(multibroadcast)
migraphx::shape input{migraphx::shape::float_type, {}}; migraphx::shape input{migraphx::shape::float_type, {}};
throws_shape(migraphx::op::multibroadcast{lens}, input); throws_shape(migraphx::op::multibroadcast{lens}, input);
} }
{
std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {3, 4}};
throws_shape(migraphx::op::multibroadcast{lens}, input);
}
{
std::vector<std::size_t> lens{2, 3, 4, 5};
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4}};
throws_shape(migraphx::op::multibroadcast{lens}, input);
}
} }
TEST_CASE(broadcast) TEST_CASE(broadcast)
...@@ -380,6 +390,38 @@ TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); } ...@@ -380,6 +390,38 @@ TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); } TEST_CASE(logsoftmax) { test_softmax_variations<migraphx::op::logsoftmax>(); }
template <class T>
void test_argop_var()
{
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}}, T{0}, input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}}, T{1}, input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}}, T{2}, input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}}, T{3}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(T{4}, input);
}
}
TEST_CASE(argmax) { test_argop_var<migraphx::op::argmax>(); }
TEST_CASE(argmin) { test_argop_var<migraphx::op::argmin>(); }
// 2 inputs arguments // 2 inputs arguments
TEST_CASE(matmul) TEST_CASE(matmul)
{ {
......
...@@ -38,7 +38,7 @@ TEST_CASE(test_shape_packed) ...@@ -38,7 +38,7 @@ TEST_CASE(test_shape_packed)
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
} }
TEST_CASE(test_shape_transposed) TEST_CASE(test_shape_transposed1)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 2}}; migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 2}};
EXPECT(not s.standard()); EXPECT(not s.standard());
...@@ -47,6 +47,15 @@ TEST_CASE(test_shape_transposed) ...@@ -47,6 +47,15 @@ TEST_CASE(test_shape_transposed)
EXPECT(not s.broadcasted()); EXPECT(not s.broadcasted());
} }
TEST_CASE(test_shape_transposed2)
{
migraphx::shape s{migraphx::shape::float_type, {1, 1, 1, 1, 2}, {2, 2, 2, 2, 1}};
EXPECT(s.standard());
EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_broadcasted) TEST_CASE(test_shape_broadcasted)
{ {
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 0}}; migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 0}};
......
#include <migraphx/simplify_reshapes.hpp> #include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp> #include <basic_ops.hpp>
#include <test.hpp> #include <test.hpp>
...@@ -165,4 +166,144 @@ TEST_CASE(transpose_double_contiguous) ...@@ -165,4 +166,144 @@ TEST_CASE(transpose_double_contiguous)
EXPECT(p.has_instruction(t)); EXPECT(p.has_instruction(t));
} }
TEST_CASE(transpose_partial1)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
p.add_instruction(pass_op{}, t2);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
TEST_CASE(transpose_partial2)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
p.add_instruction(pass_op{}, t3);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
}
TEST_CASE(transpose_partial3)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
auto t4 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t3);
p.add_instruction(pass_op{}, t4);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 3);
}
TEST_CASE(nop_transpose1)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
TEST_CASE(nop_transpose2)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t2);
auto t4 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t3);
p.add_instruction(pass_op{}, t4);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 4);
}
TEST_CASE(nop_transpose3)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto concat = p.add_instruction(migraphx::op::concat{3}, x, y);
auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2, 3}}, concat);
auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, t1);
p.add_instruction(pass_op{}, t2);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
}
TEST_CASE(concat_transpose1)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto xt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto yt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y);
auto concat = p.add_instruction(migraphx::op::concat{2}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, concat);
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 3);
auto new_concat =
std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != p.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 3);
}
TEST_CASE(concat_transpose2)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto xt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
auto yt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
p.add_instruction(pass_op{}, t);
auto out_shape = p.get_shape();
auto n = std::distance(p.begin(), p.end());
p.compile(simplify_reshapes_target{});
EXPECT(p.get_shape().lens() == out_shape.lens());
EXPECT(std::distance(p.begin(), p.end()) == n - 2);
auto new_concat =
std::find_if(p.begin(), p.end(), [](auto ins) { return ins.name() == "concat"; });
EXPECT(bool{new_concat != p.end()});
EXPECT(migraphx::any_cast<migraphx::op::concat>(new_concat->get_operator()).axis == 1);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/operators.hpp> #include <migraphx/operators.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/tf.hpp> #include <migraphx/tf.hpp>
#include "test.hpp" #include "test.hpp"
migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
{
auto prog = migraphx::parse_tf(name, is_nhwc);
if(is_nhwc)
migraphx::run_passes(prog,
{migraphx::simplify_reshapes{},
migraphx::dead_code_elimination{},
migraphx::eliminate_identity{}});
return prog;
}
TEST_CASE(add_test) TEST_CASE(add_test)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
p.add_instruction(migraphx::op::add{}, l0, l1); p.add_instruction(migraphx::op::add{}, l0, l1);
auto prog = migraphx::parse_tf("add_test.pb", false); auto prog = optimize_tf("add_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -28,7 +43,7 @@ TEST_CASE(add_bcast_test) ...@@ -28,7 +43,7 @@ TEST_CASE(add_bcast_test)
auto l2 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l0); auto l2 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1); auto l3 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1);
p.add_instruction(migraphx::op::add{}, l2, l3); p.add_instruction(migraphx::op::add{}, l2, l3);
auto prog = migraphx::parse_tf("add_bcast_test.pb", false); auto prog = optimize_tf("add_bcast_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -51,7 +66,7 @@ TEST_CASE(batchnorm_test) ...@@ -51,7 +66,7 @@ TEST_CASE(batchnorm_test)
auto l4 = p.add_parameter("4", s0); auto l4 = p.add_parameter("4", s0);
auto l1 = p.add_literal(migraphx::literal{s0, const_vals}); auto l1 = p.add_literal(migraphx::literal{s0, const_vals});
p.add_instruction(op, l0, l1, l2, l3, l4); p.add_instruction(op, l0, l1, l2, l3, l4);
auto prog = migraphx::parse_tf("batchnorm_test.pb", true); auto prog = optimize_tf("batchnorm_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -65,7 +80,7 @@ TEST_CASE(biasadd_test) ...@@ -65,7 +80,7 @@ TEST_CASE(biasadd_test)
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
auto l2 = p.add_instruction(migraphx::op::broadcast{axis, l0->get_shape().lens()}, l1); auto l2 = p.add_instruction(migraphx::op::broadcast{axis, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::add{}, l0, l2); p.add_instruction(migraphx::op::add{}, l0, l2);
auto prog = migraphx::parse_tf("biasadd_test.pb", true); auto prog = optimize_tf("biasadd_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -83,7 +98,7 @@ TEST_CASE(concat_test) ...@@ -83,7 +98,7 @@ TEST_CASE(concat_test)
p.add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis}); p.add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis});
p.add_instruction(migraphx::op::concat{static_cast<std::size_t>(axis)}, l0, l1); p.add_instruction(migraphx::op::concat{static_cast<std::size_t>(axis)}, l0, l1);
auto prog = migraphx::parse_tf("concat_test.pb", false); auto prog = optimize_tf("concat_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -92,7 +107,7 @@ TEST_CASE(const_test) ...@@ -92,7 +107,7 @@ TEST_CASE(const_test)
{ {
migraphx::program p; migraphx::program p;
p.add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector<float>{1.0f}); p.add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector<float>{1.0f});
auto prog = migraphx::parse_tf("constant_test.pb", false); auto prog = optimize_tf("constant_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -112,10 +127,9 @@ TEST_CASE(conv_test) ...@@ -112,10 +127,9 @@ TEST_CASE(conv_test)
op.padding = {1, 1}; op.padding = {1, 1};
op.stride = {1, 1}; op.stride = {1, 1};
op.dilation = {1, 1}; op.dilation = {1, 1};
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1); auto l2 = p.add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
auto l3 = p.add_instruction(migraphx::op::transpose{{1, 3, 0, 2}}, l2); p.add_instruction(op, l0, l2);
p.add_instruction(op, l0, l3); auto prog = optimize_tf("conv_test.pb", true);
auto prog = migraphx::parse_tf("conv_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -136,12 +150,11 @@ TEST_CASE(depthwiseconv_test) ...@@ -136,12 +150,11 @@ TEST_CASE(depthwiseconv_test)
op.stride = {1, 1}; op.stride = {1, 1};
op.dilation = {1, 1}; op.dilation = {1, 1};
op.group = 3; op.group = 3;
auto l2 = p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l1); auto l3 = p.add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
auto l3 = p.add_instruction(migraphx::op::transpose{{1, 3, 0, 2}}, l2);
auto l4 = p.add_instruction(migraphx::op::contiguous{}, l3); auto l4 = p.add_instruction(migraphx::op::contiguous{}, l3);
auto l5 = p.add_instruction(migraphx::op::reshape{{3, 1, 3, 3}}, l4); auto l5 = p.add_instruction(migraphx::op::reshape{{3, 1, 3, 3}}, l4);
p.add_instruction(op, l0, l5); p.add_instruction(op, l0, l5);
auto prog = migraphx::parse_tf("depthwise_conv_test.pb", true); auto prog = optimize_tf("depthwise_conv_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -151,11 +164,20 @@ TEST_CASE(expanddims_test) ...@@ -151,11 +164,20 @@ TEST_CASE(expanddims_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
p.add_literal(-1);
p.add_literal(0);
p.add_instruction(migraphx::op::reshape{{2, 3, 4, 1}}, l0);
p.add_instruction(migraphx::op::reshape{{1, 2, 3, 4}}, l0); p.add_instruction(migraphx::op::reshape{{1, 2, 3, 4}}, l0);
auto prog = migraphx::parse_tf("expanddims_test.pb", true); auto prog = optimize_tf("expanddims_test.pb", true);
EXPECT(p == prog);
}
TEST_CASE(expanddims_test_neg_dims)
{
// this check makes sure the pb parses negative dim value correctly
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
p.add_instruction(migraphx::op::reshape{{2, 3, 4, 1}}, l0);
auto prog = optimize_tf("expanddims_neg_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -165,7 +187,7 @@ TEST_CASE(identity_test) ...@@ -165,7 +187,7 @@ TEST_CASE(identity_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::identity{}, l0); p.add_instruction(migraphx::op::identity{}, l0);
auto prog = migraphx::parse_tf("identity_test.pb", false); auto prog = optimize_tf("identity_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -180,7 +202,7 @@ TEST_CASE(matmul_test) ...@@ -180,7 +202,7 @@ TEST_CASE(matmul_test)
auto trans_l1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1); auto trans_l1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
p.add_instruction(migraphx::op::dot{}, trans_l0, trans_l1); p.add_instruction(migraphx::op::dot{}, trans_l0, trans_l1);
auto prog = migraphx::parse_tf("matmul_test.pb", false); auto prog = optimize_tf("matmul_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -197,7 +219,7 @@ TEST_CASE(mean_test) ...@@ -197,7 +219,7 @@ TEST_CASE(mean_test)
p.add_instruction(op, l0); p.add_instruction(op, l0);
auto l3 = p.add_instruction(op, l0); auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3); p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
auto prog = migraphx::parse_tf("mean_test.pb", false); auto prog = optimize_tf("mean_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -207,14 +229,11 @@ TEST_CASE(mean_test_nhwc) ...@@ -207,14 +229,11 @@ TEST_CASE(mean_test_nhwc)
migraphx::program p; migraphx::program p;
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}}; migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_literal(l);
p.add_literal(l);
migraphx::op::pooling op; migraphx::op::pooling op;
op.lengths = {16, 16}; op.lengths = {16, 16};
p.add_instruction(op, l0); auto l3 = p.add_instruction(op, l0);
auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3); p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
auto prog = migraphx::parse_tf("mean_test_nhwc.pb", true); auto prog = optimize_tf("mean_test_nhwc.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -226,7 +245,7 @@ TEST_CASE(mul_test) ...@@ -226,7 +245,7 @@ TEST_CASE(mul_test)
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
p.add_instruction(migraphx::op::mul{}, l0, l1); p.add_instruction(migraphx::op::mul{}, l0, l1);
auto prog = migraphx::parse_tf("mul_test.pb", false); auto prog = optimize_tf("mul_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -248,7 +267,7 @@ TEST_CASE(pack_test) ...@@ -248,7 +267,7 @@ TEST_CASE(pack_test)
return p.add_instruction(migraphx::op::unsqueeze{{axis}}, arg); return p.add_instruction(migraphx::op::unsqueeze{{axis}}, arg);
}); });
p.add_instruction(migraphx::op::concat{static_cast<size_t>(axis)}, unsqueezed_args); p.add_instruction(migraphx::op::concat{static_cast<size_t>(axis)}, unsqueezed_args);
auto prog = migraphx::parse_tf("pack_test.pb", false); auto prog = optimize_tf("pack_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -256,12 +275,15 @@ TEST_CASE(pack_test) ...@@ -256,12 +275,15 @@ TEST_CASE(pack_test)
TEST_CASE(pack_test_nhwc) TEST_CASE(pack_test_nhwc)
{ {
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); auto lt0 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}}); auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
std::vector<migraphx::instruction_ref> args{l0, l1, l2}; auto lt1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l1);
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt2 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l2);
std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
std::vector<migraphx::instruction_ref> unsqueezed_args; std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t nchw_axis = 1; int64_t nchw_axis = 3;
std::transform(args.begin(), std::transform(args.begin(),
args.end(), args.end(),
...@@ -270,7 +292,7 @@ TEST_CASE(pack_test_nhwc) ...@@ -270,7 +292,7 @@ TEST_CASE(pack_test_nhwc)
return p.add_instruction(migraphx::op::unsqueeze{{nchw_axis}}, arg); return p.add_instruction(migraphx::op::unsqueeze{{nchw_axis}}, arg);
}); });
p.add_instruction(migraphx::op::concat{static_cast<size_t>(nchw_axis)}, unsqueezed_args); p.add_instruction(migraphx::op::concat{static_cast<size_t>(nchw_axis)}, unsqueezed_args);
auto prog = migraphx::parse_tf("pack_test_nhwc.pb", true); auto prog = optimize_tf("pack_test_nhwc.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -287,9 +309,9 @@ TEST_CASE(pooling_test) ...@@ -287,9 +309,9 @@ TEST_CASE(pooling_test)
max_pool_op.stride = {2, 2}; max_pool_op.stride = {2, 2};
avg_pool_op.lengths = {2, 2}; avg_pool_op.lengths = {2, 2};
max_pool_op.lengths = {2, 2}; max_pool_op.lengths = {2, 2};
p.add_instruction(avg_pool_op, l0);
p.add_instruction(max_pool_op, l0); p.add_instruction(max_pool_op, l0);
auto prog = migraphx::parse_tf("pooling_test.pb", true); // p.add_instruction(avg_pool_op, l0);
auto prog = optimize_tf("pooling_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -299,7 +321,7 @@ TEST_CASE(relu_test) ...@@ -299,7 +321,7 @@ TEST_CASE(relu_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::relu{}, l0); p.add_instruction(migraphx::op::relu{}, l0);
auto prog = migraphx::parse_tf("relu_test.pb", false); auto prog = optimize_tf("relu_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -309,7 +331,7 @@ TEST_CASE(relu6_test) ...@@ -309,7 +331,7 @@ TEST_CASE(relu6_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::clip{6.0, 0.0}, l0); p.add_instruction(migraphx::op::clip{6.0, 0.0}, l0);
auto prog = migraphx::parse_tf("relu6_test.pb", false); auto prog = optimize_tf("relu6_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -322,7 +344,7 @@ TEST_CASE(reshape_test) ...@@ -322,7 +344,7 @@ TEST_CASE(reshape_test)
// in tf, the second arg is a literal that contains new dimensions // in tf, the second arg is a literal that contains new dimensions
p.add_literal(migraphx::literal{s0, {1, 1, 1, 16}}); p.add_literal(migraphx::literal{s0, {1, 1, 1, 16}});
p.add_instruction(migraphx::op::reshape{{1, 1, 1, 16}}, l0); p.add_instruction(migraphx::op::reshape{{1, 1, 1, 16}}, l0);
auto prog = migraphx::parse_tf("reshape_test.pb", false); auto prog = optimize_tf("reshape_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -335,7 +357,7 @@ TEST_CASE(softmax_test) ...@@ -335,7 +357,7 @@ TEST_CASE(softmax_test)
auto r = p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, l0); auto r = p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1]), 1, 1}}, l0);
auto s = p.add_instruction(migraphx::op::softmax{}, r); auto s = p.add_instruction(migraphx::op::softmax{}, r);
p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1])}}, s); p.add_instruction(migraphx::op::reshape{{long(dims[0]), long(dims[1])}}, s);
auto prog = migraphx::parse_tf("softmax_test.pb", false); auto prog = optimize_tf("softmax_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -345,7 +367,7 @@ TEST_CASE(squeeze_test) ...@@ -345,7 +367,7 @@ TEST_CASE(squeeze_test)
migraphx::program p; migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}}); auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}});
p.add_instruction(migraphx::op::squeeze{{0, 3}}, l0); p.add_instruction(migraphx::op::squeeze{{0, 3}}, l0);
auto prog = migraphx::parse_tf("squeeze_test.pb", false); auto prog = optimize_tf("squeeze_test.pb", false);
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -357,18 +379,13 @@ TEST_CASE(stridedslice_test) ...@@ -357,18 +379,13 @@ TEST_CASE(stridedslice_test)
std::size_t num_axes = 4; std::size_t num_axes = 4;
migraphx::op::slice op; migraphx::op::slice op;
op.starts = {0, 0, 0, 0}; op.starts = {0, 0, 0, 0};
op.ends = {1, 5, 1, 1}; op.ends = {1, 1, 1, 5};
op.axes = std::vector<int64_t>(num_axes); op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0); std::iota(op.axes.begin(), op.axes.end(), 0);
// add literals for starts, ends, and strides in tf (NHWC format)
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 0, 0, 0});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 5});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 1});
auto l1 = p.add_instruction(op, l0); auto l1 = p.add_instruction(op, l0);
auto shrink_axis = 2; auto shrink_axis = 1;
p.add_instruction(migraphx::op::squeeze{{shrink_axis}}, l1); p.add_instruction(migraphx::op::squeeze{{shrink_axis}}, l1);
auto prog = migraphx::parse_tf("stridedslice_test.pb", true); auto prog = optimize_tf("stridedslice_test.pb", true);
EXPECT(p == prog); EXPECT(p == prog);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment