Commit eb98f95b authored by Paul's avatar Paul
Browse files

Reduce first

parent 99c96638
......@@ -48,10 +48,11 @@ __device__ void generic_binary_layernorm(
using reduce_output = reduce::with_axis<Input1, Axis>;
reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type;
using reduce_type = vec_type<value_type>;
constexpr auto relements = r.template elements<Input1>();
auto means = r.reduce(op::sum{}, make_array<value_type>(0, 0), [&](auto x1, auto x2) {
auto means = r.reduce(op::sum{}, make_array<reduce_type>(0, 0), [&](auto x1, auto x2) {
auto x = op(x1, x2);
return make_array(x, x * x) / value_type{relements};
return make_array(vec_reduce(x, op::sum{}), vec_reduce(x * x, op::sum{})) / reduce_type{relements};
})(input1, input2);
auto mean_x = means[0];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment