Commit 0d2ccb8d authored by zhangshao's avatar zhangshao
Browse files

调整pa tc和非tc调用关系

parent 4f8d38c8
...@@ -999,28 +999,6 @@ void paged_attention_v2_launcher_opt_tc( ...@@ -999,28 +999,6 @@ void paged_attention_v2_launcher_opt_tc(
break; \ break; \
} }
void paged_attention_v2(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor&
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v2_opt_tc( void paged_attention_v2_opt_tc(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
...@@ -1043,37 +1021,10 @@ void paged_attention_v2_opt_tc( ...@@ -1043,37 +1021,10 @@ void paged_attention_v2_opt_tc(
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) { const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){
paged_attention_v2(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step);
}
else{
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE) CALL_V2_LAUNCHER_BLOCK_SIZE)
}
} }
void paged_attention_v1(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);
void paged_attention_v1_opt_tc( void paged_attention_v1_opt_tc(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size]
...@@ -1091,20 +1042,10 @@ void paged_attention_v1_opt_tc( ...@@ -1091,20 +1042,10 @@ void paged_attention_v1_opt_tc(
const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) { const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){
paged_attention_v1(out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step);
}
else{
paged_attention_v2_opt_tc(out,out,out,out,query,key_cache,value_cache,num_kv_heads, paged_attention_v2_opt_tc(out,out,out,out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype, scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride, k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step); blocksparse_block_size,blocksparse_head_sliding_step);
}
} }
#undef WARP_SIZE #undef WARP_SIZE
......
...@@ -910,30 +910,6 @@ void paged_attention_v2_launcher_opt_tc_with_mask( ...@@ -910,30 +910,6 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
break; \ break; \
} }
void paged_attention_v2_with_mask(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor&
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,// [num_seqs, max_seq_len]
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride);
void paged_attention_v2_opt_tc_with_mask( void paged_attention_v2_opt_tc_with_mask(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
...@@ -958,38 +934,10 @@ void paged_attention_v2_opt_tc_with_mask( ...@@ -958,38 +934,10 @@ void paged_attention_v2_opt_tc_with_mask(
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len] const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) { const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){
paged_attention_v2_with_mask(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step,attn_masks,attn_masks_stride);
}
else{
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE) CALL_V2_LAUNCHER_BLOCK_SIZE)
}
} }
void paged_attention_v1_with_mask(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride);
void paged_attention_v1_opt_tc_with_mask( void paged_attention_v1_opt_tc_with_mask(
torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& out, // [num_seqs, num_heads, head_size]
...@@ -1010,20 +958,10 @@ void paged_attention_v1_opt_tc_with_mask( ...@@ -1010,20 +958,10 @@ void paged_attention_v1_opt_tc_with_mask(
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len] const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) { const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){
paged_attention_v1_with_mask(out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step,attn_masks,attn_masks_stride);
}
else{
paged_attention_v2_opt_tc_with_mask(out,out,out,out,query,key_cache,value_cache,num_kv_heads, paged_attention_v2_opt_tc_with_mask(out,out,out,out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype, scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride, k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step,attn_masks,attn_masks_stride); blocksparse_block_size,blocksparse_head_sliding_step,attn_masks,attn_masks_stride);
}
} }
#undef WARP_SIZE #undef WARP_SIZE
......
...@@ -16,6 +16,7 @@ if HAS_TRITON: ...@@ -16,6 +16,7 @@ if HAS_TRITON:
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
gpuname = torch.cuda.get_device_properties(torch.cuda.current_device()).name gpuname = torch.cuda.get_device_properties(torch.cuda.current_device()).name
support_tc = gpuname.startswith('K100_AI') or gpuname.startswith('BW') support_tc = gpuname.startswith('K100_AI') or gpuname.startswith('BW')
use_tc = envs.VLLM_USE_OPT_OP and envs.VLLM_USE_TC_PAGED_ATTN and support_tc
@dataclass @dataclass
class PagedAttentionMetadata: class PagedAttentionMetadata:
...@@ -131,7 +132,7 @@ class PagedAttention: ...@@ -131,7 +132,7 @@ class PagedAttention:
# For context len > 8192, use V2 kernel to avoid shared memory shortage. # For context len > 8192, use V2 kernel to avoid shared memory shortage.
if envs.VLLM_USE_TC_PAGED_ATTN and support_tc: if use_tc and head_size==128:
if envs.VLLM_USE_PA_PRINT_PARAM: if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA V1 SIZE:") print("PA V1 SIZE:")
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}") print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment