Commit b3992aad authored by zhangshao's avatar zhangshao
Browse files

优化bf16精度

parent 228a714a
...@@ -124,11 +124,11 @@ inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4 ...@@ -124,11 +124,11 @@ inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4
{ {
if constexpr (is_half){ if constexpr (is_half){
asm volatile("v_mmac_f32_16x16x16_f16 %0, %1, %2, %0" : asm volatile("\n s_nop 1 \n v_mmac_f32_16x16x16_f16 %0, %1, %2, %0" :
"=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c)); "=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c));
} }
else{ else{
asm volatile("v_mmac_f32_16x16x16_bf16 %0, %1, %2, %0" : asm volatile("\n s_nop 1 \n v_mmac_f32_16x16x16_bf16 %0, %1, %2, %0" :
"=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c)); "=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c));
} }
} }
...@@ -147,7 +147,7 @@ inline __device__ void builtin_amdgcn_mmac(const half4_t& reg_a, const half4_t& ...@@ -147,7 +147,7 @@ inline __device__ void builtin_amdgcn_mmac(const half4_t& reg_a, const half4_t&
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE,int REUSE_KV_TIMES> // Zero means no partitioning. bool IS_BLOCK_SPARSE,int REUSE_KV_TIMES> // Zero means no partitioning.
__global__ void paged_attention_kernel_TC_with_mask( __global__ void paged_attention_kernel_TC(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads,head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads,head_size]
...@@ -225,10 +225,12 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -225,10 +225,12 @@ __global__ void paged_attention_kernel_TC_with_mask(
for(int i=0;i<q_boundary;i++){ for(int i=0;i<q_boundary;i++){
if(thread_idx<16){ if(thread_idx<16){
half4x2 temp = *reinterpret_cast<const half4x2*>(q_ptr+i*HEAD_SIZE+thread_idx*8); half4x2 temp = *reinterpret_cast<const half4x2*>(q_ptr+i*HEAD_SIZE+thread_idx*8);
#pragma unroll if constexpr(is_half){
for(int k=0;k<4;k++){ scalar_t *t=reinterpret_cast<scalar_t*>(&temp);
temp.data[0][k]=((float)temp.data[0][k])*scale; #pragma unroll
temp.data[1][k]=((float)temp.data[1][k])*scale; for(int k=0;k<8;k++){
from_float(t[k],to_float(t[k])*scale);
}
} }
q_vecs[i][thread_idx]=temp; q_vecs[i][thread_idx]=temp;
} }
...@@ -265,6 +267,9 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -265,6 +267,9 @@ __global__ void paged_attention_kernel_TC_with_mask(
int reuse_kv_idx=rows+i*4; int reuse_kv_idx=rows+i*4;
if(reuse_kv_idx<REUSE_KV_TIMES){ if(reuse_kv_idx<REUSE_KV_TIMES){
if(reuse_kv_idx>=q_boundary)qk_vec[i]=0; if(reuse_kv_idx>=q_boundary)qk_vec[i]=0;
else {
if constexpr(!is_half) qk_vec[i]*=scale;
}
const int token_idx = block_idx * BLOCK_SIZE+rowid; const int token_idx = block_idx * BLOCK_SIZE+rowid;
if(alibi_slope[i] != 0){ if(alibi_slope[i] != 0){
float alibi=alibi_slope[i]* (token_idx - seq_len + 1); float alibi=alibi_slope[i]* (token_idx - seq_len + 1);
...@@ -764,16 +769,16 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_v2_reduce_kern ...@@ -764,16 +769,16 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_v2_reduce_kern
#define LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE) \
hipLaunchKernelGGL( \ hipLaunchKernelGGL( \
(vllm::paged_attention_kernel_TC_with_mask< \ (vllm::paged_attention_kernel_TC< \
T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \ T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \
IS_BLOCK_SPARSE, REUSE_KV_TIMES>), \ IS_BLOCK_SPARSE, REUSE_KV_TIMES>), \
dim3(grid), dim3(block), shared_mem_size, stream, exp_sums_ptr, \ dim3(grid), dim3(block), shared_mem_size, stream, exp_sums_ptr, \
max_logits_ptr,out_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr,\ max_logits_ptr,out_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr,\
num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \ num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \ kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step,PARTITION_SIZE);\ blocksparse_head_sliding_step,PARTITION_SIZE); \
if (max_num_partitions<=64&&max_num_partitions>1){ \ if (max_num_partitions<=64&&max_num_partitions>1){ \
hipLaunchKernelGGL( \ hipLaunchKernelGGL( \
(vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, 64>), \ (vllm::paged_attention_v2_reduce_kernel_opt_tc<T, HEAD_SIZE, 64>), \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment