Commit 06607821 authored by Paul's avatar Paul
Browse files

Merge branch 'jit-layernorm-merge' into bert-opt-layernorm

parents 6ef8bd98 d8038982
...@@ -134,6 +134,15 @@ struct array ...@@ -134,6 +134,15 @@ struct array
return result; return result;
} }
template <class F>
constexpr auto apply(F f) const
{
array<decltype(f(d[0])), N> result;
for(index_int i = 0; i < N; i++)
result[i] = f(d[i]);
return result;
}
MIGRAPHX_DEVICE_ARRAY_OP(+=, +) MIGRAPHX_DEVICE_ARRAY_OP(+=, +)
MIGRAPHX_DEVICE_ARRAY_OP(-=, -) MIGRAPHX_DEVICE_ARRAY_OP(-=, -)
MIGRAPHX_DEVICE_ARRAY_OP(*=, *) MIGRAPHX_DEVICE_ARRAY_OP(*=, *)
...@@ -201,6 +210,11 @@ struct array ...@@ -201,6 +210,11 @@ struct array
} }
}; };
template <class T, class... Ts>
constexpr array<T, sizeof...(Ts) + 1> make_array(T x, Ts... xs)
{
return {x, static_cast<T>(xs)...};
}
template <class T, T... Xs> template <class T, T... Xs>
struct integral_const_array : array<T, sizeof...(Xs)> struct integral_const_array : array<T, sizeof...(Xs)>
{ {
......
...@@ -29,6 +29,12 @@ ...@@ -29,6 +29,12 @@
namespace migraphx { namespace migraphx {
template <class T, index_int N, class Op>
constexpr auto vec_reduce(const array<T, N>& a, Op op)
{
return a.apply([&](auto x) { return vec_reduce(x, op); });
}
template <index_int Axis, template <index_int Axis,
class F, class F,
class BinOp, class BinOp,
...@@ -43,23 +49,19 @@ __device__ void generic_binary_layernorm( ...@@ -43,23 +49,19 @@ __device__ void generic_binary_layernorm(
reduce::block::run<reduce_output>([&](auto, auto r) { reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type; using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
auto mean = [&](auto f) { auto means = r.reduce(op::sum{}, make_array<value_type>(0, 0), [&](auto x1, auto x2) {
return r.reduce(op::sum{}, 0, [&](auto x1, auto x2) { auto x = op(x1, x2);
return f(x1, x2) / value_type{relements}; return make_array(x, x * x) / value_type{relements};
})(input1, input2); })(input1, input2);
};
// mean(x) auto mean_x = means[0];
auto mean_x = mean(op); auto mean_x2 = means[1];
// mean(m ^ 2)
auto mean_m2 = mean([&](auto x1, auto x2) {
auto m = op(x1, x2) - mean_x;
return m * m;
});
r.inner([&](auto& y, auto x1, auto x2, auto... xs) { r.inner([&](auto& y, auto x1, auto x2, auto... xs) {
auto m = op(x1, x2) - mean_x; auto x = op(x1, x2);
auto m = x - mean_x;
// m * rsqrt(mean(m ^ 2) + 1e-12) // m * rsqrt(mean(m ^ 2) + 1e-12)
y = compute(m * rsqrt(mean_m2 + value_type{1e-12}), xs...); y = compute(m * rsqrt(mean_x2 - mean_x + value_type{1e-12}), xs...);
})(output, input1, input2, inputs...); })(output, input1, input2, inputs...);
}); });
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment