Commit 9b7a906d authored by turneram's avatar turneram
Browse files

Add test

parent c2c5a632
...@@ -1091,7 +1091,6 @@ void simplify_algebra::apply(module& m) const ...@@ -1091,7 +1091,6 @@ void simplify_algebra::apply(module& m) const
for(int i = 0; i < 8; i++) for(int i = 0; i < 8; i++)
{ {
match::find_matches(m, match::find_matches(m,
find_gelu_erf{},
find_inner_broadcast{}, find_inner_broadcast{},
find_double_add_lit_broadcast{}, find_double_add_lit_broadcast{},
find_add_lit_broadcast{}, find_add_lit_broadcast{},
...@@ -1101,6 +1100,7 @@ void simplify_algebra::apply(module& m) const ...@@ -1101,6 +1100,7 @@ void simplify_algebra::apply(module& m) const
find_mul_slice_conv{}, find_mul_slice_conv{},
find_mul_add{}, find_mul_add{},
find_div_const{}, find_div_const{},
find_gelu_erf{},
find_sub_const{}, find_sub_const{},
find_rsqrt{}, find_rsqrt{},
find_concat_op{}, find_concat_op{},
......
...@@ -734,6 +734,51 @@ TEST_CASE(simplify_div_const) ...@@ -734,6 +734,51 @@ TEST_CASE(simplify_div_const)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(simplify_gelu_erf_with_bias)
{
migraphx::shape s1{migraphx::shape::half_type, {2, 4, 8}};
migraphx::shape s2{migraphx::shape::half_type, {1}};
migraphx::module m1;
{
auto a = m1.add_parameter("a", s1);
auto b = m1.add_parameter("b", s1);
auto add1 = m1.add_instruction(migraphx::make_op("add"), a, b);
auto l1 = m1.add_literal(migraphx::literal{s2, {1.4140625f}});
l1 = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l1);
auto div = m1.add_instruction(migraphx::make_op("div"), add1, l1);
auto erf = m1.add_instruction(migraphx::make_op("erf"), div);
auto l2 = m1.add_literal(migraphx::literal{s2, {1.0f}});
l2 = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l2);
auto add2 = m1.add_instruction(migraphx::make_op("add"), erf, l2);
auto mul = m1.add_instruction(migraphx::make_op("mul"), add1, add2);
auto l3 = m1.add_literal(migraphx::literal{s2, {0.5f}});
l3 = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l3);
m1.add_instruction(migraphx::make_op("mul"), mul, l3);
}
run_pass(m1);
run_pass(m1);
run_pass(m1);
migraphx::module m2;
{
auto a = m2.add_parameter("a", s1);
auto b = m2.add_parameter("b", s1);
auto add = m2.add_instruction(migraphx::make_op("add"), a, b);
auto l1 = m2.add_literal(migraphx::literal{s2, {1.702f}});
l1 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l1);
auto mul = m2.add_instruction(migraphx::make_op("mul"), add, l1);
auto sig = m2.add_instruction(migraphx::make_op("neg"), mul);
sig = m2.add_instruction(migraphx::make_op("exp"), sig);
auto l2 = m2.add_literal(migraphx::literal{s2, {1.0f}});
l2 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l2);
sig = m2.add_instruction(migraphx::make_op("add"), sig, l2);
sig = m2.add_instruction(migraphx::make_op("div"), l2, sig);
m2.add_instruction(migraphx::make_op("mul"), sig, add);
}
EXPECT(m1 == m2);
}
TEST_CASE(simplify_sub_const) TEST_CASE(simplify_sub_const)
{ {
migraphx::module m1; migraphx::module m1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment