Unverified Commit 1a4ff504 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Horizontal fusions of gemms and convolutions (#472)



* Add decompose pass

* Add decompose test

* Formatting

* Add remap

* Formatting

* Add compute method for dot

* Formatting

* Add finder for horizontal fusion

* Formatting

* Formatting

* Reuse predicate

* Add gemm fusions

* Formatting

* Add some fixes for convolution

* Formatting

* Fix shape tests

* Formatting

* Reuse axis equal

* Add initial split fusion

* Formatting

* Update offset

* Workaround outputs that cant accept nonstandard shapes

* Formatting

* Add check for split concat

* Formatting

* Add missing headers

* Formatting

* Add tests

* Formatting

* Add more testing

* Formatting

* Fix when there is duplicate splits in inputs

* Formatting

* Fix mismatch iterators

* Add tests for dot fusions

* Formatting

* Add test for convolution

* Formatting

* Fix tidy issues

* Add more tests

* Formatting

* Ignore build directory for codecov

* Add test for groups

* Formatting

* Add more tests for groups

* Formatting

* Add test for missing end slice

* Add newline

* Remove unused function

* Add support for when beta is not 1

* Formatting

* Add test for scalar

* Add one more scalar test
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 45bb91ea
......@@ -563,6 +563,7 @@ TEST_CASE(simplify_rsqrt)
migraphx::program p2;
{
auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}});
p2.add_instruction(migraphx::op::rsqrt{}, x);
}
......@@ -585,4 +586,631 @@ TEST_CASE(simplify_rsqrt_multi_use)
EXPECT(p1 == p2);
}
TEST_CASE(simplify_split_add_relu)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::program p1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = p1.add_parameter("input", s);
auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto add = p1.add_instruction(migraphx::op::add{}, relu1, relu2);
p1.add_instruction(pass_op{}, add);
}
run_pass(p1);
migraphx::program p2;
{
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto input = p2.add_parameter("input", s);
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto concat = p2.add_instruction(migraphx::op::concat{0}, one, two);
auto concatb = p2.add_instruction(b, concat);
auto sum = p2.add_instruction(migraphx::op::add{}, input, concatb);
auto relu = p2.add_instruction(migraphx::op::relu{}, sum);
auto x = p2.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, relu);
auto y = p2.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, relu);
auto add = p2.add_instruction(migraphx::op::add{}, x, y);
p2.add_instruction(pass_op{}, add);
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_split_add_relu_reshape)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::program p1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto r = migraphx::op::reshape{{3, 4}};
auto input = p1.add_parameter("input", s);
auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto reshape1 = p1.add_instruction(r, relu1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto reshape2 = p1.add_instruction(r, relu2);
auto add = p1.add_instruction(migraphx::op::add{}, reshape1, reshape2);
p1.add_instruction(pass_op{}, add);
}
run_pass(p1);
migraphx::program p2;
{
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto r = migraphx::op::reshape{{3, 4}};
auto input = p2.add_parameter("input", s);
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto concat = p2.add_instruction(migraphx::op::concat{0}, one, two);
auto concatb = p2.add_instruction(b, concat);
auto sum = p2.add_instruction(migraphx::op::add{}, input, concatb);
auto relu = p2.add_instruction(migraphx::op::relu{}, sum);
auto slice1 = p2.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, relu);
auto cont1 = p2.add_instruction(migraphx::op::contiguous{}, slice1);
auto reshape1 = p2.add_instruction(r, cont1);
auto slice2 = p2.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, relu);
auto cont2 = p2.add_instruction(migraphx::op::contiguous{}, slice2);
auto reshape2 = p2.add_instruction(r, cont2);
auto add = p2.add_instruction(migraphx::op::add{}, reshape1, reshape2);
p2.add_instruction(pass_op{}, add);
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_slice_different_axis)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 2}};
migraphx::program p1;
{
auto r = migraphx::op::reshape{{3, 2, 4}};
auto input = p1.add_parameter("input", s);
auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = p1.add_instruction(migraphx::op::slice{{3}, {0}, {1}}, input);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(migraphx::op::broadcast{1, {3, 1, 4, 2}}, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(migraphx::op::broadcast{3, {3, 2, 4, 1}}, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto reshape1 = p1.add_instruction(r, relu1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto reshape2 = p1.add_instruction(r, relu2);
auto add = p1.add_instruction(migraphx::op::add{}, reshape1, reshape2);
p1.add_instruction(pass_op{}, add);
}
migraphx::program p2 = p1;
run_pass(p1);
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_slice_missing_begining_slice)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::program p1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = p1.add_parameter("input", s);
auto x = p1.add_instruction(migraphx::op::slice{{1}, {2}, {3}}, input);
auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto add = p1.add_instruction(migraphx::op::add{}, relu1, relu2);
p1.add_instruction(pass_op{}, add);
}
migraphx::program p2 = p1;
run_pass(p1);
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_slice_missing_middle_slice)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::program p1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = p1.add_parameter("input", s);
auto x = p1.add_instruction(migraphx::op::slice{{1}, {2}, {3}}, input);
auto y = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto add = p1.add_instruction(migraphx::op::add{}, relu1, relu2);
p1.add_instruction(pass_op{}, add);
}
migraphx::program p2 = p1;
run_pass(p1);
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_slice_missing_end_slice)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}};
migraphx::program p1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = p1.add_parameter("input", s);
auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto add = p1.add_instruction(migraphx::op::add{}, relu1, relu2);
p1.add_instruction(pass_op{}, add);
}
migraphx::program p2 = p1;
run_pass(p1);
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_split_add_relu_concat_same_axis)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::program p1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = p1.add_parameter("input", s);
auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto concat = p1.add_instruction(migraphx::op::concat{1}, relu1, relu2);
p1.add_instruction(pass_op{}, concat);
}
run_pass(p1);
migraphx::program p2;
{
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto input = p2.add_parameter("input", s);
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto concat = p2.add_instruction(migraphx::op::concat{0}, one, two);
auto concatb = p2.add_instruction(b, concat);
auto sum = p2.add_instruction(migraphx::op::add{}, input, concatb);
auto relu = p2.add_instruction(migraphx::op::relu{}, sum);
p2.add_instruction(pass_op{}, relu);
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_split_add_relu_multi_axes)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 6}};
migraphx::program p1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4, 3}};
auto input = p1.add_parameter("input", s);
auto x = p1.add_instruction(migraphx::op::slice{{1, 3}, {0, 0}, {1, 3}}, input);
auto y = p1.add_instruction(migraphx::op::slice{{1, 3}, {1, 3}, {2, 6}}, input);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto add = p1.add_instruction(migraphx::op::add{}, relu1, relu2);
p1.add_instruction(pass_op{}, add);
}
migraphx::program p2 = p1;
run_pass(p1);
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_split_add_relu_used_multiple_split1)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::program p1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = p1.add_parameter("input", s);
auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto add1 = p1.add_instruction(migraphx::op::add{}, relu1, relu2);
auto add2 = p1.add_instruction(migraphx::op::add{}, x, add1);
p1.add_instruction(pass_op{}, add2);
}
run_pass(p1);
migraphx::program p2;
{
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto input = p2.add_parameter("input", s);
auto slice = p2.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto concat = p2.add_instruction(migraphx::op::concat{0}, one, two);
auto concatb = p2.add_instruction(b, concat);
auto sum = p2.add_instruction(migraphx::op::add{}, input, concatb);
auto relu = p2.add_instruction(migraphx::op::relu{}, sum);
auto x = p2.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, relu);
auto y = p2.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, relu);
auto add1 = p2.add_instruction(migraphx::op::add{}, x, y);
auto add2 = p2.add_instruction(migraphx::op::add{}, slice, add1);
p2.add_instruction(pass_op{}, add2);
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_split_add_relu_used_multiple_split2)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::program p1;
{
auto b = migraphx::op::broadcast{1, {3, 1, 4}};
auto input = p1.add_parameter("input", s);
auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto z = p1.add_instruction(migraphx::op::relu{}, x);
auto one = p1.add_literal(1);
auto oneb = p1.add_instruction(b, one);
auto two = p1.add_literal(2);
auto twob = p1.add_instruction(b, two);
auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb);
auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1);
auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob);
auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2);
auto add1 = p1.add_instruction(migraphx::op::add{}, relu1, relu2);
auto add2 = p1.add_instruction(migraphx::op::add{}, z, add1);
p1.add_instruction(pass_op{}, add2);
}
run_pass(p1);
migraphx::program p2;
{
auto b = migraphx::op::broadcast{1, {3, 2, 4}};
auto input = p2.add_parameter("input", s);
auto slice = p2.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto z = p2.add_instruction(migraphx::op::relu{}, slice);
auto one = p2.add_literal(1);
auto two = p2.add_literal(2);
auto concat = p2.add_instruction(migraphx::op::concat{0}, one, two);
auto concatb = p2.add_instruction(b, concat);
auto sum = p2.add_instruction(migraphx::op::add{}, input, concatb);
auto relu = p2.add_instruction(migraphx::op::relu{}, sum);
auto x = p2.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, relu);
auto y = p2.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, relu);
auto add1 = p2.add_instruction(migraphx::op::add{}, x, y);
auto add2 = p2.add_instruction(migraphx::op::add{}, z, add1);
p2.add_instruction(pass_op{}, add2);
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_split_between_add)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}};
migraphx::program p1;
{
auto input = p1.add_parameter("input", s);
auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input);
auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input);
auto sum = p1.add_instruction(migraphx::op::add{}, x, y);
p1.add_instruction(pass_op{}, sum);
}
migraphx::program p2 = p1;
run_pass(p1);
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_dot_horiz)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}};
migraphx::program p1;
{
auto input = p1.add_parameter("input", s);
auto a = p1.add_literal(migraphx::generate_literal(s, 0));
auto b = p1.add_literal(migraphx::generate_literal(s, 1));
auto x = p1.add_instruction(migraphx::op::dot{}, input, a);
auto y = p1.add_instruction(migraphx::op::dot{}, input, b);
auto sum = p1.add_instruction(migraphx::op::add{}, x, y);
p1.add_instruction(pass_op{}, sum);
}
run_pass(p1);
migraphx::program p2;
{
auto input = p2.add_parameter("input", s);
auto a = p2.add_literal(migraphx::generate_literal(s, 0));
auto b = p2.add_literal(migraphx::generate_literal(s, 1));
auto concat = p2.add_instruction(migraphx::op::concat{2}, a, b);
auto dot = p2.add_instruction(migraphx::op::dot{}, input, concat);
auto x = p2.add_instruction(migraphx::op::slice{{2}, {0}, {2}}, dot);
auto y = p2.add_instruction(migraphx::op::slice{{2}, {2}, {4}}, dot);
auto sum = p2.add_instruction(migraphx::op::add{}, x, y);
p2.add_instruction(pass_op{}, sum);
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_dot_horiz_same_constant)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}};
migraphx::program p1;
{
auto input = p1.add_parameter("input", s);
auto a = p1.add_literal(migraphx::generate_literal(s, 0));
auto x = p1.add_instruction(migraphx::op::dot{}, input, a);
auto y = p1.add_instruction(migraphx::op::dot{}, input, a);
auto sum = p1.add_instruction(migraphx::op::add{}, x, y);
p1.add_instruction(pass_op{}, sum);
}
run_pass(p1);
migraphx::program p2;
{
auto input = p2.add_parameter("input", s);
auto a = p2.add_literal(migraphx::generate_literal(s, 0));
auto concat = p2.add_instruction(migraphx::op::concat{2}, a, a);
auto dot = p2.add_instruction(migraphx::op::dot{}, input, concat);
auto x = p2.add_instruction(migraphx::op::slice{{2}, {0}, {2}}, dot);
auto y = p2.add_instruction(migraphx::op::slice{{2}, {2}, {4}}, dot);
auto sum = p2.add_instruction(migraphx::op::add{}, x, y);
p2.add_instruction(pass_op{}, sum);
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_dot_horiz_flipped)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}};
migraphx::program p1;
{
auto input = p1.add_parameter("input", s);
auto a = p1.add_literal(migraphx::generate_literal(s, 0));
auto b = p1.add_literal(migraphx::generate_literal(s, 1));
auto x = p1.add_instruction(migraphx::op::dot{}, input, a);
auto y = p1.add_instruction(migraphx::op::dot{}, b, input);
auto sum = p1.add_instruction(migraphx::op::add{}, x, y);
p1.add_instruction(pass_op{}, sum);
}
migraphx::program p2 = p1;
run_pass(p1);
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_conv_horiz)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {8, 3, 64, 64}};
auto ws = migraphx::shape{migraphx::shape::int32_type, {12, 3, 3, 3}};
migraphx::program p1;
{
auto input = p1.add_parameter("input", s);
auto a = p1.add_literal(migraphx::generate_literal(ws, 0));
auto b = p1.add_literal(migraphx::generate_literal(ws, 1));
auto x = p1.add_instruction(migraphx::op::convolution{}, input, a);
auto y = p1.add_instruction(migraphx::op::convolution{}, input, b);
auto sum = p1.add_instruction(migraphx::op::add{}, x, y);
p1.add_instruction(pass_op{}, sum);
}
run_pass(p1);
migraphx::program p2;
{
auto input = p2.add_parameter("input", s);
auto a = p2.add_literal(migraphx::generate_literal(ws, 0));
auto b = p2.add_literal(migraphx::generate_literal(ws, 1));
auto concat = p2.add_instruction(migraphx::op::concat{0}, a, b);
auto conv = p2.add_instruction(migraphx::op::convolution{}, input, concat);
auto x = p2.add_instruction(migraphx::op::slice{{1}, {0}, {12}}, conv);
auto y = p2.add_instruction(migraphx::op::slice{{1}, {12}, {24}}, conv);
auto sum = p2.add_instruction(migraphx::op::add{}, x, y);
p2.add_instruction(pass_op{}, sum);
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_conv_horiz_groups)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}};
auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
migraphx::program p1;
{
auto input = p1.add_parameter("input", s);
auto a = p1.add_literal(migraphx::generate_literal(ws1, 0));
auto b = p1.add_literal(migraphx::generate_literal(ws1, 1));
auto c = p1.add_literal(migraphx::generate_literal(ws2, 2));
auto d = p1.add_literal(migraphx::generate_literal(ws2, 3));
auto convx = p1.add_instruction(migraphx::op::convolution{{1, 1}}, input, a);
auto convy = p1.add_instruction(migraphx::op::convolution{{1, 1}}, input, b);
auto dotx = p1.add_instruction(migraphx::op::dot{}, input, c);
auto doty = p1.add_instruction(migraphx::op::dot{}, input, d);
auto sum1 = p1.add_instruction(migraphx::op::add{}, convx, convy);
auto sum2 = p1.add_instruction(migraphx::op::add{}, dotx, doty);
auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
p1.add_instruction(pass_op{}, sum3);
}
run_pass(p1);
migraphx::program p2;
{
auto input = p2.add_parameter("input", s);
auto a = p2.add_literal(migraphx::generate_literal(ws1, 0));
auto b = p2.add_literal(migraphx::generate_literal(ws1, 1));
auto c = p2.add_literal(migraphx::generate_literal(ws2, 2));
auto d = p2.add_literal(migraphx::generate_literal(ws2, 3));
auto concat1 = p2.add_instruction(migraphx::op::concat{0}, a, b);
auto concat2 = p2.add_instruction(migraphx::op::concat{3}, c, d);
auto conv = p2.add_instruction(migraphx::op::convolution{{1, 1}}, input, concat1);
auto convx = p2.add_instruction(migraphx::op::slice{{1}, {0}, {6}}, conv);
auto convy = p2.add_instruction(migraphx::op::slice{{1}, {6}, {12}}, conv);
auto sum1 = p2.add_instruction(migraphx::op::add{}, convx, convy);
auto dot = p2.add_instruction(migraphx::op::dot{}, input, concat2);
auto dotx = p2.add_instruction(migraphx::op::slice{{3}, {0}, {64}}, dot);
auto doty = p2.add_instruction(migraphx::op::slice{{3}, {64}, {128}}, dot);
auto sum2 = p2.add_instruction(migraphx::op::add{}, dotx, doty);
auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, sum2);
p2.add_instruction(pass_op{}, sum3);
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_conv_horiz_groups_extra1)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}};
auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
migraphx::program p1;
{
auto input = p1.add_parameter("input", s);
auto a = p1.add_literal(migraphx::generate_literal(ws1, 0));
auto b = p1.add_literal(migraphx::generate_literal(ws1, 1));
auto c = p1.add_literal(migraphx::generate_literal(ws2, 2));
auto d = p1.add_literal(migraphx::generate_literal(ws2, 3));
auto e = p1.add_literal(migraphx::generate_literal(s, 4));
auto convx = p1.add_instruction(migraphx::op::convolution{{1, 1}}, input, a);
auto convy = p1.add_instruction(migraphx::op::convolution{{1, 1}}, input, b);
auto dotx = p1.add_instruction(migraphx::op::dot{}, input, c);
auto doty = p1.add_instruction(migraphx::op::dot{}, input, d);
auto sqdiffx = p1.add_instruction(migraphx::op::sqdiff{}, input, e);
auto sum1 = p1.add_instruction(migraphx::op::add{}, convx, convy);
auto sum2 = p1.add_instruction(migraphx::op::add{}, dotx, doty);
auto sum3 = sqdiffx;
auto sum4 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum5 = p1.add_instruction(migraphx::op::add{}, sum4, sum3);
p1.add_instruction(pass_op{}, sum5);
}
run_pass(p1);
migraphx::program p2;
{
auto input = p2.add_parameter("input", s);
auto a = p2.add_literal(migraphx::generate_literal(ws1, 0));
auto b = p2.add_literal(migraphx::generate_literal(ws1, 1));
auto c = p2.add_literal(migraphx::generate_literal(ws2, 2));
auto d = p2.add_literal(migraphx::generate_literal(ws2, 3));
auto e = p2.add_literal(migraphx::generate_literal(s, 4));
auto concat1 = p2.add_instruction(migraphx::op::concat{0}, a, b);
auto concat2 = p2.add_instruction(migraphx::op::concat{3}, c, d);
auto conv = p2.add_instruction(migraphx::op::convolution{{1, 1}}, input, concat1);
auto convx = p2.add_instruction(migraphx::op::slice{{1}, {0}, {6}}, conv);
auto convy = p2.add_instruction(migraphx::op::slice{{1}, {6}, {12}}, conv);
auto sum1 = p2.add_instruction(migraphx::op::add{}, convx, convy);
auto dot = p2.add_instruction(migraphx::op::dot{}, input, concat2);
auto dotx = p2.add_instruction(migraphx::op::slice{{3}, {0}, {64}}, dot);
auto doty = p2.add_instruction(migraphx::op::slice{{3}, {64}, {128}}, dot);
auto sum2 = p2.add_instruction(migraphx::op::add{}, dotx, doty);
auto sqdiffx = p2.add_instruction(migraphx::op::sqdiff{}, input, e);
auto sum3 = sqdiffx;
auto sum4 = p2.add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum5 = p2.add_instruction(migraphx::op::add{}, sum4, sum3);
p2.add_instruction(pass_op{}, sum5);
}
EXPECT(p1.sort() == p2.sort());
}
TEST_CASE(simplify_conv_horiz_groups_extra2)
{
auto s = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}};
auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}};
migraphx::program p1;
{
auto input = p1.add_parameter("input", s);
auto a = p1.add_literal(migraphx::generate_literal(ws1, 0));
auto b = p1.add_literal(migraphx::generate_literal(ws1, 1));
auto c = p1.add_literal(migraphx::generate_literal(ws2, 2));
auto d = p1.add_literal(migraphx::generate_literal(ws2, 3));
auto e = p1.add_literal(migraphx::generate_literal(s, 4));
auto f = p1.add_literal(migraphx::generate_literal(s, 5));
auto convx = p1.add_instruction(migraphx::op::convolution{{1, 1}}, input, a);
auto convy = p1.add_instruction(migraphx::op::convolution{{1, 1}}, input, b);
auto dotx = p1.add_instruction(migraphx::op::dot{}, input, c);
auto doty = p1.add_instruction(migraphx::op::dot{}, input, d);
auto sqdiffx = p1.add_instruction(migraphx::op::sqdiff{}, input, e);
auto sqdiffy = p1.add_instruction(migraphx::op::sqdiff{}, input, f);
auto sum1 = p1.add_instruction(migraphx::op::add{}, convx, convy);
auto sum2 = p1.add_instruction(migraphx::op::add{}, dotx, doty);
auto sum3 = p1.add_instruction(migraphx::op::add{}, sqdiffx, sqdiffy);
auto sum4 = p1.add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum5 = p1.add_instruction(migraphx::op::add{}, sum4, sum3);
p1.add_instruction(pass_op{}, sum5);
}
run_pass(p1);
migraphx::program p2;
{
auto input = p2.add_parameter("input", s);
auto a = p2.add_literal(migraphx::generate_literal(ws1, 0));
auto b = p2.add_literal(migraphx::generate_literal(ws1, 1));
auto c = p2.add_literal(migraphx::generate_literal(ws2, 2));
auto d = p2.add_literal(migraphx::generate_literal(ws2, 3));
auto e = p2.add_literal(migraphx::generate_literal(s, 4));
auto f = p2.add_literal(migraphx::generate_literal(s, 5));
auto concat1 = p2.add_instruction(migraphx::op::concat{0}, a, b);
auto concat2 = p2.add_instruction(migraphx::op::concat{3}, c, d);
auto conv = p2.add_instruction(migraphx::op::convolution{{1, 1}}, input, concat1);
auto convx = p2.add_instruction(migraphx::op::slice{{1}, {0}, {6}}, conv);
auto convy = p2.add_instruction(migraphx::op::slice{{1}, {6}, {12}}, conv);
auto sum1 = p2.add_instruction(migraphx::op::add{}, convx, convy);
auto dot = p2.add_instruction(migraphx::op::dot{}, input, concat2);
auto dotx = p2.add_instruction(migraphx::op::slice{{3}, {0}, {64}}, dot);
auto doty = p2.add_instruction(migraphx::op::slice{{3}, {64}, {128}}, dot);
auto sum2 = p2.add_instruction(migraphx::op::add{}, dotx, doty);
auto sqdiffx = p2.add_instruction(migraphx::op::sqdiff{}, input, e);
auto sqdiffy = p2.add_instruction(migraphx::op::sqdiff{}, input, f);
auto sum3 = p2.add_instruction(migraphx::op::add{}, sqdiffx, sqdiffy);
auto sum4 = p2.add_instruction(migraphx::op::add{}, sum1, sum2);
auto sum5 = p2.add_instruction(migraphx::op::add{}, sum4, sum3);
p2.add_instruction(pass_op{}, sum5);
}
EXPECT(p1.sort() == p2.sort());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment