"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "9d72daf4ced05a5fec1ad8ea2914a39296f402da"
Commit 2b91ac93 authored by zhuwenwen's avatar zhuwenwen
Browse files

优化pa小batch性能(pa_v2),优化pa小seq性能(pa_v1),reusekv=16优化

parent de7d9456
...@@ -43,7 +43,16 @@ inline std::string get_device_name() ...@@ -43,7 +43,16 @@ inline std::string get_device_name()
const std::string raw_name(props.gcnArchName); const std::string raw_name(props.gcnArchName);
return raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str. return raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str.
} }
static inline int get_env_(const char *env_var) {
if (char *value = std::getenv(env_var)) {
return atoi(value);
}
return 0;
}
static const int PA_REUSE_KV_TIMES = get_env_("PA_REUSE_KV_TIMES");
static const int PA_BLOCK_SIZE = get_env_("PA_BLOCK_SIZE");
static const int PA_PRINT_PARAM = get_env_("PA_PRINT_PARAM");
namespace vllm { namespace vllm {
// Utility function for attention softmax. // Utility function for attention softmax.
...@@ -344,7 +353,7 @@ __device__ void paged_attention_kernel_TC( ...@@ -344,7 +353,7 @@ __device__ void paged_attention_kernel_TC(
} }
constexpr int NUM_ROWS_PER_THREAD =DIVIDE_ROUND_UP(HEAD_SIZE, WARP_SIZE);//2 constexpr int NUM_ROWS_PER_THREAD =DIVIDE_ROUND_UP(HEAD_SIZE, WARP_SIZE);//2
if constexpr(REUSE_KV_TIMES<=2&&(NUM_WARPS>64||USE_PARTITIONING)){ if constexpr(REUSE_KV_TIMES<=2){
float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD]; float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
...@@ -723,10 +732,64 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt( ...@@ -723,10 +732,64 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt(
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads,num_kv_heads, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads,num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step); blocksparse_head_sliding_step);
void get_numberthread_and_reuse_kv_v1(int& num_thread,int& reusekv,int batchsize,int seq,int qheads,int kvheads){
//mha
reusekv=1;
if(qheads==kvheads){
//llama 7B ,其他模型未可知
if(seq<=16||batchsize>=32)num_thread=64;
else if(batchsize<=2)num_thread=256;
else if(batchsize<8)num_thread=128;
else num_thread=64;
return;
}
// mqa
if(qheads>kvheads*4){
if(seq<64){
if(batchsize<=64){reusekv=1;num_thread=64;}
else if(batchsize<128){reusekv=2;num_thread=64;}
else {reusekv=4;num_thread=64;}
}
else if(seq<=400){
if(batchsize<16){reusekv=1;num_thread=256;}
else if(batchsize<64){reusekv=2;num_thread=256;}
else if(batchsize<=128){
reusekv=4;
if(qheads%7==0)num_thread=64;//qwen7b
else num_thread=256;//llama70b
}
else {reusekv=8;num_thread=64;}
}
else if(seq<=1000){
if(batchsize<16){reusekv=1;num_thread=256;}
else if(qheads%7==0&&batchsize<=128){//qwen7b
if(batchsize<64){reusekv=4;num_thread=256;}
else{reusekv=4;num_thread=64;}
}
else if(batchsize<=64){reusekv=4;num_thread=256;}
else {reusekv=8;num_thread=128;}
}
else if(seq<3900) {reusekv=8;num_thread=256;}
else if(seq<7800) {reusekv=4;num_thread=256;}
else {reusekv=2;num_thread=256;}
return;
}
if(qheads/kvheads >4 && seq<3900)reusekv=8;
else if(qheads/kvheads >2 && seq<7800)reusekv=4;
else if(qheads/kvheads >=2 && seq<15600)reusekv=2;
if(seq<=64){
num_thread=64;
if(batchsize<=64)reusekv=1;
}
else num_thread=256;
}
// TODO(woosuk): Tune NUM_THREADS. // TODO(woosuk): Tune NUM_THREADS.
template <typename T, typename CACHE_T, int BLOCK_SIZE, template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE> vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE>
...@@ -750,7 +813,7 @@ void paged_attention_v1_launcher_opt( ...@@ -750,7 +813,7 @@ void paged_attention_v1_launcher_opt(
if (num_heads != num_kv_heads) { if (num_heads != num_kv_heads) {
num_threads = 256; num_threads = 256;
} }
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0); assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional. // NOTE: alibi_slopes is optional.
...@@ -769,44 +832,40 @@ void paged_attention_v1_launcher_opt( ...@@ -769,44 +832,40 @@ void paged_attention_v1_launcher_opt(
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2 && KV_DTYPE==vllm::Fp8KVCacheDataType::kAuto){ if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2 && KV_DTYPE==vllm::Fp8KVCacheDataType::kAuto){
// if(head_size==128&&get_device_name()=="gfx928"){ constexpr int HEAD_SIZE=128;
REUSEKV_SWITCH_V1([&] { constexpr static int use_vmac = false;
constexpr int HEAD_SIZE=128; int reusekv, num_thread;
// constexpr int REUSE_KV_TIMES=8; get_numberthread_and_reuse_kv_v1(num_thread,reusekv,num_seqs,padded_max_seq_len,num_heads,num_kv_heads);
int num_thread=64; if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES;
if(REUSE_KV_TIMES>1){ if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE;
if(padded_max_seq_len>1024||num_heads * num_seqs/REUSE_KV_TIMES<600)num_thread=256; if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d\n",reusekv,num_thread);
else num_thread=128; REUSEKV_SWITCH(reusekv,[&] {
} NUM_THREADS_SWITCH(num_thread , [&] {
else if(num_heads * num_seqs<800)num_thread=128; //constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES;
NUM_THREADS_SWITCH(num_thread , [&] { constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
constexpr static int use_vmac = false; int logits_size = REUSE_KV_TIMES * padded_max_seq_len * 2;
//constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES; int outputs_size = REUSE_KV_TIMES * (NUM_WARPS / 2) * head_size * sizeof(float);
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; if(REUSE_KV_TIMES==1)outputs_size=0;
int logits_size = REUSE_KV_TIMES * padded_max_seq_len * 2; // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
int outputs_size = REUSE_KV_TIMES * (NUM_WARPS / 2) * head_size * sizeof(float); // Keep that in sync with the logic here!
if(REUSE_KV_TIMES==1)outputs_size=0; int shared_mem_size = ::max(logits_size, outputs_size);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len if(num_heads == num_kv_heads) shared_mem_size = ::max(12 * 1024, shared_mem_size);
// Keep that in sync with the logic here! // int shared_mem_size = ::max(31*1024, ::max(logits_size, outputs_size));
int shared_mem_size = ::max(logits_size, outputs_size); // std::cout<<"shared_mem_size = "<<shared_mem_size<<std::endl;
if(num_heads == num_kv_heads) shared_mem_size = ::max(12 * 1024, shared_mem_size); // printf("REUSE_KV_TIMES=%d,use_vmac=%d\n",REUSE_KV_TIMES,(int)use_vmac);
// int shared_mem_size = ::max(31*1024, ::max(logits_size, outputs_size)); dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1,num_seqs);
// std::cout<<"shared_mem_size = "<<shared_mem_size<<std::endl; dim3 block(NUM_THREADS);
// printf("REUSE_KV_TIMES=%d,use_vmac=%d\n",REUSE_KV_TIMES,(int)use_vmac); LAUNCH_PAGED_ATTENTION_V1_TC(HEAD_SIZE);
dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1,num_seqs);
dim3 block(NUM_THREADS);
LAUNCH_PAGED_ATTENTION_V1_TC(HEAD_SIZE);
});
}); });
});
} }
// }
} }
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v1_launcher_opt<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \ paged_attention_v1_launcher_opt<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \ IS_BLOCK_SPARSE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); blocksparse_block_size, blocksparse_head_sliding_step);
...@@ -902,10 +961,43 @@ void paged_attention_v1_opt( ...@@ -902,10 +961,43 @@ void paged_attention_v1_opt(
hipLaunchKernelGGL( \ hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \ (vllm::paged_attention_v2_reduce_kernel_opt<T, HEAD_SIZE, NUM_THREADS, \
PARTITION_SIZE>), \ PARTITION_SIZE>), \
dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, out_ptr, \ dim3(reduce_grid), dim3(128), reduce_shared_mem_size, stream, out_ptr, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
max_num_partitions); max_num_partitions);
void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int batchsize,int max_num_partitions,int qheads,int kvheads){
reusekv=1;
int blocks=batchsize*qheads*max_num_partitions;
if(qheads==kvheads){
if(blocks<=80||blocks>8000){num_thread=256;}
else if(blocks<=160){num_thread=128;}
else num_thread=64;
return;
}
if(qheads/kvheads>8&&blocks>4000){
reusekv=16;
if(blocks>40000)num_thread=64;
else num_thread=128;
}
else if(qheads/kvheads==5||qheads/kvheads==7){
if(blocks<=160){reusekv=1;num_thread=256;}
else if(blocks<640/5*qheads/kvheads){reusekv=4;num_thread=256;}
else if(blocks<1920){reusekv=8;num_thread=128;}
else {reusekv=8;num_thread=64;}
}
else if(qheads>kvheads*4){
if(blocks<=128){reusekv=1;num_thread=256;}
else if(blocks<1536){reusekv=4;num_thread=256;}
else if(blocks<6144){reusekv=8;num_thread=128;}
else {reusekv=8;num_thread=64;}
}
else {
if(blocks<=128){reusekv=1;num_thread=256;}
else if(blocks<3000){reusekv=4;num_thread=256;}
else {reusekv=4;num_thread=64;}
}
}
template <typename T, typename CACHE_T, int BLOCK_SIZE, template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE, int PARTITION_SIZE = 512> vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE, int PARTITION_SIZE = 512>
void paged_attention_v2_launcher_opt( void paged_attention_v2_launcher_opt(
...@@ -953,17 +1045,12 @@ void paged_attention_v2_launcher_opt( ...@@ -953,17 +1045,12 @@ void paged_attention_v2_launcher_opt(
//if(head_size==128&&get_device_name()=="gfx928"){ //if(head_size==128&&get_device_name()=="gfx928"){
constexpr int HEAD_SIZE=128; constexpr int HEAD_SIZE=128;
constexpr static int use_vmac = false; constexpr static int use_vmac = false;
REUSEKV_SWITCH_V2([&] { int reusekv, num_thread;
int num_thread; get_numberthread_and_reuse_kv_v2(num_thread,reusekv,num_seqs,max_num_partitions,num_heads,num_kv_heads);
if(REUSE_KV_TIMES>1){ if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES;
if(num_seqs<16)num_thread=256; if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE;
else if(max_num_partitions*num_seqs*num_heads/REUSE_KV_TIMES>4000)num_thread=64; if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d\n",reusekv,num_thread);
else num_thread=128; REUSEKV_SWITCH(reusekv,[&] {
}
else{
if(num_seqs<16&&max_num_partitions<10)num_thread=256;
else num_thread=64;
}
NUM_THREADS_SWITCH(num_thread , [&] { NUM_THREADS_SWITCH(num_thread , [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 2; int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 2;
...@@ -982,7 +1069,7 @@ void paged_attention_v2_launcher_opt( ...@@ -982,7 +1069,7 @@ void paged_attention_v2_launcher_opt(
} }
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
paged_attention_v2_launcher_opt<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \ paged_attention_v2_launcher_opt<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
IS_BLOCK_SPARSE>( \ IS_BLOCK_SPARSE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
......
...@@ -52,47 +52,18 @@ ...@@ -52,47 +52,18 @@
} \ } \
}() }()
#define REUSEKV_SWITCH(num_blocks , ...) \ #define REUSEKV_SWITCH(reusekv,...) \
[&] { \ [&] { \
if (num_heads % 2 == 0 && num_heads / num_kv_heads >= 4 && num_blocks >= 1200){ \ if (reusekv==16){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} else if (num_heads / num_kv_heads >= 2 && num_blocks >= 1200){\
constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define REUSEKV_SWITCH_V2( ...) \
[&] { \
if (num_heads / num_kv_heads > 8 ){ \
constexpr static int REUSE_KV_TIMES = 16; \ constexpr static int REUSE_KV_TIMES = 16; \
return __VA_ARGS__(); \ return __VA_ARGS__();} \
}else if (num_heads / num_kv_heads > 4 ){ \ else if (reusekv==8){ \
constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \
}else if (num_heads / num_kv_heads > 2 ){ \
constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \
} else { \
constexpr static int REUSE_KV_TIMES = 1; \
return __VA_ARGS__(); \
} \
}()
#define REUSEKV_SWITCH_V1( ...) \
[&] { \
if (num_heads/num_kv_heads >4 && padded_max_seq_len<3900){ \
constexpr static int REUSE_KV_TIMES = 8; \ constexpr static int REUSE_KV_TIMES = 8; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
}else if (num_heads/num_kv_heads >2 && padded_max_seq_len<7800){ \ }else if (reusekv==4){ \
constexpr static int REUSE_KV_TIMES = 4; \ constexpr static int REUSE_KV_TIMES = 4; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
}else if (num_heads/num_kv_heads ==2 && padded_max_seq_len<15600){ \ }else if (reusekv==2){ \
constexpr static int REUSE_KV_TIMES = 2; \ constexpr static int REUSE_KV_TIMES = 2; \
return __VA_ARGS__(); \ return __VA_ARGS__(); \
}else { \ }else { \
......
...@@ -127,7 +127,7 @@ class PagedAttention: ...@@ -127,7 +127,7 @@ class PagedAttention:
# use_v1 = (max_seq_len <= 8192 # use_v1 = (max_seq_len <= 8192
# and (max_num_partitions == 1 or num_seqs * num_heads > 512)) # and (max_num_partitions == 1 or num_seqs * num_heads > 512))
use_v1 = (max_seq_len < 8192 use_v1 = (max_seq_len < 8192
and (max_seq_len<1000 or num_seqs * num_heads > (1024 if num_kv_heads < num_heads else 512))) and (max_seq_len<(1024 if num_kv_heads == num_heads else 600) or num_seqs * num_heads > (1024 if num_kv_heads < num_heads else 512)))
if use_v1: if use_v1:
# Run PagedAttention V1. # Run PagedAttention V1.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment