Commit f776086d authored by 王敏's avatar 王敏
Browse files

Merge remote-tracking branch 'origin/v0.7.2-dev' into v0.7.2-dev

parents 1fbdf957 bed03c5a
...@@ -161,7 +161,7 @@ def benchmark_config( ...@@ -161,7 +161,7 @@ def benchmark_config(
nn_moe = False nn_moe = False
block_shape=[0, group_size] block_shape=[0, group_size]
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
print(f"shape: {x.shape[0]} | config: {config}")
def prepare(i: int): def prepare(i: int):
input_gating.copy_(gating_output[i]) input_gating.copy_(gating_output[i])
...@@ -187,6 +187,7 @@ def benchmark_config( ...@@ -187,6 +187,7 @@ def benchmark_config(
a2_scale=a2_scale, a2_scale=a2_scale,
use_nn_moe=nn_moe, use_nn_moe=nn_moe,
block_shape=block_shape, block_shape=block_shape,
moe_ep_size=1,
) )
# JIT compilation & warmup # JIT compilation & warmup
...@@ -221,8 +222,7 @@ def benchmark_config( ...@@ -221,8 +222,7 @@ def benchmark_config(
end_event.record() end_event.record()
end_event.synchronize() end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event)) latencies.append(start_event.elapsed_time(end_event))
avg = sum(latencies) / (num_iters * 10) * 1000 # us avg = sum(latencies) / (num_iters) * 1000 # us
print(f"avg: {avg}")
# graph.reset() # graph.reset()
return avg return avg
...@@ -694,7 +694,7 @@ if __name__ == "__main__": ...@@ -694,7 +694,7 @@ if __name__ == "__main__":
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
parser.add_argument("--model", parser.add_argument("--model",
type=str, type=str,
default="/home/yang/llm-models/vllm-awq-models/DeepSeek-R1-AWQ/") default="")
parser.add_argument("--tp-size", parser.add_argument("--tp-size",
"-tp", "-tp",
"--tensor-parallel-size", "--tensor-parallel-size",
......
...@@ -92,7 +92,7 @@ __device__ void paged_attention_kernel_opt( ...@@ -92,7 +92,7 @@ __device__ void paged_attention_kernel_opt(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale_ptr, const float* v_scale_ptr, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.z; const int seq_idx = blockIdx.z;
...@@ -316,7 +316,7 @@ __device__ void paged_attention_kernel_opt( ...@@ -316,7 +316,7 @@ __device__ void paged_attention_kernel_opt(
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>( Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>( k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, k_scale); k_vec_quant, *k_scale_ptr);
} }
} }
} }
...@@ -483,7 +483,7 @@ __device__ void paged_attention_kernel_opt( ...@@ -483,7 +483,7 @@ __device__ void paged_attention_kernel_opt(
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset); *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec. // Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec, v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
v_scale); *v_scale_ptr);
} }
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the // NOTE(woosuk): When v_vec contains the tokens that are out of the
...@@ -610,7 +610,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel_opt( ...@@ -610,7 +610,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_kernel_opt(
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
...@@ -650,7 +650,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel_opt( ...@@ -650,7 +650,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_kernel_opt(
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel_opt<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
...@@ -783,20 +783,10 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt( ...@@ -783,20 +783,10 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step); blocksparse_head_sliding_step);
// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
// vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
// NUM_THREADS, KV_DTYPE, REUSE_KV_TIMES, IS_BLOCK_SPARSE, odd_nheads> \
// <<<dim3(grid), dim3(block)>>>( \
// out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
// scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
// alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
// kv_scale, tp_rank, blocksparse_local_blocks, \
// blocksparse_vert_stride, blocksparse_block_size, \
// blocksparse_head_sliding_step);
// TODO(woosuk): Tune NUM_THREADS. // TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE, template <typename T, typename CACHE_T, int BLOCK_SIZE,
...@@ -805,8 +795,8 @@ void paged_attention_v1_launcher( ...@@ -805,8 +795,8 @@ void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
...@@ -829,7 +819,8 @@ void paged_attention_v1_launcher( ...@@ -829,7 +819,8 @@ void paged_attention_v1_launcher(
alibi_slopes alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
...@@ -910,7 +901,7 @@ void paged_attention_v1_opt( ...@@ -910,7 +901,7 @@ void paged_attention_v1_opt(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) { const int64_t blocksparse_head_sliding_step) {
...@@ -928,7 +919,7 @@ void paged_attention_v1_opt( ...@@ -928,7 +919,7 @@ void paged_attention_v1_opt(
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \ value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \ blocksparse_block_size, blocksparse_head_sliding_step); \
hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \ hipLaunchKernelGGL(( vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
...@@ -945,8 +936,8 @@ void paged_attention_v2_launcher( ...@@ -945,8 +936,8 @@ void paged_attention_v2_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
...@@ -965,7 +956,8 @@ void paged_attention_v2_launcher( ...@@ -965,7 +956,8 @@ void paged_attention_v2_launcher(
alibi_slopes alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr()); float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr()); float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
...@@ -1058,7 +1050,7 @@ void paged_attention_v2_opt( ...@@ -1058,7 +1050,7 @@ void paged_attention_v2_opt(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) { const int64_t blocksparse_head_sliding_step) {
......
...@@ -162,7 +162,7 @@ __device__ void paged_attention_kernel_TC( ...@@ -162,7 +162,7 @@ __device__ void paged_attention_kernel_TC(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.z; const int seq_idx = blockIdx.z;
...@@ -639,7 +639,7 @@ __global__ void paged_attention_v1_kernel_TC( ...@@ -639,7 +639,7 @@ __global__ void paged_attention_v1_kernel_TC(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
#if defined(__gfx936__) || defined(__gfx928__) #if defined(__gfx936__) || defined(__gfx928__)
...@@ -678,7 +678,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC( ...@@ -678,7 +678,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
#if defined(__gfx936__) || defined(__gfx928__) #if defined(__gfx936__) || defined(__gfx928__)
...@@ -814,7 +814,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t ...@@ -814,7 +814,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads,num_kv_heads, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads,num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step); blocksparse_head_sliding_step);
...@@ -908,8 +908,8 @@ void paged_attention_v1_launcher_opt_tc( ...@@ -908,8 +908,8 @@ void paged_attention_v1_launcher_opt_tc(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
...@@ -932,7 +932,8 @@ void paged_attention_v1_launcher_opt_tc( ...@@ -932,7 +932,8 @@ void paged_attention_v1_launcher_opt_tc(
alibi_slopes alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
...@@ -1001,7 +1002,7 @@ void paged_attention_v1_launcher_opt_tc( ...@@ -1001,7 +1002,7 @@ void paged_attention_v1_launcher_opt_tc(
break; \ break; \
} }
void paged_attention_v1_opt( void paged_attention_v1(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& torch::Tensor&
...@@ -1014,7 +1015,7 @@ void paged_attention_v1_opt( ...@@ -1014,7 +1015,7 @@ void paged_attention_v1_opt(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
...@@ -1032,14 +1033,14 @@ void paged_attention_v1_opt_tc( ...@@ -1032,14 +1033,14 @@ void paged_attention_v1_opt_tc(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) { const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse|| if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){ block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){
paged_attention_v1_opt(out,query,key_cache,value_cache,num_kv_heads, paged_attention_v1(out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype, scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride, k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step); blocksparse_block_size,blocksparse_head_sliding_step);
...@@ -1059,7 +1060,7 @@ void paged_attention_v1_opt_tc( ...@@ -1059,7 +1060,7 @@ void paged_attention_v1_opt_tc(
max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, \ max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, \
num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \ num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step); \ blocksparse_head_sliding_step); \
hipLaunchKernelGGL( \ hipLaunchKernelGGL( \
...@@ -1133,8 +1134,8 @@ void paged_attention_v2_launcher_opt_tc( ...@@ -1133,8 +1134,8 @@ void paged_attention_v2_launcher_opt_tc(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) { const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
...@@ -1156,6 +1157,8 @@ void paged_attention_v2_launcher_opt_tc( ...@@ -1156,6 +1157,8 @@ void paged_attention_v2_launcher_opt_tc(
: nullptr; : nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr()); float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr()); float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr()); T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
...@@ -1231,7 +1234,7 @@ void paged_attention_v2_launcher_opt_tc( ...@@ -1231,7 +1234,7 @@ void paged_attention_v2_launcher_opt_tc(
break; \ break; \
} }
void paged_attention_v2_opt( void paged_attention_v2(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
...@@ -1248,7 +1251,7 @@ void paged_attention_v2_opt( ...@@ -1248,7 +1251,7 @@ void paged_attention_v2_opt(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
...@@ -1270,14 +1273,14 @@ void paged_attention_v2_opt_tc( ...@@ -1270,14 +1273,14 @@ void paged_attention_v2_opt_tc(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) { const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse|| if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){ block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){
paged_attention_v2_opt(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads, paged_attention_v2(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype, scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride, k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step); blocksparse_block_size,blocksparse_head_sliding_step);
......
...@@ -105,7 +105,7 @@ __device__ void paged_attention_with_mask_kernel( ...@@ -105,7 +105,7 @@ __device__ void paged_attention_with_mask_kernel(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale_ptr, const float* v_scale_ptr, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) { const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
...@@ -286,7 +286,7 @@ __device__ void paged_attention_with_mask_kernel( ...@@ -286,7 +286,7 @@ __device__ void paged_attention_with_mask_kernel(
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>( Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>( k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, k_scale); k_vec_quant, *k_scale_ptr);
} }
} }
...@@ -424,7 +424,7 @@ __device__ void paged_attention_with_mask_kernel( ...@@ -424,7 +424,7 @@ __device__ void paged_attention_with_mask_kernel(
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset); *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec. // Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec, v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
v_scale); *v_scale_ptr);
} }
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the // NOTE(woosuk): When v_vec contains the tokens that are out of the
...@@ -522,7 +522,7 @@ __global__ void paged_attention_v1_with_mask_kernel( ...@@ -522,7 +522,7 @@ __global__ void paged_attention_v1_with_mask_kernel(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) { const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
...@@ -559,7 +559,7 @@ __global__ void paged_attention_v2_with_mask_kernel( ...@@ -559,7 +559,7 @@ __global__ void paged_attention_v2_with_mask_kernel(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) { const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
...@@ -693,7 +693,7 @@ __global__ void paged_attention_v2_reduce_kernel( ...@@ -693,7 +693,7 @@ __global__ void paged_attention_v2_reduce_kernel(
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks_ptr, \ blocksparse_head_sliding_step, attn_masks_ptr, \
attn_masks_stride); attn_masks_stride);
...@@ -706,8 +706,8 @@ void paged_attention_v1_launcher( ...@@ -706,8 +706,8 @@ void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step, const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, const c10::optional<torch::Tensor>& attn_masks,
...@@ -728,7 +728,8 @@ void paged_attention_v1_launcher( ...@@ -728,7 +728,8 @@ void paged_attention_v1_launcher(
alibi_slopes alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
...@@ -842,7 +843,7 @@ void paged_attention_v1_with_mask( ...@@ -842,7 +843,7 @@ void paged_attention_v1_with_mask(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
...@@ -862,7 +863,7 @@ void paged_attention_v1_with_mask( ...@@ -862,7 +863,7 @@ void paged_attention_v1_with_mask(
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step, \ blocksparse_block_size, blocksparse_head_sliding_step, \
attn_masks_ptr, attn_masks_stride); \ attn_masks_ptr, attn_masks_stride); \
...@@ -880,8 +881,8 @@ void paged_attention_v2_launcher( ...@@ -880,8 +881,8 @@ void paged_attention_v2_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step, const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, const c10::optional<torch::Tensor>& attn_masks,
...@@ -902,7 +903,8 @@ void paged_attention_v2_launcher( ...@@ -902,7 +903,8 @@ void paged_attention_v2_launcher(
alibi_slopes alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr()); float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr()); float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
...@@ -1023,7 +1025,7 @@ void paged_attention_v2_with_mask( ...@@ -1023,7 +1025,7 @@ void paged_attention_v2_with_mask(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
......
...@@ -92,7 +92,7 @@ __device__ void paged_attention_with_mask_kernel_opt( ...@@ -92,7 +92,7 @@ __device__ void paged_attention_with_mask_kernel_opt(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale_ptr, const float* v_scale_ptr, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) { const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
...@@ -317,7 +317,7 @@ __device__ void paged_attention_with_mask_kernel_opt( ...@@ -317,7 +317,7 @@ __device__ void paged_attention_with_mask_kernel_opt(
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>( Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>( k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, k_scale); k_vec_quant, *k_scale_ptr);
} }
} }
} }
...@@ -498,7 +498,7 @@ __device__ void paged_attention_with_mask_kernel_opt( ...@@ -498,7 +498,7 @@ __device__ void paged_attention_with_mask_kernel_opt(
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset); *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec. // Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec, v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
v_scale); *v_scale_ptr);
} }
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the // NOTE(woosuk): When v_vec contains the tokens that are out of the
...@@ -625,7 +625,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_with_mask_kernel_opt ...@@ -625,7 +625,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v1_with_mask_kernel_opt
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) { const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
...@@ -666,7 +666,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_with_mask_kernel_opt ...@@ -666,7 +666,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_with_mask_kernel_opt
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) { const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
...@@ -800,7 +800,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt( ...@@ -800,7 +800,7 @@ __global__ __launch_bounds__(256,1) void paged_attention_v2_reduce_kernel_opt(
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks_ptr, \ blocksparse_head_sliding_step, attn_masks_ptr, \
attn_masks_stride); attn_masks_stride);
...@@ -823,8 +823,8 @@ void paged_attention_v1_launcher( ...@@ -823,8 +823,8 @@ void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step, const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, const c10::optional<torch::Tensor>& attn_masks,
...@@ -849,7 +849,8 @@ void paged_attention_v1_launcher( ...@@ -849,7 +849,8 @@ void paged_attention_v1_launcher(
alibi_slopes alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
...@@ -940,7 +941,7 @@ void paged_attention_v1_opt_with_mask( ...@@ -940,7 +941,7 @@ void paged_attention_v1_opt_with_mask(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
...@@ -960,7 +961,7 @@ void paged_attention_v1_opt_with_mask( ...@@ -960,7 +961,7 @@ void paged_attention_v1_opt_with_mask(
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \ value_cache_ptr, num_heads, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step, \ blocksparse_block_size, blocksparse_head_sliding_step, \
attn_masks_ptr, attn_masks_stride); \ attn_masks_ptr, attn_masks_stride); \
...@@ -978,8 +979,8 @@ void paged_attention_v2_launcher( ...@@ -978,8 +979,8 @@ void paged_attention_v2_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step, const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, const c10::optional<torch::Tensor>& attn_masks,
...@@ -1000,7 +1001,8 @@ void paged_attention_v2_launcher( ...@@ -1000,7 +1001,8 @@ void paged_attention_v2_launcher(
alibi_slopes alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr()); float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr()); float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
...@@ -1100,7 +1102,7 @@ void paged_attention_v2_opt_with_mask( ...@@ -1100,7 +1102,7 @@ void paged_attention_v2_opt_with_mask(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
......
...@@ -6,51 +6,31 @@ ...@@ -6,51 +6,31 @@
#include "attention_dtypes.h" #include "attention_dtypes.h"
#include "attention_utils.cuh" #include "attention_utils.cuh"
#ifdef USE_ROCM
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh" #include "../quantization/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#else
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif
#ifndef USE_ROCM #define WARP_SIZE 64
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#include "static_switch_tc.h" #include "static_switch_tc.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
inline std::string get_device_name() std::string get_device_name();
{
hipDeviceProp_t props{}; static const std::string device_name=get_device_name();
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
{
return std::string();
}
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
{
return std::string();
}
const std::string raw_name(props.gcnArchName);
return raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str.
}
static inline int get_env_(const char *env_var) { static inline int get_env_(const char *env_var) {
if (char *value = std::getenv(env_var)) { if (char *value = std::getenv(env_var)) {
return atoi(value); return atoi(value);
} }
return 0; return 0;
} }
static const int PA_USE_V1 = get_env_("PA_USE_V1");
static const int PA_REUSE_KV_TIMES = get_env_("PA_REUSE_KV_TIMES"); static const int PA_REUSE_KV_TIMES = get_env_("PA_REUSE_KV_TIMES");
static const int PA_PARTITION_SIZE = get_env_("PA_PARTITION_SIZE");
static const int PA_BLOCK_SIZE = get_env_("PA_BLOCK_SIZE"); static const int PA_BLOCK_SIZE = get_env_("PA_BLOCK_SIZE");
static const int PA_PRINT_PARAM = get_env_("PA_PRINT_PARAM"); static const int PA_PRINT_PARAM = get_env_("PA_PRINT_PARAM");
namespace vllm { namespace vllm {
...@@ -94,10 +74,16 @@ inline __device__ float block_sum(float* red_smem, float sum) { ...@@ -94,10 +74,16 @@ inline __device__ float block_sum(float* red_smem, float sum) {
using half4_t = __attribute__( (__vector_size__(4 * sizeof(_Float16)) )) _Float16; using half4_t = __attribute__( (__vector_size__(4 * sizeof(_Float16)) )) _Float16;
using v4bh = __attribute__( (__vector_size__(4 * sizeof(short)) )) short; using v4bh = __attribute__( (__vector_size__(4 * sizeof(short)) )) short;
using float4_t = __attribute__( (__vector_size__(4 * sizeof(float)) )) float; using float4_t = __attribute__( (__vector_size__(4 * sizeof(float)) )) float;
using float2_t = __attribute__( (__vector_size__(2 * sizeof(float)) )) float;
struct half4x2{ struct half4x2{
half4_t data[2]; half4_t data[2];
}; };
template<typename scalar_t>
struct vec2data{
scalar_t data[2];
};
template<bool is_half> template<bool is_half>
inline __device__ void float4_2_half4(half4_t& dst,const float4_t& src) inline __device__ void float4_2_half4(half4_t& dst,const float4_t& src)
{ {
...@@ -130,29 +116,25 @@ inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4 ...@@ -130,29 +116,25 @@ inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4
} }
} }
template<bool is_half,bool use_vmac> template<bool is_half>
inline __device__ void builtin_amdgcn_mmac(const half4_t& reg_a, const half4_t& reg_b, float4_t& reg_c) inline __device__ void builtin_amdgcn_mmac(const half4_t& reg_a, const half4_t& reg_b, float4_t& reg_c)
{ {
if constexpr (use_vmac){v_mmac_f32_16x16x16_f16<is_half>(reg_a,reg_b,reg_c);}
else{
if constexpr (is_half){reg_c=__builtin_amdgcn_mmac_f32_16x16x16f16(reg_a,reg_b,reg_c);} if constexpr (is_half){reg_c=__builtin_amdgcn_mmac_f32_16x16x16f16(reg_a,reg_b,reg_c);}
else{ else{
reg_c=__builtin_amdgcn_mmac_f32_16x16x16bf16(*(v4bh*)&reg_a,*(v4bh*)&reg_b,reg_c); reg_c=__builtin_amdgcn_mmac_f32_16x16x16bf16(*(v4bh*)&reg_a,*(v4bh*)&reg_b,reg_c);
} }
}
} }
// TODO(woosuk): Merge the last two dimensions of the grid. // TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,int REUSE_KV_TIMES,bool use_vmac,int PARTITION_SIZE = 0> // Zero means no partitioning. bool IS_BLOCK_SPARSE,int REUSE_KV_TIMES> // Zero means no partitioning.
__device__ void paged_attention_with_mask_kernel_TC( __global__ void paged_attention_kernel_TC_with_mask(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
// max_num_partitions] scalar_t* __restrict__ out, // [num_seqs, num_heads,head_size]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, scalar_t* __restrict__ out_tmp, // [num_seqs, num_heads, max_num_partitions,head_size]
// head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x] // head_size/x, block_size, x]
...@@ -166,36 +148,30 @@ __device__ void paged_attention_with_mask_kernel_TC( ...@@ -166,36 +148,30 @@ __device__ void paged_attention_with_mask_kernel_TC(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) { const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0,int PARTITION_SIZE=0) {
#if defined(__gfx936__) || defined(__gfx928__)
const int seq_idx = blockIdx.z; const int seq_idx = blockIdx.z;
const int partition_idx = blockIdx.y; const int partition_idx = blockIdx.x;
const int max_num_partitions = gridDim.y; const int max_num_partitions = gridDim.x;
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
const int seq_len = __builtin_amdgcn_readfirstlane(seq_lens[seq_idx]); const int seq_len = __builtin_amdgcn_readfirstlane(seq_lens[seq_idx]);
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
// No work to do. Terminate the thread block. const bool USE_PARTITIONING = PARTITION_SIZE<num_seq_blocks * BLOCK_SIZE && PARTITION_SIZE>0;
return; if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) return;
}
constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value; constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value;
static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS"); static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS");
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE; const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE;
// [start_block_idx, end_block_idx) is the range of blocks to process. const int start_block_idx = partition_idx * num_blocks_per_partition;
const int start_block_idx = partition_idx * num_blocks_per_partition;//0,64,128… const int end_block_idx =MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
const int end_block_idx =MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);//64,128,192… const int num_blocks = end_block_idx - start_block_idx;
const int num_blocks = end_block_idx - start_block_idx;//64 or 1-63 const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
// [start_token_idx, end_token_idx) is the range of tokens to process. const int num_tokens = end_token_idx - start_token_idx;
const int start_token_idx = start_block_idx * BLOCK_SIZE;//0,1024,2048… constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);//1024,2048,3072… constexpr int x = 16 / sizeof(cache_t);
const int num_tokens = end_token_idx - start_token_idx;//1024 or 1-1023
// divides NUM_THREADS
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;//4
constexpr int x = 16 / sizeof(cache_t);//8
const int thread_idx = threadIdx.x; const int thread_idx = threadIdx.x;
const int warp_idx = __builtin_amdgcn_readfirstlane(thread_idx / WARP_SIZE); const int warp_idx = __builtin_amdgcn_readfirstlane(thread_idx / WARP_SIZE);
const int lane = thread_idx % WARP_SIZE; const int lane = thread_idx % WARP_SIZE;
...@@ -204,14 +180,10 @@ __device__ void paged_attention_with_mask_kernel_TC( ...@@ -204,14 +180,10 @@ __device__ void paged_attention_with_mask_kernel_TC(
const int num_queries_per_kv = num_heads / num_kv_heads; const int num_queries_per_kv = num_heads / num_kv_heads;
const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES); const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES);
const int odd_tg_round = (((blockIdx.z * gridDim.y * gridDim.x) + blockIdx.y * gridDim.x) / 128) % 2; const int head_idx=(blockIdx.y / num_blocks_per_kv) * num_queries_per_kv + (blockIdx.y % num_blocks_per_kv) * REUSE_KV_TIMES;
const int mid_x = gridDim.x / 2;
const int blockIdx_shift = (odd_tg_round | (gridDim.x & 1)) ? blockIdx.x : (blockIdx.x < mid_x ? (blockIdx.x + mid_x) : (blockIdx.x - mid_x));
const int head_idx = (blockIdx_shift / num_blocks_per_kv) * num_queries_per_kv + (blockIdx_shift % num_blocks_per_kv) * REUSE_KV_TIMES;
//const int head_idx=(blockIdx.x / num_blocks_per_kv) * num_queries_per_kv + (blockIdx.x % num_blocks_per_kv) * REUSE_KV_TIMES;
int q_boundary=REUSE_KV_TIMES; int q_boundary=REUSE_KV_TIMES;
if(num_heads < REUSE_KV_TIMES*gridDim.x && (num_blocks_per_kv-1)*REUSE_KV_TIMES == head_idx%num_queries_per_kv) if(num_heads < REUSE_KV_TIMES*gridDim.y && (num_blocks_per_kv-1)*REUSE_KV_TIMES == head_idx%num_queries_per_kv)
q_boundary=num_queries_per_kv-(num_blocks_per_kv-1)*REUSE_KV_TIMES; q_boundary=num_queries_per_kv-(num_blocks_per_kv-1)*REUSE_KV_TIMES;
const int kv_head_idx = head_idx / num_queries_per_kv; const int kv_head_idx = head_idx / num_queries_per_kv;
constexpr int reuse_group=(REUSE_KV_TIMES-1)/4+1; constexpr int reuse_group=(REUSE_KV_TIMES-1)/4+1;
...@@ -234,51 +206,42 @@ __device__ void paged_attention_with_mask_kernel_TC( ...@@ -234,51 +206,42 @@ __device__ void paged_attention_with_mask_kernel_TC(
q_vec.data[1]={0,0,0,0}; q_vec.data[1]={0,0,0,0};
__shared__ half4x2 q_vecs[REUSE_KV_TIMES][16]; __shared__ half4x2 q_vecs[REUSE_KV_TIMES][16];
//if(thread_idx==0)printf("blockIdx.x==%d,q_boundary=%d,head_idx=%d,kv_head_idx=%d\n",blockIdx.x,q_boundary,head_idx,kv_head_idx);
for(int i=0;i<q_boundary;i++){ for(int i=0;i<q_boundary;i++){
if(thread_idx<16){ if(thread_idx<16){
q_vecs[i][thread_idx]=*reinterpret_cast<const half4x2*>(q_ptr+i*HEAD_SIZE+thread_idx*8); half4x2 temp = *reinterpret_cast<const half4x2*>(q_ptr+i*HEAD_SIZE+thread_idx*8);
#pragma unroll
for(int k=0;k<4;k++){
temp.data[0][k]=((float)temp.data[0][k])*scale;
temp.data[1][k]=((float)temp.data[1][k])*scale;
}
q_vecs[i][thread_idx]=temp;
} }
} }
__syncthreads(); __syncthreads();
// Memory planning.
extern __shared__ char shared_mem[]; extern __shared__ char shared_mem[];
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
scalar_t* logits = reinterpret_cast<scalar_t*>(shared_mem); scalar_t* logits = reinterpret_cast<scalar_t*>(shared_mem);
// Workspace for reduction. // __shared__ float red_smem[2 * NUM_WARPS];
__shared__ float red_smem[2 * NUM_WARPS]; __shared__ float s_max[REUSE_KV_TIMES][NUM_WARPS];
__shared__ float s_logit[NUM_WARPS];
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
// blocksparse specific vars
int bs_block_offset;
int q_bs_block_id;
const cache_t* k_ptr_base = k_cache+kv_head_idx * kv_head_stride+lane*8; const cache_t* k_ptr_base = k_cache+kv_head_idx * kv_head_stride+lane*8;
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;block_idx += NUM_WARPS) {
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]); const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
const cache_t* k_ptr=k_ptr_base + physical_block_number * kv_block_stride; const cache_t* k_ptr=k_ptr_base + physical_block_number * kv_block_stride;
float4_t qk_vec={0,0,0,0}; float4_t qk_vec={0,0,0,0};
half4x2 k_vec[2]; half4x2 k_vec[2];
k_vec[0]=*reinterpret_cast<const half4x2*>(k_ptr); k_vec[0]=*reinterpret_cast<const half4x2*>(k_ptr);
#pragma unroll #pragma unroll
for(int i=0;i<3;i++){ for(int i=0;i<3;i++){
if(rowid<q_boundary)q_vec=q_vecs[rowid][i*4+rows]; if(rowid<q_boundary)q_vec=q_vecs[rowid][i*4+rows];
k_vec[1-i%2]=*reinterpret_cast<const half4x2*>(k_ptr+(i+1)*512); k_vec[1-i%2]=*reinterpret_cast<const half4x2*>(k_ptr+(i+1)*512);
builtin_amdgcn_mmac<is_half,use_vmac>(k_vec[i%2].data[0],q_vec.data[0],qk_vec); builtin_amdgcn_mmac<is_half>(k_vec[i%2].data[0],q_vec.data[0],qk_vec);
builtin_amdgcn_mmac<is_half,use_vmac>(k_vec[i%2].data[1],q_vec.data[1],qk_vec); builtin_amdgcn_mmac<is_half>(k_vec[i%2].data[1],q_vec.data[1],qk_vec);
} }
//tail //tail
{ {
if(rowid<q_boundary)q_vec=q_vecs[rowid][3*4+rows]; if(rowid<q_boundary)q_vec=q_vecs[rowid][3*4+rows];
builtin_amdgcn_mmac<is_half,use_vmac>(k_vec[1].data[0],q_vec.data[0],qk_vec); builtin_amdgcn_mmac<is_half>(k_vec[1].data[0],q_vec.data[0],qk_vec);
v_mmac_f32_16x16x16_f16<is_half>(k_vec[1].data[1],q_vec.data[1],qk_vec); v_mmac_f32_16x16x16_f16<is_half>(k_vec[1].data[1],q_vec.data[1],qk_vec);
} }
#pragma unroll #pragma unroll
...@@ -286,13 +249,11 @@ __device__ void paged_attention_with_mask_kernel_TC( ...@@ -286,13 +249,11 @@ __device__ void paged_attention_with_mask_kernel_TC(
int reuse_kv_idx=rows+i*4; int reuse_kv_idx=rows+i*4;
if(reuse_kv_idx<REUSE_KV_TIMES){ if(reuse_kv_idx<REUSE_KV_TIMES){
if(reuse_kv_idx>=q_boundary)qk_vec[i]=0; if(reuse_kv_idx>=q_boundary)qk_vec[i]=0;
else qk_vec[i]*=scale;
const int token_idx = block_idx * BLOCK_SIZE+rowid; const int token_idx = block_idx * BLOCK_SIZE+rowid;
if(alibi_slope[i] != 0){ if(alibi_slope[i] != 0){
float alibi=alibi_slope[i]* (token_idx - seq_len + 1); float alibi=alibi_slope[i]* (token_idx - seq_len + 1);
qk_vec[i] += alibi; qk_vec[i] += alibi;
} }
// used for tree-style attention // used for tree-style attention
if (attn_masks != nullptr && token_idx < seq_len) { if (attn_masks != nullptr && token_idx < seq_len) {
const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride; const int* attn_masks_ptr = attn_masks + seq_idx * attn_masks_stride;
...@@ -300,7 +261,6 @@ __device__ void paged_attention_with_mask_kernel_TC( ...@@ -300,7 +261,6 @@ __device__ void paged_attention_with_mask_kernel_TC(
qk_vec[i] = -FLT_MAX; qk_vec[i] = -FLT_MAX;
} }
} }
const bool mask = (token_idx >= seq_len); const bool mask = (token_idx >= seq_len);
if(mask){ if(mask){
from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f); from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f);
...@@ -312,68 +272,60 @@ __device__ void paged_attention_with_mask_kernel_TC( ...@@ -312,68 +272,60 @@ __device__ void paged_attention_with_mask_kernel_TC(
} }
} }
} }
// if(blockIdx.x==0)printf("%d,qkmax=%f\n",threadIdx.x,qk_max[0]); // compute max
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
for(int reuse_kv_idx=0; reuse_kv_idx<q_boundary; reuse_kv_idx++) {
const int head_idx_ = head_idx + reuse_kv_idx;
float qk_max_tmp=qk_max[reuse_kv_idx/4];
float exp_sum = 0.f;
#pragma unroll #pragma unroll
for (int mask = 8; mask >= 1; mask /= 2) { for (int mask = 8; mask >= 1; mask /= 2) {
qk_max_tmp = fmaxf(qk_max_tmp, VLLM_SHFL_XOR_SYNC(qk_max_tmp, mask)); #pragma unroll
for(int r=0;r<reuse_group;r++){
qk_max[r]=fmaxf(qk_max[r],__shfl_xor(qk_max[r],mask));
}
}
#pragma unroll
for(int r=0;r<reuse_group;r++){
if(rowid==0&&r*4+rows<q_boundary){
s_max[r*4+rows][warp_idx] = qk_max[r];
} }
if (rowid==0 && reuse_kv_idx%4==rows) {
red_smem[warp_idx] = qk_max_tmp;
} }
__syncthreads(); __syncthreads();
__shared__ float max_out[REUSE_KV_TIMES];
// TODO(woosuk): Refactor this part. __shared__ float expsum_out[REUSE_KV_TIMES];
// Get the max qk value for the sequence. for(int reuse_kv_idx=0; reuse_kv_idx<q_boundary; reuse_kv_idx++) {
qk_max_tmp = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; const int head_idx_ = head_idx + reuse_kv_idx;
float qk_max_tmp = lane < NUM_WARPS ? s_max[reuse_kv_idx][lane] : -FLT_MAX;
float exp_sum = 0.f;
#pragma unroll #pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max_tmp = fmaxf(qk_max_tmp, VLLM_SHFL_XOR_SYNC(qk_max_tmp, mask)); qk_max_tmp = fmaxf(qk_max_tmp, __shfl_xor(qk_max_tmp, mask));
} }
// Broadcast the max qk value to all threads. qk_max_tmp = __shfl(qk_max_tmp, 0);
qk_max_tmp = VLLM_SHFL_SYNC(qk_max_tmp, 0);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(to_float(logits[(reuse_kv_idx * partition_size) + i]) - qk_max_tmp); float val = __expf(to_float(logits[(reuse_kv_idx * partition_size) + i]) - qk_max_tmp);
from_float(logits[(reuse_kv_idx * partition_size) + i] , val); from_float(logits[(reuse_kv_idx * partition_size) + i] , val);
exp_sum += val; exp_sum += val;
} }
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum); exp_sum = block_sum<NUM_WARPS>(s_logit, exp_sum);
// Compute softmax. // Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
from_float(logits[(reuse_kv_idx * partition_size) + i] ,to_float(logits[(reuse_kv_idx * partition_size) + i])*inv_sum); from_float(logits[(reuse_kv_idx * partition_size) + i] ,to_float(logits[(reuse_kv_idx * partition_size) + i])*inv_sum);
} }
__syncthreads(); if(USE_PARTITIONING&&thread_idx == 0){
max_out[reuse_kv_idx] = qk_max_tmp;
// If partitioning is enabled, store the max logit and exp_sum. expsum_out[reuse_kv_idx]=exp_sum;
if (USE_PARTITIONING && thread_idx == 0) {
float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions +
head_idx_ * max_num_partitions + partition_idx;
*max_logits_ptr = qk_max_tmp;
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
head_idx_ * max_num_partitions + partition_idx;
*exp_sums_ptr = exp_sum;
} }
} }
__syncthreads();
constexpr int NUM_ROWS_PER_THREAD =DIVIDE_ROUND_UP(HEAD_SIZE, WARP_SIZE);//2 constexpr int NUM_ROWS_PER_THREAD =DIVIDE_ROUND_UP(HEAD_SIZE, WARP_SIZE);//2
if constexpr(REUSE_KV_TIMES<=2){ if (q_boundary<=2){
float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD]; constexpr int acc_size = REUSE_KV_TIMES==1?1:2;
float accs[acc_size][NUM_ROWS_PER_THREAD];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
#pragma unroll #pragma unroll
for(int k=0;k<REUSE_KV_TIMES;k++) for(int k=0;k<acc_size;k++)
{ {
accs[k][i] = 0.f; accs[k][i] = 0.f;
} }
} }
scalar_t zero_value; scalar_t zero_value;
zero(zero_value); zero(zero_value);
...@@ -402,9 +354,9 @@ __device__ void paged_attention_with_mask_kernel_TC( ...@@ -402,9 +354,9 @@ __device__ void paged_attention_with_mask_kernel_TC(
} }
} }
float4_t out_vec={0,0,0,0}; float4_t out_vec={0,0,0,0};
builtin_amdgcn_mmac<is_half,use_vmac>(v_vec,logits_vec,out_vec); builtin_amdgcn_mmac<is_half>(v_vec,logits_vec,out_vec);
if(rows==k){ if(rows==k){
for(int resuseid=0;resuseid<REUSE_KV_TIMES;resuseid++){ for(int resuseid=0;resuseid<acc_size;resuseid++){
accs[resuseid][i]+=out_vec[resuseid]; accs[resuseid][i]+=out_vec[resuseid];
} }
} }
...@@ -414,8 +366,8 @@ __device__ void paged_attention_with_mask_kernel_TC( ...@@ -414,8 +366,8 @@ __device__ void paged_attention_with_mask_kernel_TC(
__syncthreads(); __syncthreads();
using floatV_t = __attribute__( (__vector_size__(NUM_ROWS_PER_THREAD * sizeof(float)) )) float; using floatV_t = __attribute__( (__vector_size__(NUM_ROWS_PER_THREAD * sizeof(float)) )) float;
// Perform reduction across warps. // Perform reduction across warps.
#pragma unroll
for(int reuse_kv_idx=0; reuse_kv_idx<q_boundary; reuse_kv_idx++) { for(int reuse_kv_idx=0; reuse_kv_idx<acc_size; reuse_kv_idx++) {
if constexpr (NUM_THREADS>64){ if constexpr (NUM_THREADS>64){
floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem); floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem);
#pragma unroll #pragma unroll
...@@ -437,11 +389,16 @@ __device__ void paged_attention_with_mask_kernel_TC( ...@@ -437,11 +389,16 @@ __device__ void paged_attention_with_mask_kernel_TC(
__syncthreads(); __syncthreads();
} }
} }
// Write the final output.
if (warp_idx == 0) { if (warp_idx == 0) {
scalar_t* out_ptr = scalar_t* out_ptr;
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + int out_offset;
(head_idx+reuse_kv_idx) * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; if(USE_PARTITIONING){
out_offset=max_num_partitions*HEAD_SIZE;
out_ptr=out_tmp + seq_idx * num_heads * out_offset + head_idx*out_offset+partition_idx * HEAD_SIZE;
}
else{
out_ptr=out + seq_idx * num_heads * HEAD_SIZE + head_idx*HEAD_SIZE;
}
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane + i * WARP_SIZE; const int row_idx = lane + i * WARP_SIZE;
...@@ -450,6 +407,7 @@ __device__ void paged_attention_with_mask_kernel_TC( ...@@ -450,6 +407,7 @@ __device__ void paged_attention_with_mask_kernel_TC(
} }
} }
} }
#if defined __gfx928__
else{ else{
constexpr int GROUPS=reuse_group*4; constexpr int GROUPS=reuse_group*4;
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy. // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
...@@ -489,7 +447,7 @@ __device__ void paged_attention_with_mask_kernel_TC( ...@@ -489,7 +447,7 @@ __device__ void paged_attention_with_mask_kernel_TC(
} }
} }
float4_t out_vec={0,0,0,0}; float4_t out_vec={0,0,0,0};
builtin_amdgcn_mmac<is_half,use_vmac>(v_vec,logits_vec,out_vec); builtin_amdgcn_mmac<is_half>(v_vec,logits_vec,out_vec);
for(int g=0;g<reuse_group;g++){ for(int g=0;g<reuse_group;g++){
accs[g*4+k][i]+=out_vec[g]; accs[g*4+k][i]+=out_vec[g];
} }
...@@ -525,12 +483,20 @@ __device__ void paged_attention_with_mask_kernel_TC( ...@@ -525,12 +483,20 @@ __device__ void paged_attention_with_mask_kernel_TC(
} }
} }
if (warp_idx == 0) { if (warp_idx == 0) {
scalar_t* out_ptr_base;
int out_offset;
if(USE_PARTITIONING){
out_offset=max_num_partitions*HEAD_SIZE;
out_ptr_base=out_tmp + seq_idx * num_heads * out_offset + head_idx*out_offset+partition_idx * HEAD_SIZE;
}
else{
out_offset=HEAD_SIZE;
out_ptr_base=out + seq_idx * num_heads * HEAD_SIZE + head_idx*HEAD_SIZE;
}
for(int g=0;g<reuse_group;g++){ for(int g=0;g<reuse_group;g++){
int reusekvid=g*4+rows; int reusekvid=g*4+rows;
if(reusekvid<q_boundary){ if(reusekvid<q_boundary){
scalar_t* out_ptr = scalar_t* out_ptr = out_ptr_base + reusekvid * out_offset;
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
(head_idx+reusekvid) * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
for(int k=0;k<4;k++){ for(int k=0;k<4;k++){
...@@ -542,86 +508,120 @@ __device__ void paged_attention_with_mask_kernel_TC( ...@@ -542,86 +508,120 @@ __device__ void paged_attention_with_mask_kernel_TC(
} }
} }
} }
} #else
else{
constexpr int GROUPS=reuse_group*4;
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float4_t accs[4][NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
#pragma unroll
for(int k=0;k<4;k++)
{
accs[k][i] = {0.f,0.f,0.f,0.f};
}
}
scalar_t zero_value;
zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) {
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0};
if(rowid<q_boundary){
logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
}
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + rows*4+rowid*16;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
#pragma unroll
for(int k=0;k<4;k++){
int offset=i*1024+k*256;
half4_t v_vec=*reinterpret_cast<const half4_t*>(v_ptr + offset);
if (block_idx == num_seq_blocks - 1) {
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
for (int j = 0; j < 4; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
}
}
builtin_amdgcn_mmac<is_half>(v_vec,logits_vec,accs[k][i]);
}
}
}
if constexpr (NUM_THREADS>64){
__syncthreads();
using floatV_t = __attribute__( (__vector_size__(reuse_group * sizeof(float)) )) float;
// Perform reduction across warps.
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, for(int m=0; m<4; m++) {
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem);
bool IS_BLOCK_SPARSE,int REUSE_KV_TIMES,bool use_vmac> #pragma unroll
__global__ void paged_attention_v1_with_mask_kernel_TC( for (int i = NUM_WARPS; i > 1; i /= 2) {
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] int mid = i / 2;
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] // Upper warps write to shared memory.
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, if (warp_idx >= mid && warp_idx < i) {
// head_size/x, block_size, x] for(int k=0;k<NUM_ROWS_PER_THREAD;k++){
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, out_smem[((warp_idx - mid) * 64+lane)*NUM_ROWS_PER_THREAD+k]=*(floatV_t*)(&(accs[m][k]));
// head_size, block_size]
const int num_heads,
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
#if defined(__gfx936__) || defined(__gfx928__)
paged_attention_with_mask_kernel_TC<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE,REUSE_KV_TIMES,use_vmac>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_heads,num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step, attn_masks, attn_masks_stride);
#endif
} }
}
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
for(int k=0;k<NUM_ROWS_PER_THREAD;k++){
floatV_t tmp=out_smem[thread_idx*NUM_ROWS_PER_THREAD+k];
#pragma unroll
for (int i = 0; i < reuse_group; i++) {
accs[m][k][i] += tmp[i];
}
}
}
__syncthreads();
}
}
}
if (warp_idx == 0) {
scalar_t* out_ptr_base;
int out_offset;
if(USE_PARTITIONING){
out_offset=max_num_partitions*HEAD_SIZE;
out_ptr_base=out_tmp + seq_idx * num_heads * out_offset + head_idx*out_offset+partition_idx * HEAD_SIZE;
}
else{
out_offset=HEAD_SIZE;
out_ptr_base=out + seq_idx * num_heads * HEAD_SIZE + head_idx*HEAD_SIZE;
}
for(int g=0;g<reuse_group;g++){
int reusekvid=g*4+rows;
if(reusekvid<q_boundary){
scalar_t* out_ptr = out_ptr_base + reusekvid*out_offset;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
for(int k=0;k<4;k++){
const int row_idx = rowid+16*k + i * WARP_SIZE;
from_float(*(out_ptr + row_idx), accs[k][i][g]);
}
}
}
}
}
}
#endif
if (USE_PARTITIONING&&thread_idx < q_boundary){
int offset = seq_idx * num_heads * max_num_partitions + (head_idx+thread_idx) * max_num_partitions + partition_idx;
*(max_logits+offset)=max_out[thread_idx];
*(exp_sums+offset)=expsum_out[thread_idx];
}
#endif
// Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE, int REUSE_KV_TIMES,bool use_vmac, int PARTITION_SIZE,
bool odd_nheads = false>
__global__ __launch_bounds__(256, 1) void paged_attention_v2_with_mask_kernel_TC(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_heads, // [num_heads]
const int num_kv_heads, // [num_kv_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
#if defined(__gfx936__) || defined(__gfx928__)
paged_attention_with_mask_kernel_TC<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES,use_vmac,
PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads,
num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq,
alibi_slopes, q_stride, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step, attn_masks, attn_masks_stride);
#endif
} }
// Grid: (num_heads, num_seqs). // Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS, int PARTITION_SIZE> template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS>
__global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_tc( __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_v2_reduce_kernel_opt_tc(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads, const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
...@@ -630,41 +630,71 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t ...@@ -630,41 +630,71 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size] // max_num_partitions, head_size]
const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_partitions) { const int max_num_partitions,int PARTITION_SIZE=512) {
const int num_heads = gridDim.x; const int num_heads = gridDim.x;
const int head_idx = blockIdx.x; const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.y;
const int seq_len = seq_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
if (num_partitions == 1) { if(num_partitions==1)return;
// No need to reduce. Only copy tmp_out to out. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
scalar_t* out_ptr = const int thread_idx = threadIdx.x;
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; const int warp_idx = __builtin_amdgcn_readfirstlane(thread_idx / WARP_SIZE);
const scalar_t* tmp_out_ptr = const int lane = thread_idx % WARP_SIZE;
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE; int offset = seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions;
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { const float* max_logits_ptr = max_logits + offset;
out_ptr[i] = tmp_out_ptr[i]; const float* exp_sums_ptr = exp_sums + offset;
float max_logit = -FLT_MAX;
float global_max_logit = -FLT_MAX;
float global_exp_sum = 0.0f;
if constexpr(NUM_THREADS == 64&& HEAD_SIZE==128){
__shared__ float shared_exp_sums[64];
if(thread_idx<num_partitions){
max_logit = max_logits_ptr[thread_idx];
global_exp_sum = exp_sums_ptr[thread_idx];
global_max_logit = max_logit;
} }
// Terminate the thread block. #pragma unroll
return; for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
global_max_logit = fmaxf(global_max_logit, VLLM_SHFL_XOR_SYNC(global_max_logit, mask));
} }
if(thread_idx<num_partitions){
global_exp_sum = global_exp_sum * __expf(max_logit - global_max_logit);
shared_exp_sums[thread_idx] = global_exp_sum;
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
global_exp_sum += VLLM_SHFL_XOR_SYNC(global_exp_sum, mask);
}
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
const int warp_idx = threadIdx.x / WARP_SIZE; const scalar_t* tmp_out_ptr = tmp_out + offset * HEAD_SIZE;
const int lane = threadIdx.x % WARP_SIZE; using half2_t = vec2data<scalar_t>;
float2_t acc = {0.0f, 0.0f};
half2_t acc_half;
for (int j = 0; j < num_partitions; ++j) {
half2_t tout= *(half2_t*)(tmp_out_ptr + j * HEAD_SIZE + thread_idx*2);
float temp_sum=shared_exp_sums[j]*inv_global_exp_sum;
#pragma unroll
for(int i=0;i<2;i++){
acc[i] += to_float(tout.data[i])*temp_sum;
}
}
#pragma unroll
for(int i=0;i<2;i++){
from_float(acc_half.data[i],acc[i]);
}
*(half2_t*)(out_ptr+thread_idx*2)=acc_half;
}
else{
// Size: 2 * num_partitions. // Size: 2 * num_partitions.
extern __shared__ char shared_mem[]; extern __shared__ char shared_mem[];
// Workspace for reduction. // Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS]; __shared__ float red_smem[2 * NUM_WARPS];
// Load max logits to shared memory. // Load max logits to shared memory.
float* shared_max_logits = reinterpret_cast<float*>(shared_mem); float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
const float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float max_logit = -FLT_MAX;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
const float l = max_logits_ptr[i]; const float l = max_logits_ptr[i];
shared_max_logits[i] = l; shared_max_logits[i] = l;
...@@ -694,10 +724,6 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t ...@@ -694,10 +724,6 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t
// Load rescaled exp sums to shared memory. // Load rescaled exp sums to shared memory.
float* shared_exp_sums = float* shared_exp_sums =
reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions); reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
const float* exp_sums_ptr = exp_sums +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float global_exp_sum = 0.0f;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
float l = shared_max_logits[i]; float l = shared_max_logits[i];
float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
...@@ -707,7 +733,6 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t ...@@ -707,7 +733,6 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t
__syncthreads(); __syncthreads();
global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum); global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
// Aggregate tmp_out to out. // Aggregate tmp_out to out.
const scalar_t* tmp_out_ptr = const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
...@@ -723,95 +748,142 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t ...@@ -723,95 +748,142 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_t
} }
from_float(out_ptr[i], acc); from_float(out_ptr[i], acc);
} }
}
} }
} // namespace vllm } // namespace vllm
#define LAUNCH_PAGED_ATTENTION_V1_TC(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE) \
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ hipLaunchKernelGGL( \
((void*)vllm::paged_attention_v1_with_mask_kernel_TC<T, CACHE_T, HEAD_SIZE, \ (vllm::paged_attention_kernel_TC_with_mask< \
BLOCK_SIZE, NUM_THREADS, \ T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \
KV_DTYPE, IS_BLOCK_SPARSE,REUSE_KV_TIMES,use_vmac>), \ IS_BLOCK_SPARSE, REUSE_KV_TIMES>), \
shared_mem_size); \ dim3(grid), dim3(block), shared_mem_size, stream, exp_sums_ptr, \
vllm::paged_attention_v1_with_mask_kernel_TC<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \ max_logits_ptr,out_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr,\
NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE,REUSE_KV_TIMES,use_vmac> \ num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \
<<<grid, block, shared_mem_size, stream>>>( \ max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads,num_kv_heads, \ kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks_ptr, \ blocksparse_head_sliding_step,attn_masks_ptr,attn_masks_stride,PARTITION_SIZE);\
attn_masks_stride); if (max_num_partitions<=64&&max_num_partitions>1){ \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, 64>), \
dim3(reduce_grid), dim3(64), 0, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions,PARTITION_SIZE); \
}else if(max_num_partitions>64){ \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, 128>), \
dim3(reduce_grid), dim3(128), reduce_shared_mem_size, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions,PARTITION_SIZE);}
void get_number_thread_and_reuse_kv_v1(int& num_thread,int& reusekv,int batchsize,int seq,int qheads,int kvheads){
//mha static void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITION_SIZE,int &max_num_partitions,
int batchsize,int max_seq_len,int qheads,int kvheads,int num_blocks)
{
reusekv=1; reusekv=1;
num_thread=256;
PARTITION_SIZE=512;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
if(max_seq_len==8192&&num_blocks==1024){//ali test
if(batchsize==1&&qheads==16&&kvheads==16){num_thread=128;return;}
if(batchsize==1&&qheads==32&&kvheads==32){num_thread=64;return;}
if(batchsize==1){
if(qheads==52){reusekv=8;return;}
if(qheads==13){reusekv=2;return;}
reusekv=4;return;
}
if(batchsize==64){
if(qheads==13){PARTITION_SIZE=256;num_thread=128;reusekv=8;}
else if(qheads==32){PARTITION_SIZE=1024;reusekv=8;}
else if(qheads==52||qheads==26){reusekv=16;}
else reusekv=8;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
return;
}
}
if(qheads==kvheads){ if(qheads==kvheads){
//llama 7B ,其他模型未可知 if(max_seq_len<=8192){
if(seq<=16||batchsize>=32)num_thread=64; if(batchsize*qheads>=512){
else if(batchsize<=2)num_thread=256; max_num_partitions=1;
else if(batchsize<8)num_thread=128; num_thread=64;
else num_thread=64; }
if(qheads==32&&max_seq_len<=1024)max_num_partitions=1;
}
return; return;
} }
// mqa if(max_seq_len<800)max_num_partitions=1;
if(qheads>kvheads*4){ if(qheads>kvheads*4){
if(seq<64){ if(max_seq_len<=1000||
if(batchsize<=64){reusekv=1;num_thread=64;} max_seq_len<1500&&(batchsize>=8&&qheads>=8||batchsize>=64)||
else if(batchsize<128){reusekv=2;num_thread=64;} max_seq_len<1900&&batchsize>=8&&qheads==28
else {reusekv=4;num_thread=64;} )
} max_num_partitions=1;
else if(seq<=400){ int blocks=max_num_partitions*batchsize*qheads;
if(batchsize<16){reusekv=1;num_thread=256;} if(device_name=="gfx928"){
else if(batchsize<64){reusekv=2;num_thread=256;} if(batchsize*qheads>1024&&max_seq_len>=2000){
else if(batchsize<=128){ max_num_partitions=1;
reusekv=4; if(max_seq_len<3900)reusekv=8;
if(qheads%7==0)num_thread=64;//qwen7b else if(max_seq_len<7800)reusekv=4;
else num_thread=256;//llama70b else{
PARTITION_SIZE=2048;
reusekv=8;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
} }
else {reusekv=8;num_thread=64;} return;
} }
else if(seq<=1000){
if(batchsize<16){reusekv=1;num_thread=256;}
else if(qheads%7==0&&batchsize<=128){//qwen7b
if(batchsize<64){reusekv=4;num_thread=256;}
else{reusekv=4;num_thread=64;}
} }
else if(batchsize<=64){reusekv=4;num_thread=256;} if(max_num_partitions==1){
else {reusekv=8;num_thread=128;} if(max_seq_len<512){
int bytes=max_seq_len*qheads*batchsize;
if(bytes<51200)reusekv=1;
else if(bytes<256000)reusekv=4;
else reusekv=8;
return;
} }
else if(seq<3900) {reusekv=8;num_thread=256;} if(batchsize<4||batchsize==4&&qheads==8)reusekv=1;
else if(seq<7800) {reusekv=4;num_thread=256;} else if(batchsize<32||batchsize<=64&&qheads==8)reusekv=4;
else {reusekv=2;num_thread=256;} else reusekv=8;
return; return;
} }
if(blocks<150)return;
if(qheads/kvheads >4 && seq<3900)reusekv=8; if(blocks<600||qheads<=kvheads*4){reusekv=4;return;}
else if(qheads/kvheads >2 && seq<7800)reusekv=4; reusekv=8;return;
else if(qheads/kvheads >=2 && seq<15600)reusekv=2;
if(seq<=64){
num_thread=64;
if(batchsize<=64)reusekv=1;
} }
else num_thread=256; if(device_name=="gfx928"){
} if(batchsize*qheads>1024&&max_seq_len>=2000){
max_num_partitions=1;
if(max_seq_len<7800)reusekv=4;
else{
PARTITION_SIZE=2048;
reusekv=4;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
}
return;
}
}
if(max_seq_len<=1000||
max_seq_len<=1500&&(qheads>4&&batchsize>=16||batchsize>=64))
max_num_partitions=1;
int blocks=max_num_partitions*batchsize*qheads;
if(blocks>=150||batchsize>=16||qheads>=8&&(batchsize>=4||max_seq_len>=2000))reusekv=4;
// TODO(woosuk): Tune NUM_THREADS. }
template <typename T, typename CACHE_T, int BLOCK_SIZE, template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE> vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE>
void paged_attention_v1_launcher_opt_tc( void paged_attention_v2_launcher_opt_tc_with_mask(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale, const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks, torch::Tensor& v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step, const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0) { const int attn_masks_stride) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
...@@ -819,100 +891,116 @@ void paged_attention_v1_launcher_opt_tc( ...@@ -819,100 +891,116 @@ void paged_attention_v1_launcher_opt_tc(
int q_stride = query.stride(0); int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0); int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1); int kv_head_stride = key_cache.stride(1);
int num_threads = 128; int num_blocks=key_cache.size(0);
// printf("paged_attention_v1\n");
if (num_heads != num_kv_heads) {
num_threads = 256;
}
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional. // NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr = const float* alibi_slopes_ptr =
alibi_slopes alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
// NOTE: attn_masks is optional.
const int* attn_masks_ptr =
attn_masks ? attn_masks.value().data_ptr<int>() : nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
// float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
// float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
// T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
static float* exp_sums_ptr = nullptr;
// NOTE: attn_masks is optional. static float* max_logits_ptr = nullptr;
const int* attn_masks_ptr = static T* tmp_out_ptr = nullptr;
attn_masks if(exp_sums_ptr == nullptr){
? attn_masks.value().data_ptr<int>() hipMalloc(&exp_sums_ptr, 1000000); // 1m
: nullptr; hipMalloc(&max_logits_ptr, 1000000); // 1m
hipMalloc(&tmp_out_ptr, 100000000); // 100m
int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; }
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 reduce_grid(num_heads, num_seqs);
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2 && KV_DTYPE==vllm::Fp8KVCacheDataType::kAuto){ if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2 && KV_DTYPE==vllm::Fp8KVCacheDataType::kAuto){
constexpr int HEAD_SIZE=128; constexpr int HEAD_SIZE=128;
constexpr static int use_vmac = false; int reusekv, num_thread,max_num_partitions,PARTITION_SIZE;
int reusekv, num_thread; get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks);
get_number_thread_and_reuse_kv_v1(num_thread,reusekv,num_seqs,padded_max_seq_len,num_heads,num_kv_heads); if(PA_PARTITION_SIZE!=0){
PARTITION_SIZE=PA_PARTITION_SIZE;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
}
if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES; if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES;
if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE; if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE;
if(PA_USE_V1!=0)max_num_partitions=1;
if(max_num_partitions==1)PARTITION_SIZE=max_seq_len;
assert(num_seqs*num_heads*max_num_partitions*head_size<=100000000);
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
REUSEKV_SWITCH(reusekv,[&] { REUSEKV_SWITCH(reusekv,[&] {
NUM_THREADS_SWITCH(num_thread , [&] { NUM_THREADS_SWITCH(num_thread , [&] {
//constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES * padded_max_seq_len * 2; int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 2;
if(max_num_partitions==1)PARTITION_SIZE=0;
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
if (NUM_WARPS==64)outputs_size=0; dim3 grid;
int shared_mem_size = ::max(logits_size, outputs_size); grid.y = (num_heads/num_kv_heads + REUSE_KV_TIMES -1)/REUSE_KV_TIMES * num_kv_heads;
dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1,num_seqs); grid.x = max_num_partitions;
grid.z = num_seqs;
dim3 block(NUM_THREADS); dim3 block(NUM_THREADS);
int shared_mem_size = ::max(logits_size, outputs_size);
if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d\n", if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d\n",
reusekv,num_thread,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs); reusekv,num_thread,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs);
LAUNCH_PAGED_ATTENTION_V1_TC(HEAD_SIZE); LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE);
}); });
}); });
} }
} }
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v1_launcher_opt_tc<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \ paged_attention_v2_launcher_opt_tc_with_mask<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \ IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
blocksparse_local_blocks, blocksparse_vert_stride, \ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_block_size, blocksparse_head_sliding_step, \ blocksparse_vert_stride, blocksparse_block_size, \
attn_masks, attn_masks_stride); blocksparse_head_sliding_step,attn_masks, attn_masks_stride);
#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ #define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \ switch (is_block_sparse) { \
case true: \ case true: \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
break; \ break; \
case false: \ case false: \
CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
break; \ break; \
} }
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256. // 1, 2, 4, 64, 128, 256.
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \ switch (block_size) { \
case 8: \ case 8: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \ break; \
case 16: \ case 16: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \ break; \
case 32: \ case 32: \
CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \ break; \
default: \ default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \ break; \
} }
void paged_attention_v1_opt_with_mask( void paged_attention_v2_with_mask(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor&
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
...@@ -924,15 +1012,19 @@ void paged_attention_v1_opt_with_mask( ...@@ -924,15 +1012,19 @@ void paged_attention_v1_opt_with_mask(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,// [num_seqs, max_seq_len]
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len] const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride); const int64_t attn_masks_stride);
void paged_attention_v1_opt_tc_with_mask( void paged_attention_v2_opt_tc_with_mask(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor&
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
...@@ -944,201 +1036,28 @@ void paged_attention_v1_opt_tc_with_mask( ...@@ -944,201 +1036,28 @@ void paged_attention_v1_opt_tc_with_mask(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len] const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) { const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse|| if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||(get_device_name()!="gfx928" && get_device_name()!="gfx936")){ block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){
paged_attention_v1_opt_with_mask(out,query,key_cache,value_cache,num_kv_heads, paged_attention_v2_with_mask(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype, scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride, k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step, blocksparse_block_size,blocksparse_head_sliding_step,attn_masks,attn_masks_stride);
attn_masks, attn_masks_stride);
} }
else{ else{
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V1_LAUNCHER_BLOCK_SIZE) CALL_V2_LAUNCHER_BLOCK_SIZE)
}
}
#define LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE) \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_with_mask_kernel_TC< \
T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \
IS_BLOCK_SPARSE, REUSE_KV_TIMES,use_vmac, PARTITION_SIZE>), \
dim3(grid), dim3(block), shared_mem_size, stream, exp_sums_ptr, \
max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, \
num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, \
attn_masks_ptr, attn_masks_stride); \
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>), \
dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions);
void get_number_thread_and_reuse_kv_v2(int& num_thread,int& reusekv,int batchsize,int max_num_partitions,int qheads,int kvheads){
reusekv=1;
int blocks=batchsize*qheads*max_num_partitions;
if(qheads==kvheads){
if(blocks<=80||blocks>8000){num_thread=256;}
else if(blocks<=160){num_thread=128;}
else num_thread=64;
return;
}
if(qheads/kvheads>8&&blocks>4000){
reusekv=16;
if(blocks>40000)num_thread=64;
else num_thread=128;
}
else if(qheads/kvheads==5||qheads/kvheads==7){
if(blocks<=160){reusekv=1;num_thread=256;}
else if(blocks<640/5*qheads/kvheads){reusekv=4;num_thread=256;}
else if(blocks<1920){reusekv=8;num_thread=128;}
else {reusekv=8;num_thread=64;}
}
else if(qheads>kvheads*4){
if(blocks<=128){reusekv=1;num_thread=256;}
else if(blocks<1536){reusekv=4;num_thread=256;}
else if(blocks<6144){reusekv=8;num_thread=128;}
else {reusekv=8;num_thread=64;}
}
else {
if(blocks<=128){reusekv=1;num_thread=256;}
else if(blocks<3000){reusekv=4;num_thread=256;}
else {reusekv=4;num_thread=64;}
}
}
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE, int PARTITION_SIZE = 512>
void paged_attention_v2_launcher_opt_tc(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks,
const int attn_masks_stride) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
// printf("paged_attention_v2\n");
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
// NOTE: attn_masks is optional.
const int* attn_masks_ptr =
attn_masks ? attn_masks.value().data_ptr<int>() : nullptr;
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 reduce_grid(num_heads, num_seqs);
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2 && KV_DTYPE==vllm::Fp8KVCacheDataType::kAuto){
//if(head_size==128&&get_device_name()=="gfx928"){
constexpr int HEAD_SIZE=128;
constexpr static int use_vmac = false;
int reusekv, num_thread;
get_number_thread_and_reuse_kv_v2(num_thread,reusekv,num_seqs,max_num_partitions,num_heads,num_kv_heads);
if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES;
if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE;
REUSEKV_SWITCH(reusekv,[&] {
NUM_THREADS_SWITCH(num_thread , [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 2;
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
dim3 grid;
grid.x = (num_heads/num_kv_heads + REUSE_KV_TIMES -1)/REUSE_KV_TIMES * num_kv_heads;
grid.y = max_num_partitions;
grid.z = num_seqs;
dim3 block(NUM_THREADS);
int shared_mem_size = ::max(logits_size, outputs_size);
if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d\n",
reusekv,num_thread,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs);
LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE);
});
});
} }
//}
} }
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ void paged_attention_v1_with_mask(
paged_attention_v2_launcher_opt_tc<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step, attn_masks, attn_masks_stride);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
case true: \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
break; \
case false: \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
break; \
}
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
switch (block_size) { \
case 8: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
break; \
case 16: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
break; \
case 32: \
CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void paged_attention_v2_opt_with_mask(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor&
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
...@@ -1150,19 +1069,15 @@ void paged_attention_v2_opt_with_mask( ...@@ -1150,19 +1069,15 @@ void paged_attention_v2_opt_with_mask(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len] const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride); const int64_t attn_masks_stride);
void paged_attention_v2_opt_tc_with_mask( void paged_attention_v1_opt_tc_with_mask(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor&
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
...@@ -1174,24 +1089,25 @@ void paged_attention_v2_opt_tc_with_mask( ...@@ -1174,24 +1089,25 @@ void paged_attention_v2_opt_tc_with_mask(
torch::Tensor& seq_lens, // [num_seqs] torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len, int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len] const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) { const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse|| if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||(get_device_name()!="gfx928" && get_device_name()!="gfx936")){ block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){
paged_attention_v2_opt_with_mask(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads, paged_attention_v1_with_mask(out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype, scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride, k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step, attn_masks, blocksparse_block_size,blocksparse_head_sliding_step,attn_masks,attn_masks_stride);
attn_masks_stride);
} }
else{ else{
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, paged_attention_v2_opt_tc_with_mask(out,out,out,out,query,key_cache,value_cache,num_kv_heads,
CALL_V2_LAUNCHER_BLOCK_SIZE) scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step,attn_masks,attn_masks_stride);
} }
} }
......
...@@ -57,7 +57,7 @@ void paged_attention_v1_opt( ...@@ -57,7 +57,7 @@ void paged_attention_v1_opt(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
...@@ -68,7 +68,7 @@ void paged_attention_v2_opt( ...@@ -68,7 +68,7 @@ void paged_attention_v2_opt(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
...@@ -78,7 +78,7 @@ void paged_attention_v1_opt_tc( ...@@ -78,7 +78,7 @@ void paged_attention_v1_opt_tc(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
...@@ -89,7 +89,7 @@ void paged_attention_v2_opt_tc( ...@@ -89,7 +89,7 @@ void paged_attention_v2_opt_tc(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step); const int64_t blocksparse_head_sliding_step);
...@@ -101,7 +101,7 @@ void paged_attention_v1_with_mask( ...@@ -101,7 +101,7 @@ void paged_attention_v1_with_mask(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
...@@ -114,7 +114,7 @@ void paged_attention_v2_with_mask( ...@@ -114,7 +114,7 @@ void paged_attention_v2_with_mask(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
...@@ -126,7 +126,7 @@ void paged_attention_v1_opt_with_mask( ...@@ -126,7 +126,7 @@ void paged_attention_v1_opt_with_mask(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
...@@ -139,7 +139,7 @@ void paged_attention_v2_opt_with_mask( ...@@ -139,7 +139,7 @@ void paged_attention_v2_opt_with_mask(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
...@@ -151,7 +151,7 @@ void paged_attention_v1_opt_tc_with_mask( ...@@ -151,7 +151,7 @@ void paged_attention_v1_opt_tc_with_mask(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
...@@ -164,7 +164,7 @@ void paged_attention_v2_opt_tc_with_mask( ...@@ -164,7 +164,7 @@ void paged_attention_v2_opt_tc_with_mask(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes, int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
......
...@@ -58,7 +58,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -58,7 +58,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"); " int blocksparse_head_sliding_step) -> ()");
...@@ -72,7 +72,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -72,7 +72,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"); " int blocksparse_head_sliding_step) -> ()");
...@@ -86,7 +86,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -86,7 +86,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"); " int blocksparse_head_sliding_step) -> ()");
...@@ -100,7 +100,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -100,7 +100,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"); " int blocksparse_head_sliding_step) -> ()");
...@@ -114,7 +114,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -114,7 +114,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step," " int blocksparse_head_sliding_step,"
...@@ -130,7 +130,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -130,7 +130,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step," " int blocksparse_head_sliding_step,"
...@@ -146,7 +146,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -146,7 +146,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step," " int blocksparse_head_sliding_step,"
...@@ -162,7 +162,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -162,7 +162,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step," " int blocksparse_head_sliding_step,"
...@@ -178,7 +178,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -178,7 +178,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step," " int blocksparse_head_sliding_step,"
...@@ -194,7 +194,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -194,7 +194,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor value_cache, int num_kv_heads, float scale," " Tensor value_cache, int num_kv_heads, float scale,"
" Tensor block_tables, Tensor seq_lens, int block_size," " Tensor block_tables, Tensor seq_lens, int block_size,"
" int max_seq_len, Tensor? alibi_slopes," " int max_seq_len, Tensor? alibi_slopes,"
" str kv_cache_dtype, float k_scale, float v_scale," " str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
" int tp_rank, int blocksparse_local_blocks," " int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step," " int blocksparse_head_sliding_step,"
......
...@@ -15,5 +15,5 @@ setuptools_scm>=8 ...@@ -15,5 +15,5 @@ setuptools_scm>=8
torch == 2.4.1 torch == 2.4.1
triton == 3.0.0 triton == 3.0.0
flash_attn == 2.6.1 flash_attn == 2.6.1
lmslim == 0.3.0 lmslim == 0.2.1
numa numa
...@@ -117,8 +117,8 @@ def paged_attention_v1_with_mask( ...@@ -117,8 +117,8 @@ def paged_attention_v1_with_mask(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -152,8 +152,8 @@ def paged_attention_v2_with_mask( ...@@ -152,8 +152,8 @@ def paged_attention_v2_with_mask(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -185,8 +185,8 @@ def paged_attention_v1_opt( ...@@ -185,8 +185,8 @@ def paged_attention_v1_opt(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -217,8 +217,8 @@ def paged_attention_v2_opt( ...@@ -217,8 +217,8 @@ def paged_attention_v2_opt(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -246,8 +246,8 @@ def paged_attention_v1_opt_with_mask( ...@@ -246,8 +246,8 @@ def paged_attention_v1_opt_with_mask(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -281,8 +281,8 @@ def paged_attention_v2_opt_with_mask( ...@@ -281,8 +281,8 @@ def paged_attention_v2_opt_with_mask(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -314,8 +314,8 @@ def paged_attention_v1_opt_tc( ...@@ -314,8 +314,8 @@ def paged_attention_v1_opt_tc(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -345,8 +345,8 @@ def paged_attention_v2_opt_tc( ...@@ -345,8 +345,8 @@ def paged_attention_v2_opt_tc(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -375,8 +375,8 @@ def paged_attention_v1_opt_tc_with_mask( ...@@ -375,8 +375,8 @@ def paged_attention_v1_opt_tc_with_mask(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -409,8 +409,8 @@ def paged_attention_v2_opt_tc_with_mask( ...@@ -409,8 +409,8 @@ def paged_attention_v2_opt_tc_with_mask(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
......
...@@ -324,8 +324,7 @@ class ModelConfig: ...@@ -324,8 +324,7 @@ class ModelConfig:
# Set enforce_eager to False if the value is unset. # Set enforce_eager to False if the value is unset.
if self.enforce_eager is None: if self.enforce_eager is None:
# self.enforce_eager = False self.enforce_eager = False
self.enforce_eager = True
sliding_window = getattr(self.hf_text_config, "sliding_window", None) sliding_window = getattr(self.hf_text_config, "sliding_window", None)
has_interleaved_attention = (sliding_window is not None) and ( has_interleaved_attention = (sliding_window is not None) and (
......
...@@ -565,7 +565,6 @@ class EngineArgs: ...@@ -565,7 +565,6 @@ class EngineArgs:
'parsed into a dictionary.') 'parsed into a dictionary.')
parser.add_argument('--enforce-eager', parser.add_argument('--enforce-eager',
action='store_true', action='store_true',
default=True,
help='Always use eager-mode PyTorch. If False, ' help='Always use eager-mode PyTorch. If False, '
'will use eager mode and CUDA graph in hybrid ' 'will use eager mode and CUDA graph in hybrid '
'for maximal performance and flexibility.') 'for maximal performance and flexibility.')
......
...@@ -666,8 +666,6 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -666,8 +666,6 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
# 暂时awq不支持cutlass
envs.VLLM_USE_TRITON_AWQ = True
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
...@@ -871,13 +869,13 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -871,13 +869,13 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
weight.data.copy_(_weight) weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1) weight.data=weight.data.reshape(ori_shape[1],-1)
# 暂时不支持TN
if self.quant_method == "awq" and not envs.VLLM_USE_TRITON_AWQ: if self.config.quantization_config["quant_method"] == "awq" and not envs.VLLM_USE_TRITON_AWQ:
lay_key_words = [ lay_key_words = [
"self_attn.q_a_proj.qweight", "self_attn.q_a_proj.qweight",
"self_attn.q_b_proj.qweight", "self_attn.q_b_proj.qweight",
"self_attn.kv_a_proj_with_mqa.qweight",
"self_attn.kv_b_proj.qweight", "self_attn.kv_b_proj.qweight",
"self_attn.kv_a_proj_with_mqa.qweight",
"self_attn.o_proj.qweight", "self_attn.o_proj.qweight",
"mlp.gate_up_proj.qweight", "mlp.gate_up_proj.qweight",
"mlp.down_proj.qweight", "mlp.down_proj.qweight",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment