Commit 19bc93d9 authored by 王敏's avatar 王敏
Browse files

增加medusa并行解码功能,后续增加使用说明和测试文档

parent aba40fda
......@@ -107,7 +107,8 @@ __device__ void paged_attention_kernel(
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z;
......@@ -299,7 +300,14 @@ __device__ void paged_attention_kernel(
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= seq_len;
bool mask = token_idx >= seq_len;
// used for tree-style attention
if (attn_masks != nullptr) {
const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride;
mask |= attn_masks_ptr[token_idx] == 0;
}
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
// Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
......@@ -515,7 +523,8 @@ __global__ void paged_attention_v1_kernel(
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
......@@ -523,7 +532,7 @@ __global__ void paged_attention_v1_kernel(
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
blocksparse_head_sliding_step, attn_masks, attn_masks_stride);
}
// Grid: (num_heads, num_seqs, max_num_partitions).
......@@ -551,14 +560,15 @@ __global__ void paged_attention_v2_kernel(
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
blocksparse_head_sliding_step, attn_masks, attn_masks_stride);
}
// Grid: (num_heads, num_seqs).
......@@ -684,7 +694,8 @@ __global__ void paged_attention_v2_reduce_kernel(
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
blocksparse_head_sliding_step, attn_masks_ptr, \
attn_masks_stride);
// TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE,
......@@ -697,7 +708,9 @@ void paged_attention_v1_launcher(
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
......@@ -722,6 +735,12 @@ void paged_attention_v1_launcher(
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
// NOTE: attn_masks is optional.
const int* attn_masks_ptr =
attn_masks
? attn_masks.value().data_ptr<int>()
: nullptr;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_seq_len =
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
......@@ -778,7 +797,8 @@ void paged_attention_v1_launcher(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);
blocksparse_block_size, blocksparse_head_sliding_step, \
attn_masks, attn_masks_stride);
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
......@@ -824,7 +844,9 @@ void paged_attention_v1(
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
......@@ -841,7 +863,8 @@ void paged_attention_v1(
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
blocksparse_block_size, blocksparse_head_sliding_step, \
attn_masks_ptr, attn_masks_stride); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE> \
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
......@@ -859,7 +882,9 @@ void paged_attention_v2_launcher(
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int attn_masks_stride) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
......@@ -887,6 +912,10 @@ void paged_attention_v2_launcher(
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
// NOTE: attn_masks is optional.
const int* attn_masks_ptr =
attn_masks ? attn_masks.value().data_ptr<int>() : nullptr;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
int logits_size = PARTITION_SIZE * sizeof(float);
......@@ -946,7 +975,7 @@ void paged_attention_v2_launcher(
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
blocksparse_head_sliding_step, attn_masks, attn_masks_stride);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
......@@ -996,7 +1025,9 @@ void paged_attention_v2(
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE)
......
......@@ -94,7 +94,8 @@ __device__ void paged_attention_kernel_opt(
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
const int seq_idx = blockIdx.z;
const int partition_idx = blockIdx.y;
const int max_num_partitions = gridDim.y;
......@@ -330,7 +331,13 @@ __device__ void paged_attention_kernel_opt(
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= seq_len;
bool mask = token_idx >= seq_len;
// used for tree-style attention
if (attn_masks != nullptr) {
const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride;
mask |= attn_masks_ptr[token_idx] == 0;
}
logits[(reuse_kv_idx * partition_size) + (token_idx - start_token_idx)] = mask ? 0.f : qk;
// Update the max value.
qk_max[reuse_kv_idx] = mask ? qk_max[reuse_kv_idx] : fmaxf(qk_max[reuse_kv_idx], qk);
......@@ -611,7 +618,8 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel_opt(
const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
paged_attention_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
......@@ -619,7 +627,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel_opt(
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
blocksparse_head_sliding_step, attn_masks, attn_masks_stride);
}
// Grid: (num_heads, num_seqs, max_num_partitions).
......@@ -651,14 +659,15 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel_opt(
const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
paged_attention_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES, odd_nheads, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
blocksparse_head_sliding_step, attn_masks, attn_masks_stride);
}
// Grid: (num_heads, num_seqs).
......@@ -784,7 +793,8 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
blocksparse_head_sliding_step, attn_masks_ptr, \
attn_masks_stride);
// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
// vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
......@@ -807,7 +817,9 @@ void paged_attention_v1_launcher(
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
......@@ -836,6 +848,12 @@ void paged_attention_v1_launcher(
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
// NOTE: attn_masks is optional.
const int* attn_masks_ptr =
attn_masks
? attn_masks.value().data_ptr<int>()
: nullptr;
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
REUSEKV_SWITCH_V1(num_heads * num_seqs , [&] {
BOOL_SWITCH((num_heads/num_kv_heads % REUSE_KV_TIMES != 0), odd_nheads, [&] {
......@@ -869,7 +887,8 @@ void paged_attention_v1_launcher(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);
blocksparse_block_size, blocksparse_head_sliding_step, \
attn_masks, attn_masks_stride);
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
......@@ -915,7 +934,9 @@ void paged_attention_v1_opt(
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
......@@ -932,7 +953,8 @@ void paged_attention_v1_opt(
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
blocksparse_block_size, blocksparse_head_sliding_step, \
attn_masks_ptr, attn_masks_stride); \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>) \
, dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, \
......@@ -950,7 +972,9 @@ void paged_attention_v2_launcher(
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int attn_masks_stride) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
......@@ -978,6 +1002,10 @@ void paged_attention_v2_launcher(
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
// NOTE: attn_masks is optional.
const int* attn_masks_ptr =
attn_masks ? attn_masks.value().data_ptr<int>() : nullptr;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
REUSEKV_SWITCH(num_heads * max_num_partitions * num_seqs , [&] {
......@@ -1016,7 +1044,7 @@ void paged_attention_v2_launcher(
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
blocksparse_head_sliding_step, attn_masks, attn_masks_stride);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
......@@ -1066,7 +1094,9 @@ void paged_attention_v2_opt(
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE)
......
......@@ -168,7 +168,8 @@ __device__ void paged_attention_kernel_TC(
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
const int seq_idx = blockIdx.z;
const int partition_idx = blockIdx.y;
const int max_num_partitions = gridDim.y;
......@@ -291,7 +292,13 @@ __device__ void paged_attention_kernel_TC(
float alibi=alibi_slope[i]* (token_idx - seq_len + 1);
qk_vec[i] += alibi;
}
const bool mask = (token_idx >= seq_len);
bool mask = (token_idx >= seq_len);
// used for tree-style attention
if (attn_masks != nullptr) {
const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride;
mask |= attn_masks_ptr[token_idx] == 0;
}
if(mask){
from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f);
}
......@@ -555,7 +562,8 @@ __global__ void paged_attention_v1_kernel_TC(
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
#ifdef __gfx928__
paged_attention_kernel_TC<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE,REUSE_KV_TIMES,use_vmac>(
......@@ -564,7 +572,7 @@ __global__ void paged_attention_v1_kernel_TC(
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
blocksparse_head_sliding_step, attn_masks, attn_masks_stride);
#endif
}
......@@ -594,7 +602,8 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
#ifdef __gfx928__
paged_attention_kernel_TC<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES,use_vmac,
......@@ -603,7 +612,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq,
alibi_slopes, q_stride, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
blocksparse_head_sliding_step, attn_masks, attn_masks_stride);
#endif
}
......@@ -730,7 +739,8 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
blocksparse_head_sliding_step, attn_masks_ptr, \
attn_masks_stride);
void get_numberthread_and_reuse_kv_v1(int& num_thread,int& reusekv,int batchsize,int seq,int qheads,int kvheads){
//mha
......@@ -796,7 +806,9 @@ void paged_attention_v1_launcher_opt_tc(
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
......@@ -824,6 +836,13 @@ void paged_attention_v1_launcher_opt_tc(
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
// NOTE: attn_masks is optional.
const int* attn_masks_ptr =
attn_masks
? attn_masks.value().data_ptr<int>()
: nullptr;
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
......@@ -858,7 +877,8 @@ void paged_attention_v1_launcher_opt_tc(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);
blocksparse_block_size, blocksparse_head_sliding_step, \
attn_masks, attn_masks_stride);
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
......@@ -904,7 +924,9 @@ void paged_attention_v1_opt(
const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride);
void paged_attention_v1_opt_tc(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
......@@ -922,14 +944,17 @@ void paged_attention_v1_opt_tc(
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||get_device_name()!="gfx928"){
paged_attention_v1_opt(out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step);
blocksparse_block_size,blocksparse_head_sliding_step,
attn_masks, attn_masks_stride);
}
else{
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
......@@ -948,7 +973,8 @@ void paged_attention_v1_opt_tc(
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step); \
blocksparse_head_sliding_step, \
attn_masks_ptr, attn_masks_stride); \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>), \
......@@ -999,7 +1025,9 @@ void paged_attention_v2_launcher_opt_tc(
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int attn_masks_stride) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
......@@ -1026,6 +1054,11 @@ void paged_attention_v2_launcher_opt_tc(
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
// NOTE: attn_masks is optional.
const int* attn_masks_ptr =
attn_masks ? attn_masks.value().data_ptr<int>() : nullptr;
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 reduce_grid(num_heads, num_seqs);
......@@ -1067,7 +1100,7 @@ void paged_attention_v2_launcher_opt_tc(
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
blocksparse_head_sliding_step, attn_masks, attn_masks_stride);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
......@@ -1117,7 +1150,9 @@ void paged_attention_v2_opt(
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride);
void paged_attention_v2_opt_tc(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
......@@ -1139,14 +1174,17 @@ void paged_attention_v2_opt_tc(
const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||get_device_name()!="gfx928"){
paged_attention_v2_opt(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step);
blocksparse_block_size,blocksparse_head_sliding_step, attn_masks,
attn_masks_stride);
}
else{
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
......
......@@ -31,3 +31,19 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
// Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);
void read_cache(
torch::Tensor& keys,
torch::Tensor& values,
std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);
void write_cache_multi_layers(
torch::Tensor& keys,
torch::Tensor& values,
std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);
......@@ -245,6 +245,133 @@ __global__ void reshape_and_cache_flash_kernel(
}
}
}
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void write_cache_multi_layers_kernel(
scalar_t* __restrict__ keys, // [num_layers, num_tokens, num_heads, head_size]
scalar_t* __restrict__ values, // [num_layers, num_tokens, num_heads, head_size]
int64_t* key_cache_ptrs, // [num_blocks, num_heads, head_size/x,
// block_size, x]
int64_t* value_cache_ptrs, // [num_blocks, num_heads, head_size,
// block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int value_stride,
const int num_heads, const int head_size, const int block_size,
const int x, const int num_tokens) {
const int layer_idx = blockIdx.x;
const int token_idx = blockIdx.y;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
// Padding token that should be ignored.
return;
}
cache_t* key_cache = reinterpret_cast<cache_t*>(key_cache_ptrs[layer_idx]);
cache_t* value_cache =
reinterpret_cast<cache_t*>(value_cache_ptrs[layer_idx]);
scalar_t* key = keys + layer_idx * num_tokens * key_stride;
scalar_t* value = values + layer_idx * num_tokens * value_stride;
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int x_idx = head_offset / x;
const int x_offset = head_offset % x;
const int64_t tgt_key_idx =
block_idx * num_heads * (head_size / x) * block_size * x +
head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
block_offset * x + x_offset;
const int64_t tgt_value_idx =
block_idx * num_heads * head_size * block_size +
head_idx * head_size * block_size + head_offset * block_size +
block_offset;
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value;
} else {
key_cache[tgt_key_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, 1.0);
value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, 1.0);
}
}
}
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void read_cache_kernel(
scalar_t* __restrict__ keys, // [num_layers, num_tokens, num_heads, head_size]
scalar_t* __restrict__ values, // [num_layers, num_tokens, num_heads, head_size]
int64_t* key_cache_ptrs, // [num_blocks, num_heads, head_size/x,
// block_size, x]
int64_t* value_cache_ptrs, // [num_blocks, num_heads, head_size,
// block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int value_stride,
const int num_heads, const int head_size, const int block_size,
const int x, const int num_tokens) {
const int layer_idx = blockIdx.x;
const int token_idx = blockIdx.y;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
// Padding token that should be ignored.
return;
}
cache_t* key_cache = reinterpret_cast<cache_t*>(key_cache_ptrs[layer_idx]);
cache_t* value_cache =
reinterpret_cast<cache_t*>(value_cache_ptrs[layer_idx]);
scalar_t* key = keys + layer_idx * num_tokens * key_stride;
scalar_t* value = values + layer_idx * num_tokens * value_stride;
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int x_idx = head_offset / x;
const int x_offset = head_offset % x;
const int64_t src_key_idx =
block_idx * num_heads * (head_size / x) * block_size * x +
head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
block_offset * x + x_offset;
const int64_t src_value_idx =
block_idx * num_heads * head_size * block_size +
head_idx * head_size * block_size + head_offset * block_size +
block_offset;
const int64_t tgt_key_idx = token_idx * key_stride + i;
const int64_t tgt_value_idx = token_idx * value_stride + i;
cache_t tgt_key = key_cache[src_key_idx];
cache_t tgt_value = value_cache[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key[tgt_key_idx] = tgt_key;
value[tgt_value_idx] = tgt_value;
} else {
key[tgt_key_idx] = fp8::scaled_convert<scalar_t, cache_t, kv_dt>(tgt_key, 1.0);
value[tgt_value_idx] = fp8::scaled_convert<scalar_t, cache_t, kv_dt>(tgt_value, 1.0);
}
}
}
} // namespace vllm
// KV_T is the stored data type of kv-cache.
......@@ -329,6 +456,151 @@ void reshape_and_cache_flash(
CALL_RESHAPE_AND_CACHE_FLASH);
}
// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_READ_CACHE(KV_T, CACHE_T, KV_DTYPE) \
vllm::read_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(keys.data_ptr()), \
reinterpret_cast<KV_T*>(values.data_ptr()), \
key_cache_ptrs_tensor.data_ptr<int64_t>(), \
value_cache_ptrs_tensor.data_ptr<int64_t>(), \
slot_mapping.data_ptr<int64_t>(), \
key_stride, value_stride, \
num_heads, head_size, block_size, x, num_tokens);
void read_cache(
torch::Tensor& keys, // [num_layers, seq_len, num_heads, head_size]
torch::Tensor& values, // [num_layers, seq_len, num_heads, head_size]
std::vector<torch::Tensor> const& key_caches, // [num_blocks, num_heads, head_size/x, block_size, x]
std::vector<torch::Tensor> const& value_caches, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype) {
int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) {
return;
}
torch::Device cache_device = key_caches[0].device();
TORCH_CHECK(cache_device.is_cuda());
// Create data structures for the kernel.
// Create an array of pointers to the key and value and caches.
int64_t key_cache_ptrs[num_layers];
int64_t value_cache_ptrs[num_layers];
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
key_cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
value_cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
}
int num_tokens = keys.size(1);
auto kv_dtype = keys.dtype();
torch::Tensor key_cache = key_caches[0];
torch::Tensor value_cache = value_caches[0];
int key_stride = keys.stride(1);
int value_stride = values.stride(1);
int num_heads = value_cache.size(1);
int head_size = value_cache.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
// Move the data structures to the GPU.
// NOTE: This synchronizes the CPU and GPU.
torch::Tensor key_cache_ptrs_tensor =
torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
.to(cache_device);
torch::Tensor value_cache_ptrs_tensor =
torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
.to(cache_device);
dim3 grid(num_layers, num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(slot_mapping));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(kv_dtype, kv_cache_dtype,
CALL_READ_CACHE);
}
// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_WRITE_CACHE_MULTI_LAYERS(KV_T, CACHE_T, KV_DTYPE) \
vllm::write_cache_multi_layers_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(keys.data_ptr()), \
reinterpret_cast<KV_T*>(values.data_ptr()), \
key_cache_ptrs_tensor.data_ptr<int64_t>(), \
value_cache_ptrs_tensor.data_ptr<int64_t>(), \
slot_mapping.data_ptr<int64_t>(), \
key_stride, value_stride, \
num_heads, head_size, block_size, x, num_tokens);
void write_cache_multi_layers(
torch::Tensor& keys, // [num_layers, seq_len, num_heads, head_size]
torch::Tensor& values, // [num_layers, seq_len, num_heads, head_size]
std::vector<torch::Tensor> const& key_caches, // [num_blocks, num_heads, head_size/x, block_size, x]
std::vector<torch::Tensor> const& value_caches, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype) {
int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) {
return;
}
torch::Device cache_device = key_caches[0].device();
TORCH_CHECK(cache_device.is_cuda());
// Create data structures for the kernel.
// Create an array of pointers to the key and value and caches.
int64_t key_cache_ptrs[num_layers];
int64_t value_cache_ptrs[num_layers];
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
key_cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
value_cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
}
auto kv_dtype = keys.dtype();
int num_tokens = keys.size(1);
torch::Tensor key_cache = key_caches[0];
torch::Tensor value_cache = value_caches[0];
int key_stride = keys.stride(1);
int value_stride = values.stride(1);
int num_heads = value_cache.size(1);
int head_size = value_cache.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
// Move the data structures to the GPU.
// NOTE: This synchronizes the CPU and GPU.
torch::Tensor key_cache_ptrs_tensor =
torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
.to(cache_device);
torch::Tensor value_cache_ptrs_tensor =
torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
.to(cache_device);
dim3 grid(num_layers, num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(slot_mapping));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(kv_dtype, kv_cache_dtype,
CALL_WRITE_CACHE_MULTI_LAYERS);
}
namespace vllm {
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
......
......@@ -136,3 +136,24 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const torch::Tensor& block_mapping) {
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
}
void read_cache(
std::vector<torch::Tensor> const& keys,
std::vector<torch::Tensor> const& values,
std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype) {
TORCH_CHECK(false, "read_cache is unsupported on CPU.")
}
void write_cache_multi_layers(
std::vector<torch::Tensor> const& keys,
std::vector<torch::Tensor> const& values,
std::vector<torch::Tensor> const& key_caches,
std::vector<torch::Tensor> const& value_caches,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype) {
TORCH_CHECK(false, "write_cache_multi_layers is unsupported on CPU.")
}
......@@ -13,7 +13,9 @@ void paged_attention_v1(
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0);
void paged_attention_v2(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
......@@ -24,7 +26,9 @@ void paged_attention_v2(
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0);
void paged_attention_v1_opt(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
......@@ -34,7 +38,9 @@ void paged_attention_v1_opt(
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0);
void paged_attention_v2_opt(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
......@@ -45,7 +51,9 @@ void paged_attention_v2_opt(
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0);
void paged_attention_v1_opt_tc(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
......@@ -55,7 +63,9 @@ void paged_attention_v1_opt_tc(
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0);
void paged_attention_v2_opt_tc(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
......@@ -66,7 +76,9 @@ void paged_attention_v2_opt_tc(
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0);
void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
double epsilon);
......
......@@ -30,7 +30,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()");
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
// PagedAttention V2.
......@@ -44,7 +46,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
// Compute the attention between an input query and the cached
......@@ -58,7 +62,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()");
ops.impl("paged_attention_v1_opt", torch::kCUDA, &paged_attention_v1_opt);
// PagedAttention V2 (opt).
......@@ -72,7 +78,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()");
ops.impl("paged_attention_v2_opt", torch::kCUDA, &paged_attention_v2_opt);
// Compute the attention between an input query and the cached
......@@ -86,7 +94,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()");
ops.impl("paged_attention_v1_opt_tc", torch::kCUDA, &paged_attention_v1_opt_tc);
// PagedAttention V2 (opt).
......@@ -100,7 +110,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" str kv_cache_dtype, float k_scale, float v_scale,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
" int blocksparse_head_sliding_step,"
" Tensor? attn_masks,"
" int attn_masks_stride) -> ()");
ops.impl("paged_attention_v2_opt_tc", torch::kCUDA, &paged_attention_v2_opt_tc);
// Activation ops
......@@ -479,6 +491,22 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
&reshape_and_cache_flash);
// read key and value form kv cache
cache_ops.def(
"read_cache(Tensor keys, Tensor values,"
" Tensor[]! key_caches, Tensor[]! value_caches,"
" Tensor slot_mapping,"
" str kv_cache_dtype) -> ()");
cache_ops.impl("read_cache", torch::kCUDA, &read_cache);
// write multi-layers key and value to kv cache
cache_ops.def(
"write_cache_multi_layers(Tensor keys, Tensor values,"
" Tensor[]! key_caches, Tensor[]! value_caches,"
" Tensor slot_mapping,"
" str kv_cache_dtype) -> ()");
cache_ops.impl("write_cache_multi_layers", torch::kCUDA, &write_cache_multi_layers);
// Convert the key and value cache to fp8 data type.
cache_ops.def(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
......
import os
from pathlib import Path
from typing import Iterable, List, Optional, Tuple, Union
from addict import Dict
import yaml
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from transformers import PretrainedConfig
from safetensors.torch import save_model, safe_open
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64
TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE = 'base_model.model.medusa_head.{}.{}.linear.weight'
TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE = 'base_model.model.medusa_head.{}.1.weight'
VLLM_BLOCK_WEIGHT_NAME_TEMPLATE = 'blocks.{}.layers.{}.weight'
VLLM_MEDUSA_HEADS_WEIGHT_NAME_TEMPLATE = 'lm_heads.{}.weight'
MEDUSA_CHOICES = [(0,), (0, 0), (0, 0, 0), (1,), (0, 1), (1, 0), (0, 2), (2,), (0, 0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 3), (2, 0), (1, 0, 0), (3,), (0, 0, 2), (0, 4), (0, 2, 0), (0, 5), (4,), (1, 1), (0, 0, 3), (3, 0), (0, 6), (0, 0, 0, 1), (0, 3, 0), (0, 0, 4), (0, 0, 1, 0), (2, 0, 0), (5,), (0, 1, 0, 0), (0, 7)]
def default_weight_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
def pad_vocab_size(vocab_size: int,
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
class MedusaConfig(PretrainedConfig):
model_type = "medusa"
def __init__(self,
hidden_size: int = 4096,
vocab_size: int = 32001,
num_heads: int = 5,
num_hidden_layers: int = 1,
max_paths: int = 64,
topk: int = 10,
truncated_vocab_size: Optional[int] = None,
**kwargs):
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.num_heads = num_heads
self.num_hidden_layers = num_hidden_layers
self.max_paths = max_paths
self.topk = topk
self.max_seq_len = int(2**20)
self.truncated_vocab_size = vocab_size if truncated_vocab_size is None\
else truncated_vocab_size
if "architectures" not in kwargs:
kwargs["architectures"] = ["MedusaModel"]
super().__init__(**kwargs)
@property
def num_attention_heads(self):
return 0
@property
def num_lookahead_tokens(self):
return self.num_heads
@num_lookahead_tokens.setter
def num_lookahead_tokens(self, num_lookahead_tokens: int):
self.num_heads = num_lookahead_tokens
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
make sure it is divisible by the number of model parallel GPUs.
In order to support various loading methods, we ensure that LoRA-added
embeddings are always at the end of TP-sharded tensors. In other words,
we shard base embeddings and LoRA embeddings separately (both padded),
and place them in the same tensor.
In this example, we will have the original vocab size = 1010,
added vocab size = 16 and padding to 64. Therefore, the total
vocab size with padding will be 1088 (because we first pad 1010 to
1024, add 16, and then pad to 1088).
Therefore, the tensor format looks like the following:
TP1, rank 0 (no sharding):
|< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 |
index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
TP2, rank 0:
|< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 |
TP2, rank 1:
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
quant_config: quant config for the layer
prefix: full name of the layer in the state dict
""" # noqa: E501
def __init__(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.num_embeddings = num_embeddings
self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
self.padding_size)
self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings,
self.padding_size)
assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.embedding_dim = embedding_dim
linear_method = None
if quant_config is not None:
linear_method = quant_config.get_quant_method(self, prefix=prefix)
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method: QuantizeMethodBase = linear_method
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.linear_method.create_weights(self,
self.embedding_dim,
[self.num_embeddings_padded],
self.embedding_dim,
self.num_embeddings_padded,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
assert param.data.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
def forward(self, input_):
masked_input = input_
# Get the embeddings.
output = F.embedding(masked_input.long(), self.weight)
return output
class ParallelLMHead(VocabParallelEmbedding):
"""Parallelized LM head.
Output logits weight matrices used in the Sampler. The weight and bias
tensors are padded to make sure they are divisible by the number of
model parallel GPUs.
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
bias: whether to use bias.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size, quant_config,
prefix)
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def forward(self, input_):
del input_
raise RuntimeError("LMHead's weights should be used in the sampler.")
class ResidualBlock(nn.Module):
def __init__(self, hidden_size: int, num_layers: int) -> None:
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(hidden_size, hidden_size, bias=False)
for _ in range(num_layers)
])
self.act = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
x = x + self.act(layer(x))
return x
class Medusa(nn.Module):
def __init__(self, config: MedusaConfig, **_) -> None:
super().__init__()
self.config = config
self.blocks = nn.ModuleList([
ResidualBlock(hidden_size=self.config.hidden_size,
num_layers=self.config.num_hidden_layers)
for _ in range(self.config.num_heads)
])
self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size
self.unpadded_vocab_size = self.truncated_vocab_size
self.lm_heads = nn.ModuleList([
ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=self.truncated_vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
) for _ in range(self.config.num_heads)
])
logit_scale = getattr(config, "logit_scale", 1.0)
self.token_map = None
def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
return [block(hidden_states) for block in self.blocks]
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
weights_map = {}
for name, loaded_weight in weights:
name = name.replace("medusa_heads.", "")
if name == "token_map":
if self.truncated_vocab_size < self.orig_vocab_size:
self.token_map = nn.Parameter(loaded_weight,
requires_grad=False)
elif name in params_dict:
weights_map[name] = loaded_weight
for name, loaded_weight in weights_map.items():
if "lm_head" in name and self.token_map is not None and\
loaded_weight.shape[0] > self.token_map.shape[0]:
loaded_weight = loaded_weight[self.token_map]
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if self.token_map is not None:
self.token_map.to(device=self.lm_heads[0].weight.device)
assert (self.truncated_vocab_size
== self.orig_vocab_size) or (self.token_map is not None)
class CustomMedusaConfig(PretrainedConfig):
model_type = "medusa"
def __init__(self,
name_or_path: str = "sugon/vllm-medusa-qwen1.5-7b-chat",
architectures: list[str] = ["MedusaModel"],
hidden_size: int = 4096,
model_type: str = "medusa",
num_heads: int = 5,
num_hidden_layers: int = 1,
transformers_version: str = "4.41.2",
truncated_vocab_size: Optional[int] = None,
vocab_size: int = 151936,
medusa_choices:List[List[int]] = None,
**kwargs):
super().__init__(**kwargs)
self._name_or_path = name_or_path
self.architectures = architectures
self.hidden_size = hidden_size
self.model_type = model_type
self.num_heads = num_heads
self.num_hidden_layers = num_hidden_layers
self.transformers_version = transformers_version
self.truncated_vocab_size = truncated_vocab_size
self.vocab_size = vocab_size
self.medusa_choices = medusa_choices
def main(args):
# load the medusa config from the yaml file
medusa_config_path=args.medusa_config_path
with open(medusa_config_path, encoding="utf-8") as file:
medusa_cfg: Dict = Dict(yaml.safe_load(file))
medusa_head_num = medusa_cfg.medusa_num_heads
medusa_num_layers = medusa_cfg.medusa_num_layers
config = MedusaConfig(hidden_size=args.hidden_size, vocab_size=args.vocab_size, num_heads=medusa_head_num)
medusa_model = Medusa(config)
params_dict = dict(medusa_model.named_parameters())
trained_medusa_model = torch.load(args.medusa_model_path)
for i in range(medusa_head_num):
vllm_medusa_head_weight_name = VLLM_MEDUSA_HEADS_WEIGHT_NAME_TEMPLATE.format(i)
trained_medusa_head_weight_name = TRAINED_MEDUSA_HEADS_NEMA_TEMPLATE.format(i)
vllm_medusa_head_param = params_dict[vllm_medusa_head_weight_name]
trained_medusa_head_param = trained_medusa_model[trained_medusa_head_weight_name]
weight_loader = getattr(vllm_medusa_head_param, "weight_loader",
default_weight_loader)
weight_loader(vllm_medusa_head_param, trained_medusa_head_param)
for i in range(medusa_head_num):
for j in range(medusa_num_layers):
vllm_medusa_block_weight_name = VLLM_BLOCK_WEIGHT_NAME_TEMPLATE.format(i, j)
trained_medusa_block_weight_name = TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE.format(i, j)
vllm_medusa_block_param = params_dict[vllm_medusa_block_weight_name]
trained_medusa_block_param = trained_medusa_model[trained_medusa_block_weight_name]
weight_loader = getattr(vllm_medusa_block_param, "weight_loader",
default_weight_loader)
weight_loader(vllm_medusa_block_param, trained_medusa_block_param)
if not Path(args.output_dir).is_dir():
os.makedirs(args.output_dir, exist_ok=True)
save_model(medusa_model, os.path.join(args.output_dir, "model.safetensors"))
to_save_config = CustomMedusaConfig(name_or_path=os.path.join(args.output_dir, "config.json"),
hidden_size=args.hidden_size,
num_heads=medusa_head_num,
num_hidden_layers=medusa_num_layers,
vocab_size=args.vocab_size,
medusa_choices=MEDUSA_CHOICES)
to_save_config.save_pretrained(args.output_dir)
# validate weight
# with safe_open("model.safetensors", framework="pt") as f:
# param = f.get_tensor(VLLM_BLOCK_WEIGHT_NAME_TEMPLATE.format(0, 0))
# trained_param = trained_medusa_model[TRAINED_BLOCK_WEIGHT_NAME_TEMPLATE.format(0, 0)]
# mse_value = torch.nn.functional.mse_loss(param.cpu(), trained_param.cpu())
# print("weight mes:", mse_value)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Medusa Model Evaluator")
parser.add_argument("--medusa_config_path", type=str, required=True,
help="Path to the medusa config file.")
parser.add_argument("--medusa_model_path", type=str, required=True,
help="Path to the medusa model file.")
parser.add_argument("--vocab_size", type=int, required=True,
help="Vocab size")
parser.add_argument("--hidden_size", type=int, required=True,
help="Hidden size")
parser.add_argument("--output_dir", type=str, required=True,
help="Output dir")
args = parser.parse_args()
main(args)
......@@ -108,13 +108,16 @@ def paged_attention_v1(
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
attn_masks: Optional[torch.Tensor] = None,
attn_masks_stride: int = 0,
) -> None:
torch.ops._C.paged_attention_v1(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step)
blocksparse_head_sliding_step,attn_masks,
attn_masks_stride)
def paged_attention_v2(
......@@ -140,13 +143,16 @@ def paged_attention_v2(
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
attn_masks: Optional[torch.Tensor] = None,
attn_masks_stride: int = 0,
) -> None:
torch.ops._C.paged_attention_v2(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
blocksparse_block_size, blocksparse_head_sliding_step,
attn_masks, attn_masks_stride)
# page attention ops (opt)
......@@ -170,13 +176,16 @@ def paged_attention_v1_opt(
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
attn_masks: Optional[torch.Tensor] = None,
attn_masks_stride: int = 0,
) -> None:
torch.ops._C.paged_attention_v1_opt(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step)
blocksparse_head_sliding_step, attn_masks,
attn_masks_stride)
def paged_attention_v2_opt(
......@@ -202,13 +211,16 @@ def paged_attention_v2_opt(
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
attn_masks: Optional[torch.Tensor] = None,
attn_masks_stride: int = 0,
) -> None:
torch.ops._C.paged_attention_v2_opt(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
blocksparse_block_size, blocksparse_head_sliding_step,
attn_masks, attn_masks_stride)
# page attention ops (opt)
......@@ -232,12 +244,15 @@ def paged_attention_v1_opt_tc(
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
attn_masks: Optional[torch.Tensor] = None,
attn_masks_stride: int = 0,
) -> None:
torch.ops._C.paged_attention_v1_opt_tc(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
blocksparse_block_size, blocksparse_head_sliding_step,
attn_masks, attn_masks_stride)
def paged_attention_v2_opt_tc(
......@@ -263,13 +278,16 @@ def paged_attention_v2_opt_tc(
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
attn_masks: Optional[torch.Tensor] = None,
attn_masks_stride: int = 0,
) -> None:
torch.ops._C.paged_attention_v2_opt_tc(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step)
blocksparse_block_size, blocksparse_head_sliding_step,
attn_masks, attn_masks_stride)
def paged_attention_rocm(
......@@ -1142,6 +1160,31 @@ def register_graph_buffers(fa: int, handles: List[str],
offsets: List[List[int]]) -> None:
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
def read_cache(
keys: torch.Tensor,
values: torch.Tensor,
key_caches: List[torch.Tensor],
value_caches: List[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache_dtype: str
) -> None:
torch.ops._C_cache_ops.read_cache(keys, values, key_caches,
value_caches, slot_mapping,
kv_cache_dtype)
def write_cache_multi_layers(
keys: torch.Tensor,
values: torch.Tensor,
key_caches: List[torch.Tensor],
value_caches: List[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache_dtype: str
) -> None:
torch.ops._C_cache_ops.write_cache_multi_layers(keys, values, key_caches,
value_caches, slot_mapping,
kv_cache_dtype)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
......
......@@ -82,6 +82,17 @@ class AttentionBackend(ABC):
src_to_dists: torch.Tensor,
) -> None:
raise NotImplementedError
@staticmethod
@abstractmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
raise NotImplementedError
def advance_step(self, model_input: "ModelRunnerInputBase",
sampled_token_ids: Optional[torch.Tensor],
......@@ -195,7 +206,8 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
@abstractmethod
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int) -> T:
cuda_graph_pad_size: int, batch_size: int,
tree_attention_masks_tensor: Optional[torch.Tensor] = None) -> T:
"""Build attention metadata with on-device tensors."""
raise NotImplementedError
......
......@@ -10,6 +10,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn, get_head_sliding_step)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm import _custom_ops as ops
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
......@@ -128,6 +129,50 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
key_caches = []
value_caches = []
num_layers = len(kv_caches)
token_num = src_to_dists.shape[0]
tmp_store_kv = torch.empty(
(2, num_layers, token_num, num_kv_heads, head_size),
dtype=kv_caches[0].dtype, device=kv_caches[0].device)
keys = tmp_store_kv[0].contiguous()
values = tmp_store_kv[1].contiguous()
for kv_cache in kv_caches:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, num_kv_heads, head_size)
key_caches.append(key_cache)
value_caches.append(value_cache)
ops.read_cache(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 0].contiguous(),
kv_cache_dtype
)
ops.write_cache_multi_layers(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 1].contiguous(),
kv_cache_dtype
)
@dataclass
class BlocksparseFlashAttentionMetadata(AttentionMetadata):
......@@ -190,6 +235,8 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
"BlocksparseFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional[
"BlocksparseFlashAttentionMetadata"] = None
tree_attention_masks_tensor: Optional[torch.Tensor] = None
@property
def prefill_metadata(
......@@ -222,6 +269,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False,
tree_attention_masks_tensor=self.tree_attention_masks_tensor
)
return self._cached_prefill_metadata
......@@ -250,6 +298,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
tree_attention_masks_tensor=self.tree_attention_masks_tensor
)
return self._cached_decode_metadata
......
......@@ -475,7 +475,8 @@ class FlashAttentionMetadataBuilder(
self.block_size, inter_data.block_tables)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
cuda_graph_pad_size: int, batch_size: int,
tree_attention_masks_tensor: Optional[torch.Tensor] = None):
"""Build attention metadata with on-device tensors.
Args:
......@@ -484,6 +485,7 @@ class FlashAttentionMetadataBuilder(
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
tree_attention_masks_tensor: attention mask used in tree style attention.
"""
prefix_cache_hit = any([
inter_data.prefix_cache_hit
......
......@@ -92,6 +92,16 @@ class FlashInferBackend(AttentionBackend):
return torch.float8_e5m2
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
NotImplementedError
class FlashInferState(AttentionState):
......@@ -574,7 +584,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.paged_kv_last_page_len.append(last_page_len)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
cuda_graph_pad_size: int, batch_size: int,
tree_attention_masks_tensor: Optional[torch.Tensor] = None):
"""Build attention metadata with on-device tensors.
Args:
......@@ -583,6 +594,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
tree_attention_masks_tensor: attention mask used in tree style attention.
"""
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
......
......@@ -62,6 +62,16 @@ class IpexAttnBackend(AttentionBackend):
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
NotImplementedError
@dataclass
class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata):
......
......@@ -62,6 +62,16 @@ class OpenVINOAttentionBackend(AttentionBackend):
key_cache.data[dst, :] = key_cache.data[src, :]
value_cache.data[dst, :] = value_cache.data[src, :]
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
NotImplementedError
@dataclass
class OpenVINOAttentionMetadata:
......
......@@ -53,6 +53,16 @@ class PallasAttentionBackend(AttentionBackend):
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
v_cache[:, dst_indices] = v_cache[:, src_indices]
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
NotImplementedError
@dataclass
class PallasMetadata(AttentionMetadata):
......
......@@ -13,6 +13,7 @@ from vllm.attention.backends.utils import (CommonAttentionState,
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
if TYPE_CHECKING:
......@@ -71,6 +72,50 @@ class ROCmFlashAttentionBackend(AttentionBackend):
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
key_caches = []
value_caches = []
num_layers = len(kv_caches)
token_num = src_to_dists.shape[0]
tmp_store_kv = torch.empty(
(2, num_layers, token_num, num_kv_heads, head_size),
dtype=kv_caches[0].dtype, device=kv_caches[0].device)
keys = tmp_store_kv[0].contiguous()
values = tmp_store_kv[1].contiguous()
for kv_cache in kv_caches:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, num_kv_heads, head_size)
key_caches.append(key_cache)
value_caches.append(value_cache)
ops.read_cache(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 0].contiguous(),
kv_cache_dtype
)
ops.write_cache_multi_layers(
keys,
values,
key_caches,
value_caches,
src_to_dists[:, 1].contiguous(),
kv_cache_dtype
)
@dataclass
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
......@@ -122,6 +167,8 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
tree_attention_masks_tensor: Optional[torch.Tensor] = None
@property
def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
if self.num_prefills == 0:
......@@ -152,6 +199,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False,
tree_attention_masks_tensor=self.tree_attention_masks_tensor
)
return self._cached_prefill_metadata
......@@ -180,6 +228,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
tree_attention_masks_tensor=self.tree_attention_masks_tensor
)
return self._cached_decode_metadata
......@@ -613,6 +662,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
v_scale,
)
else:
tree_attention_masks_tensor = decode_meta.tree_attention_masks_tensor
output[num_prefill_tokens:] = PagedAttention.forward_decode(
decode_query,
key_cache,
......@@ -626,6 +676,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.alibi_slopes,
k_scale,
v_scale,
attn_masks=tree_attention_masks_tensor,
attn_masks_stride=tree_attention_masks_tensor.stride(0) if tree_attention_masks_tensor is not None else 0
)
# Reshape the output tensor.
......
......@@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.utils import is_cpu
from vllm import _custom_ops as ops
if is_cpu():
try:
......@@ -64,6 +65,16 @@ class TorchSDPABackend(AttentionBackend):
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@staticmethod
def move_cache(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
kv_cache_dtype: str,
num_kv_heads: int,
head_size: int,
) -> None:
NotImplementedError
@dataclass
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
......
"""Attention backend utils"""
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union, Optional
import numpy as np
import torch
......@@ -188,7 +188,8 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self.block_size, inter_data.block_tables)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
cuda_graph_pad_size: int, batch_size: int,
tree_attention_masks_tensor: Optional[torch.Tensor] = None):
"""Build attention metadata with on-device tensors.
Args:
......@@ -271,6 +272,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
tree_attention_masks_tensor=tree_attention_masks_tensor
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment