#include #include "utils.h" #include "params.h" #include "config.h" #include "traits.h" #include "softmax.h" using namespace cute; namespace gfx93 { // template // __device__ void // compute_attn_1rowblock_splitkv_mla_block_m_64_gfx936(const DenseAttnDecodeParams& params, // const int bidb, const int bidh, const int m_block, // const int n_split_idx, const int seqlen_k, // const int n_block_min, const int n_block_max, const bool NoSplit) // { // extern __shared__ char shared_memory[]; // const int tidx = threadIdx.x; // constexpr int kBlockM = T::BLOCK_SIZE_M; // constexpr int kBlockN = T::PAGE_BLOCK_SIZE; // constexpr int kHeadDim = T::HEAD_DIM_K; // constexpr int kHeadDimV = T::HEAD_DIM_V; // using Element = T::InputT; // using index_t = int64_t; // const int warp_idx = __builtin_amdgcn_readfirstlane(tidx / 64); // const int lane_idx = tidx % 64; // Element* q_lds = (Element*)&(shared_memory); // Element* k_lds = q_lds; // Element* v_lds = q_lds; // const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; // Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), // Shape, Int>{}, // make_stride(params.q_row_stride, _1{})); // // const index_t row_offset_k = (0) * params.k_head_stride; // // Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), // // Shape, Int>{}, // // make_stride(params.k_row_stride, _1{})); // typedef __bf16 __fp16x8_t __attribute__((ext_vector_type(8))); // typedef __bf16 __fp16x4_t __attribute__((ext_vector_type(4))); // typedef __bf16 __fp16x2_t __attribute__((ext_vector_type(2))); // union Bf16_storage { // __fp16x8_t data_128; // __fp16x4_t data_64[2]; // __fp16x2_t data_32[4]; // uint16_t data_array[8]; // }; // union Bf16_storage_x4 { // __fp16x4_t data_64; // __fp16x2_t data_32[2]; // uint16_t data[4]; // }; // struct PtrWrapper { // uint32_t former; // uint32_t latter; // }; // PtrWrapper glob_ptr_q; // *(uint64_t*)&glob_ptr_q = reinterpret_cast(gQ.data().get()); // glob_ptr_q.latter |= ((params.q_row_stride * 2) << 16); // glob_ptr_q.latter |= 0x40000000; // uint32x4_t global_addr_q = {0}; // global_addr_q[0] = (glob_ptr_q.former); // global_addr_q[1] = (glob_ptr_q.latter); // global_addr_q[2] = params.q_seq_per_hk - m_block * kBlockM; // global_addr_q[3] = 0x00020000; // int virtual_row_ = lane_idx / 8;//0 // int virtual_col_ = lane_idx % 8;//0 // int swizzle_col_ = virtual_row_ ^ virtual_col_; // int row_ = lane_idx / 4;//0 // // 8->9 9->8 // // row_ = (row_ >= 8 ) ^ row_; // int col_ = swizzle_col_ % 4; // auto calc_row_and_col_k = [&]() -> std::tuple { // constexpr int elements_per_thread = 8; // #if defined(__gfx938__) // int row_offset = row_ + warp_idx * 16; // int col_offset = col_ * 8; // #else // int row_offset = row_ * 4 + warp_idx; // int col_offset = col_ * 8; // #endif // return {row_offset, col_offset}; // }; // auto buffer_load_lds_k = [&](int row_offset, int col, int k_idx, int block_idx, index_t offset_k) { // constexpr int element_size = 2; // PtrWrapper glob_ptr_k; // *(uint64_t*)&glob_ptr_k = reinterpret_cast(params.k_ptr) + offset_k * 2; // glob_ptr_k.latter |= ((params.k_row_stride * 2) << 16); // glob_ptr_k.latter |= 0x40000000; // uint32x4_t global_addr_k = {0}; // global_addr_k[0] = __builtin_amdgcn_readfirstlane(glob_ptr_k.former); // global_addr_k[1] = __builtin_amdgcn_readfirstlane(glob_ptr_k.latter); // global_addr_k[2] = seqlen_k - block_idx * kBlockN; // global_addr_k[3] = 0x00020000; // constexpr int elements_per_thread = 8; // int col_offset = col; // int offset_v = col_offset * 2; // int ldsAddrPerWave = reinterpret_cast(k_lds) + warp_idx * 16 * 32 * 2 + (k_idx % 4) * 64 * 32 * 2; // typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2))); // uint32x2_t index_offset = {0}; // index_offset[0] = row_offset; // index_offset[1] = offset_v; // const int offset_s = k_idx * 32 * 2; // __builtin_amdgcn_sched_barrier(0); // asm volatile( // "s_mov_b32 m0, %1 \n\t" // "s_nop 0 \n\t" // "buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset), // "s"(ldsAddrPerWave), "s"(global_addr_k), "s"(offset_s) // :); // __builtin_amdgcn_sched_barrier(0); // }; // auto k_lds_read_offset = [&] () -> int { // int row = lane_idx % 16; // int col = lane_idx / 16; // col = (row / 2) ^ col; // col = col % 4; // // row = (row >= 8) ^ row; // const auto lds_offset = row * 32 + col * 8; // // #endif // return lds_offset; // }; // auto calc_row_and_col_v = [&](int i) -> int { // int row = lane_idx / 4; // // int col = lane_idx % 4; // int row_offset = row + i * 16; // // int col_offset = col * 8 + warp_idx * 32; // return row_offset; // }; // const int v_lds_read_ptr = reinterpret_cast(v_lds + lane_idx * 8); // Element* k_lds_read_ptr = (k_lds + k_lds_read_offset()); // int col_offset_v = (lane_idx % 4) * 8 + warp_idx * 32; // auto buffer_load_lds_v = [&](int row_offset, int col, int k_idx, int n_idx, int block_idx, index_t offset_k) { // constexpr int element_size = 2; // PtrWrapper glob_ptr_k; // *(uint64_t*)&glob_ptr_k = reinterpret_cast(params.k_ptr) + offset_k * 2; // glob_ptr_k.latter |= ((params.k_row_stride * 2) << 16); // glob_ptr_k.latter |= 0x40000000; // uint32x4_t global_addr_k = {0}; // global_addr_k[0] = __builtin_amdgcn_readfirstlane(glob_ptr_k.former); // global_addr_k[1] = __builtin_amdgcn_readfirstlane(glob_ptr_k.latter); // global_addr_k[2] = seqlen_k - block_idx * kBlockN; // global_addr_k[3] = 0x00020000; // constexpr int elements_per_thread = 8; // int col_offset = col; // // int v_idx = row_offset; // int offset_v = col_offset * 2; // int ldsAddrPerWave = reinterpret_cast(v_lds) + warp_idx * 16 * 32 * 2 + (k_idx) * 128 * 16 * 2; // typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2))); // uint32x2_t index_offset = {0}; // index_offset[0] = row_offset; // index_offset[1] = offset_v; // const int offset_s = n_idx * 128 * 2; // __builtin_amdgcn_sched_barrier(0); // asm volatile( // "s_mov_b32 m0, %1 \n\t" // "s_nop 0 \n\t" // "buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset), // "s"(ldsAddrPerWave), "s"(global_addr_k), "s"(offset_s) // :); // __builtin_amdgcn_sched_barrier(0); // }; // Bf16_storage q_reg[18]; // for (int i = 0; i < 18; i++) // { // constexpr int elements_per_thread = 8; // int row = lane_idx % 16; // int col = lane_idx / 16; // int row_offset = row + warp_idx * 16; // int col_offset = col * 8; // int offset_v = col_offset * 2 + i * 32 * 2; // q_reg[i].data_128 = __builtin_amdgcn_buffer_load_dwordx4(global_addr_q, row_offset, offset_v, false, false); // } // __syncthreads(); // v4f acco_f32[32]; // for (int i = 0; i < 32; i++) // { // acco_f32[i].x = 0.0f; // acco_f32[i].y = 0.0f; // acco_f32[i].z = 0.0f; // acco_f32[i].w = 0.0f; // } // const int *block_table = params.block_table + bidb * params.block_table_batch_stride; // auto float2bf16 = [] (float s) -> uint16_t { // uint32_t x32 = reinterpret_cast(s); // #ifndef FLASH_MLA_BF16_TYPE // #define FLASH_MLA_BF16_TYPE 0 // #endif // #if FLASH_MLA_BF16_TYPE == 1 // x32 += 0x8000u; // #endif // return uint16_t(x32 >> 16); // }; // // int block_idx = 0; // // int cur_block_table = block_table[block_idx]; // // index_t offset_k = (index_t)(cur_block_table) * params.k_batch_stride; // // auto [row_offset, col] = calc_row_and_col_k(block_idx); // // buffer_load_lds_k(row_offset, col, 0, offset_k); // // __syncthreads(); // { // // if (thread0()) // // { // // int k = 0; // // for (int i = 0; i < 64; i++) // // { // // for (int j = 0; j < 32; j++) // // { // // printf(" %.3f ", float(k_lds[k])); // // k++; // // } // // printf("\n"); // // } // // } // // if (block0() && threadIdx.x < 64) // // { // // cutlass::bfloat16_t q[8]; // // for (int i = 0; i < 8; i++) // // { // // q[i].storage = v_reg[0].data_array[i]; // // // q[i].storage = q_reg[0].data_array[i]; // // } // // printf("tidx %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %d\n ", threadIdx.x, // // float(q[0]), // // float(q[1]), // // float(q[2]), // // float(q[3]), // // float(q[4]), // // float(q[5]), // // float(q[6]), // // float(q[7]), // // v_lds_read_ptr // // ); // // } // } // struct IsMaskBlock {}; // struct IsFirstMaskBlock {}; // struct IsNoMaskBlock {}; // flash::Softmax<1> softmax; // auto process_one_block = [&] (int block_idx, auto is_mask_block_t) { // static constexpr bool IS_MASK_BLOCK = std::is_same_v; // static constexpr bool IS_FIRST_MASK_BLOCK = std::is_same_v; // static constexpr bool IS_NO_MASK_BLOCK = std::is_same_v; // int cur_block_table = block_table[block_idx]; // v4f accs_f32[4]; // for (int i = 0; i < 4; i++) // { // accs_f32[i].x = 0.0f; // accs_f32[i].y = 0.0f; // accs_f32[i].z = 0.0f; // accs_f32[i].w = 0.0f; // } // index_t offset_k = (index_t)(cur_block_table) * params.k_batch_stride; // auto [row_offset, col] = calc_row_and_col_k(); // #define LOAD_K_AND_QK_GEMM(k) \ // { \ // constexpr int k_val = (k); \ // buffer_load_lds_k(row_offset, col, k_val - 3, block_idx, offset_k); \ // flash::qk_gemm(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); \ // __builtin_amdgcn_sched_barrier(0); \ // asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \ // __builtin_amdgcn_sched_barrier(0); \ // } // { // constexpr int k_val = (17); // buffer_load_lds_k(row_offset, col, k_val, block_idx, offset_k); // buffer_load_lds_k(row_offset, col, k_val - 1, block_idx, offset_k); // buffer_load_lds_k(row_offset, col, k_val - 2, block_idx, offset_k); // buffer_load_lds_k(row_offset, col, k_val - 3, block_idx, offset_k); // __builtin_amdgcn_sched_barrier(0); // asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); // __builtin_amdgcn_sched_barrier(0); // flash::qk_gemm(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); // __builtin_amdgcn_sched_barrier(0); // asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); // __builtin_amdgcn_sched_barrier(0); // LOAD_K_AND_QK_GEMM(16); // LOAD_K_AND_QK_GEMM(15); // LOAD_K_AND_QK_GEMM(14); // LOAD_K_AND_QK_GEMM(13); // LOAD_K_AND_QK_GEMM(12); // LOAD_K_AND_QK_GEMM(11); // LOAD_K_AND_QK_GEMM(10); // LOAD_K_AND_QK_GEMM(9); // LOAD_K_AND_QK_GEMM(8); // LOAD_K_AND_QK_GEMM(7); // LOAD_K_AND_QK_GEMM(6); // LOAD_K_AND_QK_GEMM(5); // LOAD_K_AND_QK_GEMM(4); // LOAD_K_AND_QK_GEMM(3); // flash::qk_gemm(q_reg[k_val - 15].data_128, k_lds_read_ptr, accs_f32); // __builtin_amdgcn_sched_barrier(0); // asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); // __builtin_amdgcn_sched_barrier(0); // flash::qk_gemm(q_reg[k_val - 16].data_128, k_lds_read_ptr, accs_f32); // __builtin_amdgcn_sched_barrier(0); // asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); // __builtin_amdgcn_sched_barrier(0); // flash::qk_gemm(q_reg[k_val - 17].data_128, k_lds_read_ptr, accs_f32); // __builtin_amdgcn_sched_barrier(0); // asm volatile("s_barrier\n\t"); // __builtin_amdgcn_sched_barrier(0); // } // // if (block0() && tidx < 64) // // { // // printf(" %.3f %.3f \n", accs_f32[0][0], accs_f32[0][1]); // // } // if constexpr (!IS_NO_MASK_BLOCK) { // for (int i = 0; i < 16; ++i) { // int idx = i; // if constexpr (!T::Is_causal) { // if ((lane_idx / 16) * 4 + (idx % 4) + (idx / 4) * 16 >= int(seqlen_k - block_idx * kBlockN)) // { // #if defined(__gfx938__) // accs_f32[i/4][i%4] = -INFINITY; // #else // accs_f32[i%4][i/4] = -INFINITY; // #endif // } // } else { // // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups // // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups // int row = (lane_idx % 16) + warp_idx * 16;; // int col_limit_right = seqlen_k - 1 - block_idx * kBlockN - (params.q_seq_per_hk - 1 - (m_block * kBlockM + row)) / params.q_head_per_hk; // if ((lane_idx / 16) * 4 + (idx % 4) + (idx / 4) * 16 > col_limit_right) { // #if defined(__gfx938__) // accs_f32[i/4][i%4] = -INFINITY; // #else // accs_f32[i%4][i/4] = -INFINITY; // #endif // } // } // } // } // Tensor scores = make_tensor(Shape<_1, _16>{}); // for (int i = 0; i < 16; i++) { // #if defined(__gfx938__) // scores(0, i) = accs_f32[i/4][i%4]; // #else // scores(0, i) = accs_f32[i%4][i/4]; // #endif // } // softmax.template softmax_rescale_o_prefill_4x1(scores, acco_f32, params.scale_softmax_log2); // Bf16_storage_x4 p[4]; // for (int i = 0; i < 4; i++) // { // #if defined(__gfx938__) // p[i].data_32[0] = __builtin_hcu_cvt_pk_bf16_f32(0, scores(0, i * 4), 0, scores(0, i * 4 + 1), 0); // p[i].data_32[1] = __builtin_hcu_cvt_pk_bf16_f32(0, scores(0, i * 4 + 2), 0, scores(0, i * 4 + 3), 0); // #else // p[i].data[0] = float2bf16(scores(0, i * 4)); // p[i].data[1] = float2bf16(scores(0, i * 4 + 1)); // p[i].data[2] = float2bf16(scores(0, i * 4 + 2)); // p[i].data[3] = float2bf16(scores(0, i * 4 + 3)); // #endif // } // int row_offset_v[4]; // for (int i = 0; i < 4; i++) // { // row_offset_v[i] = calc_row_and_col_v(i); // } // // __syncthreads(); // // #if 1 // { // constexpr int k_val = (0); // buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 0, block_idx, offset_k); // buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 0, block_idx, offset_k); // buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, 0, block_idx, offset_k); // buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, 0, block_idx, offset_k); // __builtin_amdgcn_sched_barrier(0); // asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); // __builtin_amdgcn_sched_barrier(0); // flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); // __builtin_amdgcn_sched_barrier(0); // asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); // __builtin_amdgcn_sched_barrier(0); // buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 1, block_idx, offset_k); // flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); // __builtin_amdgcn_sched_barrier(0); // asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); // __builtin_amdgcn_sched_barrier(0); // buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 1, block_idx, offset_k); // flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); // __builtin_amdgcn_sched_barrier(0); // asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); // __builtin_amdgcn_sched_barrier(0); // buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, 1, block_idx, offset_k); // flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); // __builtin_amdgcn_sched_barrier(0); // asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); // __builtin_amdgcn_sched_barrier(0); // buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, 1, block_idx, offset_k); // } // #define LOAD_V_AND_PV_GEMM(n) \ // { \ // constexpr int k_val = (0); \ // constexpr int n_val = (n); \ // flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ // flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ // flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ // flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ // __builtin_amdgcn_sched_barrier(0); \ // asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \ // __builtin_amdgcn_sched_barrier(0); \ // buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, n_val + 1, block_idx, offset_k); \ // flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \ // flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \ // flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \ // flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \ // __builtin_amdgcn_sched_barrier(0); \ // asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \ // __builtin_amdgcn_sched_barrier(0); \ // buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, n_val + 1, block_idx, offset_k); \ // flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \ // flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \ // flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \ // flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \ // __builtin_amdgcn_sched_barrier(0); \ // asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \ // __builtin_amdgcn_sched_barrier(0); \ // buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, n_val + 1, block_idx, offset_k); \ // flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \ // flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \ // flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \ // flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \ // __builtin_amdgcn_sched_barrier(0); \ // asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \ // __builtin_amdgcn_sched_barrier(0); \ // buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, n_val + 1, block_idx, offset_k); \ // } // LOAD_V_AND_PV_GEMM(1); // LOAD_V_AND_PV_GEMM(2); // { // constexpr int n_val = (3); // flash::pv_gemm<0, 12>(p[0].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm<0, 13>(p[0].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm<0, 14>(p[0].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm<0, 15>(p[0].data_64, v_lds_read_ptr, acco_f32); // __builtin_amdgcn_sched_barrier(0); // asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); // __builtin_amdgcn_sched_barrier(0); // flash::pv_gemm<1, 12>(p[1].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm<1, 13>(p[1].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm<1, 14>(p[1].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm<1, 15>(p[1].data_64, v_lds_read_ptr, acco_f32); // __builtin_amdgcn_sched_barrier(0); // asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); // __builtin_amdgcn_sched_barrier(0); // flash::pv_gemm<2, 12>(p[2].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm<2, 13>(p[2].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm<2, 14>(p[2].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm<2, 15>(p[2].data_64, v_lds_read_ptr, acco_f32); // __builtin_amdgcn_sched_barrier(0); // asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); // __builtin_amdgcn_sched_barrier(0); // flash::pv_gemm<3, 12>(p[3].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm<3, 13>(p[3].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm<3, 14>(p[3].data_64, v_lds_read_ptr, acco_f32); // flash::pv_gemm<3, 15>(p[3].data_64, v_lds_read_ptr, acco_f32); // __builtin_amdgcn_sched_barrier(0); // asm volatile("s_barrier\n\t"); // __builtin_amdgcn_sched_barrier(0); // } // }; // constexpr int n_masking_steps = !T::Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; // int n_block = n_block_max - 1; // if constexpr (n_masking_steps == 1) { // if (n_block >= n_block_min) { // process_one_block(n_block, IsFirstMaskBlock{}); // } // n_block--; // } else { // int masking_step = 1; // if (n_block >= n_block_min) { // process_one_block(n_block, IsFirstMaskBlock{}); // } // n_block--; // for (; n_block >= n_block_min && masking_step < n_masking_steps; ++masking_step, --n_block) { // process_one_block(n_block, IsMaskBlock{}); // } // } // for(; n_block >= n_block_min; --n_block) { // process_one_block(n_block, IsNoMaskBlock{}); // } // using ElementAccum = float; // if constexpr (true) // { // using ElementO = Element; // const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; // const index_t row_offset_lse = (bidb * params.h_k + bidh) * params.q_seq_per_hk + m_block * kBlockM; // constexpr bool Split = false; // Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + ( row_offset_o)), // Shape, Int>{}, // make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); // Tensor lse = softmax.template normalize_softmax_lse_prefill_4x1(acco_f32, params.scale_softmax); // // if (block0() && tidx < 64) // // { // // printf(" %.3f %.3f \n", float(acc_o(0)), float(acc_o(1))); // // } // Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lse)), // Shape>{}, Stride<_1>{}); // { // // using result_type = cutlass::Array; // // int tidx = threadIdx.x; // int row, col; // // int warpid = tidx / 64; // for (int mi = 0; mi < 1; mi++) { // row = mi * kBlockM + lane_idx % 16 + warp_idx * 16; // if (row < params.q_seq_per_hk - m_block * kBlockM) { // for (int ni = 0; ni < 16; ++ni) { // #if defined(__gfx938__) // Bf16_storage res; // col = (lane_idx / 16) * 8 + ni * 32 ; // res.data_32[0] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][0], 0, acco_f32[ni * 2 + 1][0], 0); // res.data_32[1] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][1], 0, acco_f32[ni * 2 + 1][1], 0); // res.data_32[2] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][2], 0, acco_f32[ni * 2 + 1][2], 0); // res.data_32[3] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][3], 0, acco_f32[ni * 2 + 1][3], 0); // *(__fp16x8_t*)(&gOaccum(row, col)) = res.data_128; // #else // col = (lane_idx / 16) * 2 + ni * 32 ; // using result_type = cutlass::Array; // for (int ei = 0; ei < 4; ei++) // { // result_type res; // Element e0, e1; // e0.storage = float2bf16(acco_f32[ni * 2][ei]); // e1.storage = float2bf16(acco_f32[ni * 2 + 1][ei]); // res[0] = e0; // res[1] = e1; // // gO(row, col) = res[0]; // // gO(row, col + 1) = res[1]; // *(result_type*)(&gOaccum(row, col)) = res; // col += 8; // } // #endif // } // // for (int n = 0; n < 1; n++) { // // col = (tidx % 64 / 16) + warpid * 32 + n * 128; // // for (int ei = 0; ei < 8; ei ++) { // // gOaccum(row, col) = rO(ei, m, n); // // col += 4; // // } // // } // gLSEaccum(row) = lse(mi); // } // } // } // } // } template __device__ void compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int seqlen_k, const int n_block_min, const int n_block_max, const bool NoSplit) { extern __shared__ char shared_memory[]; using SharedMemoryPlan = typename T::SharedMemoryPlan; SharedMemoryPlan &plan = *reinterpret_cast(shared_memory); const int tidx = threadIdx.x; constexpr int kBlockM = T::BLOCK_SIZE_M; constexpr int kBlockN = T::PAGE_BLOCK_SIZE; constexpr int kHeadDim = T::HEAD_DIM_K; constexpr int kHeadDimV = T::HEAD_DIM_V; using Element = T::InputT; using index_t = int64_t; const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, make_stride(params.q_row_stride, _1{})); const index_t row_offset_k = (bidh) * params.k_head_stride; Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); Tensor sQ = make_tensor(make_smem_ptr(plan.smem_q.data()), typename T::SmemLayoutQ{}); Tensor sV = make_tensor(make_smem_ptr(plan.smem_v.data()), typename T::SmemLayoutV{}); Tensor sK = make_tensor(make_smem_ptr(plan.smem_v.data()), typename T::SmemLayoutK{}); Tensor sP = make_tensor(make_smem_ptr(plan.smem_p.data()), typename T::SmemLayoutP{}); Tensor sVt = make_tensor(sV.data(), typename T::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename T::SmemLayoutVtransposedNoSwizzle{}); Tensor sRow_max_reduce_buffer = make_tensor(make_smem_ptr(plan.smem_row_max.data()), typename T::SmemLayoutRow{}); Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(plan.smem_row_sum.data()), typename T::SmemLayoutRow{}); typename T::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); typename T::TiledMma_O tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); #if 1 typename T::GmemTiledCopyQ gmem_tiled_copy_Q; auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ))); Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); Tensor tQpQ = make_tensor(make_shape(size<2>(tQgQ))); if (tidx < 128) flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, params.q_seq_per_hk - m_block * kBlockM); __syncthreads(); auto smem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); __syncthreads(); #else auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSgQ = gmem_thr_copy_Q.partition_S(gQ); Tensor tSrQ = thr_mma.partition_fragment_A(gQ); Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ))); Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); Tensor tQpQ = make_tensor(make_shape(size<2>(tSgQ))); flash::copy(gmem_tiled_copy_Q, tSgQ, tSrQ, tQcQ, tQpQ, params.q_seq_per_hk - m_block * kBlockM); __syncthreads(); #endif // if (block0() && tidx < 64) // { // printf(" %.3f %.3f \n", float(tSrQ(0)), float(tSrQ(1))); // } auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSgK = smem_thr_copy_K.partition_S(gK); Tensor tSsK = smem_thr_copy_K.partition_S(sK); Tensor tSrK = thr_mma.partition_fragment_B(sK); auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom{}, tiled_mma_o); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); Tensor tOrVt = thr_mma_o.partition_fragment_B(sVtNoSwizzle); constexpr int n_masking_steps = !T::Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; const int *block_table = params.block_table + bidb * params.block_table_batch_stride; int n_block = n_block_max - 1; Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape, Int>{}); clear(acc_o); flash::Softmax(acc_o)> softmax; Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt); Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK); // Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt); // Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK); // Tensor tKcK_smem = smem_thr_copy_K.partition_S(cK); Tensor tKpK_smem = make_tensor(make_shape(size<2>(tSgK))); Tensor tSrK_smem = thr_mma.partition_fragment_B(gK); constexpr static int k0_lds_loops = 15; constexpr static int k0_loops = size<2>(tSrK_smem); constexpr static int k1_loops = size<2>(tOrVt); constexpr static int STAGE = 15; for (int masking_step = 0; masking_step < n_masking_steps && n_block >= n_block_min; ++masking_step, --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); clear(acc_s); // asm volatile("s_barrier\n\t"); // 这个也做过循环2类似的修改,但是性能不如现在的好,所以保持不变 int cur_block_table; const int *cur_block_table_ptr = block_table + n_block; // cur_block_table = block_table[n_block - 1]; __builtin_amdgcn_sched_barrier(0); asm volatile("s_load_dword %1, %0, 0x0\n\t" "s_waitcnt lgkmcnt(0)\n\t": "+s"(cur_block_table_ptr), "=s"(cur_block_table)); __builtin_amdgcn_sched_barrier(0); index_t offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK.data() + (offset_k); #pragma unroll for (int i = 0; i < STAGE; i++) { flash::lds_direct_copy(gK, sK, i, params.k_row_stride, seqlen_k - n_block * kBlockN); } constexpr static int BUFFER_SIZE = 3; uint128_t buffer[BUFFER_SIZE]; flash::buffer_load_copy(gK, buffer[0], 15, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); flash::buffer_load_copy(gK, buffer[1], 16, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); flash::buffer_load_copy(gK, buffer[2], 17, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); // if constexpr (STAGE == 15) { int k_idx = 0; // k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(14 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(13 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(12 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(11 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(10 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(9+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(8+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(7+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(6+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(5 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(4 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(3 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(2 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(1 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); } __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(2) \n\t \n\t"); __builtin_amdgcn_sched_barrier(0); flash::buffer_to_tensor(buffer[0], tSrK_smem, 15); cute::gemm(tiled_mma, tSrQ(_, _, 15), tSrK_smem(_, _, 15), acc_s); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(1) \n\t \n\t"); __builtin_amdgcn_sched_barrier(0); flash::buffer_to_tensor(buffer[1], tSrK_smem, 16); cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0) \n\t \n\t"); __builtin_amdgcn_sched_barrier(0); flash::buffer_to_tensor(buffer[2], tSrK_smem, 17); cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s); // asm volatile("s_barrier\n\t"); // if (block0() && tidx < 64) // { // printf(" %.3f %.3f \n", acc_s(0), acc_s(1)); // } Tensor cS = make_identity_tensor(Shape, Int>{}); Tensor tScS = thr_mma.partition_C(cS); for (int i = 0; i < size(acc_s); ++i) { if constexpr (!T::Is_causal) { if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) acc_s(i) = -INFINITY; } else { // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups int row = int(get<0>(tScS(i))); int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.q_seq_per_hk - 1 - (m_block * kBlockM + row)) / params.q_head_per_hk; if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY; } } // We have key_padding_mask so we'll need to Check_inf if constexpr (n_masking_steps == 1) { softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2); } else { const bool is_first_masking_step = masking_step == 0; is_first_masking_step ? softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2) : softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2); } Tensor rP = flash::convert_type(acc_s); // Tensor tOrP = convert_layout_acc_Aregs(tiled_mma_o, rP, sP); Tensor tOrP = flash::convert_layout_acc_Aregs_dense(tiled_mma, tiled_mma_o, rP, sP); __syncthreads(); flash::lds_direct_copy(gK, sK, 15, params.k_row_stride, seqlen_k - n_block * kBlockN); // asm_ds_write(buffer[0], tVsV, 15); // asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); gK.data() = gK.data() + (-offset_k); #pragma unroll for (int i = 0; i < k1_loops; i++) { cute::copy(smem_tiled_copy_V, tOsVt(_, _, i), tOrVt_copy_view(_, _, i)); cute::gemm(tiled_mma_o, tOrP(_, _, i), tOrVt(_, _, i), acc_o); } // asm volatile("s_barrier\n\t"); } for (; n_block >= n_block_min; --n_block) { Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); clear(acc_s); // asm volatile("s_barrier\n\t"); int cur_block_table; const int *cur_block_table_ptr = block_table + n_block; // cur_block_table = block_table[n_block - 1]; __builtin_amdgcn_sched_barrier(0); asm volatile("s_load_dword %1, %0, 0x0\n\t" "s_waitcnt lgkmcnt(0)\n\t": "+s"(cur_block_table_ptr), "=s"(cur_block_table)); __builtin_amdgcn_sched_barrier(0); index_t offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK.data() + (offset_k); #pragma unroll for (int i = 0; i < 16; i++) { flash::lds_direct_copy(gK, sK, i, params.k_row_stride, seqlen_k - n_block * kBlockN); } constexpr static int BUFFER_SIZE = 2; uint128_t buffer[BUFFER_SIZE]; // buffer_load_copy(gK, buffer[0], 15, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); flash::buffer_load_copy(gK, buffer[0], 16, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); flash::buffer_load_copy(gK, buffer[1], 17, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); // if constexpr (STAGE == 15) { int k_idx = 0; // k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(14 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(13 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(12 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(11 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(10 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(9+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(8+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(7+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(6+ 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(5 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(4 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); k_idx++; __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(3 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(2 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(1 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0 + 3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0 + 2) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); k_idx++; cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); __builtin_amdgcn_sched_barrier(0); flash::__ds_read_m32x16_row_col<3, 0>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col<3, 1>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col<3, 2>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col<3, 3>(tOsVt, tOrVt_copy_view); __builtin_amdgcn_sched_barrier(0); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); // asm volatile("s_barrier\n\t"); } __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(1) \n\t \n\t"); __builtin_amdgcn_sched_barrier(0); flash::buffer_to_tensor(buffer[0], tSrK_smem, 16); cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0) \n\t \n\t"); __builtin_amdgcn_sched_barrier(0); flash::buffer_to_tensor(buffer[1], tSrK_smem, 17); cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s); // asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); gK.data() = gK.data() + (-offset_k); // We have key_padding_mask so we'll need to Check_inf softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2); Tensor rP = flash::convert_type(acc_s); // Tensor tOrP = convert_layout_acc_Aregs(tiled_mma_o, rP, sP); Tensor tOrP = flash::convert_layout_acc_Aregs_dense(tiled_mma, tiled_mma_o, rP, sP); flash::__ds_read_m32x16_row_col<0, 0>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col<1, 0>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col<2, 0>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col<0, 1>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col<1, 1>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col<2, 1>(tOsVt, tOrVt_copy_view); cute::gemm(tiled_mma_o, tOrP(_, _, 0), tOrVt(_, _, 0), acc_o); cute::gemm(tiled_mma_o, tOrP(_, _, 1), tOrVt(_, _, 1), acc_o); flash::__ds_read_m32x16_row_col<0, 2>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col<1, 2>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col<2, 2>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col<0, 3>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col<1, 3>(tOsVt, tOrVt_copy_view); flash::__ds_read_m32x16_row_col<2, 3>(tOsVt, tOrVt_copy_view); cute::gemm(tiled_mma_o, tOrP(_, _, 2), tOrVt(_, _, 2), acc_o); cute::gemm(tiled_mma_o, tOrP(_, _, 3), tOrVt(_, _, 3), acc_o); // asm volatile("s_barrier\n\t"); } using ElementAccum = float; if (NoSplit) { using ElementO = Element; const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_lse = (bidb * params.h_k + bidh) * params.q_seq_per_hk + m_block * kBlockM; constexpr bool Split = false; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + ( row_offset_o)), Shape, Int>{}, make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); Tensor lse = softmax.template normalize_softmax_lse(acc_o, sRow_sum_reduce_buffer, params.scale_softmax); // if (block0() && tidx < 64) // { // printf(" %.3f %.3f \n", float(acc_o(0)), float(acc_o(1))); // } Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lse)), Shape>{}, Stride<_1>{}); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma_o.partition_C(caccO); Tensor rO = flash::convert_type(acc_o); if (get<1>(taccOcO(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO(0, mi, 0)); if (row < params.q_seq_per_hk - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } } } { // using result_type = cutlass::Array; // int tidx = threadIdx.x; int col = 0; int warpid = tidx / 64; for (int m = 0; m < 1; m++) { const int row = tidx % 16; if (row < params.q_seq_per_hk - m_block * kBlockM) { for (int n = 0; n < size<2>(acc_o); n++) { col = (tidx % 64 / 16) + warpid * 32 + n * 128; for (int ei = 0; ei < 8; ei ++) { gOaccum(row, col) = rO(ei, m, n); col += 4; } } } } } } else { using ElementO = float; int split_idx = params.num_splits_ptr[bidb] + n_split_idx; constexpr bool Split = true; const index_t row_offset_oaccum = ((split_idx*params.h_k + bidh)*params.q_seq_per_hk + m_block * kBlockM)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1) const index_t row_offset_lseaccum = (split_idx*params.h_k + bidh)*params.q_seq_per_hk + m_block * kBlockM; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (row_offset_oaccum)), Shape, Int>{}, make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lseaccum)), Shape>{}, Stride<_1>{}); Tensor lse = softmax.template normalize_softmax_lse(acc_o, sRow_sum_reduce_buffer, params.scale_softmax); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma_o.partition_C(caccO); if (get<1>(taccOcO(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO(0, mi, 0)); if (row < params.q_seq_per_hk - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } } } { // using result_type = cutlass::Array; // int tidx = threadIdx.x; int col = 0; int warpid = tidx / 64; for (int m = 0; m < 1; m++) { const int row = tidx % 16; if (row < params.q_seq_per_hk - m_block * kBlockM) { for (int n = 0; n < size<2>(acc_o); n++) { col = (tidx % 64 / 16) + warpid * 32 + n * 128; for (int ei = 0; ei < 8; ei ++) { gOaccum(row, col) = acc_o(ei, m, n); col += 4; } } } } } } } template __device__ void compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int seqlen_k, const int n_block_min, const int n_block_max, const bool NoSplit) { extern __shared__ char shared_memory[]; using SharedMemoryPlan = typename T::SharedMemoryPlan; SharedMemoryPlan &plan = *reinterpret_cast(shared_memory); const int tidx = threadIdx.x; constexpr int kBlockM = T::BLOCK_SIZE_M; constexpr int kBlockN = T::PAGE_BLOCK_SIZE; constexpr int kHeadDim = T::HEAD_DIM_K; constexpr int kHeadDimV = T::HEAD_DIM_V; using Element = T::InputT; using index_t = int64_t; const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, make_stride(params.q_row_stride, _1{})); const index_t row_offset_k = (bidh) * params.k_head_stride; Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); Tensor sQ = make_tensor(make_smem_ptr(plan.smem_q.data()), typename T::SmemLayoutQ{}); Tensor sV = make_tensor(make_smem_ptr(plan.smem_v.data()), typename T::SmemLayoutV{}); Tensor sK = make_tensor(make_smem_ptr(plan.smem_v.data()), typename T::SmemLayoutK{}); Tensor sP = make_tensor(make_smem_ptr(plan.smem_p.data()), typename T::SmemLayoutP{}); Tensor sVt = make_tensor(sV.data(), typename T::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename T::SmemLayoutVtransposedNoSwizzle{}); Tensor sRow_max_reduce_buffer = make_tensor(make_smem_ptr(plan.smem_row_max.data()), typename T::SmemLayoutRow{}); Tensor sRow_sum_reduce_buffer = make_tensor(make_smem_ptr(plan.smem_row_sum.data()), typename T::SmemLayoutRow{}); using MMA_Atom_Arch = std::conditional_t< std::is_same_v, MMA_Atom, MMA_Atom >; using ValLayoutMNK = Layout>; using TiledMma = TiledMMA< MMA_Atom_Arch, Layout, _1>>, // 1x4x1 or 1x8x1 thread group ValLayoutMNK>; TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); typename T::TiledMma_O tiled_mma_o; auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); #if 1 typename T::GmemTiledCopyQ gmem_tiled_copy_Q; auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); Tensor cQ = make_identity_tensor(make_shape(size<0>(gQ), size<1>(gQ))); Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); Tensor tQpQ = make_tensor(make_shape(size<2>(tQgQ))); if (tidx < 128) flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, params.q_seq_per_hk - m_block * kBlockM); __syncthreads(); auto smem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom{}, tiled_mma); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); __syncthreads(); #else #endif typename T::GmemTiledCopyK gmem_tiled_copy_K; auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx); Tensor tKgK = gmem_thr_copy_K.partition_S(gK); Tensor tKsK = gmem_thr_copy_K.partition_D(sK); Tensor cK = make_identity_tensor(make_shape(size<0>(gK), size<1>(gK))); Tensor tKcK = gmem_thr_copy_K.partition_S(cK); Tensor tKpK = make_tensor(make_shape(size<2>(tKgK))); auto smem_tiled_copy_K = make_tiled_copy_B(Copy_Atom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSgK = smem_thr_copy_K.partition_S(gK); Tensor tSsK = smem_thr_copy_K.partition_S(sK); Tensor tSrK = thr_mma.partition_fragment_B(sK); Tensor tKcK_smem = smem_thr_copy_K.partition_S(cK); Tensor tKpK_smem = make_tensor(make_shape(size<2>(tSgK))); Tensor tSrK_smem = thr_mma.partition_fragment_B(gK); typename T::GmemTiledCopyV gmem_tiled_copy_V; auto gmem_thr_copy_V = gmem_tiled_copy_V.get_thread_slice(tidx); Tensor tVgV = gmem_thr_copy_V.partition_S(gV); Tensor tVsV = gmem_thr_copy_V.partition_D(sV); Tensor cV = make_identity_tensor(make_shape(size<0>(gV), size<1>(gV))); Tensor tVcV = gmem_thr_copy_V.partition_S(cV); Tensor tVpV = make_tensor(make_shape(size<2>(tVgV))); auto smem_tiled_copy_V = make_tiled_copy_B(Copy_Atom{}, tiled_mma_o); auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); Tensor tOrVt = thr_mma_o.partition_fragment_B(sVtNoSwizzle); constexpr int n_masking_steps = !T::Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; const int *block_table = params.block_table + bidb * params.block_table_batch_stride; int n_block = n_block_max - 1; // constexpr static int k0_lds_loops = 0; constexpr static int k0_lds_loops = 16; constexpr static int k0_loops = size<2>(tSrK_smem); constexpr static int k1_loops = size<2>(tOrVt); Tensor acc_o = partition_fragment_C(tiled_mma_o, Shape, Int>{}); clear(acc_o); flash::Softmax(acc_o)> softmax; int cur_block_table; index_t offset_k; constexpr static int BUFFER_SIZE = 4; uint128_t buffer[BUFFER_SIZE]; if (n_block >= n_block_min) { cur_block_table = block_table[n_block]; offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK.data() + (offset_k); flash::buffer_load_copy(gK, buffer[0], 0, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); flash::buffer_load_copy(gK, buffer[1], 1, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); flash::buffer_load_copy(gK, buffer[2], 2, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); } #if 1 #pragma unroll for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) { __builtin_amdgcn_sched_barrier(0); asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); Tensor acc_s = partition_fragment_C(tiled_mma, Shape, Int>{}); clear(acc_s); Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK); // 计算0~11 #if 1 #pragma unroll for (int i = 0; i < k0_lds_loops - BUFFER_SIZE + 1; i++) { // asm volatile("s_waitcnt vmcnt(3) \n\t \n\t"); flash::asm_ds_write(buffer[i % BUFFER_SIZE], tKsK, i); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, i), tSrK_copy_view(_, _, i)); flash::buffer_load_copy(gK, buffer[(i + BUFFER_SIZE - 1) % BUFFER_SIZE], i + BUFFER_SIZE - 1, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); cute::gemm(tiled_mma, tSrQ(_, _, i), tSrK(_, _, i), acc_s); // asm volatile("s_barrier\n\t"); } // asm volatile("s_barrier\n\t"); #endif #if 0 #else // 计算 13-15 const int k_idx = k0_lds_loops - BUFFER_SIZE + 1; flash::asm_ds_write(buffer[k_idx % BUFFER_SIZE], tKsK, k_idx); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx), tSrK_copy_view(_, _, k_idx)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx), tSrK(_, _, k_idx), acc_s); flash::asm_ds_write(buffer[(k_idx + 1) % BUFFER_SIZE], tKsK, k_idx + 1); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx + 1), tSrK_copy_view(_, _, k_idx + 1)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx + 1), tSrK(_, _, k_idx + 1), acc_s); flash::asm_ds_write(buffer[(k_idx + 2) % BUFFER_SIZE], tKsK, k_idx + 2); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); cute::copy(smem_tiled_copy_K, tSsK(_, _, k_idx + 2), tSrK_copy_view(_, _, k_idx + 2)); cute::gemm(tiled_mma, tSrQ(_, _, k_idx + 2), tSrK(_, _, k_idx + 2), acc_s); // asm volatile("s_barrier\n\t"); // 读取16-17 flash::buffer_load_copy(gK, buffer[1], 16, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); flash::buffer_load_copy(gK, buffer[2], 17, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN); flash::buffer_to_tensor(buffer[1], tSrK_smem, 16); cute::gemm(tiled_mma, tSrQ(_, _, 16), tSrK_smem(_, _, 16), acc_s); flash::buffer_to_tensor(buffer[2], tSrK_smem, 17); cute::gemm(tiled_mma, tSrQ(_, _, 17), tSrK_smem(_, _, 17), acc_s); __builtin_amdgcn_sched_barrier(0); asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); #endif const bool is_masking_step = masking_step > 0; const bool is_first_masking_step = masking_step == n_masking_steps; if (is_masking_step) { Tensor cS = make_identity_tensor(Shape, Int>{}); Tensor tScS = thr_mma.partition_C(cS); for (int i = 0; i < size(acc_s); ++i) { if constexpr (!T::Is_causal) { if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) acc_s(i) = -INFINITY; } else { // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups int row = int(get<0>(tScS(i))); int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.q_seq_per_hk - 1 - (m_block * kBlockM + row)) / params.q_head_per_hk; if (int(get<1>(tScS(i))) > col_limit_right) acc_s(i) = -INFINITY; } } } // We have key_padding_mask so we'll need to Check_inf is_first_masking_step ? softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2) : is_masking_step ? softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2) : softmax.template softmax_rescale_o(acc_s, acc_o, sRow_max_reduce_buffer, params.scale_softmax_log2); Tensor rP = flash::convert_type(acc_s); Tensor tOrP = flash::convert_layout_acc_Aregs(tiled_mma, tiled_mma_o, rP, sP); __syncthreads(); #if 1 // 第15块已经读取到了buffer[3]中 flash::asm_ds_write(buffer[3], tVsV, 15); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); #endif gK.data() = gK.data() + (-offset_k); if (n_block > n_block_min) { cur_block_table = block_table[n_block - 1]; offset_k = cur_block_table * params.k_batch_stride; gK.data() = gK.data() + (offset_k); flash::buffer_load_copy(gK, buffer[0], 0, params.k_row_stride, offset_k); flash::buffer_load_copy(gK, buffer[1], 1, params.k_row_stride, offset_k); flash::buffer_load_copy(gK, buffer[2], 2, params.k_row_stride, offset_k); } Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt); #pragma unroll for (int i = 0; i < k1_loops; i++) { cute::copy(smem_tiled_copy_V, tOsVt(_, _, i), tOrVt_copy_view(_, _, i)); cute::gemm(tiled_mma_o, tOrP(_, _, i), tOrVt(_, _, i), acc_o); } __builtin_amdgcn_sched_barrier(0); asm volatile(" s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); } #endif using ElementAccum = float; if (NoSplit) { using ElementO = Element; const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_lse = (bidb * params.h_k + bidh) * params.q_seq_per_hk + m_block * kBlockM; constexpr bool Split = false; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + ( row_offset_o)), Shape, Int>{}, make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); Tensor lse = softmax.template normalize_softmax_lse(acc_o, sRow_sum_reduce_buffer, params.scale_softmax); // if (block0() && tidx < 64) // { // printf(" %.3f %.3f \n", float(acc_o(0)), float(acc_o(1))); // } Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lse)), Shape>{}, Stride<_1>{}); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma_o.partition_C(caccO); Tensor rO = flash::convert_type(acc_o); if (get<1>(taccOcO(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO(0, mi, 0)); if (row < params.q_seq_per_hk - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } } } { // using result_type = cutlass::Array; // int tidx = threadIdx.x; int col = 0; int warpid = tidx / 64; for (int m = 0; m < 1; m++) { const int row = tidx % 16; if (row < params.q_seq_per_hk - m_block * kBlockM) { for (int n = 0; n < size<2>(acc_o); n++) { col = (tidx % 64 / 16) + warpid * 32 + n * 128; for (int ei = 0; ei < 8; ei ++) { gOaccum(row, col) = rO(ei, m, n); col += 4; } } } } } } else { using ElementO = float; int split_idx = params.num_splits_ptr[bidb] + n_split_idx; constexpr bool Split = true; const index_t row_offset_oaccum = ((split_idx*params.h_k + bidh)*params.q_seq_per_hk + m_block * kBlockM)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1) const index_t row_offset_lseaccum = (split_idx*params.h_k + bidh)*params.q_seq_per_hk + m_block * kBlockM; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (row_offset_oaccum)), Shape, Int>{}, make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lseaccum)), Shape>{}, Stride<_1>{}); Tensor lse = softmax.template normalize_softmax_lse(acc_o, sRow_sum_reduce_buffer, params.scale_softmax); Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor taccOcO = thr_mma_o.partition_C(caccO); if (get<1>(taccOcO(0)) == 0) { #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccOcO(0, mi, 0)); if (row < params.q_seq_per_hk - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } } } { // using result_type = cutlass::Array; // int tidx = threadIdx.x; int col = 0; int warpid = tidx / 64; for (int m = 0; m < 1; m++) { const int row = tidx % 16; if (row < params.q_seq_per_hk - m_block * kBlockM) { for (int n = 0; n < size<2>(acc_o); n++) { col = (tidx % 64 / 16) + warpid * 32 + n * 128; for (int ei = 0; ei < 8; ei ++) { gOaccum(row, col) = acc_o(ei, m, n); col += 4; } } } } } } } template __global__ void __launch_bounds__(T::NUM_THREADS, 1) flash_fwd_splitkv_mla_kernel(const DenseAttnDecodeParams params) { const int m_block = blockIdx.x; const int bidh = blockIdx.y; const int partition_idx = blockIdx.z; DecodingSchedMeta sched_meta = params.tile_scheduler_metadata_ptr[partition_idx]; if (sched_meta.begin_req_idx >= params.b) return; for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) { constexpr int kBlockN = T::PAGE_BLOCK_SIZE; const int n_split_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_split_idx : 0; int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx); const int start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0; int end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : cute::ceil_div(seqlen_k, kBlockN); const bool is_no_split = batch_idx == sched_meta.begin_req_idx ? !sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? !sched_meta.is_last_req_splitted : true); if (batch_idx > sched_meta.begin_req_idx) { __syncthreads(); } #if defined(__gfx928__) compute_attn_1rowblock_splitkv_mla_gfx928(params, batch_idx, bidh, m_block, n_split_idx, seqlen_k, start_block_idx, end_block_idx, is_no_split ); #else compute_attn_1rowblock_splitkv_mla_gfx936(params, batch_idx, bidh, m_block, n_split_idx, seqlen_k, start_block_idx, end_block_idx, is_no_split ); #endif } } template __global__ void __launch_bounds__(T::NUM_THREADS, 1) flash_fwd_splitkv_mla_block_m_64_kernel(const DenseAttnDecodeParams params) { #if defined(__gfx936__) || defined(__gfx938__) constexpr int kBlockN = T::PAGE_BLOCK_SIZE; const int m_block = blockIdx.x; const int bidh = blockIdx.y; int bidb; int seqlen_k; int n_block_min; int n_block_max; const int tidx = threadIdx.x; const int lane_idx = tidx % 64; bool is_split = use_split_kv; if constexpr (use_split_kv) { int num_splits = params.total_num_splits / params.b; bidb = blockIdx.z % params.b; // bidb = blockIdx.z / num_splits; seqlen_k = __ldg(params.seqlens_k_ptr + bidb); int split_id = blockIdx.z / params.b; n_block_min = split_id * params.partition_block_nums; n_block_max = split_id == (num_splits - 1) ? cute::ceil_div(seqlen_k, kBlockN) : std::min((split_id + 1) * params.partition_block_nums, cute::ceil_div(seqlen_k, kBlockN)); if (split_id == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN)) { is_split = false; } // if (tidx == 0 && bidb == 61) // { // printf("bidb = %d split_id = %d n_block_min = %d n_block_max = %d num_splits = %d params.partition_block_nums %d is_split = %d \n", bidb, split_id, n_block_min, n_block_max, num_splits, params.partition_block_nums, is_split); // } if (n_block_max <= n_block_min) return; } else { bidb = blockIdx.z; seqlen_k = __ldg(params.seqlens_k_ptr + bidb); n_block_min = 0; n_block_max = cute::ceil_div(seqlen_k, kBlockN); } extern __shared__ char shared_memory[]; constexpr int kBlockM = T::BLOCK_SIZE_M; // constexpr int kBlockN = T::PAGE_BLOCK_SIZE; constexpr int kHeadDim = T::HEAD_DIM_K; constexpr int kHeadDimV = T::HEAD_DIM_V; using Element = T::InputT; using index_t = int64_t; const int warp_idx = __builtin_amdgcn_readfirstlane(tidx / 64); Element* q_lds = (Element*)&(shared_memory); Element* k_lds = q_lds; Element* v_lds = q_lds; const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, make_stride(params.q_row_stride, _1{})); // const index_t row_offset_k = (0) * params.k_head_stride; // Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), // Shape, Int>{}, // make_stride(params.k_row_stride, _1{})); typedef __bf16 __fp16x8_t __attribute__((ext_vector_type(8))); typedef __bf16 __fp16x4_t __attribute__((ext_vector_type(4))); typedef __bf16 __fp16x2_t __attribute__((ext_vector_type(2))); union Bf16_storage { __fp16x8_t data_128; __fp16x4_t data_64[2]; __fp16x2_t data_32[4]; uint16_t data_array[8]; }; union Bf16_storage_x4 { __fp16x4_t data_64; __fp16x2_t data_32[2]; uint16_t data[4]; }; struct PtrWrapper { uint32_t former; uint32_t latter; }; PtrWrapper glob_ptr_q; *(uint64_t*)&glob_ptr_q = reinterpret_cast(gQ.data().get()); glob_ptr_q.latter |= ((params.q_row_stride * 2) << 16); glob_ptr_q.latter |= 0x40000000; uint32x4_t global_addr_q = {0}; global_addr_q[0] = (glob_ptr_q.former); global_addr_q[1] = (glob_ptr_q.latter); global_addr_q[2] = params.q_seq_per_hk - m_block * kBlockM; global_addr_q[3] = 0x00020000; int virtual_row_ = lane_idx / 8;//0 int virtual_col_ = lane_idx % 8;//0 int swizzle_col_ = virtual_row_ ^ virtual_col_; int row_ = lane_idx / 4;//0 // 8->9 9->8 // row_ = (row_ >= 8 ) ^ row_; int col_ = swizzle_col_ % 4; auto calc_row_and_col_k = [&]() -> std::tuple { constexpr int elements_per_thread = 8; #if defined(__gfx938__) int row_offset = row_ + warp_idx * 16; int col_offset = col_ * 8; #else int row_offset = row_ * 4 + warp_idx; int col_offset = col_ * 8; #endif return {row_offset, col_offset}; }; auto buffer_load_lds_k = [&](int row_offset, int col, int k_idx, int block_idx, index_t offset_k) { constexpr int element_size = 2; PtrWrapper glob_ptr_k; *(uint64_t*)&glob_ptr_k = reinterpret_cast(params.k_ptr) + offset_k * 2; glob_ptr_k.latter |= ((params.k_row_stride * 2) << 16); glob_ptr_k.latter |= 0x40000000; uint32x4_t global_addr_k = {0}; global_addr_k[0] = __builtin_amdgcn_readfirstlane(glob_ptr_k.former); global_addr_k[1] = __builtin_amdgcn_readfirstlane(glob_ptr_k.latter); global_addr_k[2] = seqlen_k - block_idx * kBlockN; global_addr_k[3] = 0x00020000; constexpr int elements_per_thread = 8; int col_offset = col; int offset_v = col_offset * 2; int ldsAddrPerWave = reinterpret_cast(k_lds) + warp_idx * 16 * 32 * 2 + (k_idx % 5) * 64 * 32 * 2; typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2))); uint32x2_t index_offset = {0}; index_offset[0] = row_offset; index_offset[1] = offset_v; const int offset_s = k_idx * 32 * 2; __builtin_amdgcn_sched_barrier(0); asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset), "s"(ldsAddrPerWave), "s"(global_addr_k), "s"(offset_s) :); __builtin_amdgcn_sched_barrier(0); }; auto k_lds_read_offset = [&] () -> int { int row = lane_idx % 16; int col = lane_idx / 16; col = (row / 2) ^ col; col = col % 4; // row = (row >= 8) ^ row; const auto lds_offset = row * 32 + col * 8; // #endif return lds_offset; }; auto calc_row_and_col_v = [&](int i) -> int { int row = lane_idx / 4; // int col = lane_idx % 4; int row_offset = row + i * 16; // int col_offset = col * 8 + warp_idx * 32; return row_offset; }; const int v_lds_read_ptr = reinterpret_cast(v_lds + lane_idx * 8); Element* k_lds_read_ptr = (k_lds + k_lds_read_offset()); int col_offset_v = (lane_idx % 4) * 8 + warp_idx * 32; auto buffer_load_lds_v = [&](int row_offset, int col, int k_idx, int n_idx, int block_idx, index_t offset_k) { constexpr int element_size = 2; PtrWrapper glob_ptr_k; *(uint64_t*)&glob_ptr_k = reinterpret_cast(params.k_ptr) + offset_k * 2; glob_ptr_k.latter |= ((params.k_row_stride * 2) << 16); glob_ptr_k.latter |= 0x40000000; uint32x4_t global_addr_k = {0}; global_addr_k[0] = __builtin_amdgcn_readfirstlane(glob_ptr_k.former); global_addr_k[1] = __builtin_amdgcn_readfirstlane(glob_ptr_k.latter); global_addr_k[2] = seqlen_k - block_idx * kBlockN; global_addr_k[3] = 0x00020000; constexpr int elements_per_thread = 8; int col_offset = col; // int v_idx = row_offset; int offset_v = col_offset * 2; int ldsAddrPerWave = reinterpret_cast(v_lds) + warp_idx * 16 * 32 * 2 + (k_idx) * 128 * 16 * 2; typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2))); uint32x2_t index_offset = {0}; index_offset[0] = row_offset; index_offset[1] = offset_v; const int offset_s = n_idx * 128 * 2; __builtin_amdgcn_sched_barrier(0); asm volatile( "s_mov_b32 m0, %1 \n\t" "s_nop 0 \n\t" "buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset), "s"(ldsAddrPerWave), "s"(global_addr_k), "s"(offset_s) :); __builtin_amdgcn_sched_barrier(0); }; Bf16_storage q_reg[18]; for (int i = 0; i < 18; i++) { constexpr int elements_per_thread = 8; int row = lane_idx % 16; int col = lane_idx / 16; int row_offset = row + warp_idx * 16; int col_offset = col * 8; int offset_v = col_offset * 2 + i * 32 * 2; q_reg[i].data_128 = __builtin_amdgcn_buffer_load_dwordx4(global_addr_q, row_offset, offset_v, false, false); } __syncthreads(); v4f acco_f32[32]; for (int i = 0; i < 32; i++) { acco_f32[i].x = 0.0f; acco_f32[i].y = 0.0f; acco_f32[i].z = 0.0f; acco_f32[i].w = 0.0f; } const int *block_table = params.block_table + bidb * params.block_table_batch_stride; auto float2bf16 = [] (float s) -> uint16_t { uint32_t x32 = reinterpret_cast(s); #ifndef FLASH_MLA_BF16_TYPE #define FLASH_MLA_BF16_TYPE 0 #endif #if FLASH_MLA_BF16_TYPE == 1 x32 += 0x8000u; #endif return uint16_t(x32 >> 16); }; // int block_idx = 0; // int cur_block_table = block_table[block_idx]; // index_t offset_k = (index_t)(cur_block_table) * params.k_batch_stride; // auto [row_offset, col] = calc_row_and_col_k(block_idx); // buffer_load_lds_k(row_offset, col, 0, offset_k); // __syncthreads(); { // if (thread0()) // { // int k = 0; // for (int i = 0; i < 64; i++) // { // for (int j = 0; j < 32; j++) // { // printf(" %.3f ", float(k_lds[k])); // k++; // } // printf("\n"); // } // } // if (block0() && threadIdx.x < 64) // { // cutlass::bfloat16_t q[8]; // for (int i = 0; i < 8; i++) // { // q[i].storage = v_reg[0].data_array[i]; // // q[i].storage = q_reg[0].data_array[i]; // } // printf("tidx %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %d\n ", threadIdx.x, // float(q[0]), // float(q[1]), // float(q[2]), // float(q[3]), // float(q[4]), // float(q[5]), // float(q[6]), // float(q[7]), // v_lds_read_ptr // ); // } } struct IsMaskBlock {}; struct IsFirstMaskBlock {}; struct IsNoMaskBlock {}; flash::Softmax<1> softmax; auto process_one_block = [&] (int block_idx, auto is_mask_block_t) { static constexpr bool IS_MASK_BLOCK = std::is_same_v; static constexpr bool IS_FIRST_MASK_BLOCK = std::is_same_v; static constexpr bool IS_NO_MASK_BLOCK = std::is_same_v; int cur_block_table = block_table[block_idx]; v4f accs_f32[4]; for (int i = 0; i < 4; i++) { accs_f32[i].x = 0.0f; accs_f32[i].y = 0.0f; accs_f32[i].z = 0.0f; accs_f32[i].w = 0.0f; } index_t offset_k = (index_t)(cur_block_table) * params.k_batch_stride; auto [row_offset, col] = calc_row_and_col_k(); #define LOAD_K_AND_QK_GEMM(k) \ { \ constexpr int k_val = (k); \ buffer_load_lds_k(row_offset, col, k_val - 4, block_idx, offset_k); \ flash::qk_gemm(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); \ __builtin_amdgcn_sched_barrier(0); \ asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); \ __builtin_amdgcn_sched_barrier(0); \ } { constexpr int k_val = (17); buffer_load_lds_k(row_offset, col, k_val, block_idx, offset_k); buffer_load_lds_k(row_offset, col, k_val - 1, block_idx, offset_k); buffer_load_lds_k(row_offset, col, k_val - 2, block_idx, offset_k); buffer_load_lds_k(row_offset, col, k_val - 3, block_idx, offset_k); buffer_load_lds_k(row_offset, col, k_val - 4, block_idx, offset_k); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); flash::qk_gemm(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); LOAD_K_AND_QK_GEMM(16); LOAD_K_AND_QK_GEMM(15); LOAD_K_AND_QK_GEMM(14); LOAD_K_AND_QK_GEMM(13); LOAD_K_AND_QK_GEMM(12); LOAD_K_AND_QK_GEMM(11); LOAD_K_AND_QK_GEMM(10); LOAD_K_AND_QK_GEMM(9); LOAD_K_AND_QK_GEMM(8); LOAD_K_AND_QK_GEMM(7); LOAD_K_AND_QK_GEMM(6); LOAD_K_AND_QK_GEMM(5); LOAD_K_AND_QK_GEMM(4); // LOAD_K_AND_QK_GEMM(3); flash::qk_gemm(q_reg[k_val - 14].data_128, k_lds_read_ptr, accs_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); flash::qk_gemm(q_reg[k_val - 15].data_128, k_lds_read_ptr, accs_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); flash::qk_gemm(q_reg[k_val - 16].data_128, k_lds_read_ptr, accs_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); flash::qk_gemm(q_reg[k_val - 17].data_128, k_lds_read_ptr, accs_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); } // if (block0() && tidx < 64) // { // printf(" %.3f %.3f \n", accs_f32[0][0], accs_f32[0][1]); // } if constexpr (!IS_NO_MASK_BLOCK) { for (int i = 0; i < 16; ++i) { int idx = i; if constexpr (!T::Is_causal) { if ((lane_idx / 16) * 4 + (idx % 4) + (idx / 4) * 16 >= int(seqlen_k - block_idx * kBlockN)) { #if defined(__gfx938__) accs_f32[i/4][i%4] = -INFINITY; #else accs_f32[i%4][i/4] = -INFINITY; #endif } } else { // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups int row = (lane_idx % 16) + warp_idx * 16;; int col_limit_right = seqlen_k - 1 - block_idx * kBlockN - (params.q_seq_per_hk - 1 - (m_block * kBlockM + row)) / params.q_head_per_hk; if ((lane_idx / 16) * 4 + (idx % 4) + (idx / 4) * 16 > col_limit_right) { #if defined(__gfx938__) accs_f32[i/4][i%4] = -INFINITY; #else accs_f32[i%4][i/4] = -INFINITY; #endif } } } } Tensor scores = make_tensor(Shape<_1, _16>{}); for (int i = 0; i < 16; i++) { #if defined(__gfx938__) scores(0, i) = accs_f32[i/4][i%4]; #else scores(0, i) = accs_f32[i%4][i/4]; #endif } softmax.template softmax_rescale_o_prefill_4x1(scores, acco_f32, params.scale_softmax_log2); Bf16_storage_x4 p[4]; for (int i = 0; i < 4; i++) { #if defined(__gfx938__) p[i].data_32[0] = __builtin_hcu_cvt_pk_bf16_f32(0, scores(0, i * 4), 0, scores(0, i * 4 + 1), 0); p[i].data_32[1] = __builtin_hcu_cvt_pk_bf16_f32(0, scores(0, i * 4 + 2), 0, scores(0, i * 4 + 3), 0); #else p[i].data[0] = float2bf16(scores(0, i * 4)); p[i].data[1] = float2bf16(scores(0, i * 4 + 1)); p[i].data[2] = float2bf16(scores(0, i * 4 + 2)); p[i].data[3] = float2bf16(scores(0, i * 4 + 3)); #endif } int row_offset_v[4]; for (int i = 0; i < 4; i++) { row_offset_v[i] = calc_row_and_col_v(i); } // __syncthreads(); // #if 1 { constexpr int k_val = (0); buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 0, block_idx, offset_k); buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 0, block_idx, offset_k); buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, 0, block_idx, offset_k); buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, 0, block_idx, offset_k); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 1, block_idx, offset_k); flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 1, block_idx, offset_k); flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, 1, block_idx, offset_k); flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, 1, block_idx, offset_k); } #define LOAD_V_AND_PV_GEMM(n) \ { \ constexpr int k_val = (0); \ constexpr int n_val = (n); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \ __builtin_amdgcn_sched_barrier(0); \ asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \ __builtin_amdgcn_sched_barrier(0); \ buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, n_val + 1, block_idx, offset_k); \ flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \ __builtin_amdgcn_sched_barrier(0); \ asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \ __builtin_amdgcn_sched_barrier(0); \ buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, n_val + 1, block_idx, offset_k); \ flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \ __builtin_amdgcn_sched_barrier(0); \ asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \ __builtin_amdgcn_sched_barrier(0); \ buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, n_val + 1, block_idx, offset_k); \ flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \ flash::pv_gemm(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \ __builtin_amdgcn_sched_barrier(0); \ asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \ __builtin_amdgcn_sched_barrier(0); \ buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, n_val + 1, block_idx, offset_k); \ } LOAD_V_AND_PV_GEMM(1); LOAD_V_AND_PV_GEMM(2); { constexpr int n_val = (3); flash::pv_gemm<0, 12>(p[0].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<0, 13>(p[0].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<0, 14>(p[0].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<0, 15>(p[0].data_64, v_lds_read_ptr, acco_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); flash::pv_gemm<1, 12>(p[1].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<1, 13>(p[1].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<1, 14>(p[1].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<1, 15>(p[1].data_64, v_lds_read_ptr, acco_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); flash::pv_gemm<2, 12>(p[2].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<2, 13>(p[2].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<2, 14>(p[2].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<2, 15>(p[2].data_64, v_lds_read_ptr, acco_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); flash::pv_gemm<3, 12>(p[3].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<3, 13>(p[3].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<3, 14>(p[3].data_64, v_lds_read_ptr, acco_f32); flash::pv_gemm<3, 15>(p[3].data_64, v_lds_read_ptr, acco_f32); __builtin_amdgcn_sched_barrier(0); asm volatile("s_barrier\n\t"); __builtin_amdgcn_sched_barrier(0); } }; constexpr int n_masking_steps = !T::Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; int n_block = n_block_max - 1; if constexpr (n_masking_steps == 1) { if (n_block >= n_block_min) { process_one_block(n_block, IsFirstMaskBlock{}); } n_block--; } else { int masking_step = 1; if (n_block >= n_block_min) { process_one_block(n_block, IsFirstMaskBlock{}); } n_block--; for (; n_block >= n_block_min && masking_step < n_masking_steps; ++masking_step, --n_block) { process_one_block(n_block, IsMaskBlock{}); } } for(; n_block >= n_block_min; --n_block) { process_one_block(n_block, IsNoMaskBlock{}); } using ElementAccum = float; // if constexpr (!use_split_kv) if (!is_split) { using ElementO = Element; const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_lse = (bidb * params.h_k + bidh) * params.q_seq_per_hk + m_block * kBlockM; constexpr bool Split = false; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + ( row_offset_o)), Shape, Int>{}, make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); Tensor lse = softmax.template normalize_softmax_lse_prefill_4x1(acco_f32, params.scale_softmax); // if (block0() && tidx < 64) // { // printf(" %.3f %.3f \n", float(acc_o(0)), float(acc_o(1))); // } Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lse)), Shape>{}, Stride<_1>{}); { // using result_type = cutlass::Array; // int tidx = threadIdx.x; int row, col; // int warpid = tidx / 64; for (int mi = 0; mi < 1; mi++) { row = mi * kBlockM + lane_idx % 16 + warp_idx * 16; if (row < params.q_seq_per_hk - m_block * kBlockM) { for (int ni = 0; ni < 16; ++ni) { #if defined(__gfx938__) Bf16_storage res; col = (lane_idx / 16) * 8 + ni * 32 ; res.data_32[0] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][0], 0, acco_f32[ni * 2 + 1][0], 0); res.data_32[1] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][1], 0, acco_f32[ni * 2 + 1][1], 0); res.data_32[2] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][2], 0, acco_f32[ni * 2 + 1][2], 0); res.data_32[3] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][3], 0, acco_f32[ni * 2 + 1][3], 0); *(__fp16x8_t*)(&gOaccum(row, col)) = res.data_128; #else col = (lane_idx / 16) * 2 + ni * 32 ; using result_type = cutlass::Array; for (int ei = 0; ei < 4; ei++) { result_type res; Element e0, e1; e0.storage = float2bf16(acco_f32[ni * 2][ei]); e1.storage = float2bf16(acco_f32[ni * 2 + 1][ei]); res[0] = e0; res[1] = e1; // gO(row, col) = res[0]; // gO(row, col + 1) = res[1]; *(result_type*)(&gOaccum(row, col)) = res; col += 8; } #endif } // for (int n = 0; n < 1; n++) { // col = (tidx % 64 / 16) + warpid * 32 + n * 128; // for (int ei = 0; ei < 8; ei ++) { // gOaccum(row, col) = rO(ei, m, n); // col += 4; // } // } gLSEaccum(row) = lse(mi); } } } } else { using ElementO = float; int num_splits = params.total_num_splits / params.b; int split_idx = (blockIdx.z / params.b) + bidb * num_splits; constexpr bool Split = true; const index_t row_offset_oaccum = ((split_idx*params.h_k + bidh)*params.q_seq_per_hk + m_block * kBlockM)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1) const index_t row_offset_lseaccum = (split_idx*params.h_k + bidh)*params.q_seq_per_hk + m_block * kBlockM; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (row_offset_oaccum)), Shape, Int>{}, make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lseaccum)), Shape>{}, Stride<_1>{}); Tensor lse = softmax.template normalize_softmax_lse_prefill_4x1(acco_f32, params.scale_softmax); { // using result_type = cutlass::Array; // int tidx = threadIdx.x; int row, col; // int warpid = tidx / 64; for (int mi = 0; mi < 1; mi++) { row = mi * kBlockM + lane_idx % 16 + warp_idx * 16; if (row < params.q_seq_per_hk - m_block * kBlockM) { for (int ni = 0; ni < 16; ++ni) { #if defined(__gfx938__) col = (lane_idx / 16) * 8 + ni * 32 ; for (int ei = 0; ei < 4; ei++) { gOaccum(row, col) = acco_f32[ni * 2][ei]; gOaccum(row, col + 1) = acco_f32[ni * 2 + 1][ei]; col += 2; } #else col = (lane_idx / 16) * 2 + ni * 32 ; for (int ei = 0; ei < 4; ei++) { gOaccum(row, col) = acco_f32[ni * 2][ei]; gOaccum(row, col + 1) = acco_f32[ni * 2 + 1][ei]; col += 8; } #endif } // for (int n = 0; n < 1; n++) { // col = (tidx % 64 / 16) + warpid * 32 + n * 128; // for (int ei = 0; ei < 8; ei ++) { // gOaccum(row, col) = rO(ei, m, n); // col += 4; // } // } gLSEaccum(row) = lse(mi); } } } } #endif } template void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams ¶ms) { FLASH_ASSERT(params.d == Config::HEAD_DIM_K); FLASH_ASSERT(params.d_v == Config::HEAD_DIM_V); BOOL_SWITCH(params.is_causal, Is_causal, [&] { if (params.h_q >= 64 && params.h_k == 1 && !params.is_gfx928) { using T = Traits_Block_M_64; constexpr size_t smem_size = 16384 + 4096; if (params.use_split_kv) { auto mla_kernel = &flash_fwd_splitkv_mla_block_m_64_kernel; const int num_m_block = cute::ceil_div(params.q_seq_per_hk, T::BLOCK_SIZE_M); mla_kernel<<>>(params); } else { auto mla_kernel = &flash_fwd_splitkv_mla_block_m_64_kernel; const int num_m_block = cute::ceil_div(params.q_seq_per_hk, T::BLOCK_SIZE_M); mla_kernel<<>>(params); } } else { constexpr size_t smem_size = 65536; using T = Traits; const int num_m_block = cute::ceil_div(params.q_seq_per_hk, T::BLOCK_SIZE_M); auto mla_kernel = &flash_fwd_splitkv_mla_kernel; mla_kernel<<>>(params); } }); CHECK_CUDA_KERNEL_LAUNCH(); } }