Commit 5d3fb1d4 authored by zhuwenwen's avatar zhuwenwen
Browse files

解决PA崩溃

parent 9d5187eb
...@@ -965,19 +965,15 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO ...@@ -965,19 +965,15 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
) )
max_num_partitions=1; max_num_partitions=1;
int blocks=max_num_partitions*batchsize*qheads; int blocks=max_num_partitions*batchsize*qheads;
if(device_name=="gfx928"||batchsize>100){ if(batchsize>100&&max_seq_len>=2000){
if(batchsize*qheads>1024&&max_seq_len>=2000){ if(max_seq_len<3900)reusekv=4;
max_num_partitions=1;
if(max_seq_len<2000)reusekv=8;
else if(max_seq_len<3900)reusekv=4;
else{ else{
PARTITION_SIZE=2048; PARTITION_SIZE=1024;
reusekv=8; reusekv=8;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
} }
return; return;
} }
}
if(max_num_partitions==1){ if(max_num_partitions==1){
if(max_seq_len<512){ if(max_seq_len<512){
int bytes=max_seq_len*qheads*batchsize; int bytes=max_seq_len*qheads*batchsize;
...@@ -995,9 +991,7 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO ...@@ -995,9 +991,7 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
if(blocks<600||qheads<=kvheads*4){reusekv=4;return;} if(blocks<600||qheads<=kvheads*4){reusekv=4;return;}
reusekv=8;return; reusekv=8;return;
} }
if(device_name=="gfx928"){ if(batchsize>100&&max_seq_len>=2000){
if(batchsize*qheads>1024&&max_seq_len>=2000){
max_num_partitions=1;
if(max_seq_len<3900)reusekv=4; if(max_seq_len<3900)reusekv=4;
else{ else{
PARTITION_SIZE=2048; PARTITION_SIZE=2048;
...@@ -1006,7 +1000,6 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO ...@@ -1006,7 +1000,6 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
} }
return; return;
} }
}
if(max_seq_len<=1000|| if(max_seq_len<=1000||
max_seq_len<=1500&&(qheads>4&&batchsize>=16||batchsize>=64)) max_seq_len<=1500&&(qheads>4&&batchsize>=16||batchsize>=64))
max_num_partitions=1; max_num_partitions=1;
...@@ -1068,6 +1061,7 @@ void paged_attention_v2_launcher_opt_tc( ...@@ -1068,6 +1061,7 @@ void paged_attention_v2_launcher_opt_tc(
int reusekv, num_thread,max_num_partitions,PARTITION_SIZE=512; int reusekv, num_thread,max_num_partitions,PARTITION_SIZE=512;
if(!is_half&&max_seq_len<=8192)PARTITION_SIZE=256; if(!is_half&&max_seq_len<=8192)PARTITION_SIZE=256;
get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks); get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks);
if(num_seqs>100&&max_num_partitions>16)max_num_partitions=16;
if(PA_PARTITION_SIZE!=0){ if(PA_PARTITION_SIZE!=0){
PARTITION_SIZE=PA_PARTITION_SIZE; PARTITION_SIZE=PA_PARTITION_SIZE;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment