Commit cb11ba46 authored by Paul's avatar Paul
Browse files

Fix more tests

parent 6df9f47f
...@@ -168,7 +168,7 @@ __device__ auto subwave_reduce(index idx, Op op, T init, Index n, F f) ...@@ -168,7 +168,7 @@ __device__ auto subwave_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal()); MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
using type = decltype(index::invoke_loop(f, 0, _c<0>)); using type = decltype(index::invoke_loop(f, 0, _c<0>));
type x = init; auto x = type(init);
idx.local_subwave_stride<SubWaveSize>( idx.local_subwave_stride<SubWaveSize>(
n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); }); n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
dpp_reduce<SubWaveSize>(x, op); dpp_reduce<SubWaveSize>(x, op);
...@@ -185,12 +185,14 @@ template <class Op, class T, class Index, class F> ...@@ -185,12 +185,14 @@ template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal()); MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
if(idx.max_nlocal() == idx.nlocal_wave()) #ifdef MIGRAPHX_HAS_CONST_LOCAL
if constexpr(decltype(idx.nlocal()){} == __AMDGCN_WAVEFRONT_SIZE)
return wave_reduce(idx, op, init, n, f); return wave_reduce(idx, op, init, n, f);
#endif
constexpr index_int lanes_per_thread = __AMDGCN_WAVEFRONT_SIZE; constexpr index_int lanes_per_thread = __AMDGCN_WAVEFRONT_SIZE;
using type = decltype(index::invoke_loop(f, 0, _c<0>)); using type = decltype(index::invoke_loop(f, 0, _c<0>));
__shared__ type buffer[idx.max_nlocal() / lanes_per_thread]; __shared__ type buffer[idx.max_nlocal() / lanes_per_thread];
type x = type(init); auto x = type(init);
idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); }); idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
dpp_reduce(x, op); dpp_reduce(x, op);
...@@ -215,7 +217,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f) ...@@ -215,7 +217,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal()); MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
using type = decltype(index::invoke_loop(f, 0, _c<0>)); using type = decltype(index::invoke_loop(f, 0, _c<0>));
__shared__ type buffer[idx.max_nlocal()]; __shared__ type buffer[idx.max_nlocal()];
type x = init; auto x = type(init);
idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); }); idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
buffer[idx.local] = x; buffer[idx.local] = x;
__syncthreads(); __syncthreads();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment