Commit 2b91ac93 authored by zhuwenwen's avatar zhuwenwen
Browse files

优化pa小batch性能(pa_v2),优化pa小seq性能(pa_v1),reusekv=16优化

parent de7d9456
......@@ -43,7 +43,16 @@ inline std::string get_device_name()
const std::string raw_name(props.gcnArchName);
return raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str.
}
static inline int get_env_(const char *env_var) {
if (char *value = std::getenv(env_var)) {
return atoi(value);
}
return 0;
}
static const int PA_REUSE_KV_TIMES = get_env_("PA_REUSE_KV_TIMES");
static const int PA_BLOCK_SIZE = get_env_("PA_BLOCK_SIZE");
static const int PA_PRINT_PARAM = get_env_("PA_PRINT_PARAM");
namespace vllm {
// Utility function for attention softmax.
......@@ -344,7 +353,7 @@ __device__ void paged_attention_kernel_TC(
}
constexpr int NUM_ROWS_PER_THREAD =DIVIDE_ROUND_UP(HEAD_SIZE, WARP_SIZE);//2
if constexpr(REUSE_KV_TIMES<=2&&(NUM_WARPS>64||USE_PARTITIONING)){
if constexpr(REUSE_KV_TIMES<=2){
float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
......@@ -723,10 +732,64 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt(
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads,num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);
void get_numberthread_and_reuse_kv_v1(int& num_thread,int& reusekv,int batchsize,int seq,int qheads,int kvheads){
//mha
reusekv=1;
if(qheads==kvheads){
//llama 7B ,其他模型未可知
if(seq<=16||batchsize>=32)num_thread=64;
else if(batchsize<=2)num_thread=256;
else if(batchsize<8)num_thread=128;
else num_thread=64;
return;
}
// mqa
if(qheads>kvheads*4){
if(seq<64){
if(batchsize<=64){reusekv=1;num_thread=64;}
else if(batchsize<128){reusekv=2;num_thread=64;}
else {reusekv=4;num_thread=64;}
}
else if(seq<=400){
if(batchsize<16){reusekv=1;num_thread=256;}
else if(batchsize<64){reusekv=2;num_thread=256;}
else if(batchsize<=128){
reusekv=4;
if(qheads%7==0)num_thread=64;//qwen7b
else num_thread=256;//llama70b
}
else {reusekv=8;num_thread=64;}
}
else if(seq<=1000){
if(batchsize<16){reusekv=1;num_thread=256;}
else if(qheads%7==0&&batchsize<=128){//qwen7b
if(batchsize<64){reusekv=4;num_thread=256;}
else{reusekv=4;num_thread=64;}
}
else if(batchsize<=64){reusekv=4;num_thread=256;}
else {reusekv=8;num_thread=128;}
}
else if(seq<3900) {reusekv=8;num_thread=256;}
else if(seq<7800) {reusekv=4;num_thread=256;}
else {reusekv=2;num_thread=256;}
return;
}
if(qheads/kvheads >4 && seq<3900)reusekv=8;
else if(qheads/kvheads >2 && seq<7800)reusekv=4;
else if(qheads/kvheads >=2 && seq<15600)reusekv=2;
if(seq<=64){
num_thread=64;
if(batchsize<=64)reusekv=1;
}
else num_thread=256;
}
// TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE>
......@@ -750,7 +813,7 @@ void paged_attention_v1_launcher_opt(
if (num_heads != num_kv_heads) {
num_threads = 256;
}
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional.
......@@ -769,44 +832,40 @@ void paged_attention_v1_launcher_opt(
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2 && KV_DTYPE==vllm::Fp8KVCacheDataType::kAuto){
// if(head_size==128&&get_device_name()=="gfx928"){
REUSEKV_SWITCH_V1([&] {
constexpr int HEAD_SIZE=128;
// constexpr int REUSE_KV_TIMES=8;
int num_thread=64;
if(REUSE_KV_TIMES>1){
if(padded_max_seq_len>1024||num_heads * num_seqs/REUSE_KV_TIMES<600)num_thread=256;
else num_thread=128;
}
else if(num_heads * num_seqs<800)num_thread=128;
NUM_THREADS_SWITCH(num_thread , [&] {
constexpr static int use_vmac = false;
//constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES * padded_max_seq_len * 2;
int outputs_size = REUSE_KV_TIMES * (NUM_WARPS / 2) * head_size * sizeof(float);
if(REUSE_KV_TIMES==1)outputs_size=0;
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int shared_mem_size = ::max(logits_size, outputs_size);
if(num_heads == num_kv_heads) shared_mem_size = ::max(12 * 1024, shared_mem_size);
// int shared_mem_size = ::max(31*1024, ::max(logits_size, outputs_size));
// std::cout<<"shared_mem_size = "<<shared_mem_size<<std::endl;
// printf("REUSE_KV_TIMES=%d,use_vmac=%d\n",REUSE_KV_TIMES,(int)use_vmac);
dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1,num_seqs);
dim3 block(NUM_THREADS);
LAUNCH_PAGED_ATTENTION_V1_TC(HEAD_SIZE);
});
constexpr int HEAD_SIZE=128;
constexpr static int use_vmac = false;
int reusekv, num_thread;
get_numberthread_and_reuse_kv_v1(num_thread,reusekv,num_seqs,padded_max_seq_len,num_heads,num_kv_heads);
if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES;
if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE;
if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d\n",reusekv,num_thread);
REUSEKV_SWITCH(reusekv,[&] {
NUM_THREADS_SWITCH(num_thread , [&] {
//constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES * padded_max_seq_len * 2;
int outputs_size = REUSE_KV_TIMES * (NUM_WARPS / 2) * head_size * sizeof(float);
if(REUSE_KV_TIMES==1)outputs_size=0;
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int shared_mem_size = ::max(logits_size, outputs_size);
if(num_heads == num_kv_heads) shared_mem_size = ::max(12 * 1024, shared_mem_size);
// int shared_mem_size = ::max(31*1024, ::max(logits_size, outputs_size));
// std::cout<<"shared_mem_size = "<<shared_mem_size<<std::endl;
// printf("REUSE_KV_TIMES=%d,use_vmac=%d\n",REUSE_KV_TIMES,(int)use_vmac);
dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1,num_seqs);
dim3 block(NUM_THREADS);
LAUNCH_PAGED_ATTENTION_V1_TC(HEAD_SIZE);
});
});
}
// }
}
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v1_launcher_opt<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step);
......@@ -902,10 +961,43 @@ void paged_attention_v1_opt(
hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>), \
dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, out_ptr, \
dim3(reduce_grid), dim3(128), reduce_shared_mem_size, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions);
void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int batchsize,int max_num_partitions,int qheads,int kvheads){
reusekv=1;
int blocks=batchsize*qheads*max_num_partitions;
if(qheads==kvheads){
if(blocks<=80||blocks>8000){num_thread=256;}
else if(blocks<=160){num_thread=128;}
else num_thread=64;
return;
}
if(qheads/kvheads>8&&blocks>4000){
reusekv=16;
if(blocks>40000)num_thread=64;
else num_thread=128;
}
else if(qheads/kvheads==5||qheads/kvheads==7){
if(blocks<=160){reusekv=1;num_thread=256;}
else if(blocks<640/5*qheads/kvheads){reusekv=4;num_thread=256;}
else if(blocks<1920){reusekv=8;num_thread=128;}
else {reusekv=8;num_thread=64;}
}
else if(qheads>kvheads*4){
if(blocks<=128){reusekv=1;num_thread=256;}
else if(blocks<1536){reusekv=4;num_thread=256;}
else if(blocks<6144){reusekv=8;num_thread=128;}
else {reusekv=8;num_thread=64;}
}
else {
if(blocks<=128){reusekv=1;num_thread=256;}
else if(blocks<3000){reusekv=4;num_thread=256;}
else {reusekv=4;num_thread=64;}
}
}
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE, int PARTITION_SIZE = 512>
void paged_attention_v2_launcher_opt(
......@@ -953,17 +1045,12 @@ void paged_attention_v2_launcher_opt(
//if(head_size==128&&get_device_name()=="gfx928"){
constexpr int HEAD_SIZE=128;
constexpr static int use_vmac = false;
REUSEKV_SWITCH_V2([&] {
int num_thread;
if(REUSE_KV_TIMES>1){
if(num_seqs<16)num_thread=256;
else if(max_num_partitions*num_seqs*num_heads/REUSE_KV_TIMES>4000)num_thread=64;
else num_thread=128;
}
else{
if(num_seqs<16&&max_num_partitions<10)num_thread=256;
else num_thread=64;
}
int reusekv, num_thread;
get_numberthread_and_reuse_kv_v2(num_thread,reusekv,num_seqs,max_num_partitions,num_heads,num_kv_heads);
if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES;
if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE;
if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d\n",reusekv,num_thread);
REUSEKV_SWITCH(reusekv,[&] {
NUM_THREADS_SWITCH(num_thread , [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 2;
......@@ -982,7 +1069,7 @@ void paged_attention_v2_launcher_opt(
}
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v2_launcher_opt<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
paged_attention_v2_launcher_opt<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
......
......@@ -52,47 +52,18 @@
} \
}()
#define REUSEKV_SWITCH(num_blocks , ...) \
#define REUSEKV_SWITCH(reusekv,...) \
[&] { \
if (num_heads % 2 == 0 && num_heads / num_kv_heads >= 4 && num_blocks >= 1200){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} else if (num_heads / num_kv_heads >= 2 && num_blocks >= 1200){\
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define REUSEKV_SWITCH_V2( ...) \
[&] { \
if (num_heads / num_kv_heads > 8 ){ \
if (reusekv==16){ \
constexpr static int REUSE_KV_TIMES = 16; \
return __VA_ARGS__(); \
}else if (num_heads / num_kv_heads > 4 ){ \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
}else if (num_heads / num_kv_heads > 2 ){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define REUSEKV_SWITCH_V1( ...) \
[&] { \
if (num_heads/num_kv_heads >4 && padded_max_seq_len<3900){ \
return __VA_ARGS__();} \
else if (reusekv==8){ \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
}else if (num_heads/num_kv_heads >2 && padded_max_seq_len<7800){ \
}else if (reusekv==4){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
}else if (num_heads/num_kv_heads ==2 && padded_max_seq_len<15600){ \
}else if (reusekv==2){ \
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
}else { \
......
......@@ -127,7 +127,7 @@ class PagedAttention:
# use_v1 = (max_seq_len <= 8192
# and (max_num_partitions == 1 or num_seqs * num_heads > 512))
use_v1 = (max_seq_len < 8192
and (max_seq_len<1000 or num_seqs * num_heads > (1024 if num_kv_heads < num_heads else 512)))
and (max_seq_len<(1024 if num_kv_heads == num_heads else 600) or num_seqs * num_heads > (1024 if num_kv_heads < num_heads else 512)))
if use_v1:
# Run PagedAttention V1.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment