Commit 9e053941 authored by zhuwenwen's avatar zhuwenwen
Browse files

skip fp8 kernel and _rocm_C extension

parent f850f22a
...@@ -233,11 +233,11 @@ set(VLLM_EXT_SRC ...@@ -233,11 +233,11 @@ set(VLLM_EXT_SRC
"csrc/pos_encoding_kernels.cu" "csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu" "csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu" "csrc/layernorm_kernels.cu"
"csrc/layernorm_quant_kernels.cu" # "csrc/layernorm_quant_kernels.cu"
"csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu" # "csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" # "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu" "csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu" "csrc/prepare_inputs/advance_step.cu"
...@@ -613,6 +613,7 @@ define_gpu_extension_target( ...@@ -613,6 +613,7 @@ define_gpu_extension_target(
USE_SABI 3 USE_SABI 3
WITH_SOABI) WITH_SOABI)
#[[
if(VLLM_GPU_LANG STREQUAL "HIP") if(VLLM_GPU_LANG STREQUAL "HIP")
# #
# _rocm_C extension # _rocm_C extension
...@@ -631,9 +632,10 @@ if(VLLM_GPU_LANG STREQUAL "HIP") ...@@ -631,9 +632,10 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
USE_SABI 3 USE_SABI 3
WITH_SOABI) WITH_SOABI)
endif() endif()
]]
# For CUDA we also build and ship some external projects. # For CUDA we also build and ship some external projects.
if (VLLM_GPU_LANG STREQUAL "CUDA") if (VLLM_GPU_LANG STREQUAL "CUDA")
include(cmake/external_projects/flashmla.cmake) include(cmake/external_projects/flashmla.cmake)
include(cmake/external_projects/vllm_flash_attn.cmake) include(cmake/external_projects/vllm_flash_attn.cmake)
endif () endif ()
\ No newline at end of file
...@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) ...@@ -119,7 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list(APPEND GPU_FLAGS list(APPEND GPU_FLAGS
"-DUSE_ROCM" "-DUSE_ROCM"
"-DENABLE_FP8" #"-DENABLE_FP8"
"-U__HIP_NO_HALF_CONVERSIONS__" "-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__" "-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc") "-fno-gpu-rdc")
......
...@@ -17,660 +17,660 @@ ...@@ -17,660 +17,660 @@
* limitations under the License. * limitations under the License.
*/ */
#include <torch/all.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <algorithm> #include <algorithm>
#include "attention_dtypes.h" #include "attention_dtypes.h"
#include "attention_utils.cuh" #include "attention_utils.cuh"
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh" #include "../quantization/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#else #else
#include "../quantization/fp8/nvidia/quant_utils.cuh" #include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
#define WARP_SIZE 32 #define WARP_SIZE 32
#else #else
#define WARP_SIZE warpSize #define WARP_SIZE warpSize
#endif #endif
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
namespace vllm { namespace vllm {
// Utility function for attention softmax. // Utility function for attention softmax.
template <int NUM_WARPS> template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) { inline __device__ float block_sum(float* red_smem, float sum) {
// Decompose the thread index into warp / lane. // Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE; int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE; int lane = threadIdx.x % WARP_SIZE;
// Compute the sum per warp. // Compute the sum per warp.
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask); sum += VLLM_SHFL_XOR_SYNC(sum, mask);
} }
// Warp leaders store the data to shared memory. // Warp leaders store the data to shared memory.
if (lane == 0) { if (lane == 0) {
red_smem[warp] = sum; red_smem[warp] = sum;
} }
// Make sure the data is in shared memory. // Make sure the data is in shared memory.
__syncthreads(); __syncthreads();
// The warps compute the final sums. // The warps compute the final sums.
if (lane < NUM_WARPS) { if (lane < NUM_WARPS) {
sum = red_smem[lane]; sum = red_smem[lane];
} }
// Parallel reduction inside the warp. // Parallel reduction inside the warp.
#pragma unroll #pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask); sum += VLLM_SHFL_XOR_SYNC(sum, mask);
} }
// Broadcast to other threads. // Broadcast to other threads.
return VLLM_SHFL_SYNC(sum, 0); return VLLM_SHFL_SYNC(sum, 0);
} }
// TODO(woosuk): Merge the last two dimensions of the grid. // TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE, bool IS_BLOCK_SPARSE,
int PARTITION_SIZE = 0> // Zero means no partitioning. int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel( __device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size] // head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x] // head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, // [num_heads] const int num_kv_heads, // [num_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z; const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z; const int max_num_partitions = gridDim.z;
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
const int seq_len = seq_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
// No work to do. Terminate the thread block. // No work to do. Terminate the thread block.
return; return;
} }
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition = const int num_blocks_per_partition =
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
// [start_block_idx, end_block_idx) is the range of blocks to process. // [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx = const int start_block_idx =
USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
const int end_block_idx = const int end_block_idx =
MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
const int num_blocks = end_block_idx - start_block_idx; const int num_blocks = end_block_idx - start_block_idx;
// [start_token_idx, end_token_idx) is the range of tokens to process. // [start_token_idx, end_token_idx) is the range of tokens to process.
const int start_token_idx = start_block_idx * BLOCK_SIZE; const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx = const int end_token_idx =
MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
const int num_tokens = end_token_idx - start_token_idx; const int num_tokens = end_token_idx - start_token_idx;
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_THREAD_GROUPS = constexpr int NUM_THREAD_GROUPS =
NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
// divides NUM_THREADS // divides NUM_THREADS
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
constexpr int NUM_TOKENS_PER_THREAD_GROUP = constexpr int NUM_TOKENS_PER_THREAD_GROUP =
DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x; const int thread_idx = threadIdx.x;
const int warp_idx = thread_idx / WARP_SIZE; const int warp_idx = thread_idx / WARP_SIZE;
const int lane = thread_idx % WARP_SIZE; const int lane = thread_idx % WARP_SIZE;
const int head_idx = blockIdx.x; const int head_idx = blockIdx.x;
const int num_heads = gridDim.x; const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / num_kv_heads; const int num_queries_per_kv = num_heads / num_kv_heads;
const int kv_head_idx = head_idx / num_queries_per_kv; const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = const float alibi_slope =
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// A vector type to store a part of a key or a query. // A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread // The vector size is configured in such a way that the threads in a thread
// group fetch or compute 16 bytes at a time. For example, if the size of a // group fetch or compute 16 bytes at a time. For example, if the size of a
// thread group is 4 and the data type is half, then the vector size is 16 / // thread group is 4 and the data type is half, then the vector size is 16 /
// (4 * sizeof(half)) == 2. // (4 * sizeof(half)) == 2.
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type; using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type; using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type; using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
// Load the query to registers. // Load the query to registers.
// Each thread in a thread group has a different part of the query. // Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in // For example, if the the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the query, and the second thread // the group has 0, 4, 8, ... th vectors of the query, and the second thread
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// q is split from a qkv tensor, it may not be contiguous. // q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
#pragma unroll #pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
i += NUM_THREAD_GROUPS) { i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[thread_group_offset][i] = q_vecs[thread_group_offset][i] =
*reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE); *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
} }
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// memory wall right before we use q_vecs // memory wall right before we use q_vecs
// Memory planning. // Memory planning.
extern __shared__ char shared_mem[]; extern __shared__ char shared_mem[];
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
float* logits = reinterpret_cast<float*>(shared_mem); float* logits = reinterpret_cast<float*>(shared_mem);
// Workspace for reduction. // Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS]; __shared__ float red_smem[2 * NUM_WARPS];
// x == THREAD_GROUP_SIZE * VEC_SIZE // x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time. // Each thread group fetches x elements from the key at a time.
constexpr int x = 16 / sizeof(cache_t); constexpr int x = 16 / sizeof(cache_t);
float qk_max = -FLT_MAX; float qk_max = -FLT_MAX;
// Iterate over the key blocks. // Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration. // Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes // Each thread group in a warp fetches a key from the block, and computes
// dot product with the query. // dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
// blocksparse specific vars // blocksparse specific vars
int bs_block_offset; int bs_block_offset;
int q_bs_block_id; int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) { if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len, // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size); // blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size; q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0) if (blocksparse_head_sliding_step >= 0)
// sliding on q heads // sliding on q heads
bs_block_offset = bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1; (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else else
// sliding on kv heads // sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) + (-blocksparse_head_sliding_step) +
1; 1;
} }
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) { block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to // NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied // int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride). // by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not // For blocksparse attention: skip computation on blocks that are not
// attended // attended
if constexpr (IS_BLOCK_SPARSE) { if constexpr (IS_BLOCK_SPARSE) {
const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
const bool is_remote = const bool is_remote =
((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0); ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
const bool is_local = const bool is_local =
(k_bs_block_id > q_bs_block_id - blocksparse_local_blocks); (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
if (!is_remote && !is_local) { if (!is_remote && !is_local) {
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
if (thread_group_offset == 0) { if (thread_group_offset == 0) {
// NOTE(linxihui): assign very large number to skipped tokens to // NOTE(linxihui): assign very large number to skipped tokens to
// avoid contribution to the sumexp softmax normalizer. This will // avoid contribution to the sumexp softmax normalizer. This will
// not be used at computing sum(softmax*v) as the blocks will be // not be used at computing sum(softmax*v) as the blocks will be
// skipped. // skipped.
logits[token_idx - start_token_idx] = -FLT_MAX; logits[token_idx - start_token_idx] = -FLT_MAX;
} }
} }
continue; continue;
} }
} }
const int64_t physical_block_number = const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]); static_cast<int64_t>(block_table[block_idx]);
// Load a key to registers. // Load a key to registers.
// Each thread in a thread group has a different part of the key. // Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in // For example, if the the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the key, and the second thread // the group has 0, 4, 8, ... th vectors of the key, and the second thread
// has 1, 5, 9, ... th vectors of the key, and so on. // has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD]; K_vec k_vecs[NUM_VECS_PER_THREAD];
#pragma unroll #pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr = const cache_t* k_ptr =
k_cache + physical_block_number * kv_block_stride + k_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + physical_block_offset * x; kv_head_idx * kv_head_stride + physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x; const int offset2 = (vec_idx * VEC_SIZE) % x;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>( k_vecs[j] = *reinterpret_cast<const K_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} else { } else {
// Vector conversion from Quant_vec to K_vec. // Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>( Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>( k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, *k_scale); k_vec_quant, *k_scale);
} }
} }
// Compute dot product. // Compute dot product.
// This includes a reduction across the threads in the same thread group. // This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot( float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
q_vecs[thread_group_offset], k_vecs); q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given. // Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
if (thread_group_offset == 0) { if (thread_group_offset == 0) {
// Store the partial reductions to shared memory. // Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits. // NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= seq_len; const bool mask = token_idx >= seq_len;
logits[token_idx - start_token_idx] = mask ? 0.f : qk; logits[token_idx - start_token_idx] = mask ? 0.f : qk;
// Update the max value. // Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk); qk_max = mask ? qk_max : fmaxf(qk_max, qk);
} }
} }
} }
// Perform reduction across the threads in the same warp to get the // Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet). // max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value. // The 0-th thread of each thread group already has its max qk value.
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
} }
if (lane == 0) { if (lane == 0) {
red_smem[warp_idx] = qk_max; red_smem[warp_idx] = qk_max;
} }
__syncthreads(); __syncthreads();
// TODO(woosuk): Refactor this part. // TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence. // Get the max qk value for the sequence.
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll #pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
} }
// Broadcast the max qk value to all threads. // Broadcast the max qk value to all threads.
qk_max = VLLM_SHFL_SYNC(qk_max, 0); qk_max = VLLM_SHFL_SYNC(qk_max, 0);
// Get the sum of the exp values. // Get the sum of the exp values.
float exp_sum = 0.f; float exp_sum = 0.f;
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[i] - qk_max); float val = __expf(logits[i] - qk_max);
logits[i] = val; logits[i] = val;
exp_sum += val; exp_sum += val;
} }
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum); exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
// Compute softmax. // Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[i] *= inv_sum; logits[i] *= inv_sum;
} }
__syncthreads(); __syncthreads();
// If partitioning is enabled, store the max logit and exp_sum. // If partitioning is enabled, store the max logit and exp_sum.
if (USE_PARTITIONING && thread_idx == 0) { if (USE_PARTITIONING && thread_idx == 0) {
float* max_logits_ptr = max_logits + float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx; head_idx * max_num_partitions + partition_idx;
*max_logits_ptr = qk_max; *max_logits_ptr = qk_max;
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx; head_idx * max_num_partitions + partition_idx;
*exp_sums_ptr = exp_sum; *exp_sums_ptr = exp_sum;
} }
// Each thread will fetch 16 bytes from the value cache at a time. // Each thread will fetch 16 bytes from the value cache at a time.
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type; using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type; using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type; using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
using Float_L_vec = typename FloatVec<L_vec>::Type; using Float_L_vec = typename FloatVec<L_vec>::Type;
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
constexpr int NUM_ROWS_PER_THREAD = constexpr int NUM_ROWS_PER_THREAD =
DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy. // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[NUM_ROWS_PER_THREAD]; float accs[NUM_ROWS_PER_THREAD];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[i] = 0.f; accs[i] = 0.f;
} }
scalar_t zero_value; scalar_t zero_value;
zero(zero_value); zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) { block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to // NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied // int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride). // by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not // For blocksparse attention: skip computation on blocks that are not
// attended // attended
if constexpr (IS_BLOCK_SPARSE) { if constexpr (IS_BLOCK_SPARSE) {
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) && if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
continue; continue;
} }
} }
const int64_t physical_block_number = const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]); static_cast<int64_t>(block_table[block_idx]);
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec; L_vec logits_vec;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
start_token_idx)); start_token_idx));
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride; kv_head_idx * kv_head_stride;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) { if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset; const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec; V_vec v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset); v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else { } else {
V_quant_vec v_quant_vec = V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset); *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec. // Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec, v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
*v_scale); *v_scale);
} }
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the // NOTE(woosuk): When v_vec contains the tokens that are out of the
// context, we should explicitly zero out the values since they may // context, we should explicitly zero out the values since they may
// contain NaNs. See // contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec); scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll #pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) { for (int j = 0; j < V_VEC_SIZE; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
} }
} }
accs[i] += dot(logits_vec, v_vec); accs[i] += dot(logits_vec, v_vec);
} }
} }
} }
// Perform reduction within each warp. // Perform reduction within each warp.
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[i]; float acc = accs[i];
#pragma unroll #pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += VLLM_SHFL_XOR_SYNC(acc, mask); acc += VLLM_SHFL_XOR_SYNC(acc, mask);
} }
accs[i] = acc; accs[i] = acc;
} }
// NOTE(woosuk): A barrier is required because the shared memory space for // NOTE(woosuk): A barrier is required because the shared memory space for
// logits is reused for the output. // logits is reused for the output.
__syncthreads(); __syncthreads();
// Perform reduction across warps. // Perform reduction across warps.
float* out_smem = reinterpret_cast<float*>(shared_mem); float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll #pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) { for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2; int mid = i / 2;
// Upper warps write to shared memory. // Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) { if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
dst[row_idx] = accs[i]; dst[row_idx] = accs[i];
} }
} }
} }
__syncthreads(); __syncthreads();
// Lower warps update the output. // Lower warps update the output.
if (warp_idx < mid) { if (warp_idx < mid) {
const float* src = &out_smem[warp_idx * HEAD_SIZE]; const float* src = &out_smem[warp_idx * HEAD_SIZE];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
accs[i] += src[row_idx]; accs[i] += src[row_idx];
} }
} }
} }
__syncthreads(); __syncthreads();
} }
// Write the final output. // Write the final output.
if (warp_idx == 0) { if (warp_idx == 0) {
scalar_t* out_ptr = scalar_t* out_ptr =
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
from_float(*(out_ptr + row_idx), accs[i]); from_float(*(out_ptr + row_idx), accs[i]);
} }
} }
} }
} }
// Grid: (num_heads, num_seqs, 1). // Grid: (num_heads, num_seqs, 1).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE> bool IS_BLOCK_SPARSE>
__global__ void paged_attention_v1_kernel( __global__ void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x] // head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, // [num_heads] const int num_kv_heads, // [num_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE>( KV_DTYPE, IS_BLOCK_SPARSE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_kv_heads, scale, block_tables, seq_lens, v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step); blocksparse_head_sliding_step);
} }
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE, bool IS_BLOCK_SPARSE,
int PARTITION_SIZE> int PARTITION_SIZE>
__global__ void paged_attention_v2_kernel( __global__ void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
}
// Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE>
__global__ void paged_attention_v2_reduce_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads, scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size] // max_num_partitions, head_size]
const int* __restrict__ seq_lens, // [num_seqs] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const int max_num_partitions) { const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
const int num_heads = gridDim.x; // head_size/x, block_size, x]
const int head_idx = blockIdx.x; const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
const int seq_idx = blockIdx.y; // head_size, block_size]
const int seq_len = seq_lens[seq_idx]; const int num_kv_heads, // [num_heads]
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); const float scale,
if (num_partitions == 1) { const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
// No need to reduce. Only copy tmp_out to out. const int* __restrict__ seq_lens, // [num_seqs]
scalar_t* out_ptr = const int max_num_blocks_per_seq,
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; const float* __restrict__ alibi_slopes, // [num_heads]
const scalar_t* tmp_out_ptr = const int q_stride, const int kv_block_stride, const int kv_head_stride,
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + const float* k_scale, const float* v_scale, const int tp_rank,
head_idx * max_num_partitions * HEAD_SIZE; const int blocksparse_local_blocks, const int blocksparse_vert_stride,
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
out_ptr[i] = tmp_out_ptr[i]; paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
} KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
// Terminate the thread block. exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
return; block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
} kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; blocksparse_head_sliding_step);
const int warp_idx = threadIdx.x / WARP_SIZE; }
const int lane = threadIdx.x % WARP_SIZE;
// Grid: (num_heads, num_seqs).
// Size: 2 * num_partitions. template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
extern __shared__ char shared_mem[]; int PARTITION_SIZE>
// Workspace for reduction. __global__ void paged_attention_v2_reduce_kernel(
__shared__ float red_smem[2 * NUM_WARPS]; scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// Load max logits to shared memory. // max_num_partitions]
float* shared_max_logits = reinterpret_cast<float*>(shared_mem); const float* __restrict__ max_logits, // [num_seqs, num_heads,
const float* max_logits_ptr = max_logits + // max_num_partitions]
seq_idx * num_heads * max_num_partitions + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
head_idx * max_num_partitions; // max_num_partitions, head_size]
float max_logit = -FLT_MAX; const int* __restrict__ seq_lens, // [num_seqs]
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { const int max_num_partitions) {
const float l = max_logits_ptr[i]; const int num_heads = gridDim.x;
shared_max_logits[i] = l; const int head_idx = blockIdx.x;
max_logit = fmaxf(max_logit, l); const int seq_idx = blockIdx.y;
} const int seq_len = seq_lens[seq_idx];
__syncthreads(); const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
if (num_partitions == 1) {
// Get the global max logit. // No need to reduce. Only copy tmp_out to out.
// Reduce within the warp. scalar_t* out_ptr =
#pragma unroll out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { const scalar_t* tmp_out_ptr =
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
} head_idx * max_num_partitions * HEAD_SIZE;
if (lane == 0) { for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
red_smem[warp_idx] = max_logit; out_ptr[i] = tmp_out_ptr[i];
} }
__syncthreads(); // Terminate the thread block.
// Reduce across warps. return;
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; }
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); const int warp_idx = threadIdx.x / WARP_SIZE;
} const int lane = threadIdx.x % WARP_SIZE;
// Broadcast the max value to all threads.
max_logit = VLLM_SHFL_SYNC(max_logit, 0); // Size: 2 * num_partitions.
extern __shared__ char shared_mem[];
// Load rescaled exp sums to shared memory. // Workspace for reduction.
float* shared_exp_sums = __shared__ float red_smem[2 * NUM_WARPS];
reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
const float* exp_sums_ptr = exp_sums + // Load max logits to shared memory.
seq_idx * num_heads * max_num_partitions + float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
head_idx * max_num_partitions; const float* max_logits_ptr = max_logits +
float global_exp_sum = 0.0f; seq_idx * num_heads * max_num_partitions +
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { head_idx * max_num_partitions;
float l = shared_max_logits[i]; float max_logit = -FLT_MAX;
float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
global_exp_sum += rescaled_exp_sum; const float l = max_logits_ptr[i];
shared_exp_sums[i] = rescaled_exp_sum; shared_max_logits[i] = l;
} max_logit = fmaxf(max_logit, l);
__syncthreads(); }
global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum); __syncthreads();
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
// Get the global max logit.
// Aggregate tmp_out to out. // Reduce within the warp.
const scalar_t* tmp_out_ptr = #pragma unroll
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
head_idx * max_num_partitions * HEAD_SIZE; max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
scalar_t* out_ptr = }
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; if (lane == 0) {
#pragma unroll red_smem[warp_idx] = max_logit;
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { }
float acc = 0.0f; __syncthreads();
for (int j = 0; j < num_partitions; ++j) { // Reduce across warps.
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
inv_global_exp_sum; #pragma unroll
} for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
from_float(out_ptr[i], acc); max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
} }
} // Broadcast the max value to all threads.
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
} // namespace vllm
// Load rescaled exp sums to shared memory.
#undef WARP_SIZE float* shared_exp_sums =
#undef MAX reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
#undef MIN const float* exp_sums_ptr = exp_sums +
#undef DIVIDE_ROUND_UP seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float global_exp_sum = 0.0f;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
float l = shared_max_logits[i];
float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
global_exp_sum += rescaled_exp_sum;
shared_exp_sums[i] = rescaled_exp_sum;
}
__syncthreads();
global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
// Aggregate tmp_out to out.
const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE;
scalar_t* out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
float acc = 0.0f;
for (int j = 0; j < num_partitions; ++j) {
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
inv_global_exp_sum;
}
from_float(out_ptr[i], acc);
}
}
} // namespace vllm
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
\ No newline at end of file
...@@ -728,4 +728,4 @@ void gather_cache( ...@@ -728,4 +728,4 @@ void gather_cache(
} else { } else {
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits); TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
} }
} }
\ No newline at end of file
...@@ -5,238 +5,240 @@ ...@@ -5,238 +5,240 @@
* Currently, only static fp8 quantization is supported. * Currently, only static fp8 quantization is supported.
*/ */
#include "type_convert.cuh" #include "type_convert.cuh"
#include "quantization/fp8/common.cuh" #ifndef USE_ROCM
#include "dispatch_utils.h" #include "quantization/fp8/common.cuh"
#endif
#include <torch/cuda.h> #include "dispatch_utils.h"
#include <c10/cuda/CUDAGuard.h>
#include <torch/cuda.h>
#ifndef USE_ROCM #include <c10/cuda/CUDAGuard.h>
#include <cub/cub.cuh>
#else #ifndef USE_ROCM
#include <hipcub/hipcub.hpp> #include <cub/cub.cuh>
#endif #else
#include <hipcub/hipcub.hpp>
namespace vllm { #endif
// TODO(woosuk): Further optimize this kernel. namespace vllm {
template <typename scalar_t, typename fp8_type>
__global__ void rms_norm_static_fp8_quant_kernel( // TODO(woosuk): Further optimize this kernel.
fp8_type* __restrict__ out, // [..., hidden_size] template <typename scalar_t, typename fp8_type>
const scalar_t* __restrict__ input, // [..., hidden_size] __global__ void rms_norm_static_fp8_quant_kernel(
const scalar_t* __restrict__ weight, // [hidden_size] fp8_type* __restrict__ out, // [..., hidden_size]
const float* __restrict__ scale, // [1] const scalar_t* __restrict__ input, // [..., hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) { const scalar_t* __restrict__ weight, // [hidden_size]
__shared__ float s_variance; const float* __restrict__ scale, // [1]
float variance = 0.0f; const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { float variance = 0.0f;
const float x = (float)input[blockIdx.x * hidden_size + idx];
variance += x * x; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
} const float x = (float)input[blockIdx.x * hidden_size + idx];
variance += x * x;
using BlockReduce = cub::BlockReduce<float, 1024>; }
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
if (threadIdx.x == 0) { variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
s_variance = rsqrtf(variance / hidden_size + epsilon);
} if (threadIdx.x == 0) {
__syncthreads(); s_variance = rsqrtf(variance / hidden_size + epsilon);
}
// invert scale to avoid division __syncthreads();
float const scale_inv = 1.0f / *scale;
// invert scale to avoid division
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { float const scale_inv = 1.0f / *scale;
float x = (float)input[blockIdx.x * hidden_size + idx];
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
out[blockIdx.x * hidden_size + idx] = float x = (float)input[blockIdx.x * hidden_size + idx];
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv); float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
} out[blockIdx.x * hidden_size + idx] =
} scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
}
/* Function specialization in the case of FP16/BF16 tensors. }
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the /* Function specialization in the case of FP16/BF16 tensors.
memory latency bottleneck. */ Additional optimizations we can make in this case are
template <typename scalar_t, int width, typename fp8_type> packed and vectorized operations, which help with the
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists> memory latency bottleneck. */
fused_add_rms_norm_static_fp8_quant_kernel( template <typename scalar_t, int width, typename fp8_type>
fp8_type* __restrict__ out, // [..., hidden_size] __global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
scalar_t* __restrict__ input, // [..., hidden_size] fused_add_rms_norm_static_fp8_quant_kernel(
scalar_t* __restrict__ residual, // [..., hidden_size] fp8_type* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
const float* __restrict__ scale, // [1] scalar_t* __restrict__ residual, // [..., hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) { const scalar_t* __restrict__ weight, // [hidden_size]
// Sanity checks on our vector struct and type-punned pointer arithmetic const float* __restrict__ scale, // [1]
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>); const float epsilon, const int num_tokens, const int hidden_size) {
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width); // Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
const int vec_hidden_size = hidden_size / width; static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
__shared__ float s_variance;
float variance = 0.0f; const int vec_hidden_size = hidden_size / width;
/* These and the argument pointers are all declared `restrict` as they are __shared__ float s_variance;
not aliased in practice. Argument pointers should not be dereferenced float variance = 0.0f;
in this kernel as that would be undefined behavior */ /* These and the argument pointers are all declared `restrict` as they are
auto* __restrict__ input_v = not aliased in practice. Argument pointers should not be dereferenced
reinterpret_cast<_f16Vec<scalar_t, width>*>(input); in this kernel as that would be undefined behavior */
auto* __restrict__ residual_v = auto* __restrict__ input_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual); reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
auto* __restrict__ weight_v = auto* __restrict__ residual_v =
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight); reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
auto* __restrict__ weight_v =
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = input_v[id]; for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
temp += residual_v[id]; int id = blockIdx.x * vec_hidden_size + idx;
variance += temp.sum_squares(); _f16Vec<scalar_t, width> temp = input_v[id];
residual_v[id] = temp; temp += residual_v[id];
} variance += temp.sum_squares();
residual_v[id] = temp;
using BlockReduce = cub::BlockReduce<float, 1024>; }
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
if (threadIdx.x == 0) { variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
s_variance = rsqrtf(variance / hidden_size + epsilon);
} if (threadIdx.x == 0) {
__syncthreads(); s_variance = rsqrtf(variance / hidden_size + epsilon);
}
// invert scale to avoid division __syncthreads();
float const scale_inv = 1.0f / *scale;
// invert scale to avoid division
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { float const scale_inv = 1.0f / *scale;
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = residual_v[id]; for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
temp *= s_variance; int id = blockIdx.x * vec_hidden_size + idx;
temp *= weight_v[idx]; _f16Vec<scalar_t, width> temp = residual_v[id];
#pragma unroll temp *= s_variance;
for (int i = 0; i < width; ++i) { temp *= weight_v[idx];
out[id * width + i] = #pragma unroll
scaled_fp8_conversion<true, fp8_type>(float(temp.data[i]), scale_inv); for (int i = 0; i < width; ++i) {
} out[id * width + i] =
} scaled_fp8_conversion<true, fp8_type>(float(temp.data[i]), scale_inv);
} }
}
/* Generic fused_add_rms_norm_kernel }
The width field is not used here but necessary for other specializations.
*/ /* Generic fused_add_rms_norm_kernel
template <typename scalar_t, int width, typename fp8_type> The width field is not used here but necessary for other specializations.
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists> */
fused_add_rms_norm_static_fp8_quant_kernel( template <typename scalar_t, int width, typename fp8_type>
fp8_type* __restrict__ out, // [..., hidden_size] __global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
scalar_t* __restrict__ input, // [..., hidden_size] fused_add_rms_norm_static_fp8_quant_kernel(
scalar_t* __restrict__ residual, // [..., hidden_size] fp8_type* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
const float* __restrict__ scale, // [1] scalar_t* __restrict__ residual, // [..., hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) { const scalar_t* __restrict__ weight, // [hidden_size]
__shared__ float s_variance; const float* __restrict__ scale, // [1]
float variance = 0.0f; const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { float variance = 0.0f;
scalar_t z = input[blockIdx.x * hidden_size + idx];
z += residual[blockIdx.x * hidden_size + idx]; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)z; scalar_t z = input[blockIdx.x * hidden_size + idx];
variance += x * x; z += residual[blockIdx.x * hidden_size + idx];
residual[blockIdx.x * hidden_size + idx] = z; float x = (float)z;
} variance += x * x;
residual[blockIdx.x * hidden_size + idx] = z;
using BlockReduce = cub::BlockReduce<float, 1024>; }
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
if (threadIdx.x == 0) { variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
s_variance = rsqrtf(variance / hidden_size + epsilon);
} if (threadIdx.x == 0) {
__syncthreads(); s_variance = rsqrtf(variance / hidden_size + epsilon);
}
// invert scale to avoid division __syncthreads();
float const scale_inv = 1.0f / *scale;
// invert scale to avoid division
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { float const scale_inv = 1.0f / *scale;
float x = (float)residual[blockIdx.x * hidden_size + idx];
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
out[blockIdx.x * hidden_size + idx] = float x = (float)residual[blockIdx.x * hidden_size + idx];
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv); float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
} out[blockIdx.x * hidden_size + idx] =
} scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
}
} // namespace vllm }
void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] } // namespace vllm
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size] void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& scale, // [1] torch::Tensor& input, // [..., hidden_size]
double epsilon) { torch::Tensor& weight, // [hidden_size]
int hidden_size = input.size(-1); torch::Tensor& scale, // [1]
int num_tokens = input.numel() / hidden_size; double epsilon) {
int hidden_size = input.size(-1);
dim3 grid(num_tokens); int num_tokens = input.numel() / hidden_size;
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); dim3 grid(num_tokens);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 block(std::min(hidden_size, 1024));
VLLM_DISPATCH_FLOATING_TYPES( const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
input.scalar_type(), "rms_norm_kernel_scalar_type", [&] { const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FP8_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
out.scalar_type(), "rms_norm_kernel_fp8_type", [&] { input.scalar_type(), "rms_norm_kernel_scalar_type", [&] {
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t> VLLM_DISPATCH_FP8_TYPES(
<<<grid, block, 0, stream>>>( out.scalar_type(), "rms_norm_kernel_fp8_type", [&] {
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), <<<grid, block, 0, stream>>>(
epsilon, num_tokens, hidden_size); out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
}); weight.data_ptr<scalar_t>(), scale.data_ptr<float>(),
}); epsilon, num_tokens, hidden_size);
} });
});
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ }
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \ #define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FP8_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \ input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, \ VLLM_DISPATCH_FP8_TYPES( \
width, fp8_t> \ out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
<<<grid, block, 0, stream>>>( \ vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, \
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \ width, fp8_t> \
residual.data_ptr<scalar_t>(), \ <<<grid, block, 0, stream>>>( \
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \ out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
epsilon, num_tokens, hidden_size); \ residual.data_ptr<scalar_t>(), \
}); \ weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
}); epsilon, num_tokens, hidden_size); \
void fused_add_rms_norm_static_fp8_quant( }); \
torch::Tensor& out, // [..., hidden_size], });
torch::Tensor& input, // [..., hidden_size] void fused_add_rms_norm_static_fp8_quant(
torch::Tensor& residual, // [..., hidden_size] torch::Tensor& out, // [..., hidden_size],
torch::Tensor& weight, // [hidden_size] torch::Tensor& input, // [..., hidden_size]
torch::Tensor& scale, // [1] torch::Tensor& residual, // [..., hidden_size]
double epsilon) { torch::Tensor& weight, // [hidden_size]
int hidden_size = input.size(-1); torch::Tensor& scale, // [1]
int num_tokens = input.numel() / hidden_size; double epsilon) {
int hidden_size = input.size(-1);
dim3 grid(num_tokens); int num_tokens = input.numel() / hidden_size;
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows dim3 grid(num_tokens);
for increased block occupancy on CUs and better latency /* This kernel is memory-latency bound in many scenarios.
hiding on global mem ops. */ When num_tokens is large, a smaller block size allows
const int max_block_size = (num_tokens < 256) ? 1024 : 256; for increased block occupancy on CUs and better latency
dim3 block(std::min(hidden_size, max_block_size)); hiding on global mem ops. */
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const int max_block_size = (num_tokens < 256) ? 1024 : 256;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); dim3 block(std::min(hidden_size, max_block_size));
/*If the tensor types are FP16/BF16, try to use the optimized kernel const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
with packed + vectorized ops. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Max optimization is achieved with a width-8 vector of FP16/BF16s /*If the tensor types are FP16/BF16, try to use the optimized kernel
since we can load at most 128 bits at once in a global memory op. with packed + vectorized ops.
However, this requires each tensor's data to be aligned to 16 Max optimization is achieved with a width-8 vector of FP16/BF16s
bytes. since we can load at most 128 bits at once in a global memory op.
*/ However, this requires each tensor's data to be aligned to 16
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr()); bytes.
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr()); */
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr()); auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
bool ptrs_are_aligned = auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
if (ptrs_are_aligned && hidden_size % 8 == 0) { bool ptrs_are_aligned =
LAUNCH_FUSED_ADD_RMS_NORM(8); inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
} else { if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_ADD_RMS_NORM(0); LAUNCH_FUSED_ADD_RMS_NORM(8);
} } else {
} LAUNCH_FUSED_ADD_RMS_NORM(0);
}
}
\ No newline at end of file
...@@ -58,15 +58,15 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, ...@@ -58,15 +58,15 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon); torch::Tensor& weight, double epsilon);
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, // void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& weight, torch::Tensor& scale, // torch::Tensor& weight, torch::Tensor& scale,
double epsilon); // double epsilon);
void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out, // void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
torch::Tensor& input, // torch::Tensor& input,
torch::Tensor& residual, // torch::Tensor& residual,
torch::Tensor& weight, // torch::Tensor& weight,
torch::Tensor& scale, double epsilon); // torch::Tensor& scale, double epsilon);
void rms_norm_dynamic_per_token_quant(torch::Tensor& out, void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
torch::Tensor const& input, torch::Tensor const& input,
...@@ -213,15 +213,15 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, ...@@ -213,15 +213,15 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, // void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale); // torch::Tensor const& scale);
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, // void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scale); // torch::Tensor& scale);
void dynamic_per_token_scaled_fp8_quant( // void dynamic_per_token_scaled_fp8_quant(
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale, // torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
std::optional<torch::Tensor> const& scale_ub); // std::optional<torch::Tensor> const& scale_ub);
void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& A, const torch::Tensor& B,
......
#pragma once #pragma once
#ifndef USE_ROCM
#include <hip/hip_fp8.h> #include <hip/hip_fp8.h>
#endif
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
...@@ -670,4 +672,4 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { ...@@ -670,4 +672,4 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
} // namespace fp8 } // namespace fp8
#endif // USE_ROCM #endif // USE_ROCM
} // namespace vllm } // namespace vllm
\ No newline at end of file
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "quantization/vectorization.cuh" #include "quantization/vectorization.cuh"
// TODO(luka/varun):refactor common.cuh to use this file instead // TODO(luka/varun):refactor common.cuh to use this file instead
#include "quantization/fp8/common.cuh" // #include "quantization/fp8/common.cuh"
namespace vllm { namespace vllm {
......
...@@ -43,21 +43,21 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) { ...@@ -43,21 +43,21 @@ __device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) {
// //
#if defined(__CUDA_ARCH__) || defined(USE_ROCM) // #if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) // #if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) { // __device__ __forceinline__ void atomicAdd(half* address, half val) {
atomicAdd_half(address, val); // atomicAdd_half(address, val);
} // }
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) // #if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { // __device__ __forceinline__ void atomicAdd(half2* address, half2 val) {
atomicAdd_half2(address, val); // atomicAdd_half2(address, val);
} // }
#endif // #endif
#endif // #endif
#endif // #endif
} // namespace gptq } // namespace gptq
} // namespace vllm } // namespace vllm
......
...@@ -126,20 +126,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -126,20 +126,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Layernorm-quant // Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor. // Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def( // ops.def(
"rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, " // "rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
"Tensor scale, float epsilon) -> " // "Tensor scale, float epsilon) -> "
"()"); // "()");
ops.impl("rms_norm_static_fp8_quant", torch::kCUDA, // ops.impl("rms_norm_static_fp8_quant", torch::kCUDA,
&rms_norm_static_fp8_quant); // &rms_norm_static_fp8_quant);
// In-place fused Add and RMS Normalization. // In-place fused Add and RMS Normalization.
ops.def( // ops.def(
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, " // "fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
"Tensor! residual, Tensor weight, " // "Tensor! residual, Tensor weight, "
"Tensor scale, float epsilon) -> ()"); // "Tensor scale, float epsilon) -> ()");
ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA, // ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA,
&fused_add_rms_norm_static_fp8_quant); // &fused_add_rms_norm_static_fp8_quant);
// Fused Layernorm + Quant kernels // Fused Layernorm + Quant kernels
ops.def( ops.def(
...@@ -455,25 +455,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -455,25 +455,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// Compute FP8 quantized tensor for given scaling factor. // Compute FP8 quantized tensor for given scaling factor.
ops.def( // ops.def(
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> " // "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
"()"); // "()");
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant); // ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor. // // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
ops.def( // ops.def(
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) " // "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
"-> " // "-> "
"()"); // "()");
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant); // ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// Compute dynamic-per-token FP8 quantized tensor and scaling factor. // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
ops.def( // ops.def(
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, " // "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
"Tensor! scale, Tensor? scale_ub) -> " // "Tensor! scale, Tensor? scale_ub) -> "
"()"); // "()");
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, // ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
&dynamic_per_token_scaled_fp8_quant); // &dynamic_per_token_scaled_fp8_quant);
// Compute int8 quantized tensor for given scaling factor. // Compute int8 quantized tensor for given scaling factor.
ops.def( ops.def(
...@@ -602,4 +602,4 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { ...@@ -602,4 +602,4 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
} }
#endif #endif
REGISTER_EXTENSION(TORCH_EXTENSION_NAME) REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
\ No newline at end of file
...@@ -643,8 +643,8 @@ ext_modules = [] ...@@ -643,8 +643,8 @@ ext_modules = []
if _is_cuda() or _is_hip(): if _is_cuda() or _is_hip():
ext_modules.append(CMakeExtension(name="vllm._moe_C")) ext_modules.append(CMakeExtension(name="vllm._moe_C"))
if _is_hip(): # if _is_hip():
ext_modules.append(CMakeExtension(name="vllm._rocm_C")) # ext_modules.append(CMakeExtension(name="vllm._rocm_C"))
if _is_cuda(): if _is_cuda():
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C")) ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
......
...@@ -98,30 +98,30 @@ def paged_attention_v2( ...@@ -98,30 +98,30 @@ def paged_attention_v2(
blocksparse_block_size, blocksparse_head_sliding_step) blocksparse_block_size, blocksparse_head_sliding_step)
def paged_attention_rocm( # def paged_attention_rocm(
out: torch.Tensor, # out: torch.Tensor,
exp_sum: torch.Tensor, # exp_sum: torch.Tensor,
max_logits: torch.Tensor, # max_logits: torch.Tensor,
tmp_out: torch.Tensor, # tmp_out: torch.Tensor,
query: torch.Tensor, # query: torch.Tensor,
key_cache: torch.Tensor, # key_cache: torch.Tensor,
value_cache: torch.Tensor, # value_cache: torch.Tensor,
num_kv_heads: int, # num_kv_heads: int,
scale: float, # scale: float,
block_tables: torch.Tensor, # block_tables: torch.Tensor,
seq_lens: torch.Tensor, # seq_lens: torch.Tensor,
block_size: int, # block_size: int,
max_seq_len: int, # max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], # alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, # kv_cache_dtype: str,
k_scale: torch.Tensor, # k_scale: torch.Tensor,
v_scale: torch.Tensor, # v_scale: torch.Tensor,
) -> None: # ) -> None:
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, # torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads, # key_cache, value_cache, num_kv_heads,
scale, block_tables, seq_lens, # scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes, # block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale) # kv_cache_dtype, k_scale, v_scale)
# pos encoding ops # pos encoding ops
...@@ -1365,4 +1365,4 @@ def flash_mla_with_kvcache( ...@@ -1365,4 +1365,4 @@ def flash_mla_with_kvcache(
tile_scheduler_metadata, tile_scheduler_metadata,
num_splits, num_splits,
) )
return out, softmax_lse return out, softmax_lse
\ No newline at end of file
...@@ -790,9 +790,10 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -790,9 +790,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_seqs, num_heads, head_size = decode_query.shape num_seqs, num_heads, head_size = decode_query.shape
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
gqa_ratio = num_heads // self.num_kv_heads gqa_ratio = num_heads // self.num_kv_heads
use_custom = _use_rocm_custom_paged_attention( # use_custom = _use_rocm_custom_paged_attention(
decode_query.dtype, head_size, block_size, gqa_ratio, # decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len) # decode_meta.max_decode_seq_len)
use_custom = False
if use_custom: if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else != AttentionType.ENCODER_DECODER else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment