Unverified Commit 9d16eaca authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Fix compile error with no dpp reductions (#571)


Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 5cc6e160
......@@ -85,13 +85,18 @@ struct highest
};
#ifdef MIGRAPHX_NO_DPP
template <index_int N, class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
template <index_int N,
class Op,
class T,
class ForStride,
class F,
MIGRAPHX_REQUIRES(not std::is_integral<ForStride>{})>
__device__ auto block_reduce(index idx, Op op, T init, ForStride fs, F f)
{
using type = decltype(f(idx.local));
using type = decltype(f(deduce_for_stride(fs)));
MIGRAPHX_DEVICE_SHARED type buffer[N];
type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
fs([&](auto i) { x = op(x, f(i)); });
buffer[idx.local] = x;
__syncthreads();
......@@ -218,7 +223,7 @@ __device__ auto block_reduce(index idx, Op op, T init, ForStride fs, F f)
}
return y;
}
#endif
template <index_int N, class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
{
......@@ -229,8 +234,6 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
return block_reduce<N>(
idx, op, init, midx.for_stride(fs), [&](auto mi) __device__ { return f(mi[0]); });
}
#endif
constexpr index_int compute_block_size(index_int n, index_int max_block_size)
{
size_t block_size = 64;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment