Unverified Commit c9415c19 authored by kliuae's avatar kliuae Committed by GitHub
Browse files

[ROCm] Fix warp and lane calculation in blockReduceSum (#3321)

parent 4c922709
......@@ -29,12 +29,22 @@ __inline__ __device__ T warpReduceSum(T val) {
return val;
}
__inline__ __device__ constexpr int _calculateLaneMask(int warp_size) {
return warp_size - 1;
}
__inline__ __device__ constexpr int _calculateWidShift(int warp_size) {
return 5 + (warp_size >> 6);
}
/* Calculate the sum of all elements in a block */
template<typename T>
__inline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
constexpr auto LANE_MASK = _calculateLaneMask(WARP_SIZE);
constexpr auto WID_SHIFT = _calculateWidShift(WARP_SIZE);
int lane = threadIdx.x & LANE_MASK;
int wid = threadIdx.x >> WID_SHIFT;
val = warpReduceSum<T>(val);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment