Commit fbde1e5a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'origin/0.7.2-zhangshao' into v0.7.2-pa

parents 146eb9d3 228a714a
This diff is collapsed.
...@@ -316,13 +316,12 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -316,13 +316,12 @@ __global__ void paged_attention_kernel_TC_with_mask(
} }
__syncthreads(); __syncthreads();
constexpr int NUM_ROWS_PER_THREAD =DIVIDE_ROUND_UP(HEAD_SIZE, WARP_SIZE);//2 constexpr int NUM_ROWS_PER_THREAD =DIVIDE_ROUND_UP(HEAD_SIZE, WARP_SIZE);//2
if (q_boundary<=2){ if constexpr(REUSE_KV_TIMES<=2){
constexpr int acc_size = REUSE_KV_TIMES==1?1:2; float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD];
float accs[acc_size][NUM_ROWS_PER_THREAD];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
#pragma unroll #pragma unroll
for(int k=0;k<acc_size;k++) for(int k=0;k<REUSE_KV_TIMES;k++)
{ {
accs[k][i] = 0.f; accs[k][i] = 0.f;
} }
...@@ -356,7 +355,7 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -356,7 +355,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
float4_t out_vec={0,0,0,0}; float4_t out_vec={0,0,0,0};
builtin_amdgcn_mmac<is_half>(v_vec,logits_vec,out_vec); builtin_amdgcn_mmac<is_half>(v_vec,logits_vec,out_vec);
if(rows==k){ if(rows==k){
for(int resuseid=0;resuseid<acc_size;resuseid++){ for(int resuseid=0;resuseid<REUSE_KV_TIMES;resuseid++){
accs[resuseid][i]+=out_vec[resuseid]; accs[resuseid][i]+=out_vec[resuseid];
} }
} }
...@@ -366,8 +365,7 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -366,8 +365,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
__syncthreads(); __syncthreads();
using floatV_t = __attribute__( (__vector_size__(NUM_ROWS_PER_THREAD * sizeof(float)) )) float; using floatV_t = __attribute__( (__vector_size__(NUM_ROWS_PER_THREAD * sizeof(float)) )) float;
// Perform reduction across warps. // Perform reduction across warps.
#pragma unroll for(int reuse_kv_idx=0; reuse_kv_idx<q_boundary; reuse_kv_idx++) {
for(int reuse_kv_idx=0; reuse_kv_idx<acc_size; reuse_kv_idx++) {
if constexpr (NUM_THREADS>64){ if constexpr (NUM_THREADS>64){
floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem); floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem);
#pragma unroll #pragma unroll
...@@ -780,97 +778,9 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_v2_reduce_kern ...@@ -780,97 +778,9 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_v2_reduce_kern
max_num_partitions,PARTITION_SIZE);} max_num_partitions,PARTITION_SIZE);}
static void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITION_SIZE,int &max_num_partitions, void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITION_SIZE,int &max_num_partitions,
int batchsize,int max_seq_len,int qheads,int kvheads,int num_blocks) int batchsize,int max_seq_len,int qheads,int kvheads,int num_blocks);
{
reusekv=1;
num_thread=256;
PARTITION_SIZE=512;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
if(max_seq_len==8192&&num_blocks==1024){//ali test
if(batchsize==1&&qheads==16&&kvheads==16){num_thread=128;return;}
if(batchsize==1&&qheads==32&&kvheads==32){num_thread=64;return;}
if(batchsize==1){
if(qheads==52){reusekv=8;return;}
if(qheads==13){reusekv=2;return;}
reusekv=4;return;
}
if(batchsize==64){
if(qheads==13){PARTITION_SIZE=256;num_thread=128;reusekv=8;}
else if(qheads==32){PARTITION_SIZE=1024;reusekv=8;}
else if(qheads==52||qheads==26){reusekv=16;}
else reusekv=8;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
return;
}
}
if(qheads==kvheads){
if(max_seq_len<=8192){
if(batchsize*qheads>=512){
max_num_partitions=1;
num_thread=64;
}
if(qheads==32&&max_seq_len<=1024)max_num_partitions=1;
}
return;
}
if(max_seq_len<800)max_num_partitions=1;
if(qheads>kvheads*4){
if(max_seq_len<=1000||
max_seq_len<1500&&(batchsize>=8&&qheads>=8||batchsize>=64)||
max_seq_len<1900&&batchsize>=8&&qheads==28
)
max_num_partitions=1;
int blocks=max_num_partitions*batchsize*qheads;
if(device_name=="gfx928"){
if(batchsize*qheads>1024&&max_seq_len>=2000){
max_num_partitions=1;
if(max_seq_len<3900)reusekv=8;
else if(max_seq_len<7800)reusekv=4;
else{
PARTITION_SIZE=2048;
reusekv=8;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
}
return;
}
}
if(max_num_partitions==1){
if(max_seq_len<512){
int bytes=max_seq_len*qheads*batchsize;
if(bytes<51200)reusekv=1;
else if(bytes<256000)reusekv=4;
else reusekv=8;
return;
}
if(batchsize<4||batchsize==4&&qheads==8)reusekv=1;
else if(batchsize<32||batchsize<=64&&qheads==8)reusekv=4;
else reusekv=8;
return;
}
if(blocks<150)return;
if(blocks<600||qheads<=kvheads*4){reusekv=4;return;}
reusekv=8;return;
}
if(device_name=="gfx928"){
if(batchsize*qheads>1024&&max_seq_len>=2000){
max_num_partitions=1;
if(max_seq_len<7800)reusekv=4;
else{
PARTITION_SIZE=2048;
reusekv=4;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
}
return;
}
}
if(max_seq_len<=1000||
max_seq_len<=1500&&(qheads>4&&batchsize>=16||batchsize>=64))
max_num_partitions=1;
int blocks=max_num_partitions*batchsize*qheads;
if(blocks>=150||batchsize>=16||qheads>=8&&(batchsize>=4||max_seq_len>=2000))reusekv=4;
}
template <typename T, typename CACHE_T, int BLOCK_SIZE, template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE> vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE>
void paged_attention_v2_launcher_opt_tc_with_mask( void paged_attention_v2_launcher_opt_tc_with_mask(
......
...@@ -14,7 +14,8 @@ if HAS_TRITON: ...@@ -14,7 +14,8 @@ if HAS_TRITON:
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
gpuname = torch.cuda.get_device_properties(torch.cuda.current_device()).name
support_tc = gpuname.startswith('K100_AI') or gpuname.startswith('BW')
@dataclass @dataclass
class PagedAttentionMetadata: class PagedAttentionMetadata:
...@@ -128,22 +129,13 @@ class PagedAttention: ...@@ -128,22 +129,13 @@ class PagedAttention:
# to parallelize. # to parallelize.
# TODO(woosuk): Tune this heuristic. # TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage. # For context len > 8192, use V2 kernel to avoid shared memory shortage.
if envs.VLLM_USE_TC_PAGED_ATTN:
use_v1 = (max_seq_len < 8192
and (max_seq_len<(1024 if num_kv_heads == num_heads else 600) or num_seqs * num_heads > (1024 if num_kv_heads < num_heads else 512)))
else:
use_v1 = (max_seq_len <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
if use_v1:
# Run PagedAttention V1. if envs.VLLM_USE_TC_PAGED_ATTN and support_tc:
if envs.VLLM_USE_PA_PRINT_PARAM: if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA V1 SIZE:") print("PA V1 SIZE:")
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}") print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}") print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if envs.VLLM_USE_OPT_OP:
if envs.VLLM_USE_TC_PAGED_ATTN:
if attn_masks is None: if attn_masks is None:
ops.paged_attention_v1_opt_tc( ops.paged_attention_v1_opt_tc(
output, output,
...@@ -190,7 +182,18 @@ class PagedAttention: ...@@ -190,7 +182,18 @@ class PagedAttention:
attn_masks, attn_masks,
attn_masks_stride attn_masks_stride
) )
else: return output
use_v1 = (max_seq_len <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512))
if use_v1:
# Run PagedAttention V1.
if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA V1 SIZE:")
print(f"query.shape = {query.shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}")
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if envs.VLLM_USE_OPT_OP:
if attn_masks is None: if attn_masks is None:
ops.paged_attention_v1_opt( ops.paged_attention_v1_opt(
output, output,
...@@ -306,60 +309,6 @@ class PagedAttention: ...@@ -306,60 +309,6 @@ class PagedAttention:
print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}") print(f"num_kv_heads = {num_kv_heads}, scale = {scale:.3f}, block_tables.shape = {block_tables.shape}, seq_lens.shape = {seq_lens.shape}, block_size = {block_size}, max_seq_len = {max_seq_len}")
if envs.VLLM_USE_OPT_OP: if envs.VLLM_USE_OPT_OP:
if envs.VLLM_USE_TC_PAGED_ATTN:
if attn_masks is None:
ops.paged_attention_v2_opt_tc(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step
)
else:
ops.paged_attention_v2_opt_tc_with_mask(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
tp_rank,
blocksparse_local_blocks,
blocksparse_vert_stride,
blocksparse_block_size,
blocksparse_head_sliding_step,
attn_masks,
attn_masks_stride
)
else:
if attn_masks is None: if attn_masks is None:
ops.paged_attention_v2_opt( ops.paged_attention_v2_opt(
output, output,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment