Commit 0aa2480f authored by zhuwenwen's avatar zhuwenwen
Browse files

Refactoring the optimized kernel

parent 2f9e0bad
...@@ -181,7 +181,10 @@ set(VLLM_EXT_SRC ...@@ -181,7 +181,10 @@ set(VLLM_EXT_SRC
"csrc/pos_encoding_tgi_kernels.cu" "csrc/pos_encoding_tgi_kernels.cu"
"csrc/activation_kernels.cu" "csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu" "csrc/layernorm_kernels.cu"
"csrc/transpose_kernels.cu" "csrc/opt/transpose_kernels.cu"
"csrc/opt/activation_kernels_opt.cu"
"csrc/attention/attention_kernels_opt.cu"
"csrc/opt/layernorm_kernels_opt.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
......
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/all.h> #include <torch/all.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <cmath> #include <cmath>
...@@ -24,60 +23,6 @@ __global__ void act_and_mul_kernel( ...@@ -24,60 +23,6 @@ __global__ void act_and_mul_kernel(
} }
} }
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), int VEC>
__global__ void act_and_mul_kernel_vectorize1(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
using VecType = at::native::memory::aligned_vector<scalar_t, VEC>;
const int64_t token_idx= blockIdx.x;
int idx = threadIdx.x * VEC;
if (idx < d) {
const int64_t x_index = token_idx * 2 * d + idx;
const int64_t y_index = token_idx * d + idx;
VecType* x1 = (VecType*)(input + x_index);
VecType* x2 = (VecType*)(input + x_index + d);
VecType* y = (VecType*)(out + y_index);
scalar_t r_x1[VEC];
scalar_t r_x2[VEC];
scalar_t r_y[VEC];
*(VecType*)r_x1 = *x1;
*(VecType*)r_x2 = *x2;
#pragma unroll
for (int i = 0; i < VEC; i++) {
r_y[i] = ACT_FN(r_x1[i]) * r_x2[i];
}
*y = *(VecType*)r_y;
}
}
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), int VEC>
__global__ void act_and_mul_kernel_vectorize2(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
using VecType = at::native::memory::aligned_vector<scalar_t, VEC>;
const int64_t token_idx = blockIdx.x;
int idx = threadIdx.x * VEC;
for (; idx < d; idx += blockDim.x * VEC) {
const int64_t x_index = token_idx * 2 * d + idx;
const int64_t y_index = token_idx * d + idx;
VecType* x1 = (VecType*)(input + x_index);
VecType* x2 = (VecType*)(input + x_index + d);
VecType* y = (VecType*)(out + y_index);
scalar_t r_x1[VEC];
scalar_t r_x2[VEC];
scalar_t r_y[VEC];
*(VecType*)r_x1 = *x1;
*(VecType*)r_x2 = *x2;
#pragma unroll
for (int i = 0; i < VEC; i++) {
r_y[i] = ACT_FN(r_x1[i]) * r_x2[i];
}
*y = *(VecType*)r_y;
}
}
template <typename T> template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) { __device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x) // x * sigmoid(x)
...@@ -109,42 +54,19 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { ...@@ -109,42 +54,19 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
} // namespace vllm } // namespace vllm
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ // Launch activation and gating kernel.
int d = input.size(-1) / 2; \ #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int64_t num_tokens = input.numel() / input.size(-1); \ int d = input.size(-1) / 2; \
dim3 grid(num_tokens); \ int64_t num_tokens = input.numel() / input.size(-1); \
dim3 block(std::min(d, 1024)); \ dim3 grid(num_tokens); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
VLLM_DISPATCH_FLOATING_TYPES( \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
input.scalar_type(), "act_and_mul_kernel", [&] { \ VLLM_DISPATCH_FLOATING_TYPES( \
if (0 == d % 8 && d <= 16384) { \ input.scalar_type(), "act_and_mul_kernel", [&] { \
if (d <= 512) { \ vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 2> \ <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \ input.data_ptr<scalar_t>(), d); \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 1024) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 128, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 2048) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 4096) { \
vllm::act_and_mul_kernel_vectorize1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 512, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else { \
vllm::act_and_mul_kernel_vectorize2<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 1024, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} \
} else { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} \
}); });
void silu_and_mul(torch::Tensor& out, // [..., d] void silu_and_mul(torch::Tensor& out, // [..., d]
...@@ -237,5 +159,4 @@ void gelu_quick(torch::Tensor& out, // [..., d] ...@@ -237,5 +159,4 @@ void gelu_quick(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., d] torch::Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_quick_kernel);
} }
\ No newline at end of file
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
...@@ -20,8 +39,6 @@ typedef __hip_bfloat16 __nv_bfloat16; ...@@ -20,8 +39,6 @@ typedef __hip_bfloat16 __nv_bfloat16;
#define WARP_SIZE warpSize #define WARP_SIZE warpSize
#endif #endif
#include "static_switch.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
...@@ -68,12 +85,9 @@ inline __device__ float block_sum(float* red_smem, float sum) { ...@@ -68,12 +85,9 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE, bool IS_BLOCK_SPARSE,
int REUSE_KV_TIMES = 1,
bool odd_nheads = false,
int PARTITION_SIZE = 0> // Zero means no partitioning. int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel_opt(
__device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
...@@ -84,32 +98,30 @@ __device__ void paged_attention_kernel( ...@@ -84,32 +98,30 @@ __device__ void paged_attention_kernel(
// head_size/x, block_size, x] // head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_heads, // [num_heads] const int num_kv_heads, // [num_heads]
const int num_kv_heads, // [num_kv_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.z; const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.y; const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.y; const int max_num_partitions = gridDim.z;
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
const int seq_len = seq_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
// No work to do. Terminate the thread block. // No work to do. Terminate the thread block.
return; return;
} }
if constexpr (sizeof(scalar_t)==2){
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition = const int num_blocks_per_partition =
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
// [start_block_idx, end_block_idx) is the range of blocks to process. // [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx = const int start_block_idx =
USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
...@@ -132,38 +144,22 @@ __device__ void paged_attention_kernel( ...@@ -132,38 +144,22 @@ __device__ void paged_attention_kernel(
DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x; const int thread_idx = threadIdx.x;
// const int warp_idx_vec = thread_idx / WARP_SIZE; const int warp_idx = thread_idx / WARP_SIZE;
// int warp_idx =0;
// asm volatile("v_readfirstlane_b32 %0,%1"
// : "=s"(warp_idx)
// : "v"(warp_idx_vec)
// :);
// // const int warp_idx = thread_idx / WARP_SIZE;
// const int lane = thread_idx % WARP_SIZE;
//const int warp_idx = thread_idx / WARP_SIZE;
const int lane = thread_idx % WARP_SIZE; const int lane = thread_idx % WARP_SIZE;
int warp_id_vec = threadIdx.x / WARP_SIZE; //warp id in a block const int head_idx = blockIdx.x;
int warp_idx =0; const int num_heads = gridDim.x;
asm volatile("v_readfirstlane_b32 %0,%1"
: "=s"(warp_idx)
: "v"(warp_id_vec)
:);
// const int head_idx = blockIdx.x;
// const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / num_kv_heads; const int num_queries_per_kv = num_heads / num_kv_heads;
// const float alibi_slope = const int kv_head_idx = head_idx / num_queries_per_kv;
// alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; const float alibi_slope =
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// A vector type to store a part of a key or a query. // A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread // The vector size is configured in such a way that the threads in a thread
// group fetch or compute 16 bytes at a time. For example, if the size of a // group fetch or compute 16 bytes at a time. For example, if the size of a
// thread group is 4 and the data type is half, then the vector size is 16 / // thread group is 4 and the data type is half, then the vector size is 16 /
// (4 * sizeof(half)) == 2. // (4 * sizeof(half)) == 2.
constexpr int VEC_SIZE = MAX(32 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type; using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type; using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type; using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
...@@ -180,89 +176,61 @@ __device__ void paged_attention_kernel( ...@@ -180,89 +176,61 @@ __device__ void paged_attention_kernel(
// the group has 0, 4, 8, ... th vectors of the query, and the second thread // the group has 0, 4, 8, ... th vectors of the query, and the second thread
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// q is split from a qkv tensor, it may not be contiguous. // q is split from a qkv tensor, it may not be contiguous.
// const scalar_t* q_ptr = q + seq_idx * q_stride; const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
const scalar_t* q_ptr_offset = q + seq_idx * q_stride; __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
#pragma unroll
__shared__ Q_vec q_vecs[REUSE_KV_TIMES * THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
// #pragma unroll i += NUM_THREAD_GROUPS) {
// for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
// i += NUM_THREAD_GROUPS) { q_vecs[thread_group_offset][i] =
// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
// q_vecs[thread_group_offset][i] = }
// *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE); __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// } // memory wall right before we use q_vecs
// __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// // memory wall right before we use q_vecs
// Memory planning. // Memory planning.
extern __shared__ char shared_mem[]; extern __shared__ char shared_mem[];
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
float* logits = reinterpret_cast<float*>(shared_mem); float* logits = reinterpret_cast<float*>(shared_mem);
// Workspace for reduction. // Workspace for reduction.
__shared__ float red_smem[REUSE_KV_TIMES][2 * NUM_WARPS]; __shared__ float red_smem[2 * NUM_WARPS];
// float (*red_smem)[2 * NUM_WARPS] = reinterpret_cast<float(*)[2 * NUM_WARPS]>(&shared_mem[10*1024]);
// __shared__ char shared_mem[12 * 1024];
// float* logits = reinterpret_cast<float*>(shared_mem);
// __shared__ float red_smem[REUSE_KV_TIMES][2 * NUM_WARPS];
// x == THREAD_GROUP_SIZE * VEC_SIZE // x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time. // Each thread group fetches x elements from the key at a time.
constexpr int x = 16 / sizeof(cache_t); constexpr int x = 16 / sizeof(cache_t);
float qk_max[REUSE_KV_TIMES]; float qk_max = -FLT_MAX;
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
qk_max[reuse_kv_idx] = -FLT_MAX;
}
const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES);
const int head_idx_soffset = (blockIdx.x / num_blocks_per_kv) * num_queries_per_kv + (blockIdx.x % num_blocks_per_kv) * REUSE_KV_TIMES;
const int kv_head_idx = head_idx_soffset / num_queries_per_kv;
const int q_boundary = (kv_head_idx + 1)* num_queries_per_kv;
#pragma unroll
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx;//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
const scalar_t* q_ptr = q_ptr_offset + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[reuse_kv_idx*THREAD_GROUP_SIZE + thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
}
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
// Iterate over the key blocks. // Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration. // Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes // Each thread group in a warp fetches a key from the block, and computes
// dot product with the query. // dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
// blocksparse specific vars
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to // NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied // int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride). // by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not // For blocksparse attention: skip computation on blocks that are not
// attended // attended
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx;//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
if(!odd_nheads || head_idx < q_boundary) {
// blocksparse specific vars
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
if constexpr (IS_BLOCK_SPARSE) { if constexpr (IS_BLOCK_SPARSE) {
const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
const bool is_remote = const bool is_remote =
...@@ -286,8 +254,8 @@ __device__ void paged_attention_kernel( ...@@ -286,8 +254,8 @@ __device__ void paged_attention_kernel(
continue; continue;
} }
} }
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; const int64_t physical_block_number =
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]); static_cast<int64_t>(block_table[block_idx]);
// Load a key to registers. // Load a key to registers.
// Each thread in a thread group has a different part of the key. // Each thread in a thread group has a different part of the key.
...@@ -295,104 +263,99 @@ __device__ void paged_attention_kernel( ...@@ -295,104 +263,99 @@ __device__ void paged_attention_kernel(
// the group has 0, 4, 8, ... th vectors of the key, and the second thread // the group has 0, 4, 8, ... th vectors of the key, and the second thread
// has 1, 5, 9, ... th vectors of the key, and so on. // has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD]; K_vec k_vecs[NUM_VECS_PER_THREAD];
if(reuse_kv_idx == 0) {
#pragma unroll #pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr = const cache_t* k_ptr =
k_cache + physical_block_number * kv_block_stride + k_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + physical_block_offset * x; kv_head_idx * kv_head_stride + physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x; const int offset2 = (vec_idx * VEC_SIZE) % x;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>( k_vecs[j] = *reinterpret_cast<const K_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} else { } else {
// Vector conversion from Quant_vec to K_vec. // Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>( Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>( k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, k_scale); k_vec_quant, k_scale);
}
} }
} }
__builtin_amdgcn_sched_barrier(0);
// Compute dot product. // Compute dot product.
// This includes a reduction across the threads in the same thread group. // This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[reuse_kv_idx*THREAD_GROUP_SIZE + thread_group_offset], k_vecs); float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given. // Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
__builtin_amdgcn_sched_barrier(0);
if (thread_group_offset == 0) { if (thread_group_offset == 0) {
// Store the partial reductions to shared memory. // Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits. // NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= seq_len; const bool mask = token_idx >= seq_len;
logits[(reuse_kv_idx * partition_size) + (token_idx - start_token_idx)] = mask ? 0.f : qk; logits[token_idx - start_token_idx] = mask ? 0.f : qk;
// Update the max value. // Update the max value.
qk_max[reuse_kv_idx] = mask ? qk_max[reuse_kv_idx] : fmaxf(qk_max[reuse_kv_idx], qk); qk_max = mask ? qk_max : fmaxf(qk_max, qk);
} }
} }
} }
}
}
// Get the sum of the exp values.
float exp_sum[REUSE_KV_TIMES] = {0.f};
// Perform reduction across the threads in the same warp to get the // Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet). // max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value. // The 0-th thread of each thread group already has its max qk value.
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) { #pragma unroll
const int head_idx = head_idx_soffset + reuse_kv_idx; for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
if(!odd_nheads || head_idx < q_boundary) { qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
#pragma unroll }
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { if (lane == 0) {
qk_max[reuse_kv_idx] = fmaxf(qk_max[reuse_kv_idx], VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask)); red_smem[warp_idx] = qk_max;
} }
if (lane == 0) { __syncthreads();
red_smem[reuse_kv_idx][warp_idx] = qk_max[reuse_kv_idx];
}
__syncthreads();
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max[reuse_kv_idx] = lane < NUM_WARPS ? red_smem[reuse_kv_idx][lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max[reuse_kv_idx] = fmaxf(qk_max[reuse_kv_idx], VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask));
}
// Broadcast the max qk value to all threads.
qk_max[reuse_kv_idx] = VLLM_SHFL_SYNC(qk_max[reuse_kv_idx], 0);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { // TODO(woosuk): Refactor this part.
float val = __expf(logits[(reuse_kv_idx * partition_size) + i] - qk_max[reuse_kv_idx]); // Get the max qk value for the sequence.
logits[(reuse_kv_idx * partition_size) + i] = val; qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
exp_sum[reuse_kv_idx] += val; #pragma unroll
} for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
exp_sum[reuse_kv_idx] = block_sum<NUM_WARPS>(&red_smem[reuse_kv_idx][NUM_WARPS], exp_sum[reuse_kv_idx]); qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
// Broadcast the max qk value to all threads.
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
// Compute softmax. // Get the sum of the exp values.
const float inv_sum = __fdividef(1.f, exp_sum[reuse_kv_idx] + 1e-6f); float exp_sum = 0.f;
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[(reuse_kv_idx * partition_size) + i] *= inv_sum; float val = __expf(logits[i] - qk_max);
} logits[i] = val;
__syncthreads(); exp_sum += val;
}
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
// If partitioning is enabled, store the max logit and exp_sum. // Compute softmax.
if (USE_PARTITIONING && thread_idx == 0) { const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
float* max_logits_ptr = max_logits + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
seq_idx * num_heads * max_num_partitions + logits[i] *= inv_sum;
head_idx * max_num_partitions + partition_idx; }
*max_logits_ptr = qk_max[reuse_kv_idx]; __syncthreads();
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx; // If partitioning is enabled, store the max logit and exp_sum.
*exp_sums_ptr = exp_sum[reuse_kv_idx]; if (USE_PARTITIONING && thread_idx == 0) {
} float* max_logits_ptr = max_logits +
} seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*max_logits_ptr = qk_max;
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*exp_sums_ptr = exp_sum;
} }
// Each thread will fetch 16 bytes from the value cache at a time. // Each thread will fetch 16 bytes from the value cache at a time.
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type; using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
...@@ -406,74 +369,44 @@ __device__ void paged_attention_kernel( ...@@ -406,74 +369,44 @@ __device__ void paged_attention_kernel(
DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy. // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD]; float accs[NUM_ROWS_PER_THREAD];
#pragma unroll
#pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) { accs[i] = 0.f;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[reuse_kv_idx][i] = 0.f;
}
} }
scalar_t zero_value; scalar_t zero_value;
zero(zero_value); zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) { block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
if constexpr (IS_BLOCK_SPARSE) {
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
continue;
}
}
const int64_t physical_block_number = const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]); static_cast<int64_t>(block_table[block_idx]);
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec; L_vec logits_vec;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
start_token_idx));
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
V_vec v_vec;
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
// blocksparse specific vars
const int head_idx = head_idx_soffset + reuse_kv_idx;
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
if constexpr (IS_BLOCK_SPARSE) {
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
continue;
}
}
if(!odd_nheads || head_idx < q_boundary) {
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + (reuse_kv_idx * partition_size) + token_idx - start_token_idx));
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// for(int i=0;i<8;++i){
// from_float(*(logits_vec_ptr+i), 1000);
// }
if(reuse_kv_idx==0) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) { if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset; const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset); v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
...@@ -495,41 +428,20 @@ __device__ void paged_attention_kernel( ...@@ -495,41 +428,20 @@ __device__ void paged_attention_kernel(
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
} }
} }
// if(threadIdx.x==0){ accs[i] += dot(logits_vec, v_vec);
// scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// for(int i=0;i<8;++i){
// printf("v_vec[%d] = %f\n",i, half_to_float(v_vec_ptr[i]));
// // from_float(*(v_vec_ptr + i), 1000);
// }
// for(int i=0;i<8;++i){
// printf("logits_vec[%d] = %f\n",i,half_to_float(logits_vec_ptr[i]));
// // from_float(*(logits_vec_ptr + i), 1000);
// }
// }
// accs[reuse_kv_idx][i] += dot(logits_vec, v_vec);
}
}
accs[reuse_kv_idx][i] += dot(logits_vec, v_vec);
}
} }
} }
} }
// Perform reduction within each warp. // Perform reduction within each warp.
#pragma unroll
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
int head_idx = head_idx_soffset + reuse_kv_idx;
if(!odd_nheads || head_idx < q_boundary) {
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[reuse_kv_idx][i]; float acc = accs[i];
#pragma unroll #pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += VLLM_SHFL_XOR_SYNC(acc, mask); acc += VLLM_SHFL_XOR_SYNC(acc, mask);
} }
accs[reuse_kv_idx][i] = acc; accs[i] = acc;
} }
// NOTE(woosuk): A barrier is required because the shared memory space for // NOTE(woosuk): A barrier is required because the shared memory space for
...@@ -543,12 +455,12 @@ __device__ void paged_attention_kernel( ...@@ -543,12 +455,12 @@ __device__ void paged_attention_kernel(
int mid = i / 2; int mid = i / 2;
// Upper warps write to shared memory. // Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) { if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) + (warp_idx - mid) * HEAD_SIZE]; float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
dst[row_idx] = accs[reuse_kv_idx][i]; dst[row_idx] = accs[i];
} }
} }
} }
...@@ -556,12 +468,12 @@ __device__ void paged_attention_kernel( ...@@ -556,12 +468,12 @@ __device__ void paged_attention_kernel(
// Lower warps update the output. // Lower warps update the output.
if (warp_idx < mid) { if (warp_idx < mid) {
const float* src = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) + warp_idx * HEAD_SIZE]; const float* src = &out_smem[warp_idx * HEAD_SIZE];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
accs[reuse_kv_idx][i] += src[row_idx]; accs[i] += src[row_idx];
} }
} }
} }
...@@ -574,33 +486,26 @@ __device__ void paged_attention_kernel( ...@@ -574,33 +486,26 @@ __device__ void paged_attention_kernel(
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
from_float(*(out_ptr + row_idx), accs[reuse_kv_idx][i]); from_float(*(out_ptr + row_idx), accs[i]);
}
}
} }
} }
} }
}
} }
// Grid: (num_heads, num_seqs, 1). // Grid: (num_heads, num_seqs, 1).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
int REUSE_KV_TIMES = 1, bool IS_BLOCK_SPARSE>
bool IS_BLOCK_SPARSE, __global__ void paged_attention_v1_kernel_opt(
bool odd_nheads = false>
__global__ __launch_bounds__(256,1) void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x] // head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_heads, // [num_heads]
const int num_kv_heads, // [num_heads] const int num_kv_heads, // [num_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
...@@ -608,14 +513,13 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel( ...@@ -608,14 +513,13 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads>( KV_DTYPE, IS_BLOCK_SPARSE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_heads, num_kv_heads, scale, block_tables, seq_lens, v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size, blocksparse_vert_stride, blocksparse_block_size,
...@@ -626,10 +530,8 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel( ...@@ -626,10 +530,8 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel(
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE, bool IS_BLOCK_SPARSE,
int REUSE_KV_TIMES, int PARTITION_SIZE>
int PARTITION_SIZE, __global__ void paged_attention_v2_kernel_opt(
bool odd_nheads = false>
__global__ __launch_bounds__(256,1) void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
...@@ -640,31 +542,29 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel( ...@@ -640,31 +542,29 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel(
// head_size/x, block_size, x] // head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_heads, // [num_heads] const int num_kv_heads, // [num_heads]
const int num_kv_heads, // [num_kv_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads, PARTITION_SIZE>( KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads, num_kv_heads, scale, exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step); blocksparse_head_sliding_step);
} }
// Grid: (num_heads, num_seqs). // Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS, template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE> int PARTITION_SIZE>
__global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel( __global__ void paged_attention_v2_reduce_kernel_opt(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads, const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
...@@ -772,34 +672,24 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel( ...@@ -772,34 +672,24 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel(
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \ ((void*)vllm::paged_attention_v1_kernel_opt<T, CACHE_T, HEAD_SIZE, \
BLOCK_SIZE, NUM_THREADS, \ BLOCK_SIZE, NUM_THREADS, \
KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>), \ KV_DTYPE, IS_BLOCK_SPARSE>), \
shared_mem_size); \ shared_mem_size); \
hipLaunchKernelGGL(( vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \ vllm::paged_attention_v1_kernel_opt<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>) \ NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \
, dim3(grid), dim3(block), shared_mem_size, stream, \ <<<grid, block, shared_mem_size, stream>>>( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step); blocksparse_head_sliding_step);
// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
// vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
// NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads> \
// <<<dim3(grid), dim3(block)>>>( \
// out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
// scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
// alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
// kv_scale, tp_rank, blocksparse_local_blocks, \
// blocksparse_vert_stride, blocksparse_block_size, \
// blocksparse_head_sliding_step);
// TODO(woosuk): Tune NUM_THREADS. // TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE, template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE> vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
int NUM_THREADS = 128>
void paged_attention_v1_launcher( void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
...@@ -815,10 +705,6 @@ void paged_attention_v1_launcher( ...@@ -815,10 +705,6 @@ void paged_attention_v1_launcher(
int q_stride = query.stride(0); int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0); int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1); int kv_head_stride = key_cache.stride(1);
int num_threads = 128;
if(num_heads!=num_kv_heads){
num_threads =256;
}
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0); assert(head_size % thread_group_size == 0);
...@@ -836,31 +722,51 @@ void paged_attention_v1_launcher( ...@@ -836,31 +722,51 @@ void paged_attention_v1_launcher(
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
REUSEKV_SWITCH_V1(num_heads * num_seqs , [&] { int padded_max_seq_len =
BOOL_SWITCH((num_heads/num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] { DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
HEADSIZE_SWITCH(head_size, [&] { int logits_size = padded_max_seq_len * sizeof(float);
NUM_THREADS_SWITCH(num_threads, [&] { int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
OPT_SWITCH(num_heads == num_kv_heads, [&] { // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; // Keep that in sync with the logic here!
int logits_size = REUSE_KV_TIMES*padded_max_seq_len * sizeof(float); int shared_mem_size = std::max(logits_size, outputs_size);
int outputs_size = REUSE_KV_TIMES*(NUM_WARPS / 2) * head_size * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len dim3 grid(num_heads, num_seqs, 1);
// Keep that in sync with the logic here! dim3 block(NUM_THREADS);
int shared_mem_size = ::max(logits_size, outputs_size); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
if(num_heads == num_kv_heads) shared_mem_size = ::max(12 * 1024, shared_mem_size); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// int shared_mem_size = ::max(31*1024, ::max(logits_size, outputs_size)); switch (head_size) {
// std::cout<<"shared_mem_size = "<<shared_mem_size<<std::endl; // NOTE(woosuk): To reduce the compilation time, we only compile for the
dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1, num_seqs); // head sizes that we use in the model. However, we can easily extend this
dim3 block(NUM_THREADS); // to support any head size which is a multiple of 16.
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query)); case 64:
const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); LAUNCH_PAGED_ATTENTION_V1(64);
LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE); break;
}); case 80:
}); LAUNCH_PAGED_ATTENTION_V1(80);
}); break;
}); case 96:
}); LAUNCH_PAGED_ATTENTION_V1(96);
break;
case 112:
LAUNCH_PAGED_ATTENTION_V1(112);
break;
case 120:
LAUNCH_PAGED_ATTENTION_V1(120);
break;
case 128:
LAUNCH_PAGED_ATTENTION_V1(128);
break;
case 192:
LAUNCH_PAGED_ATTENTION_V1(192);
break;
case 256:
LAUNCH_PAGED_ATTENTION_V1(256);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
} }
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
...@@ -899,7 +805,7 @@ void paged_attention_v1_launcher( ...@@ -899,7 +805,7 @@ void paged_attention_v1_launcher(
break; \ break; \
} }
void paged_attention_v1( void paged_attention_v1_opt(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& torch::Tensor&
...@@ -923,25 +829,25 @@ void paged_attention_v1( ...@@ -923,25 +829,25 @@ void paged_attention_v1(
} }
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
hipLaunchKernelGGL(( vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \ vllm::paged_attention_v2_kernel_opt<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \ NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
REUSE_KV_TIMES, PARTITION_SIZE, odd_nheads>) \ PARTITION_SIZE> \
, dim3(grid), dim3(block), shared_mem_size, stream, \ <<<grid, block, shared_mem_size, stream>>>( \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \ blocksparse_block_size, blocksparse_head_sliding_step); \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \ vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>) \ PARTITION_SIZE> \
, dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, \ <<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions); max_num_partitions);
template <typename T, typename CACHE_T, int BLOCK_SIZE, template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE, vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
int NUM_THREADS = 256, int PARTITION_SIZE = 512> int NUM_THREADS = 128, int PARTITION_SIZE = 512>
void paged_attention_v2_launcher( void paged_attention_v2_launcher(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
...@@ -980,33 +886,51 @@ void paged_attention_v2_launcher( ...@@ -980,33 +886,51 @@ void paged_attention_v2_launcher(
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
REUSEKV_SWITCH(num_heads * max_num_partitions * num_seqs , [&] { int logits_size = PARTITION_SIZE * sizeof(float);
BOOL_SWITCH((num_heads/num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] { int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
HEADSIZE_SWITCH(head_size, [&] {
OPT_SWITCH(num_heads == num_kv_heads, [&] { // For paged attention v2 kernel.
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * sizeof(float); dim3 grid(num_heads, num_seqs, max_num_partitions);
int outputs_size = REUSE_KV_TIMES*(NUM_WARPS / 2) * head_size * sizeof(float); int shared_mem_size = std::max(logits_size, outputs_size);
// For paged attention v2 reduce kernel.
// For paged attention v2 kernel. dim3 reduce_grid(num_heads, num_seqs);
// dim3 grid(num_heads, max_num_partitions, num_seqs); int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
dim3 grid; dim3 block(NUM_THREADS);
grid.x = (num_heads/num_kv_heads + REUSE_KV_TIMES -1)/REUSE_KV_TIMES * num_kv_heads; const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
grid.y = max_num_partitions; const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
grid.z = num_seqs; switch (head_size) {
// int shared_mem_size = ::max(1024*32, ::max(logits_size, outputs_size)); // NOTE(woosuk): To reduce the compilation time, we only compile for the
int shared_mem_size = ::max(logits_size, outputs_size); // head sizes that we use in the model. However, we can easily extend this
// For paged attention v2 reduce kernel. // to support any head size which is a multiple of 16.
dim3 reduce_grid(num_heads, num_seqs); case 64:
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); LAUNCH_PAGED_ATTENTION_V2(64);
dim3 block(NUM_THREADS); break;
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query)); case 80:
const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); LAUNCH_PAGED_ATTENTION_V2(80);
LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE); break;
}); case 96:
}); LAUNCH_PAGED_ATTENTION_V2(96);
}); break;
}); case 112:
LAUNCH_PAGED_ATTENTION_V2(112);
break;
case 120:
LAUNCH_PAGED_ATTENTION_V2(120);
break;
case 128:
LAUNCH_PAGED_ATTENTION_V2(128);
break;
case 192:
LAUNCH_PAGED_ATTENTION_V2(192);
break;
case 256:
LAUNCH_PAGED_ATTENTION_V2(256);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
} }
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
...@@ -1046,7 +970,7 @@ void paged_attention_v2_launcher( ...@@ -1046,7 +970,7 @@ void paged_attention_v2_launcher(
break; \ break; \
} }
void paged_attention_v2( void paged_attention_v2_opt(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
......
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16;
#else
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#include "static_switch.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
namespace vllm {
// Utility function for attention softmax.
template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) {
// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}
// Warp leaders store the data to shared memory.
if (lane == 0) {
red_smem[warp] = sum;
}
// Make sure the data is in shared memory.
__syncthreads();
// The warps compute the final sums.
if (lane < NUM_WARPS) {
sum = red_smem[lane];
}
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}
// Broadcast to other threads.
return VLLM_SHFL_SYNC(sum, 0);
}
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,
int REUSE_KV_TIMES = 1,
bool odd_nheads = false,
int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_heads, // [num_heads]
const int num_kv_heads, // [num_kv_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.z;
const int partition_idx = blockIdx.y;
const int max_num_partitions = gridDim.y;
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
const int seq_len = seq_lens[seq_idx];
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
// No work to do. Terminate the thread block.
return;
}
if constexpr (sizeof(scalar_t)==2){
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition =
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
// [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx =
USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
const int end_block_idx =
MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
const int num_blocks = end_block_idx - start_block_idx;
// [start_token_idx, end_token_idx) is the range of tokens to process.
const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx =
MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
const int num_tokens = end_token_idx - start_token_idx;
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_THREAD_GROUPS =
NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
// divides NUM_THREADS
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
constexpr int NUM_TOKENS_PER_THREAD_GROUP =
DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x;
// const int warp_idx_vec = thread_idx / WARP_SIZE;
// int warp_idx =0;
// asm volatile("v_readfirstlane_b32 %0,%1"
// : "=s"(warp_idx)
// : "v"(warp_idx_vec)
// :);
// // const int warp_idx = thread_idx / WARP_SIZE;
// const int lane = thread_idx % WARP_SIZE;
//const int warp_idx = thread_idx / WARP_SIZE;
const int lane = thread_idx % WARP_SIZE;
int warp_id_vec = threadIdx.x / WARP_SIZE; //warp id in a block
int warp_idx =0;
asm volatile("v_readfirstlane_b32 %0,%1"
: "=s"(warp_idx)
: "v"(warp_id_vec)
:);
// const int head_idx = blockIdx.x;
// const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / num_kv_heads;
// const float alibi_slope =
// alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread
// group fetch or compute 16 bytes at a time. For example, if the size of a
// thread group is 4 and the data type is half, then the vector size is 16 /
// (4 * sizeof(half)) == 2.
constexpr int VEC_SIZE = MAX(32 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
// Load the query to registers.
// Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// q is split from a qkv tensor, it may not be contiguous.
// const scalar_t* q_ptr = q + seq_idx * q_stride;
const scalar_t* q_ptr_offset = q + seq_idx * q_stride;
__shared__ Q_vec q_vecs[REUSE_KV_TIMES * THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
// #pragma unroll
// for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
// i += NUM_THREAD_GROUPS) {
// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
// q_vecs[thread_group_offset][i] =
// *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
// }
// __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// // memory wall right before we use q_vecs
// Memory planning.
extern __shared__ char shared_mem[];
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
float* logits = reinterpret_cast<float*>(shared_mem);
// Workspace for reduction.
__shared__ float red_smem[REUSE_KV_TIMES][2 * NUM_WARPS];
// float (*red_smem)[2 * NUM_WARPS] = reinterpret_cast<float(*)[2 * NUM_WARPS]>(&shared_mem[10*1024]);
// __shared__ char shared_mem[12 * 1024];
// float* logits = reinterpret_cast<float*>(shared_mem);
// __shared__ float red_smem[REUSE_KV_TIMES][2 * NUM_WARPS];
// x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time.
constexpr int x = 16 / sizeof(cache_t);
float qk_max[REUSE_KV_TIMES];
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
qk_max[reuse_kv_idx] = -FLT_MAX;
}
const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES);
const int head_idx_soffset = (blockIdx.x / num_blocks_per_kv) * num_queries_per_kv + (blockIdx.x % num_blocks_per_kv) * REUSE_KV_TIMES;
const int kv_head_idx = head_idx_soffset / num_queries_per_kv;
const int q_boundary = (kv_head_idx + 1)* num_queries_per_kv;
#pragma unroll
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx;//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
const scalar_t* q_ptr = q_ptr_offset + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[reuse_kv_idx*THREAD_GROUP_SIZE + thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
}
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx;//blockIdx.x * REUSE_KV_TIMES + reuse_kv_idx;
if(!odd_nheads || head_idx < q_boundary) {
// blocksparse specific vars
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
if constexpr (IS_BLOCK_SPARSE) {
const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
const bool is_remote =
((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
const bool is_local =
(k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
if (!is_remote && !is_local) {
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
if (thread_group_offset == 0) {
// NOTE(linxihui): assign very large number to skipped tokens to
// avoid contribution to the sumexp softmax normalizer. This will
// not be used at computing sum(softmax*v) as the blocks will be
// skipped.
logits[token_idx - start_token_idx] = -FLT_MAX;
}
}
continue;
}
}
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
// has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD];
if(reuse_kv_idx == 0) {
#pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr =
k_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} else {
// Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, k_scale);
}
}
}
__builtin_amdgcn_sched_barrier(0);
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[reuse_kv_idx*THREAD_GROUP_SIZE + thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
__builtin_amdgcn_sched_barrier(0);
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= seq_len;
logits[(reuse_kv_idx * partition_size) + (token_idx - start_token_idx)] = mask ? 0.f : qk;
// Update the max value.
qk_max[reuse_kv_idx] = mask ? qk_max[reuse_kv_idx] : fmaxf(qk_max[reuse_kv_idx], qk);
}
}
}
}
}
// Get the sum of the exp values.
float exp_sum[REUSE_KV_TIMES] = {0.f};
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
const int head_idx = head_idx_soffset + reuse_kv_idx;
if(!odd_nheads || head_idx < q_boundary) {
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max[reuse_kv_idx] = fmaxf(qk_max[reuse_kv_idx], VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask));
}
if (lane == 0) {
red_smem[reuse_kv_idx][warp_idx] = qk_max[reuse_kv_idx];
}
__syncthreads();
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max[reuse_kv_idx] = lane < NUM_WARPS ? red_smem[reuse_kv_idx][lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max[reuse_kv_idx] = fmaxf(qk_max[reuse_kv_idx], VLLM_SHFL_XOR_SYNC(qk_max[reuse_kv_idx], mask));
}
// Broadcast the max qk value to all threads.
qk_max[reuse_kv_idx] = VLLM_SHFL_SYNC(qk_max[reuse_kv_idx], 0);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[(reuse_kv_idx * partition_size) + i] - qk_max[reuse_kv_idx]);
logits[(reuse_kv_idx * partition_size) + i] = val;
exp_sum[reuse_kv_idx] += val;
}
exp_sum[reuse_kv_idx] = block_sum<NUM_WARPS>(&red_smem[reuse_kv_idx][NUM_WARPS], exp_sum[reuse_kv_idx]);
// Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum[reuse_kv_idx] + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[(reuse_kv_idx * partition_size) + i] *= inv_sum;
}
__syncthreads();
// If partitioning is enabled, store the max logit and exp_sum.
if (USE_PARTITIONING && thread_idx == 0) {
float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*max_logits_ptr = qk_max[reuse_kv_idx];
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx;
*exp_sums_ptr = exp_sum[reuse_kv_idx];
}
}
}
// Each thread will fetch 16 bytes from the value cache at a time.
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
using Float_L_vec = typename FloatVec<L_vec>::Type;
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
constexpr int NUM_ROWS_PER_THREAD =
DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD];
#pragma unroll
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[reuse_kv_idx][i] = 0.f;
}
}
scalar_t zero_value;
zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
V_vec v_vec;
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
// blocksparse specific vars
const int head_idx = head_idx_soffset + reuse_kv_idx;
int bs_block_offset;
int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0)
// sliding on q heads
bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else
// sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) +
1;
}
if constexpr (IS_BLOCK_SPARSE) {
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
continue;
}
}
if(!odd_nheads || head_idx < q_boundary) {
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + (reuse_kv_idx * partition_size) + token_idx - start_token_idx));
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// for(int i=0;i<8;++i){
// from_float(*(logits_vec_ptr+i), 1000);
// }
if(reuse_kv_idx==0) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else {
V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
v_scale);
}
if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the
// context, we should explicitly zero out the values since they may
// contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
}
}
// if(threadIdx.x==0){
// scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
// scalar_t* logits_vec_ptr = reinterpret_cast<scalar_t*>(&logits_vec);
// for(int i=0;i<8;++i){
// printf("v_vec[%d] = %f\n",i, half_to_float(v_vec_ptr[i]));
// // from_float(*(v_vec_ptr + i), 1000);
// }
// for(int i=0;i<8;++i){
// printf("logits_vec[%d] = %f\n",i,half_to_float(logits_vec_ptr[i]));
// // from_float(*(logits_vec_ptr + i), 1000);
// }
// }
// accs[reuse_kv_idx][i] += dot(logits_vec, v_vec);
}
}
accs[reuse_kv_idx][i] += dot(logits_vec, v_vec);
}
}
}
}
// Perform reduction within each warp.
#pragma unroll
for(int reuse_kv_idx=0; reuse_kv_idx<REUSE_KV_TIMES; reuse_kv_idx++) {
int head_idx = head_idx_soffset + reuse_kv_idx;
if(!odd_nheads || head_idx < q_boundary) {
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[reuse_kv_idx][i];
#pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
}
accs[reuse_kv_idx][i] = acc;
}
// NOTE(woosuk): A barrier is required because the shared memory space for
// logits is reused for the output.
__syncthreads();
// Perform reduction across warps.
float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) + (warp_idx - mid) * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
dst[row_idx] = accs[reuse_kv_idx][i];
}
}
}
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
const float* src = &out_smem[(reuse_kv_idx * (NUM_WARPS / 2) * HEAD_SIZE) + warp_idx * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
accs[reuse_kv_idx][i] += src[row_idx];
}
}
}
__syncthreads();
}
// Write the final output.
if (warp_idx == 0) {
scalar_t* out_ptr =
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
from_float(*(out_ptr + row_idx), accs[reuse_kv_idx][i]);
}
}
}
}
}
}
}
// Grid: (num_heads, num_seqs, 1).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
int REUSE_KV_TIMES = 1,
bool IS_BLOCK_SPARSE,
bool odd_nheads = false>
__global__ __launch_bounds__(256,1) void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_heads, // [num_heads]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_heads, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
}
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,
int REUSE_KV_TIMES,
int PARTITION_SIZE,
bool odd_nheads = false>
__global__ __launch_bounds__(256,1) void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_heads, // [num_heads]
const int num_kv_heads, // [num_kv_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
}
// Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE>
__global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_partitions) {
const int num_heads = gridDim.x;
const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y;
const int seq_len = seq_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
if (num_partitions == 1) {
// No need to reduce. Only copy tmp_out to out.
scalar_t* out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE;
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
out_ptr[i] = tmp_out_ptr[i];
}
// Terminate the thread block.
return;
}
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int warp_idx = threadIdx.x / WARP_SIZE;
const int lane = threadIdx.x % WARP_SIZE;
// Size: 2 * num_partitions.
extern __shared__ char shared_mem[];
// Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS];
// Load max logits to shared memory.
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
const float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float max_logit = -FLT_MAX;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
const float l = max_logits_ptr[i];
shared_max_logits[i] = l;
max_logit = fmaxf(max_logit, l);
}
__syncthreads();
// Get the global max logit.
// Reduce within the warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
if (lane == 0) {
red_smem[warp_idx] = max_logit;
}
__syncthreads();
// Reduce across warps.
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
// Broadcast the max value to all threads.
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
// Load rescaled exp sums to shared memory.
float* shared_exp_sums =
reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
const float* exp_sums_ptr = exp_sums +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float global_exp_sum = 0.0f;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
float l = shared_max_logits[i];
float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
global_exp_sum += rescaled_exp_sum;
shared_exp_sums[i] = rescaled_exp_sum;
}
__syncthreads();
global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
// Aggregate tmp_out to out.
const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE;
scalar_t* out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
float acc = 0.0f;
for (int j = 0; j < num_partitions; ++j) {
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
inv_global_exp_sum;
}
from_float(out_ptr[i], acc);
}
}
} // namespace vllm
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, \
BLOCK_SIZE, NUM_THREADS, \
KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>), \
shared_mem_size); \
hipLaunchKernelGGL(( vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads>) \
, dim3(grid), dim3(block), shared_mem_size, stream, \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
// vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
// NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads> \
// <<<dim3(grid), dim3(block)>>>( \
// out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
// scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
// alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
// kv_scale, tp_rank, blocksparse_local_blocks, \
// blocksparse_vert_stride, blocksparse_block_size, \
// blocksparse_head_sliding_step);
// TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE>
void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
int num_threads = 128;
if(num_heads!=num_kv_heads){
num_threads =256;
}
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
REUSEKV_SWITCH_V1(num_heads * num_seqs , [&] {
BOOL_SWITCH((num_heads/num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] {
HEADSIZE_SWITCH(head_size, [&] {
NUM_THREADS_SWITCH(num_threads, [&] {
OPT_SWITCH(num_heads == num_kv_heads, [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES*padded_max_seq_len * sizeof(float);
int outputs_size = REUSE_KV_TIMES*(NUM_WARPS / 2) * head_size * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int shared_mem_size = ::max(logits_size, outputs_size);
if(num_heads == num_kv_heads) shared_mem_size = ::max(12 * 1024, shared_mem_size);
// int shared_mem_size = ::max(31*1024, ::max(logits_size, outputs_size));
// std::cout<<"shared_mem_size = "<<shared_mem_size<<std::endl;
dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1, num_seqs);
dim3 block(NUM_THREADS);
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query));
const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE);
});
});
});
});
});
}
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
case true: \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
break; \
case false: \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
break; \
}
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v1(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V1_LAUNCHER_BLOCK_SIZE)
}
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
hipLaunchKernelGGL(( vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
REUSE_KV_TIMES, PARTITION_SIZE, odd_nheads>) \
, dim3(grid), dim3(block), shared_mem_size, stream, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>) \
, dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions);
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
int NUM_THREADS = 256, int PARTITION_SIZE = 512>
void paged_attention_v2_launcher(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
REUSEKV_SWITCH(num_heads * max_num_partitions * num_seqs , [&] {
BOOL_SWITCH((num_heads/num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] {
HEADSIZE_SWITCH(head_size, [&] {
OPT_SWITCH(num_heads == num_kv_heads, [&] {
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * sizeof(float);
int outputs_size = REUSE_KV_TIMES*(NUM_WARPS / 2) * head_size * sizeof(float);
// For paged attention v2 kernel.
// dim3 grid(num_heads, max_num_partitions, num_seqs);
dim3 grid;
grid.x = (num_heads/num_kv_heads + REUSE_KV_TIMES -1)/REUSE_KV_TIMES * num_kv_heads;
grid.y = max_num_partitions;
grid.z = num_seqs;
// int shared_mem_size = ::max(1024*32, ::max(logits_size, outputs_size));
int shared_mem_size = ::max(logits_size, outputs_size);
// For paged attention v2 reduce kernel.
dim3 reduce_grid(num_heads, num_seqs);
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
dim3 block(NUM_THREADS);
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query));
const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE);
});
});
});
});
}
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
case true: \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
break; \
case false: \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
break; \
}
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v2(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor&
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE)
}
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
\ No newline at end of file
...@@ -26,12 +26,39 @@ void paged_attention_v2( ...@@ -26,12 +26,39 @@ void paged_attention_v2(
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
void paged_attention_v1_opt(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v2_opt(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
double epsilon); double epsilon);
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon); torch::Tensor& weight, double epsilon);
void rms_norm_opt(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
double epsilon);
void fused_add_rms_norm_opt(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int64_t head_size, torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox); torch::Tensor& cos_sin_cache, bool is_neox);
...@@ -55,12 +82,20 @@ void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); ...@@ -55,12 +82,20 @@ void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
void silu_and_mul_opt(torch::Tensor& out, torch::Tensor& input);
void gelu_and_mul_opt(torch::Tensor& out, torch::Tensor& input);
void gelu_tanh_and_mul_opt(torch::Tensor& out, torch::Tensor& input);
void gelu_new(torch::Tensor& out, torch::Tensor& input); void gelu_new(torch::Tensor& out, torch::Tensor& input);
void gelu_fast(torch::Tensor& out, torch::Tensor& input); void gelu_fast(torch::Tensor& out, torch::Tensor& input);
void gelu_quick(torch::Tensor& out, torch::Tensor& input); void gelu_quick(torch::Tensor& out, torch::Tensor& input);
void trans_w16_gemm(torch::Tensor dst, torch::Tensor src, int64_t row, int64_t col);
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
torch::Tensor& input_positions, torch::Tensor& seq_lens, torch::Tensor& input_positions, torch::Tensor& seq_lens,
...@@ -151,8 +186,6 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, ...@@ -151,8 +186,6 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
void trans_w16_gemm(torch::Tensor dst, torch::Tensor src, int64_t row, int64_t col);
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, // void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor const& scale); // torch::Tensor const& scale);
......
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <cmath>
#include "cuda_compat.h"
#include "../dispatch_utils.h"
namespace vllm {
// Activation and gating kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = ACT_FN(x) * y;
}
}
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), int VEC>
__global__ void act_and_mul_kernel_opt1(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
using VecType = at::native::memory::aligned_vector<scalar_t, VEC>;
const int64_t token_idx= blockIdx.x;
int idx = threadIdx.x * VEC;
if (idx < d) {
const int64_t x_index = token_idx * 2 * d + idx;
const int64_t y_index = token_idx * d + idx;
VecType* x1 = (VecType*)(input + x_index);
VecType* x2 = (VecType*)(input + x_index + d);
VecType* y = (VecType*)(out + y_index);
scalar_t r_x1[VEC];
scalar_t r_x2[VEC];
scalar_t r_y[VEC];
*(VecType*)r_x1 = *x1;
*(VecType*)r_x2 = *x2;
#pragma unroll
for (int i = 0; i < VEC; i++) {
r_y[i] = ACT_FN(r_x1[i]) * r_x2[i];
}
*y = *(VecType*)r_y;
}
}
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), int VEC>
__global__ void act_and_mul_kernel_opt2(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
using VecType = at::native::memory::aligned_vector<scalar_t, VEC>;
const int64_t token_idx = blockIdx.x;
int idx = threadIdx.x * VEC;
for (; idx < d; idx += blockDim.x * VEC) {
const int64_t x_index = token_idx * 2 * d + idx;
const int64_t y_index = token_idx * d + idx;
VecType* x1 = (VecType*)(input + x_index);
VecType* x2 = (VecType*)(input + x_index + d);
VecType* y = (VecType*)(out + y_index);
scalar_t r_x1[VEC];
scalar_t r_x2[VEC];
scalar_t r_y[VEC];
*(VecType*)r_x1 = *x1;
*(VecType*)r_x2 = *x2;
#pragma unroll
for (int i = 0; i < VEC; i++) {
r_y[i] = ACT_FN(r_x1[i]) * r_x2[i];
}
*y = *(VecType*)r_y;
}
}
template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x)
return (T)(((float)x) / (1.0f + expf((float)-x)));
}
template <typename T>
__device__ __forceinline__ T gelu_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'none' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
const float f = (float)x;
constexpr float ALPHA = M_SQRT1_2;
return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
}
template <typename T>
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'tanh' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
const float f = (float)x;
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
constexpr float KAPPA = 0.044715;
float x_cube = f * f * f;
float inner = BETA * (f + KAPPA * x_cube);
return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
}
} // namespace vllm
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
if (0 == d % 8 && d <= 16384) { \
if (d <= 512) { \
vllm::act_and_mul_kernel_opt1<scalar_t, KERNEL<scalar_t>, 2> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 1024) { \
vllm::act_and_mul_kernel_opt1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 128, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 2048) { \
vllm::act_and_mul_kernel_opt1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 256, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else if (d <= 4096) { \
vllm::act_and_mul_kernel_opt1<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 512, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} else { \
vllm::act_and_mul_kernel_opt2<scalar_t, KERNEL<scalar_t>, 8> \
<<<grid, 1024, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} \
} else { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
} \
});
void silu_and_mul_opt(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}
void gelu_and_mul_opt(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
}
void gelu_tanh_and_mul_opt(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
}
\ No newline at end of file
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
#include "../dispatch_utils.h"
#include "../reduction_utils.cuh"
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
#endif
namespace vllm {
// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t>
__global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * hidden_size + idx];
variance += x * x;
}
variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
}
}
/* Converter structs for the conversion from torch types to HIP/CUDA types,
and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion
operators/constructors are not consistently implemented by HIP/CUDA, so
a generic conversion via type casts cannot be implemented.
Each struct should have the member static constexpr bool `exists`:
If false, the optimized kernel is not used for the corresponding torch type.
If true, the struct should be fully defined as shown in the examples below.
*/
template <typename torch_type>
struct _typeConvert {
static constexpr bool exists = false;
};
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
template <>
struct _typeConvert<c10::Half> {
static constexpr bool exists = true;
using hip_type = __half;
using packed_hip_type = __half2;
__device__ static inline float convert(hip_type x) { return __half2float(x); }
__device__ static inline float2 convert(packed_hip_type x) {
return __half22float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2half_rn(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22half2_rn(x);
}
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
template <>
struct _typeConvert<c10::BFloat16> {
static constexpr bool exists = true;
using hip_type = __nv_bfloat16;
using packed_hip_type = __nv_bfloat162;
__device__ static inline float convert(hip_type x) {
return __bfloat162float(x);
}
__device__ static inline float2 convert(packed_hip_type x) {
return __bfloat1622float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2bfloat16(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22bfloat162_rn(x);
}
};
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel.
Only functions that are necessary in that kernel are implemented.
Alignment to 16 bytes is required to use 128-bit global memory ops.
*/
template <typename scalar_t, int width>
struct alignas(16) _f16Vec {
/* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */
static_assert(width > 0 && (width & (width - 1)) == 0,
"Width is not a positive power of 2!");
using Converter = _typeConvert<scalar_t>;
using T1 = typename Converter::hip_type;
using T2 = typename Converter::packed_hip_type;
T1 data[width];
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i + 1]};
temp += T2{other.data[i], other.data[i + 1]};
data[i] = temp.x;
data[i + 1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) data[i] += other.data[i];
}
return *this;
}
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i + 1]};
temp *= T2{other.data[i], other.data[i + 1]};
data[i] = temp.x;
data[i + 1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) data[i] *= other.data[i];
}
return *this;
}
__device__ _f16Vec& operator*=(const float scale) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
temp_f.x *= scale;
temp_f.y *= scale;
T2 temp = Converter::convert(temp_f);
data[i] = temp.x;
data[i + 1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) {
float temp = Converter::convert(data[i]) * scale;
data[i] = Converter::convert(temp);
}
}
return *this;
}
__device__ float sum_squares() const {
float result = 0.0f;
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i + 1]});
result += z.x * z.x + z.y * z.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) {
float x = Converter::convert(data[i]);
result += x * x;
}
}
return result;
}
};
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck. */
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
const int vec_hidden_size = hidden_size / width;
__shared__ float s_variance;
float variance = 0.0f;
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto* __restrict__ input_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
auto* __restrict__ residual_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
auto* __restrict__ weight_v =
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = input_v[id];
temp += residual_v[id];
variance += temp.sum_squares();
residual_v[id] = temp;
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance);
} else
variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = residual_v[id];
temp *= s_variance;
temp *= weight_v[idx];
input_v[id] = temp;
}
}
/* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
scalar_t z = input[blockIdx.x * hidden_size + idx];
z += residual[blockIdx.x * hidden_size + idx];
float x = (float)z;
variance += x * x;
residual[blockIdx.x * hidden_size + idx] = z;
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance);
} else
variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)residual[blockIdx.x * hidden_size + idx];
input[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
}
}
} // namespace vllm
template <typename T,int reducesize=C10_WARP_SIZE>
__inline__ __device__ T WarpReduceSum_NEW(T val) {
#pragma unroll
for (int offset = reducesize/2; offset > 0; offset >>= 1) {
val += WARP_SHFL_DOWN(val, offset);
}
return val;
}
template <typename T,int block_size=512>
__inline__ __device__ T BlockReduceSum_NEW(T val, T* shared) {
constexpr int share_size=block_size/C10_WARP_SIZE;
val = WarpReduceSum_NEW<T>(val);
if constexpr(block_size==C10_WARP_SIZE)
{
return val;
}
else{
const int lid = threadIdx.x % C10_WARP_SIZE;
const int wid = threadIdx.x / C10_WARP_SIZE;
if (lid == 0&&wid<share_size) {
shared[wid] = val;
}
__syncthreads();
if (wid == 0&&lid<share_size) {
val = WarpReduceSum_NEW<T,share_size>(shared[lid]);
}
return val;
}
}
template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
__global__ void fused_add_rms_kernel_opt(scalar_t* input,scalar_t* residual,scalar_t* gamma,int cols,T_ACC eps)
{
constexpr int share_size=block_size/C10_WARP_SIZE;
__shared__ T_ACC val_shared[share_size];
__shared__ T_ACC s_rstd;
T_ACC val=0;
int i=blockIdx.x;
int j=threadIdx.x;
int tcol=cols/Vec;
using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>;
scalar_t intput_vec[Vec];
scalar_t residual_vec[Vec];
T_ACC trstd;
int idx = i * tcol + j;
idx*=Vec;
*(LoadT*)intput_vec = *(LoadT*)(input+idx);
*(LoadT*)residual_vec = *(LoadT*)(residual+idx);
if (j < tcol) {
#pragma unroll
for (int ii = 0; ii < Vec; ii++) {
residual_vec[ii]+=intput_vec[ii];
val += static_cast<T_ACC>(residual_vec[ii])*static_cast<T_ACC>(residual_vec[ii]);
}
}
val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps);
__syncthreads();
trstd=s_rstd;
if (j < tcol) {
#pragma unroll
for(int ii=0;ii<Vec;ii++){
int jj=j*Vec+ii;
intput_vec[ii] = static_cast<T_ACC>(residual_vec[ii]) *trstd* static_cast<T_ACC>(gamma[jj]);
}
*(LoadT*)(residual+idx)=*(LoadT*)residual_vec;
*(LoadT*)(input+idx)=*(LoadT*)intput_vec;
}
}
template <typename scalar_t,typename T_ACC,int Vec=4,int block_size=512>
__global__ void fused_rms_kernel_opt(scalar_t* input,scalar_t* output,scalar_t* gamma,int cols,T_ACC eps)
{
constexpr int share_size=block_size/C10_WARP_SIZE;
__shared__ T_ACC val_shared[share_size];
__shared__ T_ACC s_rstd;
T_ACC val=0;
int i=blockIdx.x;
int j=threadIdx.x;
int tcol=cols/Vec;
using LoadT = at::native::memory::aligned_vector<scalar_t, Vec>;
scalar_t intput_vec[Vec];
T_ACC trstd;
int idx = i * tcol + j;
idx*=Vec;
*(LoadT*)intput_vec = *(LoadT*)(input+idx);
if (j < tcol) {
#pragma unroll
for (int ii = 0; ii < Vec; ii++) {
val += static_cast<T_ACC>(intput_vec[ii])*static_cast<T_ACC>(intput_vec[ii]);
}
}
val = BlockReduceSum_NEW<T_ACC,block_size>(val,val_shared);
if (j == 0) s_rstd=c10::cuda::compat::rsqrt(val/cols + eps);
__syncthreads();
trstd=s_rstd;
if (j < tcol) {
#pragma unroll
for(int ii=0;ii<Vec;ii++){
int jj=j*Vec+ii;
intput_vec[ii] = static_cast<T_ACC>(intput_vec[ii]) *trstd* static_cast<T_ACC>(gamma[jj]);
}
*(LoadT*)(output+idx)=*(LoadT*)intput_vec;
}
}
void rms_norm_opt(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =inp_ptr % 16 == 0 && wt_ptr % 16 == 0;
if(hidden_size%16==0&&hidden_size<=16384&&ptrs_are_aligned){
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"fused_add_rms_norm_kernel",
[&] {
using T_ACC = at::acc_type<scalar_t, true>;
T_ACC eps = epsilon;
scalar_t* self_data = input.data_ptr<scalar_t>();
scalar_t* out_data =out.data_ptr<scalar_t>();
scalar_t* weight_data=weight.data_ptr<scalar_t>();
if (hidden_size<=1024){
fused_rms_kernel_opt<scalar_t,T_ACC,8,128><<<num_tokens, 128, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=2048){
fused_rms_kernel_opt<scalar_t,T_ACC,8,256><<<num_tokens, 256, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=4096){
if(num_tokens>1200){
fused_rms_kernel_opt<scalar_t,T_ACC,8,512><<<num_tokens, 512, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
else{
fused_rms_kernel_opt<scalar_t,T_ACC,4,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
}
else if(hidden_size<=8192){
fused_rms_kernel_opt<scalar_t,T_ACC,8,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
else{
fused_rms_kernel_opt<scalar_t,T_ACC,16,1024><<<num_tokens, 1024, 0, stream>>>(self_data,out_data,weight_data,hidden_size,eps);
}
});
}
else{
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
});
}
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), epsilon, \
num_tokens, hidden_size); \
});
void fused_add_rms_norm_opt(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if(hidden_size%16==0&&hidden_size>=2048&&hidden_size<=8192&&ptrs_are_aligned){
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"fused_add_rms_norm_kernel",
[&] {
using T_ACC = at::acc_type<scalar_t, true>;
T_ACC eps = epsilon;
scalar_t* self_data = input.data_ptr<scalar_t>();
scalar_t* other_data =residual.data_ptr<scalar_t>();
scalar_t* weight_data=weight.data_ptr<scalar_t>();
if (hidden_size<=1024){
fused_add_rms_kernel_opt<scalar_t,T_ACC,8,128><<<num_tokens, 128, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=2048){
fused_add_rms_kernel_opt<scalar_t,T_ACC,8,256><<<num_tokens, 256, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else if(hidden_size<=4096){
if(num_tokens>1200){
fused_add_rms_kernel_opt<scalar_t,T_ACC,8,512><<<num_tokens, 512, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else{
fused_add_rms_kernel_opt<scalar_t,T_ACC,4,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
}
else if(hidden_size<=8192){
fused_add_rms_kernel_opt<scalar_t,T_ACC,8,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
else{
fused_add_rms_kernel_opt<scalar_t,T_ACC,16,1024><<<num_tokens, 1024, 0, stream>>>(self_data,other_data,weight_data,hidden_size,eps);
}
});
}
else{
dim3 grid(num_tokens);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size));
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
} else {
LAUNCH_FUSED_ADD_RMS_NORM(0);
}
}
}
\ No newline at end of file
...@@ -47,6 +47,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -47,6 +47,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step) -> ()"); " int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops.def(
"paged_attention_v1_opt("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
// PagedAttention V2 (opt).
ops.def(
"paged_attention_v2_opt("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
// Activation ops // Activation ops
// Activation function used in SwiGLU. // Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
...@@ -60,6 +88,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -60,6 +88,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
// Activation function used in SwiGLU. (opt)
ops.def("silu_and_mul_opt(Tensor! out, Tensor input) -> ()");
ops.impl("silu_and_mul_opt", torch::kCUDA, &silu_and_mul);
// Activation function used in GeGLU with `none` approximation. (opt)
ops.def("gelu_and_mul_opt(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_and_mul_opt", torch::kCUDA, &gelu_and_mul);
// Activation function used in GeGLU with `tanh` approximation. (opt)
ops.def("gelu_tanh_and_mul_opt(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_tanh_and_mul_opt", torch::kCUDA, &gelu_tanh_and_mul);
// GELU implementation used in GPT-2. // GELU implementation used in GPT-2.
ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); ops.def("gelu_new(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_new", torch::kCUDA, &gelu_new); ops.impl("gelu_new", torch::kCUDA, &gelu_new);
...@@ -89,6 +129,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -89,6 +129,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()"); "float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
// Apply Root Mean Square (RMS) Normalization to the input tensor. (opt)
ops.def(
"rms_norm_opt(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
"()");
ops.impl("rms_norm_opt", torch::kCUDA, &rms_norm_opt);
// In-place fused Add and RMS Normalization. (opt)
ops.def(
"fused_add_rms_norm_opt(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm_opt", torch::kCUDA, &fused_add_rms_norm_opt);
// Rotary embedding // Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key. // Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def( ops.def(
...@@ -116,6 +168,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -116,6 +168,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache_offsets) -> ()"); " Tensor cos_sin_cache_offsets) -> ()");
ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding); ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding);
// trans w16
ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()");
ops.impl("trans_w16_gemm", torch::kCUDA, &trans_w16_gemm);
// Quantization ops // Quantization ops
#ifndef USE_ROCM #ifndef USE_ROCM
// Quantized GEMM for AQLM. // Quantized GEMM for AQLM.
...@@ -185,10 +241,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -185,10 +241,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"); ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// trans w16
ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()");
ops.impl("trans_w16_gemm", torch::kCUDA, &trans_w16_gemm);
// Quantized GEMM for SqueezeLLM. // Quantized GEMM for SqueezeLLM.
ops.def( ops.def(
"squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor " "squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
......
...@@ -59,6 +59,18 @@ def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: ...@@ -59,6 +59,18 @@ def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_tanh_and_mul(out, x) torch.ops._C.gelu_tanh_and_mul(out, x)
def silu_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.silu_and_mul_opt(out, x)
def gelu_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_and_mul_opt(out, x)
def gelu_tanh_and_mul_opt(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.gelu_tanh_and_mul_opt(out, x)
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
...@@ -135,6 +147,68 @@ def paged_attention_v2( ...@@ -135,6 +147,68 @@ def paged_attention_v2(
blocksparse_block_size, blocksparse_head_sliding_step) blocksparse_block_size, blocksparse_head_sliding_step)
# page attention ops (opt)
def paged_attention_v1_opt(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
torch.ops._C.paged_attention_v1_opt(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step)
def paged_attention_v2_opt(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> None:
torch.ops._C.paged_attention_v2_opt(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
# pos encoding ops # pos encoding ops
def rotary_embedding( def rotary_embedding(
positions: torch.Tensor, positions: torch.Tensor,
...@@ -167,6 +241,17 @@ def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, ...@@ -167,6 +241,17 @@ def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None: weight: torch.Tensor, epsilon: float) -> None:
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
# layer norm ops (opt)
def rms_norm_opt(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
torch.ops._C.rms_norm_opt(out, input, weight, epsilon)
def fused_add_rms_norm_opt(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
torch.ops._C.fused_add_rms_norm_opt(input, residual, weight, epsilon)
def advance_step(num_seqs: int, num_queries: int, block_size: int, def advance_step(num_seqs: int, num_queries: int, block_size: int,
...@@ -180,6 +265,11 @@ def advance_step(num_seqs: int, num_queries: int, block_size: int, ...@@ -180,6 +265,11 @@ def advance_step(num_seqs: int, num_queries: int, block_size: int,
input_positions, seq_lens, slot_mapping, input_positions, seq_lens, slot_mapping,
block_tables) block_tables)
# trans_w16
def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor,
row:int, col:int) -> None :
torch.ops._C.trans_w16_gemm(dst,src,row,col)
# quantization ops # quantization ops
# awq # awq
...@@ -247,10 +337,6 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, ...@@ -247,10 +337,6 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
quant_ops.gptq_shuffle(q_weight, q_perm, bit) quant_ops.gptq_shuffle(q_weight, q_perm, bit)
# torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) # torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# trans_w16
def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor,
row:int, col:int) -> None :
torch.ops._C.trans_w16_gemm(dst,src,row,col)
# squeezellm # squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor, def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
......
...@@ -134,27 +134,50 @@ class PagedAttention: ...@@ -134,27 +134,50 @@ class PagedAttention:
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}") print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}") print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
ops.paged_attention_v1( if envs.VLLM_USE_OPT_OP:
output, ops.paged_attention_v1_opt(
query, output,
key_cache, query,
value_cache, key_cache,
num_kv_heads, value_cache,
scale, num_kv_heads,
block_tables, scale,
seq_lens, block_tables,
block_size, seq_lens,
max_seq_len, block_size,
alibi_slopes, max_seq_len,
kv_cache_dtype, alibi_slopes,
k_scale, kv_cache_dtype,
v_scale, k_scale,
tp_rank, v_scale,
blocksparse_local_blocks, tp_rank,
blocksparse_vert_stride, blocksparse_local_blocks,
blocksparse_block_size, blocksparse_vert_stride,
blocksparse_head_sliding_step, blocksparse_block_size,
) blocksparse_head_sliding_step,
)
else:
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0 assert _PARTITION_SIZE % block_size == 0
...@@ -176,30 +199,56 @@ class PagedAttention: ...@@ -176,30 +199,56 @@ class PagedAttention:
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}") print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}") print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
ops.paged_attention_v2( if envs.VLLM_USE_OPT_OP:
output, ops.paged_attention_v2_opt(
exp_sums, output,
max_logits, exp_sums,
tmp_output, max_logits,
query, tmp_output,
key_cache, query,
value_cache, key_cache,
num_kv_heads, value_cache,
scale, num_kv_heads,
block_tables, scale,
seq_lens, block_tables,
block_size, seq_lens,
max_seq_len, block_size,
alibi_slopes, max_seq_len,
kv_cache_dtype, alibi_slopes,
k_scale, kv_cache_dtype,
v_scale, k_scale,
tp_rank, v_scale,
blocksparse_local_blocks, tp_rank,
blocksparse_vert_stride, blocksparse_local_blocks,
blocksparse_block_size, blocksparse_vert_stride,
blocksparse_head_sliding_step, blocksparse_block_size,
) blocksparse_head_sliding_step,
)
else:
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
)
return output return output
@staticmethod @staticmethod
......
...@@ -12,6 +12,7 @@ if TYPE_CHECKING: ...@@ -12,6 +12,7 @@ if TYPE_CHECKING:
LD_LIBRARY_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_FLASH_ATTN_AUTO: bool = False VLLM_USE_FLASH_ATTN_AUTO: bool = False
VLLM_USE_OPT_OP: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False VLLM_USE_PA_PRINT_PARAM: bool = False
LOCAL_RANK: int = 0 LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None CUDA_VISIBLE_DEVICES: Optional[str] = None
...@@ -188,6 +189,11 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -188,6 +189,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_AUTO", "True").lower() in lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_AUTO", "True").lower() in
("true", "1")), ("true", "1")),
# flag to control vllm to use optimized kernels
"VLLM_USE_OPT_OP":
lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in
("true", "1")),
# flag to control if vllm print pa parameters # flag to control if vllm print pa parameters
"VLLM_USE_PA_PRINT_PARAM": "VLLM_USE_PA_PRINT_PARAM":
lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in lambda: (os.environ.get("VLLM_USE_PA_PRINT_PARAM", "False").lower() in
......
...@@ -11,6 +11,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, ...@@ -11,6 +11,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
import vllm.envs as envs
class SiluAndMul(CustomOp): class SiluAndMul(CustomOp):
...@@ -34,7 +35,10 @@ class SiluAndMul(CustomOp): ...@@ -34,7 +35,10 @@ class SiluAndMul(CustomOp):
d = x.shape[-1] // 2 d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, )) output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device) out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x) if envs.VLLM_USE_OPT_OP:
ops.silu_and_mul_opt(out, x)
else:
ops.silu_and_mul(out, x)
return out return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
...@@ -75,9 +79,15 @@ class GeluAndMul(CustomOp): ...@@ -75,9 +79,15 @@ class GeluAndMul(CustomOp):
output_shape = (x.shape[:-1] + (d, )) output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device) out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if self.approximate == "none": if self.approximate == "none":
ops.gelu_and_mul(out, x) if envs.VLLM_USE_OPT_OP:
ops.gelu_and_mul_opt(out, x)
else:
ops.gelu_and_mul(out, x)
elif self.approximate == "tanh": elif self.approximate == "tanh":
ops.gelu_tanh_and_mul(out, x) if envs.VLLM_USE_OPT_OP:
ops.gelu_tanh_and_mul_opt(out, x)
else:
ops.gelu_tanh_and_mul(out, x)
return out return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
import vllm.envs as envs
class RMSNorm(CustomOp): class RMSNorm(CustomOp):
...@@ -51,20 +52,36 @@ class RMSNorm(CustomOp): ...@@ -51,20 +52,36 @@ class RMSNorm(CustomOp):
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
if residual is not None: if residual is not None:
ops.fused_add_rms_norm( if envs.VLLM_USE_OPT_OP:
ops.fused_add_rms_norm_opt(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
else:
ops.fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
return x, residual
out = torch.empty_like(x)
if envs.VLLM_USE_OPT_OP:
ops.rms_norm_opt(
out,
x,
self.weight.data,
self.variance_epsilon,
)
else:
ops.rms_norm(
out,
x, x,
residual,
self.weight.data, self.weight.data,
self.variance_epsilon, self.variance_epsilon,
) )
return x, residual
out = torch.empty_like(x)
ops.rms_norm(
out,
x,
self.weight.data,
self.variance_epsilon,
)
return out return out
def forward_xpu( def forward_xpu(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment