Unverified Commit 2466dd6f authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Refactor program to module (#684)



* code backup

* clang format

* change corresponding tool files

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent de10423f
......@@ -10,34 +10,37 @@
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}});
migraphx::run_passes(*p.get_main_module(),
{migraphx::simplify_algebra{}, migraphx::dead_code_elimination{}});
}
TEST_CASE(simplify_add1)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = p1.add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, one);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, two);
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3);
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = mm1->add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = mm1->add_literal(1);
auto two = mm1->add_literal(2);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, one);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, two);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
mm1->add_instruction(pass_op{}, sum3);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = p2.add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two);
auto sum2 = p2.add_instruction(migraphx::op::add{}, x, y);
auto sum3 = p2.add_instruction(migraphx::op::add{}, sum2, sum1);
p2.add_instruction(pass_op{}, sum3);
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = mm2->add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, two);
auto sum2 = mm2->add_instruction(migraphx::op::add{}, x, y);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum2, sum1);
mm2->add_instruction(pass_op{}, sum3);
}
EXPECT(p1 == p2);
}
......@@ -46,27 +49,29 @@ TEST_CASE(simplify_add2)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = p1.add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto sum1 = p1.add_instruction(migraphx::op::add{}, one, x);
auto sum2 = p1.add_instruction(migraphx::op::add{}, two, y);
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3);
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = mm1->add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = mm1->add_literal(1);
auto two = mm1->add_literal(2);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, one, x);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, two, y);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
mm1->add_instruction(pass_op{}, sum3);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = p2.add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two);
auto sum2 = p2.add_instruction(migraphx::op::add{}, x, y);
auto sum3 = p2.add_instruction(migraphx::op::add{}, sum2, sum1);
p2.add_instruction(pass_op{}, sum3);
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = mm2->add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, two);
auto sum2 = mm2->add_instruction(migraphx::op::add{}, x, y);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum2, sum1);
mm2->add_instruction(pass_op{}, sum3);
}
EXPECT(p1 == p2);
}
......@@ -75,25 +80,27 @@ TEST_CASE(simplify_add3)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto sum1 = p1.add_instruction(migraphx::op::add{}, one, x);
auto sum2 = p1.add_instruction(migraphx::op::add{}, one, two);
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3);
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = mm1->add_literal(1);
auto two = mm1->add_literal(2);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, one, x);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, one, two);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
mm1->add_instruction(pass_op{}, sum3);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two);
auto sum2 = p2.add_instruction(migraphx::op::add{}, one, sum1);
auto sum3 = p2.add_instruction(migraphx::op::add{}, x, sum2);
p2.add_instruction(pass_op{}, sum3);
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, two);
auto sum2 = mm2->add_instruction(migraphx::op::add{}, one, sum1);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, x, sum2);
mm2->add_instruction(pass_op{}, sum3);
}
EXPECT(p1 == p2);
}
......@@ -105,30 +112,32 @@ TEST_CASE(simplify_add_broadcast1)
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
migraphx::program p1;
{
auto x = p1.add_parameter("x", outer);
auto y = p1.add_parameter("y", outer);
auto one = p1.add_literal({inner, {1, 1}});
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal({inner, {2, 2}});
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3);
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", outer);
auto y = mm1->add_parameter("y", outer);
auto one = mm1->add_literal({inner, {1, 1}});
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal({inner, {2, 2}});
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
mm1->add_instruction(pass_op{}, sum3);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", outer);
auto y = p2.add_parameter("y", outer);
auto one = p2.add_literal({inner, {1, 1}});
auto two = p2.add_literal({inner, {2, 2}});
auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two);
auto sum1b = p2.add_instruction(b, sum1);
auto sum2 = p2.add_instruction(migraphx::op::add{}, x, y);
auto sum3 = p2.add_instruction(migraphx::op::add{}, sum2, sum1b);
p2.add_instruction(pass_op{}, sum3);
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", outer);
auto y = mm2->add_parameter("y", outer);
auto one = mm2->add_literal({inner, {1, 1}});
auto two = mm2->add_literal({inner, {2, 2}});
auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, two);
auto sum1b = mm2->add_instruction(b, sum1);
auto sum2 = mm2->add_instruction(migraphx::op::add{}, x, y);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum2, sum1b);
mm2->add_instruction(pass_op{}, sum3);
}
EXPECT(p1 == p2);
}
......@@ -140,15 +149,16 @@ TEST_CASE(simplify_add_broadcast2)
migraphx::op::broadcast b{1, {1, 2, 3, 3}};
auto create_program = [&] {
migraphx::program p;
auto x = p.add_parameter("x", outer);
auto y = p.add_parameter("y", outer);
auto one = p.add_literal({inner, {1, 1}});
auto oneb = p.add_instruction(b, one);
auto two = p.add_literal({outer, {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}});
auto sum1 = p.add_instruction(migraphx::op::add{}, x, y);
auto sum2 = p.add_instruction(migraphx::op::add{}, oneb, two);
auto sum3 = p.add_instruction(migraphx::op::add{}, sum2, sum1);
p.add_instruction(pass_op{}, sum3);
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", outer);
auto y = mm->add_parameter("y", outer);
auto one = mm->add_literal({inner, {1, 1}});
auto oneb = mm->add_instruction(b, one);
auto two = mm->add_literal({outer, {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}});
auto sum1 = mm->add_instruction(migraphx::op::add{}, x, y);
auto sum2 = mm->add_instruction(migraphx::op::add{}, oneb, two);
auto sum3 = mm->add_instruction(migraphx::op::add{}, sum2, sum1);
mm->add_instruction(pass_op{}, sum3);
return p;
};
migraphx::program p1 = create_program();
......@@ -164,27 +174,29 @@ void simplify_add4()
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = p1.add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto sum1 = p1.add_instruction(migraphx::op::add{}, one, x);
auto sum2 = p1.add_instruction(migraphx::op::add{}, sum1, y);
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum2, two);
p1.add_instruction(pass_op{}, sum3);
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = mm1->add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = mm1->add_literal(1);
auto two = mm1->add_literal(2);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, one, x);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, sum1, y);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum2, two);
mm1->add_instruction(pass_op{}, sum3);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = p2.add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto sum1 = p2.add_instruction(migraphx::op::add{}, one, two);
auto sum2 = p2.add_instruction(migraphx::op::add{}, x, y);
auto sum3 = p2.add_instruction(migraphx::op::add{}, sum2, sum1);
p2.add_instruction(pass_op{}, sum3);
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = mm2->add_parameter("y", {migraphx::shape::int32_type, {1}});
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, one, two);
auto sum2 = mm2->add_instruction(migraphx::op::add{}, x, y);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum2, sum1);
mm2->add_instruction(pass_op{}, sum3);
}
EXPECT(p1 == p2);
}
......@@ -192,14 +204,15 @@ void simplify_add4()
TEST_CASE(simplify_mul_conv1)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}});
auto w =
p.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}}));
auto conv = p.add_instruction(migraphx::op::convolution{{1, 1}, {2, 2}, {1, 1}}, x, w);
auto a = p.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}}));
auto b = p.add_instruction(migraphx::op::broadcast{1, {1, 256, 14, 14}}, a);
auto mul = p.add_instruction(migraphx::op::mul{}, conv, b);
p.add_instruction(pass_op{}, mul);
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::int32_type, {1, 128, 28, 28}});
auto w = mm->add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {256, 128, 3, 3}}));
auto conv = mm->add_instruction(migraphx::op::convolution{{1, 1}, {2, 2}, {1, 1}}, x, w);
auto a = mm->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {256}}));
auto b = mm->add_instruction(migraphx::op::broadcast{1, {1, 256, 14, 14}}, a);
auto mul = mm->add_instruction(migraphx::op::mul{}, conv, b);
mm->add_instruction(pass_op{}, mul);
EXPECT(conv->outputs().front()->name() == "mul");
run_pass(p);
auto new_conv =
......@@ -211,36 +224,38 @@ TEST_CASE(simplify_mul_slice_conv1)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = p1.add_literal(
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = mm1->add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
auto conv = p1.add_instruction(migraphx::op::convolution{}, x, w);
auto slice1 = p1.add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
auto a = p1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = p1.add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a);
auto mul = p1.add_instruction(migraphx::op::mul{}, slice1, b);
auto slice2 = p1.add_instruction(migraphx::op::slice{{1}, {384}, {768}}, conv);
auto add = p1.add_instruction(migraphx::op::add{}, mul, slice2);
p1.add_instruction(pass_op{}, add);
auto conv = mm1->add_instruction(migraphx::op::convolution{}, x, w);
auto slice1 = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
auto a = mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = mm1->add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a);
auto mul = mm1->add_instruction(migraphx::op::mul{}, slice1, b);
auto slice2 = mm1->add_instruction(migraphx::op::slice{{1}, {384}, {768}}, conv);
auto add = mm1->add_instruction(migraphx::op::add{}, mul, slice2);
mm1->add_instruction(pass_op{}, add);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = p2.add_literal(
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = mm2->add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
auto wslice1 = p2.add_instruction(migraphx::op::slice{{0}, {0}, {384}}, w);
auto a = p2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = p2.add_instruction(migraphx::op::broadcast{0, {384, 1024, 1, 1}}, a);
auto mul = p2.add_instruction(migraphx::op::mul{}, b, wslice1);
auto wslice2 = p2.add_instruction(migraphx::op::slice{{0}, {384}, {768}}, w);
auto concat = p2.add_instruction(migraphx::op::concat{0}, mul, wslice2);
auto conv = p2.add_instruction(migraphx::op::convolution{}, x, concat);
auto slice1 = p2.add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
auto slice2 = p2.add_instruction(migraphx::op::slice{{1}, {384}, {768}}, conv);
auto add = p2.add_instruction(migraphx::op::add{}, slice1, slice2);
p2.add_instruction(pass_op{}, add);
auto wslice1 = mm2->add_instruction(migraphx::op::slice{{0}, {0}, {384}}, w);
auto a = mm2->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = mm2->add_instruction(migraphx::op::broadcast{0, {384, 1024, 1, 1}}, a);
auto mul = mm2->add_instruction(migraphx::op::mul{}, b, wslice1);
auto wslice2 = mm2->add_instruction(migraphx::op::slice{{0}, {384}, {768}}, w);
auto concat = mm2->add_instruction(migraphx::op::concat{0}, mul, wslice2);
auto conv = mm2->add_instruction(migraphx::op::convolution{}, x, concat);
auto slice1 = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
auto slice2 = mm2->add_instruction(migraphx::op::slice{{1}, {384}, {768}}, conv);
auto add = mm2->add_instruction(migraphx::op::add{}, slice1, slice2);
mm2->add_instruction(pass_op{}, add);
}
EXPECT(p1 == p2);
}
......@@ -249,17 +264,18 @@ TEST_CASE(simplify_mul_slice_conv_overlapping_slice)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = p1.add_literal(
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = mm1->add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
auto conv = p1.add_instruction(migraphx::op::convolution{}, x, w);
auto slice1 = p1.add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
auto a = p1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = p1.add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a);
auto mul = p1.add_instruction(migraphx::op::mul{}, slice1, b);
auto slice2 = p1.add_instruction(migraphx::op::slice{{1}, {383}, {767}}, conv);
auto add = p1.add_instruction(migraphx::op::add{}, mul, slice2);
p1.add_instruction(pass_op{}, add);
auto conv = mm1->add_instruction(migraphx::op::convolution{}, x, w);
auto slice1 = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
auto a = mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = mm1->add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a);
auto mul = mm1->add_instruction(migraphx::op::mul{}, slice1, b);
auto slice2 = mm1->add_instruction(migraphx::op::slice{{1}, {383}, {767}}, conv);
auto add = mm1->add_instruction(migraphx::op::add{}, mul, slice2);
mm1->add_instruction(pass_op{}, add);
}
migraphx::program p2 = p1;
run_pass(p1);
......@@ -270,19 +286,20 @@ TEST_CASE(simplify_mul_slice_conv_not_all_slice)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = p1.add_literal(
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = mm1->add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
auto conv = p1.add_instruction(migraphx::op::convolution{}, x, w);
auto slice1 = p1.add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
auto a = p1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = p1.add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a);
auto mul = p1.add_instruction(migraphx::op::mul{}, slice1, b);
auto c = p1.add_literal(
auto conv = mm1->add_instruction(migraphx::op::convolution{}, x, w);
auto slice1 = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
auto a = mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}));
auto b = mm1->add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a);
auto mul = mm1->add_instruction(migraphx::op::mul{}, slice1, b);
auto c = mm1->add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {1, 768, 17, 17}}));
auto add = p1.add_instruction(migraphx::op::add{}, conv, c);
auto concat = p1.add_instruction(migraphx::op::concat{1}, mul, add);
p1.add_instruction(pass_op{}, concat);
auto add = mm1->add_instruction(migraphx::op::add{}, conv, c);
auto concat = mm1->add_instruction(migraphx::op::concat{1}, mul, add);
mm1->add_instruction(pass_op{}, concat);
}
migraphx::program p2 = p1;
run_pass(p1);
......@@ -293,24 +310,26 @@ TEST_CASE(simplify_mul_add)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = p1.add_literal(1);
auto two = p1.add_literal(2);
auto sum = p1.add_instruction(migraphx::op::add{}, one, x);
auto mul = p1.add_instruction(migraphx::op::mul{}, sum, two);
p1.add_instruction(pass_op{}, mul);
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = mm1->add_literal(1);
auto two = mm1->add_literal(2);
auto sum = mm1->add_instruction(migraphx::op::add{}, one, x);
auto mul = mm1->add_instruction(migraphx::op::mul{}, sum, two);
mm1->add_instruction(pass_op{}, mul);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto mul1 = p2.add_instruction(migraphx::op::mul{}, two, x);
auto mul2 = p2.add_instruction(migraphx::op::mul{}, two, one);
auto sum = p2.add_instruction(migraphx::op::add{}, mul1, mul2);
p2.add_instruction(pass_op{}, sum);
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto mul1 = mm2->add_instruction(migraphx::op::mul{}, two, x);
auto mul2 = mm2->add_instruction(migraphx::op::mul{}, two, one);
auto sum = mm2->add_instruction(migraphx::op::add{}, mul1, mul2);
mm2->add_instruction(pass_op{}, sum);
}
EXPECT(p1 == p2);
}
......@@ -320,22 +339,24 @@ TEST_CASE(simplify_inner_broadcast)
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = p1.add_parameter("y", {migraphx::shape::int32_type, {1}});
auto xb = p1.add_instruction(b, x);
auto yb = p1.add_instruction(b, y);
auto sum = p1.add_instruction(migraphx::op::add{}, xb, yb);
p1.add_instruction(pass_op{}, sum);
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = mm1->add_parameter("y", {migraphx::shape::int32_type, {1}});
auto xb = mm1->add_instruction(b, x);
auto yb = mm1->add_instruction(b, y);
auto sum = mm1->add_instruction(migraphx::op::add{}, xb, yb);
mm1->add_instruction(pass_op{}, sum);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = p2.add_parameter("y", {migraphx::shape::int32_type, {1}});
auto sum = p2.add_instruction(migraphx::op::add{}, x, y);
auto sumb = p2.add_instruction(b, sum);
p2.add_instruction(pass_op{}, sumb);
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto y = mm2->add_parameter("y", {migraphx::shape::int32_type, {1}});
auto sum = mm2->add_instruction(migraphx::op::add{}, x, y);
auto sumb = mm2->add_instruction(b, sum);
mm2->add_instruction(pass_op{}, sumb);
}
EXPECT(p1 == p2);
}
......@@ -343,16 +364,17 @@ TEST_CASE(simplify_inner_broadcast)
TEST_CASE(simplify_add_conv1)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 28}});
auto w =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 3, 3}}));
auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}});
auto v =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 3, 3}}));
auto conv1 = p.add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = p.add_instruction(migraphx::op::convolution{}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 28}});
auto w = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 3, 3}}));
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}});
auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 3, 3}}));
auto conv1 = mm->add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = mm->add_instruction(migraphx::op::convolution{}, y, v);
auto sum = mm->add_instruction(migraphx::op::add{}, conv1, conv2);
mm->add_instruction(pass_op{}, sum);
auto s = p.get_output_shapes().back();
run_pass(p);
EXPECT(s == p.get_output_shapes().back());
......@@ -363,16 +385,17 @@ TEST_CASE(simplify_add_conv1)
TEST_CASE(simplify_add_conv_no_fusion_7x7_diff_strides)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}});
auto w =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 7, 7}}));
auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}});
auto v =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 7, 7}}));
auto conv1 = p.add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {3, 3}}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}});
auto w = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 7, 7}}));
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}});
auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 7, 7}}));
auto conv1 = mm->add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = mm->add_instruction(migraphx::op::convolution{{0, 0}, {3, 3}}, y, v);
auto sum = mm->add_instruction(migraphx::op::add{}, conv1, conv2);
mm->add_instruction(pass_op{}, sum);
auto s = p.get_output_shapes().back();
run_pass(p);
EXPECT(s == p.get_output_shapes().back());
......@@ -384,16 +407,17 @@ TEST_CASE(simplify_add_conv_no_fusion_7x7_diff_strides)
TEST_CASE(simplify_add_conv_1x1_diff_strides1)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}});
auto w =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}});
auto v =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto conv1 = p.add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}});
auto w = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 28}});
auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto conv1 = mm->add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = mm->add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, y, v);
auto sum = mm->add_instruction(migraphx::op::add{}, conv1, conv2);
mm->add_instruction(pass_op{}, sum);
auto s = p.get_output_shapes().back();
run_pass(p);
EXPECT(s == p.get_output_shapes().back());
......@@ -404,16 +428,17 @@ TEST_CASE(simplify_add_conv_1x1_diff_strides1)
TEST_CASE(simplify_add_conv_1x1_diff_strides2)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 28}});
auto w =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 128, 14, 14}});
auto v =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto conv1 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, x, w);
auto conv2 = p.add_instruction(migraphx::op::convolution{}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 28}});
auto w = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 14, 14}});
auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto conv1 = mm->add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, x, w);
auto conv2 = mm->add_instruction(migraphx::op::convolution{}, y, v);
auto sum = mm->add_instruction(migraphx::op::add{}, conv1, conv2);
mm->add_instruction(pass_op{}, sum);
auto s = p.get_output_shapes().back();
run_pass(p);
EXPECT(s == p.get_output_shapes().back());
......@@ -424,16 +449,17 @@ TEST_CASE(simplify_add_conv_1x1_diff_strides2)
TEST_CASE(simplify_add_conv_1x1_diff_strides_odd)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 54, 83, 83}});
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 54, 83, 83}});
auto w =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {54, 54, 1, 1}}));
auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 54, 165, 165}});
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {54, 54, 1, 1}}));
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 54, 165, 165}});
auto v =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {54, 54, 1, 1}}));
auto conv1 = p.add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {54, 54, 1, 1}}));
auto conv1 = mm->add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = mm->add_instruction(migraphx::op::convolution{{0, 0}, {2, 2}}, y, v);
auto sum = mm->add_instruction(migraphx::op::add{}, conv1, conv2);
mm->add_instruction(pass_op{}, sum);
auto s = p.get_output_shapes().back();
run_pass(p);
EXPECT(s == p.get_output_shapes().back());
......@@ -444,16 +470,17 @@ TEST_CASE(simplify_add_conv_1x1_diff_strides_odd)
TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides1)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 14}});
auto w =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 128, 14, 14}});
auto v =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto conv1 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 1}}, x, w);
auto conv2 = p.add_instruction(migraphx::op::convolution{}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 28, 14}});
auto w = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 14, 14}});
auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto conv1 = mm->add_instruction(migraphx::op::convolution{{0, 0}, {2, 1}}, x, w);
auto conv2 = mm->add_instruction(migraphx::op::convolution{}, y, v);
auto sum = mm->add_instruction(migraphx::op::add{}, conv1, conv2);
mm->add_instruction(pass_op{}, sum);
auto s = p.get_output_shapes().back();
run_pass(p);
EXPECT(s == p.get_output_shapes().back());
......@@ -465,16 +492,17 @@ TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides1)
TEST_CASE(simplify_add_conv_no_fusion_asymetrical_strides2)
{
migraphx::program p;
auto x = p.add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}});
auto w =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto y = p.add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 14}});
auto v =
p.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto conv1 = p.add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = p.add_instruction(migraphx::op::convolution{{0, 0}, {2, 1}}, y, v);
auto sum = p.add_instruction(migraphx::op::add{}, conv1, conv2);
p.add_instruction(pass_op{}, sum);
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {migraphx::shape::float_type, {1, 128, 14, 14}});
auto w = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto y = mm->add_parameter("y", {migraphx::shape::float_type, {1, 128, 28, 14}});
auto v = mm->add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {256, 128, 1, 1}}));
auto conv1 = mm->add_instruction(migraphx::op::convolution{}, x, w);
auto conv2 = mm->add_instruction(migraphx::op::convolution{{0, 0}, {2, 1}}, y, v);
auto sum = mm->add_instruction(migraphx::op::add{}, conv1, conv2);
mm->add_instruction(pass_op{}, sum);
auto s = p.get_output_shapes().back();
run_pass(p);
EXPECT(s == p.get_output_shapes().back());
......@@ -488,30 +516,32 @@ TEST_CASE(simplify_concat_add_relu)
auto s = migraphx::shape{migraphx::shape::int32_type, {1}};
migraphx::program p1;
{
auto x = p1.add_parameter("x", s);
auto y = p1.add_parameter("y", s);
auto one = p1.add_literal({s, {1}});
auto two = p1.add_literal({s, {2}});
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, one);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, two);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto concat = p1.add_instruction(migraphx::op::concat{0}, relu1, relu2);
p1.add_instruction(pass_op{}, concat);
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", s);
auto y = mm1->add_parameter("y", s);
auto one = mm1->add_literal({s, {1}});
auto two = mm1->add_literal({s, {2}});
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, one);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, two);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto concat = mm1->add_instruction(migraphx::op::concat{0}, relu1, relu2);
mm1->add_instruction(pass_op{}, concat);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", s);
auto y = p2.add_parameter("y", s);
auto one = p2.add_literal({s, {1}});
auto two = p2.add_literal({s, {2}});
auto concat1 = p2.add_instruction(migraphx::op::concat{0}, x, y);
auto concat2 = p2.add_instruction(migraphx::op::concat{0}, one, two);
auto sum = p2.add_instruction(migraphx::op::add{}, concat1, concat2);
auto relu = p2.add_instruction(migraphx::op::relu{}, sum);
p2.add_instruction(pass_op{}, relu);
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", s);
auto y = mm2->add_parameter("y", s);
auto one = mm2->add_literal({s, {1}});
auto two = mm2->add_literal({s, {2}});
auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, x, y);
auto concat2 = mm2->add_instruction(migraphx::op::concat{0}, one, two);
auto sum = mm2->add_instruction(migraphx::op::add{}, concat1, concat2);
auto relu = mm2->add_instruction(migraphx::op::relu{}, sum);
mm2->add_instruction(pass_op{}, relu);
}
EXPECT(p1 == p2);
}
......@@ -521,33 +551,35 @@ TEST_CASE(simplify_concat_add_relu_partial)
auto s = migraphx::shape{migraphx::shape::int32_type, {1}};
migraphx::program p1;
{
auto x = p1.add_parameter("x", s);
auto y = p1.add_parameter("y", s);
auto one = p1.add_literal({s, {1}});
auto two = p1.add_literal({s, {2}});
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, one);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, two);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto sum3 = p1.add_instruction(migraphx::op::add{}, x, y);
auto concat = p1.add_instruction(migraphx::op::concat{0}, sum3, relu1, relu2);
p1.add_instruction(pass_op{}, concat);
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", s);
auto y = mm1->add_parameter("y", s);
auto one = mm1->add_literal({s, {1}});
auto two = mm1->add_literal({s, {2}});
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, one);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, two);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, x, y);
auto concat = mm1->add_instruction(migraphx::op::concat{0}, sum3, relu1, relu2);
mm1->add_instruction(pass_op{}, concat);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", s);
auto y = p2.add_parameter("y", s);
auto one = p2.add_literal({s, {1}});
auto two = p2.add_literal({s, {2}});
auto concat1 = p2.add_instruction(migraphx::op::concat{0}, x, y);
auto concat2 = p2.add_instruction(migraphx::op::concat{0}, one, two);
auto sum1 = p2.add_instruction(migraphx::op::add{}, concat1, concat2);
auto relu = p2.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p2.add_instruction(migraphx::op::add{}, x, y);
auto concat = p2.add_instruction(migraphx::op::concat{0}, sum2, relu);
p2.add_instruction(pass_op{}, concat);
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", s);
auto y = mm2->add_parameter("y", s);
auto one = mm2->add_literal({s, {1}});
auto two = mm2->add_literal({s, {2}});
auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, x, y);
auto concat2 = mm2->add_instruction(migraphx::op::concat{0}, one, two);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, concat1, concat2);
auto relu = mm2->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm2->add_instruction(migraphx::op::add{}, x, y);
auto concat = mm2->add_instruction(migraphx::op::concat{0}, sum2, relu);
mm2->add_instruction(pass_op{}, concat);
}
EXPECT(p1.sort() == p2.sort());
}
......@@ -557,31 +589,33 @@ TEST_CASE(simplify_concat_add_relu_partial_broadcast)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto x = p1.add_parameter("x", s);
auto y = p1.add_parameter("y", s);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum = p1.add_instruction(migraphx::op::add{}, x, y);
auto concat = p1.add_instruction(migraphx::op::concat{1}, sum, oneb, twob);
p1.add_instruction(pass_op{}, concat);
auto x = mm1->add_parameter("x", s);
auto y = mm1->add_parameter("y", s);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum = mm1->add_instruction(migraphx::op::add{}, x, y);
auto concat = mm1->add_instruction(migraphx::op::concat{1}, sum, oneb, twob);
mm1->add_instruction(pass_op{}, concat);
}
run_pass(p1);
migraphx::program p2;
{
auto* mm2 = p2.get_main_module();
auto b = migraphx::op::broadcast{1, {2, 2, 4, 5}};
auto x = p2.add_parameter("x", s);
auto y = p2.add_parameter("y", s);
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto concat1 = p2.add_instruction(migraphx::op::concat{0}, one, two);
auto concatb = p2.add_instruction(b, concat1);
auto sum = p2.add_instruction(migraphx::op::add{}, x, y);
auto concat2 = p2.add_instruction(migraphx::op::concat{1}, sum, concatb);
p2.add_instruction(pass_op{}, concat2);
auto x = mm2->add_parameter("x", s);
auto y = mm2->add_parameter("y", s);
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, one, two);
auto concatb = mm2->add_instruction(b, concat1);
auto sum = mm2->add_instruction(migraphx::op::add{}, x, y);
auto concat2 = mm2->add_instruction(migraphx::op::concat{1}, sum, concatb);
mm2->add_instruction(pass_op{}, concat2);
}
EXPECT(p1.sort() == p2.sort());
}
......@@ -591,35 +625,37 @@ TEST_CASE(simplify_concat_add_relu_broadcast_different_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto x = p1.add_parameter("x", s);
auto y = p1.add_parameter("y", s);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto concat = p1.add_instruction(migraphx::op::concat{1}, relu1, relu2);
p1.add_instruction(pass_op{}, concat);
auto x = mm1->add_parameter("x", s);
auto y = mm1->add_parameter("y", s);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto concat = mm1->add_instruction(migraphx::op::concat{1}, relu1, relu2);
mm1->add_instruction(pass_op{}, concat);
}
run_pass(p1);
migraphx::program p2;
{
auto* mm2 = p2.get_main_module();
auto b = migraphx::op::broadcast{1, {2, 2, 4, 5}};
auto x = p2.add_parameter("x", s);
auto y = p2.add_parameter("y", s);
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto concat1 = p2.add_instruction(migraphx::op::concat{1}, x, y);
auto concat2 = p2.add_instruction(migraphx::op::concat{0}, one, two);
auto concat2b = p2.add_instruction(b, concat2);
auto sum = p2.add_instruction(migraphx::op::add{}, concat1, concat2b);
auto relu = p2.add_instruction(migraphx::op::relu{}, sum);
p2.add_instruction(pass_op{}, relu);
auto x = mm2->add_parameter("x", s);
auto y = mm2->add_parameter("y", s);
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto concat1 = mm2->add_instruction(migraphx::op::concat{1}, x, y);
auto concat2 = mm2->add_instruction(migraphx::op::concat{0}, one, two);
auto concat2b = mm2->add_instruction(b, concat2);
auto sum = mm2->add_instruction(migraphx::op::add{}, concat1, concat2b);
auto relu = mm2->add_instruction(migraphx::op::relu{}, sum);
mm2->add_instruction(pass_op{}, relu);
}
EXPECT(p1 == p2);
}
......@@ -629,36 +665,38 @@ TEST_CASE(simplify_concat_add_relu_broadcast_same_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {2, 1, 4, 5}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto x = p1.add_parameter("x", s);
auto y = p1.add_parameter("y", s);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto concat = p1.add_instruction(migraphx::op::concat{0}, relu1, relu2);
p1.add_instruction(pass_op{}, concat);
auto x = mm1->add_parameter("x", s);
auto y = mm1->add_parameter("y", s);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto concat = mm1->add_instruction(migraphx::op::concat{0}, relu1, relu2);
mm1->add_instruction(pass_op{}, concat);
}
run_pass(p1);
migraphx::program p2;
{
auto* mm2 = p2.get_main_module();
auto b = migraphx::op::broadcast{1, {2, 1, 4, 5}};
auto x = p2.add_parameter("x", s);
auto y = p2.add_parameter("y", s);
auto one = p2.add_literal(1);
auto oneb = p2.add_instruction(b, one);
auto two = p2.add_literal(2);
auto twob = p2.add_instruction(b, two);
auto concat1 = p2.add_instruction(migraphx::op::concat{0}, x, y);
auto concat2 = p2.add_instruction(migraphx::op::concat{0}, oneb, twob);
auto sum = p2.add_instruction(migraphx::op::add{}, concat1, concat2);
auto relu = p2.add_instruction(migraphx::op::relu{}, sum);
p2.add_instruction(pass_op{}, relu);
auto x = mm2->add_parameter("x", s);
auto y = mm2->add_parameter("y", s);
auto one = mm2->add_literal(1);
auto oneb = mm2->add_instruction(b, one);
auto two = mm2->add_literal(2);
auto twob = mm2->add_instruction(b, two);
auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, x, y);
auto concat2 = mm2->add_instruction(migraphx::op::concat{0}, oneb, twob);
auto sum = mm2->add_instruction(migraphx::op::add{}, concat1, concat2);
auto relu = mm2->add_instruction(migraphx::op::relu{}, sum);
mm2->add_instruction(pass_op{}, relu);
}
EXPECT(p1 == p2);
}
......@@ -667,18 +705,20 @@ TEST_CASE(simplify_div_const)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto two = p1.add_literal(2);
p1.add_instruction(migraphx::op::div{}, x, two);
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto two = mm1->add_literal(2);
mm1->add_instruction(migraphx::op::div{}, x, two);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto two = p2.add_literal(2);
auto recip = p2.insert_instruction(std::next(two), migraphx::op::recip{}, two);
p2.add_instruction(migraphx::op::mul{}, x, recip);
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto two = mm2->add_literal(2);
auto recip = mm2->insert_instruction(std::next(two), migraphx::op::recip{}, two);
mm2->add_instruction(migraphx::op::mul{}, x, recip);
}
EXPECT(p1 == p2);
}
......@@ -687,18 +727,20 @@ TEST_CASE(simplify_sub_const)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto two = p1.add_literal(2);
p1.add_instruction(migraphx::op::sub{}, x, two);
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto two = mm1->add_literal(2);
mm1->add_instruction(migraphx::op::sub{}, x, two);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto two = p2.add_literal(2);
auto neg = p2.insert_instruction(std::next(two), migraphx::op::neg{}, two);
p2.add_instruction(migraphx::op::add{}, x, neg);
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto two = mm2->add_literal(2);
auto neg = mm2->insert_instruction(std::next(two), migraphx::op::neg{}, two);
mm2->add_instruction(migraphx::op::add{}, x, neg);
}
EXPECT(p1 == p2);
}
......@@ -707,17 +749,18 @@ TEST_CASE(simplify_rsqrt)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto sqrt = p1.add_instruction(migraphx::op::sqrt{}, x);
p1.add_instruction(migraphx::op::recip{}, sqrt);
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto sqrt = mm1->add_instruction(migraphx::op::sqrt{}, x);
mm1->add_instruction(migraphx::op::recip{}, sqrt);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
p2.add_instruction(migraphx::op::rsqrt{}, x);
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {1}});
mm2->add_instruction(migraphx::op::rsqrt{}, x);
}
EXPECT(p1 == p2);
}
......@@ -726,11 +769,12 @@ TEST_CASE(simplify_rsqrt_multi_use)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1}});
auto sqrt = p1.add_instruction(migraphx::op::sqrt{}, x);
auto add = p1.add_instruction(migraphx::op::add{}, sqrt, sqrt);
auto rsqrt = p1.add_instruction(migraphx::op::recip{}, sqrt);
p1.add_instruction(migraphx::op::add{}, rsqrt, add);
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1}});
auto sqrt = mm1->add_instruction(migraphx::op::sqrt{}, x);
auto add = mm1->add_instruction(migraphx::op::add{}, sqrt, sqrt);
auto rsqrt = mm1->add_instruction(migraphx::op::recip{}, sqrt);
mm1->add_instruction(migraphx::op::add{}, rsqrt, add);
}
migraphx::program p2{p1};
......@@ -744,24 +788,26 @@ TEST_CASE(simplify_slice_concat)
migraphx::program p1;
{
auto x = p1.add_parameter("x", s);
auto y = p1.add_parameter("y", s);
auto xslice1 = p1.add_instruction(migraphx::op::slice{{0}, {0}, {128}}, x);
auto xslice2 = p1.add_instruction(migraphx::op::slice{{0}, {128}, {256}}, x);
auto yslice1 = p1.add_instruction(migraphx::op::slice{{0}, {0}, {128}}, y);
auto yslice2 = p1.add_instruction(migraphx::op::slice{{0}, {128}, {256}}, y);
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", s);
auto y = mm1->add_parameter("y", s);
auto xslice1 = mm1->add_instruction(migraphx::op::slice{{0}, {0}, {128}}, x);
auto xslice2 = mm1->add_instruction(migraphx::op::slice{{0}, {128}, {256}}, x);
auto yslice1 = mm1->add_instruction(migraphx::op::slice{{0}, {0}, {128}}, y);
auto yslice2 = mm1->add_instruction(migraphx::op::slice{{0}, {128}, {256}}, y);
auto concat =
p1.add_instruction(migraphx::op::concat{0}, xslice1, xslice2, yslice1, yslice2);
p1.add_instruction(pass_op{}, concat);
mm1->add_instruction(migraphx::op::concat{0}, xslice1, xslice2, yslice1, yslice2);
mm1->add_instruction(pass_op{}, concat);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", s);
auto y = p2.add_parameter("y", s);
auto concat = p2.add_instruction(migraphx::op::concat{0}, x, y);
p2.add_instruction(pass_op{}, concat);
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", s);
auto y = mm2->add_parameter("y", s);
auto concat = mm2->add_instruction(migraphx::op::concat{0}, x, y);
mm2->add_instruction(pass_op{}, concat);
}
EXPECT(p1 == p2);
}
......@@ -772,26 +818,28 @@ TEST_CASE(simplify_slice_concat_non_uniform)
migraphx::program p1;
{
auto x = p1.add_parameter("x", s);
auto y = p1.add_parameter("y", s);
auto xslice1 = p1.add_instruction(migraphx::op::slice{{0}, {0}, {64}}, x);
auto xslice2 = p1.add_instruction(migraphx::op::slice{{0}, {64}, {192}}, x);
auto xslice3 = p1.add_instruction(migraphx::op::slice{{0}, {192}, {256}}, x);
auto yslice1 = p1.add_instruction(migraphx::op::slice{{0}, {0}, {64}}, y);
auto yslice2 = p1.add_instruction(migraphx::op::slice{{0}, {64}, {192}}, y);
auto yslice3 = p1.add_instruction(migraphx::op::slice{{0}, {192}, {256}}, y);
auto concat = p1.add_instruction(
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", s);
auto y = mm1->add_parameter("y", s);
auto xslice1 = mm1->add_instruction(migraphx::op::slice{{0}, {0}, {64}}, x);
auto xslice2 = mm1->add_instruction(migraphx::op::slice{{0}, {64}, {192}}, x);
auto xslice3 = mm1->add_instruction(migraphx::op::slice{{0}, {192}, {256}}, x);
auto yslice1 = mm1->add_instruction(migraphx::op::slice{{0}, {0}, {64}}, y);
auto yslice2 = mm1->add_instruction(migraphx::op::slice{{0}, {64}, {192}}, y);
auto yslice3 = mm1->add_instruction(migraphx::op::slice{{0}, {192}, {256}}, y);
auto concat = mm1->add_instruction(
migraphx::op::concat{0}, xslice1, xslice2, xslice3, yslice1, yslice2, yslice3);
p1.add_instruction(pass_op{}, concat);
mm1->add_instruction(pass_op{}, concat);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", s);
auto y = p2.add_parameter("y", s);
auto concat = p2.add_instruction(migraphx::op::concat{0}, x, y);
p2.add_instruction(pass_op{}, concat);
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", s);
auto y = mm2->add_parameter("y", s);
auto concat = mm2->add_instruction(migraphx::op::concat{0}, x, y);
mm2->add_instruction(pass_op{}, concat);
}
EXPECT(p1 == p2);
......@@ -803,17 +851,18 @@ TEST_CASE(simplify_slice_concat_flipped)
migraphx::program p1;
{
auto x = p1.add_parameter("x", s);
auto y = p1.add_parameter("y", s);
auto xslice1 = p1.add_instruction(migraphx::op::slice{{0}, {0}, {64}}, x);
auto xslice2 = p1.add_instruction(migraphx::op::slice{{0}, {192}, {256}}, x);
auto xslice3 = p1.add_instruction(migraphx::op::slice{{0}, {64}, {192}}, x);
auto yslice1 = p1.add_instruction(migraphx::op::slice{{0}, {0}, {64}}, y);
auto yslice2 = p1.add_instruction(migraphx::op::slice{{0}, {192}, {256}}, y);
auto yslice3 = p1.add_instruction(migraphx::op::slice{{0}, {64}, {192}}, y);
auto concat = p1.add_instruction(
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", s);
auto y = mm1->add_parameter("y", s);
auto xslice1 = mm1->add_instruction(migraphx::op::slice{{0}, {0}, {64}}, x);
auto xslice2 = mm1->add_instruction(migraphx::op::slice{{0}, {192}, {256}}, x);
auto xslice3 = mm1->add_instruction(migraphx::op::slice{{0}, {64}, {192}}, x);
auto yslice1 = mm1->add_instruction(migraphx::op::slice{{0}, {0}, {64}}, y);
auto yslice2 = mm1->add_instruction(migraphx::op::slice{{0}, {192}, {256}}, y);
auto yslice3 = mm1->add_instruction(migraphx::op::slice{{0}, {64}, {192}}, y);
auto concat = mm1->add_instruction(
migraphx::op::concat{0}, xslice1, xslice2, xslice3, yslice1, yslice2, yslice3);
p1.add_instruction(pass_op{}, concat);
mm1->add_instruction(pass_op{}, concat);
}
migraphx::program p2 = p1;
run_pass(p1);
......@@ -826,37 +875,39 @@ TEST_CASE(simplify_split_add_relu)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = p1.add_parameter("input", s);
auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto add = p1.add_instruction(migraphx::op::add{}, relu1, relu2);
p1.add_instruction(pass_op{}, add);
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto add = mm1->add_instruction(migraphx::op::add{}, relu1, relu2);
mm1->add_instruction(pass_op{}, add);
}
run_pass(p1);
migraphx::program p2;
{
auto* mm2 = p2.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto input = p2.add_parameter("input", s);
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto concat = p2.add_instruction(migraphx::op::concat{0}, one, two);
auto concatb = p2.add_instruction(b, concat);
auto sum = p2.add_instruction(migraphx::op::add{}, input, concatb);
auto relu = p2.add_instruction(migraphx::op::relu{}, sum);
auto x = p2.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, relu);
auto y = p2.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, relu);
auto add = p2.add_instruction(migraphx::op::add{}, x, y);
p2.add_instruction(pass_op{}, add);
auto input = mm2->add_parameter("input", s);
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto concat = mm2->add_instruction(migraphx::op::concat{0}, one, two);
auto concatb = mm2->add_instruction(b, concat);
auto sum = mm2->add_instruction(migraphx::op::add{}, input, concatb);
auto relu = mm2->add_instruction(migraphx::op::relu{}, sum);
auto x = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, relu);
auto y = mm2->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, relu);
auto add = mm2->add_instruction(migraphx::op::add{}, x, y);
mm2->add_instruction(pass_op{}, add);
}
EXPECT(p1.sort() == p2.sort());
}
......@@ -866,41 +917,43 @@ TEST_CASE(simplify_split_add_relu_reshape)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto r = migraphx::op::reshape{{3, 4}};
auto input = p1.add_parameter("input", s);
auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto reshape1 = p1.add_instruction(r, relu1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto reshape2 = p1.add_instruction(r, relu2);
auto add = p1.add_instruction(migraphx::op::add{}, reshape1, reshape2);
p1.add_instruction(pass_op{}, add);
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto reshape1 = mm1->add_instruction(r, relu1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto reshape2 = mm1->add_instruction(r, relu2);
auto add = mm1->add_instruction(migraphx::op::add{}, reshape1, reshape2);
mm1->add_instruction(pass_op{}, add);
}
run_pass(p1);
migraphx::program p2;
{
auto* mm2 = p2.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto input = p2.add_parameter("input", s);
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto concat = p2.add_instruction(migraphx::op::concat{0}, one, two);
auto concatb = p2.add_instruction(b, concat);
auto sum = p2.add_instruction(migraphx::op::add{}, input, concatb);
auto relu = p2.add_instruction(migraphx::op::relu{}, sum);
auto rsp = p2.add_instruction(migraphx::op::reshape{{3, 8}}, relu);
auto slc1 = p2.add_instruction(migraphx::op::slice{{1}, {0}, {4}}, rsp);
auto slc2 = p2.add_instruction(migraphx::op::slice{{1}, {4}, {8}}, rsp);
auto add = p2.add_instruction(migraphx::op::add{}, slc1, slc2);
p2.add_instruction(pass_op{}, add);
auto input = mm2->add_parameter("input", s);
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto concat = mm2->add_instruction(migraphx::op::concat{0}, one, two);
auto concatb = mm2->add_instruction(b, concat);
auto sum = mm2->add_instruction(migraphx::op::add{}, input, concatb);
auto relu = mm2->add_instruction(migraphx::op::relu{}, sum);
auto rsp = mm2->add_instruction(migraphx::op::reshape{{3, 8}}, relu);
auto slc1 = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {4}}, rsp);
auto slc2 = mm2->add_instruction(migraphx::op::slice{{1}, {4}, {8}}, rsp);
auto add = mm2->add_instruction(migraphx::op::add{}, slc1, slc2);
mm2->add_instruction(pass_op{}, add);
}
EXPECT(p1.sort() == p2.sort());
}
......@@ -910,22 +963,23 @@ TEST_CASE(simplify_slice_different_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 2}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto r = migraphx::op::reshape{{3, 2, 4}};
auto input = p1.add_parameter("input", s);
auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = p1.add_instruction(migraphx::op::slice{{3}, {0}, {1}}, input);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(migraphx::op::broadcast{1, {3, 1, 4, 2}}, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(migraphx::op::broadcast{3, {3, 2, 4, 1}}, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto reshape1 = p1.add_instruction(r, relu1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto reshape2 = p1.add_instruction(r, relu2);
auto add = p1.add_instruction(migraphx::op::add{}, reshape1, reshape2);
p1.add_instruction(pass_op{}, add);
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = mm1->add_instruction(migraphx::op::slice{{3}, {0}, {1}}, input);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(migraphx::op::broadcast{1, {3, 1, 4, 2}}, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(migraphx::op::broadcast{3, {3, 2, 4, 1}}, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto reshape1 = mm1->add_instruction(r, relu1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto reshape2 = mm1->add_instruction(r, relu2);
auto add = mm1->add_instruction(migraphx::op::add{}, reshape1, reshape2);
mm1->add_instruction(pass_op{}, add);
}
migraphx::program p2 = p1;
run_pass(p1);
......@@ -938,20 +992,21 @@ TEST_CASE(simplify_slice_missing_begining_slice)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = p1.add_parameter("input", s);
auto x = p1.add_instruction(migraphx::op::slice{{1}, {2}, {3}}, input);
auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto add = p1.add_instruction(migraphx::op::add{}, relu1, relu2);
p1.add_instruction(pass_op{}, add);
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(migraphx::op::slice{{1}, {2}, {3}}, input);
auto y = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto add = mm1->add_instruction(migraphx::op::add{}, relu1, relu2);
mm1->add_instruction(pass_op{}, add);
}
migraphx::program p2 = p1;
run_pass(p1);
......@@ -964,20 +1019,21 @@ TEST_CASE(simplify_slice_missing_middle_slice)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = p1.add_parameter("input", s);
auto x = p1.add_instruction(migraphx::op::slice{{1}, {2}, {3}}, input);
auto y = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto add = p1.add_instruction(migraphx::op::add{}, relu1, relu2);
p1.add_instruction(pass_op{}, add);
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(migraphx::op::slice{{1}, {2}, {3}}, input);
auto y = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto add = mm1->add_instruction(migraphx::op::add{}, relu1, relu2);
mm1->add_instruction(pass_op{}, add);
}
migraphx::program p2 = p1;
run_pass(p1);
......@@ -990,20 +1046,21 @@ TEST_CASE(simplify_slice_missing_end_slice)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = p1.add_parameter("input", s);
auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto add = p1.add_instruction(migraphx::op::add{}, relu1, relu2);
p1.add_instruction(pass_op{}, add);
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto add = mm1->add_instruction(migraphx::op::add{}, relu1, relu2);
mm1->add_instruction(pass_op{}, add);
}
migraphx::program p2 = p1;
run_pass(p1);
......@@ -1016,34 +1073,36 @@ TEST_CASE(simplify_split_add_relu_concat_same_axis)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = p1.add_parameter("input", s);
auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto concat = p1.add_instruction(migraphx::op::concat{1}, relu1, relu2);
p1.add_instruction(pass_op{}, concat);
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto concat = mm1->add_instruction(migraphx::op::concat{1}, relu1, relu2);
mm1->add_instruction(pass_op{}, concat);
}
run_pass(p1);
migraphx::program p2;
{
auto* mm2 = p2.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto input = p2.add_parameter("input", s);
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto concat = p2.add_instruction(migraphx::op::concat{0}, one, two);
auto concatb = p2.add_instruction(b, concat);
auto sum = p2.add_instruction(migraphx::op::add{}, input, concatb);
auto relu = p2.add_instruction(migraphx::op::relu{}, sum);
p2.add_instruction(pass_op{}, relu);
auto input = mm2->add_parameter("input", s);
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto concat = mm2->add_instruction(migraphx::op::concat{0}, one, two);
auto concatb = mm2->add_instruction(b, concat);
auto sum = mm2->add_instruction(migraphx::op::add{}, input, concatb);
auto relu = mm2->add_instruction(migraphx::op::relu{}, sum);
mm2->add_instruction(pass_op{}, relu);
}
EXPECT(p1.sort() == p2.sort());
}
......@@ -1053,20 +1112,21 @@ TEST_CASE(simplify_split_add_relu_multi_axes)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 6}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 1, 4, 3}};
auto input = p1.add_parameter("input", s);
auto x = p1.add_instruction(migraphx::op::slice{{1, 3}, {0, 0}, {1, 3}}, input);
auto y = p1.add_instruction(migraphx::op::slice{{1, 3}, {1, 3}, {2, 6}}, input);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto add = p1.add_instruction(migraphx::op::add{}, relu1, relu2);
p1.add_instruction(pass_op{}, add);
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(migraphx::op::slice{{1, 3}, {0, 0}, {1, 3}}, input);
auto y = mm1->add_instruction(migraphx::op::slice{{1, 3}, {1, 3}, {2, 6}}, input);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto add = mm1->add_instruction(migraphx::op::add{}, relu1, relu2);
mm1->add_instruction(pass_op{}, add);
}
migraphx::program p2 = p1;
run_pass(p1);
......@@ -1078,40 +1138,42 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split1)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = p1.add_parameter("input", s);
auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto add1 = p1.add_instruction(migraphx::op::add{}, relu1, relu2);
auto add2 = p1.add_instruction(migraphx::op::add{}, x, add1);
p1.add_instruction(pass_op{}, add2);
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto add1 = mm1->add_instruction(migraphx::op::add{}, relu1, relu2);
auto add2 = mm1->add_instruction(migraphx::op::add{}, x, add1);
mm1->add_instruction(pass_op{}, add2);
}
run_pass(p1);
migraphx::program p2;
{
auto* mm2 = p2.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto input = p2.add_parameter("input", s);
auto slice = p2.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto concat = p2.add_instruction(migraphx::op::concat{0}, one, two);
auto concatb = p2.add_instruction(b, concat);
auto sum = p2.add_instruction(migraphx::op::add{}, input, concatb);
auto relu = p2.add_instruction(migraphx::op::relu{}, sum);
auto x = p2.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, relu);
auto y = p2.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, relu);
auto add1 = p2.add_instruction(migraphx::op::add{}, x, y);
auto add2 = p2.add_instruction(migraphx::op::add{}, slice, add1);
p2.add_instruction(pass_op{}, add2);
auto input = mm2->add_parameter("input", s);
auto slice = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto concat = mm2->add_instruction(migraphx::op::concat{0}, one, two);
auto concatb = mm2->add_instruction(b, concat);
auto sum = mm2->add_instruction(migraphx::op::add{}, input, concatb);
auto relu = mm2->add_instruction(migraphx::op::relu{}, sum);
auto x = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, relu);
auto y = mm2->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, relu);
auto add1 = mm2->add_instruction(migraphx::op::add{}, x, y);
auto add2 = mm2->add_instruction(migraphx::op::add{}, slice, add1);
mm2->add_instruction(pass_op{}, add2);
}
EXPECT(p1.sort() == p2.sort());
}
......@@ -1121,42 +1183,44 @@ TEST_CASE(simplify_split_add_relu_used_multiple_split2)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::program p1;
{
auto* mm1 = p1.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = p1.add_parameter("input", s);
auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto z = p1.add_instruction(migraphx::op::relu{}, x);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto add1 = p1.add_instruction(migraphx::op::add{}, relu1, relu2);
auto add2 = p1.add_instruction(migraphx::op::add{}, z, add1);
p1.add_instruction(pass_op{}, add2);
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto z = mm1->add_instruction(migraphx::op::relu{}, x);
auto one = mm1->add_literal(1);
auto oneb = mm1->add_instruction(b, one);
auto two = mm1->add_literal(2);
auto twob = mm1->add_instruction(b, two);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = mm1->add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = mm1->add_instruction(migraphx::op::relu{}, sum2);
auto add1 = mm1->add_instruction(migraphx::op::add{}, relu1, relu2);
auto add2 = mm1->add_instruction(migraphx::op::add{}, z, add1);
mm1->add_instruction(pass_op{}, add2);
}
run_pass(p1);
migraphx::program p2;
{
auto* mm2 = p2.get_main_module();
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto input = p2.add_parameter("input", s);
auto slice = p2.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto z = p2.add_instruction(migraphx::op::relu{}, slice);
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto concat = p2.add_instruction(migraphx::op::concat{0}, one, two);
auto concatb = p2.add_instruction(b, concat);
auto sum = p2.add_instruction(migraphx::op::add{}, input, concatb);
auto relu = p2.add_instruction(migraphx::op::relu{}, sum);
auto x = p2.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, relu);
auto y = p2.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, relu);
auto add1 = p2.add_instruction(migraphx::op::add{}, x, y);
auto add2 = p2.add_instruction(migraphx::op::add{}, z, add1);
p2.add_instruction(pass_op{}, add2);
auto input = mm2->add_parameter("input", s);
auto slice = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto z = mm2->add_instruction(migraphx::op::relu{}, slice);
auto one = mm2->add_literal(1);
auto two = mm2->add_literal(2);
auto concat = mm2->add_instruction(migraphx::op::concat{0}, one, two);
auto concatb = mm2->add_instruction(b, concat);
auto sum = mm2->add_instruction(migraphx::op::add{}, input, concatb);
auto relu = mm2->add_instruction(migraphx::op::relu{}, sum);
auto x = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, relu);
auto y = mm2->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, relu);
auto add1 = mm2->add_instruction(migraphx::op::add{}, x, y);
auto add2 = mm2->add_instruction(migraphx::op::add{}, z, add1);
mm2->add_instruction(pass_op{}, add2);
}
EXPECT(p1.sort() == p2.sort());
}
......@@ -1166,11 +1230,12 @@ TEST_CASE(simplify_split_between_add)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::program p1;
{
auto input = p1.add_parameter("input", s);
auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto sum = p1.add_instruction(migraphx::op::add{}, x, y);
p1.add_instruction(pass_op{}, sum);
auto* mm1 = p1.get_main_module();
auto input = mm1->add_parameter("input", s);
auto x = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = mm1->add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto sum = mm1->add_instruction(migraphx::op::add{}, x, y);
mm1->add_instruction(pass_op{}, sum);
}
migraphx::program p2 = p1;
run_pass(p1);
......@@ -1182,27 +1247,29 @@ TEST_CASE(simplify_dot_horiz)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}};
migraphx::program p1;
{
auto input = p1.add_parameter("input", s);
auto a = p1.add_literal(migraphx::generate_literal(s, 0));
auto b = p1.add_literal(migraphx::generate_literal(s, 1));
auto x = p1.add_instruction(migraphx::op::dot{}, input, a);
auto y = p1.add_instruction(migraphx::op::dot{}, input, b);
auto sum = p1.add_instruction(migraphx::op::add{}, x, y);
p1.add_instruction(pass_op{}, sum);
auto* mm1 = p1.get_main_module();
auto input = mm1->add_parameter("input", s);
auto a = mm1->add_literal(migraphx::generate_literal(s, 0));
auto b = mm1->add_literal(migraphx::generate_literal(s, 1));
auto x = mm1->add_instruction(migraphx::op::dot{}, input, a);
auto y = mm1->add_instruction(migraphx::op::dot{}, input, b);
auto sum = mm1->add_instruction(migraphx::op::add{}, x, y);
mm1->add_instruction(pass_op{}, sum);
}
run_pass(p1);
migraphx::program p2;
{
auto input = p2.add_parameter("input", s);
auto a = p2.add_literal(migraphx::generate_literal(s, 0));
auto b = p2.add_literal(migraphx::generate_literal(s, 1));
auto concat = p2.add_instruction(migraphx::op::concat{2}, a, b);
auto dot = p2.add_instruction(migraphx::op::dot{}, input, concat);
auto x = p2.add_instruction(migraphx::op::slice{{2}, {0}, {2}}, dot);
auto y = p2.add_instruction(migraphx::op::slice{{2}, {2}, {4}}, dot);
auto sum = p2.add_instruction(migraphx::op::add{}, x, y);
p2.add_instruction(pass_op{}, sum);
auto* mm2 = p2.get_main_module();
auto input = mm2->add_parameter("input", s);
auto a = mm2->add_literal(migraphx::generate_literal(s, 0));
auto b = mm2->add_literal(migraphx::generate_literal(s, 1));
auto concat = mm2->add_instruction(migraphx::op::concat{2}, a, b);
auto dot = mm2->add_instruction(migraphx::op::dot{}, input, concat);
auto x = mm2->add_instruction(migraphx::op::slice{{2}, {0}, {2}}, dot);
auto y = mm2->add_instruction(migraphx::op::slice{{2}, {2}, {4}}, dot);
auto sum = mm2->add_instruction(migraphx::op::add{}, x, y);
mm2->add_instruction(pass_op{}, sum);
}
EXPECT(p1.sort() == p2.sort());
}
......@@ -1212,25 +1279,27 @@ TEST_CASE(simplify_dot_horiz_same_constant)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}};
migraphx::program p1;
{
auto input = p1.add_parameter("input", s);
auto a = p1.add_literal(migraphx::generate_literal(s, 0));
auto x = p1.add_instruction(migraphx::op::dot{}, input, a);
auto y = p1.add_instruction(migraphx::op::dot{}, input, a);
auto sum = p1.add_instruction(migraphx::op::add{}, x, y);
p1.add_instruction(pass_op{}, sum);
auto* mm1 = p1.get_main_module();
auto input = mm1->add_parameter("input", s);
auto a = mm1->add_literal(migraphx::generate_literal(s, 0));
auto x = mm1->add_instruction(migraphx::op::dot{}, input, a);
auto y = mm1->add_instruction(migraphx::op::dot{}, input, a);
auto sum = mm1->add_instruction(migraphx::op::add{}, x, y);
mm1->add_instruction(pass_op{}, sum);
}
run_pass(p1);
migraphx::program p2;
{
auto input = p2.add_parameter("input", s);
auto a = p2.add_literal(migraphx::generate_literal(s, 0));
auto concat = p2.add_instruction(migraphx::op::concat{2}, a, a);
auto dot = p2.add_instruction(migraphx::op::dot{}, input, concat);
auto x = p2.add_instruction(migraphx::op::slice{{2}, {0}, {2}}, dot);
auto y = p2.add_instruction(migraphx::op::slice{{2}, {2}, {4}}, dot);
auto sum = p2.add_instruction(migraphx::op::add{}, x, y);
p2.add_instruction(pass_op{}, sum);
auto* mm2 = p2.get_main_module();
auto input = mm2->add_parameter("input", s);
auto a = mm2->add_literal(migraphx::generate_literal(s, 0));
auto concat = mm2->add_instruction(migraphx::op::concat{2}, a, a);
auto dot = mm2->add_instruction(migraphx::op::dot{}, input, concat);
auto x = mm2->add_instruction(migraphx::op::slice{{2}, {0}, {2}}, dot);
auto y = mm2->add_instruction(migraphx::op::slice{{2}, {2}, {4}}, dot);
auto sum = mm2->add_instruction(migraphx::op::add{}, x, y);
mm2->add_instruction(pass_op{}, sum);
}
EXPECT(p1.sort() == p2.sort());
}
......@@ -1240,13 +1309,14 @@ TEST_CASE(simplify_dot_horiz_flipped)
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}};
migraphx::program p1;
{
auto input = p1.add_parameter("input", s);
auto a = p1.add_literal(migraphx::generate_literal(s, 0));
auto b = p1.add_literal(migraphx::generate_literal(s, 1));
auto x = p1.add_instruction(migraphx::op::dot{}, input, a);
auto y = p1.add_instruction(migraphx::op::dot{}, b, input);
auto sum = p1.add_instruction(migraphx::op::add{}, x, y);
p1.add_instruction(pass_op{}, sum);
auto* mm1 = p1.get_main_module();
auto input = mm1->add_parameter("input", s);
auto a = mm1->add_literal(migraphx::generate_literal(s, 0));
auto b = mm1->add_literal(migraphx::generate_literal(s, 1));
auto x = mm1->add_instruction(migraphx::op::dot{}, input, a);
auto y = mm1->add_instruction(migraphx::op::dot{}, b, input);
auto sum = mm1->add_instruction(migraphx::op::add{}, x, y);
mm1->add_instruction(pass_op{}, sum);
}
migraphx::program p2 = p1;
......@@ -1260,27 +1330,29 @@ TEST_CASE(simplify_conv_horiz)
auto ws = migraphx::shape{migraphx::shape::int32_type, {12, 3, 3, 3}};
migraphx::program p1;
{
auto input = p1.add_parameter("input", s);
auto a = p1.add_literal(migraphx::generate_literal(ws, 0));
auto b = p1.add_literal(migraphx::generate_literal(ws, 1));
auto x = p1.add_instruction(migraphx::op::convolution{}, input, a);
auto y = p1.add_instruction(migraphx::op::convolution{}, input, b);
auto sum = p1.add_instruction(migraphx::op::add{}, x, y);
p1.add_instruction(pass_op{}, sum);
auto* mm1 = p1.get_main_module();
auto input = mm1->add_parameter("input", s);
auto a = mm1->add_literal(migraphx::generate_literal(ws, 0));
auto b = mm1->add_literal(migraphx::generate_literal(ws, 1));
auto x = mm1->add_instruction(migraphx::op::convolution{}, input, a);
auto y = mm1->add_instruction(migraphx::op::convolution{}, input, b);
auto sum = mm1->add_instruction(migraphx::op::add{}, x, y);
mm1->add_instruction(pass_op{}, sum);
}
run_pass(p1);
migraphx::program p2;
{
auto input = p2.add_parameter("input", s);
auto a = p2.add_literal(migraphx::generate_literal(ws, 0));
auto b = p2.add_literal(migraphx::generate_literal(ws, 1));
auto concat = p2.add_instruction(migraphx::op::concat{0}, a, b);
auto conv = p2.add_instruction(migraphx::op::convolution{}, input, concat);
auto x = p2.add_instruction(migraphx::op::slice{{1}, {0}, {12}}, conv);
auto y = p2.add_instruction(migraphx::op::slice{{1}, {12}, {24}}, conv);
auto sum = p2.add_instruction(migraphx::op::add{}, x, y);
p2.add_instruction(pass_op{}, sum);
auto* mm2 = p2.get_main_module();
auto input = mm2->add_parameter("input", s);
auto a = mm2->add_literal(migraphx::generate_literal(ws, 0));
auto b = mm2->add_literal(migraphx::generate_literal(ws, 1));
auto concat = mm2->add_instruction(migraphx::op::concat{0}, a, b);
auto conv = mm2->add_instruction(migraphx::op::convolution{}, input, concat);
auto x = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {12}}, conv);
auto y = mm2->add_instruction(migraphx::op::slice{{1}, {12}, {24}}, conv);
auto sum = mm2->add_instruction(migraphx::op::add{}, x, y);
mm2->add_instruction(pass_op{}, sum);
}
EXPECT(p1.sort() == p2.sort());
}
......@@ -1291,14 +1363,15 @@ TEST_CASE(simplify_group_conv_horiz)
auto ws = migraphx::shape{migraphx::shape::int32_type, {32, 1, 7, 7}};
migraphx::program p1;
{
auto x = p1.add_parameter("x", s);
auto w1 = p1.add_literal(migraphx::generate_literal(ws, 1));
auto w2 = p1.add_literal(migraphx::generate_literal(ws, 2));
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", s);
auto w1 = mm1->add_literal(migraphx::generate_literal(ws, 1));
auto w2 = mm1->add_literal(migraphx::generate_literal(ws, 2));
auto conv1 =
p1.add_instruction(migraphx::op::convolution{{3, 3}, {2, 2}, {1, 1}, 32}, x, w1);
mm1->add_instruction(migraphx::op::convolution{{3, 3}, {2, 2}, {1, 1}, 32}, x, w1);
auto conv2 =
p1.add_instruction(migraphx::op::convolution{{3, 3}, {2, 2}, {1, 1}, 32}, x, w2);
p1.add_instruction(pass_op{}, conv1, conv2);
mm1->add_instruction(migraphx::op::convolution{{3, 3}, {2, 2}, {1, 1}, 32}, x, w2);
mm1->add_instruction(pass_op{}, conv1, conv2);
}
migraphx::program p2 = p1;
run_pass(p1);
......@@ -1313,42 +1386,44 @@ TEST_CASE(simplify_conv_horiz_grouped)
auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
migraphx::program p1;
{
auto input = p1.add_parameter("input", s);
auto a = p1.add_literal(migraphx::generate_literal(ws1, 0));
auto b = p1.add_literal(migraphx::generate_literal(ws1, 1));
auto c = p1.add_literal(migraphx::generate_literal(ws2, 2));
auto d = p1.add_literal(migraphx::generate_literal(ws2, 3));
auto convx = p1.add_instruction(migraphx::op::convolution{{1, 1}}, input, a);
auto convy = p1.add_instruction(migraphx::op::convolution{{1, 1}}, input, b);
auto dotx = p1.add_instruction(migraphx::op::dot{}, input, c);
auto doty = p1.add_instruction(migraphx::op::dot{}, input, d);
auto sum1 = p1.add_instruction(migraphx::op::add{}, convx, convy);
auto sum2 = p1.add_instruction(migraphx::op::add{}, dotx, doty);
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3);
auto* mm1 = p1.get_main_module();
auto input = mm1->add_parameter("input", s);
auto a = mm1->add_literal(migraphx::generate_literal(ws1, 0));
auto b = mm1->add_literal(migraphx::generate_literal(ws1, 1));
auto c = mm1->add_literal(migraphx::generate_literal(ws2, 2));
auto d = mm1->add_literal(migraphx::generate_literal(ws2, 3));
auto convx = mm1->add_instruction(migraphx::op::convolution{{1, 1}}, input, a);
auto convy = mm1->add_instruction(migraphx::op::convolution{{1, 1}}, input, b);
auto dotx = mm1->add_instruction(migraphx::op::dot{}, input, c);
auto doty = mm1->add_instruction(migraphx::op::dot{}, input, d);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, convx, convy);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, dotx, doty);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
mm1->add_instruction(pass_op{}, sum3);
}
run_pass(p1);
migraphx::program p2;
{
auto input = p2.add_parameter("input", s);
auto a = p2.add_literal(migraphx::generate_literal(ws1, 0));
auto b = p2.add_literal(migraphx::generate_literal(ws1, 1));
auto c = p2.add_literal(migraphx::generate_literal(ws2, 2));
auto d = p2.add_literal(migraphx::generate_literal(ws2, 3));
auto concat1 = p2.add_instruction(migraphx::op::concat{0}, a, b);
auto concat2 = p2.add_instruction(migraphx::op::concat{3}, c, d);
auto conv = p2.add_instruction(migraphx::op::convolution{{1, 1}}, input, concat1);
auto convx = p2.add_instruction(migraphx::op::slice{{1}, {0}, {6}}, conv);
auto convy = p2.add_instruction(migraphx::op::slice{{1}, {6}, {12}}, conv);
auto sum1 = p2.add_instruction(migraphx::op::add{}, convx, convy);
auto dot = p2.add_instruction(migraphx::op::dot{}, input, concat2);
auto dotx = p2.add_instruction(migraphx::op::slice{{3}, {0}, {64}}, dot);
auto doty = p2.add_instruction(migraphx::op::slice{{3}, {64}, {128}}, dot);
auto sum2 = p2.add_instruction(migraphx::op::add{}, dotx, doty);
auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, sum2);
p2.add_instruction(pass_op{}, sum3);
auto* mm2 = p2.get_main_module();
auto input = mm2->add_parameter("input", s);
auto a = mm2->add_literal(migraphx::generate_literal(ws1, 0));
auto b = mm2->add_literal(migraphx::generate_literal(ws1, 1));
auto c = mm2->add_literal(migraphx::generate_literal(ws2, 2));
auto d = mm2->add_literal(migraphx::generate_literal(ws2, 3));
auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, a, b);
auto concat2 = mm2->add_instruction(migraphx::op::concat{3}, c, d);
auto conv = mm2->add_instruction(migraphx::op::convolution{{1, 1}}, input, concat1);
auto convx = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {6}}, conv);
auto convy = mm2->add_instruction(migraphx::op::slice{{1}, {6}, {12}}, conv);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, convx, convy);
auto dot = mm2->add_instruction(migraphx::op::dot{}, input, concat2);
auto dotx = mm2->add_instruction(migraphx::op::slice{{3}, {0}, {64}}, dot);
auto doty = mm2->add_instruction(migraphx::op::slice{{3}, {64}, {128}}, dot);
auto sum2 = mm2->add_instruction(migraphx::op::add{}, dotx, doty);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, sum1, sum2);
mm2->add_instruction(pass_op{}, sum3);
}
EXPECT(p1.sort() == p2.sort());
}
......@@ -1360,49 +1435,51 @@ TEST_CASE(simplify_conv_horiz_grouped_extra1)
auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
migraphx::program p1;
{
auto input = p1.add_parameter("input", s);
auto a = p1.add_literal(migraphx::generate_literal(ws1, 0));
auto b = p1.add_literal(migraphx::generate_literal(ws1, 1));
auto c = p1.add_literal(migraphx::generate_literal(ws2, 2));
auto d = p1.add_literal(migraphx::generate_literal(ws2, 3));
auto e = p1.add_literal(migraphx::generate_literal(s, 4));
auto convx = p1.add_instruction(migraphx::op::convolution{{1, 1}}, input, a);
auto convy = p1.add_instruction(migraphx::op::convolution{{1, 1}}, input, b);
auto dotx = p1.add_instruction(migraphx::op::dot{}, input, c);
auto doty = p1.add_instruction(migraphx::op::dot{}, input, d);
auto sqdiffx = p1.add_instruction(migraphx::op::sqdiff{}, input, e);
auto sum1 = p1.add_instruction(migraphx::op::add{}, convx, convy);
auto sum2 = p1.add_instruction(migraphx::op::add{}, dotx, doty);
auto* mm1 = p1.get_main_module();
auto input = mm1->add_parameter("input", s);
auto a = mm1->add_literal(migraphx::generate_literal(ws1, 0));
auto b = mm1->add_literal(migraphx::generate_literal(ws1, 1));
auto c = mm1->add_literal(migraphx::generate_literal(ws2, 2));
auto d = mm1->add_literal(migraphx::generate_literal(ws2, 3));
auto e = mm1->add_literal(migraphx::generate_literal(s, 4));
auto convx = mm1->add_instruction(migraphx::op::convolution{{1, 1}}, input, a);
auto convy = mm1->add_instruction(migraphx::op::convolution{{1, 1}}, input, b);
auto dotx = mm1->add_instruction(migraphx::op::dot{}, input, c);
auto doty = mm1->add_instruction(migraphx::op::dot{}, input, d);
auto sqdiffx = mm1->add_instruction(migraphx::op::sqdiff{}, input, e);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, convx, convy);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, dotx, doty);
auto sum3 = sqdiffx;
auto sum4 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum5 = p1.add_instruction(migraphx::op::add{}, sum4, sum3);
p1.add_instruction(pass_op{}, sum5);
auto sum4 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum5 = mm1->add_instruction(migraphx::op::add{}, sum4, sum3);
mm1->add_instruction(pass_op{}, sum5);
}
run_pass(p1);
migraphx::program p2;
{
auto input = p2.add_parameter("input", s);
auto a = p2.add_literal(migraphx::generate_literal(ws1, 0));
auto b = p2.add_literal(migraphx::generate_literal(ws1, 1));
auto c = p2.add_literal(migraphx::generate_literal(ws2, 2));
auto d = p2.add_literal(migraphx::generate_literal(ws2, 3));
auto e = p2.add_literal(migraphx::generate_literal(s, 4));
auto concat1 = p2.add_instruction(migraphx::op::concat{0}, a, b);
auto concat2 = p2.add_instruction(migraphx::op::concat{3}, c, d);
auto conv = p2.add_instruction(migraphx::op::convolution{{1, 1}}, input, concat1);
auto convx = p2.add_instruction(migraphx::op::slice{{1}, {0}, {6}}, conv);
auto convy = p2.add_instruction(migraphx::op::slice{{1}, {6}, {12}}, conv);
auto sum1 = p2.add_instruction(migraphx::op::add{}, convx, convy);
auto dot = p2.add_instruction(migraphx::op::dot{}, input, concat2);
auto dotx = p2.add_instruction(migraphx::op::slice{{3}, {0}, {64}}, dot);
auto doty = p2.add_instruction(migraphx::op::slice{{3}, {64}, {128}}, dot);
auto sum2 = p2.add_instruction(migraphx::op::add{}, dotx, doty);
auto sqdiffx = p2.add_instruction(migraphx::op::sqdiff{}, input, e);
auto* mm2 = p2.get_main_module();
auto input = mm2->add_parameter("input", s);
auto a = mm2->add_literal(migraphx::generate_literal(ws1, 0));
auto b = mm2->add_literal(migraphx::generate_literal(ws1, 1));
auto c = mm2->add_literal(migraphx::generate_literal(ws2, 2));
auto d = mm2->add_literal(migraphx::generate_literal(ws2, 3));
auto e = mm2->add_literal(migraphx::generate_literal(s, 4));
auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, a, b);
auto concat2 = mm2->add_instruction(migraphx::op::concat{3}, c, d);
auto conv = mm2->add_instruction(migraphx::op::convolution{{1, 1}}, input, concat1);
auto convx = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {6}}, conv);
auto convy = mm2->add_instruction(migraphx::op::slice{{1}, {6}, {12}}, conv);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, convx, convy);
auto dot = mm2->add_instruction(migraphx::op::dot{}, input, concat2);
auto dotx = mm2->add_instruction(migraphx::op::slice{{3}, {0}, {64}}, dot);
auto doty = mm2->add_instruction(migraphx::op::slice{{3}, {64}, {128}}, dot);
auto sum2 = mm2->add_instruction(migraphx::op::add{}, dotx, doty);
auto sqdiffx = mm2->add_instruction(migraphx::op::sqdiff{}, input, e);
auto sum3 = sqdiffx;
auto sum4 = p2.add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum5 = p2.add_instruction(migraphx::op::add{}, sum4, sum3);
p2.add_instruction(pass_op{}, sum5);
auto sum4 = mm2->add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum5 = mm2->add_instruction(migraphx::op::add{}, sum4, sum3);
mm2->add_instruction(pass_op{}, sum5);
}
EXPECT(p1.sort() == p2.sort());
}
......@@ -1414,53 +1491,55 @@ TEST_CASE(simplify_conv_horiz_grouped_extra2)
auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
migraphx::program p1;
{
auto input = p1.add_parameter("input", s);
auto a = p1.add_literal(migraphx::generate_literal(ws1, 0));
auto b = p1.add_literal(migraphx::generate_literal(ws1, 1));
auto c = p1.add_literal(migraphx::generate_literal(ws2, 2));
auto d = p1.add_literal(migraphx::generate_literal(ws2, 3));
auto e = p1.add_literal(migraphx::generate_literal(s, 4));
auto f = p1.add_literal(migraphx::generate_literal(s, 5));
auto convx = p1.add_instruction(migraphx::op::convolution{{1, 1}}, input, a);
auto convy = p1.add_instruction(migraphx::op::convolution{{1, 1}}, input, b);
auto dotx = p1.add_instruction(migraphx::op::dot{}, input, c);
auto doty = p1.add_instruction(migraphx::op::dot{}, input, d);
auto sqdiffx = p1.add_instruction(migraphx::op::sqdiff{}, input, e);
auto sqdiffy = p1.add_instruction(migraphx::op::sqdiff{}, input, f);
auto sum1 = p1.add_instruction(migraphx::op::add{}, convx, convy);
auto sum2 = p1.add_instruction(migraphx::op::add{}, dotx, doty);
auto sum3 = p1.add_instruction(migraphx::op::add{}, sqdiffx, sqdiffy);
auto sum4 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum5 = p1.add_instruction(migraphx::op::add{}, sum4, sum3);
p1.add_instruction(pass_op{}, sum5);
auto* mm1 = p1.get_main_module();
auto input = mm1->add_parameter("input", s);
auto a = mm1->add_literal(migraphx::generate_literal(ws1, 0));
auto b = mm1->add_literal(migraphx::generate_literal(ws1, 1));
auto c = mm1->add_literal(migraphx::generate_literal(ws2, 2));
auto d = mm1->add_literal(migraphx::generate_literal(ws2, 3));
auto e = mm1->add_literal(migraphx::generate_literal(s, 4));
auto f = mm1->add_literal(migraphx::generate_literal(s, 5));
auto convx = mm1->add_instruction(migraphx::op::convolution{{1, 1}}, input, a);
auto convy = mm1->add_instruction(migraphx::op::convolution{{1, 1}}, input, b);
auto dotx = mm1->add_instruction(migraphx::op::dot{}, input, c);
auto doty = mm1->add_instruction(migraphx::op::dot{}, input, d);
auto sqdiffx = mm1->add_instruction(migraphx::op::sqdiff{}, input, e);
auto sqdiffy = mm1->add_instruction(migraphx::op::sqdiff{}, input, f);
auto sum1 = mm1->add_instruction(migraphx::op::add{}, convx, convy);
auto sum2 = mm1->add_instruction(migraphx::op::add{}, dotx, doty);
auto sum3 = mm1->add_instruction(migraphx::op::add{}, sqdiffx, sqdiffy);
auto sum4 = mm1->add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum5 = mm1->add_instruction(migraphx::op::add{}, sum4, sum3);
mm1->add_instruction(pass_op{}, sum5);
}
run_pass(p1);
migraphx::program p2;
{
auto input = p2.add_parameter("input", s);
auto a = p2.add_literal(migraphx::generate_literal(ws1, 0));
auto b = p2.add_literal(migraphx::generate_literal(ws1, 1));
auto c = p2.add_literal(migraphx::generate_literal(ws2, 2));
auto d = p2.add_literal(migraphx::generate_literal(ws2, 3));
auto e = p2.add_literal(migraphx::generate_literal(s, 4));
auto f = p2.add_literal(migraphx::generate_literal(s, 5));
auto concat1 = p2.add_instruction(migraphx::op::concat{0}, a, b);
auto concat2 = p2.add_instruction(migraphx::op::concat{3}, c, d);
auto conv = p2.add_instruction(migraphx::op::convolution{{1, 1}}, input, concat1);
auto convx = p2.add_instruction(migraphx::op::slice{{1}, {0}, {6}}, conv);
auto convy = p2.add_instruction(migraphx::op::slice{{1}, {6}, {12}}, conv);
auto sum1 = p2.add_instruction(migraphx::op::add{}, convx, convy);
auto dot = p2.add_instruction(migraphx::op::dot{}, input, concat2);
auto dotx = p2.add_instruction(migraphx::op::slice{{3}, {0}, {64}}, dot);
auto doty = p2.add_instruction(migraphx::op::slice{{3}, {64}, {128}}, dot);
auto sum2 = p2.add_instruction(migraphx::op::add{}, dotx, doty);
auto sqdiffx = p2.add_instruction(migraphx::op::sqdiff{}, input, e);
auto sqdiffy = p2.add_instruction(migraphx::op::sqdiff{}, input, f);
auto sum3 = p2.add_instruction(migraphx::op::add{}, sqdiffx, sqdiffy);
auto sum4 = p2.add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum5 = p2.add_instruction(migraphx::op::add{}, sum4, sum3);
p2.add_instruction(pass_op{}, sum5);
auto* mm2 = p2.get_main_module();
auto input = mm2->add_parameter("input", s);
auto a = mm2->add_literal(migraphx::generate_literal(ws1, 0));
auto b = mm2->add_literal(migraphx::generate_literal(ws1, 1));
auto c = mm2->add_literal(migraphx::generate_literal(ws2, 2));
auto d = mm2->add_literal(migraphx::generate_literal(ws2, 3));
auto e = mm2->add_literal(migraphx::generate_literal(s, 4));
auto f = mm2->add_literal(migraphx::generate_literal(s, 5));
auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, a, b);
auto concat2 = mm2->add_instruction(migraphx::op::concat{3}, c, d);
auto conv = mm2->add_instruction(migraphx::op::convolution{{1, 1}}, input, concat1);
auto convx = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {6}}, conv);
auto convy = mm2->add_instruction(migraphx::op::slice{{1}, {6}, {12}}, conv);
auto sum1 = mm2->add_instruction(migraphx::op::add{}, convx, convy);
auto dot = mm2->add_instruction(migraphx::op::dot{}, input, concat2);
auto dotx = mm2->add_instruction(migraphx::op::slice{{3}, {0}, {64}}, dot);
auto doty = mm2->add_instruction(migraphx::op::slice{{3}, {64}, {128}}, dot);
auto sum2 = mm2->add_instruction(migraphx::op::add{}, dotx, doty);
auto sqdiffx = mm2->add_instruction(migraphx::op::sqdiff{}, input, e);
auto sqdiffy = mm2->add_instruction(migraphx::op::sqdiff{}, input, f);
auto sum3 = mm2->add_instruction(migraphx::op::add{}, sqdiffx, sqdiffy);
auto sum4 = mm2->add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum5 = mm2->add_instruction(migraphx::op::add{}, sum4, sum3);
mm2->add_instruction(pass_op{}, sum5);
}
EXPECT(p1.sort() == p2.sort());
}
......@@ -1469,51 +1548,53 @@ TEST_CASE(simplify_mul_slice_conv_horiz_fusion)
{
migraphx::program p1;
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = p1.add_literal(
auto* mm1 = p1.get_main_module();
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = mm1->add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
auto conv = p1.add_instruction(migraphx::op::convolution{}, x, w);
auto slice1 = p1.add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
auto conv = mm1->add_instruction(migraphx::op::convolution{}, x, w);
auto slice1 = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {384}}, conv);
auto a1 =
p1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1));
auto b1 = p1.add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a1);
auto mul = p1.add_instruction(migraphx::op::mul{}, slice1, b1);
mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1));
auto b1 = mm1->add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a1);
auto mul = mm1->add_instruction(migraphx::op::mul{}, slice1, b1);
auto a2 =
p1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 2));
auto b2 = p1.add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a2);
auto add1 = p1.add_instruction(migraphx::op::add{}, mul, b2);
mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 2));
auto b2 = mm1->add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a2);
auto add1 = mm1->add_instruction(migraphx::op::add{}, mul, b2);
auto a3 =
p1.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3));
auto b3 = p1.add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a3);
auto slice2 = p1.add_instruction(migraphx::op::slice{{1}, {384}, {768}}, conv);
auto add2 = p1.add_instruction(migraphx::op::add{}, slice2, b3);
p1.add_instruction(pass_op{}, add1, add2);
mm1->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3));
auto b3 = mm1->add_instruction(migraphx::op::broadcast{1, {1, 384, 17, 17}}, a3);
auto slice2 = mm1->add_instruction(migraphx::op::slice{{1}, {384}, {768}}, conv);
auto add2 = mm1->add_instruction(migraphx::op::add{}, slice2, b3);
mm1->add_instruction(pass_op{}, add1, add2);
}
run_pass(p1);
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = p2.add_literal(
auto* mm2 = p2.get_main_module();
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {1, 1024, 17, 17}});
auto w = mm2->add_literal(
migraphx::generate_literal({migraphx::shape::int32_type, {768, 1024, 1, 1}}));
auto wslice1 = p2.add_instruction(migraphx::op::slice{{0}, {0}, {384}}, w);
auto wslice1 = mm2->add_instruction(migraphx::op::slice{{0}, {0}, {384}}, w);
auto a1 =
p2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1));
auto b1 = p2.add_instruction(migraphx::op::broadcast{0, {384, 1024, 1, 1}}, a1);
auto mul = p2.add_instruction(migraphx::op::mul{}, b1, wslice1);
auto wslice2 = p2.add_instruction(migraphx::op::slice{{0}, {384}, {768}}, w);
auto concat1 = p2.add_instruction(migraphx::op::concat{0}, mul, wslice2);
auto conv = p2.add_instruction(migraphx::op::convolution{}, x, concat1);
mm2->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 1));
auto b1 = mm2->add_instruction(migraphx::op::broadcast{0, {384, 1024, 1, 1}}, a1);
auto mul = mm2->add_instruction(migraphx::op::mul{}, b1, wslice1);
auto wslice2 = mm2->add_instruction(migraphx::op::slice{{0}, {384}, {768}}, w);
auto concat1 = mm2->add_instruction(migraphx::op::concat{0}, mul, wslice2);
auto conv = mm2->add_instruction(migraphx::op::convolution{}, x, concat1);
auto a2 =
p2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 2));
mm2->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 2));
auto a3 =
p2.add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3));
auto concat2 = p2.add_instruction(migraphx::op::concat{}, a2, a3);
auto b4 = p2.add_instruction(migraphx::op::broadcast{1, {1, 768, 17, 17}}, concat2);
auto add = p2.add_instruction(migraphx::op::add{}, conv, b4);
auto slice1 = p2.add_instruction(migraphx::op::slice{{1}, {0}, {384}}, add);
auto slice2 = p2.add_instruction(migraphx::op::slice{{1}, {384}, {768}}, add);
p2.add_instruction(pass_op{}, slice1, slice2);
mm2->add_literal(migraphx::generate_literal({migraphx::shape::int32_type, {384}}, 3));
auto concat2 = mm2->add_instruction(migraphx::op::concat{}, a2, a3);
auto b4 = mm2->add_instruction(migraphx::op::broadcast{1, {1, 768, 17, 17}}, concat2);
auto add = mm2->add_instruction(migraphx::op::add{}, conv, b4);
auto slice1 = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {384}}, add);
auto slice2 = mm2->add_instruction(migraphx::op::slice{{1}, {384}, {768}}, add);
mm2->add_instruction(pass_op{}, slice1, slice2);
}
EXPECT(p1.sort() == p2.sort());
}
......@@ -1523,50 +1604,52 @@ TEST_CASE(reorder_reshape_slice)
std::vector<int64_t> perm1 = {0, 2, 3, 1};
auto create_p1 = [&](std::size_t batch_size) {
migraphx::program p1;
auto* mm1 = p1.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
auto input = p1.add_parameter("input", s);
auto slc0 = p1.add_instruction(migraphx::op::slice{{2}, {0}, {640}}, input);
auto slc1 = p1.add_instruction(migraphx::op::slice{{2}, {640}, {1280}}, input);
auto slc2 = p1.add_instruction(migraphx::op::slice{{2}, {1280}, {1920}}, input);
auto input = mm1->add_parameter("input", s);
auto slc0 = mm1->add_instruction(migraphx::op::slice{{2}, {0}, {640}}, input);
auto slc1 = mm1->add_instruction(migraphx::op::slice{{2}, {640}, {1280}}, input);
auto slc2 = mm1->add_instruction(migraphx::op::slice{{2}, {1280}, {1920}}, input);
auto c0 = p1.add_instruction(migraphx::op::contiguous{}, slc0);
auto c1 = p1.add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = p1.add_instruction(migraphx::op::contiguous{}, slc2);
auto c0 = mm1->add_instruction(migraphx::op::contiguous{}, slc0);
auto c1 = mm1->add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = mm1->add_instruction(migraphx::op::contiguous{}, slc2);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 128, 10, 64};
auto r0 = p1.add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = p1.add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = p1.add_instruction(migraphx::op::reshape{lens}, c2);
auto r0 = mm1->add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = mm1->add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = mm1->add_instruction(migraphx::op::reshape{lens}, c2);
auto t0 = p1.add_instruction(migraphx::op::transpose{perm0}, r0);
auto t1 = p1.add_instruction(migraphx::op::transpose{perm0}, r1);
auto t2 = p1.add_instruction(migraphx::op::transpose{perm1}, r2);
auto t0 = mm1->add_instruction(migraphx::op::transpose{perm0}, r0);
auto t1 = mm1->add_instruction(migraphx::op::transpose{perm0}, r1);
auto t2 = mm1->add_instruction(migraphx::op::transpose{perm1}, r2);
auto sum = p1.add_instruction(migraphx::op::add{}, t0, t1);
auto ret = p1.add_instruction(migraphx::op::dot{}, sum, t2);
p1.add_return({ret});
auto sum = mm1->add_instruction(migraphx::op::add{}, t0, t1);
auto ret = mm1->add_instruction(migraphx::op::dot{}, sum, t2);
mm1->add_return({ret});
return p1;
};
auto create_p2 = [&](std::size_t batch_size) {
migraphx::program p2;
auto* mm2 = p2.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
auto input = p2.add_parameter("input", s);
auto input = mm2->add_parameter("input", s);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 128, 30, 64};
auto r = p2.add_instruction(migraphx::op::reshape{lens}, input);
auto r = mm2->add_instruction(migraphx::op::reshape{lens}, input);
auto slc0 = p2.add_instruction(migraphx::op::slice{{2}, {0}, {10}}, r);
auto slc1 = p2.add_instruction(migraphx::op::slice{{2}, {10}, {20}}, r);
auto slc2 = p2.add_instruction(migraphx::op::slice{{2}, {20}, {30}}, r);
auto slc0 = mm2->add_instruction(migraphx::op::slice{{2}, {0}, {10}}, r);
auto slc1 = mm2->add_instruction(migraphx::op::slice{{2}, {10}, {20}}, r);
auto slc2 = mm2->add_instruction(migraphx::op::slice{{2}, {20}, {30}}, r);
auto t0 = p2.add_instruction(migraphx::op::transpose{perm0}, slc0);
auto t1 = p2.add_instruction(migraphx::op::transpose{perm0}, slc1);
auto t2 = p2.add_instruction(migraphx::op::transpose{perm1}, slc2);
auto t0 = mm2->add_instruction(migraphx::op::transpose{perm0}, slc0);
auto t1 = mm2->add_instruction(migraphx::op::transpose{perm0}, slc1);
auto t2 = mm2->add_instruction(migraphx::op::transpose{perm1}, slc2);
auto sum = p2.add_instruction(migraphx::op::add{}, t0, t1);
auto ret = p2.add_instruction(migraphx::op::dot{}, sum, t2);
p2.add_return({ret});
auto sum = mm2->add_instruction(migraphx::op::add{}, t0, t1);
auto ret = mm2->add_instruction(migraphx::op::dot{}, sum, t2);
mm2->add_return({ret});
return p2;
};
......@@ -1587,52 +1670,54 @@ TEST_CASE(reorder_reshape_slice_move_axis1)
{
auto create_p1 = [](std::size_t batch_size) {
migraphx::program p1;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 256, 96}};
auto* mm1 = p1.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 256, 96}};
std::vector<int64_t> perm0 = {0, 2, 1, 3};
std::vector<int64_t> perm1 = {0, 2, 3, 1};
auto input = p1.add_parameter("input", s);
auto slc0 = p1.add_instruction(migraphx::op::slice{{2}, {0}, {32}}, input);
auto slc1 = p1.add_instruction(migraphx::op::slice{{2}, {32}, {64}}, input);
auto slc2 = p1.add_instruction(migraphx::op::slice{{2}, {64}, {96}}, input);
auto input = mm1->add_parameter("input", s);
auto slc0 = mm1->add_instruction(migraphx::op::slice{{2}, {0}, {32}}, input);
auto slc1 = mm1->add_instruction(migraphx::op::slice{{2}, {32}, {64}}, input);
auto slc2 = mm1->add_instruction(migraphx::op::slice{{2}, {64}, {96}}, input);
auto c0 = p1.add_instruction(migraphx::op::contiguous{}, slc0);
auto c1 = p1.add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = p1.add_instruction(migraphx::op::contiguous{}, slc2);
auto c0 = mm1->add_instruction(migraphx::op::contiguous{}, slc0);
auto c1 = mm1->add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = mm1->add_instruction(migraphx::op::contiguous{}, slc2);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 64, 4, 32};
auto r0 = p1.add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = p1.add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = p1.add_instruction(migraphx::op::reshape{lens}, c2);
auto r0 = mm1->add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = mm1->add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = mm1->add_instruction(migraphx::op::reshape{lens}, c2);
auto t0 = p1.add_instruction(migraphx::op::transpose{perm0}, r0);
auto t1 = p1.add_instruction(migraphx::op::transpose{perm0}, r1);
auto t2 = p1.add_instruction(migraphx::op::transpose{perm1}, r2);
auto t0 = mm1->add_instruction(migraphx::op::transpose{perm0}, r0);
auto t1 = mm1->add_instruction(migraphx::op::transpose{perm0}, r1);
auto t2 = mm1->add_instruction(migraphx::op::transpose{perm1}, r2);
auto sum = p1.add_instruction(migraphx::op::add{}, t0, t1);
auto ret = p1.add_instruction(migraphx::op::dot{}, sum, t2);
p1.add_return({ret});
auto sum = mm1->add_instruction(migraphx::op::add{}, t0, t1);
auto ret = mm1->add_instruction(migraphx::op::dot{}, sum, t2);
mm1->add_return({ret});
return p1;
};
auto create_p2 = [](std::size_t batch_size) {
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 256, 96}};
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 256, 96}};
std::vector<int64_t> perm0 = {0, 2, 1, 3};
std::vector<int64_t> perm1 = {0, 2, 3, 1};
auto input = p.add_parameter("input", s);
auto input = mm->add_parameter("input", s);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 64, 4, 96};
auto rsp = p.add_instruction(migraphx::op::reshape{lens}, input);
auto slc0 = p.add_instruction(migraphx::op::slice{{3}, {0}, {32}}, rsp);
auto t0 = p.add_instruction(migraphx::op::transpose{perm0}, slc0);
auto slc1 = p.add_instruction(migraphx::op::slice{{3}, {32}, {64}}, rsp);
auto t1 = p.add_instruction(migraphx::op::transpose{perm0}, slc1);
auto slc2 = p.add_instruction(migraphx::op::slice{{3}, {64}, {96}}, rsp);
auto t2 = p.add_instruction(migraphx::op::transpose{perm1}, slc2);
auto sum = p.add_instruction(migraphx::op::add{}, t0, t1);
auto ret = p.add_instruction(migraphx::op::dot{}, sum, t2);
p.add_return({ret});
auto rsp = mm->add_instruction(migraphx::op::reshape{lens}, input);
auto slc0 = mm->add_instruction(migraphx::op::slice{{3}, {0}, {32}}, rsp);
auto t0 = mm->add_instruction(migraphx::op::transpose{perm0}, slc0);
auto slc1 = mm->add_instruction(migraphx::op::slice{{3}, {32}, {64}}, rsp);
auto t1 = mm->add_instruction(migraphx::op::transpose{perm0}, slc1);
auto slc2 = mm->add_instruction(migraphx::op::slice{{3}, {64}, {96}}, rsp);
auto t2 = mm->add_instruction(migraphx::op::transpose{perm1}, slc2);
auto sum = mm->add_instruction(migraphx::op::add{}, t0, t1);
auto ret = mm->add_instruction(migraphx::op::dot{}, sum, t2);
mm->add_return({ret});
return p;
};
......@@ -1652,41 +1737,43 @@ TEST_CASE(reorder_reshape_slice_move_axis2)
{
auto create_p1 = [] {
migraphx::program p1;
auto* mm1 = p1.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {128, 96}};
auto input = p1.add_parameter("input", s);
auto slc0 = p1.add_instruction(migraphx::op::slice{{1}, {0}, {32}}, input);
auto slc1 = p1.add_instruction(migraphx::op::slice{{1}, {32}, {64}}, input);
auto slc2 = p1.add_instruction(migraphx::op::slice{{1}, {64}, {96}}, input);
auto input = mm1->add_parameter("input", s);
auto slc0 = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {32}}, input);
auto slc1 = mm1->add_instruction(migraphx::op::slice{{1}, {32}, {64}}, input);
auto slc2 = mm1->add_instruction(migraphx::op::slice{{1}, {64}, {96}}, input);
auto c0 = p1.add_instruction(migraphx::op::contiguous{}, slc0);
auto c1 = p1.add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = p1.add_instruction(migraphx::op::contiguous{}, slc2);
auto c0 = mm1->add_instruction(migraphx::op::contiguous{}, slc0);
auto c1 = mm1->add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = mm1->add_instruction(migraphx::op::contiguous{}, slc2);
std::vector<int64_t> lens = {1, 16, 8, 32};
auto r0 = p1.add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = p1.add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = p1.add_instruction(migraphx::op::reshape{lens}, c2);
auto r0 = mm1->add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = mm1->add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = mm1->add_instruction(migraphx::op::reshape{lens}, c2);
auto sum = p1.add_instruction(migraphx::op::add{}, r0, r1);
auto ret = p1.add_instruction(migraphx::op::mul{}, sum, r2);
p1.add_return({ret});
auto sum = mm1->add_instruction(migraphx::op::add{}, r0, r1);
auto ret = mm1->add_instruction(migraphx::op::mul{}, sum, r2);
mm1->add_return({ret});
return p1;
};
auto create_p2 = [] {
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {128, 96}};
auto input = p.add_parameter("input", s);
auto input = mm->add_parameter("input", s);
std::vector<int64_t> lens = {1, 16, 8, 96};
auto rsp = p.add_instruction(migraphx::op::reshape{lens}, input);
auto slc0 = p.add_instruction(migraphx::op::slice{{3}, {0}, {32}}, rsp);
auto slc1 = p.add_instruction(migraphx::op::slice{{3}, {32}, {64}}, rsp);
auto slc2 = p.add_instruction(migraphx::op::slice{{3}, {64}, {96}}, rsp);
auto rsp = mm->add_instruction(migraphx::op::reshape{lens}, input);
auto slc0 = mm->add_instruction(migraphx::op::slice{{3}, {0}, {32}}, rsp);
auto slc1 = mm->add_instruction(migraphx::op::slice{{3}, {32}, {64}}, rsp);
auto slc2 = mm->add_instruction(migraphx::op::slice{{3}, {64}, {96}}, rsp);
auto sum = p.add_instruction(migraphx::op::add{}, slc0, slc1);
auto ret = p.add_instruction(migraphx::op::mul{}, sum, slc2);
p.add_return({ret});
auto sum = mm->add_instruction(migraphx::op::add{}, slc0, slc1);
auto ret = mm->add_instruction(migraphx::op::mul{}, sum, slc2);
mm->add_return({ret});
return p;
};
......@@ -1701,24 +1788,25 @@ TEST_CASE(reorder_reshape_slice_not_apply)
{
auto create_p = [] {
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {128, 96}};
auto input = p.add_parameter("input", s);
auto slc0 = p.add_instruction(migraphx::op::slice{{1}, {0}, {32}}, input);
auto slc1 = p.add_instruction(migraphx::op::slice{{1}, {32}, {64}}, input);
auto slc2 = p.add_instruction(migraphx::op::slice{{1}, {64}, {96}}, input);
auto input = mm->add_parameter("input", s);
auto slc0 = mm->add_instruction(migraphx::op::slice{{1}, {0}, {32}}, input);
auto slc1 = mm->add_instruction(migraphx::op::slice{{1}, {32}, {64}}, input);
auto slc2 = mm->add_instruction(migraphx::op::slice{{1}, {64}, {96}}, input);
auto c0 = p.add_instruction(migraphx::op::contiguous{}, slc0);
auto c1 = p.add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = p.add_instruction(migraphx::op::contiguous{}, slc2);
auto c0 = mm->add_instruction(migraphx::op::contiguous{}, slc0);
auto c1 = mm->add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = mm->add_instruction(migraphx::op::contiguous{}, slc2);
std::vector<int64_t> lens = {1, 16, 16, 16};
auto r0 = p.add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = p.add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = p.add_instruction(migraphx::op::reshape{lens}, c2);
auto r0 = mm->add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = mm->add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = mm->add_instruction(migraphx::op::reshape{lens}, c2);
auto sum = p.add_instruction(migraphx::op::add{}, r0, r1);
auto ret = p.add_instruction(migraphx::op::mul{}, sum, r2);
p.add_return({ret});
auto sum = mm->add_instruction(migraphx::op::add{}, r0, r1);
auto ret = mm->add_instruction(migraphx::op::mul{}, sum, r2);
mm->add_return({ret});
return p;
};
......@@ -1733,25 +1821,26 @@ TEST_CASE(reorder_reshape_slice_diff_dims)
{
auto create_p1 = [](std::size_t batch_size) {
migraphx::program p1;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 96, 96}};
auto* mm1 = p1.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 96, 96}};
std::vector<int64_t> perm0 = {0, 2, 1, 3};
std::vector<int64_t> perm1 = {0, 2, 3, 1};
auto input = p1.add_parameter("input", s);
auto slc0 = p1.add_instruction(migraphx::op::slice{{2}, {0}, {32}}, input);
auto slc1 = p1.add_instruction(migraphx::op::slice{{2}, {32}, {64}}, input);
auto slc2 = p1.add_instruction(migraphx::op::slice{{2}, {64}, {96}}, input);
auto input = mm1->add_parameter("input", s);
auto slc0 = mm1->add_instruction(migraphx::op::slice{{2}, {0}, {32}}, input);
auto slc1 = mm1->add_instruction(migraphx::op::slice{{2}, {32}, {64}}, input);
auto slc2 = mm1->add_instruction(migraphx::op::slice{{2}, {64}, {96}}, input);
auto c0 = p1.add_instruction(migraphx::op::contiguous{}, slc0);
auto c1 = p1.add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = p1.add_instruction(migraphx::op::contiguous{}, slc2);
auto c0 = mm1->add_instruction(migraphx::op::contiguous{}, slc0);
auto c1 = mm1->add_instruction(migraphx::op::contiguous{}, slc1);
auto c2 = mm1->add_instruction(migraphx::op::contiguous{}, slc2);
std::vector<int64_t> lens = {static_cast<int64_t>(batch_size), 32, 3, 32};
std::vector<int64_t> lens1 = {static_cast<int64_t>(batch_size), 48, 2, 32};
auto r0 = p1.add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = p1.add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = p1.add_instruction(migraphx::op::reshape{lens1}, c2);
auto r0 = mm1->add_instruction(migraphx::op::reshape{lens}, c0);
auto r1 = mm1->add_instruction(migraphx::op::reshape{lens}, c1);
auto r2 = mm1->add_instruction(migraphx::op::reshape{lens1}, c2);
p1.add_return({r0, r1, r2});
mm1->add_return({r0, r1, r2});
return p1;
};
......@@ -1772,36 +1861,38 @@ TEST_CASE(reorder_slice_trans)
std::vector<int64_t> perm = {0, 2, 1};
auto create_p1 = [&](std::size_t batch_size) {
migraphx::program p1;
auto* mm1 = p1.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
auto input = p1.add_parameter("input", s);
auto slc0 = p1.add_instruction(migraphx::op::slice{{2}, {0}, {640}}, input);
auto slc1 = p1.add_instruction(migraphx::op::slice{{2}, {640}, {1280}}, input);
auto slc2 = p1.add_instruction(migraphx::op::slice{{2}, {1280}, {1920}}, input);
auto input = mm1->add_parameter("input", s);
auto slc0 = mm1->add_instruction(migraphx::op::slice{{2}, {0}, {640}}, input);
auto slc1 = mm1->add_instruction(migraphx::op::slice{{2}, {640}, {1280}}, input);
auto slc2 = mm1->add_instruction(migraphx::op::slice{{2}, {1280}, {1920}}, input);
auto t0 = p1.add_instruction(migraphx::op::transpose{perm}, slc0);
auto t1 = p1.add_instruction(migraphx::op::transpose{perm}, slc1);
auto t2 = p1.add_instruction(migraphx::op::transpose{perm}, slc2);
auto t0 = mm1->add_instruction(migraphx::op::transpose{perm}, slc0);
auto t1 = mm1->add_instruction(migraphx::op::transpose{perm}, slc1);
auto t2 = mm1->add_instruction(migraphx::op::transpose{perm}, slc2);
auto sum = p1.add_instruction(migraphx::op::add{}, t0, t1);
auto ret = p1.add_instruction(migraphx::op::mul{}, sum, t2);
p1.add_return({ret});
auto sum = mm1->add_instruction(migraphx::op::add{}, t0, t1);
auto ret = mm1->add_instruction(migraphx::op::mul{}, sum, t2);
mm1->add_return({ret});
return p1;
};
auto create_p2 = [&](std::size_t batch_size) {
migraphx::program p2;
auto* mm2 = p2.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
auto input = p2.add_parameter("input", s);
auto r = p2.add_instruction(migraphx::op::transpose{perm}, input);
auto input = mm2->add_parameter("input", s);
auto r = mm2->add_instruction(migraphx::op::transpose{perm}, input);
auto slc0 = p2.add_instruction(migraphx::op::slice{{1}, {0}, {640}}, r);
auto slc1 = p2.add_instruction(migraphx::op::slice{{1}, {640}, {1280}}, r);
auto slc2 = p2.add_instruction(migraphx::op::slice{{1}, {1280}, {1920}}, r);
auto slc0 = mm2->add_instruction(migraphx::op::slice{{1}, {0}, {640}}, r);
auto slc1 = mm2->add_instruction(migraphx::op::slice{{1}, {640}, {1280}}, r);
auto slc2 = mm2->add_instruction(migraphx::op::slice{{1}, {1280}, {1920}}, r);
auto sum = p2.add_instruction(migraphx::op::add{}, slc0, slc1);
auto ret = p2.add_instruction(migraphx::op::mul{}, sum, slc2);
p2.add_return({ret});
auto sum = mm2->add_instruction(migraphx::op::add{}, slc0, slc1);
auto ret = mm2->add_instruction(migraphx::op::mul{}, sum, slc2);
mm2->add_return({ret});
return p2;
};
......@@ -1821,21 +1912,22 @@ TEST_CASE(reorder_slice_trans_diff_perm)
{
auto create_p1 = [](std::size_t batch_size) {
migraphx::program p1;
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
auto* mm1 = p1.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {batch_size, 128, 1920}};
std::vector<int64_t> perm0 = {0, 2, 1};
std::vector<int64_t> perm1 = {0, 1, 2};
auto input = p1.add_parameter("input", s);
auto slc0 = p1.add_instruction(migraphx::op::slice{{2}, {0}, {640}}, input);
auto slc1 = p1.add_instruction(migraphx::op::slice{{2}, {640}, {1280}}, input);
auto slc2 = p1.add_instruction(migraphx::op::slice{{2}, {1280}, {1920}}, input);
auto t0 = p1.add_instruction(migraphx::op::transpose{perm0}, slc0);
auto t1 = p1.add_instruction(migraphx::op::transpose{perm0}, slc1);
auto t2 = p1.add_instruction(migraphx::op::transpose{perm1}, slc2);
auto sum = p1.add_instruction(migraphx::op::add{}, t0, t1);
auto ret = p1.add_instruction(migraphx::op::dot{}, sum, t2);
p1.add_return({ret});
auto input = mm1->add_parameter("input", s);
auto slc0 = mm1->add_instruction(migraphx::op::slice{{2}, {0}, {640}}, input);
auto slc1 = mm1->add_instruction(migraphx::op::slice{{2}, {640}, {1280}}, input);
auto slc2 = mm1->add_instruction(migraphx::op::slice{{2}, {1280}, {1920}}, input);
auto t0 = mm1->add_instruction(migraphx::op::transpose{perm0}, slc0);
auto t1 = mm1->add_instruction(migraphx::op::transpose{perm0}, slc1);
auto t2 = mm1->add_instruction(migraphx::op::transpose{perm1}, slc2);
auto sum = mm1->add_instruction(migraphx::op::add{}, t0, t1);
auto ret = mm1->add_instruction(migraphx::op::dot{}, sum, t2);
mm1->add_return({ret});
return p1;
};
......
......@@ -9,17 +9,20 @@
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}});
auto* mm = p.get_main_module();
migraphx::run_passes(*mm, {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}});
}
TEST_CASE(double_contig)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
p.add_return({c2});
auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2());
auto t1 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c1 = mm->add_instruction(migraphx::op::contiguous{}, t1);
auto c2 = mm->add_instruction(migraphx::op::contiguous{}, c1);
mm->add_return({c2});
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
run_pass(p);
......@@ -33,10 +36,12 @@ TEST_CASE(double_contig)
TEST_CASE(double_transpose)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
p.add_return({t2});
auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2());
auto t1 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto t2 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, t1);
mm->add_return({t2});
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
run_pass(p);
......@@ -50,12 +55,14 @@ TEST_CASE(double_transpose)
TEST_CASE(double_transpose_contig)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t1);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 0}}, c1);
auto c2 = p.add_instruction(migraphx::op::contiguous{}, t2);
p.add_return({c2});
auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2());
auto t1 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto c1 = mm->add_instruction(migraphx::op::contiguous{}, t1);
auto t2 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, c1);
auto c2 = mm->add_instruction(migraphx::op::contiguous{}, t2);
mm->add_return({c2});
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
run_pass(p);
......@@ -69,9 +76,11 @@ TEST_CASE(double_transpose_contig)
TEST_CASE(single_transpose)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
p.add_return({t1});
auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2());
auto t1 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
mm->add_return({t1});
EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().transposed());
run_pass(p);
......@@ -85,9 +94,11 @@ TEST_CASE(single_transpose)
TEST_CASE(double_transpose_sin_pass)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
p.add_instruction(migraphx::op::transpose{{1, 0}}, t1);
auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2());
auto t1 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
mm->add_instruction(migraphx::op::transpose{{1, 0}}, t1);
EXPECT(p.get_output_shapes().back().standard());
EXPECT(not p.get_output_shapes().back().transposed());
run_pass(p);
......@@ -102,8 +113,10 @@ TEST_CASE(double_transpose_sin_pass)
TEST_CASE(single_transpose_sin_pass)
{
migraphx::program p;
auto l = p.add_literal(get_2x2());
p.add_instruction(migraphx::op::transpose{{1, 0}}, l);
auto* mm = p.get_main_module();
auto l = mm->add_literal(get_2x2());
mm->add_instruction(migraphx::op::transpose{{1, 0}}, l);
EXPECT(not p.get_output_shapes().back().standard());
EXPECT(p.get_output_shapes().back().transposed());
run_pass(p);
......@@ -117,13 +130,15 @@ TEST_CASE(single_transpose_sin_pass)
TEST_CASE(reshape_transpose)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}};
auto x = p.add_parameter("x", s);
auto r1 = p.add_instruction(migraphx::op::reshape{{1, 4, 28, 56, 56}}, x);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 1, 3, 4}}, r1);
auto ct = p.add_instruction(migraphx::op::contiguous{}, t);
auto r2 = p.add_instruction(migraphx::op::reshape{{1, 112, 56, 56}}, ct);
p.add_return({r2});
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 112, 56, 56}};
auto x = mm->add_parameter("x", s);
auto r1 = mm->add_instruction(migraphx::op::reshape{{1, 4, 28, 56, 56}}, x);
auto t = mm->add_instruction(migraphx::op::transpose{{0, 2, 1, 3, 4}}, r1);
auto ct = mm->add_instruction(migraphx::op::contiguous{}, t);
auto r2 = mm->add_instruction(migraphx::op::reshape{{1, 112, 56, 56}}, ct);
mm->add_return({r2});
EXPECT(p.get_output_shapes().back() == s);
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -134,11 +149,13 @@ TEST_CASE(reshape_transpose)
TEST_CASE(transpose_contiguous)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {4, 4}};
auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
p.add_return({c1});
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {4, 4}};
auto x = mm->add_parameter("x", s);
auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto c1 = mm->add_instruction(migraphx::op::contiguous{}, t);
mm->add_return({c1});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -149,28 +166,32 @@ TEST_CASE(transpose_contiguous)
TEST_CASE(transpose_double_contiguous)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {4, 4}};
auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto c1 = p.add_instruction(migraphx::op::contiguous{}, t);
auto c2 = p.add_instruction(migraphx::op::contiguous{}, c1);
p.add_return({c2});
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {4, 4}};
auto x = mm->add_parameter("x", s);
auto t = mm->add_instruction(migraphx::op::transpose{{1, 0}}, x);
auto c1 = mm->add_instruction(migraphx::op::contiguous{}, t);
auto c2 = mm->add_instruction(migraphx::op::contiguous{}, c1);
mm->add_return({c2});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(p.get_output_shapes().back() == out_shape);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
EXPECT(p.has_instruction(t));
EXPECT(mm->has_instruction(t));
}
TEST_CASE(transpose_partial1)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
p.add_return({t2});
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = mm->add_parameter("x", s);
auto t1 = mm->add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = mm->add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
mm->add_return({t2});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -181,12 +202,14 @@ TEST_CASE(transpose_partial1)
TEST_CASE(transpose_partial2)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
p.add_return({t3});
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = mm->add_parameter("x", s);
auto t1 = mm->add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = mm->add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = mm->add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
mm->add_return({t3});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -197,13 +220,15 @@ TEST_CASE(transpose_partial2)
TEST_CASE(transpose_partial3)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
auto t4 = p.add_instruction(migraphx::op::transpose{{1, 0, 2}}, t3);
p.add_return({t4});
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = mm->add_parameter("x", s);
auto t1 = mm->add_instruction(migraphx::op::transpose{{1, 0, 2}}, x);
auto t2 = mm->add_instruction(migraphx::op::transpose{{1, 2, 0}}, t1);
auto t3 = mm->add_instruction(migraphx::op::transpose{{1, 0, 2}}, t2);
auto t4 = mm->add_instruction(migraphx::op::transpose{{1, 0, 2}}, t3);
mm->add_return({t4});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -214,10 +239,12 @@ TEST_CASE(transpose_partial3)
TEST_CASE(nop_transpose1)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
p.add_return({t});
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = mm->add_parameter("x", s);
auto t = mm->add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
mm->add_return({t});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -228,13 +255,15 @@ TEST_CASE(nop_transpose1)
TEST_CASE(nop_transpose2)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t1);
auto t3 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t2);
auto t4 = p.add_instruction(migraphx::op::transpose{{0, 1, 2}}, t3);
p.add_instruction(pass_op{}, t4);
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = mm->add_parameter("x", s);
auto t1 = mm->add_instruction(migraphx::op::transpose{{0, 1, 2}}, x);
auto t2 = mm->add_instruction(migraphx::op::transpose{{0, 1, 2}}, t1);
auto t3 = mm->add_instruction(migraphx::op::transpose{{0, 1, 2}}, t2);
auto t4 = mm->add_instruction(migraphx::op::transpose{{0, 1, 2}}, t3);
mm->add_instruction(pass_op{}, t4);
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -245,13 +274,15 @@ TEST_CASE(nop_transpose2)
TEST_CASE(nop_transpose3)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto concat = p.add_instruction(migraphx::op::concat{3}, x, y);
auto t1 = p.add_instruction(migraphx::op::transpose{{0, 1, 2, 3}}, concat);
auto t2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, t1);
p.add_return({t2});
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto concat = mm->add_instruction(migraphx::op::concat{3}, x, y);
auto t1 = mm->add_instruction(migraphx::op::transpose{{0, 1, 2, 3}}, concat);
auto t2 = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, t1);
mm->add_return({t2});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -262,10 +293,12 @@ TEST_CASE(nop_transpose3)
TEST_CASE(nop_convert)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = p.add_parameter("x", s);
auto t = p.add_instruction(migraphx::op::convert{migraphx::shape::float_type}, x);
p.add_return({t});
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3}};
auto x = mm->add_parameter("x", s);
auto t = mm->add_instruction(migraphx::op::convert{migraphx::shape::float_type}, x);
mm->add_return({t});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -276,14 +309,16 @@ TEST_CASE(nop_convert)
TEST_CASE(concat_transpose1)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto xt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto yt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y);
auto concat = p.add_instruction(migraphx::op::concat{2}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, concat);
p.add_return({t});
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto xt = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, x);
auto yt = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y);
auto concat = mm->add_instruction(migraphx::op::concat{2}, xt, yt);
auto t = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, concat);
mm->add_return({t});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -298,14 +333,16 @@ TEST_CASE(concat_transpose1)
TEST_CASE(concat_transpose2)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto xt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
auto yt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
auto concat = p.add_instruction(migraphx::op::concat{-1}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
p.add_return({t});
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto xt = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
auto yt = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
auto concat = mm->add_instruction(migraphx::op::concat{-1}, xt, yt);
auto t = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
mm->add_return({t});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -320,14 +357,16 @@ TEST_CASE(concat_transpose2)
TEST_CASE(concat_transpose3)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}});
auto y = p.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {1, 5, 3, 4}});
auto xt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
auto yt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
p.add_return({t});
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}});
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::float_type, {1, 5, 3, 4}});
auto xt = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
auto yt = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, y);
auto concat = mm->add_instruction(migraphx::op::concat{3}, xt, yt);
auto t = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
mm->add_return({t});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -342,32 +381,35 @@ TEST_CASE(concat_transpose3)
TEST_CASE(concat_transpose4)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto sx = migraphx::shape{migraphx::shape::float_type, {1, 1, 12, 64}};
auto sy = migraphx::shape{migraphx::shape::float_type, {1, 12, 1, 64}};
auto x = p.add_parameter("x", sx);
auto y = p.add_parameter("y", sy);
auto xt = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
auto yt = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y);
auto concat = p.add_instruction(migraphx::op::concat{3}, xt, yt);
auto t = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
p.add_return({t});
auto x = mm->add_parameter("x", sx);
auto y = mm->add_parameter("y", sy);
auto xt = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, x);
auto yt = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, y);
auto concat = mm->add_instruction(migraphx::op::concat{3}, xt, yt);
auto t = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, concat);
mm->add_return({t});
migraphx::program p1 = p;
run_pass(p);
EXPECT(p1 == p);
}
TEST_CASE(nested_concat)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto concat1 = p.add_instruction(migraphx::op::concat{1}, x, y);
auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x);
auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2);
p.add_return({concat3});
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto concat1 = mm->add_instruction(migraphx::op::concat{1}, x, y);
auto concat2 = mm->add_instruction(migraphx::op::concat{1}, y, x);
auto concat3 = mm->add_instruction(migraphx::op::concat{1}, concat1, concat2);
mm->add_return({concat3});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -379,15 +421,17 @@ TEST_CASE(nested_concat)
TEST_CASE(nested_concat_partial)
{
migraphx::program p;
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = p.add_parameter("x", s);
auto y = p.add_parameter("y", s);
auto l = p.add_literal(
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 4}};
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto l = mm->add_literal(
migraphx::generate_literal(migraphx::shape{migraphx::shape::float_type, {1, 4, 3, 4}}));
auto concat1 = p.add_instruction(migraphx::op::concat{1}, x, y);
auto concat2 = p.add_instruction(migraphx::op::concat{1}, y, x);
auto concat3 = p.add_instruction(migraphx::op::concat{1}, concat1, concat2, l);
p.add_return({concat3});
auto concat1 = mm->add_instruction(migraphx::op::concat{1}, x, y);
auto concat2 = mm->add_instruction(migraphx::op::concat{1}, y, x);
auto concat3 = mm->add_instruction(migraphx::op::concat{1}, concat1, concat2, l);
mm->add_return({concat3});
auto out_shape = p.get_output_shapes().back();
auto n = std::distance(p.begin(), p.end());
run_pass(p);
......@@ -399,11 +443,13 @@ TEST_CASE(nested_concat_partial)
TEST_CASE(multibroadcast_simplify)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> s_lens{1, 2, 3, 4};
auto s = migraphx::shape{migraphx::shape::float_type, s_lens};
auto x = p.add_parameter("x", s);
auto y = p.add_instruction(migraphx::op::multibroadcast{s_lens}, x);
p.add_instruction(migraphx::op::mul{}, y, y);
auto x = mm->add_parameter("x", s);
auto y = mm->add_instruction(migraphx::op::multibroadcast{s_lens}, x);
mm->add_instruction(migraphx::op::mul{}, y, y);
auto n = std::distance(p.begin(), p.end());
run_pass(p);
EXPECT(std::distance(p.begin(), p.end()) == n - 1);
......@@ -412,19 +458,21 @@ TEST_CASE(multibroadcast_simplify)
TEST_CASE(double_slice1)
{
migraphx::program p1;
auto* mm1 = p1.get_main_module();
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice1 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {256}}, x);
auto slice2 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {64}}, slice1);
p1.add_return({slice2});
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice1 = mm1->add_instruction(migraphx::op::slice{{0}, {32}, {256}}, x);
auto slice2 = mm1->add_instruction(migraphx::op::slice{{0}, {32}, {64}}, slice1);
mm1->add_return({slice2});
}
run_pass(p1);
migraphx::program p2;
auto* mm2 = p2.get_main_module();
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice = p2.add_instruction(migraphx::op::slice{{0}, {64}, {96}}, x);
p2.add_return({slice});
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice = mm2->add_instruction(migraphx::op::slice{{0}, {64}, {96}}, x);
mm2->add_return({slice});
}
EXPECT(p1 == p2);
}
......@@ -432,19 +480,21 @@ TEST_CASE(double_slice1)
TEST_CASE(double_slice2)
{
migraphx::program p1;
auto* mm1 = p1.get_main_module();
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice1 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {128}}, x);
auto slice2 = p1.add_instruction(migraphx::op::slice{{0}, {0}, {32}}, slice1);
p1.add_return({slice2});
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice1 = mm1->add_instruction(migraphx::op::slice{{0}, {32}, {128}}, x);
auto slice2 = mm1->add_instruction(migraphx::op::slice{{0}, {0}, {32}}, slice1);
mm1->add_return({slice2});
}
run_pass(p1);
migraphx::program p2;
auto* mm2 = p2.get_main_module();
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice = p2.add_instruction(migraphx::op::slice{{0}, {32}, {64}}, x);
p2.add_return({slice});
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {256}});
auto slice = mm2->add_instruction(migraphx::op::slice{{0}, {32}, {64}}, x);
mm2->add_return({slice});
}
EXPECT(p1 == p2);
}
......@@ -452,19 +502,22 @@ TEST_CASE(double_slice2)
TEST_CASE(double_slice_multi_axes)
{
migraphx::program p1;
auto* mm1 = p1.get_main_module();
{
auto x = p1.add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
auto slice1 = p1.add_instruction(migraphx::op::slice{{0}, {32}, {128}}, x);
auto slice2 = p1.add_instruction(migraphx::op::slice{{1}, {0}, {32}}, slice1);
p1.add_return({slice2});
auto x = mm1->add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
auto slice1 = mm1->add_instruction(migraphx::op::slice{{0}, {32}, {128}}, x);
auto slice2 = mm1->add_instruction(migraphx::op::slice{{1}, {0}, {32}}, slice1);
mm1->add_return({slice2});
}
run_pass(p1);
migraphx::program p2;
auto* mm2 = p2.get_main_module();
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
auto slice = p2.add_instruction(migraphx::op::slice{{0, 1}, {32, 0}, {128, 32}}, x);
p2.add_return({slice});
auto x = mm2->add_parameter("x", {migraphx::shape::int32_type, {256, 128}});
auto slice = mm2->add_instruction(migraphx::op::slice{{0, 1}, {32, 0}, {128, 32}}, x);
mm2->add_return({slice});
}
EXPECT(p1 == p2);
}
......
......@@ -19,8 +19,9 @@ migraphx::program parse_tf(const std::string& name, bool is_nhwc)
migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
{
auto prog = migraphx::parse_tf(name, migraphx::tf_options{is_nhwc, 1});
auto* mm = prog.get_main_module();
if(is_nhwc)
migraphx::run_passes(prog,
migraphx::run_passes(*mm,
{migraphx::simplify_reshapes{},
migraphx::dead_code_elimination{},
migraphx::eliminate_identity{}});
......@@ -30,9 +31,11 @@ migraphx::program optimize_tf(const std::string& name, bool is_nhwc)
TEST_CASE(add_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
p.add_instruction(migraphx::op::add{}, l0, l1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
mm->add_instruction(migraphx::op::add{}, l0, l1);
auto prog = optimize_tf("add_test.pb", false);
EXPECT(p == prog);
......@@ -42,12 +45,14 @@ TEST_CASE(add_bcast_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::float_type, {2, 3}};
auto l0 = p.add_parameter("0", s0);
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}});
auto l2 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l0);
auto l3 = p.add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1);
p.add_instruction(migraphx::op::add{}, l2, l3);
auto l0 = mm->add_parameter("0", s0);
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2, 1}});
auto l2 = mm->add_instruction(migraphx::op::multibroadcast{s0.lens()}, l0);
auto l3 = mm->add_instruction(migraphx::op::multibroadcast{s0.lens()}, l1);
mm->add_instruction(migraphx::op::add{}, l2, l3);
auto prog = optimize_tf("add_bcast_test.pb", false);
EXPECT(p == prog);
......@@ -56,10 +61,12 @@ TEST_CASE(add_bcast_test)
TEST_CASE(argmax_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
auto ins = p.add_instruction(migraphx::op::argmax{2}, l0);
p.add_instruction(migraphx::op::squeeze{{2}}, ins);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
auto ins = mm->add_instruction(migraphx::op::argmax{2}, l0);
mm->add_instruction(migraphx::op::squeeze{{2}}, ins);
auto prog = parse_tf("argmax_test.pb", false);
EXPECT(p == prog);
......@@ -68,10 +75,12 @@ TEST_CASE(argmax_test)
TEST_CASE(argmin_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
auto ins = p.add_instruction(migraphx::op::argmin{2}, l0);
p.add_instruction(migraphx::op::squeeze{{2}}, ins);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {2}});
auto ins = mm->add_instruction(migraphx::op::argmin{2}, l0);
mm->add_instruction(migraphx::op::squeeze{{2}}, ins);
auto prog = parse_tf("argmin_test.pb", false);
EXPECT(p == prog);
......@@ -80,14 +89,16 @@ TEST_CASE(argmin_test)
TEST_CASE(assert_less_equal_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::float_type, {2, 3}};
auto l0 = p.add_parameter("0", s0);
auto l1 = p.add_parameter("1", s0);
auto l0 = mm->add_parameter("0", s0);
auto l1 = mm->add_parameter("1", s0);
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {0, 1}};
auto l2 = p.add_literal(l);
p.add_instruction(migraphx::op::add{}, l0, l1);
auto l3 = p.add_instruction(migraphx::op::identity{}, l0, l1);
p.add_instruction(migraphx::op::identity{}, l3, l2);
auto l2 = mm->add_literal(l);
mm->add_instruction(migraphx::op::add{}, l0, l1);
auto l3 = mm->add_instruction(migraphx::op::identity{}, l0, l1);
mm->add_instruction(migraphx::op::identity{}, l3, l2);
auto prog = optimize_tf("assert_less_equal_test.pb", false);
EXPECT(p == prog);
......@@ -96,13 +107,15 @@ TEST_CASE(assert_less_equal_test)
TEST_CASE(batchmatmul_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 4}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 8}});
auto trans_l0 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0);
auto trans_l1 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 8, 4}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 4, 8}});
auto trans_l0 = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l0);
auto trans_l1 = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1);
p.add_instruction(migraphx::op::dot{}, trans_l0, trans_l1);
mm->add_instruction(migraphx::op::dot{}, trans_l0, trans_l1);
auto prog = optimize_tf("batchmatmul_test.pb", false);
EXPECT(p == prog);
......@@ -114,18 +127,20 @@ TEST_CASE(batchnorm_test)
float momentum = 0.9f;
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::op::batch_norm_inference op{
epsilon, momentum, migraphx::op::batch_norm_inference::spatial};
migraphx::shape s0{migraphx::shape::float_type, {32}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 32, 16, 16}});
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 32, 16, 16}});
std::vector<float> const_vals(32);
std::fill(const_vals.begin(), const_vals.end(), 1.0f);
auto l2 = p.add_parameter("2", s0);
auto l3 = p.add_parameter("3", s0);
auto l4 = p.add_parameter("4", s0);
auto l1 = p.add_literal(migraphx::literal{s0, const_vals});
p.add_instruction(op, l0, l1, l2, l3, l4);
auto l2 = mm->add_parameter("2", s0);
auto l3 = mm->add_parameter("3", s0);
auto l4 = mm->add_parameter("4", s0);
auto l1 = mm->add_literal(migraphx::literal{s0, const_vals});
mm->add_instruction(op, l0, l1, l2, l3, l4);
auto prog = optimize_tf("batchnorm_test.pb", true);
EXPECT(p == prog);
......@@ -134,12 +149,14 @@ TEST_CASE(batchnorm_test)
TEST_CASE(biasadd_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::float_type, {1, 500, 1, 1}};
uint64_t axis = 1;
auto l0 = p.add_parameter("0", s0);
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
auto l2 = p.add_instruction(migraphx::op::broadcast{axis, l0->get_shape().lens()}, l1);
p.add_instruction(migraphx::op::add{}, l0, l2);
auto l0 = mm->add_parameter("0", s0);
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {500}});
auto l2 = mm->add_instruction(migraphx::op::broadcast{axis, l0->get_shape().lens()}, l1);
mm->add_instruction(migraphx::op::add{}, l0, l2);
auto prog = optimize_tf("biasadd_test.pb", true);
EXPECT(p == prog);
......@@ -148,8 +165,10 @@ TEST_CASE(biasadd_test)
TEST_CASE(cast_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::convert{migraphx::shape::int32_type}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
mm->add_instruction(migraphx::op::convert{migraphx::shape::int32_type}, l0);
auto prog = optimize_tf("cast_test.pb", false);
EXPECT(p == prog);
......@@ -159,15 +178,17 @@ TEST_CASE(concat_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 7, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 7, 3}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
int axis = 1;
// tf uses axis as the third input, and it is in int32 format
// add the literal using a vector in order to set stride to 1 (like in tf parser)
p.add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis});
mm->add_literal(migraphx::shape{migraphx::shape::int32_type}, std::vector<int>{axis});
p.add_instruction(migraphx::op::concat{axis}, l0, l1);
mm->add_instruction(migraphx::op::concat{axis}, l0, l1);
auto prog = optimize_tf("concat_test.pb", false);
EXPECT(p == prog);
......@@ -176,7 +197,9 @@ TEST_CASE(concat_test)
TEST_CASE(const_test)
{
migraphx::program p;
p.add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector<float>{1.0f});
auto* mm = p.get_main_module();
mm->add_literal(migraphx::shape{migraphx::shape::float_type}, std::vector<float>{1.0f});
auto prog = optimize_tf("constant_test.pb", false);
EXPECT(p == prog);
......@@ -186,19 +209,21 @@ TEST_CASE(conv_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
std::vector<float> weight_data(3 * 3 * 3 * 32);
std::fill(weight_data.begin(), weight_data.end(), 1.0f);
auto l1 =
p.add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 32}}, weight_data);
mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 32}}, weight_data);
migraphx::op::convolution op;
op.padding_mode = migraphx::op::padding_mode_t::same;
op.padding = {1, 1};
op.stride = {1, 1};
op.dilation = {1, 1};
auto l2 = p.add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
p.add_instruction(op, l0, l2);
auto l2 = mm->add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
mm->add_instruction(op, l0, l2);
auto prog = optimize_tf("conv_test.pb", true);
EXPECT(p == prog);
......@@ -208,11 +233,13 @@ TEST_CASE(depthwiseconv_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
std::vector<float> weight_data(3 * 3 * 3 * 1);
std::fill(weight_data.begin(), weight_data.end(), 1.0f);
auto l1 =
p.add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 1}}, weight_data);
mm->add_literal(migraphx::shape{migraphx::shape::float_type, {3, 3, 3, 1}}, weight_data);
migraphx::op::convolution op;
op.padding_mode = migraphx::op::padding_mode_t::same;
......@@ -220,10 +247,10 @@ TEST_CASE(depthwiseconv_test)
op.stride = {1, 1};
op.dilation = {1, 1};
op.group = 3;
auto l3 = p.add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
auto l4 = p.add_instruction(migraphx::op::contiguous{}, l3);
auto l5 = p.add_instruction(migraphx::op::reshape{{3, 1, 3, 3}}, l4);
p.add_instruction(op, l0, l5);
auto l3 = mm->add_instruction(migraphx::op::transpose{{3, 2, 0, 1}}, l1);
auto l4 = mm->add_instruction(migraphx::op::contiguous{}, l3);
auto l5 = mm->add_instruction(migraphx::op::reshape{{3, 1, 3, 3}}, l4);
mm->add_instruction(op, l0, l5);
auto prog = optimize_tf("depthwise_conv_test.pb", true);
EXPECT(p == prog);
......@@ -233,9 +260,11 @@ TEST_CASE(expanddims_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
p.add_literal(0);
p.add_instruction(migraphx::op::reshape{{1, 2, 3, 4}}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
mm->add_literal(0);
mm->add_instruction(migraphx::op::reshape{{1, 2, 3, 4}}, l0);
auto prog = optimize_tf("expanddims_test.pb", false);
EXPECT(p == prog);
......@@ -246,9 +275,11 @@ TEST_CASE(expanddims_test_neg_dims)
// this check makes sure the pb parses negative dim value correctly
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
p.add_literal(-1);
p.add_instruction(migraphx::op::reshape{{2, 3, 4, 1}}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 4}});
mm->add_literal(-1);
mm->add_instruction(migraphx::op::reshape{{2, 3, 4, 1}}, l0);
auto prog = optimize_tf("expanddims_neg_test.pb", false);
EXPECT(p == prog);
......@@ -258,13 +289,15 @@ TEST_CASE(gather_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}});
auto l1 =
p.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 1}});
p.add_literal(1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 4}});
auto l1 = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 1}});
mm->add_literal(1);
int axis = 1;
p.add_instruction(migraphx::op::gather{axis}, l0, l1);
mm->add_instruction(migraphx::op::gather{axis}, l0, l1);
auto prog = optimize_tf("gather_test.pb", false);
EXPECT(p == prog);
......@@ -273,8 +306,10 @@ TEST_CASE(gather_test)
TEST_CASE(identity_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::identity{}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
mm->add_instruction(migraphx::op::identity{}, l0);
auto prog = optimize_tf("identity_test.pb", false);
EXPECT(p == prog);
......@@ -283,13 +318,15 @@ TEST_CASE(identity_test)
TEST_CASE(matmul_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {8, 4}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 8}});
auto trans_l0 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l0);
auto trans_l1 = p.add_instruction(migraphx::op::transpose{{1, 0}}, l1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {8, 4}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {4, 8}});
p.add_instruction(migraphx::op::dot{}, trans_l0, trans_l1);
auto trans_l0 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l0);
auto trans_l1 = mm->add_instruction(migraphx::op::transpose{{1, 0}}, l1);
mm->add_instruction(migraphx::op::dot{}, trans_l0, trans_l1);
auto prog = optimize_tf("matmul_test.pb", false);
EXPECT(p == prog);
......@@ -298,14 +335,16 @@ TEST_CASE(matmul_test)
TEST_CASE(mean_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {2, 3}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_literal(l);
p.add_literal(l);
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
mm->add_literal(l);
mm->add_literal(l);
migraphx::op::reduce_mean op{{2, 3}};
p.add_instruction(op, l0);
auto l3 = p.add_instruction(op, l0);
p.add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
mm->add_instruction(op, l0);
auto l3 = mm->add_instruction(op, l0);
mm->add_instruction(migraphx::op::squeeze{{2, 3}}, l3);
auto prog = optimize_tf("mean_test.pb", false);
EXPECT(p == prog);
......@@ -314,12 +353,14 @@ TEST_CASE(mean_test)
TEST_CASE(mean_test_nhwc)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::literal l{migraphx::shape{migraphx::shape::int32_type, {2}}, {1, 2}};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto l1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto l1 = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
migraphx::op::reduce_mean op{{1, 2}};
auto l2 = p.add_instruction(op, l1);
p.add_instruction(migraphx::op::squeeze{{1, 2}}, l2);
auto l2 = mm->add_instruction(op, l1);
mm->add_instruction(migraphx::op::squeeze{{1, 2}}, l2);
auto prog = optimize_tf("mean_test_nhwc.pb", true);
EXPECT(p == prog);
......@@ -328,10 +369,12 @@ TEST_CASE(mean_test_nhwc)
TEST_CASE(mul_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
p.add_instruction(migraphx::op::mul{}, l0, l1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 1, 1, 16}});
mm->add_instruction(migraphx::op::mul{}, l0, l1);
auto prog = optimize_tf("mul_test.pb", false);
EXPECT(p == prog);
......@@ -340,15 +383,17 @@ TEST_CASE(mul_test)
TEST_CASE(onehot_test)
{
migraphx::program p;
auto l0 = p.add_literal(
auto* mm = p.get_main_module();
auto l0 = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {5}}, {1, 1, 1, 1, 1}});
p.add_literal(2);
p.add_literal(1.0f);
p.add_literal(0.0f);
auto l1 = p.add_literal(
mm->add_literal(2);
mm->add_literal(1.0f);
mm->add_literal(0.0f);
auto l1 = mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::float_type, {2, 2}}, {1, 0, 0, 1}});
int axis = 0;
p.add_instruction(migraphx::op::gather{axis}, l1, l0);
mm->add_instruction(migraphx::op::gather{axis}, l1, l0);
auto prog = optimize_tf("onehot_test.pb", false);
EXPECT(p == prog);
......@@ -357,9 +402,11 @@ TEST_CASE(onehot_test)
TEST_CASE(pack_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2}});
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {2}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {2}});
auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {2}});
std::vector<migraphx::instruction_ref> args{l0, l1, l2};
std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t axis = 1;
......@@ -368,9 +415,9 @@ TEST_CASE(pack_test)
args.end(),
std::back_inserter(unsqueezed_args),
[&](migraphx::instruction_ref arg) {
return p.add_instruction(migraphx::op::unsqueeze{{axis}}, arg);
return mm->add_instruction(migraphx::op::unsqueeze{{axis}}, arg);
});
p.add_instruction(migraphx::op::concat{static_cast<int>(axis)}, unsqueezed_args);
mm->add_instruction(migraphx::op::concat{static_cast<int>(axis)}, unsqueezed_args);
auto prog = optimize_tf("pack_test.pb", false);
EXPECT(p == prog);
......@@ -379,12 +426,14 @@ TEST_CASE(pack_test)
TEST_CASE(pack_test_nhwc)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt0 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l1);
auto l2 = p.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt2 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l2);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt0 = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt1 = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l1);
auto l2 = mm->add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt2 = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l2);
std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t nchw_axis = 3;
......@@ -393,9 +442,9 @@ TEST_CASE(pack_test_nhwc)
args.end(),
std::back_inserter(unsqueezed_args),
[&](migraphx::instruction_ref arg) {
return p.add_instruction(migraphx::op::unsqueeze{{nchw_axis}}, arg);
return mm->add_instruction(migraphx::op::unsqueeze{{nchw_axis}}, arg);
});
p.add_instruction(migraphx::op::concat{static_cast<int>(nchw_axis)}, unsqueezed_args);
mm->add_instruction(migraphx::op::concat{static_cast<int>(nchw_axis)}, unsqueezed_args);
auto prog = optimize_tf("pack_test_nhwc.pb", true);
EXPECT(p == prog);
......@@ -404,14 +453,16 @@ TEST_CASE(pack_test_nhwc)
TEST_CASE(pooling_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
migraphx::op::pooling avg_pool_op{"average"};
migraphx::op::pooling max_pool_op{"max"};
avg_pool_op.stride = {2, 2};
max_pool_op.stride = {2, 2};
avg_pool_op.lengths = {2, 2};
max_pool_op.lengths = {2, 2};
p.add_instruction(max_pool_op, l0);
mm->add_instruction(max_pool_op, l0);
auto prog = optimize_tf("pooling_test.pb", true);
EXPECT(p == prog);
......@@ -420,9 +471,11 @@ TEST_CASE(pooling_test)
TEST_CASE(pow_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
p.add_instruction(migraphx::op::pow{}, l0, l1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
mm->add_instruction(migraphx::op::pow{}, l0, l1);
auto prog = optimize_tf("pow_test.pb", false);
EXPECT(p == prog);
......@@ -431,8 +484,10 @@ TEST_CASE(pow_test)
TEST_CASE(relu_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::relu{}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
mm->add_instruction(migraphx::op::relu{}, l0);
auto prog = optimize_tf("relu_test.pb", false);
EXPECT(p == prog);
......@@ -441,13 +496,15 @@ TEST_CASE(relu_test)
TEST_CASE(relu6_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<size_t> input_lens{1, 3, 16, 16};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, input_lens});
auto min_val = p.add_literal(0.0f);
auto max_val = p.add_literal(6.0f);
min_val = p.add_instruction(migraphx::op::multibroadcast{input_lens}, min_val);
max_val = p.add_instruction(migraphx::op::multibroadcast{input_lens}, max_val);
p.add_instruction(migraphx::op::clip{}, l0, min_val, max_val);
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, input_lens});
auto min_val = mm->add_literal(0.0f);
auto max_val = mm->add_literal(6.0f);
min_val = mm->add_instruction(migraphx::op::multibroadcast{input_lens}, min_val);
max_val = mm->add_instruction(migraphx::op::multibroadcast{input_lens}, max_val);
mm->add_instruction(migraphx::op::clip{}, l0, min_val, max_val);
auto prog = optimize_tf("relu6_test.pb", false);
EXPECT(p == prog);
......@@ -456,11 +513,13 @@ TEST_CASE(relu6_test)
TEST_CASE(reshape_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {16}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {16}});
migraphx::shape s0{migraphx::shape::int32_type, {4}};
// in tf, the second arg is a literal that contains new dimensions
p.add_literal(migraphx::literal{s0, {1, 1, 1, 16}});
p.add_instruction(migraphx::op::reshape{{1, 1, 1, 16}}, l0);
mm->add_literal(migraphx::literal{s0, {1, 1, 1, 16}});
mm->add_instruction(migraphx::op::reshape{{1, 1, 1, 16}}, l0);
auto prog = optimize_tf("reshape_test.pb", false);
EXPECT(p == prog);
......@@ -469,8 +528,10 @@ TEST_CASE(reshape_test)
TEST_CASE(rsqrt_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::rsqrt{}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
mm->add_instruction(migraphx::op::rsqrt{}, l0);
auto prog = optimize_tf("rsqrt_test.pb", false);
EXPECT(p == prog);
......@@ -479,8 +540,10 @@ TEST_CASE(rsqrt_test)
TEST_CASE(shape_test)
{
migraphx::program p;
p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_literal(
auto* mm = p.get_main_module();
mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {4}}, {1, 3, 16, 16}});
auto prog = optimize_tf("shape_test.pb", false);
......@@ -490,18 +553,20 @@ TEST_CASE(shape_test)
TEST_CASE(slice_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::size_t num_axes = 2;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 10}});
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 10}});
migraphx::shape s0{migraphx::shape::int32_type, {num_axes}};
p.add_literal(migraphx::literal{s0, {1, 0}});
p.add_literal(migraphx::literal{s0, {2, -1}});
mm->add_literal(migraphx::literal{s0, {1, 0}});
mm->add_literal(migraphx::literal{s0, {2, -1}});
migraphx::op::slice op;
op.starts = {1, 0};
op.ends = {3, 10};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
p.add_instruction(op, l0);
mm->add_instruction(op, l0);
auto prog = optimize_tf("slice_test.pb", false);
EXPECT(p == prog);
......@@ -510,8 +575,10 @@ TEST_CASE(slice_test)
TEST_CASE(softmax_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
p.add_instruction(migraphx::op::softmax{1}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3}});
mm->add_instruction(migraphx::op::softmax{1}, l0);
auto prog = optimize_tf("softmax_test.pb", false);
EXPECT(p == prog);
......@@ -520,17 +587,19 @@ TEST_CASE(softmax_test)
TEST_CASE(split_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int64_t> axes{0, 1};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
p.add_literal(3); // num_splits
p.add_literal(1); // split axis
p.add_literal(1); // concat axis
p.add_literal(1); // concat axis
auto l1 = p.add_instruction(migraphx::op::slice{axes, {0, 0}, {5, 10}}, l0);
auto l2 = p.add_instruction(migraphx::op::slice{axes, {0, 10}, {5, 20}}, l0);
auto l3 = p.add_instruction(migraphx::op::slice{axes, {0, 20}, {5, 30}}, l0);
p.add_instruction(migraphx::op::concat{1}, l1, l2);
p.add_instruction(migraphx::op::concat{1}, l2, l3);
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
mm->add_literal(3); // num_splits
mm->add_literal(1); // split axis
mm->add_literal(1); // concat axis
mm->add_literal(1); // concat axis
auto l1 = mm->add_instruction(migraphx::op::slice{axes, {0, 0}, {5, 10}}, l0);
auto l2 = mm->add_instruction(migraphx::op::slice{axes, {0, 10}, {5, 20}}, l0);
auto l3 = mm->add_instruction(migraphx::op::slice{axes, {0, 20}, {5, 30}}, l0);
mm->add_instruction(migraphx::op::concat{1}, l1, l2);
mm->add_instruction(migraphx::op::concat{1}, l2, l3);
auto prog = parse_tf("split_test.pb", false);
......@@ -540,10 +609,12 @@ TEST_CASE(split_test)
TEST_CASE(split_test_one_output)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
p.add_literal(1); // num_splits
p.add_literal(1); // split axis
p.add_instruction(migraphx::op::identity{}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
mm->add_literal(1); // num_splits
mm->add_literal(1); // split axis
mm->add_instruction(migraphx::op::identity{}, l0);
auto prog = parse_tf("split_test_one_output.pb", false);
......@@ -553,19 +624,21 @@ TEST_CASE(split_test_one_output)
TEST_CASE(split_test_vector_as_input)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int64_t> axes{0, 1};
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {5, 30}});
// split sizes
p.add_literal(
mm->add_literal(
migraphx::literal{migraphx::shape{migraphx::shape::int32_type, {3}}, {4, 15, 11}});
p.add_literal(1); // split axis
p.add_literal(1); // concat axis
p.add_literal(1); // concat axis
auto l1 = p.add_instruction(migraphx::op::slice{axes, {0, 0}, {5, 4}}, l0);
auto l2 = p.add_instruction(migraphx::op::slice{axes, {0, 4}, {5, 19}}, l0);
auto l3 = p.add_instruction(migraphx::op::slice{axes, {0, 19}, {5, 30}}, l0);
p.add_instruction(migraphx::op::concat{1}, l1, l2);
p.add_instruction(migraphx::op::concat{1}, l2, l3);
mm->add_literal(1); // split axis
mm->add_literal(1); // concat axis
mm->add_literal(1); // concat axis
auto l1 = mm->add_instruction(migraphx::op::slice{axes, {0, 0}, {5, 4}}, l0);
auto l2 = mm->add_instruction(migraphx::op::slice{axes, {0, 4}, {5, 19}}, l0);
auto l3 = mm->add_instruction(migraphx::op::slice{axes, {0, 19}, {5, 30}}, l0);
mm->add_instruction(migraphx::op::concat{1}, l1, l2);
mm->add_instruction(migraphx::op::concat{1}, l2, l3);
auto prog = parse_tf("split_test_vector_as_input.pb", false);
......@@ -575,9 +648,11 @@ TEST_CASE(split_test_vector_as_input)
TEST_CASE(sqdiff_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
p.add_instruction(migraphx::op::sqdiff{}, l0, l1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
mm->add_instruction(migraphx::op::sqdiff{}, l0, l1);
auto prog = optimize_tf("sqdiff_test.pb", false);
EXPECT(p == prog);
......@@ -586,8 +661,10 @@ TEST_CASE(sqdiff_test)
TEST_CASE(squeeze_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}});
p.add_instruction(migraphx::op::squeeze{{0, 3}}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 3, 1}});
mm->add_instruction(migraphx::op::squeeze{{0, 3}}, l0);
auto prog = optimize_tf("squeeze_test.pb", false);
EXPECT(p == prog);
......@@ -596,8 +673,10 @@ TEST_CASE(squeeze_test)
TEST_CASE(stopgradient_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::identity{}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
mm->add_instruction(migraphx::op::identity{}, l0);
auto prog = optimize_tf("stopgradient_test.pb", false);
EXPECT(p == prog);
......@@ -606,17 +685,19 @@ TEST_CASE(stopgradient_test)
TEST_CASE(stridedslice_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
auto l1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 1, 1}});
auto l1 = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
std::size_t num_axes = 4;
migraphx::op::slice op;
op.starts = {0, 0, 0, 0};
op.ends = {1, 1, 1, 5};
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
auto l2 = p.add_instruction(op, l1);
auto l2 = mm->add_instruction(op, l1);
auto shrink_axis = 1;
p.add_instruction(migraphx::op::squeeze{{shrink_axis}}, l2);
mm->add_instruction(migraphx::op::squeeze{{shrink_axis}}, l2);
auto prog = optimize_tf("stridedslice_test.pb", true);
EXPECT(p == prog);
......@@ -625,7 +706,9 @@ TEST_CASE(stridedslice_test)
TEST_CASE(stridedslice_masks_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 10, 3, 3}});
std::size_t num_axes = 4;
migraphx::op::slice op;
op.starts = {0, 1, 1, 0};
......@@ -633,13 +716,16 @@ TEST_CASE(stridedslice_masks_test)
op.axes = std::vector<int64_t>(num_axes);
std::iota(op.axes.begin(), op.axes.end(), 0);
// add literals for starts, ends, and strides in tf (NHWC format)
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 1, 1, 0});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{0, 0, 0, 0});
p.add_literal(migraphx::shape{migraphx::shape::int32_type, {4}}, std::vector<int>{1, 1, 1, 1});
auto l1 = p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto l2 = p.add_instruction(op, l1);
p.add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l2);
mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
std::vector<int>{0, 1, 1, 0});
mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
std::vector<int>{0, 0, 0, 0});
mm->add_literal(migraphx::shape{migraphx::shape::int32_type, {4}},
std::vector<int>{1, 1, 1, 1});
auto l1 = mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto l2 = mm->add_instruction(op, l1);
mm->add_instruction(migraphx::op::transpose{{0, 3, 1, 2}}, l2);
auto prog = parse_tf("stridedslice_masks_test.pb", true);
EXPECT(p == prog);
......@@ -648,9 +734,11 @@ TEST_CASE(stridedslice_masks_test)
TEST_CASE(sub_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
p.add_instruction(migraphx::op::sub{}, l0, l1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
mm->add_instruction(migraphx::op::sub{}, l0, l1);
auto prog = parse_tf("sub_test.pb", false);
EXPECT(p == prog);
......@@ -659,9 +747,11 @@ TEST_CASE(sub_test)
TEST_CASE(tanh_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = p.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
p.add_instruction(migraphx::op::sub{}, l0, l1);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 2, 3}});
mm->add_instruction(migraphx::op::sub{}, l0, l1);
auto prog = parse_tf("sub_test.pb", false);
EXPECT(p == prog);
......@@ -670,10 +760,12 @@ TEST_CASE(tanh_test)
TEST_CASE(transpose_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
migraphx::shape s0{migraphx::shape::int32_type, {4}};
p.add_literal(migraphx::literal{s0, {0, 2, 3, 1}});
p.add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
mm->add_literal(migraphx::literal{s0, {0, 2, 3, 1}});
mm->add_instruction(migraphx::op::transpose{{0, 2, 3, 1}}, l0);
auto prog = optimize_tf("transpose_test.pb", false);
EXPECT(p == prog);
......@@ -682,8 +774,10 @@ TEST_CASE(transpose_test)
TEST_CASE(variable_batch_test)
{
migraphx::program p;
auto l0 = p.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
p.add_instruction(migraphx::op::identity{}, l0);
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 3, 16, 16}});
mm->add_instruction(migraphx::op::identity{}, l0);
auto prog = optimize_tf("variable_batch_test.pb", false);
EXPECT(p == prog);
......
......@@ -7,10 +7,10 @@
TEST_CASE(simple_test)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
p.add_instruction(sum_op{}, one, two);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
mm->add_instruction(sum_op{}, one, two);
EXPECT(bool{p.validate() == p.end()});
auto result = p.eval({});
EXPECT(result.back() == migraphx::literal{3});
......@@ -20,21 +20,21 @@ TEST_CASE(simple_test)
TEST_CASE(out_of_order)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto ins = p.add_instruction(sum_op{}, one, two);
p.move_instruction(two, p.end());
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto ins = mm->add_instruction(sum_op{}, one, two);
mm->move_instruction(two, p.end());
EXPECT(bool{p.validate() == ins});
}
TEST_CASE(incomplete_args)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto ins = p.add_instruction(sum_op{}, one, two);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto ins = mm->add_instruction(sum_op{}, one, two);
ins->clear_arguments();
EXPECT(bool{p.validate() == ins});
}
......@@ -47,10 +47,10 @@ MIGRAPHX_ROB(access_ins_arguments,
TEST_CASE(invalid_args)
{
migraphx::program p;
auto one = p.add_literal(1);
auto two = p.add_literal(2);
auto ins = p.add_instruction(sum_op{}, one, two);
auto* mm = p.get_main_module();
auto one = mm->add_literal(1);
auto two = mm->add_literal(2);
auto ins = mm->add_instruction(sum_op{}, one, two);
access_ins_arguments(*ins).clear();
EXPECT(bool{p.validate() == p.begin()});
}
......
......@@ -9,16 +9,17 @@ struct batch_quant_dot_1 : verify_program<batch_quant_dot_1>
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 8, 2}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 7, 8}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto tl1 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1);
auto l2 = p.add_parameter("b", m2_shape);
auto tl2 = p.add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l2);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2, l3);
auto l1 = mm->add_parameter("a", m1_shape);
auto tl1 = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l1);
auto l2 = mm->add_parameter("b", m2_shape);
auto tl2 = mm->add_instruction(migraphx::op::transpose{{0, 1, 3, 2}}, l2);
auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction(migraphx::op::quant_dot{3, 2}, tl1, tl2, l3);
return p;
}
};
......@@ -9,14 +9,15 @@ struct batch_quant_dot_2 : verify_program<batch_quant_dot_2>
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::int8_type, {3, 2, 2, 8}};
migraphx::shape m2_shape{migraphx::shape::int8_type, {3, 2, 8, 7}};
migraphx::shape m3_shape{migraphx::shape::int32_type, {3, 2, 2, 7}};
auto l1 = p.add_parameter("a", m1_shape);
auto l2 = p.add_parameter("b", m2_shape);
auto l3 = p.add_parameter("c", m3_shape);
p.add_instruction(migraphx::op::quant_dot{1, 3}, l1, l2, l3);
auto l1 = mm->add_parameter("a", m1_shape);
auto l2 = mm->add_parameter("b", m2_shape);
auto l3 = mm->add_parameter("c", m3_shape);
mm->add_instruction(migraphx::op::quant_dot{1, 3}, l1, l2, l3);
return p;
}
};
......@@ -9,14 +9,15 @@ struct gemm_2args_bmv : verify_program<gemm_2args_bmv>
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 3, 5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2);
auto bul2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 5, 1}}, ul2);
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
auto ul2 = mm->add_instruction(migraphx::op::unsqueeze{{1}}, l2);
auto bul2 = mm->add_instruction(migraphx::op::multibroadcast{{2, 3, 5, 1}}, ul2);
p.add_instruction(migraphx::op::dot{}, l1, bul2);
mm->add_instruction(migraphx::op::dot{}, l1, bul2);
return p;
}
......
......@@ -9,13 +9,14 @@ struct gemm_2args_mm_1 : verify_program<gemm_2args_mm_1>
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4}}, l2);
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
auto bl2 = mm->add_instruction(migraphx::op::multibroadcast{{2, 3, 4}}, l2);
p.add_instruction(migraphx::op::dot{}, l1, bl2);
mm->add_instruction(migraphx::op::dot{}, l1, bl2);
return p;
}
......
......@@ -9,13 +9,14 @@ struct gemm_2args_mm_2 : verify_program<gemm_2args_mm_2>
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 4}}, l2);
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
auto bl2 = mm->add_instruction(migraphx::op::multibroadcast{{2, 3, 4}}, l2);
p.add_instruction(migraphx::op::dot{}, l1, bl2);
mm->add_instruction(migraphx::op::dot{}, l1, bl2);
return p;
}
......
......@@ -9,13 +9,14 @@ struct gemm_2args_mm_3 : verify_program<gemm_2args_mm_3>
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
auto l1 = mm->add_parameter("1", m1_shape);
auto bl1 = mm->add_instruction(migraphx::op::multibroadcast{{3, 2, 3}}, l1);
auto l2 = mm->add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, bl1, l2);
mm->add_instruction(migraphx::op::dot{}, bl1, l2);
return p;
}
......
......@@ -9,13 +9,14 @@ struct gemm_2args_mm_4 : verify_program<gemm_2args_mm_4>
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
auto l1 = mm->add_parameter("1", m1_shape);
auto bl1 = mm->add_instruction(migraphx::op::multibroadcast{{3, 2, 3}}, l1);
auto l2 = mm->add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, bl1, l2);
mm->add_instruction(migraphx::op::dot{}, bl1, l2);
return p;
}
......
......@@ -9,13 +9,14 @@ struct gemm_2args_mm_5 : verify_program<gemm_2args_mm_5>
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
auto l1 = mm->add_parameter("1", m1_shape);
auto bl1 = mm->add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = mm->add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, bl1, l2);
mm->add_instruction(migraphx::op::dot{}, bl1, l2);
return p;
}
......
......@@ -9,14 +9,15 @@ struct gemm_2args_mm_6 : verify_program<gemm_2args_mm_6>
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
auto bl2 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 3, 4}}, l2);
auto l1 = mm->add_parameter("1", m1_shape);
auto bl1 = mm->add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = mm->add_parameter("2", m2_shape);
auto bl2 = mm->add_instruction(migraphx::op::multibroadcast{{2, 3, 3, 4}}, l2);
p.add_instruction(migraphx::op::dot{}, bl1, bl2);
mm->add_instruction(migraphx::op::dot{}, bl1, bl2);
return p;
}
......
......@@ -9,13 +9,14 @@ struct gemm_2args_mm_7 : verify_program<gemm_2args_mm_7>
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto bl1 = p.add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
auto l1 = mm->add_parameter("1", m1_shape);
auto bl1 = mm->add_instruction(migraphx::op::multibroadcast{{2, 3, 2, 3}}, l1);
auto l2 = mm->add_parameter("2", m2_shape);
p.add_instruction(migraphx::op::dot{}, bl1, l2);
mm->add_instruction(migraphx::op::dot{}, bl1, l2);
return p;
}
......
......@@ -9,13 +9,14 @@ struct gemm_2args_mv : verify_program<gemm_2args_mv>
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {3, 5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2);
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
auto ul2 = mm->add_instruction(migraphx::op::unsqueeze{{1}}, l2);
p.add_instruction(migraphx::op::dot{}, l1, ul2);
mm->add_instruction(migraphx::op::dot{}, l1, ul2);
return p;
}
......
......@@ -9,16 +9,17 @@ struct gemm_2args_vbm : verify_program<gemm_2args_vbm>
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 2, 5, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto bul1 = p.add_instruction(migraphx::op::multibroadcast{{2, 2, 1, 5}}, ul1);
auto l1 = mm->add_parameter("1", m1_shape);
auto ul1 = mm->add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto bul1 = mm->add_instruction(migraphx::op::multibroadcast{{2, 2, 1, 5}}, ul1);
auto l2 = p.add_parameter("2", m2_shape);
auto l2 = mm->add_parameter("2", m2_shape);
auto res = p.add_instruction(migraphx::op::dot{}, bul1, l2);
p.add_instruction(migraphx::op::squeeze{{2}}, res);
auto res = mm->add_instruction(migraphx::op::dot{}, bul1, l2);
mm->add_instruction(migraphx::op::squeeze{{2}}, res);
return p;
}
......
......@@ -9,14 +9,15 @@ struct gemm_2args_vm : verify_program<gemm_2args_vm>
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {5}};
migraphx::shape m2_shape{migraphx::shape::float_type, {5, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
auto l1 = mm->add_parameter("1", m1_shape);
auto ul1 = mm->add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto l2 = mm->add_parameter("2", m2_shape);
auto res = p.add_instruction(migraphx::op::dot{}, ul1, l2);
p.add_instruction(migraphx::op::squeeze{{0}}, res);
auto res = mm->add_instruction(migraphx::op::dot{}, ul1, l2);
mm->add_instruction(migraphx::op::squeeze{{0}}, res);
return p;
}
......
......@@ -9,17 +9,18 @@ struct gemm_2args_vv : verify_program<gemm_2args_vv>
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {8}};
migraphx::shape m2_shape{migraphx::shape::float_type, {8}};
auto l1 = p.add_parameter("1", m1_shape);
auto ul1 = p.add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto l2 = p.add_parameter("2", m2_shape);
auto ul2 = p.add_instruction(migraphx::op::unsqueeze{{1}}, l2);
auto l1 = mm->add_parameter("1", m1_shape);
auto ul1 = mm->add_instruction(migraphx::op::unsqueeze{{0}}, l1);
auto l2 = mm->add_parameter("2", m2_shape);
auto ul2 = mm->add_instruction(migraphx::op::unsqueeze{{1}}, l2);
float alpha = 0.23f;
auto res = p.add_instruction(migraphx::op::dot{alpha}, ul1, ul2);
auto sres = p.add_instruction(migraphx::op::squeeze{{0}}, res);
p.add_instruction(migraphx::op::squeeze{{0}}, sres);
auto res = mm->add_instruction(migraphx::op::dot{alpha}, ul1, ul2);
auto sres = mm->add_instruction(migraphx::op::squeeze{{0}}, res);
mm->add_instruction(migraphx::op::squeeze{{0}}, sres);
return p;
}
......
......@@ -9,16 +9,17 @@ struct gemm_multi_3args : verify_program<gemm_multi_3args>
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {2, 3, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {2, 3, 3, 2}};
migraphx::shape m3_shape{migraphx::shape::float_type, {2, 3, 2, 2}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
auto l3 = mm->add_parameter("3", m3_shape);
float alpha = 0.35;
float beta = 0.41;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
mm->add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
......
......@@ -9,16 +9,17 @@ struct gemm_multi_3args_alpha0 : verify_program<gemm_multi_3args_alpha0>
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::float_type, {1, 3, 4}};
migraphx::shape m3_shape{migraphx::shape::float_type, {1, 2, 4}};
auto l1 = p.add_parameter("1", m1_shape);
auto l2 = p.add_parameter("2", m2_shape);
auto l3 = p.add_parameter("3", m3_shape);
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
auto l3 = mm->add_parameter("3", m3_shape);
float alpha = 0.0f;
float beta = 1.0f;
p.add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
mm->add_instruction(migraphx::op::dot{alpha, beta}, l1, l2, l3);
return p;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment