Commit cab8156e authored by turneram's avatar turneram
Browse files

Add find_gelu_erf to simplify_alegrbra

parent f2531606
...@@ -851,6 +851,54 @@ struct find_div_const ...@@ -851,6 +851,54 @@ struct find_div_const
} }
}; };
struct find_gelu_erf
{
static auto match_mul1()
{
return match::name("mul")(args(match::any().bind("x"), match::skip_broadcasts(match::name("recip"))));
}
static auto match_erf() { return match::name("erf")(match::arg(0)(match_mul1())); }
static auto match_add2()
{
return match::name("add")(args(match_erf(), match::skip_broadcasts(match::has_value(1.0f))));
}
static auto match_add1()
{
return match::name("add")(args(match::skip_broadcasts(match::is_constant()), match::name("dot")));
}
static auto match_mul2() { return match::name("mul")(match::args(match_add1(), match_add2())); }
auto matcher() const
{
return match::name("mul")(
match::args(match_mul2(), match::skip_broadcasts(match::has_value(0.5f))));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x = r.instructions["x"];
auto lit = m.add_literal(literal{shape{x->get_shape().type()}, {1.702f}});
auto mul = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), lit);
mul = m.insert_instruction(ins, make_op("mul"), x, mul);
auto sig = m.insert_instruction(ins, make_op("neg"), mul);
auto one = m.add_literal(literal{shape{x->get_shape().type()}, {1.0f}});
one = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), one);
sig = m.insert_instruction(ins, make_op("exp"), sig);
sig = m.insert_instruction(ins, make_op("add"), one, sig);
sig = m.insert_instruction(ins, make_op("div"), one, sig);
sig = m.insert_instruction(ins, make_op("mul"), x, sig);
m.replace_instruction(ins, sig);
}
};
struct find_sub_const struct find_sub_const
{ {
auto matcher() const auto matcher() const
...@@ -1040,6 +1088,7 @@ void simplify_algebra::apply(module& m) const ...@@ -1040,6 +1088,7 @@ void simplify_algebra::apply(module& m) const
for(int i = 0; i < 8; i++) for(int i = 0; i < 8; i++)
{ {
match::find_matches(m, match::find_matches(m,
find_gelu_erf{},
find_inner_broadcast{}, find_inner_broadcast{},
find_double_add_lit_broadcast{}, find_double_add_lit_broadcast{},
find_add_lit_broadcast{}, find_add_lit_broadcast{},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment