Commit f568547e authored by Paul's avatar Paul
Browse files

Format

parent 801ca3ed
...@@ -134,7 +134,7 @@ struct array ...@@ -134,7 +134,7 @@ struct array
return result; return result;
} }
template<class F> template <class F>
constexpr auto apply(F f) const constexpr auto apply(F f) const
{ {
array<decltype(f(d[0])), N> result; array<decltype(f(d[0])), N> result;
...@@ -210,8 +210,8 @@ struct array ...@@ -210,8 +210,8 @@ struct array
} }
}; };
template<class T, class... Ts> template <class T, class... Ts>
constexpr array<T, sizeof...(Ts)+1> make_array(T x, Ts... xs) constexpr array<T, sizeof...(Ts) + 1> make_array(T x, Ts... xs)
{ {
return {x, static_cast<T>(xs)...}; return {x, static_cast<T>(xs)...};
} }
......
...@@ -9,9 +9,7 @@ namespace migraphx { ...@@ -9,9 +9,7 @@ namespace migraphx {
template <class T, index_int N, class Op> template <class T, index_int N, class Op>
constexpr auto vec_reduce(const array<T, N>& a, Op op) constexpr auto vec_reduce(const array<T, N>& a, Op op)
{ {
return a.apply([&](auto x) { return a.apply([&](auto x) { return vec_reduce(x, op); });
return vec_reduce(x, op);
});
} }
template <index_int Axis, template <index_int Axis,
...@@ -30,12 +28,12 @@ __device__ void generic_binary_layernorm( ...@@ -30,12 +28,12 @@ __device__ void generic_binary_layernorm(
MIGRAPHX_ASSERT(relements > 0); MIGRAPHX_ASSERT(relements > 0);
reduce::block::run<reduce_output>([&](auto, auto r) { reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type; using value_type = typename Input1::type;
auto means = r.reduce(op::sum{}, make_array<value_type>(0, 0), [&](auto x1, auto x2) { auto means = r.reduce(op::sum{}, make_array<value_type>(0, 0), [&](auto x1, auto x2) {
auto x = op(x1, x2); auto x = op(x1, x2);
return make_array(x, x*x) / value_type{relements}; return make_array(x, x * x) / value_type{relements};
})(input1, input2); })(input1, input2);
auto mean_x = means[0]; auto mean_x = means[0];
auto mean_x2 = means[1]; auto mean_x2 = means[1];
r.inner([&](auto& y, auto x1, auto x2, auto... xs) { r.inner([&](auto& y, auto x1, auto x2, auto... xs) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment