Commit 15285630 authored by Paul's avatar Paul
Browse files

Format

parent ba72ce42
......@@ -41,9 +41,9 @@ __device__ void generic_binary_layernorm(
{
using reduce_output = reduce::with_axis<Input1, Axis>;
reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type;
using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>();
auto mean = [&](auto f) {
auto mean = [&](auto f) {
return r.reduce(op::sum{}, 0, [&](auto x1, auto x2) {
return f(x1, x2) / value_type{relements};
})(input1, input2);
......
......@@ -225,11 +225,11 @@ struct block
});
}
template<class Input>
template <class Input>
constexpr auto elements() const
{
using reduce_type = decltype(slicer(Input{}));
using value_type = typename Input::type;
using reduce_type = decltype(slicer(Input{}));
using value_type = typename Input::type;
constexpr auto relements = get_shape_c<reduce_type>{}.elements();
if constexpr(vec_size<value_type>() > 1)
return relements * vec_size<value_type>();
......@@ -294,7 +294,7 @@ struct lane
});
}
template<class Input>
template <class Input>
constexpr auto elements() const
{
using reduce_type = decltype(slicer(Input{}));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment