Commit c2fa2267 authored by Paul's avatar Paul
Browse files

Format

parent 658110e1
...@@ -95,13 +95,12 @@ MIGRAPHX_DPP_REDUCE(op::product, v_mul, _u) ...@@ -95,13 +95,12 @@ MIGRAPHX_DPP_REDUCE(op::product, v_mul, _u)
MIGRAPHX_DPP_REDUCE(op::max, v_max, _i) MIGRAPHX_DPP_REDUCE(op::max, v_max, _i)
MIGRAPHX_DPP_REDUCE(op::min, v_min, _i) MIGRAPHX_DPP_REDUCE(op::min, v_min, _i)
template <class Op, class T, class Index, class F> template <class Op, class T, class Index, class F>
__device__ auto wave_reduce(index idx, Op op, T init, Index n, F f) __device__ auto wave_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal()); MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
using type = decltype(index::invoke_loop(f, 0, _c<0>)); using type = decltype(index::invoke_loop(f, 0, _c<0>));
type x = init; type x = init;
idx.local_wave_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); }); idx.local_wave_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
dpp_reduce(x, op); dpp_reduce(x, op);
return x; return x;
...@@ -111,7 +110,7 @@ template <class Op, class T, class Index, class F> ...@@ -111,7 +110,7 @@ template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal()); MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
if (idx.max_nlocal() == idx.nlocal_wave()) if(idx.max_nlocal() == idx.nlocal_wave())
return wave_reduce(idx, op, init, n, f); return wave_reduce(idx, op, init, n, f);
#if __AMDGCN_WAVEFRONT_SIZE == 32 #if __AMDGCN_WAVEFRONT_SIZE == 32
constexpr index_int lanes_per_thread = 16; constexpr index_int lanes_per_thread = 16;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment