"cacheflow/vscode:/vscode.git/clone" did not exist on "897cb2ae28e93de1b22ecfbffcccfb9493f8f4d9"
Commit 98955c1f authored by zhangshao's avatar zhangshao
Browse files

提升bf16 pa精度

parent 09e372e7
......@@ -237,7 +237,7 @@ __global__ void paged_attention_kernel_TC(
}
__syncthreads();
extern __shared__ char shared_mem[];
scalar_t* logits = reinterpret_cast<scalar_t*>(shared_mem);
float* logits = reinterpret_cast<float*>(shared_mem);
// __shared__ float red_smem[2 * NUM_WARPS];
__shared__ float s_max[REUSE_KV_TIMES][NUM_WARPS];
__shared__ float s_logit[NUM_WARPS];
......@@ -277,10 +277,10 @@ __global__ void paged_attention_kernel_TC(
}
const bool mask = (token_idx >= seq_len);
if(mask){
from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f);
logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] = 0.f;
}
else{
from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , qk_vec[i]);
logits[partition_size*reuse_kv_idx+token_idx - start_token_idx]=qk_vec[i];
qk_max[i] = fmaxf(qk_max[i], qk_vec[i]);
}
}
......@@ -313,15 +313,15 @@ __global__ void paged_attention_kernel_TC(
}
qk_max_tmp = __shfl(qk_max_tmp, 0);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(to_float(logits[(reuse_kv_idx * partition_size) + i]) - qk_max_tmp);
from_float(logits[(reuse_kv_idx * partition_size) + i] , val);
float val = __expf(logits[(reuse_kv_idx * partition_size) + i] - qk_max_tmp);
logits[(reuse_kv_idx * partition_size) + i] = val;
exp_sum += val;
}
exp_sum = block_sum<NUM_WARPS>(s_logit, exp_sum);
// Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
from_float(logits[(reuse_kv_idx * partition_size) + i] ,to_float(logits[(reuse_kv_idx * partition_size) + i])*inv_sum);
logits[(reuse_kv_idx * partition_size) + i] = logits[(reuse_kv_idx * partition_size) + i]*inv_sum;
}
if(USE_PARTITIONING&&thread_idx == 0){
max_out[reuse_kv_idx] = qk_max_tmp;
......@@ -349,7 +349,11 @@ __global__ void paged_attention_kernel_TC(
const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0};
if(rowid<4*q_boundary){
logits_vec=*reinterpret_cast<half4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx);
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
}
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + rows*4+rowid*16;
......@@ -441,7 +445,11 @@ __global__ void paged_attention_kernel_TC(
const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0};
if(rowid<q_boundary){
logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
}
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + rows*4+rowid*16;
......@@ -542,7 +550,11 @@ __global__ void paged_attention_kernel_TC(
const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0};
if(rowid<q_boundary){
logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
}
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + rows*4+rowid*16;
......@@ -837,8 +849,8 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
if(device_name=="gfx928"){
if(batchsize*qheads>1024&&max_seq_len>=2000){
max_num_partitions=1;
if(max_seq_len<3900)reusekv=8;
else if(max_seq_len<7800)reusekv=4;
if(max_seq_len<2000)reusekv=8;
else if(max_seq_len<3900)reusekv=4;
else{
PARTITION_SIZE=2048;
reusekv=8;
......@@ -867,7 +879,7 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
if(device_name=="gfx928"){
if(batchsize*qheads>1024&&max_seq_len>=2000){
max_num_partitions=1;
if(max_seq_len<7800)reusekv=4;
if(max_seq_len<3900)reusekv=4;
else{
PARTITION_SIZE=2048;
reusekv=4;
......@@ -880,7 +892,7 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
max_seq_len<=1500&&(qheads>4&&batchsize>=16||batchsize>=64))
max_num_partitions=1;
int blocks=max_num_partitions*batchsize*qheads;
if(blocks>=150||batchsize>=16||qheads>=8&&(batchsize>=4||max_seq_len>=2000))reusekv=4;
if(blocks>=150||batchsize>=16||qheads>=8&&(batchsize>=4||(max_seq_len>=2000&&max_seq_len<3900)))reusekv=4;
}
template <typename T, typename CACHE_T, int BLOCK_SIZE,
......@@ -948,7 +960,7 @@ void paged_attention_v2_launcher_opt_tc(
REUSEKV_SWITCH(reusekv,[&] {
NUM_THREADS_SWITCH(num_thread , [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 2;
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 4;
if(max_num_partitions==1)PARTITION_SIZE=0;
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
dim3 grid;
......@@ -1051,4 +1063,4 @@ void paged_attention_v1_opt_tc(
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
\ No newline at end of file
#undef DIVIDE_ROUND_UP
......@@ -221,7 +221,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
}
__syncthreads();
extern __shared__ char shared_mem[];
scalar_t* logits = reinterpret_cast<scalar_t*>(shared_mem);
float* logits = reinterpret_cast<float*>(shared_mem);
// __shared__ float red_smem[2 * NUM_WARPS];
__shared__ float s_max[REUSE_KV_TIMES][NUM_WARPS];
__shared__ float s_logit[NUM_WARPS];
......@@ -268,10 +268,10 @@ __global__ void paged_attention_kernel_TC_with_mask(
}
const bool mask = (token_idx >= seq_len);
if(mask){
from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f);
logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] = 0.f;
}
else{
from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , qk_vec[i]);
logits[partition_size*reuse_kv_idx+token_idx - start_token_idx]=qk_vec[i];
qk_max[i] = fmaxf(qk_max[i], qk_vec[i]);
}
}
......@@ -304,15 +304,15 @@ __global__ void paged_attention_kernel_TC_with_mask(
}
qk_max_tmp = __shfl(qk_max_tmp, 0);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(to_float(logits[(reuse_kv_idx * partition_size) + i]) - qk_max_tmp);
from_float(logits[(reuse_kv_idx * partition_size) + i] , val);
float val = __expf(logits[(reuse_kv_idx * partition_size) + i] - qk_max_tmp);
logits[(reuse_kv_idx * partition_size) + i] = val;
exp_sum += val;
}
exp_sum = block_sum<NUM_WARPS>(s_logit, exp_sum);
// Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
from_float(logits[(reuse_kv_idx * partition_size) + i] ,to_float(logits[(reuse_kv_idx * partition_size) + i])*inv_sum);
logits[(reuse_kv_idx * partition_size) + i] = logits[(reuse_kv_idx * partition_size) + i]*inv_sum;
}
if(USE_PARTITIONING&&thread_idx == 0){
max_out[reuse_kv_idx] = qk_max_tmp;
......@@ -340,7 +340,11 @@ __global__ void paged_attention_kernel_TC_with_mask(
const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0};
if(rowid<4*q_boundary){
logits_vec=*reinterpret_cast<half4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx);
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid/4 * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
}
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + rows*4+rowid*16;
......@@ -432,7 +436,11 @@ __global__ void paged_attention_kernel_TC_with_mask(
const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0};
if(rowid<q_boundary){
logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
}
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + rows*4+rowid*16;
......@@ -533,7 +541,11 @@ __global__ void paged_attention_kernel_TC_with_mask(
const int token_idx = block_idx * BLOCK_SIZE +rows*4;
half4_t logits_vec={0,0,0,0};
if(rowid<q_boundary){
logits_vec=*reinterpret_cast<half4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
auto f_logits = *reinterpret_cast<float4_t*>(logits + rowid * partition_size+token_idx - start_token_idx);
scalar_t * p = reinterpret_cast<scalar_t*>(&logits_vec);
for(int i=0;i<4;i++){
from_float(p[i],f_logits[i]);
}
}
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + rows*4+rowid*16;
......@@ -856,7 +868,7 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
REUSEKV_SWITCH(reusekv,[&] {
NUM_THREADS_SWITCH(num_thread , [&] {
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 2;
int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 4;
if(max_num_partitions==1)PARTITION_SIZE=0;
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
dim3 grid;
......@@ -883,13 +895,10 @@ void paged_attention_v2_launcher_opt_tc_with_mask(
blocksparse_head_sliding_step,attn_masks, attn_masks_stride);
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
switch (is_block_sparse) { \
case true: \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
break; \
case false: \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
break; \
if (is_block_sparse) { \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
} else { \
CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
}
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
......@@ -933,8 +942,8 @@ void paged_attention_v2_opt_tc_with_mask(
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
const bool is_block_sparse = (blocksparse_vert_stride > 1);
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
CALL_V2_LAUNCHER_BLOCK_SIZE)
}
......@@ -958,10 +967,10 @@ void paged_attention_v1_opt_tc_with_mask(
const int64_t blocksparse_head_sliding_step,
const c10::optional<torch::Tensor>& attn_masks, // [num_seqs, max_seq_len]
const int64_t attn_masks_stride) {
paged_attention_v2_opt_tc_with_mask(out,out,out,out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step,attn_masks,attn_masks_stride);
paged_attention_v2_opt_tc_with_mask(out,out,out,out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
blocksparse_block_size,blocksparse_head_sliding_step,attn_masks,attn_masks_stride);
}
#undef WARP_SIZE
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment