Commit ff232fb8 authored by Sarunya Pumma's avatar Sarunya Pumma
Browse files

Fix reduce_block_into_lanes for multi_tensor_l2norm for ROCm

parent 76e4e054
...@@ -175,15 +175,16 @@ __device__ __forceinline__ T reduce_block_into_lanes ...@@ -175,15 +175,16 @@ __device__ __forceinline__ T reduce_block_into_lanes
{ {
int tid = threadIdx.x + threadIdx.y*blockDim.x; int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32. int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
auto double_warp_size = warpSize * 2;
if(blockSize >= 64) if(blockSize >= double_warp_size)
{ {
x[tid] = val; x[tid] = val;
__syncthreads(); __syncthreads();
} }
#pragma unroll #pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1) for(int i = (blockSize >> 1); i >= double_warp_size; i >>= 1)
{ {
if(tid < i) if(tid < i)
x[tid] = x[tid] + x[tid+i]; x[tid] = x[tid] + x[tid+i];
...@@ -192,18 +193,18 @@ __device__ __forceinline__ T reduce_block_into_lanes ...@@ -192,18 +193,18 @@ __device__ __forceinline__ T reduce_block_into_lanes
T final; T final;
if(tid < 32) if(tid < warpSize)
{ {
if(blockSize >= 64) if(blockSize >= double_warp_size)
final = x[tid] + x[tid+32]; final = x[tid] + x[tid + warpSize];
else else
final = val; final = val;
// __SYNCWARP(); // __SYNCWARP();
#pragma unroll #pragma unroll
for(int i = 16; i >= lanes; i >>= 1) { for(int i = warpSize / 2; i >= lanes; i >>= 1) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
final = final + __shfl_down(0xffffffff, final, i); final = final + __shfl_down(final, i);
#else #else
final = final + __shfl_down_sync(0xffffffff, final, i); final = final + __shfl_down_sync(0xffffffff, final, i);
#endif #endif
...@@ -230,15 +231,16 @@ __device__ __forceinline__ T reduce_block_into_lanes_max_op ...@@ -230,15 +231,16 @@ __device__ __forceinline__ T reduce_block_into_lanes_max_op
{ {
int tid = threadIdx.x + threadIdx.y*blockDim.x; int tid = threadIdx.x + threadIdx.y*blockDim.x;
int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32. int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32.
auto double_warp_size = warpSize * 2;
if(blockSize >= 64) if(blockSize >= double_warp_size)
{ {
x[tid] = val; x[tid] = val;
__syncthreads(); __syncthreads();
} }
#pragma unroll #pragma unroll
for(int i = (blockSize >> 1); i >= 64; i >>= 1) for(int i = (blockSize >> 1); i >= double_warp_size; i >>= 1)
{ {
if(tid < i) if(tid < i)
x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid+i])); x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid+i]));
...@@ -247,10 +249,10 @@ __device__ __forceinline__ T reduce_block_into_lanes_max_op ...@@ -247,10 +249,10 @@ __device__ __forceinline__ T reduce_block_into_lanes_max_op
T final; T final;
if(tid < 32) if(tid < warpSize)
{ {
if(blockSize >= 64) if(blockSize >= double_warp_size)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid+32])); final = fmaxf(fabsf(x[tid]), fabsf(x[tid + warpSize]));
else else
final = val; final = val;
// __SYNCWARP(); // __SYNCWARP();
...@@ -258,7 +260,7 @@ __device__ __forceinline__ T reduce_block_into_lanes_max_op ...@@ -258,7 +260,7 @@ __device__ __forceinline__ T reduce_block_into_lanes_max_op
#pragma unroll #pragma unroll
for(int i = 16; i >= lanes; i >>= 1) { for(int i = 16; i >= lanes; i >>= 1) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
final = fmaxf(fabsf(final), fabsf(__shfl_down(0xffffffff, final, i))); final = fmaxf(fabsf(final), fabsf(__shfl_down(final, i)));
#else #else
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
#endif #endif
......
...@@ -58,7 +58,6 @@ class TestMultiTensorL2Norm(unittest.TestCase): ...@@ -58,7 +58,6 @@ class TestMultiTensorL2Norm(unittest.TestCase):
self.assertTrue(self.overflow_buf.item() == 0) self.assertTrue(self.overflow_buf.item() == 0)
@unittest.skipIf(disabled, "amp_C is unavailable") @unittest.skipIf(disabled, "amp_C is unavailable")
@skipIfRocm
def test_fuzz(self): def test_fuzz(self):
input_size_pairs = ( input_size_pairs = (
(7777*77, 555*555), (7777*77, 555*555),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment