"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "a04a94012c9cfffdaa914c9d84ef3bb515a325f8"
Commit 4f8d38c8 authored by zhangshao's avatar zhangshao
Browse files

解决bf16精度问题,解决cudagraph精度问题

parent b3992aad
...@@ -175,7 +175,7 @@ __global__ void paged_attention_kernel_TC( ...@@ -175,7 +175,7 @@ __global__ void paged_attention_kernel_TC(
const int seq_len = __builtin_amdgcn_readfirstlane(seq_lens[seq_idx]); const int seq_len = __builtin_amdgcn_readfirstlane(seq_lens[seq_idx]);
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const bool USE_PARTITIONING = PARTITION_SIZE<num_seq_blocks * BLOCK_SIZE && PARTITION_SIZE>0; const bool USE_PARTITIONING = PARTITION_SIZE<num_seq_blocks * BLOCK_SIZE && PARTITION_SIZE>0;
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) return; if (partition_idx * PARTITION_SIZE >= seq_len) return;
constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value; constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value;
static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS"); static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS");
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
...@@ -275,7 +275,6 @@ __global__ void paged_attention_kernel_TC( ...@@ -275,7 +275,6 @@ __global__ void paged_attention_kernel_TC(
float alibi=alibi_slope[i]* (token_idx - seq_len + 1); float alibi=alibi_slope[i]* (token_idx - seq_len + 1);
qk_vec[i] += alibi; qk_vec[i] += alibi;
} }
const bool mask = (token_idx >= seq_len); const bool mask = (token_idx >= seq_len);
if(mask){ if(mask){
from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f); from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f);
......
...@@ -107,11 +107,11 @@ inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4 ...@@ -107,11 +107,11 @@ inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4
{ {
if constexpr (is_half){ if constexpr (is_half){
asm volatile("v_mmac_f32_16x16x16_f16 %0, %1, %2, %0" : asm volatile("\n s_nop 1 \n v_mmac_f32_16x16x16_f16 %0, %1, %2, %0" :
"=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c)); "=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c));
} }
else{ else{
asm volatile("v_mmac_f32_16x16x16_bf16 %0, %1, %2, %0" : asm volatile("\n s_nop 1 \n v_mmac_f32_16x16x16_bf16 %0, %1, %2, %0" :
"=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c)); "=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c));
} }
} }
...@@ -159,7 +159,7 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -159,7 +159,7 @@ __global__ void paged_attention_kernel_TC_with_mask(
const int seq_len = __builtin_amdgcn_readfirstlane(seq_lens[seq_idx]); const int seq_len = __builtin_amdgcn_readfirstlane(seq_lens[seq_idx]);
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const bool USE_PARTITIONING = PARTITION_SIZE<num_seq_blocks * BLOCK_SIZE && PARTITION_SIZE>0; const bool USE_PARTITIONING = PARTITION_SIZE<num_seq_blocks * BLOCK_SIZE && PARTITION_SIZE>0;
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) return; if (partition_idx * PARTITION_SIZE >= seq_len) return;
constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value; constexpr bool is_half = std::is_same<scalar_t, uint16_t>::value;
static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS"); static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS");
const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
...@@ -209,10 +209,12 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -209,10 +209,12 @@ __global__ void paged_attention_kernel_TC_with_mask(
for(int i=0;i<q_boundary;i++){ for(int i=0;i<q_boundary;i++){
if(thread_idx<16){ if(thread_idx<16){
half4x2 temp = *reinterpret_cast<const half4x2*>(q_ptr+i*HEAD_SIZE+thread_idx*8); half4x2 temp = *reinterpret_cast<const half4x2*>(q_ptr+i*HEAD_SIZE+thread_idx*8);
#pragma unroll if constexpr(is_half){
for(int k=0;k<4;k++){ scalar_t *t=reinterpret_cast<scalar_t*>(&temp);
temp.data[0][k]=((float)temp.data[0][k])*scale; #pragma unroll
temp.data[1][k]=((float)temp.data[1][k])*scale; for(int k=0;k<8;k++){
from_float(t[k],to_float(t[k])*scale);
}
} }
q_vecs[i][thread_idx]=temp; q_vecs[i][thread_idx]=temp;
} }
...@@ -249,6 +251,9 @@ __global__ void paged_attention_kernel_TC_with_mask( ...@@ -249,6 +251,9 @@ __global__ void paged_attention_kernel_TC_with_mask(
int reuse_kv_idx=rows+i*4; int reuse_kv_idx=rows+i*4;
if(reuse_kv_idx<REUSE_KV_TIMES){ if(reuse_kv_idx<REUSE_KV_TIMES){
if(reuse_kv_idx>=q_boundary)qk_vec[i]=0; if(reuse_kv_idx>=q_boundary)qk_vec[i]=0;
else {
if constexpr(!is_half) qk_vec[i]*=scale;
}
const int token_idx = block_idx * BLOCK_SIZE+rowid; const int token_idx = block_idx * BLOCK_SIZE+rowid;
if(alibi_slope[i] != 0){ if(alibi_slope[i] != 0){
float alibi=alibi_slope[i]* (token_idx - seq_len + 1); float alibi=alibi_slope[i]* (token_idx - seq_len + 1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment