Commit 8b7894b0 authored by Paul's avatar Paul
Browse files

Fuse layernorm with different patterns

parent 635502be
......@@ -43,15 +43,26 @@ struct layernorm_matcher
auto variance() const
{
return f("reduce_mean")(arg(0)(f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f)))));
return f("reduce_mean")(arg(0)(any_of(
f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f))),
f("mul")(arg(0)(x_minus_mean()), arg(1)(x_minus_mean())),
f("sqdiff")(either_arg(0, 1)(any().bind("x"), skip_broadcasts(f("reduce_mean"))))
)
));
}
auto layernorm_onnx() const
auto sqrt_add_eps(const std::string& name) const
{
return f("div")(arg(0)(x_minus_mean()),
auto add_eps = f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps")));
return skip_broadcasts(f(name)(arg(0)(add_eps)));
}
arg(1)(skip_broadcasts(f("sqrt")(arg(0)(
f("add")(either_arg(0, 1)(variance(), is_constant().bind("eps"))))))));
auto layernorm_onnx() const
{
auto div_sqrt = f("div")(arg(0)(x_minus_mean()),
arg(1)(sqrt_add_eps("sqrt")));
auto mul_rsqrt = f("mul")(either_arg(0, 1)(x_minus_mean(), sqrt_add_eps("rsqrt")));
return any(any_of(div_sqrt, mul_rsqrt));
}
auto matcher() const { return layernorm_onnx(); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment