Commit 3143befc authored by Paul's avatar Paul
Browse files

Fix variance calculation

parent 389bc830
...@@ -57,12 +57,13 @@ __device__ void generic_binary_layernorm( ...@@ -57,12 +57,13 @@ __device__ void generic_binary_layernorm(
auto mean_x = means[0]; auto mean_x = means[0];
auto mean_x2 = means[1]; auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x);
r.inner([&](auto& y, auto x1, auto x2, auto... xs) { r.inner([&](auto& y, auto x1, auto x2, auto... xs) {
auto x = op(x1, x2); auto x = op(x1, x2);
auto m = x - mean_x; auto m = x - mean_x;
// m * rsqrt(mean(m ^ 2) + 1e-12) // m * rsqrt(mean(m ^ 2) + 1e-12)
y = compute(m * rsqrt(mean_x2 - mean_x + value_type{1e-12}), xs...); y = compute(m * rsqrt(variance + value_type{1e-12}), xs...);
})(output, input1, input2, inputs...); })(output, input1, input2, inputs...);
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment