Commit 8b7894b0 authored by Paul's avatar Paul
Browse files

Fuse layernorm with different patterns

parent 635502be
...@@ -43,15 +43,26 @@ struct layernorm_matcher ...@@ -43,15 +43,26 @@ struct layernorm_matcher
auto variance() const auto variance() const
{ {
return f("reduce_mean")(arg(0)(f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f))))); return f("reduce_mean")(arg(0)(any_of(
f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f))),
f("mul")(arg(0)(x_minus_mean()), arg(1)(x_minus_mean())),
f("sqdiff")(either_arg(0, 1)(any().bind("x"), skip_broadcasts(f("reduce_mean"))))
)
));
} }
auto layernorm_onnx() const auto sqrt_add_eps(const std::string& name) const
{ {
return f("div")(arg(0)(x_minus_mean()), auto add_eps = f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps")));
return skip_broadcasts(f(name)(arg(0)(add_eps)));
}
arg(1)(skip_broadcasts(f("sqrt")(arg(0)( auto layernorm_onnx() const
f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps")))))))); {
auto div_sqrt = f("div")(arg(0)(x_minus_mean()),
arg(1)(sqrt_add_eps("sqrt")));
auto mul_rsqrt = f("mul")(either_arg(0, 1)(x_minus_mean(), sqrt_add_eps("rsqrt")));
return any(any_of(div_sqrt, mul_rsqrt));
} }
auto matcher() const { return layernorm_onnx(); } auto matcher() const { return layernorm_onnx(); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment