Commit b8194bc6 authored by Paul's avatar Paul
Browse files

Format

parent 8b7894b0
...@@ -44,11 +44,9 @@ struct layernorm_matcher ...@@ -44,11 +44,9 @@ struct layernorm_matcher
auto variance() const auto variance() const
{ {
return f("reduce_mean")(arg(0)(any_of( return f("reduce_mean")(arg(0)(any_of(
f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f))), f("pow")(arg(0)(x_minus_mean()), arg(1)(has_value(2.0f))),
f("mul")(arg(0)(x_minus_mean()), arg(1)(x_minus_mean())), f("mul")(arg(0)(x_minus_mean()), arg(1)(x_minus_mean())),
f("sqdiff")(either_arg(0, 1)(any().bind("x"), skip_broadcasts(f("reduce_mean")))) f("sqdiff")(either_arg(0, 1)(any().bind("x"), skip_broadcasts(f("reduce_mean")))))));
)
));
} }
auto sqrt_add_eps(const std::string& name) const auto sqrt_add_eps(const std::string& name) const
...@@ -59,8 +57,7 @@ struct layernorm_matcher ...@@ -59,8 +57,7 @@ struct layernorm_matcher
auto layernorm_onnx() const auto layernorm_onnx() const
{ {
auto div_sqrt = f("div")(arg(0)(x_minus_mean()), auto div_sqrt = f("div")(arg(0)(x_minus_mean()), arg(1)(sqrt_add_eps("sqrt")));
arg(1)(sqrt_add_eps("sqrt")));
auto mul_rsqrt = f("mul")(either_arg(0, 1)(x_minus_mean(), sqrt_add_eps("rsqrt"))); auto mul_rsqrt = f("mul")(either_arg(0, 1)(x_minus_mean(), sqrt_add_eps("rsqrt")));
return any(any_of(div_sqrt, mul_rsqrt)); return any(any_of(div_sqrt, mul_rsqrt));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment