Commit 8b8395be authored by Paul's avatar Paul
Browse files

Apply vec_reduce first

parent d8038982
...@@ -49,9 +49,9 @@ __device__ void generic_binary_layernorm( ...@@ -49,9 +49,9 @@ __device__ void generic_binary_layernorm(
reduce::block::run<reduce_output>([&](auto, auto r) { reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type; using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
auto means = r.reduce(op::sum{}, make_array<value_type>(0, 0), [&](auto x1, auto x2) { auto means = r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x1, auto x2) {
auto x = op(x1, x2); auto x = op(x1, x2);
return make_array(x, x * x) / value_type{relements}; return make_array(x, x * x) / vec_type<value_type>{relements};
})(input1, input2); })(input1, input2);
auto mean_x = means[0]; auto mean_x = means[0];
......
...@@ -201,12 +201,11 @@ struct block ...@@ -201,12 +201,11 @@ struct block
__device__ auto reduce(Op op, T init, Read read) const __device__ auto reduce(Op op, T init, Read read) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slicer, [=](auto x, auto... xs) {
return vec_reduce(block_reduce(idx, return block_reduce(idx,
op, op,
init, init,
x.get_shape().elements(), x.get_shape().elements(),
[&](auto j) { return read(x[j], xs[j]...); }), [&](auto j) { return vec_reduce(read(x[j], xs[j]...), op); });
op);
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment