Commit 658110e1 authored by Paul's avatar Paul
Browse files

Add wave reduction

parent d8011adf
...@@ -134,6 +134,12 @@ struct index ...@@ -134,6 +134,12 @@ struct index
#endif #endif
constexpr auto ngroup() const { return nglobal() / max_nlocal(); } constexpr auto ngroup() const { return nglobal() / max_nlocal(); }
constexpr index_constant<__AMDGCN_WAVEFRONT_SIZE> nlocal_wave() const { return {}; }
constexpr auto local_wave() const { return local % nlocal_wave(); }
constexpr auto nwave() const { return max_nlocal() / nlocal_wave(); }
constexpr auto wave() const { return local / nlocal_wave(); }
template <class N, class Stride> template <class N, class Stride>
static constexpr auto max_stride_iterations(N n, Stride stride) static constexpr auto max_stride_iterations(N n, Stride stride)
{ {
...@@ -152,6 +158,12 @@ struct index ...@@ -152,6 +158,12 @@ struct index
return max_stride_iterations(n, nlocal()); return max_stride_iterations(n, nlocal());
} }
template <class N>
constexpr auto max_local_wave_stride_iterations(N n) const
{
return max_stride_iterations(n, nlocal_wave());
}
template <class F, class I, class D> template <class F, class I, class D>
static constexpr auto invoke_loop(F f, I i, D d) -> decltype(f(i, d)) static constexpr auto invoke_loop(F f, I i, D d) -> decltype(f(i, d))
{ {
...@@ -241,6 +253,12 @@ struct index ...@@ -241,6 +253,12 @@ struct index
{ {
for_stride<false>(group, n, ngroup(), f); for_stride<false>(group, n, ngroup(), f);
} }
template <class F, class N>
__device__ void local_wave_stride(N n, F f) const
{
for_stride<false>(local_wave(), n, nlocal_wave(), f);
}
}; };
#ifdef MIGRAPHX_NLOCAL #ifdef MIGRAPHX_NLOCAL
......
...@@ -95,10 +95,24 @@ MIGRAPHX_DPP_REDUCE(op::product, v_mul, _u) ...@@ -95,10 +95,24 @@ MIGRAPHX_DPP_REDUCE(op::product, v_mul, _u)
MIGRAPHX_DPP_REDUCE(op::max, v_max, _i) MIGRAPHX_DPP_REDUCE(op::max, v_max, _i)
MIGRAPHX_DPP_REDUCE(op::min, v_min, _i) MIGRAPHX_DPP_REDUCE(op::min, v_min, _i)
template <class Op, class T, class Index, class F>
__device__ auto wave_reduce(index idx, Op op, T init, Index n, F f)
{
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
using type = decltype(index::invoke_loop(f, 0, _c<0>));
type x = init;
idx.local_wave_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
dpp_reduce(x, op);
return x;
}
template <class Op, class T, class Index, class F> template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal()); MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
if (idx.max_nlocal() == idx.nlocal_wave())
return wave_reduce(idx, op, init, n, f);
#if __AMDGCN_WAVEFRONT_SIZE == 32 #if __AMDGCN_WAVEFRONT_SIZE == 32
constexpr index_int lanes_per_thread = 16; constexpr index_int lanes_per_thread = 16;
#else #else
...@@ -470,6 +484,81 @@ struct block_large ...@@ -470,6 +484,81 @@ struct block_large
} }
}; };
struct wave
{
template <class Slicer>
struct reducer : reducer_base<reducer<Slicer>>
{
index idx;
Slicer slice;
template <class T, index_int N, class Size>
struct inner_storage : inner_storage_tag
{
using type = T;
array<T, N> arr;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto& operator()(U, V d) const
{
return arr[d];
}
template <class U, class V>
constexpr auto& operator()(U, V d)
{
return arr[d];
}
};
template <class Op, class T, class Read, class N, class... Ts>
__device__ auto reduce_impl(Op op, T init, Read read, N n, Ts&&... xs) const
{
return wave_reduce(idx, op, init, n, [&](auto j, auto d) {
return vec_reduce(read(xs(j, d)...), op);
});
}
template <class F>
__device__ void outer(F f) const
{
if(idx.local == 0)
f();
}
template <class F, class N, class... Ts>
__device__ void inner_void_impl(F f, N n, Ts&&... xs) const
{
idx.local_wave_stride(n, [&](auto j, auto d) { f(xs(j, d)...); });
}
template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const
{
using max_iterations = decltype(idx.max_local_wave_stride_iterations(n));
inner_storage<R, max_iterations{}, N> storage;
idx.local_wave_stride(n, [&](auto j, auto d) { storage(j, d) = f(xs(j, d)...); });
return storage;
}
};
template <class Slicer>
static __device__ auto make(index idx, Slicer slicer)
{
return reducer<Slicer>{{}, idx, slicer};
}
template <class Output, class F>
static __device__ void run(F f)
{
auto idx = make_index();
constexpr auto nelements = get_shape_c<Output>{}.elements();
idx.global_stride(nelements * idx.nlocal_wave(), [&](auto i) {
const auto out_idx = get_shape_c<Output>{}.multi(i / idx.nlocal_wave());
f(out_idx, make(idx, [&](auto input) { return reduce_slice<Output>(input, out_idx); }));
});
}
};
struct lane struct lane
{ {
template <class Slicer> template <class Slicer>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment