"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "f6e22d567d1aba13f2ab0e4a2753cfc965e46151"
Commit cb11ba46 authored by Paul's avatar Paul
Browse files

Fix more tests

parent 6df9f47f
......@@ -168,7 +168,7 @@ __device__ auto subwave_reduce(index idx, Op op, T init, Index n, F f)
{
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
using type = decltype(index::invoke_loop(f, 0, _c<0>));
type x = init;
auto x = type(init);
idx.local_subwave_stride<SubWaveSize>(
n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
dpp_reduce<SubWaveSize>(x, op);
......@@ -185,12 +185,14 @@ template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
if(idx.max_nlocal() == idx.nlocal_wave())
#ifdef MIGRAPHX_HAS_CONST_LOCAL
if constexpr(decltype(idx.nlocal()){} == __AMDGCN_WAVEFRONT_SIZE)
return wave_reduce(idx, op, init, n, f);
#endif
constexpr index_int lanes_per_thread = __AMDGCN_WAVEFRONT_SIZE;
using type = decltype(index::invoke_loop(f, 0, _c<0>));
__shared__ type buffer[idx.max_nlocal() / lanes_per_thread];
type x = type(init);
auto x = type(init);
idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
dpp_reduce(x, op);
......@@ -215,7 +217,7 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
using type = decltype(index::invoke_loop(f, 0, _c<0>));
__shared__ type buffer[idx.max_nlocal()];
type x = init;
auto x = type(init);
idx.local_stride(n, [&](auto i, auto d) { x = op(x, index::invoke_loop(f, i, d)); });
buffer[idx.local] = x;
__syncthreads();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment