Commit 3143befc authored by Paul's avatar Paul
Browse files

Fix variance calculation

parent 389bc830
......@@ -57,12 +57,13 @@ __device__ void generic_binary_layernorm(
auto mean_x = means[0];
auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x);
r.inner([&](auto& y, auto x1, auto x2, auto... xs) {
auto x = op(x1, x2);
auto m = x - mean_x;
// m * rsqrt(mean(m ^ 2) + 1e-12)
y = compute(m * rsqrt(mean_x2 - mean_x + value_type{1e-12}), xs...);
y = compute(m * rsqrt(variance + value_type{1e-12}), xs...);
})(output, input1, input2, inputs...);
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment