Commit 9687a0a3 authored by zhanghj2's avatar zhanghj2
Browse files

add __builtin_amdgcn_sched_barrier(0);

parent 142846b5
...@@ -33,9 +33,11 @@ ...@@ -33,9 +33,11 @@
#define FLASH_DEVICE_ASSERT(cond) \ #define FLASH_DEVICE_ASSERT(cond) \
do { \ do { \
if (not (cond)) { \ if (not (cond)) {
__builtin_amdgcn_sched_barrier(0); \
printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \ printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \
asm volatile("s_trap 0 \n\t"); \ asm volatile("s_trap 0 \n\t");
__builtin_amdgcn_sched_barrier(0); \
} \ } \
} while(0) } while(0)
...@@ -477,14 +479,14 @@ lds_direct_copy( ...@@ -477,14 +479,14 @@ lds_direct_copy(
if (!Is_even_K && col_offset >= 576) offset_v = -1; if (!Is_even_K && col_offset >= 576) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size; int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t" "s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
__builtin_amdgcn_sched_barrier(0);
} }
...@@ -541,13 +543,14 @@ lds_direct_copy( ...@@ -541,13 +543,14 @@ lds_direct_copy(
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size; int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t" "s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
__builtin_amdgcn_sched_barrier(0);
} }
} }
...@@ -613,12 +616,14 @@ lds_direct_copy_for_prefill_sparse_mla( ...@@ -613,12 +616,14 @@ lds_direct_copy_for_prefill_sparse_mla(
uint32x2_t index_offset = {0}; uint32x2_t index_offset = {0};
index_offset[0] = row_offset == -1 ? max_MN : row_offset; index_offset[0] = row_offset == -1 ? max_MN : row_offset;
index_offset[1] = offset_v; index_offset[1] = offset_v;
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t" "s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset), "buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
__builtin_amdgcn_sched_barrier(0);
} }
template< template<
...@@ -681,11 +686,13 @@ buffer_load_copy( ...@@ -681,11 +686,13 @@ buffer_load_copy(
if constexpr(use_asm) { if constexpr(use_asm) {
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n" "buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst), " \n\t" :"=v"(dst),
"+v"(offset_v), "+s"(global_addr) "+v"(offset_v), "+s"(global_addr)
); );
__builtin_amdgcn_sched_barrier(0);
} }
else { else {
auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false); auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
...@@ -712,11 +719,13 @@ buffer_load_copy( ...@@ -712,11 +719,13 @@ buffer_load_copy(
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1; if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
if constexpr(use_asm) { if constexpr(use_asm) {
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n" "buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst), " \n\t" :"=v"(dst),
"+v"(offset_v), "+s"(global_addr) "+v"(offset_v), "+s"(global_addr)
); );
__builtin_amdgcn_sched_barrier(0);
} }
else { else {
auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false); auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
...@@ -850,13 +859,14 @@ lds_direct_copy_qkvfp8( ...@@ -850,13 +859,14 @@ lds_direct_copy_qkvfp8(
if (!Is_even_K && col_offset >= 576) offset_v = -1; if (!Is_even_K && col_offset >= 576) offset_v = -1;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size; int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + k_idx * mma_k * element_size;
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
"s_nop 0 \n\t" "s_nop 0 \n\t"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
__builtin_amdgcn_sched_barrier(0);
...@@ -916,7 +926,7 @@ lds_direct_copy_qkvfp8( ...@@ -916,7 +926,7 @@ lds_direct_copy_qkvfp8(
//int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + (k_idx % 2) * mma_k * element_size; //int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + (k_idx % 2) * mma_k * element_size;
int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + (k_idx) * mma_k * element_size; int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + (k_idx) * mma_k * element_size;
__builtin_amdgcn_sched_barrier(0);
#if defined(__gfx938__) #if defined(__gfx938__)
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
...@@ -925,6 +935,7 @@ lds_direct_copy_qkvfp8( ...@@ -925,6 +935,7 @@ lds_direct_copy_qkvfp8(
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
#endif #endif
__builtin_amdgcn_sched_barrier(0);
} }
} }
...@@ -978,11 +989,13 @@ buffer_load_copy_qkvfp8( ...@@ -978,11 +989,13 @@ buffer_load_copy_qkvfp8(
if constexpr(use_asm) { if constexpr(use_asm) {
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n" "buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst), " \n\t" :"=v"(dst),
"+v"(offset_v), "+s"(global_addr) "+v"(offset_v), "+s"(global_addr)
); );
__builtin_amdgcn_sched_barrier(0);
} }
else { else {
auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false); auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
...@@ -1074,7 +1087,7 @@ lds_direct_copy_fp8( ...@@ -1074,7 +1087,7 @@ lds_direct_copy_fp8(
// { // {
// printf("offset_v = %d %d \n", offset_v, warp_id * bytes_per_warp + k_idx * mma_k * element_size); // printf("offset_v = %d %d \n", offset_v, warp_id * bytes_per_warp + k_idx * mma_k * element_size);
// } // }
__builtin_amdgcn_sched_barrier(0);
#if defined(__gfx936__) || defined(__gfx938__) #if defined(__gfx936__) || defined(__gfx938__)
asm volatile( asm volatile(
"s_mov_b32 m0, %1 \n\t" "s_mov_b32 m0, %1 \n\t"
...@@ -1082,6 +1095,7 @@ lds_direct_copy_fp8( ...@@ -1082,6 +1095,7 @@ lds_direct_copy_fp8(
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v), "buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds \n" ::"v"(offset_v),
"s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s) "s"(ldsAddrPerWave), "s"(global_addr), "s"(offset_s)
:); :);
__builtin_amdgcn_sched_barrier(0);
#endif #endif
} }
} }
...@@ -1137,9 +1151,11 @@ __forceinline__ __device__ cutlass::half_t fp8e5m2_to_fp16(const fp8& input) { ...@@ -1137,9 +1151,11 @@ __forceinline__ __device__ cutlass::half_t fp8e5m2_to_fp16(const fp8& input) {
template <int N> template <int N>
CUTE_HOST_DEVICE CUTE_HOST_DEVICE
void wait_vmcnt() { void wait_vmcnt() {
__builtin_amdgcn_sched_barrier(0);
asm volatile("s_waitcnt vmcnt(%0) ;\n\t" asm volatile("s_waitcnt vmcnt(%0) ;\n\t"
"s_barrier; \n\t" "s_barrier; \n\t"
:: "n"(N)); :: "n"(N));
__builtin_amdgcn_sched_barrier(0);
} }
#if 0 #if 0
template<typename Element, bool is_scale_equal_one, Fp8KVCacheDataType KV_DTYPE, typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3, typename Tensor4, template<typename Element, bool is_scale_equal_one, Fp8KVCacheDataType KV_DTYPE, typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3, typename Tensor4,
...@@ -1412,11 +1428,13 @@ buffer_load_copy_fp8( ...@@ -1412,11 +1428,13 @@ buffer_load_copy_fp8(
if constexpr(use_asm) { if constexpr(use_asm) {
__builtin_amdgcn_sched_barrier(0);
asm volatile( asm volatile(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n" "buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
" \n\t" :"=v"(dst), " \n\t" :"=v"(dst),
"+v"(offset_v), "+s"(global_addr) "+v"(offset_v), "+s"(global_addr)
); );
__builtin_amdgcn_sched_barrier(0);
} }
else { else {
auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false); auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, 0, offset_v, false, false);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment