Commit 1dea361d authored by zhanghj2's avatar zhanghj2
Browse files

fix buffer load lds data hazard

parent 8a69b46c
...@@ -710,6 +710,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co ...@@ -710,6 +710,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
+ offset_k * 3 * bytes_per_block; + offset_k * 3 * bytes_per_block;
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
......
...@@ -713,6 +713,7 @@ lds_direct_copy_qkvfp8_pe( ...@@ -713,6 +713,7 @@ lds_direct_copy_qkvfp8_pe(
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
...@@ -779,6 +780,7 @@ lds_direct_copy_qkvfp8( ...@@ -779,6 +780,7 @@ lds_direct_copy_qkvfp8(
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
...@@ -845,6 +847,7 @@ lds_direct_copy_qkvfp8( ...@@ -845,6 +847,7 @@ lds_direct_copy_qkvfp8(
#if defined(__gfx938__) #if defined(__gfx938__)
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
...@@ -932,6 +935,7 @@ lds_direct_copy_fp8( ...@@ -932,6 +935,7 @@ lds_direct_copy_fp8(
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__)
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
...@@ -1003,6 +1007,7 @@ lds_direct_copy_tp1( ...@@ -1003,6 +1007,7 @@ lds_direct_copy_tp1(
// } // }
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
...@@ -1183,6 +1188,7 @@ lds_direct_copy_sparse_k( ...@@ -1183,6 +1188,7 @@ lds_direct_copy_sparse_k(
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
...@@ -1251,6 +1257,7 @@ lds_direct_copy( ...@@ -1251,6 +1257,7 @@ lds_direct_copy(
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
...@@ -1314,6 +1321,7 @@ lds_direct_copy( ...@@ -1314,6 +1321,7 @@ lds_direct_copy(
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
...@@ -1403,6 +1411,7 @@ lds_direct_copy_for_prefill_sparse_mla( ...@@ -1403,6 +1411,7 @@ lds_direct_copy_for_prefill_sparse_mla(
index_offset[1] = offset_v; index_offset[1] = offset_v;
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset), "buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
...@@ -2296,6 +2305,7 @@ lds_direct_copy_qkvfp8_q_tp1( ...@@ -2296,6 +2305,7 @@ lds_direct_copy_qkvfp8_q_tp1(
#if defined(__gfx938__) #if defined(__gfx938__)
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
...@@ -2354,6 +2364,7 @@ lds_direct_copy_qkvfp8_q_tp4( ...@@ -2354,6 +2364,7 @@ lds_direct_copy_qkvfp8_q_tp4(
#if defined(__gfx938__) #if defined(__gfx938__)
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
...@@ -2435,6 +2446,7 @@ lds_direct_copy_qkvfp8_tp1( ...@@ -2435,6 +2446,7 @@ lds_direct_copy_qkvfp8_tp1(
#if defined(__gfx938__) #if defined(__gfx938__)
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
...@@ -2782,6 +2794,7 @@ lds_direct_copy_qkvfp8_zero_lds( ...@@ -2782,6 +2794,7 @@ lds_direct_copy_qkvfp8_zero_lds(
#if defined(__gfx938__) #if defined(__gfx938__)
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
......
...@@ -120,6 +120,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -120,6 +120,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
+ offset_k * 3 * bytes_per_block; + offset_k * 3 * bytes_per_block;
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
...@@ -137,6 +138,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp ...@@ -137,6 +138,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
+ offset_k * 2 * bytes_per_block; + offset_k * 2 * bytes_per_block;
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
......
...@@ -458,6 +458,7 @@ lds_direct_copy( ...@@ -458,6 +458,7 @@ lds_direct_copy(
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
...@@ -521,6 +522,7 @@ lds_direct_copy( ...@@ -521,6 +522,7 @@ lds_direct_copy(
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
...@@ -591,6 +593,7 @@ lds_direct_copy_for_prefill_sparse_mla( ...@@ -591,6 +593,7 @@ lds_direct_copy_for_prefill_sparse_mla(
index_offset[1] = offset_v; index_offset[1] = offset_v;
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset), "buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
...@@ -818,6 +821,7 @@ lds_direct_copy_qkvfp8( ...@@ -818,6 +821,7 @@ lds_direct_copy_qkvfp8(
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
...@@ -884,6 +888,7 @@ lds_direct_copy_qkvfp8( ...@@ -884,6 +888,7 @@ lds_direct_copy_qkvfp8(
#if defined(__gfx938__) #if defined(__gfx938__)
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
...@@ -1041,6 +1046,7 @@ lds_direct_copy_fp8( ...@@ -1041,6 +1046,7 @@ lds_direct_copy_fp8(
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__)
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment