Commit 276184b2 authored by Paul's avatar Paul
Browse files

Load into registers first

parent c56d6e9e
......@@ -53,6 +53,17 @@ struct index
return blockDim.x; // NOLINT
}
#endif
template<class N>
constexpr auto max_global_stride_iterations(N n) const
{
return _c<1> + n / nglobal();
}
template<class N>
constexpr auto max_local_stride_iterations(N n) const
{
return _c<1> + n / nlocal();
}
template <class F>
__device__ void global_stride(index_int n, F f) const
......
......@@ -22,24 +22,27 @@ __device__ void generic_binary_layernorm(
MIGRAPHX_ASSERT(relements > 0);
reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type;
auto input = r.inner([&](auto x1, auto x2) {
return op(x1, x2);
})(input1, input2);
auto mean = [&](auto f) {
return r.reduce(op::sum{}, 0, [&](auto x1, auto x2) {
return f(x1, x2) / value_type{relements};
})(input1, input2);
return r.reduce(op::sum{}, 0, [&](auto x) {
return f(x) / value_type{relements};
})(input);
};
// mean(x)
auto mean_x = mean(op);
auto mean_x = mean(op::id{});
// mean(m ^ 2)
auto mean_m2 = mean([&](auto x1, auto x2) {
auto m = op(x1, x2) - mean_x;
auto mean_m2 = mean([&](auto x) {
auto m = x - mean_x;
return m * m;
});
r.inner([&](auto& y, auto x1, auto x2, auto... xs) {
auto m = op(x1, x2) - mean_x;
r.inner([&](auto& y, auto x, auto... xs) {
auto m = x - mean_x;
// m * rsqrt(mean(m ^ 2) + 1e-12)
y = compute(m * rsqrt(mean_m2 + value_type{1e-12}), xs...);
})(output, input1, input2, inputs...);
})(output, input, inputs...);
});
}
......
......@@ -147,25 +147,35 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
}
#endif
namespace reduce {
struct inner_array_base {};
template <class Output, class Input, class T>
constexpr auto reduce_slice(Input input, T i)
{
constexpr auto lens = transform(get_shape_c<Input>{}.lens,
get_shape_c<Output>{}.lens,
[](index_int x, index_int y) -> index_int {
if(x == y)
return 1;
return x;
});
;
constexpr auto s = make_shape(lens, get_shape_c<Input>{}.strides);
MIGRAPHX_ASSERT((input.get_shape().index(i) + s.element_space()) <=
input.get_shape().element_space());
return make_tensor_view(&input[i], s);
if constexpr(is_base_of<inner_array_base, Input>{})
{
return input;
}
else
{
constexpr auto lens = transform(get_shape_c<Input>{}.lens,
get_shape_c<Output>{}.lens,
[](index_int x, index_int y) -> index_int {
if(x == y)
return 1;
return x;
});
;
constexpr auto s = make_shape(lens, get_shape_c<Input>{}.strides);
MIGRAPHX_ASSERT((input.get_shape().index(i) + s.element_space()) <=
input.get_shape().element_space());
return make_tensor_view(&input[i], s);
}
}
namespace reduce {
template <class Slicer, class F>
constexpr auto sliced(Slicer slicer, F f)
{
......@@ -217,11 +227,38 @@ struct block
f();
}
template<class T, index_int N, index_int Stride, class Shape>
struct inner_array : inner_array_base
{
array<T, N> arr;
constexpr Shape get_shape() const {return Shape{};}
template<class U>
constexpr auto& operator[](U i) const
{
return arr[i/Stride];
}
template<class U>
constexpr auto& operator[](U i)
{
return arr[i/Stride];
}
};
template <class F>
__device__ auto inner(F f) const
{
return sliced(slicer, [=](auto x, auto... xs) {
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
using result_type = decltype(f(x[0], xs[0]...));
if constexpr(is_same<result_type, void>{})
{
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
}
else
{
inner_array<result_type, decltype(idx.max_local_stride_iterations(x.get_shape().elements())){}, decltype(idx.nlocal()){}, decltype(x.get_shape())> y;
idx.local_stride(x.get_shape().elements(), [&](auto j) { y[j] = f(x[j], xs[j]...); });
return y;
}
});
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment