"composable_kernel/include/utility/functional.hpp" did not exist on "81497a93a0840d5a1b5e84c1e47a90ae39d0fee6"
Commit 801ca3ed authored by Paul's avatar Paul
Browse files

Merge reduction into one

parent c56d6e9e
...@@ -134,6 +134,15 @@ struct array ...@@ -134,6 +134,15 @@ struct array
return result; return result;
} }
template<class F>
constexpr auto apply(F f) const
{
array<decltype(f(d[0])), N> result;
for(index_int i = 0; i < N; i++)
result[i] = f(d[i]);
return result;
}
MIGRAPHX_DEVICE_ARRAY_OP(+=, +) MIGRAPHX_DEVICE_ARRAY_OP(+=, +)
MIGRAPHX_DEVICE_ARRAY_OP(-=, -) MIGRAPHX_DEVICE_ARRAY_OP(-=, -)
MIGRAPHX_DEVICE_ARRAY_OP(*=, *) MIGRAPHX_DEVICE_ARRAY_OP(*=, *)
...@@ -201,6 +210,11 @@ struct array ...@@ -201,6 +210,11 @@ struct array
} }
}; };
template<class T, class... Ts>
constexpr array<T, sizeof...(Ts)+1> make_array(T x, Ts... xs)
{
return {x, static_cast<T>(xs)...};
}
template <class T, T... Xs> template <class T, T... Xs>
struct integral_const_array : array<T, sizeof...(Xs)> struct integral_const_array : array<T, sizeof...(Xs)>
{ {
......
...@@ -6,6 +6,14 @@ ...@@ -6,6 +6,14 @@
namespace migraphx { namespace migraphx {
template <class T, index_int N, class Op>
constexpr auto vec_reduce(const array<T, N>& a, Op op)
{
return a.apply([&](auto x) {
return vec_reduce(x, op);
});
}
template <index_int Axis, template <index_int Axis,
class F, class F,
class BinOp, class BinOp,
...@@ -22,23 +30,19 @@ __device__ void generic_binary_layernorm( ...@@ -22,23 +30,19 @@ __device__ void generic_binary_layernorm(
MIGRAPHX_ASSERT(relements > 0); MIGRAPHX_ASSERT(relements > 0);
reduce::block::run<reduce_output>([&](auto, auto r) { reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type; using value_type = typename Input1::type;
auto mean = [&](auto f) { auto means = r.reduce(op::sum{}, make_array<value_type>(0, 0), [&](auto x1, auto x2) {
return r.reduce(op::sum{}, 0, [&](auto x1, auto x2) { auto x = op(x1, x2);
return f(x1, x2) / value_type{relements}; return make_array(x, x*x) / value_type{relements};
})(input1, input2); })(input1, input2);
};
// mean(x) auto mean_x = means[0];
auto mean_x = mean(op); auto mean_x2 = means[1];
// mean(m ^ 2)
auto mean_m2 = mean([&](auto x1, auto x2) {
auto m = op(x1, x2) - mean_x;
return m * m;
});
r.inner([&](auto& y, auto x1, auto x2, auto... xs) { r.inner([&](auto& y, auto x1, auto x2, auto... xs) {
auto m = op(x1, x2) - mean_x; auto x = op(x1, x2);
auto m = x - mean_x;
// m * rsqrt(mean(m ^ 2) + 1e-12) // m * rsqrt(mean(m ^ 2) + 1e-12)
y = compute(m * rsqrt(mean_m2 + value_type{1e-12}), xs...); y = compute(m * rsqrt(mean_x2 - mean_x + value_type{1e-12}), xs...);
})(output, input1, input2, inputs...); })(output, input1, input2, inputs...);
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment