Unverified Commit 978aed53 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Kernel][Attention] Separate `Attention.kv_scale` into `k_scale` and `v_scale` (#6081)

parent 160e1d8c
...@@ -100,7 +100,7 @@ def main( ...@@ -100,7 +100,7 @@ def main(
start_time = time.perf_counter() start_time = time.perf_counter()
# Using default kv_scale # Using default kv_scale
kv_scale = 1.0 k_scale = v_scale = 1.0
for _ in range(num_iters): for _ in range(num_iters):
if version == "v1": if version == "v1":
...@@ -117,7 +117,8 @@ def main( ...@@ -117,7 +117,8 @@ def main(
max_seq_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale, k_scale,
v_scale,
) )
elif version == "v2": elif version == "v2":
ops.paged_attention_v2( ops.paged_attention_v2(
...@@ -136,7 +137,8 @@ def main( ...@@ -136,7 +137,8 @@ def main(
max_seq_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale, k_scale,
v_scale,
) )
else: else:
raise ValueError(f"Invalid version: {version}") raise ValueError(f"Invalid version: {version}")
......
...@@ -105,9 +105,9 @@ __device__ void paged_attention_kernel( ...@@ -105,9 +105,9 @@ __device__ void paged_attention_kernel(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z; const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z; const int max_num_partitions = gridDim.z;
...@@ -285,7 +285,7 @@ __device__ void paged_attention_kernel( ...@@ -285,7 +285,7 @@ __device__ void paged_attention_kernel(
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>( Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>( k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, kv_scale); k_vec_quant, k_scale);
} }
} }
...@@ -415,7 +415,7 @@ __device__ void paged_attention_kernel( ...@@ -415,7 +415,7 @@ __device__ void paged_attention_kernel(
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset); *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec. // Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec, v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
kv_scale); v_scale);
} }
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the // NOTE(woosuk): When v_vec contains the tokens that are out of the
...@@ -513,15 +513,15 @@ __global__ void paged_attention_v1_kernel( ...@@ -513,15 +513,15 @@ __global__ void paged_attention_v1_kernel(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE>( KV_DTYPE, IS_BLOCK_SPARSE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_kv_heads, scale, block_tables, seq_lens, v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks, kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step); blocksparse_head_sliding_step);
} }
...@@ -549,14 +549,14 @@ __global__ void paged_attention_v2_kernel( ...@@ -549,14 +549,14 @@ __global__ void paged_attention_v2_kernel(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>( KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, kv_scale, tp_rank, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step); blocksparse_head_sliding_step);
} }
...@@ -682,7 +682,7 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -682,7 +682,7 @@ __global__ void paged_attention_v2_reduce_kernel(
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
kv_scale, tp_rank, blocksparse_local_blocks, \ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step); blocksparse_head_sliding_step);
...@@ -694,8 +694,8 @@ void paged_attention_v1_launcher( ...@@ -694,8 +694,8 @@ void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale, const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
const int tp_rank, const int blocksparse_local_blocks, float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
...@@ -770,7 +770,7 @@ void paged_attention_v1_launcher( ...@@ -770,7 +770,7 @@ void paged_attention_v1_launcher(
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \ paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \ IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank, \ seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); blocksparse_block_size, blocksparse_head_sliding_step);
...@@ -815,8 +815,8 @@ void paged_attention_v1( ...@@ -815,8 +815,8 @@ void paged_attention_v1(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) { const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
...@@ -833,7 +833,7 @@ void paged_attention_v1( ...@@ -833,7 +833,7 @@ void paged_attention_v1(
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, kv_scale, tp_rank, \ kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \ blocksparse_block_size, blocksparse_head_sliding_step); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \ vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
...@@ -850,8 +850,8 @@ void paged_attention_v2_launcher( ...@@ -850,8 +850,8 @@ void paged_attention_v2_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale, const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
const int tp_rank, const int blocksparse_local_blocks, float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
...@@ -932,8 +932,9 @@ void paged_attention_v2_launcher( ...@@ -932,8 +932,9 @@ void paged_attention_v2_launcher(
IS_BLOCK_SPARSE>( \ IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_block_size, blocksparse_head_sliding_step); blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ #define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \ switch (is_block_sparse) { \
...@@ -980,8 +981,8 @@ void paged_attention_v2( ...@@ -980,8 +981,8 @@ void paged_attention_v2(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) { const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
......
...@@ -18,8 +18,8 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches, ...@@ -18,8 +18,8 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype, const double k_scale,
const double kv_scale); const double v_scale);
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& key_cache,
......
...@@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel( ...@@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel(
// block_size] // block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens] const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int value_stride, const int num_heads, const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x, const int head_size, const int block_size, const int x, const float k_scale,
const float kv_scale) { const float v_scale) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx]; const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) { if (slot_idx < 0) {
...@@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel( ...@@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel(
value_cache[tgt_value_idx] = tgt_value; value_cache[tgt_value_idx] = tgt_value;
} else { } else {
key_cache[tgt_key_idx] = key_cache[tgt_key_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale); fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
value_cache[tgt_value_idx] = value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale); fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
} }
} }
} }
...@@ -248,7 +248,7 @@ __global__ void reshape_and_cache_flash_kernel( ...@@ -248,7 +248,7 @@ __global__ void reshape_and_cache_flash_kernel(
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \ reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \ reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \ slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
num_heads, head_size, block_size, x, kv_scale); num_heads, head_size, block_size, x, k_scale, v_scale);
void reshape_and_cache( void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size] torch::Tensor& key, // [num_tokens, num_heads, head_size]
...@@ -258,7 +258,8 @@ void reshape_and_cache( ...@@ -258,7 +258,8 @@ void reshape_and_cache(
torch::Tensor& torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size] value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens] torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype, const double kv_scale) { const std::string& kv_cache_dtype, const double k_scale,
const double v_scale) {
int num_tokens = key.size(0); int num_tokens = key.size(0);
int num_heads = key.size(1); int num_heads = key.size(1);
int head_size = key.size(2); int head_size = key.size(2);
...@@ -318,13 +319,13 @@ namespace vllm { ...@@ -318,13 +319,13 @@ namespace vllm {
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
Tout* __restrict__ dst_cache, Tout* __restrict__ dst_cache,
const float kv_scale, const float scale,
const int64_t block_stride) { const int64_t block_stride) {
const int64_t block_idx = blockIdx.x; const int64_t block_idx = blockIdx.x;
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
int64_t idx = block_idx * block_stride + i; int64_t idx = block_idx * block_stride + i;
dst_cache[idx] = dst_cache[idx] =
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale); fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], scale);
} }
} }
...@@ -333,11 +334,11 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, ...@@ -333,11 +334,11 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ #define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \ vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \ reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), kv_scale, block_stride); reinterpret_cast<Tout*>(dst_cache.data_ptr()), scale, block_stride);
// Only for testing. // Only for testing.
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double kv_scale, const std::string& kv_cache_dtype) { const double scale, const std::string& kv_cache_dtype) {
torch::Device src_device = src_cache.device(); torch::Device src_device = src_cache.device();
torch::Device dst_device = dst_cache.device(); torch::Device dst_device = dst_cache.device();
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
......
...@@ -423,11 +423,11 @@ void paged_attention_v1( ...@@ -423,11 +423,11 @@ void paged_attention_v1(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) { const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1, TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet."); "CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
...@@ -742,11 +742,11 @@ void paged_attention_v2( ...@@ -742,11 +742,11 @@ void paged_attention_v2(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) { const int64_t blocksparse_head_sliding_step) {
TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1, TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet."); "CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
......
...@@ -107,8 +107,9 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches, ...@@ -107,8 +107,9 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, double kv_scale) { const std::string& kv_cache_dtype, double k_scale,
TORCH_CHECK(kv_scale == 1.0f); double v_scale) {
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
int num_tokens = key.size(0); int num_tokens = key.size(0);
int num_heads = key.size(1); int num_heads = key.size(1);
......
...@@ -16,8 +16,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -16,8 +16,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank," " str kv_cache_dtype, float k_scale, float v_scale,"
" int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"); " int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1); ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
...@@ -30,8 +30,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -30,8 +30,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank," " str kv_cache_dtype, float k_scale, float v_scale,"
" int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"); " int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2); ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);
...@@ -103,7 +103,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { ...@@ -103,7 +103,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor! key_cache, Tensor! value_cache," " Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping," " Tensor slot_mapping,"
" str kv_cache_dtype," " str kv_cache_dtype,"
" float kv_scale) -> ()"); " float k_scale, float v_scale) -> ()");
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
} }
......
...@@ -8,8 +8,8 @@ void paged_attention_v1( ...@@ -8,8 +8,8 @@ void paged_attention_v1(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
...@@ -19,8 +19,8 @@ void paged_attention_v2( ...@@ -19,8 +19,8 @@ void paged_attention_v2(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
......
...@@ -27,8 +27,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -27,8 +27,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank," " str kv_cache_dtype, float k_scale, float v_scale,"
" int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"); " int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1); ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
...@@ -41,8 +41,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -41,8 +41,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float kv_scale, int tp_rank," " str kv_cache_dtype, float k_scale, float v_scale,"
" int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"); " int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
...@@ -223,7 +223,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { ...@@ -223,7 +223,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor! key_cache, Tensor! value_cache," " Tensor! key_cache, Tensor! value_cache,"
" Tensor slot_mapping," " Tensor slot_mapping,"
" str kv_cache_dtype," " str kv_cache_dtype,"
" float kv_scale) -> ()"); " float k_scale, float v_scale) -> ()");
cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache); cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
// Reshape the key and value tensors and cache them. // Reshape the key and value tensors and cache them.
......
...@@ -175,7 +175,7 @@ def test_paged_attention( ...@@ -175,7 +175,7 @@ def test_paged_attention(
key_cache, value_cache = key_caches[0], value_caches[0] key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale # Using default kv_scale
kv_scale = 1.0 k_scale = v_scale = 1.0
# Call the paged attention kernel. # Call the paged attention kernel.
output = torch.empty_like(query) output = torch.empty_like(query)
...@@ -193,7 +193,8 @@ def test_paged_attention( ...@@ -193,7 +193,8 @@ def test_paged_attention(
max_seq_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale, k_scale,
v_scale,
) )
elif version == "v2": elif version == "v2":
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
...@@ -224,7 +225,8 @@ def test_paged_attention( ...@@ -224,7 +225,8 @@ def test_paged_attention(
max_seq_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale, k_scale,
v_scale,
) )
else: else:
raise AssertionError(f"Unknown version: {version}") raise AssertionError(f"Unknown version: {version}")
......
...@@ -212,7 +212,7 @@ def test_paged_attention( ...@@ -212,7 +212,7 @@ def test_paged_attention(
key_cache, value_cache = key_caches[0], value_caches[0] key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale # Using default kv_scale
kv_scale = 1.0 k_scale = v_scale = 1.0
tp_rank = 0 tp_rank = 0
# Call the paged attention kernel. # Call the paged attention kernel.
...@@ -231,7 +231,8 @@ def test_paged_attention( ...@@ -231,7 +231,8 @@ def test_paged_attention(
max_seq_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale, k_scale,
v_scale,
tp_rank=tp_rank, tp_rank=tp_rank,
blocksparse_local_blocks=blocksparse_local_blocks, blocksparse_local_blocks=blocksparse_local_blocks,
blocksparse_vert_stride=blocksparse_vert_stride, blocksparse_vert_stride=blocksparse_vert_stride,
...@@ -267,7 +268,8 @@ def test_paged_attention( ...@@ -267,7 +268,8 @@ def test_paged_attention(
max_seq_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale, k_scale,
v_scale,
tp_rank=tp_rank, tp_rank=tp_rank,
blocksparse_local_blocks=blocksparse_local_blocks, blocksparse_local_blocks=blocksparse_local_blocks,
blocksparse_vert_stride=blocksparse_vert_stride, blocksparse_vert_stride=blocksparse_vert_stride,
......
...@@ -155,11 +155,11 @@ def test_reshape_and_cache( ...@@ -155,11 +155,11 @@ def test_reshape_and_cache(
cloned_value_cache = value_cache.clone() cloned_value_cache = value_cache.clone()
# Using default kv_scale # Using default kv_scale
kv_scale = 1.0 k_scale = v_scale = 1.0
# Call the reshape_and_cache kernel. # Call the reshape_and_cache kernel.
ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
kv_cache_dtype, kv_scale) kv_cache_dtype, k_scale, v_scale)
if kv_cache_dtype == "fp8": if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
......
...@@ -7,19 +7,49 @@ import torch ...@@ -7,19 +7,49 @@ import torch
from tests.quantization.utils import is_quant_method_supported from tests.quantization.utils import is_quant_method_supported
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.model_executor.layers.quantization.fp8 import (Fp8KVCacheMethod,
Fp8LinearMethod)
MODELS = [ MODELS = [
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8", "neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
"nm-testing/Phi-3-mini-128k-instruct-FP8", "nm-testing/Phi-3-mini-128k-instruct-FP8",
] ]
@pytest.mark.skipif(not is_quant_method_supported("fp8"), @pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.") reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model_id", MODELS)
def test_model_load_and_run(vllm_runner, model: str): def test_model_load_and_run(vllm_runner, model_id: str):
with vllm_runner(model) as llm: with vllm_runner(model_id) as llm:
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
outputs = llm.generate_greedy(prompts=["Hello my name is"],
max_tokens=10)
print(outputs[0][1])
KV_CACHE_MODELS = [
# Deprecated AutoFP8 format using .kv_scale
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
# AutoFP8 format using separate .k_scale and .v_scale
"nm-testing/Qwen2-1.5B-Instruct-FP8-K-V",
]
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
attn = model.model.layers[0].self_attn.attn
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
# NOTE: it is valid for scales to be 1.0 (default value), but we know
# these checkpoints have scales < 1.0
assert 0.0 < attn._k_scale < 1.0
assert 0.0 < attn._v_scale < 1.0
# note: this does not test accuracy, just that we can run through # note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy # see lm-eval tests for accuracy
outputs = llm.generate_greedy(prompts=["Hello my name is"], outputs = llm.generate_greedy(prompts=["Hello my name is"],
......
...@@ -84,7 +84,8 @@ def paged_attention_v1( ...@@ -84,7 +84,8 @@ def paged_attention_v1(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
kv_scale: float, k_scale: float,
v_scale: float,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -94,8 +95,9 @@ def paged_attention_v1( ...@@ -94,8 +95,9 @@ def paged_attention_v1(
torch.ops._C.paged_attention_v1( torch.ops._C.paged_attention_v1(
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_block_size, blocksparse_head_sliding_step) blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step)
def paged_attention_v2( def paged_attention_v2(
...@@ -114,7 +116,8 @@ def paged_attention_v2( ...@@ -114,7 +116,8 @@ def paged_attention_v2(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
kv_scale: float, k_scale: float,
v_scale: float,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -124,7 +127,7 @@ def paged_attention_v2( ...@@ -124,7 +127,7 @@ def paged_attention_v2(
torch.ops._C.paged_attention_v2( torch.ops._C.paged_attention_v2(
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, kv_scale, tp_rank, alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_local_blocks, blocksparse_vert_stride,
blocksparse_block_size, blocksparse_head_sliding_step) blocksparse_block_size, blocksparse_head_sliding_step)
...@@ -374,11 +377,12 @@ def reshape_and_cache( ...@@ -374,11 +377,12 @@ def reshape_and_cache(
value_cache: torch.Tensor, value_cache: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache_dtype: str, kv_cache_dtype: str,
kv_scale: float, k_scale: float,
v_scale: float,
) -> None: ) -> None:
torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
value_cache, slot_mapping, value_cache, slot_mapping,
kv_cache_dtype, kv_scale) kv_cache_dtype, k_scale, v_scale)
def reshape_and_cache_flash( def reshape_and_cache_flash(
......
...@@ -59,7 +59,8 @@ class ipex_ops: ...@@ -59,7 +59,8 @@ class ipex_ops:
max_context_len: int, max_context_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
kv_scale: float, k_scale: float,
v_scale: float,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -99,7 +100,8 @@ class ipex_ops: ...@@ -99,7 +100,8 @@ class ipex_ops:
max_context_len: int, max_context_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
kv_scale: float, k_scale: float,
v_scale: float,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -227,7 +229,8 @@ class ipex_ops: ...@@ -227,7 +229,8 @@ class ipex_ops:
value_cache: torch.Tensor, value_cache: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache_dtype: str, kv_cache_dtype: str,
kv_scale: float, k_scale: float,
v_scale: float,
) -> None: ) -> None:
assert kv_cache_dtype == "auto" assert kv_cache_dtype == "auto"
ipex.llm.modules.PagedAttention.reshape_and_cache( ipex.llm.modules.PagedAttention.reshape_and_cache(
......
...@@ -134,7 +134,8 @@ class AttentionImpl(ABC, Generic[T]): ...@@ -134,7 +134,8 @@ class AttentionImpl(ABC, Generic[T]):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: T, attn_metadata: T,
kv_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -327,7 +327,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl): ...@@ -327,7 +327,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: BlocksparseFlashAttentionMetadata, attn_metadata: BlocksparseFlashAttentionMetadata,
kv_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
...@@ -368,7 +369,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl): ...@@ -368,7 +369,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
value_cache, value_cache,
attn_metadata.slot_mapping, attn_metadata.slot_mapping,
self.kv_cache_dtype, self.kv_cache_dtype,
kv_scale, k_scale,
v_scale,
) )
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
...@@ -405,7 +407,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl): ...@@ -405,7 +407,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
self.alibi_slopes, self.alibi_slopes,
kv_scale, k_scale,
v_scale,
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
blocksparse_local_blocks=self.local_blocks, blocksparse_local_blocks=self.local_blocks,
blocksparse_vert_stride=self.vert_stride, blocksparse_vert_stride=self.vert_stride,
......
...@@ -256,7 +256,8 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -256,7 +256,8 @@ class FlashAttentionImpl(AttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata, attn_metadata: FlashAttentionMetadata,
kv_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention. """Forward pass with FlashAttention.
...@@ -277,7 +278,8 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -277,7 +278,8 @@ class FlashAttentionImpl(AttentionImpl):
"FlashAttentionImpl") "FlashAttentionImpl")
# NOTE(woosuk): FlashAttention does not support FP8 KV cache. # NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
......
...@@ -223,10 +223,12 @@ class FlashInferImpl(AttentionImpl): ...@@ -223,10 +223,12 @@ class FlashInferImpl(AttentionImpl):
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: FlashInferMetadata, attn_metadata: FlashInferMetadata,
kv_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
assert kv_scale == 1.0 assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashInfer.")
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention " "encoder/decoder cross-attention "
......
...@@ -156,7 +156,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -156,7 +156,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: IpexAttnMetadata, # type: ignore attn_metadata: IpexAttnMetadata, # type: ignore
kv_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention. """Forward pass with IPEX varlen_attention and PagedAttention.
...@@ -170,7 +171,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -170,7 +171,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert kv_scale == 1.0 assert k_scale == 1.0 and v_scale == 1.0
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention " "encoder/decoder cross-attention "
...@@ -192,7 +193,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -192,7 +193,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
value_cache, value_cache,
attn_metadata.slot_mapping.flatten(), attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype, self.kv_cache_dtype,
kv_scale, k_scale,
v_scale,
) )
if attn_metadata.is_prompt: if attn_metadata.is_prompt:
...@@ -273,7 +275,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -273,7 +275,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len, max_seq_len,
self.alibi_slopes, self.alibi_slopes,
self.kv_cache_dtype, self.kv_cache_dtype,
kv_scale, k_scale,
v_scale,
) )
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
...@@ -305,7 +308,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -305,7 +308,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len, max_seq_len,
self.alibi_slopes, self.alibi_slopes,
self.kv_cache_dtype, self.kv_cache_dtype,
kv_scale, k_scale,
v_scale,
) )
# Reshape the output tensor. # Reshape the output tensor.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment