Commit c4add749 authored by Paul's avatar Paul
Browse files

Allow integral constant to be passed

parent 57a5c827
......@@ -63,6 +63,7 @@ struct index
template <class F, class N, class Stride>
static constexpr void for_stride(index_int start, N n, Stride stride, F f)
{
static_assert(not is_integral<N>{}, "");
if constexpr(not is_integral<N>{} and not is_integral<Stride>{} and
max_stride_iterations(n, stride) == 1)
{
......
......@@ -94,8 +94,8 @@ MIGRAPHX_DPP_REDUCE(op::max, v_max)
MIGRAPHX_DPP_REDUCE(op::min, v_min)
MIGRAPHX_DPP_REDUCE(op::product, v_mul)
template <class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{
#if __AMDGCN_WAVEFRONT_SIZE == 32
constexpr index_int lanes_per_thread = 16;
......@@ -123,8 +123,8 @@ __device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
return y;
}
#else
template <class Op, class T, class F>
__device__ auto block_reduce(index idx, Op op, T init, index_int n, F f)
template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{
using type = decltype(f(0));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment