Commit 276184b2 authored by Paul's avatar Paul
Browse files

Load into registers first

parent c56d6e9e
...@@ -53,6 +53,17 @@ struct index ...@@ -53,6 +53,17 @@ struct index
return blockDim.x; // NOLINT return blockDim.x; // NOLINT
} }
#endif #endif
template<class N>
constexpr auto max_global_stride_iterations(N n) const
{
return _c<1> + n / nglobal();
}
template<class N>
constexpr auto max_local_stride_iterations(N n) const
{
return _c<1> + n / nlocal();
}
template <class F> template <class F>
__device__ void global_stride(index_int n, F f) const __device__ void global_stride(index_int n, F f) const
......
...@@ -22,24 +22,27 @@ __device__ void generic_binary_layernorm( ...@@ -22,24 +22,27 @@ __device__ void generic_binary_layernorm(
MIGRAPHX_ASSERT(relements > 0); MIGRAPHX_ASSERT(relements > 0);
reduce::block::run<reduce_output>([&](auto, auto r) { reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type; using value_type = typename Input1::type;
auto input = r.inner([&](auto x1, auto x2) {
return op(x1, x2);
})(input1, input2);
auto mean = [&](auto f) { auto mean = [&](auto f) {
return r.reduce(op::sum{}, 0, [&](auto x1, auto x2) { return r.reduce(op::sum{}, 0, [&](auto x) {
return f(x1, x2) / value_type{relements}; return f(x) / value_type{relements};
})(input1, input2); })(input);
}; };
// mean(x) // mean(x)
auto mean_x = mean(op); auto mean_x = mean(op::id{});
// mean(m ^ 2) // mean(m ^ 2)
auto mean_m2 = mean([&](auto x1, auto x2) { auto mean_m2 = mean([&](auto x) {
auto m = op(x1, x2) - mean_x; auto m = x - mean_x;
return m * m; return m * m;
}); });
r.inner([&](auto& y, auto x1, auto x2, auto... xs) { r.inner([&](auto& y, auto x, auto... xs) {
auto m = op(x1, x2) - mean_x; auto m = x - mean_x;
// m * rsqrt(mean(m ^ 2) + 1e-12) // m * rsqrt(mean(m ^ 2) + 1e-12)
y = compute(m * rsqrt(mean_m2 + value_type{1e-12}), xs...); y = compute(m * rsqrt(mean_m2 + value_type{1e-12}), xs...);
})(output, input1, input2, inputs...); })(output, input, inputs...);
}); });
} }
......
...@@ -147,25 +147,35 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f) ...@@ -147,25 +147,35 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
} }
#endif #endif
namespace reduce {
struct inner_array_base {};
template <class Output, class Input, class T> template <class Output, class Input, class T>
constexpr auto reduce_slice(Input input, T i) constexpr auto reduce_slice(Input input, T i)
{ {
constexpr auto lens = transform(get_shape_c<Input>{}.lens, if constexpr(is_base_of<inner_array_base, Input>{})
get_shape_c<Output>{}.lens, {
[](index_int x, index_int y) -> index_int { return input;
if(x == y) }
return 1; else
return x; {
}); constexpr auto lens = transform(get_shape_c<Input>{}.lens,
; get_shape_c<Output>{}.lens,
constexpr auto s = make_shape(lens, get_shape_c<Input>{}.strides); [](index_int x, index_int y) -> index_int {
MIGRAPHX_ASSERT((input.get_shape().index(i) + s.element_space()) <= if(x == y)
input.get_shape().element_space()); return 1;
return make_tensor_view(&input[i], s); return x;
});
;
constexpr auto s = make_shape(lens, get_shape_c<Input>{}.strides);
MIGRAPHX_ASSERT((input.get_shape().index(i) + s.element_space()) <=
input.get_shape().element_space());
return make_tensor_view(&input[i], s);
}
} }
namespace reduce {
template <class Slicer, class F> template <class Slicer, class F>
constexpr auto sliced(Slicer slicer, F f) constexpr auto sliced(Slicer slicer, F f)
{ {
...@@ -217,11 +227,38 @@ struct block ...@@ -217,11 +227,38 @@ struct block
f(); f();
} }
template<class T, index_int N, index_int Stride, class Shape>
struct inner_array : inner_array_base
{
array<T, N> arr;
constexpr Shape get_shape() const {return Shape{};}
template<class U>
constexpr auto& operator[](U i) const
{
return arr[i/Stride];
}
template<class U>
constexpr auto& operator[](U i)
{
return arr[i/Stride];
}
};
template <class F> template <class F>
__device__ auto inner(F f) const __device__ auto inner(F f) const
{ {
return sliced(slicer, [=](auto x, auto... xs) { return sliced(slicer, [=](auto x, auto... xs) {
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); }); using result_type = decltype(f(x[0], xs[0]...));
if constexpr(is_same<result_type, void>{})
{
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
}
else
{
inner_array<result_type, decltype(idx.max_local_stride_iterations(x.get_shape().elements())){}, decltype(idx.nlocal()){}, decltype(x.get_shape())> y;
idx.local_stride(x.get_shape().elements(), [&](auto j) { y[j] = f(x[j], xs[j]...); });
return y;
}
}); });
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment