Commit 6d4311cd authored by Paul's avatar Paul
Browse files

Add more asserts

parent 7ffd56a8
...@@ -50,7 +50,7 @@ inline __device__ __attribute__((const)) index_int compute_global_size() ...@@ -50,7 +50,7 @@ inline __device__ __attribute__((const)) index_int compute_global_size()
// We cant just use blockDim.x to get the local size since its broken on hip // We cant just use blockDim.x to get the local size since its broken on hip
// when global is not divisible by local size. In this case, we calulate the // when global is not divisible by local size. In this case, we calulate the
// size for the last group. // size for the last group.
inline __device__ __attribute__((const)) index_int compute_local_size() inline __device__ __attribute__((const)) index_int compute_local_size()
{ {
#ifdef MIGRAPHX_NLOCAL #ifdef MIGRAPHX_NLOCAL
...@@ -89,7 +89,10 @@ struct index ...@@ -89,7 +89,10 @@ struct index
index_int group = 0; index_int group = 0;
#ifdef MIGRAPHX_NGLOBAL #ifdef MIGRAPHX_NGLOBAL
constexpr index_constant<MIGRAPHX_NGLOBAL> nglobal() const { return {}; } constexpr index_constant<MIGRAPHX_NGLOBAL> nglobal() const {
static_assert(MIGRAPHX_NGLOBAL > 0, "Global size must be greater than 0");
return {};
}
#else #else
__device__ index_int nglobal() const __device__ index_int nglobal() const
{ {
...@@ -99,18 +102,31 @@ struct index ...@@ -99,18 +102,31 @@ struct index
#endif #endif
#ifdef MIGRAPHX_HAS_CONST_LOCAL #ifdef MIGRAPHX_HAS_CONST_LOCAL
constexpr index_constant<MIGRAPHX_NLOCAL> nlocal() const { return {}; } constexpr index_constant<MIGRAPHX_NLOCAL> nlocal() const {
static_assert(MIGRAPHX_NLOCAL > 0, "Local size must be greater than 0");
return {};
}
#else #else
__device__ index_int nlocal() const __device__ index_int nlocal() const
{ {
#ifdef MIGRAPHX_NGROUP #ifdef MIGRAPHX_NGROUP
static_assert((MIGRAPHX_NGLOBAL % MIGRAPHX_NLOCAL != 0) and (MIGRAPHX_NGROUP > 1), static_assert((MIGRAPHX_NGLOBAL % MIGRAPHX_NLOCAL != 0) and (MIGRAPHX_NGROUP > 1),
"Local should be const"); "Local size should be const");
#endif #endif
MIGRAPHX_ASSERT(compute_local_size() > 0); MIGRAPHX_ASSERT(compute_local_size() > 0);
return compute_local_size(); // NOLINT return compute_local_size(); // NOLINT
} }
#endif #endif
#ifdef MIGRAPHX_NLOCAL
constexpr index_constant<MIGRAPHX_NLOCAL> max_nlocal() const { return {}; }
#else
__device__ index_int max_nlocal() const
{
MIGRAPHX_ASSERT(blockDim.x > 0);
return blockDim.x;
}
#endif
template <class N, class Stride> template <class N, class Stride>
static constexpr auto max_stride_iterations(N n, Stride stride) static constexpr auto max_stride_iterations(N n, Stride stride)
{ {
......
...@@ -97,13 +97,14 @@ MIGRAPHX_DPP_REDUCE(op::product, v_mul) ...@@ -97,13 +97,14 @@ MIGRAPHX_DPP_REDUCE(op::product, v_mul)
template <class Op, class T, class Index, class F> template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
#if __AMDGCN_WAVEFRONT_SIZE == 32 #if __AMDGCN_WAVEFRONT_SIZE == 32
constexpr index_int lanes_per_thread = 16; constexpr index_int lanes_per_thread = 16;
#else #else
constexpr index_int lanes_per_thread = 64; constexpr index_int lanes_per_thread = 64;
#endif #endif
using type = decltype(f(0)); using type = decltype(f(0));
__shared__ type buffer[idx.nlocal() / lanes_per_thread]; __shared__ type buffer[idx.max_nlocal() / lanes_per_thread];
type x = init; type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
dpp_reduce(x, op); dpp_reduce(x, op);
...@@ -126,9 +127,9 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f) ...@@ -126,9 +127,9 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
template <class Op, class T, class Index, class F> template <class Op, class T, class Index, class F>
__device__ auto block_reduce(index idx, Op op, T init, Index n, F f) __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
{ {
MIGRAPHX_ASSERT(idx.max_nlocal() == idx.nlocal());
using type = decltype(f(0)); using type = decltype(f(0));
__shared__ type buffer[idx.nlocal()]; __shared__ type buffer[idx.max_nlocal()];
type x = init; type x = init;
idx.local_stride(n, [&](auto i) { x = op(x, f(i)); }); idx.local_stride(n, [&](auto i) { x = op(x, f(i)); });
buffer[idx.local] = x; buffer[idx.local] = x;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment