Commit e3e8dc32 authored by Paul's avatar Paul
Browse files

Format

parent 8f1702bc
...@@ -47,15 +47,12 @@ __device__ void generic_binary_layernorm( ...@@ -47,15 +47,12 @@ __device__ void generic_binary_layernorm(
{ {
using reduce_output = reduce::with_axis<Input1, Axis>; using reduce_output = reduce::with_axis<Input1, Axis>;
reduce::block::run<reduce_output>([&](auto, auto r) { reduce::block::run<reduce_output>([&](auto, auto r) {
auto input = r.inner([&](auto x1, auto x2) { auto input = r.inner([&](auto x1, auto x2) { return op(x1, x2); })(input1, input2);
return op(x1, x2); using value_type = typename Input1::type;
})(input1, input2);
using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
auto means = auto means = r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x) {
r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x) { return make_array(x, x * x) * vec_type<value_type>{1.0 / relements};
return make_array(x, x * x) * vec_type<value_type>{1.0 / relements}; })(input);
})(input);
auto mean_x = means[0]; auto mean_x = means[0];
auto mean_x2 = means[1]; auto mean_x2 = means[1];
......
...@@ -171,7 +171,7 @@ struct inner_storage_tag ...@@ -171,7 +171,7 @@ struct inner_storage_tag
{ {
}; };
template<class T> template <class T>
using is_inner_storage = is_base_of<inner_storage_tag, remove_cv_t<remove_reference_t<T>>>; using is_inner_storage = is_base_of<inner_storage_tag, remove_cv_t<remove_reference_t<T>>>;
template <class R, class F> template <class R, class F>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment