Commit 03c6967e authored by Paul's avatar Paul
Browse files

Format

parent 07f37b4c
...@@ -49,10 +49,11 @@ __device__ void generic_binary_layernorm( ...@@ -49,10 +49,11 @@ __device__ void generic_binary_layernorm(
reduce::block::run<reduce_output>([&](auto, auto r) { reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type; using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
auto means = r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x1, auto x2) { auto means =
auto x = op(x1, x2); r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x1, auto x2) {
return make_array(x, x * x) * vec_type<value_type>{1.0 / relements}; auto x = op(x1, x2);
})(input1, input2); return make_array(x, x * x) * vec_type<value_type>{1.0 / relements};
})(input1, input2);
auto mean_x = means[0]; auto mean_x = means[0];
auto mean_x2 = means[1]; auto mean_x2 = means[1];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment