Commit 6f692ebd authored by turneram's avatar turneram
Browse files

Formatting

parent 35f709a2
...@@ -33,8 +33,8 @@ struct find_gelu_erf ...@@ -33,8 +33,8 @@ struct find_gelu_erf
{ {
static auto match_div() static auto match_div()
{ {
return match::name("div")( return match::name("div")(match::either_arg(0, 1)(
match::either_arg(0, 1)(match::any().bind("x"), match::skip_broadcasts(match::has_value(1.414f, 1e-3)))); match::any().bind("x"), match::skip_broadcasts(match::has_value(1.414f, 1e-3))));
} }
static auto match_erf() { return match::name("erf")(match::arg(0)(match_div())); } static auto match_erf() { return match::name("erf")(match::arg(0)(match_div())); }
...@@ -45,7 +45,10 @@ struct find_gelu_erf ...@@ -45,7 +45,10 @@ struct find_gelu_erf
match::either_arg(0, 1)(match_erf(), match::skip_broadcasts(match::has_value(1.0f)))); match::either_arg(0, 1)(match_erf(), match::skip_broadcasts(match::has_value(1.0f))));
} }
static auto match_mul() { return match::name("mul")(match::either_arg(0, 1)(match::any(), match_add())); } static auto match_mul()
{
return match::name("mul")(match::either_arg(0, 1)(match::any(), match_add()));
}
auto matcher() const auto matcher() const
{ {
...@@ -63,7 +66,7 @@ struct find_gelu_erf ...@@ -63,7 +66,7 @@ struct find_gelu_erf
ins, make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), lit); ins, make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), lit);
mul = m.insert_instruction(ins, make_op("mul"), x, mul); mul = m.insert_instruction(ins, make_op("mul"), x, mul);
auto sig = m.insert_instruction(ins, make_op("neg"), mul); auto sig = m.insert_instruction(ins, make_op("neg"), mul);
sig = m.insert_instruction(ins, make_op("exp"), sig); sig = m.insert_instruction(ins, make_op("exp"), sig);
auto one = m.add_literal(literal{shape{x->get_shape().type()}, {1.0f}}); auto one = m.add_literal(literal{shape{x->get_shape().type()}, {1.0f}});
one = m.insert_instruction( one = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), one); ins, make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), one);
...@@ -74,10 +77,7 @@ struct find_gelu_erf ...@@ -74,10 +77,7 @@ struct find_gelu_erf
} }
}; };
void rewrite_gelu::apply(module& m) const void rewrite_gelu::apply(module& m) const { match::find_matches(m, find_gelu_erf{}); }
{
match::find_matches(m, find_gelu_erf{});
}
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -37,7 +37,6 @@ ...@@ -37,7 +37,6 @@
#include <migraphx/verify.hpp> #include <migraphx/verify.hpp>
TEST_CASE(bias_gelu) TEST_CASE(bias_gelu)
{ {
migraphx::shape s1{migraphx::shape::half_type, {2, 4, 8}}; migraphx::shape s1{migraphx::shape::half_type, {2, 4, 8}};
...@@ -92,10 +91,10 @@ TEST_CASE(non_bias_gelu) ...@@ -92,10 +91,10 @@ TEST_CASE(non_bias_gelu)
migraphx::shape s2{migraphx::shape::half_type}; migraphx::shape s2{migraphx::shape::half_type};
migraphx::module m1; migraphx::module m1;
{ {
auto a = m1.add_parameter("a", s1); auto a = m1.add_parameter("a", s1);
auto b = m1.add_parameter("b", s1); auto b = m1.add_parameter("b", s1);
auto sub = m1.add_instruction(migraphx::make_op("sub"), a, b); auto sub = m1.add_instruction(migraphx::make_op("sub"), a, b);
auto l1 = m1.add_literal(migraphx::literal{s2, {1.4140625f}}); auto l1 = m1.add_literal(migraphx::literal{s2, {1.4140625f}});
l1 = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l1); l1 = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l1);
auto div = m1.add_instruction(migraphx::make_op("div"), sub, l1); auto div = m1.add_instruction(migraphx::make_op("div"), sub, l1);
auto erf = m1.add_instruction(migraphx::make_op("erf"), div); auto erf = m1.add_instruction(migraphx::make_op("erf"), div);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment