Commit 6f692ebd authored by turneram's avatar turneram
Browse files

Formatting

parent 35f709a2
......@@ -33,8 +33,8 @@ struct find_gelu_erf
{
static auto match_div()
{
return match::name("div")(
match::either_arg(0, 1)(match::any().bind("x"), match::skip_broadcasts(match::has_value(1.414f, 1e-3))));
return match::name("div")(match::either_arg(0, 1)(
match::any().bind("x"), match::skip_broadcasts(match::has_value(1.414f, 1e-3))));
}
static auto match_erf() { return match::name("erf")(match::arg(0)(match_div())); }
......@@ -45,7 +45,10 @@ struct find_gelu_erf
match::either_arg(0, 1)(match_erf(), match::skip_broadcasts(match::has_value(1.0f))));
}
static auto match_mul() { return match::name("mul")(match::either_arg(0, 1)(match::any(), match_add())); }
static auto match_mul()
{
return match::name("mul")(match::either_arg(0, 1)(match::any(), match_add()));
}
auto matcher() const
{
......@@ -63,7 +66,7 @@ struct find_gelu_erf
ins, make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), lit);
mul = m.insert_instruction(ins, make_op("mul"), x, mul);
auto sig = m.insert_instruction(ins, make_op("neg"), mul);
sig = m.insert_instruction(ins, make_op("exp"), sig);
sig = m.insert_instruction(ins, make_op("exp"), sig);
auto one = m.add_literal(literal{shape{x->get_shape().type()}, {1.0f}});
one = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), one);
......@@ -74,10 +77,7 @@ struct find_gelu_erf
}
};
void rewrite_gelu::apply(module& m) const
{
match::find_matches(m, find_gelu_erf{});
}
void rewrite_gelu::apply(module& m) const { match::find_matches(m, find_gelu_erf{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -37,7 +37,6 @@
#include <migraphx/verify.hpp>
TEST_CASE(bias_gelu)
{
migraphx::shape s1{migraphx::shape::half_type, {2, 4, 8}};
......@@ -92,10 +91,10 @@ TEST_CASE(non_bias_gelu)
migraphx::shape s2{migraphx::shape::half_type};
migraphx::module m1;
{
auto a = m1.add_parameter("a", s1);
auto b = m1.add_parameter("b", s1);
auto a = m1.add_parameter("a", s1);
auto b = m1.add_parameter("b", s1);
auto sub = m1.add_instruction(migraphx::make_op("sub"), a, b);
auto l1 = m1.add_literal(migraphx::literal{s2, {1.4140625f}});
auto l1 = m1.add_literal(migraphx::literal{s2, {1.4140625f}});
l1 = m1.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), l1);
auto div = m1.add_instruction(migraphx::make_op("div"), sub, l1);
auto erf = m1.add_instruction(migraphx::make_op("erf"), div);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment