Commit ac811e51 authored by zhuwenwen's avatar zhuwenwen
Browse files

解决PA部分size计算错误的问题

优化bf16精度
解决bf16精度问题,解决cudagraph精度问题
parent a5b976df
This diff is collapsed.
......@@ -107,11 +107,11 @@ inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4
{
if constexpr (is_half){
asm volatile("v_mmac_f32_16x16x16_f16 %0, %1, %2, %0" :
asm volatile("\n s_nop 1 \n v_mmac_f32_16x16x16_f16 %0, %1, %2, %0" :
"=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c));
}
else{
asm volatile("v_mmac_f32_16x16x16_bf16 %0, %1, %2, %0" :
asm volatile("\n s_nop 1 \n v_mmac_f32_16x16x16_bf16 %0, %1, %2, %0" :
"=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c));
}
}
......@@ -159,7 +159,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
const int seq_len = __builtin_amdgcn_readfirstlane(seq_lens[seq_idx]);
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const bool USE_PARTITIONING = PARTITION_SIZE<num_seq_blocks * BLOCK_SIZE && PARTITION_SIZE>0;
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) return;
if (partition_idx * PARTITION_SIZE >= seq_len) return;
constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value;
static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS");
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
......@@ -209,10 +209,12 @@ __global__ void paged_attention_kernel_TC_with_mask(
for(int i=0;i<q_boundary;i++){
if(thread_idx<16){
half4x2 temp = *reinterpret_cast<const half4x2*>(q_ptr+i*HEAD_SIZE+thread_idx*8);
if constexpr(is_half){
scalar_t *t=reinterpret_cast<scalar_t*>(&temp);
#pragma unroll
for(int k=0;k<4;k++){
temp.data[0][k]=((float)temp.data[0][k])*scale;
temp.data[1][k]=((float)temp.data[1][k])*scale;
for(int k=0;k<8;k++){
from_float(t[k],to_float(t[k])*scale);
}
}
q_vecs[i][thread_idx]=temp;
}
......@@ -249,6 +251,9 @@ __global__ void paged_attention_kernel_TC_with_mask(
int reuse_kv_idx=rows+i*4;
if(reuse_kv_idx<REUSE_KV_TIMES){
if(reuse_kv_idx>=q_boundary)qk_vec[i]=0;
else {
if constexpr(!is_half) qk_vec[i]*=scale;
}
const int token_idx = block_idx * BLOCK_SIZE+rowid;
if(alibi_slope[i] != 0){
float alibi=alibi_slope[i]* (token_idx - seq_len + 1);
......@@ -316,13 +321,12 @@ __global__ void paged_attention_kernel_TC_with_mask(
}
__syncthreads();
constexpr int NUM_ROWS_PER_THREAD =DIVIDE_ROUND_UP(HEAD_SIZE, WARP_SIZE);//2
if (q_boundary<=2){
constexpr int acc_size = REUSE_KV_TIMES==1?1:2;
float accs[acc_size][NUM_ROWS_PER_THREAD];
if constexpr(REUSE_KV_TIMES<=2){
float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
#pragma unroll
for(int k=0;k<acc_size;k++)
for(int k=0;k<REUSE_KV_TIMES;k++)
{
accs[k][i] = 0.f;
}
......@@ -356,7 +360,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
float4_t out_vec={0,0,0,0};
builtin_amdgcn_mmac<is_half>(v_vec,logits_vec,out_vec);
if(rows==k){
for(int resuseid=0;resuseid<acc_size;resuseid++){
for(int resuseid=0;resuseid<REUSE_KV_TIMES;resuseid++){
accs[resuseid][i]+=out_vec[resuseid];
}
}
......@@ -366,8 +370,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
__syncthreads();
using floatV_t = __attribute__( (__vector_size__(NUM_ROWS_PER_THREAD * sizeof(float)) )) float;
// Perform reduction across warps.
#pragma unroll
for(int reuse_kv_idx=0; reuse_kv_idx<acc_size; reuse_kv_idx++) {
for(int reuse_kv_idx=0; reuse_kv_idx<q_boundary; reuse_kv_idx++) {
if constexpr (NUM_THREADS>64){
floatV_t* out_smem = reinterpret_cast<floatV_t*>(shared_mem);
#pragma unroll
......@@ -780,97 +783,9 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_v2_reduce_kern
max_num_partitions,PARTITION_SIZE);}
static void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITION_SIZE,int &max_num_partitions,
int batchsize,int max_seq_len,int qheads,int kvheads,int num_blocks)
{
reusekv=1;
num_thread=256;
PARTITION_SIZE=512;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
if(max_seq_len==8192&&num_blocks==1024){//ali test
if(batchsize==1&&qheads==16&&kvheads==16){num_thread=128;return;}
if(batchsize==1&&qheads==32&&kvheads==32){num_thread=64;return;}
if(batchsize==1){
if(qheads==52){reusekv=8;return;}
if(qheads==13){reusekv=2;return;}
reusekv=4;return;
}
if(batchsize==64){
if(qheads==13){PARTITION_SIZE=256;num_thread=128;reusekv=8;}
else if(qheads==32){PARTITION_SIZE=1024;reusekv=8;}
else if(qheads==52||qheads==26){reusekv=16;}
else reusekv=8;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
return;
}
}
if(qheads==kvheads){
if(max_seq_len<=8192){
if(batchsize*qheads>=512){
max_num_partitions=1;
num_thread=64;
}
if(qheads==32&&max_seq_len<=1024)max_num_partitions=1;
}
return;
}
if(max_seq_len<800)max_num_partitions=1;
if(qheads>kvheads*4){
if(max_seq_len<=1000||
max_seq_len<1500&&(batchsize>=8&&qheads>=8||batchsize>=64)||
max_seq_len<1900&&batchsize>=8&&qheads==28
)
max_num_partitions=1;
int blocks=max_num_partitions*batchsize*qheads;
if(device_name=="gfx928"){
if(batchsize*qheads>1024&&max_seq_len>=2000){
max_num_partitions=1;
if(max_seq_len<3900)reusekv=8;
else if(max_seq_len<7800)reusekv=4;
else{
PARTITION_SIZE=2048;
reusekv=8;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
}
return;
}
}
if(max_num_partitions==1){
if(max_seq_len<512){
int bytes=max_seq_len*qheads*batchsize;
if(bytes<51200)reusekv=1;
else if(bytes<256000)reusekv=4;
else reusekv=8;
return;
}
if(batchsize<4||batchsize==4&&qheads==8)reusekv=1;
else if(batchsize<32||batchsize<=64&&qheads==8)reusekv=4;
else reusekv=8;
return;
}
if(blocks<150)return;
if(blocks<600||qheads<=kvheads*4){reusekv=4;return;}
reusekv=8;return;
}
if(device_name=="gfx928"){
if(batchsize*qheads>1024&&max_seq_len>=2000){
max_num_partitions=1;
if(max_seq_len<7800)reusekv=4;
else{
PARTITION_SIZE=2048;
reusekv=4;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
}
return;
}
}
if(max_seq_len<=1000||
max_seq_len<=1500&&(qheads>4&&batchsize>=16||batchsize>=64))
max_num_partitions=1;
int blocks=max_num_partitions*batchsize*qheads;
if(blocks>=150||batchsize>=16||qheads>=8&&(batchsize>=4||max_seq_len>=2000))reusekv=4;
void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITION_SIZE,int &max_num_partitions,
int batchsize,int max_seq_len,int qheads,int kvheads,int num_blocks);
}
template <typename T, typename CACHE_T, int BLOCK_SIZE,
vllm::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE>
void paged_attention_v2_launcher_opt_tc_with_mask(
......@@ -995,30 +910,6 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
break; \
}
void paged_attention_v2_with_mask(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor&
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,// [num_seqs, max_seq_len]
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride);
void paged_attention_v2_opt_tc_with_mask(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
......@@ -1043,38 +934,10 @@ void paged_attention_v2_opt_tc_with_mask(
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){
paged_attention_v2_with_mask(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step,attn_masks,attn_masks_stride);
}
else{
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE)
}
}
void paged_attention_v1_with_mask(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, // [num_heads]
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride);
void paged_attention_v1_opt_tc_with_mask(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
......@@ -1095,20 +958,10 @@ void paged_attention_v1_opt_tc_with_mask(
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||(device_name!="gfx928" && device_name!="gfx936")){
paged_attention_v1_with_mask(out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step,attn_masks,attn_masks_stride);
}
else{
paged_attention_v2_opt_tc_with_mask(out,out,out,out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step,attn_masks,attn_masks_stride);
}
}
#undef WARP_SIZE
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment