Commit d589e598 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.6.2-dev_wm' into 'v0.6.2-dev'

[feat]优化medusa代码,通过VLLM_TREE_DECODING环境变量控制是否采用tree-style解码,计算逻辑与主干隔离

See merge request dcutoolkit/deeplearing/vllm!51
parents 54b92ba4 0bb491f8
...@@ -198,7 +198,10 @@ set(VLLM_EXT_SRC ...@@ -198,7 +198,10 @@ set(VLLM_EXT_SRC
"csrc/cuda_utils_kernels.cu" "csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu" "csrc/moe_align_block_size_kernels.cu"
"csrc/prepare_inputs/advance_step.cu" "csrc/prepare_inputs/advance_step.cu"
"csrc/torch_bindings.cpp") "csrc/torch_bindings.cpp"
"csrc/attention/attention_with_mask_kernels.cu"
"csrc/attention/attention_with_mask_kernels_opt.cu"
"csrc/attention/attention_with_mask_kernels_opt_tc.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA") if(VLLM_GPU_LANG STREQUAL "CUDA")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
......
...@@ -107,8 +107,7 @@ __device__ void paged_attention_kernel( ...@@ -107,8 +107,7 @@ __device__ void paged_attention_kernel(
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z; const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z; const int max_num_partitions = gridDim.z;
...@@ -297,14 +296,6 @@ __device__ void paged_attention_kernel( ...@@ -297,14 +296,6 @@ __device__ void paged_attention_kernel(
// Add the ALiBi bias if slopes are given. // Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
// used for tree-style attention
if (attn_masks != nullptr) {
const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride;
if (attn_masks_ptr[token_idx] == 0) {
qk = -FLT_MAX;
}
}
if (thread_group_offset == 0) { if (thread_group_offset == 0) {
// Store the partial reductions to shared memory. // Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits. // NOTE(woosuk): It is required to zero out the masked logits.
...@@ -524,8 +515,7 @@ __global__ void paged_attention_v1_kernel( ...@@ -524,8 +515,7 @@ __global__ void paged_attention_v1_kernel(
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE>( KV_DTYPE, IS_BLOCK_SPARSE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
...@@ -533,7 +523,7 @@ __global__ void paged_attention_v1_kernel( ...@@ -533,7 +523,7 @@ __global__ void paged_attention_v1_kernel(
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step, attn_masks, attn_masks_stride); blocksparse_head_sliding_step);
} }
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
...@@ -561,15 +551,14 @@ __global__ void paged_attention_v2_kernel( ...@@ -561,15 +551,14 @@ __global__ void paged_attention_v2_kernel(
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>( KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step, attn_masks, attn_masks_stride); blocksparse_head_sliding_step);
} }
// Grid: (num_heads, num_seqs). // Grid: (num_heads, num_seqs).
...@@ -695,8 +684,7 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -695,8 +684,7 @@ __global__ void paged_attention_v2_reduce_kernel(
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks_ptr, \ blocksparse_head_sliding_step);
attn_masks_stride);
// TODO(woosuk): Tune NUM_THREADS. // TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE, template <typename T, typename CACHE_T, int BLOCK_SIZE,
...@@ -709,9 +697,7 @@ void paged_attention_v1_launcher( ...@@ -709,9 +697,7 @@ void paged_attention_v1_launcher(
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step, const int blocksparse_head_sliding_step) {
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
...@@ -736,12 +722,6 @@ void paged_attention_v1_launcher( ...@@ -736,12 +722,6 @@ void paged_attention_v1_launcher(
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
// NOTE: attn_masks is optional.
const int* attn_masks_ptr =
attn_masks
? attn_masks.value().data_ptr<int>()
: nullptr;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_seq_len = int padded_max_seq_len =
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
...@@ -798,8 +778,7 @@ void paged_attention_v1_launcher( ...@@ -798,8 +778,7 @@ void paged_attention_v1_launcher(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step, \ blocksparse_block_size, blocksparse_head_sliding_step);
attn_masks, attn_masks_stride);
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ #define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \ switch (is_block_sparse) { \
...@@ -845,9 +824,7 @@ void paged_attention_v1( ...@@ -845,9 +824,7 @@ void paged_attention_v1(
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step) {
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
...@@ -864,8 +841,7 @@ void paged_attention_v1( ...@@ -864,8 +841,7 @@ void paged_attention_v1(
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step, \ blocksparse_block_size, blocksparse_head_sliding_step); \
attn_masks_ptr, attn_masks_stride); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \ vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE> \ PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \ <<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
...@@ -883,9 +859,7 @@ void paged_attention_v2_launcher( ...@@ -883,9 +859,7 @@ void paged_attention_v2_launcher(
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step, const int blocksparse_head_sliding_step) {
const c10::optional<torch::Tensor>& attn_masks,
const int attn_masks_stride) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
...@@ -913,10 +887,6 @@ void paged_attention_v2_launcher( ...@@ -913,10 +887,6 @@ void paged_attention_v2_launcher(
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
// NOTE: attn_masks is optional.
const int* attn_masks_ptr =
attn_masks ? attn_masks.value().data_ptr<int>() : nullptr;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
int logits_size = PARTITION_SIZE * sizeof(float); int logits_size = PARTITION_SIZE * sizeof(float);
...@@ -976,7 +946,7 @@ void paged_attention_v2_launcher( ...@@ -976,7 +946,7 @@ void paged_attention_v2_launcher(
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks, attn_masks_stride); blocksparse_head_sliding_step);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ #define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \ switch (is_block_sparse) { \
...@@ -1026,9 +996,7 @@ void paged_attention_v2( ...@@ -1026,9 +996,7 @@ void paged_attention_v2(
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step) {
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE) CALL_V2_LAUNCHER_BLOCK_SIZE)
......
...@@ -94,8 +94,7 @@ __device__ void paged_attention_kernel_opt( ...@@ -94,8 +94,7 @@ __device__ void paged_attention_kernel_opt(
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
const int seq_idx = blockIdx.z; const int seq_idx = blockIdx.z;
const int partition_idx = blockIdx.y; const int partition_idx = blockIdx.y;
const int max_num_partitions = gridDim.y; const int max_num_partitions = gridDim.y;
...@@ -328,25 +327,11 @@ __device__ void paged_attention_kernel_opt( ...@@ -328,25 +327,11 @@ __device__ void paged_attention_kernel_opt(
// Add the ALiBi bias if slopes are given. // Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
// used for tree-style attention
if (attn_masks != nullptr) {
const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride;
if (attn_masks_ptr[token_idx] == 0) {
qk = -FLT_MAX;
}
}
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
if (thread_group_offset == 0) { if (thread_group_offset == 0) {
// Store the partial reductions to shared memory. // Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits. // NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= seq_len; const bool mask = token_idx >= seq_len;
// used for tree-style attention
/*if (attn_masks != nullptr) {
const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride;
mask |= attn_masks_ptr[token_idx] == 0;
}*/
logits[(reuse_kv_idx * partition_size) + (token_idx - start_token_idx)] = mask ? 0.f : qk; logits[(reuse_kv_idx * partition_size) + (token_idx - start_token_idx)] = mask ? 0.f : qk;
// Update the max value. // Update the max value.
qk_max[reuse_kv_idx] = mask ? qk_max[reuse_kv_idx] : fmaxf(qk_max[reuse_kv_idx], qk); qk_max[reuse_kv_idx] = mask ? qk_max[reuse_kv_idx] : fmaxf(qk_max[reuse_kv_idx], qk);
...@@ -627,8 +612,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel_opt( ...@@ -627,8 +612,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel_opt(
const float k_scale, const float v_scale, const int tp_rank, const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
paged_attention_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads>( KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
...@@ -636,7 +620,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel_opt( ...@@ -636,7 +620,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel_opt(
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step, attn_masks, attn_masks_stride); blocksparse_head_sliding_step);
} }
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
...@@ -668,15 +652,14 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel_opt( ...@@ -668,15 +652,14 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel_opt(
const float k_scale, const float v_scale, const int tp_rank, const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
paged_attention_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads, PARTITION_SIZE>( KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads, num_kv_heads, scale, exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step, attn_masks, attn_masks_stride); blocksparse_head_sliding_step);
} }
// Grid: (num_heads, num_seqs). // Grid: (num_heads, num_seqs).
...@@ -802,8 +785,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt( ...@@ -802,8 +785,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks_ptr, \ blocksparse_head_sliding_step);
attn_masks_stride);
// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ // #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
// vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \ // vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
...@@ -826,9 +808,7 @@ void paged_attention_v1_launcher( ...@@ -826,9 +808,7 @@ void paged_attention_v1_launcher(
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step, const int blocksparse_head_sliding_step) {
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
...@@ -857,12 +837,6 @@ void paged_attention_v1_launcher( ...@@ -857,12 +837,6 @@ void paged_attention_v1_launcher(
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
// NOTE: attn_masks is optional.
const int* attn_masks_ptr =
attn_masks
? attn_masks.value().data_ptr<int>()
: nullptr;
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
REUSEKV_SWITCH_V1(num_heads * num_seqs , [&] { REUSEKV_SWITCH_V1(num_heads * num_seqs , [&] {
BOOL_SWITCH((num_heads/num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] { BOOL_SWITCH((num_heads/num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] {
...@@ -896,8 +870,7 @@ void paged_attention_v1_launcher( ...@@ -896,8 +870,7 @@ void paged_attention_v1_launcher(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step, \ blocksparse_block_size, blocksparse_head_sliding_step);
attn_masks, attn_masks_stride);
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ #define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \ switch (is_block_sparse) { \
...@@ -943,9 +916,7 @@ void paged_attention_v1_opt( ...@@ -943,9 +916,7 @@ void paged_attention_v1_opt(
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step) {
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
...@@ -962,8 +933,7 @@ void paged_attention_v1_opt( ...@@ -962,8 +933,7 @@ void paged_attention_v1_opt(
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step, \ blocksparse_block_size, blocksparse_head_sliding_step); \
attn_masks_ptr, attn_masks_stride); \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \ hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>) \ PARTITION_SIZE>) \
, dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, \ , dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, \
...@@ -981,9 +951,7 @@ void paged_attention_v2_launcher( ...@@ -981,9 +951,7 @@ void paged_attention_v2_launcher(
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step, const int blocksparse_head_sliding_step) {
const c10::optional<torch::Tensor>& attn_masks,
const int attn_masks_stride) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
...@@ -1011,10 +979,6 @@ void paged_attention_v2_launcher( ...@@ -1011,10 +979,6 @@ void paged_attention_v2_launcher(
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
// NOTE: attn_masks is optional.
const int* attn_masks_ptr =
attn_masks ? attn_masks.value().data_ptr<int>() : nullptr;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
REUSEKV_SWITCH(num_heads * max_num_partitions * num_seqs , [&] { REUSEKV_SWITCH(num_heads * max_num_partitions * num_seqs , [&] {
...@@ -1053,7 +1017,7 @@ void paged_attention_v2_launcher( ...@@ -1053,7 +1017,7 @@ void paged_attention_v2_launcher(
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks, attn_masks_stride); blocksparse_head_sliding_step);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ #define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \ switch (is_block_sparse) { \
...@@ -1103,9 +1067,7 @@ void paged_attention_v2_opt( ...@@ -1103,9 +1067,7 @@ void paged_attention_v2_opt(
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step) {
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE) CALL_V2_LAUNCHER_BLOCK_SIZE)
......
...@@ -168,8 +168,7 @@ __device__ void paged_attention_kernel_TC( ...@@ -168,8 +168,7 @@ __device__ void paged_attention_kernel_TC(
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
const int seq_idx = blockIdx.z; const int seq_idx = blockIdx.z;
const int partition_idx = blockIdx.y; const int partition_idx = blockIdx.y;
const int max_num_partitions = gridDim.y; const int max_num_partitions = gridDim.y;
...@@ -293,14 +292,6 @@ __device__ void paged_attention_kernel_TC( ...@@ -293,14 +292,6 @@ __device__ void paged_attention_kernel_TC(
qk_vec[i] = alibi; qk_vec[i] = alibi;
} }
// used for tree-style attention
if (attn_masks != nullptr) {
const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride;
if (attn_masks_ptr[token_idx] == 0) {
qk_vec[i] = -FLT_MAX;
}
}
const bool mask = (token_idx >= seq_len); const bool mask = (token_idx >= seq_len);
if(mask){ if(mask){
from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f); from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f);
...@@ -565,8 +556,7 @@ __global__ void paged_attention_v1_kernel_TC( ...@@ -565,8 +556,7 @@ __global__ void paged_attention_v1_kernel_TC(
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
#if defined(__gfx936__) || defined(__gfx928__) #if defined(__gfx936__) || defined(__gfx928__)
paged_attention_kernel_TC<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel_TC<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE,REUSE_KV_TIMES,use_vmac>( KV_DTYPE, IS_BLOCK_SPARSE,REUSE_KV_TIMES,use_vmac>(
...@@ -575,7 +565,7 @@ __global__ void paged_attention_v1_kernel_TC( ...@@ -575,7 +565,7 @@ __global__ void paged_attention_v1_kernel_TC(
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step, attn_masks, attn_masks_stride); blocksparse_head_sliding_step);
#endif #endif
} }
...@@ -605,8 +595,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC( ...@@ -605,8 +595,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
#if defined(__gfx936__) || defined(__gfx928__) #if defined(__gfx936__) || defined(__gfx928__)
paged_attention_kernel_TC<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel_TC<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES,use_vmac, KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES,use_vmac,
...@@ -615,7 +604,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC( ...@@ -615,7 +604,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq,
alibi_slopes, q_stride, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step, attn_masks, attn_masks_stride); blocksparse_head_sliding_step);
#endif #endif
} }
...@@ -742,8 +731,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t ...@@ -742,8 +731,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks_ptr, \ blocksparse_head_sliding_step);
attn_masks_stride);
void get_numberthread_and_reuse_kv_v1(int& num_thread,int& reusekv,int batchsize,int seq,int qheads,int kvheads){ void get_numberthread_and_reuse_kv_v1(int& num_thread,int& reusekv,int batchsize,int seq,int qheads,int kvheads){
//mha //mha
...@@ -809,9 +797,7 @@ void paged_attention_v1_launcher_opt_tc( ...@@ -809,9 +797,7 @@ void paged_attention_v1_launcher_opt_tc(
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step, const int blocksparse_head_sliding_step) {
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
...@@ -840,12 +826,6 @@ void paged_attention_v1_launcher_opt_tc( ...@@ -840,12 +826,6 @@ void paged_attention_v1_launcher_opt_tc(
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
// NOTE: attn_masks is optional.
const int* attn_masks_ptr =
attn_masks
? attn_masks.value().data_ptr<int>()
: nullptr;
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
...@@ -880,8 +860,7 @@ void paged_attention_v1_launcher_opt_tc( ...@@ -880,8 +860,7 @@ void paged_attention_v1_launcher_opt_tc(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step, \ blocksparse_block_size, blocksparse_head_sliding_step);
attn_masks, attn_masks_stride);
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ #define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \ switch (is_block_sparse) { \
...@@ -927,9 +906,7 @@ void paged_attention_v1_opt( ...@@ -927,9 +906,7 @@ void paged_attention_v1_opt(
const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step);
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride);
void paged_attention_v1_opt_tc( void paged_attention_v1_opt_tc(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
...@@ -947,17 +924,14 @@ void paged_attention_v1_opt_tc( ...@@ -947,17 +924,14 @@ void paged_attention_v1_opt_tc(
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step) {
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse|| if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||(get_device_name()!="gfx928" && get_device_name()!="gfx936")){ block_size!=16||query.size(2)!=128||(get_device_name()!="gfx928" && get_device_name()!="gfx936")){
paged_attention_v1_opt(out,query,key_cache,value_cache,num_kv_heads, paged_attention_v1_opt(out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype, scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride, k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step, blocksparse_block_size,blocksparse_head_sliding_step);
attn_masks, attn_masks_stride);
} }
else{ else{
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
...@@ -976,8 +950,7 @@ void paged_attention_v1_opt_tc( ...@@ -976,8 +950,7 @@ void paged_attention_v1_opt_tc(
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, \ blocksparse_head_sliding_step); \
attn_masks_ptr, attn_masks_stride); \
hipLaunchKernelGGL( \ hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, NUM_THREADS, \ (vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>), \ PARTITION_SIZE>), \
...@@ -1028,9 +1001,7 @@ void paged_attention_v2_launcher_opt_tc( ...@@ -1028,9 +1001,7 @@ void paged_attention_v2_launcher_opt_tc(
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step, const int blocksparse_head_sliding_step) {
const c10::optional<torch::Tensor>& attn_masks,
const int attn_masks_stride) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
...@@ -1058,10 +1029,6 @@ void paged_attention_v2_launcher_opt_tc( ...@@ -1058,10 +1029,6 @@ void paged_attention_v2_launcher_opt_tc(
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
// NOTE: attn_masks is optional.
const int* attn_masks_ptr =
attn_masks ? attn_masks.value().data_ptr<int>() : nullptr;
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 reduce_grid(num_heads, num_seqs); dim3 reduce_grid(num_heads, num_seqs);
...@@ -1103,7 +1070,7 @@ void paged_attention_v2_launcher_opt_tc( ...@@ -1103,7 +1070,7 @@ void paged_attention_v2_launcher_opt_tc(
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks, attn_masks_stride); blocksparse_head_sliding_step);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ #define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \ switch (is_block_sparse) { \
...@@ -1153,9 +1120,7 @@ void paged_attention_v2_opt( ...@@ -1153,9 +1120,7 @@ void paged_attention_v2_opt(
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step);
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride);
void paged_attention_v2_opt_tc( void paged_attention_v2_opt_tc(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
...@@ -1177,17 +1142,14 @@ void paged_attention_v2_opt_tc( ...@@ -1177,17 +1142,14 @@ void paged_attention_v2_opt_tc(
const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step) {
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse|| if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||(get_device_name()!="gfx928" && get_device_name()!="gfx936")){ block_size!=16||query.size(2)!=128||(get_device_name()!="gfx928" && get_device_name()!="gfx936")){
paged_attention_v2_opt(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads, paged_attention_v2_opt(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype, scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride, k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step, attn_masks, blocksparse_block_size,blocksparse_head_sliding_step);
attn_masks_stride);
} }
else{ else{
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -6,6 +6,71 @@ ...@@ -6,6 +6,71 @@
#include "core/scalar_type.hpp" #include "core/scalar_type.hpp"
void paged_attention_v1( void paged_attention_v1(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v2(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v1_opt(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v2_opt(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v1_opt_tc(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v2_opt_tc(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
// paged_attention with attn_masks
void paged_attention_v1_with_mask(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
...@@ -17,7 +82,7 @@ void paged_attention_v1( ...@@ -17,7 +82,7 @@ void paged_attention_v1(
const c10::optional<torch::Tensor>& attn_masks, const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0); const int64_t attn_masks_stride=0);
void paged_attention_v2( void paged_attention_v2_with_mask(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
...@@ -30,7 +95,7 @@ void paged_attention_v2( ...@@ -30,7 +95,7 @@ void paged_attention_v2(
const c10::optional<torch::Tensor>& attn_masks, const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0); const int64_t attn_masks_stride=0);
void paged_attention_v1_opt( void paged_attention_v1_opt_with_mask(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
...@@ -42,7 +107,7 @@ void paged_attention_v1_opt( ...@@ -42,7 +107,7 @@ void paged_attention_v1_opt(
const c10::optional<torch::Tensor>& attn_masks, const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0); const int64_t attn_masks_stride=0);
void paged_attention_v2_opt( void paged_attention_v2_opt_with_mask(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
...@@ -55,7 +120,7 @@ void paged_attention_v2_opt( ...@@ -55,7 +120,7 @@ void paged_attention_v2_opt(
const c10::optional<torch::Tensor>& attn_masks, const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0); const int64_t attn_masks_stride=0);
void paged_attention_v1_opt_tc( void paged_attention_v1_opt_tc_with_mask(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
...@@ -67,7 +132,7 @@ void paged_attention_v1_opt_tc( ...@@ -67,7 +132,7 @@ void paged_attention_v1_opt_tc(
const c10::optional<torch::Tensor>& attn_masks, const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0); const int64_t attn_masks_stride=0);
void paged_attention_v2_opt_tc( void paged_attention_v2_opt_tc_with_mask(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
...@@ -80,6 +145,7 @@ void paged_attention_v2_opt_tc( ...@@ -80,6 +145,7 @@ void paged_attention_v2_opt_tc(
const c10::optional<torch::Tensor>& attn_masks, const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0); const int64_t attn_masks_stride=0);
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
double epsilon); double epsilon);
......
...@@ -30,14 +30,98 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -30,14 +30,98 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" str kv_cache_dtype, float k_scale, float v_scale," " str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
// PagedAttention V2.
ops.def(
"paged_attention_v2("
" Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
" Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops.def(
"paged_attention_v1_opt("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1_opt", torch::kCUDA, &paged_attention_v1_opt);
// PagedAttention V2 (opt).
ops.def(
"paged_attention_v2_opt("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2_opt", torch::kCUDA, &paged_attention_v2_opt);
// Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt)
ops.def(
"paged_attention_v1_opt_tc("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1_opt_tc", torch::kCUDA, &paged_attention_v1_opt_tc);
// PagedAttention V2 (opt).
ops.def(
"paged_attention_v2_opt_tc("
" Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2_opt_tc", torch::kCUDA, &paged_attention_v2_opt_tc);
// paged_attention with atth_masks
ops.def(
"paged_attention_v1_with_mask("
" Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step," " int blocksparse_head_sliding_step,"
" Tensor? attn_masks," " Tensor? attn_masks,"
" int attn_masks_stride) -> ()"); " int attn_masks_stride) -> ()");
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1); ops.impl("paged_attention_v1_with_mask", torch::kCUDA, &paged_attention_v1_with_mask);
// PagedAttention V2. // PagedAttention V2.
ops.def( ops.def(
"paged_attention_v2(" "paged_attention_v2_with_mask("
" Tensor! out, Tensor! exp_sums, Tensor! max_logits," " Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
" Tensor! tmp_out, Tensor query, Tensor key_cache," " Tensor! tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
...@@ -49,12 +133,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -49,12 +133,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step," " int blocksparse_head_sliding_step,"
" Tensor? attn_masks," " Tensor? attn_masks,"
" int attn_masks_stride) -> ()"); " int attn_masks_stride) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); ops.impl("paged_attention_v2_with_mask", torch::kCUDA, &paged_attention_v2_with_mask);
// Compute the attention between an input query and the cached // Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt) // keys/values using PagedAttention. (opt)
ops.def( ops.def(
"paged_attention_v1_opt(" "paged_attention_v1_opt_with_mask("
" Tensor! out, Tensor query, Tensor key_cache," " Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
...@@ -65,11 +149,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -65,11 +149,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step," " int blocksparse_head_sliding_step,"
" Tensor? attn_masks," " Tensor? attn_masks,"
" int attn_masks_stride) -> ()"); " int attn_masks_stride) -> ()");
ops.impl("paged_attention_v1_opt", torch::kCUDA, &paged_attention_v1_opt); ops.impl("paged_attention_v1_opt_with_mask", torch::kCUDA, &paged_attention_v1_opt_with_mask);
// PagedAttention V2 (opt). // PagedAttention V2 (opt).
ops.def( ops.def(
"paged_attention_v2_opt(" "paged_attention_v2_opt_with_mask("
" Tensor! out, Tensor exp_sums, Tensor max_logits," " Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache," " Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
...@@ -81,12 +165,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -81,12 +165,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step," " int blocksparse_head_sliding_step,"
" Tensor? attn_masks," " Tensor? attn_masks,"
" int attn_masks_stride) -> ()"); " int attn_masks_stride) -> ()");
ops.impl("paged_attention_v2_opt", torch::kCUDA, &paged_attention_v2_opt); ops.impl("paged_attention_v2_opt_with_mask", torch::kCUDA, &paged_attention_v2_opt_with_mask);
// Compute the attention between an input query and the cached // Compute the attention between an input query and the cached
// keys/values using PagedAttention. (opt) // keys/values using PagedAttention. (opt)
ops.def( ops.def(
"paged_attention_v1_opt_tc(" "paged_attention_v1_opt_tc_with_mask("
" Tensor! out, Tensor query, Tensor key_cache," " Tensor! out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
...@@ -97,11 +181,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -97,11 +181,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step," " int blocksparse_head_sliding_step,"
" Tensor? attn_masks," " Tensor? attn_masks,"
" int attn_masks_stride) -> ()"); " int attn_masks_stride) -> ()");
ops.impl("paged_attention_v1_opt_tc", torch::kCUDA, &paged_attention_v1_opt_tc); ops.impl("paged_attention_v1_opt_tc_with_mask", torch::kCUDA, &paged_attention_v1_opt_tc_with_mask);
// PagedAttention V2 (opt). // PagedAttention V2 (opt).
ops.def( ops.def(
"paged_attention_v2_opt_tc(" "paged_attention_v2_opt_tc_with_mask("
" Tensor! out, Tensor exp_sums, Tensor max_logits," " Tensor! out, Tensor exp_sums, Tensor max_logits,"
" Tensor tmp_out, Tensor query, Tensor key_cache," " Tensor tmp_out, Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
...@@ -113,7 +197,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -113,7 +197,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int blocksparse_head_sliding_step," " int blocksparse_head_sliding_step,"
" Tensor? attn_masks," " Tensor? attn_masks,"
" int attn_masks_stride) -> ()"); " int attn_masks_stride) -> ()");
ops.impl("paged_attention_v2_opt_tc", torch::kCUDA, &paged_attention_v2_opt_tc); ops.impl("paged_attention_v2_opt_tc_with_mask", torch::kCUDA, &paged_attention_v2_opt_tc_with_mask);
// Activation ops // Activation ops
// Activation function used in SwiGLU. // Activation function used in SwiGLU.
......
# Medusa Decoding # Medusa Decoding
本文说明如何使用vllm构建和运行medusa模型
本文说明如何使用vllm构建和运行medusa模型,目前medusa支持tree-style generation,target model和draft model均可多卡推理
## Overview ## Overview
Medusa是一种大模型并行解码算法,除了支持官方提供的Top1-proposer,我们还支持tree-style并行解码,target model和draft model均可多卡推理
与其他模型不同,medusa解码需要一个base model和若干Medusa heads. 与其他模型不同,medusa解码需要一个base model和若干Medusa heads.
Vllm medusa model的实现在[vllm/model_executor/models/medusa.py] Vllm medusa model的实现在[vllm/model_executor/models/medusa.py]
...@@ -19,28 +20,43 @@ Vllm medusa model的实现在[vllm/model_executor/models/medusa.py] ...@@ -19,28 +20,43 @@ Vllm medusa model的实现在[vllm/model_executor/models/medusa.py]
```bash ```bash
python medusa_weight_converter.py --medusa_num_heads 4 --medusa_num_layers 1 --medusa_model_path /work/model.bin --vocab_size 152064 --hidden_size 8192 --output_dir /work/medusa/vllm-medusa-qwen2-72b-head-4 --medusa_choices="[(0), (0, 0), (0, 0, 0), (0, 1), (1), (1, 0), (0, 0, 0, 0), (0, 0, 1), (0, 2), (0, 1, 0), (2), (0, 0, 2), (0, 3), (1, 0, 0), (2, 0), (0, 2, 0), (0, 4), (0, 0, 3), (3), (0, 0, 0, 1), (0, 5), (0, 0, 1, 0), (0, 0, 4)]" python medusa_weight_converter.py --medusa_num_heads 4 --medusa_num_layers 1 --medusa_model_path /work/model.bin --vocab_size 152064 --hidden_size 8192 --output_dir /work/medusa/vllm-medusa-qwen2-72b-head-4 --medusa_choices="[(0), (0, 0), (0, 0, 0), (0, 1), (1), (1, 0), (0, 0, 0, 0), (0, 0, 1), (0, 2), (0, 1, 0), (2), (0, 0, 2), (0, 3), (1, 0, 0), (2, 0), (0, 2, 0), (0, 4), (0, 0, 3), (3), (0, 0, 0, 1), (0, 5), (0, 0, 1, 0), (0, 0, 4)]"
``` ```
此处model.bin是训练后保存的medusa head权重 此处model.bin是训练后保存的medusa head权重,如果希望采用Top1-proposer,medusa_choices可以不设置
### Run ### Run tree-style generation server
```bash ```bash
python3 -m vllm.entrypoints.openai.api_server \ VLLM_TREE_DECODING=1 python3 -m vllm.entrypoints.openai.api_server \
--served-model-name qwen_medusa \ --served-model-name qwen_medusa \
--model /models/Qwen2-72B-Instruct/ -tp 4 \ --model /models/Qwen2-72B-Instruct/ -tp 4 \
--max-model-len 1024 --max-num-seqs 8 --gpu-memory-utilization 0.8 \ --max-model-len 1024 --max-num-seqs 8 --gpu-memory-utilization 0.8 \
--speculative-model /work/medusa/vllm-medusa-qwen2-72b-head-4 \ --speculative-model /work/medusa/vllm-medusa-qwen2-72b-head-4 \
--speculative-draft-tensor-parallel-size 4 \ --speculative-draft-tensor-parallel-size 4 \
--speculative-disable-by-batch-size 4 \ --speculative-disable-by-batch-size 9 \
--use-v2-block-manager \ --use-v2-block-manager \
--spec-decoding-acceptance-method typical_acceptance_sampler \ --spec-decoding-acceptance-method typical_acceptance_sampler \
--dtype float16 --trust-remote-code --port 8086\ --dtype float16 --trust-remote-code --port 8086\
--tree-style-spec-decoding True\
--num-speculative-heads 4 --num-speculative-tokens 24 --num-speculative-heads 4 --num-speculative-tokens 24
``` ```
注意:
num_speculative_tokens = len(medusa_choices) + 1
medusa_choices个数不能太多,否则多batch下会降低推理速度
speculative-disable-by-batch-size要大于max-num-seqs,否则当batch等于max-num-seqs时,不会走并行解码
merge-lora可以将lora权重和base model权重融合,提升整体推理速度,若对精度有严格要求,可不设置此参数 ### Run Top1-proposer server
num-speculative-tokens和medusa choices的个数相关,num_speculative_tokens = len(medusa_choices) + 1 python3 -m vllm.entrypoints.openai.api_server \
--served-model-name qwen_medusa \
--model /models/Qwen2-72B-Instruct/ -tp 4 \
--max-model-len 1024 --max-num-seqs 8 --gpu-memory-utilization 0.8 \
--speculative-model /work/medusa/vllm-medusa-qwen2-72b-head-4 \
--speculative-draft-tensor-parallel-size 4 \
--speculative-disable-by-batch-size 9 \
--use-v2-block-manager \
--spec-decoding-acceptance-method typical_acceptance_sampler \
--dtype float16 --trust-remote-code --port 8086\
--num-speculative-tokens 4
注意:
使用Top1-proposer时,num-speculative-tokens就是medusa head的个数
# do request # do request
```bash ```bash
...@@ -54,8 +70,14 @@ curl http://localhost:8086/v1/completions \ ...@@ -54,8 +70,14 @@ curl http://localhost:8086/v1/completions \
}' }'
``` ```
### benchmark ### Run tree-style benchmark
python medusa_benchmark_throughput.py --model /data/llm-models/qwen2/Qwen2-72B-Instruct/ -tp 4 --dtype float16 --trust-remote-code --max-num-seqs 1 --dataset /work/test/medusa_benchmark_data.json --max-model-len 4096 --gpu-memory-utilization 0.9 ```bash
VLLM_TREE_DECODING=1 python /work/test/medusa_benchmark_throughput.py --model /models/Qwen2-72B-Instruct/ -tp 4 --dtype float16 --trust-remote-code --max-num-seqs 4 --speculative-model /work/medusa/vllm-medusa1-qwen2-72b-head-4 --speculative-draft-tensor-parallel-size 4 --speculative-disable-by-batch-size 9 --use-v2-block-manager --spec-decoding-acceptance-method typical_acceptance_sampler --max-model-len 1024 --dataset /work/medusa_benchmark_data.json --num-speculative-heads 4 --num-speculative-tokens 24 --gpu-memory-utilization 0.95
```
### Run Top1-proposer benchmark
```bash
python /work/test/medusa_benchmark_throughput.py --model /models/Qwen2-72B-Instruct/ -tp 4 --dtype float16 --trust-remote-code --max-num-seqs 4 --speculative-model /work/medusa/vllm-medusa1-qwen2-72b-head-4 --speculative-draft-tensor-parallel-size 4 --speculative-disable-by-batch-size 9 --use-v2-block-manager --spec-decoding-acceptance-method typical_acceptance_sampler --max-model-len 1024 --dataset /work/medusa_benchmark_data.json --num-speculative-tokens 4 --gpu-memory-utilization 0.95
```
可设置max-num-seqs对不同的batch进行性能测试 可设置max-num-seqs对不同的batch进行性能测试
...@@ -98,7 +98,6 @@ def run_vllm( ...@@ -98,7 +98,6 @@ def run_vllm(
merge_lora: bool = False, merge_lora: bool = False,
lora_extra_vocab_size: int = 0, lora_extra_vocab_size: int = 0,
lora_target_modules: List[str] = None, lora_target_modules: List[str] = None,
tree_style_spec_decoding: bool = False,
num_speculative_heads: int = 5, num_speculative_heads: int = 5,
num_speculative_tokens: int = 64, num_speculative_tokens: int = 64,
use_new_beam_search_impl: bool = False, use_new_beam_search_impl: bool = False,
...@@ -138,7 +137,6 @@ def run_vllm( ...@@ -138,7 +137,6 @@ def run_vllm(
merge_lora=merge_lora, merge_lora=merge_lora,
lora_extra_vocab_size=lora_extra_vocab_size, lora_extra_vocab_size=lora_extra_vocab_size,
lora_target_modules=lora_target_modules, lora_target_modules=lora_target_modules,
tree_style_spec_decoding=tree_style_spec_decoding,
num_speculative_heads=num_speculative_heads, num_speculative_heads=num_speculative_heads,
num_speculative_tokens=num_speculative_tokens num_speculative_tokens=num_speculative_tokens
) )
...@@ -234,7 +232,6 @@ async def run_vllm_async( ...@@ -234,7 +232,6 @@ async def run_vllm_async(
merge_lora: bool = False, merge_lora: bool = False,
lora_extra_vocab_size: int = 0, lora_extra_vocab_size: int = 0,
lora_target_modules: List[str] = None, lora_target_modules: List[str] = None,
tree_style_spec_decoding: bool = False,
num_speculative_heads: int = 5, num_speculative_heads: int = 5,
num_speculative_tokens: int = 64, num_speculative_tokens: int = 64,
use_new_beam_search_impl: bool = False, use_new_beam_search_impl: bool = False,
...@@ -276,7 +273,6 @@ async def run_vllm_async( ...@@ -276,7 +273,6 @@ async def run_vllm_async(
merge_lora=merge_lora, merge_lora=merge_lora,
lora_extra_vocab_size=lora_extra_vocab_size, lora_extra_vocab_size=lora_extra_vocab_size,
lora_target_modules=lora_target_modules, lora_target_modules=lora_target_modules,
tree_style_spec_decoding=tree_style_spec_decoding,
num_speculative_heads=num_speculative_heads, num_speculative_heads=num_speculative_heads,
num_speculative_tokens=num_speculative_tokens num_speculative_tokens=num_speculative_tokens
) )
...@@ -350,7 +346,7 @@ def main(args: argparse.Namespace): ...@@ -350,7 +346,7 @@ def main(args: argparse.Namespace):
args.speculative_model, args.speculative_draft_tensor_parallel_size, args.speculative_model, args.speculative_draft_tensor_parallel_size,
args.speculative_disable_by_batch_size, args.spec_decoding_acceptance_method, args.speculative_disable_by_batch_size, args.spec_decoding_acceptance_method,
args.enable_lora, args.max_lora_rank, args.merge_lora, args.lora_extra_vocab_size, args.enable_lora, args.max_lora_rank, args.merge_lora, args.lora_extra_vocab_size,
args.lora_target_modules, args.tree_style_spec_decoding, args.num_speculative_heads, args.lora_target_modules, args.num_speculative_heads,
args.num_speculative_tokens args.num_speculative_tokens
] ]
else: else:
...@@ -368,7 +364,7 @@ def main(args: argparse.Namespace): ...@@ -368,7 +364,7 @@ def main(args: argparse.Namespace):
args.speculative_model, args.speculative_draft_tensor_parallel_size, args.speculative_model, args.speculative_draft_tensor_parallel_size,
args.speculative_disable_by_batch_size, args.spec_decoding_acceptance_method, args.speculative_disable_by_batch_size, args.spec_decoding_acceptance_method,
args.enable_lora, args.max_lora_rank, args.merge_lora, args.lora_extra_vocab_size, args.enable_lora, args.max_lora_rank, args.merge_lora, args.lora_extra_vocab_size,
args.lora_target_modules, args.tree_style_spec_decoding, args.num_speculative_heads, args.lora_target_modules, args.num_speculative_heads,
args.num_speculative_tokens args.num_speculative_tokens
] ]
...@@ -625,11 +621,6 @@ if __name__ == "__main__": ...@@ -625,11 +621,6 @@ if __name__ == "__main__":
default=None, default=None,
help='List of lora module name, If not specified, modules will be chosen according to the model architecture.') help='List of lora module name, If not specified, modules will be chosen according to the model architecture.')
parser.add_argument('--tree-style-spec-decoding',
type=bool,
default=False,
help='If set to True, tree-style generation will be activated.')
parser.add_argument( parser.add_argument(
'--num-speculative-heads', '--num-speculative-heads',
type=int, type=int,
......
...@@ -369,7 +369,7 @@ def main(args): ...@@ -369,7 +369,7 @@ def main(args):
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
save_model(medusa_model, os.path.join(args.output_dir, "model.safetensors")) save_model(medusa_model, os.path.join(args.output_dir, "model.safetensors"))
medusa_choices = ast.literal_eval(args.medusa_choices) medusa_choices = ast.literal_eval(args.medusa_choices) if args.medusa_choices is not None else None
to_save_config = CustomMedusaConfig(name_or_path=os.path.join(args.output_dir, "config.json"), to_save_config = CustomMedusaConfig(name_or_path=os.path.join(args.output_dir, "config.json"),
hidden_size=args.hidden_size, hidden_size=args.hidden_size,
num_heads=medusa_head_num, num_heads=medusa_head_num,
...@@ -403,7 +403,7 @@ if __name__ == "__main__": ...@@ -403,7 +403,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
'--medusa_choices', '--medusa_choices',
type=str, type=str,
required=True, default=None,
help="Medusa choice to use, if not none, will use Medusa decoding." help="Medusa choice to use, if not none, will use Medusa decoding."
" E.g.: [[0, 0, 0, 0], [0, 1, 0], [1, 0], [1, 1]] for 9 medusa tokens." " E.g.: [[0, 0, 0, 0], [0, 1, 0], [1, 0], [1, 1]] for 9 medusa tokens."
) )
......
...@@ -225,7 +225,7 @@ def test_paged_attention( ...@@ -225,7 +225,7 @@ def test_paged_attention(
opcheck(torch.ops._C.paged_attention_v1, opcheck(torch.ops._C.paged_attention_v1,
(output, query, key_cache, value_cache, num_kv_heads, scale, (output, query, key_cache, value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0, None, 0), kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0] cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0])) and block_size == BLOCK_SIZES[0]))
else: else:
...@@ -291,7 +291,7 @@ def test_paged_attention( ...@@ -291,7 +291,7 @@ def test_paged_attention(
(output, exp_sums, max_logits, tmp_output, query, (output, exp_sums, max_logits, tmp_output, query,
key_cache, value_cache, num_kv_heads, scale, block_tables, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0, None, 0), kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0] cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0])) and block_size == BLOCK_SIZES[0]))
else: else:
......
...@@ -60,16 +60,6 @@ class MockAttentionBackend(AttentionBackend): ...@@ -60,16 +60,6 @@ class MockAttentionBackend(AttentionBackend):
) -> None: ) -> None:
pass pass
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
NotImplementedError
def test_model_runner_input(): def test_model_runner_input():
sampling_metadata = SamplingMetadata( sampling_metadata = SamplingMetadata(
......
...@@ -89,6 +89,66 @@ def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: ...@@ -89,6 +89,66 @@ def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
# page attention ops # page attention ops
def paged_attention_v1( def paged_attention_v1(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0
) -> None:
torch.ops._C.paged_attention_v1(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step)
def paged_attention_v2(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0
) -> None:
torch.ops._C.paged_attention_v2(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
def paged_attention_v1_with_mask(
out: torch.Tensor, out: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
...@@ -111,7 +171,7 @@ def paged_attention_v1( ...@@ -111,7 +171,7 @@ def paged_attention_v1(
attn_masks: Optional[torch.Tensor] = None, attn_masks: Optional[torch.Tensor] = None,
attn_masks_stride: int = 0, attn_masks_stride: int = 0,
) -> None: ) -> None:
torch.ops._C.paged_attention_v1( torch.ops._C.paged_attention_v1_with_mask(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
k_scale, v_scale, tp_rank, blocksparse_local_blocks, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
...@@ -120,7 +180,7 @@ def paged_attention_v1( ...@@ -120,7 +180,7 @@ def paged_attention_v1(
attn_masks_stride) attn_masks_stride)
def paged_attention_v2( def paged_attention_v2_with_mask(
out: torch.Tensor, out: torch.Tensor,
exp_sum: torch.Tensor, exp_sum: torch.Tensor,
max_logits: torch.Tensor, max_logits: torch.Tensor,
...@@ -146,7 +206,7 @@ def paged_attention_v2( ...@@ -146,7 +206,7 @@ def paged_attention_v2(
attn_masks: Optional[torch.Tensor] = None, attn_masks: Optional[torch.Tensor] = None,
attn_masks_stride: int = 0, attn_masks_stride: int = 0,
) -> None: ) -> None:
torch.ops._C.paged_attention_v2( torch.ops._C.paged_attention_v2_with_mask(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
...@@ -157,6 +217,67 @@ def paged_attention_v2( ...@@ -157,6 +217,67 @@ def paged_attention_v2(
# page attention ops (opt) # page attention ops (opt)
def paged_attention_v1_opt( def paged_attention_v1_opt(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0
) -> None:
torch.ops._C.paged_attention_v1_opt(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step)
def paged_attention_v2_opt(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0
) -> None:
torch.ops._C.paged_attention_v2_opt(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
def paged_attention_v1_opt_with_mask(
out: torch.Tensor, out: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
...@@ -179,7 +300,7 @@ def paged_attention_v1_opt( ...@@ -179,7 +300,7 @@ def paged_attention_v1_opt(
attn_masks: Optional[torch.Tensor] = None, attn_masks: Optional[torch.Tensor] = None,
attn_masks_stride: int = 0, attn_masks_stride: int = 0,
) -> None: ) -> None:
torch.ops._C.paged_attention_v1_opt( torch.ops._C.paged_attention_v1_opt_with_mask(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
k_scale, v_scale, tp_rank, blocksparse_local_blocks, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
...@@ -188,7 +309,7 @@ def paged_attention_v1_opt( ...@@ -188,7 +309,7 @@ def paged_attention_v1_opt(
attn_masks_stride) attn_masks_stride)
def paged_attention_v2_opt( def paged_attention_v2_opt_with_mask(
out: torch.Tensor, out: torch.Tensor,
exp_sum: torch.Tensor, exp_sum: torch.Tensor,
max_logits: torch.Tensor, max_logits: torch.Tensor,
...@@ -214,7 +335,7 @@ def paged_attention_v2_opt( ...@@ -214,7 +335,7 @@ def paged_attention_v2_opt(
attn_masks: Optional[torch.Tensor] = None, attn_masks: Optional[torch.Tensor] = None,
attn_masks_stride: int = 0, attn_masks_stride: int = 0,
) -> None: ) -> None:
torch.ops._C.paged_attention_v2_opt( torch.ops._C.paged_attention_v2_opt_with_mask(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
...@@ -225,6 +346,67 @@ def paged_attention_v2_opt( ...@@ -225,6 +346,67 @@ def paged_attention_v2_opt(
# page attention ops (opt) # page attention ops (opt)
def paged_attention_v1_opt_tc( def paged_attention_v1_opt_tc(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0
) -> None:
torch.ops._C.paged_attention_v1_opt_tc(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
def paged_attention_v2_opt_tc(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0
) -> None:
torch.ops._C.paged_attention_v2_opt_tc(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
# page attention ops (opt)
def paged_attention_v1_opt_tc_with_mask(
out: torch.Tensor, out: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key_cache: torch.Tensor, key_cache: torch.Tensor,
...@@ -247,7 +429,7 @@ def paged_attention_v1_opt_tc( ...@@ -247,7 +429,7 @@ def paged_attention_v1_opt_tc(
attn_masks: Optional[torch.Tensor] = None, attn_masks: Optional[torch.Tensor] = None,
attn_masks_stride: int = 0, attn_masks_stride: int = 0,
) -> None: ) -> None:
torch.ops._C.paged_attention_v1_opt_tc( torch.ops._C.paged_attention_v1_opt_tc_with_mask(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
...@@ -255,7 +437,7 @@ def paged_attention_v1_opt_tc( ...@@ -255,7 +437,7 @@ def paged_attention_v1_opt_tc(
attn_masks, attn_masks_stride) attn_masks, attn_masks_stride)
def paged_attention_v2_opt_tc( def paged_attention_v2_opt_tc_with_mask(
out: torch.Tensor, out: torch.Tensor,
exp_sum: torch.Tensor, exp_sum: torch.Tensor,
max_logits: torch.Tensor, max_logits: torch.Tensor,
...@@ -281,7 +463,7 @@ def paged_attention_v2_opt_tc( ...@@ -281,7 +463,7 @@ def paged_attention_v2_opt_tc(
attn_masks: Optional[torch.Tensor] = None, attn_masks: Optional[torch.Tensor] = None,
attn_masks_stride: int = 0, attn_masks_stride: int = 0,
) -> None: ) -> None:
torch.ops._C.paged_attention_v2_opt_tc( torch.ops._C.paged_attention_v2_opt_tc_with_mask(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
......
...@@ -83,17 +83,6 @@ class AttentionBackend(ABC): ...@@ -83,17 +83,6 @@ class AttentionBackend(ABC):
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
@staticmethod
@abstractmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
raise NotImplementedError
def advance_step(self, model_input: "ModelRunnerInputBase", def advance_step(self, model_input: "ModelRunnerInputBase",
sampled_token_ids: Optional[torch.Tensor], sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int) -> None: block_size: int, num_seqs: int, num_queries: int) -> None:
...@@ -206,8 +195,7 @@ class AttentionMetadataBuilder(ABC, Generic[T]): ...@@ -206,8 +195,7 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
@abstractmethod @abstractmethod
def build(self, seq_lens: List[int], query_lens: List[int], def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int, cuda_graph_pad_size: int, batch_size: int) -> T:
tree_attention_masks_tensor: Optional[torch.Tensor] = None) -> T:
"""Build attention metadata with on-device tensors.""" """Build attention metadata with on-device tensors."""
raise NotImplementedError raise NotImplementedError
......
...@@ -129,50 +129,6 @@ class BlocksparseFlashAttentionBackend(AttentionBackend): ...@@ -129,50 +129,6 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
) -> None: ) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists) PagedAttention.copy_blocks(kv_caches, src_to_dists)
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
key_caches = []
value_caches = []
num_layers = len(kv_caches)
token_num = src_to_dists.shape[0]
tmp_store_kv = torch.empty(
(2, num_layers, token_num, num_kv_heads, head_size),
dtype=kv_caches[0].dtype, device=kv_caches[0].device)
keys = tmp_store_kv[0].contiguous()
values = tmp_store_kv[1].contiguous()
for kv_cache in kv_caches:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, num_kv_heads, head_size)
key_caches.append(key_cache)
value_caches.append(value_cache)
ops.read_cache(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 0].contiguous(),
kv_cache_dtype
)
ops.write_cache_multi_layers(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 1].contiguous(),
kv_cache_dtype
)
@dataclass @dataclass
class BlocksparseFlashAttentionMetadata(AttentionMetadata): class BlocksparseFlashAttentionMetadata(AttentionMetadata):
...@@ -236,8 +192,6 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): ...@@ -236,8 +192,6 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
_cached_decode_metadata: Optional[ _cached_decode_metadata: Optional[
"BlocksparseFlashAttentionMetadata"] = None "BlocksparseFlashAttentionMetadata"] = None
tree_attention_masks_tensor: Optional[torch.Tensor] = None
block_tables_list: Optional[List[int]] = None block_tables_list: Optional[List[int]] = None
@property @property
...@@ -271,7 +225,6 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): ...@@ -271,7 +225,6 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
context_lens_tensor=self.context_lens_tensor[:self.num_prefills], context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills], block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False, use_cuda_graph=False,
tree_attention_masks_tensor=self.tree_attention_masks_tensor,
block_tables_list=self.block_tables_list block_tables_list=self.block_tables_list
) )
return self._cached_prefill_metadata return self._cached_prefill_metadata
...@@ -301,7 +254,6 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): ...@@ -301,7 +254,6 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
context_lens_tensor=None, context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:], block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph, use_cuda_graph=self.use_cuda_graph,
tree_attention_masks_tensor=self.tree_attention_masks_tensor,
block_tables_list=self.block_tables_list block_tables_list=self.block_tables_list
) )
return self._cached_decode_metadata return self._cached_decode_metadata
......
...@@ -221,16 +221,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -221,16 +221,6 @@ class FlashAttentionBackend(AttentionBackend):
value_caches = [kv_cache[1] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists) ops.copy_blocks(key_caches, value_caches, src_to_dists)
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
NotImplementedError
@dataclass @dataclass
class FlashAttentionMetadata(AttentionMetadata): class FlashAttentionMetadata(AttentionMetadata):
......
...@@ -93,16 +93,6 @@ class FlashInferBackend(AttentionBackend): ...@@ -93,16 +93,6 @@ class FlashInferBackend(AttentionBackend):
else: else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
NotImplementedError
class FlashInferState(AttentionState): class FlashInferState(AttentionState):
......
...@@ -62,16 +62,6 @@ class IpexAttnBackend(AttentionBackend): ...@@ -62,16 +62,6 @@ class IpexAttnBackend(AttentionBackend):
value_caches = [kv_cache[1] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists) ops.copy_blocks(key_caches, value_caches, src_to_dists)
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
NotImplementedError
@dataclass @dataclass
class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata): class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment