#pragma once #include "config.h" #include "utils.h" #include "softmax.h" #include "../../helpers.h" namespace gfx93::fwd { #define CUDART_L2E_F 1.442695041F using namespace cute; template __device__ void KernelTemplate_B_H_64::devfunc(const SparseAttnFwdParams ¶ms) { const int tidx = threadIdx.x; static constexpr int kBlockM = B_H; static constexpr int kBlockN = B_TOPK; static constexpr int kHeadDim = D_QK; static constexpr int kHeadDimV = D_V; const int warp_idx = __builtin_amdgcn_readfirstlane(tidx / 64); const int s_q_idx = blockIdx.y; const int bidh = blockIdx.x; const int lane_idx = tidx % 64; extern __shared__ Element smem[]; Element* q_lds = (Element*)&(smem); Element* k_lds = q_lds; Element* v_lds = q_lds; int* sIndices = (int *)(q_lds + 8192); const index_t row_offset_q = s_q_idx * static_cast(params.stride_q_s_q) + bidh * kBlockM * params.stride_q_h_q; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q) + row_offset_q), Shape, Int>{}, make_stride(params.stride_q_h_q, _1{})); const index_t row_offset_k = 0 * params.stride_kv_h_kv; Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.kv) + row_offset_k), Shape, Int>{}, make_stride(params.stride_kv_s_kv, _1{})); const index_t row_offset_topk = s_q_idx * params.stride_indices_s_q; int* gIndices = reinterpret_cast(params.indices) + row_offset_topk; typedef __bf16 __fp16x8_t __attribute__((ext_vector_type(8))); typedef __bf16 __fp16x4_t __attribute__((ext_vector_type(4))); typedef __bf16 __fp16x2_t __attribute__((ext_vector_type(2))); union Bf16_storage { __fp16x8_t data_128; __fp16x4_t data_64[2]; __fp16x2_t data_32[4]; uint16_t data_array[8]; }; union Bf16_storage_x4 { __fp16x4_t data_64; __fp16x2_t data_32[2]; uint16_t data[4]; }; const int topk_length = HAVE_TOPK_LENGTH ? __ldg(params.topk_length + s_q_idx) : params.topk; const int num_topk_blocks = IS_TOPK_2048? 2048 / B_TOPK : HAVE_TOPK_LENGTH ? ku::ceil_div(topk_length, (int)B_TOPK) : (int)((unsigned int)params.topk/(unsigned int)B_TOPK); // TiledMMA tiled_mma = TiledMma{}; // auto thr_mma = tiled_mma.get_thread_slice(tidx); flash::Softmax<1> softmax; // #if 1 // #if defined(__gfx938__) // #else int virtual_row_ = lane_idx / 8;//0 int virtual_col_ = lane_idx % 8;//0 int swizzle_col_ = virtual_row_ ^ virtual_col_; int row_ = lane_idx / 4;//0 // 8->9 9->8 // row_ = (row_ >= 8 ) ^ row_; int col_ = swizzle_col_ % 4; // #endif auto calc_row_and_col_k = [&](const int block_idx) -> std::tuple { constexpr int elements_per_thread = 8; // int row = lane_idx % 16; // int col = lane_idx / 16; // int row_offset = row * 4 + warp_idx + block_idx * kBlockN; #if defined(__gfx938__) // int row = lane_idx / 4; // int col = lane_idx % 4; // col = (col + (4 - (row / 2) % 4)) % 4; // int row_offset = row + warp_idx * 16 + block_idx * kBlockN; // int col_offset = col * 8; int row_offset = row_ + warp_idx * 16 + block_idx * kBlockN; int col_offset = col_ * 8; #else int row_offset = row_ * 4 + warp_idx + block_idx * kBlockN; int col_offset = col_ * 8; #endif // int row_offset = row + warp_idx * 16 + block_idx * kBlockN; if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) { row_offset = sIndices[row_offset % 1024]; } else { row_offset = gIndices[row_offset]; } return {row_offset, col_offset}; }; auto calc_row_and_col_v = [&](const int block_idx, int i) -> int { int row = lane_idx / 4; // int col = lane_idx % 4; int row_offset = row + i * 16 + block_idx * kBlockN;; // int col_offset = col * 8 + warp_idx * 32; if (HAVE_TOPK_LENGTH && row_offset >= topk_length) { return params.s_kv; } if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) { row_offset = sIndices[row_offset % 1024]; } else { row_offset = gIndices[row_offset]; } row_offset = row_offset == -1 ? params.s_kv : row_offset; return row_offset; }; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr_q; *(uint64_t*)&glob_ptr_q = reinterpret_cast(gQ.data().get()); glob_ptr_q.latter |= ((params.stride_q_h_q * 2) << 16); glob_ptr_q.latter |= 0x40000000; uint32x4_t global_addr_q = {0}; global_addr_q[0] = (glob_ptr_q.former); global_addr_q[1] = (glob_ptr_q.latter); global_addr_q[2] = 64; global_addr_q[3] = 0x00020000; auto buffer_load_lds_indices = [&] (int n, int num_indices) { if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) { PtrWrapper glob_ptr_indices; *(uint64_t*)&glob_ptr_indices = reinterpret_cast(gIndices); glob_ptr_indices.latter |= 0x40000000; uint32x4_t global_addr_indices = {0}; global_addr_indices[0] = (glob_ptr_indices.former); global_addr_indices[1] = (glob_ptr_indices.latter); global_addr_indices[2] = 0x80000000; global_addr_indices[3] = 0x00020000; int ldsAddrPerWave = reinterpret_cast(sIndices) + warp_idx * 64 * 4 * 4; const int offset_v = lane_idx * 4 * 4 + warp_idx * 64 * 4 * 4; const int offset_s = n * 1024 * 4; const int first_index = warp_idx * 256 + lane_idx * 4; if (first_index < num_indices) { __builtin_amdgcn_sched_barrier(0); asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 offen offset:0, lds \n" ::"v"(offset_v), "s"(ldsAddrPerWave), "s"(global_addr_indices), "s"(offset_s) :); __builtin_amdgcn_sched_barrier(0); } } }; if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) { buffer_load_lds_indices(0, IS_TOPK_2048 ? 1024 : params.topk); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); } PtrWrapper glob_ptr_k; *(uint64_t*)&glob_ptr_k = reinterpret_cast(gK.data().get()); glob_ptr_k.latter |= ((params.stride_kv_s_kv * 2) << 16); glob_ptr_k.latter |= 0x40000000; uint32x4_t global_addr_k = {0}; global_addr_k[0] = (glob_ptr_k.former); global_addr_k[1] = (glob_ptr_k.latter); global_addr_k[2] = params.s_kv; global_addr_k[3] = 0x00020000; auto buffer_load_lds_k = [&](int row_offset, int col, int k_idx) { constexpr int element_size = 2; // int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); // struct PtrWrapper { // uint32_t former; // uint32_t latter; // }; // PtrWrapper glob_ptr; // *(uint64_t*)&glob_ptr = reinterpret_cast(gK.data().get()); // glob_ptr.latter |= ((row_stride * 2) << 16); // uint32x4_t global_addr = {0}; // global_addr[0] = (glob_ptr.former); // global_addr[1] = (glob_ptr.latter); // global_addr[2] = max_MN; // global_addr[3] = 0x00020000; constexpr int elements_per_thread = 8; int col_offset = col; int offset_v = col_offset * 2; int ldsAddrPerWave = reinterpret_cast(k_lds) + warp_idx * 16 * 32 * 2 + (k_idx % 4) * 64 * 32 * 2; typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2))); uint32x2_t index_offset = {0}; index_offset[0] = row_offset; index_offset[1] = offset_v; const int offset_s = k_idx * 32 * 2; __builtin_amdgcn_sched_barrier(0); asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset), "s"(ldsAddrPerWave), "s"(global_addr_k), "s"(offset_s) :); __builtin_amdgcn_sched_barrier(0); }; auto buffer_load_lds_v = [&](int row_offset, int col, int k_idx, int n_idx) { constexpr int element_size = 2; // int k_idx = __builtin_amdgcn_readfirstlane(k_idx_); // struct PtrWrapper { // uint32_t former; // uint32_t latter; // }; // PtrWrapper glob_ptr; // *(uint64_t*)&glob_ptr = reinterpret_cast(gK.data().get()); // glob_ptr.latter |= ((row_stride * 2) << 16); // uint32x4_t global_addr = {0}; // global_addr[0] = (glob_ptr.former); // global_addr[1] = (glob_ptr.latter); // global_addr[2] = max_MN; // global_addr[3] = 0x00020000; constexpr int elements_per_thread = 8; int col_offset = col; // int v_idx = row_offset; int offset_v = col_offset * 2; int ldsAddrPerWave = reinterpret_cast(v_lds) + warp_idx * 16 * 32 * 2 + (k_idx) * 128 * 16 * 2; typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2))); uint32x2_t index_offset = {0}; index_offset[0] = row_offset; index_offset[1] = offset_v; const int offset_s = n_idx * 128 * 2; __builtin_amdgcn_sched_barrier(0); asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset), "s"(ldsAddrPerWave), "s"(global_addr_k), "s"(offset_s) :); __builtin_amdgcn_sched_barrier(0); }; const int v_lds_read_ptr = reinterpret_cast(v_lds + lane_idx * 8); auto k_lds_read_offset = [&] () -> int { // #if defined(__gfx938__) // int row = lane_idx % 16; // int col = lane_idx / 16; // col = (col + (row / 2) % 4) % 4; // const auto lds_offset = row * 32 + col * 8; // #else int row = lane_idx % 16; int col = lane_idx / 16; col = (row / 2) ^ col; col = col % 4; // row = (row >= 8) ^ row; const auto lds_offset = row * 32 + col * 8; // #endif return lds_offset; }; Element* q_lds_read_ptr = (q_lds + warp_idx * 16 * 32 + lane_idx * 8); Element* k_lds_read_ptr = (k_lds + k_lds_read_offset()); Bf16_storage q_reg[18]; static constexpr int kQkChunks = D_QK / 32; for (int i = 0; i < kQkChunks; i++) { constexpr int elements_per_thread = 8; int row = lane_idx % 16; int col = lane_idx / 16; int row_offset = row + warp_idx * 16; int col_offset = col * 8; int offset_v = col_offset * 2 + i * 32 * 2; q_reg[i].data_128 = __builtin_amdgcn_buffer_load_dwordx4(global_addr_q, row_offset, offset_v, false, false); } __syncthreads(); v4f acco_f32[32]; for (int i = 0; i < 32; i++) { acco_f32[i].x = 0.0f; acco_f32[i].y = 0.0f; acco_f32[i].z = 0.0f; acco_f32[i].w = 0.0f; } int col_offset_v = (lane_idx % 4) * 8 + warp_idx * 32; struct IsFirstBlock {}; struct IsOtherBlock {}; auto float2bf16 = [] (float s) -> uint16_t { uint32_t x32 = reinterpret_cast(s); #ifndef FLASH_MLA_BF16_TYPE #define FLASH_MLA_BF16_TYPE 0 #endif #if FLASH_MLA_BF16_TYPE == 1 x32 += 0x8000u; #endif return uint16_t(x32 >> 16); }; auto process_one_block = [&] (int block_idx, auto is_block_t) { static constexpr bool IS_FIRST_BLOCK = std::is_same_v; static constexpr bool IS_OTHER_BLOCK = std::is_same_v; v4f accs_f32[4]; for (int i = 0; i < 4; i++) { accs_f32[i].x = 0.0f; accs_f32[i].y = 0.0f; accs_f32[i].z = 0.0f; accs_f32[i].w = 0.0f; } auto [row_offset, col] = calc_row_and_col_k(block_idx); const int row_in_topk = row_ + warp_idx * 16 + block_idx * kBlockN; if (HAVE_TOPK_LENGTH && row_in_topk >= topk_length) { row_offset = -1; } row_offset = row_offset == -1 ? params.s_kv : row_offset; #if 1 #define LOAD_K_AND_QK_GEMM(k) \ { \ constexpr int k_val = (k); \ if constexpr (k_val < kQkChunks - 1) { \ buffer_load_lds_k(row_offset, col, k_val - 3); \ flash::qk_gemm(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); \ __builtin_amdgcn_sched_barrier(0); \ asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \ __builtin_amdgcn_sched_barrier(0); \ } \ } { constexpr int k_val = kQkChunks - 1; buffer_load_lds_k(row_offset, col, k_val); buffer_load_lds_k(row_offset, col, k_val - 1); buffer_load_lds_k(row_offset, col, k_val - 2); buffer_load_lds_k(row_offset, col, k_val - 3); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); flash::qk_gemm(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); LOAD_K_AND_QK_GEMM(16); LOAD_K_AND_QK_GEMM(15); LOAD_K_AND_QK_GEMM(14); LOAD_K_AND_QK_GEMM(13); LOAD_K_AND_QK_GEMM(12); LOAD_K_AND_QK_GEMM(11); LOAD_K_AND_QK_GEMM(10); LOAD_K_AND_QK_GEMM(9); LOAD_K_AND_QK_GEMM(8); LOAD_K_AND_QK_GEMM(7); LOAD_K_AND_QK_GEMM(6); LOAD_K_AND_QK_GEMM(5); LOAD_K_AND_QK_GEMM(4); LOAD_K_AND_QK_GEMM(3); flash::qk_gemm(q_reg[2].data_128, k_lds_read_ptr, accs_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); flash::qk_gemm(q_reg[1].data_128, k_lds_read_ptr, accs_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); flash::qk_gemm(q_reg[0].data_128, k_lds_read_ptr, accs_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); } #undef LOAD_K_AND_QK_GEMM #else #endif auto is_valid_token = [&](const int idx) -> bool { const int n_idx = (lane_idx / 16) * 4 + (idx % 4) + (idx / 4) * 16; int offs = n_idx + block_idx * kBlockN; int t; if constexpr (IS_TOPK_2048 || CACHE_INDICES_IN_LDS) { t = sIndices[offs % 1024]; } else { t = gIndices[offs]; } bool is_cur_token_valid = t >= 0 && t < params.s_kv; if constexpr (HAVE_TOPK_LENGTH) { is_cur_token_valid = is_cur_token_valid && (offs < topk_length); } return is_cur_token_valid; }; for (int i = 0; i < 16; ++i) { #if defined(__gfx938__) if (!is_valid_token(i)) accs_f32[i/4][i%4] = -INFINITY; #else if (!is_valid_token(i)) accs_f32[i%4][i/4] = -INFINITY; #endif } // Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); // Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(Shape<_1, _16>{}); for (int i = 0; i < 16; i++) { #if defined(__gfx938__) scores(0, i) = accs_f32[i/4][i%4]; #else scores(0, i) = accs_f32[i%4][i/4]; #endif } softmax.template softmax_rescale_o_prefill_4x1(scores, acco_f32, params.sm_scale_div_log2); Bf16_storage_x4 p[4]; for (int i = 0; i < 4; i++) { #if defined(__gfx938__) p[i].data_32[0] = __builtin_hcu_cvt_pk_bf16_f32(0, scores(0, i * 4), 0, scores(0, i * 4 + 1), 0); p[i].data_32[1] = __builtin_hcu_cvt_pk_bf16_f32(0, scores(0, i * 4 + 2), 0, scores(0, i * 4 + 3), 0); #else p[i].data[0] = float2bf16(scores(0, i * 4)); p[i].data[1] = float2bf16(scores(0, i * 4 + 1)); p[i].data[2] = float2bf16(scores(0, i * 4 + 2)); p[i].data[3] = float2bf16(scores(0, i * 4 + 3)); #endif } int row_offset_v[4]; for (int i = 0; i < 4; i++) { row_offset_v[i] = calc_row_and_col_v(block_idx, i); } __syncthreads(); #if 1 { constexpr int k_val = (0); buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 0); buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 0); buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, 0); buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, 0); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 1); flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 1); flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, 1); flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, 1); } #define LOAD_V_AND_PV_GEMM(n) \ { \ constexpr int k_val = (0); \ constexpr int n_val = (n); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ __builtin_amdgcn_sched_barrier(0); \ asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \ __builtin_amdgcn_sched_barrier(0); \ buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, n_val + 1); \ flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \ __builtin_amdgcn_sched_barrier(0); \ asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \ __builtin_amdgcn_sched_barrier(0); \ buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, n_val + 1); \ flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \ __builtin_amdgcn_sched_barrier(0); \ asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \ __builtin_amdgcn_sched_barrier(0); \ buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, n_val + 1); \ flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \ __builtin_amdgcn_sched_barrier(0); \ asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \ __builtin_amdgcn_sched_barrier(0); \ buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, n_val + 1); \ } LOAD_V_AND_PV_GEMM(1); LOAD_V_AND_PV_GEMM(2); { constexpr int n_val = (3); flash::pv_gemm<0, 12>(p[0].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<0, 13>(p[0].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<0, 14>(p[0].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<0, 15>(p[0].data_64, v_lds_read_ptr, acco_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); flash::pv_gemm<1, 12>(p[1].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<1, 13>(p[1].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<1, 14>(p[1].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<1, 15>(p[1].data_64, v_lds_read_ptr, acco_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); flash::pv_gemm<2, 12>(p[2].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<2, 13>(p[2].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<2, 14>(p[2].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<2, 15>(p[2].data_64, v_lds_read_ptr, acco_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); flash::pv_gemm<3, 12>(p[3].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<3, 13>(p[3].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<3, 14>(p[3].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<3, 15>(p[3].data_64, v_lds_read_ptr, acco_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); } #else #define LOAD_V_AND_PV_GEMM(k) \ { \ constexpr int k_val = (k); \ buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 0); \ buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 1); \ buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 2); \ buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 3); \ __builtin_amdgcn_sched_barrier(0); \ asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); \ __builtin_amdgcn_sched_barrier(0); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ __builtin_amdgcn_sched_barrier(0); \ asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \ __builtin_amdgcn_sched_barrier(0); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ __builtin_amdgcn_sched_barrier(0); \ asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); \ __builtin_amdgcn_sched_barrier(0); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ __builtin_amdgcn_sched_barrier(0); \ asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); \ __builtin_amdgcn_sched_barrier(0); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ __builtin_amdgcn_sched_barrier(0); \ asm volatile("s_barrier \n\t"); \ __builtin_amdgcn_sched_barrier(0); \ } LOAD_V_AND_PV_GEMM(0); LOAD_V_AND_PV_GEMM(1); LOAD_V_AND_PV_GEMM(2); LOAD_V_AND_PV_GEMM(3); #endif }; if constexpr (IS_TOPK_2048) { process_one_block(0, IsFirstBlock{}); for (int block_idx = 1; block_idx < 1024 / B_TOPK; block_idx ++) { process_one_block(block_idx, IsOtherBlock{}); } buffer_load_lds_indices(1, 1024); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); for (int block_idx = 1024/B_TOPK; block_idx < 2048 / B_TOPK; block_idx ++) { process_one_block(block_idx, IsOtherBlock{}); } } else { if (num_topk_blocks > 0) process_one_block(0, IsFirstBlock{}); for (int block_idx = 1; block_idx < num_topk_blocks; block_idx ++) { process_one_block(block_idx, IsOtherBlock{}); } } Tensor lse = softmax.template normalize_softmax_lse_prefill_4x1(acco_f32, params.sm_scale); const index_t row_offset_o = s_q_idx * static_cast(params.h_q * params.d_v) + bidh * kBlockM * params.d_v; Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.out) + row_offset_o), Shape, Int>{}, make_stride(params.d_v, _1{})); const index_t row_offset_lse = s_q_idx * params.h_q + bidh * kBlockM; float* gLSE = reinterpret_cast(params.lse) + row_offset_lse; // const index_t row_offset_lse = m_block * params.h_q; float* gMax_logits = reinterpret_cast(params.max_logits) + row_offset_lse; float attn_sink_o_scale = 1.0f; if constexpr (USE_ATTN_SINK) { float rAttn_sink = __ldg((float*)params.attn_sink + bidh * kBlockM + lane_idx % 16 + warp_idx * 16); if (flash::is_positive_infinity(rAttn_sink)) { attn_sink_o_scale = 0.0f; } else if (!flash::is_positive_infinity(lse(0))) { float lse_exp2 = __builtin_amdgcn_exp2f(lse[0] * CUDART_L2E_F); float rAttn_sink_exp2 = __builtin_amdgcn_exp2f(rAttn_sink * CUDART_L2E_F); attn_sink_o_scale = lse_exp2 / (lse_exp2 + rAttn_sink_exp2); } } auto maybe_apply_attn_sink = [&] (float value) -> float { if constexpr (USE_ATTN_SINK) { return value * attn_sink_o_scale; } else { return value; } }; { // store O and gLSE // auto rO = flash::convert_type(acc_o); int row, col; // const int warpId = tidx / 64; // const int laneId = tidx % 64; for (int mi = 0; mi < 1; ++mi) { row = mi * kBlockM + lane_idx % 16 + warp_idx * 16; // if (row < params.h_q) { for (int ni = 0; ni < 16; ++ni) { #if defined(__gfx938__) Bf16_storage res; col = (lane_idx / 16) * 8 + ni * 32 ; res.data_32[0] = __builtin_hcu_cvt_pk_bf16_f32(0, maybe_apply_attn_sink(acco_f32[ni * 2][0]), 0, maybe_apply_attn_sink(acco_f32[ni * 2 + 1][0]), 0); res.data_32[1] = __builtin_hcu_cvt_pk_bf16_f32(0, maybe_apply_attn_sink(acco_f32[ni * 2][1]), 0, maybe_apply_attn_sink(acco_f32[ni * 2 + 1][1]), 0); res.data_32[2] = __builtin_hcu_cvt_pk_bf16_f32(0, maybe_apply_attn_sink(acco_f32[ni * 2][2]), 0, maybe_apply_attn_sink(acco_f32[ni * 2 + 1][2]), 0); res.data_32[3] = __builtin_hcu_cvt_pk_bf16_f32(0, maybe_apply_attn_sink(acco_f32[ni * 2][3]), 0, maybe_apply_attn_sink(acco_f32[ni * 2 + 1][3]), 0); *(__fp16x8_t*)(&gO(row, col)) = res.data_128; #else col = (lane_idx / 16) * 2 + ni * 32 ; using result_type = cutlass::Array; for (int ei = 0; ei < 4; ei++) { result_type res; Element e0, e1; e0.storage = float2bf16(maybe_apply_attn_sink(acco_f32[ni * 2][ei])); e1.storage = float2bf16(maybe_apply_attn_sink(acco_f32[ni * 2 + 1][ei])); res[0] = e0; res[1] = e1; // gO(row, col) = res[0]; // gO(row, col + 1) = res[1]; *(result_type*)(&gO(row, col)) = res; col += 8; } #endif } gLSE[row] = lse(mi); if constexpr (HAVE_TOPK_LENGTH) { gMax_logits[row] = topk_length == 0 ? -INFINITY : softmax.row_max(mi) * params.sm_scale; } else { gMax_logits[row] = softmax.row_max(mi) * params.sm_scale; } } } } } template __device__ void KernelTemplate::devfunc(const SparseAttnFwdParams ¶ms) { extern __shared__ char smem_[]; SharedMemoryPlan &plan = *reinterpret_cast(smem_); const int tidx = threadIdx.x; static constexpr int kBlockM = B_H; static constexpr int kBlockN = B_TOPK; static constexpr int kHeadDim = D_QK; static constexpr int kHeadDimV = D_V; const int warp_idx = tidx / 64; const int s_q_idx = blockIdx.x; const int bidh = blockIdx.y; const int lane_idx = tidx % 64; const index_t row_offset_q = s_q_idx * static_cast(params.stride_q_s_q) + bidh * kBlockM * params.stride_q_h_q; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q) + row_offset_q), Shape, Int>{}, make_stride(params.stride_q_h_q, _1{})); const index_t row_offset_k = 0 * params.stride_kv_h_kv; Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.kv) + row_offset_k), Shape, Int>{}, make_stride(params.stride_kv_s_kv, _1{})); const index_t row_offset_topk = s_q_idx * params.stride_indices_s_q; int* gIndices = reinterpret_cast(params.indices) + row_offset_topk; Tensor sQ = make_tensor(make_smem_ptr(plan.smem_q.data()), SmemLayoutQ{}); Tensor sV = make_tensor(make_smem_ptr(plan.smem_v.data()), SmemLayoutV{}); Tensor sK = make_tensor(make_smem_ptr(plan.smem_v.data()), SmemLayoutK{}); Tensor sP = make_tensor(make_smem_ptr(plan.smem_p.data()), SmemLayoutP{}); Tensor sVt = make_tensor(sV.data(), SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), SmemLayoutVtransposedNoSwizzle{}); Tensor sRow_max_reduce_buffer = make_tensor(make_smem_ptr(plan.smem_row_max.data()), SmemLayoutRow{}); Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(plan.smem_row_sum.data()), SmemLayoutRow{}); TiledMMA tiled_mma = TiledMma{}; auto thr_mma = tiled_mma.get_thread_slice(tidx); TiledMMA tiled_mma_o = TiledMma_O{}; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); flash::lds_direct_copy(gQ, sQ, 0, params.stride_q_h_q, params.h_q - bidh * kBlockM); flash::lds_direct_copy(gQ, sQ, 1, params.stride_q_h_q, params.h_q - bidh * kBlockM); flash::lds_direct_copy(gQ, sQ, 2, params.stride_q_h_q, params.h_q - bidh * kBlockM); flash::lds_direct_copy(gQ, sQ, 3, params.stride_q_h_q, params.h_q - bidh * kBlockM); if constexpr (D_QK == 576) { flash::lds_direct_copy(gQ, sQ, 4, params.stride_q_h_q, params.h_q - bidh * kBlockM); } auto smem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); // asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); if constexpr (D_QK == 576) { asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 0), tSrQ_copy_view(_, _, 0)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 1), tSrQ_copy_view(_, _, 1)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 2), tSrQ_copy_view(_, _, 2)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 3), tSrQ_copy_view(_, _, 3)); asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 4), tSrQ_copy_view(_, _, 4)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 5), tSrQ_copy_view(_, _, 5)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 6), tSrQ_copy_view(_, _, 6)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 7), tSrQ_copy_view(_, _, 7)); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 8), tSrQ_copy_view(_, _, 8)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 9), tSrQ_copy_view(_, _, 9)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 10), tSrQ_copy_view(_, _, 10)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 11), tSrQ_copy_view(_, _, 11)); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 12), tSrQ_copy_view(_, _, 12)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 13), tSrQ_copy_view(_, _, 13)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 14), tSrQ_copy_view(_, _, 14)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 15), tSrQ_copy_view(_, _, 15)); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 16), tSrQ_copy_view(_, _, 16)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 17), tSrQ_copy_view(_, _, 17)); } else { asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 0), tSrQ_copy_view(_, _, 0)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 1), tSrQ_copy_view(_, _, 1)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 2), tSrQ_copy_view(_, _, 2)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 3), tSrQ_copy_view(_, _, 3)); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 4), tSrQ_copy_view(_, _, 4)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 5), tSrQ_copy_view(_, _, 5)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 6), tSrQ_copy_view(_, _, 6)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 7), tSrQ_copy_view(_, _, 7)); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 8), tSrQ_copy_view(_, _, 8)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 9), tSrQ_copy_view(_, _, 9)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 10), tSrQ_copy_view(_, _, 10)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 11), tSrQ_copy_view(_, _, 11)); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 12), tSrQ_copy_view(_, _, 12)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 13), tSrQ_copy_view(_, _, 13)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 14), tSrQ_copy_view(_, _, 14)); cute::copy(smem_tiled_copy_Q, tSsQ(_, _, 15), tSrQ_copy_view(_, _, 15)); } __syncthreads(); const int topk_length = HAVE_TOPK_LENGTH ? __ldg(params.topk_length + s_q_idx) : params.topk; const int num_topk_blocks = HAVE_TOPK_LENGTH ? ku::ceil_div(topk_length, (int)B_TOPK) : (int)((unsigned int)params.topk/(unsigned int)B_TOPK); auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); Tensor tSrK = thr_mma.partition_fragment_B(sK); Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK); Tensor tSrK_smem = thr_mma.partition_fragment_B(gK); auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom{}, tiled_mma_o); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); Tensor tOrVt = thr_mma_o.partition_fragment_B(sVtNoSwizzle); Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt); Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape, Int>{}); clear(acc_o); flash::Softmax(acc_o)> softmax; auto calc_row_and_col = [&](const int block_idx) -> std::tuple { // 计算swizzle后的全局显存访存地址 int virtual_row = lane_idx / 8; int virtual_col = lane_idx % 8; int swizzle_col = virtual_row ^ virtual_col; int row = lane_idx / 4; row = (row >= 8 ) ^ row; int col = swizzle_col % 4; int warp_id = tidx / 64; int row_offset = block_idx * kBlockN + row + (warp_idx * 16) ; // row_offset = row_offset < params.topk ? gIndices[row_offset] : -1; row_offset = gIndices[row_offset]; return {row_offset, col}; }; for (int block_idx = 0; block_idx < num_topk_blocks; block_idx++) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); clear(acc_s); auto [row_offset, col] = calc_row_and_col(block_idx); if constexpr (D_QK == 576) { for (int i = 16; i < 18; i++) { flash::lds_direct_copy_for_prefill_sparse_mla(gK, sK, row_offset, col, i, params.stride_kv_s_kv, params.s_kv); } asm volatile("s_waitcnt vmcnt(1) \n s_barrier"); cute::copy(smem_tiled_copy_K, tSsK(_, _, 0), tSrK_copy_view(_, _, 0)); cute::gemm(tiled_mma, tSrQ(_, _, 0 + 16), tSrK(_, _, 0), acc_s); asm volatile("s_waitcnt vmcnt(0) \n s_barrier"); flash::lds_direct_copy_for_prefill_sparse_mla(gK, sK, row_offset, col, 0, params.stride_kv_s_kv, params.s_kv); cute::copy(smem_tiled_copy_K, tSsK(_, _, 1), tSrK_copy_view(_, _, 1)); cute::gemm(tiled_mma, tSrQ(_, _, 1 + 16), tSrK(_, _, 1), acc_s); } else { flash::lds_direct_copy_for_prefill_sparse_mla(gK, sK, row_offset, col, 0, params.stride_kv_s_kv, params.s_kv); } for (int i = 1; i < 4; i++) { flash::lds_direct_copy_for_prefill_sparse_mla(gK, sK, row_offset, col, i, params.stride_kv_s_kv, params.s_kv); } int k_idx = 0; asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); flash::__ds_read_m32x16_row_col_rrow_alt<0, 0, 0>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow_alt<0, 1, 0>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow_alt<0, 2, 0>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow_alt<0, 3, 0>(tOsVt, tOrVt_copy_view); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); flash::lds_direct_copy_for_prefill_sparse_mla(gK, sK, row_offset, col, 4, params.stride_kv_s_kv, params.s_kv); flash::lds_direct_copy_for_prefill_sparse_mla(gK, sK, row_offset, col, 5, params.stride_kv_s_kv, params.s_kv); flash::lds_direct_copy_for_prefill_sparse_mla(gK, sK, row_offset, col, 6, params.stride_kv_s_kv, params.s_kv); flash::lds_direct_copy_for_prefill_sparse_mla(gK, sK, row_offset, col, 7, params.stride_kv_s_kv, params.s_kv); k_idx++; asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); flash::__ds_read_m32x16_row_col_rrow_alt<0, 0, 1>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow_alt<0, 1, 1>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow_alt<0, 2, 1>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow_alt<0, 3, 1>(tOsVt, tOrVt_copy_view); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); flash::lds_direct_copy_for_prefill_sparse_mla(gK, sK, row_offset, col, 8, params.stride_kv_s_kv, params.s_kv); flash::lds_direct_copy_for_prefill_sparse_mla(gK, sK, row_offset, col, 9, params.stride_kv_s_kv, params.s_kv); flash::lds_direct_copy_for_prefill_sparse_mla(gK, sK, row_offset, col, 10, params.stride_kv_s_kv, params.s_kv); flash::lds_direct_copy_for_prefill_sparse_mla(gK, sK, row_offset, col, 11, params.stride_kv_s_kv, params.s_kv); k_idx++; asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); flash::__ds_read_m32x16_row_col_rrow_alt<0, 0, 2>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow_alt<0, 1, 2>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow_alt<0, 2, 2>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow_alt<0, 3, 2>(tOsVt, tOrVt_copy_view); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); flash::lds_direct_copy_for_prefill_sparse_mla(gK, sK, row_offset, col, 12, params.stride_kv_s_kv, params.s_kv); flash::lds_direct_copy_for_prefill_sparse_mla(gK, sK, row_offset, col, 13, params.stride_kv_s_kv, params.s_kv); flash::lds_direct_copy_for_prefill_sparse_mla(gK, sK, row_offset, col, 14, params.stride_kv_s_kv, params.s_kv); flash::lds_direct_copy_for_prefill_sparse_mla(gK, sK, row_offset, col, 15, params.stride_kv_s_kv, params.s_kv); k_idx++; asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx % 4), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); asm volatile("s_barrier\n\t"); // if (block0()) // { // printf(" %.2f %.2f %.2f \n ", acc_s(0), acc_s(1), acc_s(2)); // } Tensor cS = make_identity_tensor(Shape, Int>{}); Tensor tScS = thr_mma.partition_C(cS); auto is_valid_token = [&](const int idx) -> bool { int offs = int(get<1>(tScS(idx))) + block_idx * kBlockN; int t = gIndices[offs]; bool is_cur_token_valid = t >= 0 && t < params.s_kv; if constexpr (HAVE_TOPK_LENGTH) { is_cur_token_valid = is_cur_token_valid && (offs < topk_length); } return is_cur_token_valid; }; { for (int i = 0; i < size(acc_s); ++i) { // idx = idx < params.topk ? gIndices[idx] : -1; if (!is_valid_token(i)) acc_s(i) = -INFINITY; } } block_idx == 0 ? softmax.template softmax_rescale_o_prefill(acc_s, acc_o, sRow_max_reduce_buffer, params.sm_scale_div_log2): softmax.template softmax_rescale_o_prefill(acc_s, acc_o, sRow_max_reduce_buffer, params.sm_scale_div_log2); // if (block0()) // { // printf(" %.2f %.2f %.2f %.2f %.2f %.2f \n ", acc_s(0), acc_s(1), acc_s(2), acc_s(3), softmax.row_max(0), params.sm_scale_div_log2); // } Tensor rP = flash::convert_type(acc_s); Tensor tOrP = flash::convert_layout_acc_Aregs(tiled_mma, tiled_mma_o, rP, sP); { flash::__ds_read_m32x16_row_col_rrow_alt<0, 0, 3>(tOsVt, tOrVt_copy_view); // __ds_read_m32x16_row_col<0, 0>(tOsVt, tOrVt_copy_view); // __ds_read_m32x16_row_col<1, 0>(tOsVt, tOrVt_copy_view); // __ds_read_m32x16_row_col<2, 0>(tOsVt, tOrVt_copy_view); // __ds_read_m32x16_row_col<0, 1>(tOsVt, tOrVt_copy_view); // __ds_read_m32x16_row_col<1, 1>(tOsVt, tOrVt_copy_view); // __ds_read_m32x16_row_col<2, 1>(tOsVt, tOrVt_copy_view); cute::gemm(tiled_mma_o, tOrP(_, _, 0), tOrVt(_, _, 0), acc_o); flash::__ds_read_m32x16_row_col_rrow_alt<0, 1, 3>(tOsVt, tOrVt_copy_view); cute::gemm(tiled_mma_o, tOrP(_, _, 1), tOrVt(_, _, 1), acc_o); // __ds_read_m32x16_row_col<0, 2>(tOsVt, tOrVt_copy_view); // __ds_read_m32x16_row_col<1, 2>(tOsVt, tOrVt_copy_view); // __ds_read_m32x16_row_col<2, 2>(tOsVt, tOrVt_copy_view); // __ds_read_m32x16_row_col<0, 3>(tOsVt, tOrVt_copy_view); // __ds_read_m32x16_row_col<1, 3>(tOsVt, tOrVt_copy_view); // __ds_read_m32x16_row_col<2, 3>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col_rrow_alt<0, 2, 3>(tOsVt, tOrVt_copy_view); cute::gemm(tiled_mma_o, tOrP(_, _, 2), tOrVt(_, _, 2), acc_o); flash::__ds_read_m32x16_row_col_rrow_alt<0, 3, 3>(tOsVt, tOrVt_copy_view); cute::gemm(tiled_mma_o, tOrP(_, _, 3), tOrVt(_, _, 3), acc_o); // for (int i = 0; i < size(tOrP); i++) // { // tOrP(i) = Element(1.0f); // } // cute::copy(smem_tiled_copy_V, tOsVt(_, 0, 0), tOrVt_copy_view(_, 0, 0)); // for (int i = 0; i < 4; i++) { // cute::copy(smem_tiled_copy_V, tOsVt(_, _, i), tOrVt_copy_view(_, _, i)); // // if (tOrVt(_, _, i) ) // cute::gemm(tiled_mma_o, tOrP(_, _, i), tOrVt(_, _, i), acc_o); // } // for (int i = 0; i < 8 * 2 * 16; i++) // { // } // asm volatile("s_barrier"); // if (thread0()) { // for (int i = 0; i < 64; i++) { // for (int j = 0; j < 512; j++) { // printf(" %.2f ", float(sK(i, j))); // } // printf("\n"); // } // } // if (block0()) // { // print("tidx %d acc_s %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f \n", // tidx, acc_o(0), acc_o(1), acc_o(2), acc_o(3), // acc_o(4), acc_o(5), acc_o(6), acc_o(7), // acc_o(8), acc_o(9), acc_o(10), acc_o(11), // acc_o(12), acc_o(13), acc_o(14), acc_o(15) // ); // } } // asm volatile("s_barrier\n\t"); } Tensor lse = softmax.template normalize_softmax_lse_prefill(acc_o, sRow_sum_reduce_buffer, params.sm_scale); const index_t row_offset_o = s_q_idx * static_cast(params.h_q * params.d_v) + bidh * kBlockM * params.d_v; Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.out) + row_offset_o), Shape, Int>{}, make_stride(params.d_v, _1{})); // lse = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat)); const index_t row_offset_lse = s_q_idx * params.h_q + bidh * kBlockM; float* gLSE = reinterpret_cast(params.lse) + row_offset_lse; // const index_t row_offset_lse = m_block * params.h_q; float* gMax_logits = reinterpret_cast(params.max_logits) + row_offset_lse; if (params.attn_sink != nullptr) { float rAttn_sink = __ldg((float*)params.attn_sink + bidh * kBlockM + lane_idx % 16); if (flash::is_positive_infinity(rAttn_sink)) { for (int i = 0; i < size(acc_o); i++) { acc_o(i) = 0.0f; } } else { if (!flash::is_positive_infinity(lse(0))) { float lse_exp2 = __builtin_amdgcn_exp2f(lse[0] * CUDART_L2E_F); float rAttn_sink_exp2 = __builtin_amdgcn_exp2f(rAttn_sink * CUDART_L2E_F); float o_scale = lse_exp2 / (lse_exp2 + rAttn_sink_exp2); for (int i = 0; i < size(acc_o); i++) { acc_o(i) *= o_scale; } } } } // if (block0()) // { // print("tidx %d acc_s %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f \n", // tidx, acc_o(0), acc_o(1), acc_o(2), acc_o(3), // acc_o(4), acc_o(5), acc_o(6), acc_o(7), // acc_o(8), acc_o(9), acc_o(10), acc_o(11), // acc_o(12), acc_o(13), acc_o(14), acc_o(15) // ); // } { // store O and gLSE // auto rO = flash::convert_type(acc_o); auto float2bf16 = [] (float s) -> uint16_t { uint32_t x32 = reinterpret_cast(s); #ifndef FLASH_MLA_BF16_TYPE #define FLASH_MLA_BF16_TYPE 0 #endif #if FLASH_MLA_BF16_TYPE == 1 x32 += 0x8000u; #endif return uint16_t(x32 >> 16); }; int row, col; const int warpId = tidx / 64; const int laneId = tidx % 64; for (int mi = 0; mi < size<1>(acc_o); ++mi) { row = mi * kBlockM + laneId % 16; if (row < params.h_q) { for (int ni = 0; ni < size<2>(acc_o); ++ni) { col = (laneId / 16) * 2 + ni * 128 + warpId * 32 ; using result_type = cutlass::Array; for (int ei = 0; ei < 4; ei++) { #if defined(__gfx938__) auto d = __builtin_hcu_cvt_pk_bf16_f32(0, acc_o(ei, mi, ni), 0, acc_o(ei + 4, mi, ni), 0); auto res = reinterpret_cast(d); #else result_type res; Element e0, e1; e0.storage = float2bf16(acc_o(ei, mi, ni)); e1.storage = float2bf16(acc_o(ei + 4, mi, ni)); res[0] = e0; res[1] = e1; #endif // gO(row, col) = res[0]; // gO(row, col + 1) = res[1]; *(result_type*)(&gO(row, col)) = res; col += 8; } // gO(row, col) = rO(0, mi, ni); // gO(row, col + 1) = rO(1, mi, ni); // col += 8; // gO(row, col) = rO(2, mi, ni); // gO(row, col + 1) = rO(3, mi, ni); // col += 8; // gO(row, col) = rO(4, mi, ni); // gO(row, col + 1) = rO(5, mi, ni); // col += 8; // gO(row, col) = rO(6, mi, ni); // gO(row, col + 1) = rO(7, mi, ni); // gO(row, col) = rO(0, mi, ni); // gO(row, col + 1) = rO(4, mi, ni); // col += 8; // gO(row, col) = rO(1, mi, ni); // gO(row, col + 1) = rO(5, mi, ni); // col += 8; // gO(row, col) = rO(2, mi, ni); // gO(row, col + 1) = rO(6, mi, ni); // col += 8; // gO(row, col) = rO(3, mi, ni); // gO(row, col + 1) = rO(7, mi, ni); // for (int ei = 0; ei < size<0>(acc_o); ei += 2) { // gO(row, col) = rO(ei, mi, ni); // col += 4; // } } gLSE[row] = lse(mi); gMax_logits[row] = topk_length == 0 ? -INFINITY : softmax.row_max(mi) * params.sm_scale; } } } } template __global__ void __launch_bounds__(Kernel::NUM_THREADS, 1) sparse_attn_fwd_kernel(const SparseAttnFwdParams params) { #if defined(__gfx936__) || defined(__gfx938__) Kernel::devfunc(params); #endif } template void KernelTemplate::run(const SparseAttnFwdParams ¶ms) { KU_ASSERT(params.h_kv == 1); KU_ASSERT(params.topk % (2*B_TOPK) == 0); // To save some boundry checkings KU_ASSERT(params.topk > 0); // KU_ASSERT(params.h_q % B_H == 0); auto kernel = &sparse_attn_fwd_kernel>; constexpr size_t smem_size = 16384 + 4096; // 做了lds复用 dim3 grid(params.s_q, (params.h_q + B_H - 1) / B_H, 1); kernel<<>>(params); KU_CHECK_KERNEL_LAUNCH(); } template void KernelTemplate_B_H_64::run(const SparseAttnFwdParams ¶ms) { KU_ASSERT(params.h_kv == 1); // KU_ASSERT(params.topk % (2*B_TOPK) == 0); // To save some boundry checkings KU_ASSERT(params.topk > 0); // KU_ASSERT(params.h_q % B_H == 0); auto kernel = &sparse_attn_fwd_kernel>; constexpr size_t smem_size = 16384 + 4096; // 做了lds复用 dim3 grid((params.h_q + B_H - 1) / B_H, params.s_q, 1); kernel<<>>(params); KU_CHECK_KERNEL_LAUNCH(); } template static void run_h64_fast_path(const SparseAttnFwdParams& params) { if (params.topk == 2048) { KernelTemplate_B_H_64::run(params); } else if (params.topk <= 1024) { KernelTemplate_B_H_64::run(params); } else { KernelTemplate_B_H_64::run(params); } } template void run_fwd_phase1_kernel(const SparseAttnFwdParams& params) { if (params.h_q == 64 || params.h_q == 128) { if (params.attn_sink) { run_h64_fast_path(params); } else { run_h64_fast_path(params); } return; } KernelTemplate::run(params); } }