Commit 9e053941 authored by zhuwenwen's avatar zhuwenwen
Browse files

skip fp8 kernel and _rocm_C extension

parent f850f22a
...@@ -233,11 +233,11 @@ set(VLLM_EXT_SRC ...@@ -233,11 +233,11 @@ set(VLLM_EXT_SRC
"csrc/pos_encoding_kernels.cu" "csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu" "csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu" "csrc/layernorm_kernels.cu"
"csrc/layernorm_quant_kernels.cu" # "csrc/layernorm_quant_kernels.cu"
"csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu" # "csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" # "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu" "csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu" "csrc/prepare_inputs/advance_step.cu"
...@@ -613,6 +613,7 @@ define_gpu_extension_target( ...@@ -613,6 +613,7 @@ define_gpu_extension_target(
USE_SABI 3 USE_SABI 3
WITH_SOABI) WITH_SOABI)
#[[
if(VLLM_GPU_LANG STREQUAL "HIP") if(VLLM_GPU_LANG STREQUAL "HIP")
# #
# _rocm_C extension # _rocm_C extension
...@@ -631,6 +632,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP") ...@@ -631,6 +632,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
USE_SABI 3 USE_SABI 3
WITH_SOABI) WITH_SOABI)
endif() endif()
]]
# For CUDA we also build and ship some external projects. # For CUDA we also build and ship some external projects.
if (VLLM_GPU_LANG STREQUAL "CUDA") if (VLLM_GPU_LANG STREQUAL "CUDA")
......
...@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) ...@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list(APPEND GPU_FLAGS list(APPEND GPU_FLAGS
"-DUSE_ROCM" "-DUSE_ROCM"
"-DENABLE_FP8" #"-DENABLE_FP8"
"-U__HIP_NO_HALF_CONVERSIONS__" "-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__" "-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc") "-fno-gpu-rdc")
......
...@@ -17,43 +17,43 @@ ...@@ -17,43 +17,43 @@
* limitations under the License. * limitations under the License.
*/ */
#include <torch/all.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <algorithm> #include <algorithm>
#include "attention_dtypes.h" #include "attention_dtypes.h"
#include "attention_utils.cuh" #include "attention_utils.cuh"
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh" #include "../quantization/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#else #else
#include "../quantization/fp8/nvidia/quant_utils.cuh" #include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
#define WARP_SIZE 32 #define WARP_SIZE 32
#else #else
#define WARP_SIZE warpSize #define WARP_SIZE warpSize
#endif #endif
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
namespace vllm { namespace vllm {
// Utility function for attention softmax. // Utility function for attention softmax.
template <int NUM_WARPS> template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) { inline __device__ float block_sum(float* red_smem, float sum) {
// Decompose the thread index into warp / lane. // Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE; int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE; int lane = threadIdx.x % WARP_SIZE;
// Compute the sum per warp. // Compute the sum per warp.
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask); sum += VLLM_SHFL_XOR_SYNC(sum, mask);
} }
...@@ -72,22 +72,22 @@ inline __device__ float block_sum(float* red_smem, float sum) { ...@@ -72,22 +72,22 @@ inline __device__ float block_sum(float* red_smem, float sum) {
} }
// Parallel reduction inside the warp. // Parallel reduction inside the warp.
#pragma unroll #pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask); sum += VLLM_SHFL_XOR_SYNC(sum, mask);
} }
// Broadcast to other threads. // Broadcast to other threads.
return VLLM_SHFL_SYNC(sum, 0); return VLLM_SHFL_SYNC(sum, 0);
} }
// TODO(woosuk): Merge the last two dimensions of the grid. // TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE, bool IS_BLOCK_SPARSE,
int PARTITION_SIZE = 0> // Zero means no partitioning. int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel( __device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
...@@ -178,7 +178,7 @@ __device__ void paged_attention_kernel( ...@@ -178,7 +178,7 @@ __device__ void paged_attention_kernel(
// q is split from a qkv tensor, it may not be contiguous. // q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
#pragma unroll #pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
i += NUM_THREAD_GROUPS) { i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
...@@ -268,7 +268,7 @@ __device__ void paged_attention_kernel( ...@@ -268,7 +268,7 @@ __device__ void paged_attention_kernel(
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD]; K_vec k_vecs[NUM_VECS_PER_THREAD];
#pragma unroll #pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr = const cache_t* k_ptr =
k_cache + physical_block_number * kv_block_stride + k_cache + physical_block_number * kv_block_stride +
...@@ -310,7 +310,7 @@ __device__ void paged_attention_kernel( ...@@ -310,7 +310,7 @@ __device__ void paged_attention_kernel(
// Perform reduction across the threads in the same warp to get the // Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet). // max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value. // The 0-th thread of each thread group already has its max qk value.
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
} }
...@@ -322,7 +322,7 @@ __device__ void paged_attention_kernel( ...@@ -322,7 +322,7 @@ __device__ void paged_attention_kernel(
// TODO(woosuk): Refactor this part. // TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence. // Get the max qk value for the sequence.
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll #pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
} }
...@@ -370,7 +370,7 @@ __device__ void paged_attention_kernel( ...@@ -370,7 +370,7 @@ __device__ void paged_attention_kernel(
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy. // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[NUM_ROWS_PER_THREAD]; float accs[NUM_ROWS_PER_THREAD];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[i] = 0.f; accs[i] = 0.f;
} }
...@@ -401,7 +401,7 @@ __device__ void paged_attention_kernel( ...@@ -401,7 +401,7 @@ __device__ void paged_attention_kernel(
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride; kv_head_idx * kv_head_stride;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) { if (row_idx < HEAD_SIZE) {
...@@ -423,7 +423,7 @@ __device__ void paged_attention_kernel( ...@@ -423,7 +423,7 @@ __device__ void paged_attention_kernel(
// contain NaNs. See // contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec); scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll #pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) { for (int j = 0; j < V_VEC_SIZE; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
} }
...@@ -434,10 +434,10 @@ __device__ void paged_attention_kernel( ...@@ -434,10 +434,10 @@ __device__ void paged_attention_kernel(
} }
// Perform reduction within each warp. // Perform reduction within each warp.
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[i]; float acc = accs[i];
#pragma unroll #pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += VLLM_SHFL_XOR_SYNC(acc, mask); acc += VLLM_SHFL_XOR_SYNC(acc, mask);
} }
...@@ -450,13 +450,13 @@ __device__ void paged_attention_kernel( ...@@ -450,13 +450,13 @@ __device__ void paged_attention_kernel(
// Perform reduction across warps. // Perform reduction across warps.
float* out_smem = reinterpret_cast<float*>(shared_mem); float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll #pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) { for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2; int mid = i / 2;
// Upper warps write to shared memory. // Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) { if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
...@@ -469,7 +469,7 @@ __device__ void paged_attention_kernel( ...@@ -469,7 +469,7 @@ __device__ void paged_attention_kernel(
// Lower warps update the output. // Lower warps update the output.
if (warp_idx < mid) { if (warp_idx < mid) {
const float* src = &out_smem[warp_idx * HEAD_SIZE]; const float* src = &out_smem[warp_idx * HEAD_SIZE];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
...@@ -485,7 +485,7 @@ __device__ void paged_attention_kernel( ...@@ -485,7 +485,7 @@ __device__ void paged_attention_kernel(
scalar_t* out_ptr = scalar_t* out_ptr =
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
...@@ -493,13 +493,13 @@ __device__ void paged_attention_kernel( ...@@ -493,13 +493,13 @@ __device__ void paged_attention_kernel(
} }
} }
} }
} }
// Grid: (num_heads, num_seqs, 1). // Grid: (num_heads, num_seqs, 1).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE> bool IS_BLOCK_SPARSE>
__global__ void paged_attention_v1_kernel( __global__ void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
...@@ -524,14 +524,14 @@ __global__ void paged_attention_v1_kernel( ...@@ -524,14 +524,14 @@ __global__ void paged_attention_v1_kernel(
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step); blocksparse_head_sliding_step);
} }
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE, bool IS_BLOCK_SPARSE,
int PARTITION_SIZE> int PARTITION_SIZE>
__global__ void paged_attention_v2_kernel( __global__ void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
...@@ -559,12 +559,12 @@ __global__ void paged_attention_v2_kernel( ...@@ -559,12 +559,12 @@ __global__ void paged_attention_v2_kernel(
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step); blocksparse_head_sliding_step);
} }
// Grid: (num_heads, num_seqs). // Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS, template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE> int PARTITION_SIZE>
__global__ void paged_attention_v2_reduce_kernel( __global__ void paged_attention_v2_reduce_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads, const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
...@@ -617,7 +617,7 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -617,7 +617,7 @@ __global__ void paged_attention_v2_reduce_kernel(
// Get the global max logit. // Get the global max logit.
// Reduce within the warp. // Reduce within the warp.
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
} }
...@@ -627,7 +627,7 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -627,7 +627,7 @@ __global__ void paged_attention_v2_reduce_kernel(
__syncthreads(); __syncthreads();
// Reduce across warps. // Reduce across warps.
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll #pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
} }
...@@ -657,7 +657,7 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -657,7 +657,7 @@ __global__ void paged_attention_v2_reduce_kernel(
head_idx * max_num_partitions * HEAD_SIZE; head_idx * max_num_partitions * HEAD_SIZE;
scalar_t* out_ptr = scalar_t* out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll #pragma unroll
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
float acc = 0.0f; float acc = 0.0f;
for (int j = 0; j < num_partitions; ++j) { for (int j = 0; j < num_partitions; ++j) {
...@@ -666,11 +666,11 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -666,11 +666,11 @@ __global__ void paged_attention_v2_reduce_kernel(
} }
from_float(out_ptr[i], acc); from_float(out_ptr[i], acc);
} }
} }
} // namespace vllm } // namespace vllm
#undef WARP_SIZE #undef WARP_SIZE
#undef MAX #undef MAX
#undef MIN #undef MIN
#undef DIVIDE_ROUND_UP #undef DIVIDE_ROUND_UP
\ No newline at end of file
...@@ -5,24 +5,26 @@ ...@@ -5,24 +5,26 @@
* Currently, only static fp8 quantization is supported. * Currently, only static fp8 quantization is supported.
*/ */
#include "type_convert.cuh" #include "type_convert.cuh"
#include "quantization/fp8/common.cuh" #ifndef USE_ROCM
#include "dispatch_utils.h" #include "quantization/fp8/common.cuh"
#endif
#include "dispatch_utils.h"
#include <torch/cuda.h> #include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#ifndef USE_ROCM #ifndef USE_ROCM
#include <cub/cub.cuh> #include <cub/cub.cuh>
#else #else
#include <hipcub/hipcub.hpp> #include <hipcub/hipcub.hpp>
#endif #endif
namespace vllm { namespace vllm {
// TODO(woosuk): Further optimize this kernel. // TODO(woosuk): Further optimize this kernel.
template <typename scalar_t, typename fp8_type> template <typename scalar_t, typename fp8_type>
__global__ void rms_norm_static_fp8_quant_kernel( __global__ void rms_norm_static_fp8_quant_kernel(
fp8_type* __restrict__ out, // [..., hidden_size] fp8_type* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
...@@ -54,15 +56,15 @@ __global__ void rms_norm_static_fp8_quant_kernel( ...@@ -54,15 +56,15 @@ __global__ void rms_norm_static_fp8_quant_kernel(
out[blockIdx.x * hidden_size + idx] = out[blockIdx.x * hidden_size + idx] =
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv); scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
} }
} }
/* Function specialization in the case of FP16/BF16 tensors. /* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are Additional optimizations we can make in this case are
packed and vectorized operations, which help with the packed and vectorized operations, which help with the
memory latency bottleneck. */ memory latency bottleneck. */
template <typename scalar_t, int width, typename fp8_type> template <typename scalar_t, int width, typename fp8_type>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists> __global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel( fused_add_rms_norm_static_fp8_quant_kernel(
fp8_type* __restrict__ out, // [..., hidden_size] fp8_type* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size]
...@@ -111,20 +113,20 @@ fused_add_rms_norm_static_fp8_quant_kernel( ...@@ -111,20 +113,20 @@ fused_add_rms_norm_static_fp8_quant_kernel(
_f16Vec<scalar_t, width> temp = residual_v[id]; _f16Vec<scalar_t, width> temp = residual_v[id];
temp *= s_variance; temp *= s_variance;
temp *= weight_v[idx]; temp *= weight_v[idx];
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) { for (int i = 0; i < width; ++i) {
out[id * width + i] = out[id * width + i] =
scaled_fp8_conversion<true, fp8_type>(float(temp.data[i]), scale_inv); scaled_fp8_conversion<true, fp8_type>(float(temp.data[i]), scale_inv);
} }
} }
} }
/* Generic fused_add_rms_norm_kernel /* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations. The width field is not used here but necessary for other specializations.
*/ */
template <typename scalar_t, int width, typename fp8_type> template <typename scalar_t, int width, typename fp8_type>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists> __global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel( fused_add_rms_norm_static_fp8_quant_kernel(
fp8_type* __restrict__ out, // [..., hidden_size] fp8_type* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size]
...@@ -160,11 +162,11 @@ fused_add_rms_norm_static_fp8_quant_kernel( ...@@ -160,11 +162,11 @@ fused_add_rms_norm_static_fp8_quant_kernel(
out[blockIdx.x * hidden_size + idx] = out[blockIdx.x * hidden_size + idx] =
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv); scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
} }
} }
} // namespace vllm } // namespace vllm
void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size] torch::Tensor& weight, // [hidden_size]
torch::Tensor& scale, // [1] torch::Tensor& scale, // [1]
...@@ -187,9 +189,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] ...@@ -187,9 +189,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
epsilon, num_tokens, hidden_size); epsilon, num_tokens, hidden_size);
}); });
}); });
} }
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ #define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \ input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \
VLLM_DISPATCH_FP8_TYPES( \ VLLM_DISPATCH_FP8_TYPES( \
...@@ -203,7 +205,7 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] ...@@ -203,7 +205,7 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
epsilon, num_tokens, hidden_size); \ epsilon, num_tokens, hidden_size); \
}); \ }); \
}); });
void fused_add_rms_norm_static_fp8_quant( void fused_add_rms_norm_static_fp8_quant(
torch::Tensor& out, // [..., hidden_size], torch::Tensor& out, // [..., hidden_size],
torch::Tensor& input, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size]
...@@ -239,4 +241,4 @@ void fused_add_rms_norm_static_fp8_quant( ...@@ -239,4 +241,4 @@ void fused_add_rms_norm_static_fp8_quant(
} else { } else {
LAUNCH_FUSED_ADD_RMS_NORM(0); LAUNCH_FUSED_ADD_RMS_NORM(0);
} }
} }
\ No newline at end of file
...@@ -58,15 +58,15 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, ...@@ -58,15 +58,15 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon); torch::Tensor& weight, double epsilon);
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, // void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& weight, torch::Tensor& scale, // torch::Tensor& weight, torch::Tensor& scale,
double epsilon); // double epsilon);
void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out, // void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
torch::Tensor& input, // torch::Tensor& input,
torch::Tensor& residual, // torch::Tensor& residual,
torch::Tensor& weight, // torch::Tensor& weight,
torch::Tensor& scale, double epsilon); // torch::Tensor& scale, double epsilon);
void rms_norm_dynamic_per_token_quant(torch::Tensor& out, void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
torch::Tensor const& input, torch::Tensor const& input,
...@@ -213,15 +213,15 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, ...@@ -213,15 +213,15 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, // void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale); // torch::Tensor const& scale);
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, // void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scale); // torch::Tensor& scale);
void dynamic_per_token_scaled_fp8_quant( // void dynamic_per_token_scaled_fp8_quant(
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, // torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
std::optional<torch::Tensor> const& scale_ub); // std::optional<torch::Tensor> const& scale_ub);
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& A, const torch::Tensor& B,
......
#pragma once #pragma once
#ifndef USE_ROCM
#include <hip/hip_fp8.h> #include <hip/hip_fp8.h>
#endif
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "quantization/vectorization.cuh" #include "quantization/vectorization.cuh"
// TODO(luka/varun):refactor common.cuh to use this file instead // TODO(luka/varun):refactor common.cuh to use this file instead
#include "quantization/fp8/common.cuh" // #include "quantization/fp8/common.cuh"
namespace vllm { namespace vllm {
......
...@@ -43,21 +43,21 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) { ...@@ -43,21 +43,21 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
// //
#if defined(__CUDA_ARCH__) || defined(USE_ROCM) // #if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) // #if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) { // __device__ __forceinline__ void atomicAdd(half* address, half val) {
atomicAdd_half(address, val); // atomicAdd_half(address, val);
} // }
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) // #if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { // __device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
atomicAdd_half2(address, val); // atomicAdd_half2(address, val);
} // }
#endif // #endif
#endif // #endif
#endif // #endif
} // namespace gptq } // namespace gptq
} // namespace vllm } // namespace vllm
......
...@@ -126,20 +126,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -126,20 +126,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Layernorm-quant // Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor. // Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def( // ops.def(
"rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, " // "rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
"Tensor scale, float epsilon) -> " // "Tensor scale, float epsilon) -> "
"()"); // "()");
ops.impl("rms_norm_static_fp8_quant", torch::kCUDA, // ops.impl("rms_norm_static_fp8_quant", torch::kCUDA,
&rms_norm_static_fp8_quant); // &rms_norm_static_fp8_quant);
// In-place fused Add and RMS Normalization. // In-place fused Add and RMS Normalization.
ops.def( // ops.def(
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, " // "fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
"Tensor! residual, Tensor weight, " // "Tensor! residual, Tensor weight, "
"Tensor scale, float epsilon) -> ()"); // "Tensor scale, float epsilon) -> ()");
ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA, // ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA,
&fused_add_rms_norm_static_fp8_quant); // &fused_add_rms_norm_static_fp8_quant);
// Fused Layernorm + Quant kernels // Fused Layernorm + Quant kernels
ops.def( ops.def(
...@@ -455,25 +455,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -455,25 +455,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// Compute FP8 quantized tensor for given scaling factor. // Compute FP8 quantized tensor for given scaling factor.
ops.def( // ops.def(
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> " // "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
"()"); // "()");
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant); // ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor. // // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
ops.def( // ops.def(
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) " // "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
"-> " // "-> "
"()"); // "()");
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant); // ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// Compute dynamic-per-token FP8 quantized tensor and scaling factor. // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
ops.def( // ops.def(
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, " // "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
"Tensor! scale, Tensor? scale_ub) -> " // "Tensor! scale, Tensor? scale_ub) -> "
"()"); // "()");
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, // ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
&dynamic_per_token_scaled_fp8_quant); // &dynamic_per_token_scaled_fp8_quant);
// Compute int8 quantized tensor for given scaling factor. // Compute int8 quantized tensor for given scaling factor.
ops.def( ops.def(
......
...@@ -643,8 +643,8 @@ ext_modules = [] ...@@ -643,8 +643,8 @@ ext_modules = []
if _is_cuda() or _is_hip(): if _is_cuda() or _is_hip():
ext_modules.append(CMakeExtension(name="vllm._moe_C")) ext_modules.append(CMakeExtension(name="vllm._moe_C"))
if _is_hip(): # if _is_hip():
ext_modules.append(CMakeExtension(name="vllm._rocm_C")) # ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
if _is_cuda(): if _is_cuda():
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C")) ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
......
...@@ -98,30 +98,30 @@ def paged_attention_v2( ...@@ -98,30 +98,30 @@ def paged_attention_v2(
blocksparse_block_size, blocksparse_head_sliding_step) blocksparse_block_size, blocksparse_head_sliding_step)
def paged_attention_rocm( # def paged_attention_rocm(
out: torch.Tensor, # out: torch.Tensor,
exp_sum: torch.Tensor, # exp_sum: torch.Tensor,
max_logits: torch.Tensor, # max_logits: torch.Tensor,
tmp_out: torch.Tensor, # tmp_out: torch.Tensor,
query: torch.Tensor, # query: torch.Tensor,
key_cache: torch.Tensor, # key_cache: torch.Tensor,
value_cache: torch.Tensor, # value_cache: torch.Tensor,
num_kv_heads: int, # num_kv_heads: int,
scale: float, # scale: float,
block_tables: torch.Tensor, # block_tables: torch.Tensor,
seq_lens: torch.Tensor, # seq_lens: torch.Tensor,
block_size: int, # block_size: int,
max_seq_len: int, # max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], # alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, # kv_cache_dtype: str,
k_scale: torch.Tensor, # k_scale: torch.Tensor,
v_scale: torch.Tensor, # v_scale: torch.Tensor,
) -> None: # ) -> None:
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, # torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads, # key_cache, value_cache, num_kv_heads,
scale, block_tables, seq_lens, # scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes, # block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale) # kv_cache_dtype, k_scale, v_scale)
# pos encoding ops # pos encoding ops
......
...@@ -790,9 +790,10 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -790,9 +790,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_seqs, num_heads, head_size = decode_query.shape num_seqs, num_heads, head_size = decode_query.shape
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
gqa_ratio = num_heads // self.num_kv_heads gqa_ratio = num_heads // self.num_kv_heads
use_custom = _use_rocm_custom_paged_attention( # use_custom = _use_rocm_custom_paged_attention(
decode_query.dtype, head_size, block_size, gqa_ratio, # decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len) # decode_meta.max_decode_seq_len)
use_custom = False
if use_custom: if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else != AttentionType.ENCODER_DECODER else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment