Unverified Commit e4a28e53 authored by Douglas Lehr's avatar Douglas Lehr Committed by GitHub
Browse files

[ROCM] Fix blockReduceSum to use correct warp counts for ROCm and CUDA (#3262)

parent 0bba88df
...@@ -15,9 +15,6 @@ ...@@ -15,9 +15,6 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif
#include <torch/extension.h> #include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
...@@ -31,11 +28,6 @@ ...@@ -31,11 +28,6 @@
#include <algorithm> #include <algorithm>
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
......
#pragma once #pragma once
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#ifndef USE_ROCM #ifndef USE_ROCM
#define VLLM_LDG(arg) __ldg(arg) #define VLLM_LDG(arg) __ldg(arg)
#else #else
......
...@@ -24,7 +24,7 @@ namespace vllm { ...@@ -24,7 +24,7 @@ namespace vllm {
template<typename T> template<typename T>
__inline__ __device__ T warpReduceSum(T val) { __inline__ __device__ T warpReduceSum(T val) {
#pragma unroll #pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1)
val += VLLM_SHFL_XOR_SYNC(val, mask); val += VLLM_SHFL_XOR_SYNC(val, mask);
return val; return val;
} }
...@@ -32,7 +32,7 @@ __inline__ __device__ T warpReduceSum(T val) { ...@@ -32,7 +32,7 @@ __inline__ __device__ T warpReduceSum(T val) {
/* Calculate the sum of all elements in a block */ /* Calculate the sum of all elements in a block */
template<typename T> template<typename T>
__inline__ __device__ T blockReduceSum(T val) { __inline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[32]; static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f; int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5; int wid = threadIdx.x >> 5;
...@@ -45,7 +45,7 @@ __inline__ __device__ T blockReduceSum(T val) { ...@@ -45,7 +45,7 @@ __inline__ __device__ T blockReduceSum(T val) {
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32 // blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); val = (threadIdx.x < (blockDim.x / (WARP_SIZE * 1.0f))) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val); val = warpReduceSum<T>(val);
return val; return val;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment