Commit d365012a authored by zhangshao's avatar zhangshao
Browse files

解决cudagraph计算错误的问题

parent 4d4c6fe3
...@@ -92,7 +92,7 @@ __device__ void paged_attention_kernel_opt( ...@@ -92,7 +92,7 @@ __device__ void paged_attention_kernel_opt(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale_ptr, const float* v_scale_ptr, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.z; const int seq_idx = blockIdx.z;
...@@ -316,7 +316,7 @@ __device__ void paged_attention_kernel_opt( ...@@ -316,7 +316,7 @@ __device__ void paged_attention_kernel_opt(
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>( Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>( k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, k_scale); k_vec_quant, *k_scale_ptr);
} }
} }
} }
...@@ -483,7 +483,7 @@ __device__ void paged_attention_kernel_opt( ...@@ -483,7 +483,7 @@ __device__ void paged_attention_kernel_opt(
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset); *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec. // Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec, v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
v_scale); *v_scale_ptr);
} }
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the // NOTE(woosuk): When v_vec contains the tokens that are out of the
...@@ -610,7 +610,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel_opt( ...@@ -610,7 +610,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel_opt(
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
...@@ -650,7 +650,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel_opt( ...@@ -650,7 +650,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel_opt(
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
...@@ -770,7 +770,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt( ...@@ -770,7 +770,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
} }
} // namespace vllm } // namespace vllm
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
((void*)vllm::paged_attention_v1_kernel_opt<T, CACHE_T, HEAD_SIZE, \ ((void*)vllm::paged_attention_v1_kernel_opt<T, CACHE_T, HEAD_SIZE, \
...@@ -783,20 +783,10 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt( ...@@ -783,20 +783,10 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step); blocksparse_head_sliding_step);
// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
// vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
// NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads> \
// <<<dim3(grid), dim3(block)>>>( \
// out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
// scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
// alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
// kv_scale, tp_rank, blocksparse_local_blocks, \
// blocksparse_vert_stride, blocksparse_block_size, \
// blocksparse_head_sliding_step);
// TODO(woosuk): Tune NUM_THREADS. // TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE, template <typename T, typename CACHE_T, int BLOCK_SIZE,
...@@ -805,8 +795,8 @@ void paged_attention_v1_launcher( ...@@ -805,8 +795,8 @@ void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
...@@ -829,7 +819,8 @@ void paged_attention_v1_launcher( ...@@ -829,7 +819,8 @@ void paged_attention_v1_launcher(
alibi_slopes alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
...@@ -910,7 +901,7 @@ void paged_attention_v1_opt( ...@@ -910,7 +901,7 @@ void paged_attention_v1_opt(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) { const int64_t blocksparse_head_sliding_step) {
...@@ -928,7 +919,7 @@ void paged_attention_v1_opt( ...@@ -928,7 +919,7 @@ void paged_attention_v1_opt(
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \ value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \ blocksparse_block_size, blocksparse_head_sliding_step); \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \ hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
...@@ -945,8 +936,8 @@ void paged_attention_v2_launcher( ...@@ -945,8 +936,8 @@ void paged_attention_v2_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
...@@ -965,7 +956,8 @@ void paged_attention_v2_launcher( ...@@ -965,7 +956,8 @@ void paged_attention_v2_launcher(
alibi_slopes alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr()); float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr()); float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
...@@ -1058,7 +1050,7 @@ void paged_attention_v2_opt( ...@@ -1058,7 +1050,7 @@ void paged_attention_v2_opt(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) { const int64_t blocksparse_head_sliding_step) {
...@@ -1070,4 +1062,4 @@ void paged_attention_v2_opt( ...@@ -1070,4 +1062,4 @@ void paged_attention_v2_opt(
#undef WARP_SIZE #undef WARP_SIZE
#undef MAX #undef MAX
#undef MIN #undef MIN
#undef DIVIDE_ROUND_UP #undef DIVIDE_ROUND_UP
\ No newline at end of file
...@@ -162,7 +162,7 @@ __device__ void paged_attention_kernel_TC( ...@@ -162,7 +162,7 @@ __device__ void paged_attention_kernel_TC(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.z; const int seq_idx = blockIdx.z;
...@@ -639,7 +639,7 @@ __global__ void paged_attention_v1_kernel_TC( ...@@ -639,7 +639,7 @@ __global__ void paged_attention_v1_kernel_TC(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
#if defined(__gfx936__) || defined(__gfx928__) #if defined(__gfx936__) || defined(__gfx928__)
...@@ -678,7 +678,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC( ...@@ -678,7 +678,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
#if defined(__gfx936__) || defined(__gfx928__) #if defined(__gfx936__) || defined(__gfx928__)
...@@ -814,7 +814,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t ...@@ -814,7 +814,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads,num_kv_heads, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads,num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step); blocksparse_head_sliding_step);
...@@ -908,8 +908,8 @@ void paged_attention_v1_launcher_opt_tc( ...@@ -908,8 +908,8 @@ void paged_attention_v1_launcher_opt_tc(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
...@@ -932,7 +932,8 @@ void paged_attention_v1_launcher_opt_tc( ...@@ -932,7 +932,8 @@ void paged_attention_v1_launcher_opt_tc(
alibi_slopes alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
...@@ -1001,7 +1002,7 @@ void paged_attention_v1_launcher_opt_tc( ...@@ -1001,7 +1002,7 @@ void paged_attention_v1_launcher_opt_tc(
break; \ break; \
} }
void paged_attention_v1_opt( void paged_attention_v1(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& torch::Tensor&
...@@ -1014,7 +1015,7 @@ void paged_attention_v1_opt( ...@@ -1014,7 +1015,7 @@ void paged_attention_v1_opt(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
...@@ -1032,14 +1033,14 @@ void paged_attention_v1_opt_tc( ...@@ -1032,14 +1033,14 @@ void paged_attention_v1_opt_tc(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) { const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse|| if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){ block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){
paged_attention_v1_opt(out,query,key_cache,value_cache,num_kv_heads, paged_attention_v1(out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype, scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride, k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step); blocksparse_block_size,blocksparse_head_sliding_step);
...@@ -1059,7 +1060,7 @@ void paged_attention_v1_opt_tc( ...@@ -1059,7 +1060,7 @@ void paged_attention_v1_opt_tc(
max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, \ max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, \
num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \ num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step); \ blocksparse_head_sliding_step); \
hipLaunchKernelGGL( \ hipLaunchKernelGGL( \
...@@ -1133,8 +1134,8 @@ void paged_attention_v2_launcher_opt_tc( ...@@ -1133,8 +1134,8 @@ void paged_attention_v2_launcher_opt_tc(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
...@@ -1156,6 +1157,8 @@ void paged_attention_v2_launcher_opt_tc( ...@@ -1156,6 +1157,8 @@ void paged_attention_v2_launcher_opt_tc(
: nullptr; : nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr()); float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr()); float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr()); T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
...@@ -1231,7 +1234,7 @@ void paged_attention_v2_launcher_opt_tc( ...@@ -1231,7 +1234,7 @@ void paged_attention_v2_launcher_opt_tc(
break; \ break; \
} }
void paged_attention_v2_opt( void paged_attention_v2(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
...@@ -1248,7 +1251,7 @@ void paged_attention_v2_opt( ...@@ -1248,7 +1251,7 @@ void paged_attention_v2_opt(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
...@@ -1270,14 +1273,14 @@ void paged_attention_v2_opt_tc( ...@@ -1270,14 +1273,14 @@ void paged_attention_v2_opt_tc(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) { const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse|| if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){ block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){
paged_attention_v2_opt(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads, paged_attention_v2(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype, scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride, k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step); blocksparse_block_size,blocksparse_head_sliding_step);
......
...@@ -105,7 +105,7 @@ __device__ void paged_attention_with_mask_kernel( ...@@ -105,7 +105,7 @@ __device__ void paged_attention_with_mask_kernel(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale_ptr, const float* v_scale_ptr, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) { const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
...@@ -286,7 +286,7 @@ __device__ void paged_attention_with_mask_kernel( ...@@ -286,7 +286,7 @@ __device__ void paged_attention_with_mask_kernel(
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>( Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>( k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, k_scale); k_vec_quant, *k_scale_ptr);
} }
} }
...@@ -424,7 +424,7 @@ __device__ void paged_attention_with_mask_kernel( ...@@ -424,7 +424,7 @@ __device__ void paged_attention_with_mask_kernel(
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset); *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec. // Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec, v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
v_scale); *v_scale_ptr);
} }
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the // NOTE(woosuk): When v_vec contains the tokens that are out of the
...@@ -522,7 +522,7 @@ __global__ void paged_attention_v1_with_mask_kernel( ...@@ -522,7 +522,7 @@ __global__ void paged_attention_v1_with_mask_kernel(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) { const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
...@@ -559,7 +559,7 @@ __global__ void paged_attention_v2_with_mask_kernel( ...@@ -559,7 +559,7 @@ __global__ void paged_attention_v2_with_mask_kernel(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) { const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
...@@ -693,7 +693,7 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -693,7 +693,7 @@ __global__ void paged_attention_v2_reduce_kernel(
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks_ptr, \ blocksparse_head_sliding_step, attn_masks_ptr, \
attn_masks_stride); attn_masks_stride);
...@@ -706,8 +706,8 @@ void paged_attention_v1_launcher( ...@@ -706,8 +706,8 @@ void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step, const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, const c10::optional<torch::Tensor>& attn_masks,
...@@ -728,7 +728,8 @@ void paged_attention_v1_launcher( ...@@ -728,7 +728,8 @@ void paged_attention_v1_launcher(
alibi_slopes alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
...@@ -842,7 +843,7 @@ void paged_attention_v1_with_mask( ...@@ -842,7 +843,7 @@ void paged_attention_v1_with_mask(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
...@@ -862,7 +863,7 @@ void paged_attention_v1_with_mask( ...@@ -862,7 +863,7 @@ void paged_attention_v1_with_mask(
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step, \ blocksparse_block_size, blocksparse_head_sliding_step, \
attn_masks_ptr, attn_masks_stride); \ attn_masks_ptr, attn_masks_stride); \
...@@ -880,8 +881,8 @@ void paged_attention_v2_launcher( ...@@ -880,8 +881,8 @@ void paged_attention_v2_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step, const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, const c10::optional<torch::Tensor>& attn_masks,
...@@ -902,7 +903,8 @@ void paged_attention_v2_launcher( ...@@ -902,7 +903,8 @@ void paged_attention_v2_launcher(
alibi_slopes alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr()); float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr()); float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
...@@ -1023,7 +1025,7 @@ void paged_attention_v2_with_mask( ...@@ -1023,7 +1025,7 @@ void paged_attention_v2_with_mask(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
...@@ -1037,4 +1039,4 @@ void paged_attention_v2_with_mask( ...@@ -1037,4 +1039,4 @@ void paged_attention_v2_with_mask(
#undef WARP_SIZE #undef WARP_SIZE
#undef MAX #undef MAX
#undef MIN #undef MIN
#undef DIVIDE_ROUND_UP #undef DIVIDE_ROUND_UP
\ No newline at end of file
...@@ -92,7 +92,7 @@ __device__ void paged_attention_with_mask_kernel_opt( ...@@ -92,7 +92,7 @@ __device__ void paged_attention_with_mask_kernel_opt(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale_ptr, const float* v_scale_ptr, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) { const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
...@@ -317,7 +317,7 @@ __device__ void paged_attention_with_mask_kernel_opt( ...@@ -317,7 +317,7 @@ __device__ void paged_attention_with_mask_kernel_opt(
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>( Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>( k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, k_scale); k_vec_quant, *k_scale_ptr);
} }
} }
} }
...@@ -498,7 +498,7 @@ __device__ void paged_attention_with_mask_kernel_opt( ...@@ -498,7 +498,7 @@ __device__ void paged_attention_with_mask_kernel_opt(
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset); *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec. // Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec, v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
v_scale); *v_scale_ptr);
} }
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the // NOTE(woosuk): When v_vec contains the tokens that are out of the
...@@ -625,7 +625,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_with_mask_kernel_opt ...@@ -625,7 +625,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_with_mask_kernel_opt
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) { const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
...@@ -666,7 +666,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_with_mask_kernel_opt ...@@ -666,7 +666,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_with_mask_kernel_opt
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) { const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
...@@ -800,7 +800,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt( ...@@ -800,7 +800,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks_ptr, \ blocksparse_head_sliding_step, attn_masks_ptr, \
attn_masks_stride); attn_masks_stride);
...@@ -823,8 +823,8 @@ void paged_attention_v1_launcher( ...@@ -823,8 +823,8 @@ void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step, const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, const c10::optional<torch::Tensor>& attn_masks,
...@@ -849,7 +849,8 @@ void paged_attention_v1_launcher( ...@@ -849,7 +849,8 @@ void paged_attention_v1_launcher(
alibi_slopes alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
...@@ -940,7 +941,7 @@ void paged_attention_v1_opt_with_mask( ...@@ -940,7 +941,7 @@ void paged_attention_v1_opt_with_mask(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
...@@ -960,7 +961,7 @@ void paged_attention_v1_opt_with_mask( ...@@ -960,7 +961,7 @@ void paged_attention_v1_opt_with_mask(
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \ value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step, \ blocksparse_block_size, blocksparse_head_sliding_step, \
attn_masks_ptr, attn_masks_stride); \ attn_masks_ptr, attn_masks_stride); \
...@@ -978,8 +979,8 @@ void paged_attention_v2_launcher( ...@@ -978,8 +979,8 @@ void paged_attention_v2_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step, const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, const c10::optional<torch::Tensor>& attn_masks,
...@@ -1000,7 +1001,8 @@ void paged_attention_v2_launcher( ...@@ -1000,7 +1001,8 @@ void paged_attention_v2_launcher(
alibi_slopes alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr()); float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr()); float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
...@@ -1100,7 +1102,7 @@ void paged_attention_v2_opt_with_mask( ...@@ -1100,7 +1102,7 @@ void paged_attention_v2_opt_with_mask(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
...@@ -1114,4 +1116,4 @@ void paged_attention_v2_opt_with_mask( ...@@ -1114,4 +1116,4 @@ void paged_attention_v2_opt_with_mask(
#undef WARP_SIZE #undef WARP_SIZE
#undef MAX #undef MAX
#undef MIN #undef MIN
#undef DIVIDE_ROUND_UP #undef DIVIDE_ROUND_UP
\ No newline at end of file
...@@ -57,7 +57,7 @@ void paged_attention_v1_opt( ...@@ -57,7 +57,7 @@ void paged_attention_v1_opt(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
...@@ -68,7 +68,7 @@ void paged_attention_v2_opt( ...@@ -68,7 +68,7 @@ void paged_attention_v2_opt(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
...@@ -78,7 +78,7 @@ void paged_attention_v1_opt_tc( ...@@ -78,7 +78,7 @@ void paged_attention_v1_opt_tc(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
...@@ -89,7 +89,7 @@ void paged_attention_v2_opt_tc( ...@@ -89,7 +89,7 @@ void paged_attention_v2_opt_tc(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
...@@ -101,7 +101,7 @@ void paged_attention_v1_with_mask( ...@@ -101,7 +101,7 @@ void paged_attention_v1_with_mask(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
...@@ -114,7 +114,7 @@ void paged_attention_v2_with_mask( ...@@ -114,7 +114,7 @@ void paged_attention_v2_with_mask(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
...@@ -126,7 +126,7 @@ void paged_attention_v1_opt_with_mask( ...@@ -126,7 +126,7 @@ void paged_attention_v1_opt_with_mask(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
...@@ -139,7 +139,7 @@ void paged_attention_v2_opt_with_mask( ...@@ -139,7 +139,7 @@ void paged_attention_v2_opt_with_mask(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
...@@ -151,7 +151,7 @@ void paged_attention_v1_opt_tc_with_mask( ...@@ -151,7 +151,7 @@ void paged_attention_v1_opt_tc_with_mask(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
...@@ -164,7 +164,7 @@ void paged_attention_v2_opt_tc_with_mask( ...@@ -164,7 +164,7 @@ void paged_attention_v2_opt_tc_with_mask(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
......
...@@ -58,7 +58,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -58,7 +58,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"); " int blocksparse_head_sliding_step) -> ()");
...@@ -72,7 +72,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -72,7 +72,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"); " int blocksparse_head_sliding_step) -> ()");
...@@ -86,7 +86,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -86,7 +86,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"); " int blocksparse_head_sliding_step) -> ()");
...@@ -100,7 +100,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -100,7 +100,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"); " int blocksparse_head_sliding_step) -> ()");
...@@ -114,7 +114,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -114,7 +114,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step," " int blocksparse_head_sliding_step,"
...@@ -130,7 +130,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -130,7 +130,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step," " int blocksparse_head_sliding_step,"
...@@ -146,7 +146,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -146,7 +146,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step," " int blocksparse_head_sliding_step,"
...@@ -162,7 +162,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -162,7 +162,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step," " int blocksparse_head_sliding_step,"
...@@ -178,7 +178,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -178,7 +178,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step," " int blocksparse_head_sliding_step,"
...@@ -194,7 +194,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -194,7 +194,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step," " int blocksparse_head_sliding_step,"
......
...@@ -117,8 +117,8 @@ def paged_attention_v1_with_mask( ...@@ -117,8 +117,8 @@ def paged_attention_v1_with_mask(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -152,8 +152,8 @@ def paged_attention_v2_with_mask( ...@@ -152,8 +152,8 @@ def paged_attention_v2_with_mask(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -185,8 +185,8 @@ def paged_attention_v1_opt( ...@@ -185,8 +185,8 @@ def paged_attention_v1_opt(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -217,8 +217,8 @@ def paged_attention_v2_opt( ...@@ -217,8 +217,8 @@ def paged_attention_v2_opt(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -246,8 +246,8 @@ def paged_attention_v1_opt_with_mask( ...@@ -246,8 +246,8 @@ def paged_attention_v1_opt_with_mask(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -281,8 +281,8 @@ def paged_attention_v2_opt_with_mask( ...@@ -281,8 +281,8 @@ def paged_attention_v2_opt_with_mask(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -314,8 +314,8 @@ def paged_attention_v1_opt_tc( ...@@ -314,8 +314,8 @@ def paged_attention_v1_opt_tc(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -345,8 +345,8 @@ def paged_attention_v2_opt_tc( ...@@ -345,8 +345,8 @@ def paged_attention_v2_opt_tc(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -375,8 +375,8 @@ def paged_attention_v1_opt_tc_with_mask( ...@@ -375,8 +375,8 @@ def paged_attention_v1_opt_tc_with_mask(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -409,8 +409,8 @@ def paged_attention_v2_opt_tc_with_mask( ...@@ -409,8 +409,8 @@ def paged_attention_v2_opt_tc_with_mask(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment