"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "969be85cedefa265d785ee5e6fe6bbbbddfaa069"
Commit 8f1702bc authored by Paul's avatar Paul
Browse files

Update layernorm to use regs

parent 8f91ab39
...@@ -47,26 +47,27 @@ __device__ void generic_binary_layernorm( ...@@ -47,26 +47,27 @@ __device__ void generic_binary_layernorm(
{ {
using reduce_output = reduce::with_axis<Input1, Axis>; using reduce_output = reduce::with_axis<Input1, Axis>;
reduce::block::run<reduce_output>([&](auto, auto r) { reduce::block::run<reduce_output>([&](auto, auto r) {
auto input = r.inner([&](auto x1, auto x2) {
return op(x1, x2);
})(input1, input2);
using value_type = typename Input1::type; using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>(); constexpr auto relements = r.template elements<Input1>();
auto means = auto means =
r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x1, auto x2) { r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x) {
auto x = op(x1, x2);
return make_array(x, x * x) * vec_type<value_type>{1.0 / relements}; return make_array(x, x * x) * vec_type<value_type>{1.0 / relements};
})(input1, input2); })(input);
auto mean_x = means[0]; auto mean_x = means[0];
auto mean_x2 = means[1]; auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x); auto variance = mean_x2 - (mean_x * mean_x);
value_type eps_val = eps; // implicit conversion for eps value_type eps_val = eps; // implicit conversion for eps
r.inner([&](auto& y, auto x1, auto x2, auto... xs) { r.inner([&](auto& y, auto x, auto... xs) {
auto x = op(x1, x2);
auto m = x - mean_x; auto m = x - mean_x;
// m * rsqrt(mean(m ^ 2) + epsilon) // m * rsqrt(mean(m ^ 2) + epsilon)
y = compute(m * rsqrt(variance + eps_val), xs...); y = compute(m * rsqrt(variance + eps_val), xs...);
})(output, input1, input2, inputs...); })(output, input, inputs...);
}); });
} }
......
...@@ -171,6 +171,9 @@ struct inner_storage_tag ...@@ -171,6 +171,9 @@ struct inner_storage_tag
{ {
}; };
template<class T>
using is_inner_storage = is_base_of<inner_storage_tag, remove_cv_t<remove_reference_t<T>>>;
template <class R, class F> template <class R, class F>
struct storage_access : F struct storage_access : F
{ {
...@@ -213,7 +216,7 @@ struct reducer_base ...@@ -213,7 +216,7 @@ struct reducer_base
template <class T> template <class T>
__device__ auto make_inner_slice(T x) const __device__ auto make_inner_slice(T x) const
{ {
if constexpr(is_base_of<inner_storage_tag, T>{}) if constexpr(is_inner_storage<T>{})
{ {
return x; return x;
} }
...@@ -237,7 +240,7 @@ struct reducer_base ...@@ -237,7 +240,7 @@ struct reducer_base
template <class T, class... Ts> template <class T, class... Ts>
constexpr auto get_size(T&& x) const constexpr auto get_size(T&& x) const
{ {
if constexpr(is_base_of<inner_storage_tag, T>{}) if constexpr(is_inner_storage<T>{})
{ {
return x.rsize(); return x.rsize();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment