Unverified Commit ed10f3ce authored by Gregory Shtrasberg's avatar Gregory Shtrasberg Committed by GitHub
Browse files

[ROCm] warpSize is being made non constexpr in ROCm 7.0 (#20330)


Signed-off-by: default avatarGregory Shtrasberg <Gregory.Shtrasberg@amd.com>
parent b637e9dc
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "attention_dtypes.h" #include "attention_dtypes.h"
#include "attention_utils.cuh" #include "attention_utils.cuh"
#include "cuda_compat.h"
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
...@@ -33,12 +34,6 @@ typedef __hip_bfloat16 __nv_bfloat16; ...@@ -33,12 +34,6 @@ typedef __hip_bfloat16 __nv_bfloat16;
#include "../quantization/fp8/nvidia/quant_utils.cuh" #include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif #endif
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
...@@ -670,7 +665,6 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -670,7 +665,6 @@ __global__ void paged_attention_v2_reduce_kernel(
} // namespace vllm } // namespace vllm
#undef WARP_SIZE
#undef MAX #undef MAX
#undef MIN #undef MIN
#undef DIVIDE_ROUND_UP #undef DIVIDE_ROUND_UP
...@@ -18,12 +18,7 @@ ...@@ -18,12 +18,7 @@
*/ */
#include "attention_kernels.cuh" #include "attention_kernels.cuh"
#include "cuda_compat.h"
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
...@@ -187,7 +182,6 @@ void paged_attention_v1( ...@@ -187,7 +182,6 @@ void paged_attention_v1(
CALL_V1_LAUNCHER_BLOCK_SIZE) CALL_V1_LAUNCHER_BLOCK_SIZE)
} }
#undef WARP_SIZE
#undef MAX #undef MAX
#undef MIN #undef MIN
#undef DIVIDE_ROUND_UP #undef DIVIDE_ROUND_UP
...@@ -18,12 +18,7 @@ ...@@ -18,12 +18,7 @@
*/ */
#include "attention_kernels.cuh" #include "attention_kernels.cuh"
#include "cuda_compat.h"
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
...@@ -197,7 +192,6 @@ void paged_attention_v2( ...@@ -197,7 +192,6 @@ void paged_attention_v2(
CALL_V2_LAUNCHER_BLOCK_SIZE) CALL_V2_LAUNCHER_BLOCK_SIZE)
} }
#undef WARP_SIZE
#undef MAX #undef MAX
#undef MIN #undef MIN
#undef DIVIDE_ROUND_UP #undef DIVIDE_ROUND_UP
...@@ -4,10 +4,10 @@ ...@@ -4,10 +4,10 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#endif #endif
#ifndef USE_ROCM #if defined(USE_ROCM) && defined(__GFX9__)
#define WARP_SIZE 32 #define WARP_SIZE 64
#else #else
#define WARP_SIZE warpSize #define WARP_SIZE 32
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment