Commit c2c5a632 authored by turneram's avatar turneram
Browse files

Formatting

parent cab8156e
...@@ -855,19 +855,22 @@ struct find_gelu_erf ...@@ -855,19 +855,22 @@ struct find_gelu_erf
{ {
static auto match_mul1() static auto match_mul1()
{ {
return match::name("mul")(args(match::any().bind("x"), match::skip_broadcasts(match::name("recip")))); return match::name("mul")(
args(match::any().bind("x"), match::skip_broadcasts(match::name("recip"))));
} }
static auto match_erf() { return match::name("erf")(match::arg(0)(match_mul1())); } static auto match_erf() { return match::name("erf")(match::arg(0)(match_mul1())); }
static auto match_add2() static auto match_add2()
{ {
return match::name("add")(args(match_erf(), match::skip_broadcasts(match::has_value(1.0f)))); return match::name("add")(
args(match_erf(), match::skip_broadcasts(match::has_value(1.0f))));
} }
static auto match_add1() static auto match_add1()
{ {
return match::name("add")(args(match::skip_broadcasts(match::is_constant()), match::name("dot"))); return match::name("add")(
args(match::skip_broadcasts(match::is_constant()), match::name("dot")));
} }
static auto match_mul2() { return match::name("mul")(match::args(match_add1(), match_add2())); } static auto match_mul2() { return match::name("mul")(match::args(match_add1(), match_add2())); }
...@@ -880,8 +883,8 @@ struct find_gelu_erf ...@@ -880,8 +883,8 @@ struct find_gelu_erf
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x = r.instructions["x"]; auto x = r.instructions["x"];
auto lit = m.add_literal(literal{shape{x->get_shape().type()}, {1.702f}}); auto lit = m.add_literal(literal{shape{x->get_shape().type()}, {1.702f}});
auto mul = m.insert_instruction( auto mul = m.insert_instruction(
...@@ -889,12 +892,12 @@ struct find_gelu_erf ...@@ -889,12 +892,12 @@ struct find_gelu_erf
mul = m.insert_instruction(ins, make_op("mul"), x, mul); mul = m.insert_instruction(ins, make_op("mul"), x, mul);
auto sig = m.insert_instruction(ins, make_op("neg"), mul); auto sig = m.insert_instruction(ins, make_op("neg"), mul);
auto one = m.add_literal(literal{shape{x->get_shape().type()}, {1.0f}}); auto one = m.add_literal(literal{shape{x->get_shape().type()}, {1.0f}});
one = m.insert_instruction( one = m.insert_instruction(
ins, make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), one); ins, make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}}), one);
sig = m.insert_instruction(ins, make_op("exp"), sig); sig = m.insert_instruction(ins, make_op("exp"), sig);
sig = m.insert_instruction(ins, make_op("add"), one, sig); sig = m.insert_instruction(ins, make_op("add"), one, sig);
sig = m.insert_instruction(ins, make_op("div"), one, sig); sig = m.insert_instruction(ins, make_op("div"), one, sig);
sig = m.insert_instruction(ins, make_op("mul"), x, sig); sig = m.insert_instruction(ins, make_op("mul"), x, sig);
m.replace_instruction(ins, sig); m.replace_instruction(ins, sig);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment