Commit 8f1702bc authored by Paul's avatar Paul
Browse files

Update layernorm to use regs

parent 8f91ab39
......@@ -47,26 +47,27 @@ __device__ void generic_binary_layernorm(
{
using reduce_output = reduce::with_axis<Input1, Axis>;
reduce::block::run<reduce_output>([&](auto, auto r) {
auto input = r.inner([&](auto x1, auto x2) {
return op(x1, x2);
})(input1, input2);
using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>();
auto means =
r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x1, auto x2) {
auto x = op(x1, x2);
r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x) {
return make_array(x, x * x) * vec_type<value_type>{1.0 / relements};
})(input1, input2);
})(input);
auto mean_x = means[0];
auto mean_x2 = means[1];
auto variance = mean_x2 - (mean_x * mean_x);
value_type eps_val = eps; // implicit conversion for eps
r.inner([&](auto& y, auto x1, auto x2, auto... xs) {
auto x = op(x1, x2);
r.inner([&](auto& y, auto x, auto... xs) {
auto m = x - mean_x;
// m * rsqrt(mean(m ^ 2) + epsilon)
y = compute(m * rsqrt(variance + eps_val), xs...);
})(output, input1, input2, inputs...);
})(output, input, inputs...);
});
}
......
......@@ -171,6 +171,9 @@ struct inner_storage_tag
{
};
template<class T>
using is_inner_storage = is_base_of<inner_storage_tag, remove_cv_t<remove_reference_t<T>>>;
template <class R, class F>
struct storage_access : F
{
......@@ -213,7 +216,7 @@ struct reducer_base
template <class T>
__device__ auto make_inner_slice(T x) const
{
if constexpr(is_base_of<inner_storage_tag, T>{})
if constexpr(is_inner_storage<T>{})
{
return x;
}
......@@ -237,7 +240,7 @@ struct reducer_base
template <class T, class... Ts>
constexpr auto get_size(T&& x) const
{
if constexpr(is_base_of<inner_storage_tag, T>{})
if constexpr(is_inner_storage<T>{})
{
return x.rsize();
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment