Commit 98955c1f authored by zhangshao's avatar zhangshao
Browse files

提升bf16 pa精度

parent 09e372e7
...@@ -237,7 +237,7 @@ __global__ void paged_attention_kernel_TC( ...@@ -237,7 +237,7 @@ __global__ void paged_attention_kernel_TC(
} }
__syncthreads(); __syncthreads();
extern __shared__ char shared_mem[]; extern __shared__ char shared_mem[];
scalar_t* logits = reinterpret_cast<scalar_t*>(shared_mem); float* logits = reinterpret_cast<float*>(shared_mem);
// __shared__ float red_smem[2 * NUM_WARPS]; // __shared__ float red_smem[2 * NUM_WARPS];
__shared__ float s_max[REUSE_KV_TIMES][NUM_WARPS]; __shared__ float s_max[REUSE_KV_TIMES][NUM_WARPS];
__shared__ float s_logit[NUM_WARPS]; __shared__ float s_logit[NUM_WARPS];
...@@ -277,10 +277,10 @@ __global__ void paged_attention_kernel_TC( ...@@ -277,10 +277,10 @@ __global__ void paged_attention_kernel_TC(
} }
const bool mask = (token_idx >= seq_len); const bool mask = (token_idx >= seq_len);
if(mask){ if(mask){
from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f); logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] = 0.f;
} }
else{ else{
from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , qk_vec[i]); logits[partition_size*reuse_kv_idx+token_idx - start_token_idx]=qk_vec[i];
qk_max[i] = fmaxf(qk_max[i], qk_vec[i]); qk_max[i] = fmaxf(qk_max[i], qk_vec[i]);
} }
} }
...@@ -313,15 +313,15 @@ __global__ void paged_attention_kernel_TC( ...@@ -313,15 +313,15 @@ __global__ void paged_attention_kernel_TC(
} }
qk_max_tmp = __shfl(qk_max_tmp, 0); qk_max_tmp = __shfl(qk_max_tmp, 0);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(to_float(logits[(reuse_kv_idx * partition_size) + i]) - qk_max_tmp); float val = __expf(logits[(reuse_kv_idx * partition_size) + i] - qk_max_tmp);
from_float(logits[(reuse_kv_idx * partition_size) + i] , val); logits[(reuse_kv_idx * partition_size) + i] = val;
exp_sum += val; exp_sum += val;
} }
exp_sum = block_sum<NUM_WARPS>(s_logit, exp_sum); exp_sum = block_sum<NUM_WARPS>(s_logit, exp_sum);
// Compute softmax. // Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
from_float(logits[(reuse_kv_idx * partition_size) + i] ,to_float(logits[(reuse_kv_idx * partition_size) + i])*inv_sum); logits[(reuse_kv_idx * partition_size) + i] = logits[(reuse_kv_idx * partition_size) + i]*inv_sum;
} }
if(USE_PARTITIONING&&thread_idx == 0){ if(USE_PARTITIONING&&thread_idx == 0){
max_out[reuse_kv_idx] = qk_max_tmp; max_out[reuse_kv_idx] = qk_max_tmp;
...@@ -349,7 +349,11 @@ __global__ void paged_attention_kernel_TC( ...@@ -349,7 +349,11 @@ __global__ void paged_attention_kernel_TC(
const int token_idx = block_idx * BLOCK_SIZE +rows*4; const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0}; half4_t logits_vec={0,0,0,0};
if(rowid<4*q_boundary){ if(rowid<4*q_boundary){
logits_vec=*reinterpret_cast<half4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx); auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
} }
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + rows*4+rowid*16; kv_head_idx * kv_head_stride + rows*4+rowid*16;
...@@ -441,7 +445,11 @@ __global__ void paged_attention_kernel_TC( ...@@ -441,7 +445,11 @@ __global__ void paged_attention_kernel_TC(
const int token_idx = block_idx * BLOCK_SIZE +rows*4; const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0}; half4_t logits_vec={0,0,0,0};
if(rowid<q_boundary){ if(rowid<q_boundary){
logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx); auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
} }
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + rows*4+rowid*16; kv_head_idx * kv_head_stride + rows*4+rowid*16;
...@@ -542,7 +550,11 @@ __global__ void paged_attention_kernel_TC( ...@@ -542,7 +550,11 @@ __global__ void paged_attention_kernel_TC(
const int token_idx = block_idx * BLOCK_SIZE +rows*4; const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0}; half4_t logits_vec={0,0,0,0};
if(rowid<q_boundary){ if(rowid<q_boundary){
logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx); auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
} }
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + rows*4+rowid*16; kv_head_idx * kv_head_stride + rows*4+rowid*16;
...@@ -837,8 +849,8 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO ...@@ -837,8 +849,8 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
if(device_name=="gfx928"){ if(device_name=="gfx928"){
if(batchsize*qheads>1024&&max_seq_len>=2000){ if(batchsize*qheads>1024&&max_seq_len>=2000){
max_num_partitions=1; max_num_partitions=1;
if(max_seq_len<3900)reusekv=8; if(max_seq_len<2000)reusekv=8;
else if(max_seq_len<7800)reusekv=4; else if(max_seq_len<3900)reusekv=4;
else{ else{
PARTITION_SIZE=2048; PARTITION_SIZE=2048;
reusekv=8; reusekv=8;
...@@ -867,7 +879,7 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO ...@@ -867,7 +879,7 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
if(device_name=="gfx928"){ if(device_name=="gfx928"){
if(batchsize*qheads>1024&&max_seq_len>=2000){ if(batchsize*qheads>1024&&max_seq_len>=2000){
max_num_partitions=1; max_num_partitions=1;
if(max_seq_len<7800)reusekv=4; if(max_seq_len<3900)reusekv=4;
else{ else{
PARTITION_SIZE=2048; PARTITION_SIZE=2048;
reusekv=4; reusekv=4;
...@@ -880,7 +892,7 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO ...@@ -880,7 +892,7 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
max_seq_len<=1500&&(qheads>4&&batchsize>=16||batchsize>=64)) max_seq_len<=1500&&(qheads>4&&batchsize>=16||batchsize>=64))
max_num_partitions=1; max_num_partitions=1;
int blocks=max_num_partitions*batchsize*qheads; int blocks=max_num_partitions*batchsize*qheads;
if(blocks>=150||batchsize>=16||qheads>=8&&(batchsize>=4||max_seq_len>=2000))reusekv=4; if(blocks>=150||batchsize>=16||qheads>=8&&(batchsize>=4||(max_seq_len>=2000&&max_seq_len<3900)))reusekv=4;
} }
template <typename T, typename CACHE_T, int BLOCK_SIZE, template <typename T, typename CACHE_T, int BLOCK_SIZE,
...@@ -948,7 +960,7 @@ void paged_attention_v2_launcher_opt_tc( ...@@ -948,7 +960,7 @@ void paged_attention_v2_launcher_opt_tc(
REUSEKV_SWITCH(reusekv,[&] { REUSEKV_SWITCH(reusekv,[&] {
NUM_THREADS_SWITCH(num_thread , [&] { NUM_THREADS_SWITCH(num_thread , [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 2; int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 4;
if(max_num_partitions==1)PARTITION_SIZE=0; if(max_num_partitions==1)PARTITION_SIZE=0;
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
dim3 grid; dim3 grid;
...@@ -1051,4 +1063,4 @@ void paged_attention_v1_opt_tc( ...@@ -1051,4 +1063,4 @@ void paged_attention_v1_opt_tc(
#undef WARP_SIZE #undef WARP_SIZE
#undef MAX #undef MAX
#undef MIN #undef MIN
#undef DIVIDE_ROUND_UP #undef DIVIDE_ROUND_UP
\ No newline at end of file
...@@ -221,7 +221,7 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -221,7 +221,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
} }
__syncthreads(); __syncthreads();
extern __shared__ char shared_mem[]; extern __shared__ char shared_mem[];
scalar_t* logits = reinterpret_cast<scalar_t*>(shared_mem); float* logits = reinterpret_cast<float*>(shared_mem);
// __shared__ float red_smem[2 * NUM_WARPS]; // __shared__ float red_smem[2 * NUM_WARPS];
__shared__ float s_max[REUSE_KV_TIMES][NUM_WARPS]; __shared__ float s_max[REUSE_KV_TIMES][NUM_WARPS];
__shared__ float s_logit[NUM_WARPS]; __shared__ float s_logit[NUM_WARPS];
...@@ -268,10 +268,10 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -268,10 +268,10 @@ __global__ void paged_attention_kernel_TC_with_mask(
} }
const bool mask = (token_idx >= seq_len); const bool mask = (token_idx >= seq_len);
if(mask){ if(mask){
from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f); logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] = 0.f;
} }
else{ else{
from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , qk_vec[i]); logits[partition_size*reuse_kv_idx+token_idx - start_token_idx]=qk_vec[i];
qk_max[i] = fmaxf(qk_max[i], qk_vec[i]); qk_max[i] = fmaxf(qk_max[i], qk_vec[i]);
} }
} }
...@@ -304,15 +304,15 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -304,15 +304,15 @@ __global__ void paged_attention_kernel_TC_with_mask(
} }
qk_max_tmp = __shfl(qk_max_tmp, 0); qk_max_tmp = __shfl(qk_max_tmp, 0);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(to_float(logits[(reuse_kv_idx * partition_size) + i]) - qk_max_tmp); float val = __expf(logits[(reuse_kv_idx * partition_size) + i] - qk_max_tmp);
from_float(logits[(reuse_kv_idx * partition_size) + i] , val); logits[(reuse_kv_idx * partition_size) + i] = val;
exp_sum += val; exp_sum += val;
} }
exp_sum = block_sum<NUM_WARPS>(s_logit, exp_sum); exp_sum = block_sum<NUM_WARPS>(s_logit, exp_sum);
// Compute softmax. // Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
from_float(logits[(reuse_kv_idx * partition_size) + i] ,to_float(logits[(reuse_kv_idx * partition_size) + i])*inv_sum); logits[(reuse_kv_idx * partition_size) + i] = logits[(reuse_kv_idx * partition_size) + i]*inv_sum;
} }
if(USE_PARTITIONING&&thread_idx == 0){ if(USE_PARTITIONING&&thread_idx == 0){
max_out[reuse_kv_idx] = qk_max_tmp; max_out[reuse_kv_idx] = qk_max_tmp;
...@@ -340,7 +340,11 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -340,7 +340,11 @@ __global__ void paged_attention_kernel_TC_with_mask(
const int token_idx = block_idx * BLOCK_SIZE +rows*4; const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0}; half4_t logits_vec={0,0,0,0};
if(rowid<4*q_boundary){ if(rowid<4*q_boundary){
logits_vec=*reinterpret_cast<half4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx); auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
} }
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + rows*4+rowid*16; kv_head_idx * kv_head_stride + rows*4+rowid*16;
...@@ -432,7 +436,11 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -432,7 +436,11 @@ __global__ void paged_attention_kernel_TC_with_mask(
const int token_idx = block_idx * BLOCK_SIZE +rows*4; const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0}; half4_t logits_vec={0,0,0,0};
if(rowid<q_boundary){ if(rowid<q_boundary){
logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx); auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
} }
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + rows*4+rowid*16; kv_head_idx * kv_head_stride + rows*4+rowid*16;
...@@ -533,7 +541,11 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -533,7 +541,11 @@ __global__ void paged_attention_kernel_TC_with_mask(
const int token_idx = block_idx * BLOCK_SIZE +rows*4; const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0}; half4_t logits_vec={0,0,0,0};
if(rowid<q_boundary){ if(rowid<q_boundary){
logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx); auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
} }
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + rows*4+rowid*16; kv_head_idx * kv_head_stride + rows*4+rowid*16;
...@@ -856,7 +868,7 @@ void paged_attention_v2_launcher_opt_tc_with_mask( ...@@ -856,7 +868,7 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
REUSEKV_SWITCH(reusekv,[&] { REUSEKV_SWITCH(reusekv,[&] {
NUM_THREADS_SWITCH(num_thread , [&] { NUM_THREADS_SWITCH(num_thread , [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 2; int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 4;
if(max_num_partitions==1)PARTITION_SIZE=0; if(max_num_partitions==1)PARTITION_SIZE=0;
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
dim3 grid; dim3 grid;
...@@ -883,13 +895,10 @@ void paged_attention_v2_launcher_opt_tc_with_mask( ...@@ -883,13 +895,10 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
blocksparse_head_sliding_step,attn_masks, attn_masks_stride); blocksparse_head_sliding_step,attn_masks, attn_masks_stride);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ #define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \ if (is_block_sparse) { \
case true: \ CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ } else { \
break; \ CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
case false: \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
break; \
} }
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
...@@ -933,8 +942,8 @@ void paged_attention_v2_opt_tc_with_mask( ...@@ -933,8 +942,8 @@ void paged_attention_v2_opt_tc_with_mask(
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len] const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) { const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE) CALL_V2_LAUNCHER_BLOCK_SIZE)
} }
...@@ -958,10 +967,10 @@ void paged_attention_v1_opt_tc_with_mask( ...@@ -958,10 +967,10 @@ void paged_attention_v1_opt_tc_with_mask(
const int64_t blocksparse_head_sliding_step, const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len] const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) { const int64_t attn_masks_stride) {
paged_attention_v2_opt_tc_with_mask(out,out,out,out,query,key_cache,value_cache,num_kv_heads, paged_attention_v2_opt_tc_with_mask(out,out,out,out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype, scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride, k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step,attn_masks,attn_masks_stride); blocksparse_block_size,blocksparse_head_sliding_step,attn_masks,attn_masks_stride);
} }
#undef WARP_SIZE #undef WARP_SIZE
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment