"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "11b6b7e4fc13c52be064dd58b1126723ee9988dd"
Commit 6ed9a3e4 authored by wenjh's avatar wenjh
Browse files

[DCU] Add width to __shfl


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent b27e513d
...@@ -135,7 +135,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ...@@ -135,7 +135,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
// broadcast the amax to all threads in a warp from the lane 0 // broadcast the amax to all threads in a warp from the lane 0
constexpr int lane_zero = 0; constexpr int lane_zero = 0;
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
warp_tile_amax = __shfl(warp_tile_amax, lane_zero); warp_tile_amax = __shfl(warp_tile_amax, lane_zero, THREADS_PER_WARP);
__syncthreads();
#else #else
warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero); warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero);
#endif #endif
...@@ -362,7 +363,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose ...@@ -362,7 +363,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
// broadcast the amax to all threads in a warp from the lane 0 // broadcast the amax to all threads in a warp from the lane 0
constexpr int lane_zero = 0; constexpr int lane_zero = 0;
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
warp_tile_amax = __shfl(warp_tile_amax, lane_zero); warp_tile_amax = __shfl(warp_tile_amax, lane_zero, THREADS_PER_WARP);
__syncthreads();
#else #else
warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero); warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero);
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment