Commit 9882f6db authored by turneram's avatar turneram
Browse files

Formatting

parent 9ccf5e5c
...@@ -43,12 +43,12 @@ struct find_gelu_erf ...@@ -43,12 +43,12 @@ struct find_gelu_erf
return; return;
auto lit = m.add_literal(literal{shape{x->get_shape().type()}, {1.702f}}); auto lit = m.add_literal(literal{shape{x->get_shape().type()}, {1.702f}});
auto mul = insert_common_op(m, ins, make_op("mul"), {x, lit}); auto mul = insert_common_op(m, ins, make_op("mul"), {x, lit});
auto sig = m.insert_instruction(ins, make_op("neg"), mul); auto sig = m.insert_instruction(ins, make_op("neg"), mul);
sig = m.insert_instruction(ins, make_op("exp"), sig); sig = m.insert_instruction(ins, make_op("exp"), sig);
auto one = m.add_literal(literal{shape{x->get_shape().type()}, {1.0f}}); auto one = m.add_literal(literal{shape{x->get_shape().type()}, {1.0f}});
sig = insert_common_op(m, ins, make_op("add"), {sig, one}); sig = insert_common_op(m, ins, make_op("add"), {sig, one});
sig = m.insert_instruction(ins, make_op("div"), x, sig); sig = m.insert_instruction(ins, make_op("div"), x, sig);
m.replace_instruction(ins, sig); m.replace_instruction(ins, sig);
} }
}; };
......
...@@ -48,13 +48,13 @@ TEST_CASE(bias_gelu) ...@@ -48,13 +48,13 @@ TEST_CASE(bias_gelu)
auto b = m1.add_parameter("b", s1); auto b = m1.add_parameter("b", s1);
auto add1 = m1.add_instruction(migraphx::make_op("add"), a, b); auto add1 = m1.add_instruction(migraphx::make_op("add"), a, b);
auto l1 = m1.add_literal(migraphx::literal{s2, {1.4140625f}}); auto l1 = m1.add_literal(migraphx::literal{s2, {1.4140625f}});
auto div = add_common_op(m1, migraphx::make_op("div"), {add1, l1}); auto div = add_common_op(m1, migraphx::make_op("div"), {add1, l1});
auto erf = m1.add_instruction(migraphx::make_op("erf"), div); auto erf = m1.add_instruction(migraphx::make_op("erf"), div);
auto l2 = m1.add_literal(migraphx::literal{s2, {1.0f}}); auto l2 = m1.add_literal(migraphx::literal{s2, {1.0f}});
auto add2 = add_common_op(m1, migraphx::make_op("add"), {erf, l2}); auto add2 = add_common_op(m1, migraphx::make_op("add"), {erf, l2});
auto mul = m1.add_instruction(migraphx::make_op("mul"), add1, add2); auto mul = m1.add_instruction(migraphx::make_op("mul"), add1, add2);
auto l3 = m1.add_literal(migraphx::literal{s2, {0.5f}}); auto l3 = m1.add_literal(migraphx::literal{s2, {0.5f}});
mul = add_common_op(m1, migraphx::make_op("mul"), {mul, l3}); mul = add_common_op(m1, migraphx::make_op("mul"), {mul, l3});
m1.add_return({mul}); m1.add_return({mul});
} }
migraphx::rewrite_gelu pass; migraphx::rewrite_gelu pass;
...@@ -72,8 +72,8 @@ TEST_CASE(bias_gelu) ...@@ -72,8 +72,8 @@ TEST_CASE(bias_gelu)
auto sig = m2.add_instruction(migraphx::make_op("neg"), mul); auto sig = m2.add_instruction(migraphx::make_op("neg"), mul);
sig = m2.add_instruction(migraphx::make_op("exp"), sig); sig = m2.add_instruction(migraphx::make_op("exp"), sig);
auto l2 = m2.add_literal(migraphx::literal{s2, {1.0f}}); auto l2 = m2.add_literal(migraphx::literal{s2, {1.0f}});
sig = add_common_op(m2, migraphx::make_op("add"), {sig, l2}); sig = add_common_op(m2, migraphx::make_op("add"), {sig, l2});
sig = m2.add_instruction(migraphx::make_op("div"), add, sig); sig = m2.add_instruction(migraphx::make_op("div"), add, sig);
m2.add_return({sig}); m2.add_return({sig});
} }
...@@ -86,17 +86,17 @@ TEST_CASE(non_bias_gelu) ...@@ -86,17 +86,17 @@ TEST_CASE(non_bias_gelu)
migraphx::shape s2{migraphx::shape::half_type}; migraphx::shape s2{migraphx::shape::half_type};
migraphx::module m1; migraphx::module m1;
{ {
auto a = m1.add_parameter("a", s1); auto a = m1.add_parameter("a", s1);
auto b = m1.add_parameter("b", s1); auto b = m1.add_parameter("b", s1);
auto sub = m1.add_instruction(migraphx::make_op("sub"), a, b); auto sub = m1.add_instruction(migraphx::make_op("sub"), a, b);
auto l1 = m1.add_literal(migraphx::literal{s2, {1.4140625f}}); auto l1 = m1.add_literal(migraphx::literal{s2, {1.4140625f}});
auto div = add_common_op(m1, migraphx::make_op("div"), {sub, l1}); auto div = add_common_op(m1, migraphx::make_op("div"), {sub, l1});
auto erf = m1.add_instruction(migraphx::make_op("erf"), div); auto erf = m1.add_instruction(migraphx::make_op("erf"), div);
auto l2 = m1.add_literal(migraphx::literal{s2, {1.0f}}); auto l2 = m1.add_literal(migraphx::literal{s2, {1.0f}});
auto add2 = add_common_op(m1, migraphx::make_op("add"), {erf, l2}); auto add2 = add_common_op(m1, migraphx::make_op("add"), {erf, l2});
auto mul = m1.add_instruction(migraphx::make_op("mul"), sub, add2); auto mul = m1.add_instruction(migraphx::make_op("mul"), sub, add2);
auto l3 = m1.add_literal(migraphx::literal{s2, {0.5f}}); auto l3 = m1.add_literal(migraphx::literal{s2, {0.5f}});
mul = add_common_op(m1, migraphx::make_op("mul"), {mul, l3}); mul = add_common_op(m1, migraphx::make_op("mul"), {mul, l3});
m1.add_return({mul}); m1.add_return({mul});
} }
migraphx::rewrite_gelu pass; migraphx::rewrite_gelu pass;
...@@ -114,8 +114,8 @@ TEST_CASE(non_bias_gelu) ...@@ -114,8 +114,8 @@ TEST_CASE(non_bias_gelu)
auto sig = m2.add_instruction(migraphx::make_op("neg"), mul); auto sig = m2.add_instruction(migraphx::make_op("neg"), mul);
sig = m2.add_instruction(migraphx::make_op("exp"), sig); sig = m2.add_instruction(migraphx::make_op("exp"), sig);
auto l2 = m2.add_literal(migraphx::literal{s2, {1.0f}}); auto l2 = m2.add_literal(migraphx::literal{s2, {1.0f}});
sig = add_common_op(m2, migraphx::make_op("add"), {sig, l2}); sig = add_common_op(m2, migraphx::make_op("add"), {sig, l2});
sig = m2.add_instruction(migraphx::make_op("div"), sub, sig); sig = m2.add_instruction(migraphx::make_op("div"), sub, sig);
m2.add_return({sig}); m2.add_return({sig});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment