Commit e2e0225c authored by zhanghj2's avatar zhanghj2
Browse files

空kernel可以编译通过

parent 48c6dc42
#pragma once
#include <cute/tensor.hpp>
#include <kerutils/kerutils.cuh>
#include "defines.h"
namespace sm100::fwd::head64 {
using namespace cute;
template<
typename Shape_Q_NoPE, typename TMA_Q_NoPE,
typename Shape_Q_RoPE, typename TMA_Q_RoPE,
typename Shape_O, typename TMA_O
>
struct TmaParams {
Shape_Q_NoPE shape_Q_nope; TMA_Q_NoPE tma_Q_nope;
Shape_Q_RoPE shape_Q_rope; TMA_Q_RoPE tma_Q_rope;
Shape_O shape_O; TMA_O tma_O;
CUtensorMap tensor_map_kv_nope;
};
struct float2x2 {
float2 lo, hi;
};
constexpr int D_Q = 576;
constexpr int D_K = 576;
constexpr int D_V = 512;
constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) to avoid -inf - (-inf) = nan
constexpr int B_H = 64;
constexpr int B_TOPK = 64;
constexpr int NUM_BUFS = 3;
constexpr int NUM_THREADS = 128 + 128 + 128; // 128 scale & exp threads, 128 TMA threads, 32 UTCMMA threads
// Tensor memory columns
namespace tmem_cols {
// 0 ~ 256: output
// 256 ~ 400: Q
// 400 ~ 464: P
constexpr int O = 0;
constexpr int Q = 256;
constexpr int Q_RoPE = 256 + 128;
constexpr int P = 400;
}
using SmemLayoutQNoPE = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<D_V>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutQRoPE = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW64_Atom<bf16>{},
Shape<Int<B_H>, Int<D_Q-D_V>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutOTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutO = SmemLayoutOTiles<8>;
template<int NUM_TILES>
using SmemLayoutKTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_TOPK>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutKNoPE = SmemLayoutKTiles<8>;
using SmemLayoutV = decltype(coalesce(
composition(
SmemLayoutKNoPE{},
Layout<Shape<Int<D_V>, Int<B_TOPK>>, Stride<Int<B_TOPK>, _1>>{}
)
, Shape<_1, _1>{}));
using SmemLayoutKRoPE = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW64_Atom<bf16>{},
Shape<Int<B_TOPK>, Int<64>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutKNoPE_TiledMMA = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_TOPK*2>, Int<D_V/2>>{},
Step<_1, _2>{}
), Shape<_1, _1>{})); // Re-view K-NoPE as B_TOPK*2 x D_V/2 for dual gemm
using SmemLayoutKRoPE_TiledMMA = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW64_Atom<bf16>{},
Shape<Int<B_TOPK*2>, Int<64/2>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutS = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_INTER_Atom<bf16>{},
Shape<Int<B_H>, Int<B_TOPK>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
struct SharedMemoryPlan {
union {
struct {
array_aligned<bf16, cosize_v<SmemLayoutKRoPE>> _k_rope_pad;
array_aligned<bf16, cosize_v<SmemLayoutKNoPE>> _k_pad[2]; // So that q_nope covers k[2]
array_aligned<bf16, cosize_v<SmemLayoutQNoPE>> q_nope;
} q_full;
struct {
array_aligned<bf16, cosize_v<SmemLayoutKRoPE>> k_rope;
array_aligned<bf16, cosize_v<SmemLayoutKNoPE>> k_nope[NUM_BUFS];
} k;
array_aligned<bf16, cosize_v<SmemLayoutO>> o;
} u;
float p_exchange_buf[4][32 * (B_TOPK/2)];
union {
bf16 s[B_H*B_TOPK];
array_aligned<bf16, cosize_v<SmemLayoutQRoPE>> q_rope;
} s_q_rope;
char is_k_valid[NUM_BUFS][B_TOPK/8];
transac_bar_t bar_prologue_q_nope, bar_prologue_q_rope, bar_prologue_utccp_nope, bar_prologue_utccp_rope;
transac_bar_t bar_qk_nope_done[NUM_BUFS], bar_qk_rope_done; // Pi = QKi^T (the nope part) done
transac_bar_t bar_sv_done[NUM_BUFS]; // O += SiVi done (i.e. O, Si and Vi are free)
transac_bar_t bar_kv_nope_ready[NUM_BUFS][2], bar_kv_rope_ready;
transac_bar_t bar_p_free;
transac_bar_t bar_so_ready; // S and O are ready
transac_bar_t bar_k_valid_ready[NUM_BUFS], bar_k_valid_free[NUM_BUFS];
array_aligned<uint32_t, 1> tmem_start_addr;
float rowwise_max_buf[128], rowwise_li_buf[128];
};
using TiledMMA_P = decltype(make_tiled_mma(
SM100_MMA_F16BF16_WS_TS_NOELECT<bf16, bf16, float, B_H, 128, UMMA::Major::K, UMMA::Major::K>{} // Here we use N = 128 = 2*B_TOPK since we're going to use implicit dual gemm: <TODO Fill link here>
));
using TiledMMA_O = decltype(make_tiled_mma(
SM100_MMA_F16BF16_WS_SS_NOELECT<bf16, bf16, float, B_H, 256, UMMA::Major::K, UMMA::Major::MN>{}
));
enum NamedBarriers : int {
wg0_sync = 0,
wg0_warp02_sync = 1,
wg0_warp13_sync = 2,
pepi_sync = 3,
};
}
#include "../phase1.h"
#include "../phase1.cuh"
namespace sm100::fwd::head64 {
template void run_fwd_phase1_kernel<512>(const SparseAttnFwdParams& params);
}
#include "../phase1.h"
#include "../phase1.cuh"
namespace sm100::fwd::head64 {
template void run_fwd_phase1_kernel<576>(const SparseAttnFwdParams& params);
}
#pragma once
#include "phase1.h"
#include <math_constants.h>
#include <cute/tensor.hpp>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/arch/arch.h>
#include <cutlass/cuda_host_adapter.hpp>
#include <kerutils/kerutils.cuh>
#include "params.h"
#include "utils.h"
#include "sm100/helpers.h"
#include "sm100/prefill/sparse/common_subroutine.h"
#include "config.h"
namespace sm100::fwd::head64 {
using namespace cute;
/*
Pipeline Overview:
| Copy | MMA | Scale & Exp |
KV0
KV1
KV2
P0 = QK0^T
S0 = exp(P0)
scale(O) w.r.t P0
P1 = QK1^T
S1 = exp(P1)
O += S0V0
KV3 scale(O) w.r.t P1
P2 = QK2^T
S2 = exp(P2)
O += S1V1
KV4 scale(O) w.r.t P2
P3 = QK3^T
S3 = exp(P3)
O += S2V2
KV5 scale(O) w.r.t P3
...
O += S(n-3)V(n-3)
scale(O) w.r.t P(n-2)
P(n-1) = QK(n-1)^T
S(n-1) = exp(P(n-1))
O += S(n-2)V(n-2)
scale(O) w.r.t P(n-1)
O += S(n-1)V(n-1)
*/
using FwdMode = SparseAttnFwdMode;
template<bool HAVE_ROPE, typename TmaParams>
__global__ void __launch_bounds__(NUM_THREADS, 1, 1)
sparse_attn_fwd_kernel(__grid_constant__ const SparseAttnFwdParams params, __grid_constant__ const TmaParams tma_params) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000 && __CUDA_ARCH__ < 1200)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__))
// Grid shape: [s_q, 1, 1]
const int s_q_idx = blockIdx.x;
const int warp_idx = cutlass::canonical_warp_idx_sync();
const int lane_idx = threadIdx.x % 32;
const int warpgroup_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
const int idx_in_warpgroup = threadIdx.x % 128;
const int topk_length = params.topk_length != nullptr ? __ldg(params.topk_length + s_q_idx) : params.topk;
const int num_k_blocks = max(cute::ceil_div(topk_length, (int)B_TOPK), 1); // num_k_blocks always >= 1
// Define shared tensors
extern __shared__ char wksp_buf[];
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);
int* gIndices = params.indices + s_q_idx*params.stride_indices_s_q; // [topk]
// Allocate tmem tensors
TiledMMA tiled_mma_P = TiledMMA_P{};
TiledMMA tiled_mma_O = TiledMMA_O{};
// NOTE These tXXX tensors are only for a forged layout (so that CuTe is able to generate correct address in cute::gemm)
Tensor tP = partition_fragment_C(tiled_mma_P, Shape<Int<B_H>, _128>{});
Tensor tQ_nope_part0 = tiled_mma_P.get_slice(_0{}).make_fragment_A(
partition_shape_A(tiled_mma_P, Shape<Int<B_H>, Int<(D_V/2)/2>>{})
);
Tensor tQ_nope_part1 = tiled_mma_P.get_slice(_0{}).make_fragment_A(
partition_shape_A(tiled_mma_P, Shape<Int<B_H>, Int<(D_V/2)/2>>{})
);
Tensor tQ_rope = tiled_mma_P.get_slice(_0{}).make_fragment_A(
partition_shape_A(tiled_mma_P, Shape<Int<B_H>, Int<64/2>>{})
);
Tensor tO = partition_fragment_C(tiled_mma_O, Shape<Int<B_H>, Int<D_V>>{});
tP.data().get() = tmem_cols::P;
tQ_nope_part0.data().get() = tmem_cols::Q;
tQ_nope_part1.data().get() = tmem_cols::Q + 64;
tQ_rope.data().get() = tmem_cols::Q_RoPE;
tO.data().get() = tmem_cols::O;
if (warp_idx == 0) {
if (elect_one_sync()) {
// Copy Q
if constexpr (HAVE_ROPE) {
cute::prefetch_tma_descriptor(tma_params.tma_Q_rope.get_tma_descriptor());
}
cute::prefetch_tma_descriptor(tma_params.tma_Q_nope.get_tma_descriptor());
plan.bar_prologue_q_nope.init(1);
plan.bar_prologue_q_rope.init(1);
fence_barrier_init();
if constexpr (HAVE_ROPE) {
Tensor gQ_rope = tma_params.tma_Q_rope.get_tma_tensor(tma_params.shape_Q_rope)(_, _, s_q_idx);
Tensor sQ_rope = make_tensor(make_smem_ptr(plan.s_q_rope.q_rope.data()), SmemLayoutQRoPE{});
ku::launch_tma_copy(tma_params.tma_Q_rope, gQ_rope, sQ_rope, plan.bar_prologue_q_rope, TMA::CacheHintSm90::EVICT_FIRST);
}
Tensor gQ_nope = tma_params.tma_Q_nope.get_tma_tensor(tma_params.shape_Q_nope)(_, _, s_q_idx);
Tensor sQ_nope = make_tensor(make_smem_ptr(plan.u.q_full.q_nope.data()), SmemLayoutQNoPE{});
ku::launch_tma_copy(tma_params.tma_Q_nope, gQ_nope, sQ_nope, plan.bar_prologue_q_nope, TMA::CacheHintSm90::EVICT_FIRST);
cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor());
cute::prefetch_tma_descriptor(&(tma_params.tensor_map_kv_nope));
// Initialize other barriers
plan.bar_prologue_utccp_rope.init(1);
plan.bar_prologue_utccp_nope.init(1);
CUTE_UNROLL
for (int i = 0; i < NUM_BUFS; ++i) {
plan.bar_qk_nope_done[i].init(1);
plan.bar_sv_done[i].init(1);
plan.bar_kv_nope_ready[i][0].init(1);
plan.bar_kv_nope_ready[i][1].init(1);
plan.bar_k_valid_ready[i].init(B_TOPK/8);
plan.bar_k_valid_free[i].init(128);
}
plan.bar_p_free.init(128);
plan.bar_so_ready.init(128);
plan.bar_qk_rope_done.init(1);
plan.bar_kv_rope_ready.init(64);
fence_barrier_init();
}
// Initialize TMEM
cute::TMEM::Allocator1Sm().allocate(512, plan.tmem_start_addr.data());
TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0);
cute::TMEM::Allocator1Sm().release_allocation_lock();
}
__syncthreads();
if (warpgroup_idx == 0) {
// Scale & Exp warps
// The following three numbers are
// - mi: max_logits used to scale Pi (i.e. O := exp2(Pi*scale - mi) @ V)
// - li: sumexp, i.e. li := sum(exp(Pi*scale - mi))
// - real_mi: real max logits, i.e. real_mi := max(Pi*scale)
// where Pi is the i-th row of P, P := QK^T
// mi and real_mi are always consistent within the two threads that
// controls one row (i.e. thread 0+64, 1+65, 2+66, ...) after every update
float mi = MAX_INIT_VAL;
float li = 0.0f;
float real_mi = -CUDART_INF_F;
bf16* sS_base = plan.s_q_rope.s + lane_idx*8 + (warp_idx&1)*(B_H/2)*8 + (warp_idx/2)*B_H*(B_TOPK/2);
static constexpr int NUM_ELEMS_PER_THREAD = B_TOPK / 2;
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks; ++k) {
// Wait for P
NamedBarrier::arrive_and_wait(64, NamedBarriers::wg0_warp02_sync+(warp_idx&1));
plan.bar_qk_nope_done[k%NUM_BUFS].wait((k/NUM_BUFS)&1);
plan.bar_k_valid_ready[k%NUM_BUFS].wait((k/NUM_BUFS)&1); // Put the barrier wait here for more code reordering space
ku::tcgen05_after_thread_sync();
// Load P
float p[NUM_ELEMS_PER_THREAD];
retrieve_mask_and_reduce_p<
NUM_ELEMS_PER_THREAD,
tmem_cols::P,
NamedBarriers::wg0_warp02_sync,
NamedBarriers::wg0_warp13_sync,
false
>(
plan.is_k_valid[k%NUM_BUFS],
warp_idx, lane_idx,
[&]() {plan.bar_p_free.arrive();},
plan.p_exchange_buf,
p
);
plan.bar_k_valid_free[k%NUM_BUFS].arrive();
// Get rowwise max of Pi
float cur_pi_max = get_max<NUM_ELEMS_PER_THREAD>(p);
cur_pi_max *= params.sm_scale_div_log2;
plan.rowwise_max_buf[idx_in_warpgroup] = cur_pi_max;
NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync);
cur_pi_max = max(cur_pi_max, plan.rowwise_max_buf[idx_in_warpgroup^64]);
real_mi = max(real_mi, cur_pi_max);
bool should_scale_o = __any_sync(0xffffffff, cur_pi_max - mi > 6.0f);
// By this point:
// - cur_pi_max, real_mi, and mi is identical within each row (i.e. thread 0+64, 1+65, ...)
// - should_scale_o is identical among every warp, and is identical among threads that controls the same row (i.e. among threads 0~31+64~95; and is identical among threads 32~63+96~127)
// Calc scale factor, and scale li
float new_max, scale_for_old;
if (!should_scale_o) {
// Don't scale O
scale_for_old = 1.0f;
new_max = mi;
} else {
new_max = max(cur_pi_max, mi);
scale_for_old = exp2f(mi - new_max);
}
mi = new_max; // mi is still identical within each row
// Calculate S
nv_bfloat162 s[NUM_ELEMS_PER_THREAD/2];
float cur_sum = get_s_from_p<NUM_ELEMS_PER_THREAD>(s, p, params.sm_scale_div_log2, new_max);
li = fma(li, scale_for_old, cur_sum);
// Wait for last SV gemm, write S
if (k > 0) {
plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);
}
CUTE_UNROLL
for (int i = 0; i < NUM_ELEMS_PER_THREAD/8; i += 1) {
*(uint128_t*)(sS_base + B_H*8*i) = *(uint128_t*)(s + i*4);
}
// Scale O
if (k > 0 && should_scale_o) {
// plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); // NOTE We have waited for last SV gemm before
ku::tcgen05_after_thread_sync();
rescale_O<D_V, 32, tmem_cols::O>(scale_for_old);
ku::tcgen05_before_thread_sync();
}
fence_view_async_shared();
plan.bar_so_ready.arrive();
}
// Epilogue
if (real_mi == -CUDART_INF_F) {
// real_mi == -CUDART_INF_F <=> No valid TopK indices
// We set li to 0 to fit the definition that li := exp(x[i] - mi)
li = 0.0f;
mi = -CUDART_INF_F;
}
// Exchange li
plan.rowwise_li_buf[idx_in_warpgroup] = li;
NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync);
li += plan.rowwise_li_buf[idx_in_warpgroup^64];
// Store mi and li
if (idx_in_warpgroup < 64) {
int global_index = s_q_idx*params.h_q + idx_in_warpgroup;
float cur_lse = fmaf(mi, CUDART_LN2_F, logf(li));
cur_lse = cur_lse == -CUDART_INF_F ? +CUDART_INF_F : cur_lse;
params.max_logits[global_index] = real_mi*CUDART_LN2_F;
params.lse[global_index] = cur_lse;
}
// Wait for the last GEMM
plan.bar_sv_done[(num_k_blocks-1)%NUM_BUFS].wait(((num_k_blocks-1)/NUM_BUFS)&1);
ku::tcgen05_after_thread_sync();
// Fetch dO if necessary
// Store O
float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : __ldg(params.attn_sink + (idx_in_warpgroup%64))*CUDART_L2E_F;
float output_scale = __fdividef(1.0f, li + exp2f(attn_sink - mi));
Tensor sO = make_tensor(make_smem_ptr(plan.u.o.data()), SmemLayoutO{});
constexpr int B_EPI = 64;
Tensor tma_gO = flat_divide(
tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx),
Shape<Int<B_H>, Int<B_EPI>>{}
)(_, _, _0{}, _);
Tensor sO_divided = flat_divide(
sO,
Shape<Int<B_H>, Int<B_EPI>>{}
)(_, _, _0{}, _);
auto thr_tma = tma_params.tma_O.get_slice(_0{});
float2 o[B_EPI/2];
bool have_valid_indices = __any_sync(0xffffffff, li != 0); // Prevent some threads' li == 0 and some threads' li != 0 which lead to deadlock during ku::tmem_ld
if (!have_valid_indices) {
// If there are no valid indices, we set o[i] to 0 and don't load from TMEM
CUTE_UNROLL
for (int i = 0; i < B_EPI/2; ++i)
o[i].x = o[i].y = 0.0f;
output_scale = 1.0f;
}
float2 output_scale_float2 = make_float2(output_scale, output_scale);
bf16* sO_addrs[8];
CUTE_UNROLL
for (int i = 0; i < B_EPI/8; ++i) {
sO_addrs[i] = &sO(idx_in_warpgroup%64, i*8);
}
CUTE_UNROLL
for (int c = 0; c < 2; ++c) {
// Each tile: 64 x 256
CUTE_UNROLL
for (int k = 0; k < (D_V/4)/B_EPI; ++k) {
// Load O from tO
if (have_valid_indices) {
ku::tmem_ld_32dp32bNx<B_EPI>(tmem_cols::O + c*128 + k*B_EPI, o);
cutlass::arch::fence_view_async_tmem_load();
}
// Convert and store
CUTE_UNROLL
for (int i = 0; i < B_EPI/8; ++i) {
nv_bfloat162 o_bf16[4];
CUTE_UNROLL
for (int j = 0; j < 4; ++j) {
o[i*4+j] = ku::float2_mul(o[i*4+j], output_scale_float2);
o_bf16[j] = __float22bfloat162_rn(o[i*4+j]);
}
*(uint128_t*)(sO_addrs[i] + (c*(D_V/2) + (idx_in_warpgroup/64)*(D_V/4) + k*B_EPI)*64) = *(uint128_t*)(o_bf16);
}
// Sync
fence_view_async_shared();
NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync);
if (warp_idx == 0 && elect_one_sync()) {
int epi_chunk_idx = c*(D_V/2/B_EPI) + k;
cute::copy(
tma_params.tma_O,
thr_tma.partition_S(sO_divided(_, _, epi_chunk_idx)),
thr_tma.partition_D(tma_gO(_, _, epi_chunk_idx))
);
}
if (warp_idx == 1 && elect_one_sync()) {
int epi_chunk_idx = c*(D_V/2/B_EPI) + (D_V/B_EPI/4) + k;
cute::copy(
tma_params.tma_O,
thr_tma.partition_S(sO_divided(_, _, epi_chunk_idx)),
thr_tma.partition_D(tma_gO(_, _, epi_chunk_idx))
);
}
}
}
if (warp_idx == 0) {
cute::TMEM::Allocator1Sm().free(0, 512);
}
} else if (warpgroup_idx == 1) {
// Producer warp for KV
int warp_idx = cutlass::canonical_warp_idx_sync() - 4;
constexpr int NUM_WARPS = 4, NUM_LOCAL_ROWS_PER_WARP = (B_TOPK/4)/NUM_WARPS;
if (elect_one_sync()) {
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks; ++k) {
int4 indices[NUM_LOCAL_ROWS_PER_WARP];
int max_indices = -1, min_indices = params.s_kv;
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) {
indices[local_row] = __ldg((int4*)(gIndices + k*B_TOPK) + local_row*NUM_WARPS + warp_idx);
max_indices = max(max_indices, int4_max(indices[local_row]));
min_indices = min(min_indices, int4_min(indices[local_row]));
}
bool is_all_rows_invalid = min_indices == params.s_kv || max_indices == -1;
bool should_skip_tma = is_all_rows_invalid && k >= NUM_BUFS;
if (k == 2) {
plan.bar_prologue_utccp_nope.wait(0); // Since q_nope coincidences with k[2]
}
// Copy NoPE
int cur_buf = k%NUM_BUFS;
plan.bar_sv_done[cur_buf].wait((k/NUM_BUFS)&1^1);
bf16* sK_nope_base = plan.u.k.k_nope[cur_buf].data() + warp_idx*4*64;
auto load_kv_nope_part = [&](int part_idx) {
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) {
CUTE_UNROLL
for (int local_col = part_idx*(D_V/2/64); local_col < (part_idx+1)*(D_V/2/64); ++local_col) {
ku::tma_gather4(
&(tma_params.tensor_map_kv_nope),
plan.bar_kv_nope_ready[cur_buf][part_idx],
sK_nope_base + local_row*(4*NUM_WARPS)*64 + local_col*(B_TOPK*64),
local_col*64,
indices[local_row],
(int64_t)TMA::CacheHintSm90::EVICT_LAST
);
}
}
};
if (!should_skip_tma) {
load_kv_nope_part(0);
load_kv_nope_part(1);
} else {
// NOTE See head128/phase1.cuh for this TMA skipping technique
CUTE_UNROLL
for (int part_idx = 0; part_idx < 2; ++part_idx)
plan.bar_kv_nope_ready[cur_buf][part_idx].complete_transaction(NUM_LOCAL_ROWS_PER_WARP*4*D_V/2*sizeof(bf16));
}
}
}
} else {
// MMA warp
if (warp_idx == 8 && elect_one_sync()) {
// S -> T copy for Q
UMMA::SmemDescriptor sQ_nope_desc = UMMA::make_umma_desc<UMMA::Major::K>(
make_tensor(
make_smem_ptr(plan.u.q_full.q_nope.data()),
tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H*2>, Int<64>>{} // We use this shape for dual gemm (TODO Link)
)
)
);
UMMA::SmemDescriptor sQ_rope_desc = UMMA::make_umma_desc<UMMA::Major::K>(
make_tensor(
make_smem_ptr(plan.s_q_rope.q_rope.data()),
tile_to_shape(
UMMA::Layout_K_SW64_Atom<bf16>{},
Shape<Int<B_H*2>, Int<32>>{}
)
)
);
if constexpr (HAVE_ROPE) {
// Copy the RoPE tile: 128 rows * 32 cols (64B) (in UTCCP's view), or 64 rows * 64 cols (in our view)
plan.bar_prologue_q_rope.arrive_and_expect_tx(B_H*(D_Q-D_V)*sizeof(bf16));
plan.bar_prologue_q_rope.wait(0);
ku::tcgen05_after_thread_sync();
CUTE_UNROLL
for (int subtile_idx = 0; subtile_idx < 2; ++subtile_idx) {
// A subtile is 128 rows * 16 cols (256b, 32B) (in UTCCP's view), or 64 rows * 16 cols * 2 (in our view)
SM100_UTCCP_128dp256bit_1cta::copy(
sQ_rope_desc + (subtile_idx*32) / 16,
tmem_cols::Q_RoPE + subtile_idx*8
);
}
ku::umma_arrive_noelect(plan.bar_prologue_utccp_rope);
}
plan.bar_prologue_q_nope.arrive_and_expect_tx(B_H*D_V*sizeof(bf16));
plan.bar_prologue_q_nope.wait(0);
ku::tcgen05_after_thread_sync();
CUTE_UNROLL
for (int tile_idx = 0; tile_idx < D_V/64/2; ++tile_idx) {
// A tile is 128 rows * 64 cols (128B) (in UTCCP's view), or 64 rows * 128 cols (in our view)
CUTE_UNROLL
for (int subtile_idx = 0; subtile_idx < 4; ++subtile_idx) {
// A subtile is 128 rows * 16 cols (256b, 32B) (in UTCCP's view), or 64 rows * 16 cols * 2 (in our view)
SM100_UTCCP_128dp256bit_1cta::copy(
sQ_nope_desc + (tile_idx*(B_H*128*2) + subtile_idx*32) / 16, // Remember that 4 LSBs are not included
tmem_cols::Q + tile_idx*32 + subtile_idx*8
);
}
}
ku::umma_arrive_noelect(plan.bar_prologue_utccp_nope);
if constexpr (HAVE_ROPE) {
plan.bar_prologue_utccp_rope.wait(0);
}
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks+1; ++k) {
if (k < num_k_blocks) {
// Pi = QKi^T
int cur_buf = k%NUM_BUFS;
Tensor sK_nope = make_tensor(make_smem_ptr(plan.u.k.k_nope[cur_buf].data()), SmemLayoutKNoPE_TiledMMA{});
Tensor sK_rope = make_tensor(make_smem_ptr(plan.u.k.k_rope.data()), SmemLayoutKRoPE_TiledMMA{});
plan.bar_p_free.wait(k&1^1);
ku::tcgen05_after_thread_sync();
// Wait for K (RoPE)
// P = Q(rope) @ K(rope)^T
if constexpr (HAVE_ROPE) {
plan.bar_kv_rope_ready.wait(k&1);
ku::tcgen05_after_thread_sync();
ku::utcmma_ts(tiled_mma_P, tQ_rope, sK_rope, tP, true);
ku::umma_arrive_noelect(plan.bar_qk_rope_done);
}
// Wait for K (NoPE)
if (k == 0) {
plan.bar_prologue_utccp_nope.wait(0);
}
Tensor sK_nope_divided = flat_divide(sK_nope, Tile<Int<B_TOPK*2>, Int<D_V/4>>{})(_, _, _0{}, _);
CUTE_UNROLL
for (int kv_nope_part_idx = 0; kv_nope_part_idx < 2; ++kv_nope_part_idx) {
plan.bar_kv_nope_ready[cur_buf][kv_nope_part_idx].arrive_and_expect_tx(B_TOPK*D_V/2*sizeof(bf16));
plan.bar_kv_nope_ready[cur_buf][kv_nope_part_idx].wait((k/NUM_BUFS)&1);
ku::tcgen05_after_thread_sync();
// P += Q(nope) @ K(nope)^T
bool clear_accum = (!HAVE_ROPE) && kv_nope_part_idx == 0;
ku::utcmma_ts(tiled_mma_P, kv_nope_part_idx ? tQ_nope_part1 : tQ_nope_part0, sK_nope_divided(_, _, kv_nope_part_idx), tP, clear_accum);
}
ku::umma_arrive_noelect(plan.bar_qk_nope_done[cur_buf]);
}
if (k > 0) {
// O += S(i-1)V(i-1)
int cur_buf = (k-1)%NUM_BUFS;
Tensor sS = make_tensor(make_smem_ptr(plan.s_q_rope.s), SmemLayoutS{});
Tensor sV = make_tensor(make_smem_ptr(plan.u.k.k_nope[cur_buf].data()), SmemLayoutV{});
// Wait for S(i-1) and O to be scaled
plan.bar_so_ready.wait((k-1)&1);
ku::tcgen05_after_thread_sync();
// O += sS @ sV
ku::utcmma_ss(tiled_mma_O, sS, sV, tO, k == 1);
ku::umma_arrive_noelect(plan.bar_sv_done[cur_buf]);
}
}
} else if (warp_idx == 9) {
// KV valid loading warp
if (lane_idx < B_TOPK/8) {
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks; ++k) {
char k_validness_mask = load_indices_and_generate_mask(
lane_idx,
gIndices + k*B_TOPK,
params.s_kv,
k*B_TOPK,
topk_length
);
int cur_buf = k%NUM_BUFS;
plan.bar_k_valid_free[cur_buf].wait((k/NUM_BUFS)&1^1);
plan.is_k_valid[cur_buf][lane_idx] = k_validness_mask;
plan.bar_k_valid_ready[cur_buf].arrive();
}
}
} else if (warp_idx == 10 || warp_idx == 11) {
if constexpr (HAVE_ROPE) {
int thread_idx = threadIdx.x - 10*32;
constexpr int GROUP_SIZE = 8, NUM_GROUPS = 64/GROUP_SIZE, ROWS_PER_THREAD = B_TOPK/NUM_GROUPS;
int group_idx = thread_idx / GROUP_SIZE, idx_in_group = thread_idx % GROUP_SIZE;
Tensor sK_rope = make_tensor(make_smem_ptr(plan.u.k.k_rope.data()), SmemLayoutKRoPE{});
bf16* sK_rope_base = &sK_rope(group_idx, idx_in_group*8);
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks; ++k) {
int indices[ROWS_PER_THREAD];
CUTE_UNROLL
for (int local_row = 0; local_row < ROWS_PER_THREAD; ++local_row)
indices[local_row] = __ldg(gIndices + k*B_TOPK + group_idx + local_row*NUM_GROUPS);
plan.bar_qk_rope_done.wait(k&1^1);
CUTE_UNROLL
for (int local_row = 0; local_row < ROWS_PER_THREAD; ++local_row) {
int index = indices[local_row];
ku::cp_async_cacheglobal<ku::PrefetchSize::B128>(
params.kv + (int64_t)index*params.stride_kv_s_kv + 512 + idx_in_group*8,
sK_rope_base + local_row*NUM_GROUPS*32,
index >= 0 && index < params.s_kv
); // NOTE Using cp.async instead of TMA is faster here
// NOTE Here we only consider the range of `index` instead of also checking against topk_length, as it's noted that under this scenario (i.e. there exists a valid index among indices[topk_length: ] that points to a token who has NaN inside)
}
cutlass::arch::cpasync_barrier_arrive_noinc((uint64_t*)&(plan.bar_kv_rope_ready));
}
}
}
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100");
}
#endif
}
template<int D_QK>
void run_fwd_phase1_kernel(const SparseAttnFwdParams& params) {
KU_ASSERT(params.h_kv == 1);
KU_ASSERT(params.topk % B_TOPK == 0); // To save some boundry checkings
KU_ASSERT(params.h_q == B_H); // To save some calculation
KU_ASSERT(params.d_qk == D_QK);
static_assert(D_QK == 576 || D_QK == 512);
auto shape_Q_nope = make_shape(params.h_q, D_V, params.s_q);
auto tma_Q_nope = cute::make_tma_copy(
SM90_TMA_LOAD{},
make_tensor(
make_gmem_ptr((bf16*)params.q),
make_layout(
shape_Q_nope,
make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q)
)
),
SmemLayoutQNoPE{}
);
auto shape_Q_rope = make_shape(params.h_q, D_Q-D_V, params.s_q);
auto tma_Q_rope = cute::make_tma_copy(
SM90_TMA_LOAD{},
make_tensor(
make_gmem_ptr((bf16*)params.q + D_V),
make_layout(
shape_Q_rope,
make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q)
)
),
SmemLayoutQRoPE{}
);
auto shape_O = make_shape(params.h_q, params.d_v, params.s_q);
auto tma_O = cute::make_tma_copy(
SM90_TMA_STORE{},
make_tensor(
make_gmem_ptr((bf16*)params.out),
make_layout(
shape_O,
make_stride(params.d_v, _1{}, params.h_q*params.d_v)
)
),
SmemLayoutOTiles<1>{}
);
CUtensorMap tensor_map_kv_nope;
{
uint64_t size[2] = {D_V, (unsigned long)params.s_kv};
uint64_t stride[1] = {params.stride_kv_s_kv*sizeof(bf16)};
uint32_t box_size[2] = {64, 1};
uint32_t elem_stride[2] = {1, 1};
CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
&tensor_map_kv_nope,
CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
2,
params.kv,
size,
stride,
box_size,
elem_stride,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);
KU_ASSERT(res == CUresult::CUDA_SUCCESS);
}
TmaParams<
decltype(shape_Q_nope), decltype(tma_Q_nope),
decltype(shape_Q_rope), decltype(tma_Q_rope),
decltype(shape_O), decltype(tma_O)
> tma_params = {
shape_Q_nope, tma_Q_nope,
shape_Q_rope, tma_Q_rope,
shape_O, tma_O,
tensor_map_kv_nope
};
auto kernel = &sparse_attn_fwd_kernel<D_QK == 576, decltype(tma_params)>;
constexpr size_t smem_size = sizeof(SharedMemoryPlan);
KU_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
kernel<<<params.s_q, NUM_THREADS, smem_size, params.stream>>>(params, tma_params);
KU_CHECK_KERNEL_LAUNCH();
}
}
#pragma once
#include "params.h"
namespace sm100::fwd::head64 {
template<int D_QK>
void run_fwd_phase1_kernel(const SparseAttnFwdParams& params);
}
#pragma once
#include "phase1.h"
#include <math_constants.h>
#include <cutlass/float8.h>
#include <cute/tensor.hpp>
#include <kerutils/kerutils.cuh>
#include "defines.h"
#include "params.h"
namespace sm100::fwd_for_small_topk::head128 {
using namespace cute;
template<SparseAttnFwdMode FWD_MODE, int D_QK>
struct KernelTemplate {
using ArgT = SparseFwdArgT<FWD_MODE>;
static constexpr bool IS_DECODE = is_decode_v<FWD_MODE>;
static constexpr bool IS_PREFILL = !IS_DECODE;
using fp8_e4m3 = cutlass::float_e4m3_t;
using fp8_e8m0 = __nv_fp8_e8m0;
struct TmaParamsForPrefill {
CUtensorMap tensor_map_q;
CUtensorMap tensor_map_kv;
CUtensorMap tensor_map_o;
};
struct TmaParamsForDecode {
CUtensorMap tensor_map_q;
CUtensorMap tensor_map_o;
CUtensorMap tensor_map_o_accum;
CUtensorMap tensor_map_kv_nope;
CUtensorMap tensor_map_kv_rope;
CUtensorMap tensor_map_extra_kv_nope; // Only available if extra_kv is enabled
CUtensorMap tensor_map_extra_kv_rope;
};
using TmaParams = std::conditional_t<
IS_DECODE,
TmaParamsForDecode,
TmaParamsForPrefill
>;
static_assert(D_QK == 512);
static constexpr int D_Q = D_QK;
static constexpr int D_K = D_QK;
static constexpr int D_V = 512;
static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) to avoid -inf - (-inf) = nan
static constexpr int H_Q = 128; // For 2 CTAs
static constexpr int B_TOPK = 64; // For 2 CTAs
static constexpr int NUM_THREADS = 128*4;
static constexpr int NUM_WORKER_THREADS = IS_PREFILL ? (128 + 4 + (B_TOPK/8) + 1 + 128)*2 + 1 : (128 + 128 + 1 + 32 + 2 + 128)*2;
// For non-decode mode, we have 4 (half-)KV buffers
// For decode mode, we have 3 (half-)KV buffers with two raw KV buffers
static constexpr int NUM_K_BUFS = IS_DECODE ? 3 : 4;
static constexpr int NUM_RAW_K_BUFS = IS_DECODE ? 2 : 0;
static constexpr int NUM_INDEX_BUFS = IS_DECODE ? 4 : 4;
static constexpr int D_NOPE = 448;
static constexpr int D_ROPE = 64;
static constexpr int TMA_K_STRIDE_FOR_DECODING = D_NOPE + 2*D_ROPE;
static constexpr int NUM_SCALES_EACH_TOKEN = 8; // 7 scales + 1 padding
static constexpr int B_EPI = 64; // Epilogue block size for normal case (i.e. prefill or non-splitkv decoding)
static constexpr int B_EPI_SPLITKV = 32; // Epilogue block size for splitkv decoding
static constexpr int NUM_EPI_SPLITKV_BUFS = 4; // The number of epilogue buffers for splitkv decoding
static_assert((H_Q/2)*D_Q*sizeof(bf16) >= NUM_EPI_SPLITKV_BUFS*(H_Q/2)*(B_EPI_SPLITKV*2)*sizeof(float));
// Tensor memory columns
struct tmem_cols {
// 0 ~ 256: Output accumulator
// 256 ~ 384: Q
// 384 ~ 448: P
static constexpr int O = 0;
static constexpr int Q = 256;
static constexpr int P = 384;
};
struct SharedMemoryPlan {
array_aligned<bf16, (H_Q/2)*D_Q> Q; // Will be output for epilogue
array_aligned<bf16, B_TOPK*(D_K/2)> K[NUM_K_BUFS];
array_aligned<fp8_e4m3, B_TOPK*(D_K/2)> K_raw[NUM_RAW_K_BUFS];
array_aligned<bf16, (H_Q/2)*B_TOPK> S;
float P_exchange[4][(H_Q/2/2)*(B_TOPK/2)];
float rowwise_max_buf[128], rowwise_li_buf[128];
CUTE_ALIGNAS(16) char is_k_valid[NUM_INDEX_BUFS][B_TOPK/8];
CUTE_ALIGNAS(16) int tma_coord[NUM_INDEX_BUFS][B_TOPK];
CUTE_ALIGNAS(16) fp8_e8m0 scales[NUM_INDEX_BUFS][B_TOPK][NUM_SCALES_EACH_TOKEN/2];
transac_bar_t bar_sQ_full, bar_tQ_empty, bar_tQ_full;
transac_bar_t bar_tOut_full, bar_tOut_empty;
transac_bar_t bar_KV_full[NUM_K_BUFS], bar_KV_empty[NUM_K_BUFS];
transac_bar_t bar_P_empty;
transac_bar_t bar_QK_done, bar_SV_done;
transac_bar_t bar_S_O_full;
transac_bar_t bar_li_full, bar_li_empty;
// The following barriers are prefill-only
transac_bar_t bar_clc_full, bar_clc_empty;
// The following barriers are decode-only
transac_bar_t bar_raw_KV_full[NUM_RAW_K_BUFS], bar_raw_KV_empty[NUM_RAW_K_BUFS];
transac_bar_t bar_valid_coord_scales_full[NUM_INDEX_BUFS], bar_valid_coord_scales_empty[NUM_INDEX_BUFS];
ku::CLCResponseObj clc_response_obj;
array_aligned<uint32_t, 1> tmem_start_addr;
};
using TiledMMA_P = decltype(make_tiled_mma(
SM100_MMA_F16BF16_2x1SM_TS_NOELECT<bf16, bf16, float, H_Q, B_TOPK*2, UMMA::Major::K, UMMA::Major::K>{}
)); // *2 for dual gemm
using TiledMMA_O = decltype(make_tiled_mma(
SM100_MMA_F16BF16_2x1SM_SS_NOELECT<bf16, bf16, float, H_Q, 256, UMMA::Major::K, UMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{},
Tile<Int<128>, Layout<Shape<_128, _2, _2>, Stride<_1, _256, _128>>, _16>{} // We use this permutation layout to let CTA0 takes V[:, 0:256] and CTA1 takes V[:, 256:512]
));
struct barrier_ids {
static constexpr int WG0_SYNC = 0;
static constexpr int WG2_SYNC = 1;
static constexpr int WG2_WARP02_SYNC = 2;
static constexpr int WG2_WARP13_SYNC = 3;
};
static __device__ void
sparse_attn_fwd_kernel_devfunc(const ArgT &params, const TmaParams &tma_params);
static void run(const ArgT& params);
};
}
#include "../phase1.h"
#include "../phase1.cuh"
namespace sm100::fwd_for_small_topk::head128 {
template void run_fwd_for_small_topk_phase1_kernel<SparseAttnFwdMode::DecodeWithSplitKV, 512>(const SparseAttnDecodeParams& params);
}
#include "../phase1.h"
#include "../phase1.cuh"
namespace sm100::fwd_for_small_topk::head128 {
template void run_fwd_for_small_topk_phase1_kernel<SparseAttnFwdMode::Prefill, 512>(const SparseAttnFwdParams& params);
}
#pragma once
#include "phase1.h"
#include <math_constants.h>
#include <cute/tensor.hpp>
#include <cutlass/cluster_launch.hpp>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/arch/arch.h>
#include "params.h"
#include "utils.h"
#include "sm100/prefill/sparse/common_subroutine.h"
#include "sm100/helpers.h"
#include "config.h"
namespace sm100::fwd_for_small_topk::head128 {
using namespace cute;
using FwdMode = SparseAttnFwdMode;
template<FwdMode FWD_MODE, int D_QK>
__device__ void
KernelTemplate<FWD_MODE, D_QK>::sparse_attn_fwd_kernel_devfunc(const ArgT &params, const TmaParams &tma_params) {
#ifdef KERUTILS_ENABLE_SM100A
// Grid shape: [2*s_q, 1, 1] for prefilling, [2*s_q, num_sm_parts, 1] for decoding
// Cluster shape: [2, 1, 1]
const int warp_idx = cutlass::canonical_warp_idx_sync();
const int lane_idx = threadIdx.x % 32;
const int warpgroup_idx = cutlass::canonical_warp_group_idx();
const int idx_in_warpgroup = threadIdx.x % 128;
const int cta_idx = block_id_in_cluster().x;
extern __shared__ char wksp_buf[];
SharedMemoryPlan &smem = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);
if (warp_idx == 0 && elect_one_sync()) {
cute::prefetch_tma_descriptor(&tma_params.tensor_map_q);
cute::prefetch_tma_descriptor(&tma_params.tensor_map_o);
if constexpr (IS_DECODE) {
cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv_nope);
cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv_rope);
} else {
cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv);
}
} else if (warp_idx == 1 && elect_one_sync()) {
smem.bar_sQ_full.init(1);
smem.bar_tQ_empty.init(1);
smem.bar_tQ_full.init(1);
smem.bar_tOut_full.init(1);
smem.bar_tOut_empty.init(256);
smem.bar_P_empty.init(256);
smem.bar_QK_done.init(1);
smem.bar_SV_done.init(1);
smem.bar_S_O_full.init(256);
smem.bar_li_full.init(H_Q/2);
smem.bar_li_empty.init(128);
if constexpr (FWD_MODE != FwdMode::DecodeWithSplitKV) {
smem.bar_clc_full.init(1);
smem.bar_clc_empty.init(NUM_WORKER_THREADS);
}
fence_barrier_init();
} else if (warp_idx == 2) {
cute::TMEM::Allocator2Sm().allocate(512, smem.tmem_start_addr.data());
KU_TRAP_ONLY_DEVICE_ASSERT(smem.tmem_start_addr.data()[0] == 0);
cute::TMEM::Allocator2Sm().release_allocation_lock();
} else if (warp_idx == 3 && elect_one_sync()) {
CUTE_UNROLL
for (int i = 0; i < NUM_K_BUFS; ++i) {
smem.bar_KV_full[i].init(IS_PREFILL ? 1 : (128/32)*2+1);
smem.bar_KV_empty[i].init(1);
}
CUTE_UNROLL
for (int i = 0; i < NUM_INDEX_BUFS; ++i) {
smem.bar_valid_coord_scales_full[i].init(IS_PREFILL ? B_TOPK/8 : 32);
smem.bar_valid_coord_scales_empty[i].init(IS_PREFILL ? 128 : (128 + (cta_idx==1) + 2 + 128));
}
if constexpr (IS_DECODE) {
CUTE_UNROLL
for (int i = 0; i < NUM_RAW_K_BUFS; ++i) {
smem.bar_raw_KV_full[i].init(1);
smem.bar_raw_KV_empty[i].init(128);
}
}
fence_barrier_init();
}
ku::barrier_cluster_arrive_relaxed();
ku::barrier_cluster_wait_acquire();
struct OuterloopArgs {
bool outer_loop_phase;
int batch_idx, s_q_idx;
int start_block_idx, end_block_idx;
int topk_length;
int extra_topk_length, num_orig_kv_blocks; // extra-KV related
bool is_no_split; int n_split_idx; // splitkv related
};
auto run_outer_loop = [&](auto loop_body) -> bool {
int outer_loop_phase = false;
if constexpr (FWD_MODE == FwdMode::DecodeWithSplitKV) {
int s_q_idx = blockIdx.x / 2;
DecodingSchedMeta sched_meta;
KU_LDG_256(
params.tile_scheduler_metadata_ptr + blockIdx.y,
&sched_meta,
".nc",
"no_allocate",
"evict_normal",
"256B"
);
if (sched_meta.begin_req_idx >= params.b) {
return 0;
}
#pragma unroll 1
for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) {
int topk_length = params.topk_length ? __ldg(params.topk_length + batch_idx) : params.topk;
int orig_topk_padded = max(ku::ceil(topk_length, (int)B_TOPK), (int)B_TOPK);
int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk;
int total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)B_TOPK); // % B_TOPK == 0
int start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0;
int end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : total_topk_padded / B_TOPK;
bool is_split = batch_idx == sched_meta.begin_req_idx ? sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? sched_meta.is_last_req_splitted : false);
int n_split_idx = batch_idx == sched_meta.begin_req_idx ? (__ldg(params.num_splits_ptr+batch_idx) + sched_meta.begin_split_idx) : __ldg(params.num_splits_ptr+batch_idx);
// start_block_idx = 0;
// end_block_idx = total_topk_padded / B_TOPK;
// is_split = false;
// n_split_idx = 0;
OuterloopArgs args = {
(bool)outer_loop_phase,
batch_idx, s_q_idx,
start_block_idx, end_block_idx,
topk_length,
extra_topk_length, orig_topk_padded / B_TOPK,
!is_split, n_split_idx
};
loop_body(args);
outer_loop_phase ^= 1;
}
} else {
// Prefill mode. Use CLC to allocate different s_q (for decoding, different batches + s_q) to different workers
ku::CLCResult next_job = {true, (int)blockIdx.x, IS_PREFILL ? 0 : (int)blockIdx.y, 0};
CUTE_NO_UNROLL
while (next_job.is_valid) {
int s_q_idx = next_job.x / 2;
int batch_idx = IS_PREFILL ? 0 : next_job.y;
int topk_length = params.topk_length != nullptr ? __ldg(params.topk_length + (IS_PREFILL?s_q_idx:batch_idx)) : params.topk;
if constexpr (IS_PREFILL) {
int num_k_blocks = max(cute::ceil_div(topk_length, (int)B_TOPK), 1); // num_k_blocks always >= 1
OuterloopArgs args = {
(bool)outer_loop_phase,
0, s_q_idx,
0, num_k_blocks,
topk_length
};
loop_body(args);
} else {
int orig_topk_padded = max(ku::ceil(topk_length, (int)B_TOPK), (int)B_TOPK);
int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk;
int total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)B_TOPK); // % B_TOPK == 0
OuterloopArgs args = {
(bool)outer_loop_phase,
batch_idx, s_q_idx,
0, total_topk_padded / B_TOPK,
topk_length,
extra_topk_length, orig_topk_padded / B_TOPK,
false, 0
};
loop_body(args);
}
smem.bar_clc_full.wait(outer_loop_phase);
next_job = ku::get_clc_query_response<true>(smem.clc_response_obj);
smem.bar_clc_empty.arrive(0u);
outer_loop_phase ^= 1;
}
}
return outer_loop_phase;
};
if (warpgroup_idx == 0) {
// Q fetching and O writing back warpgroup
cutlass::arch::warpgroup_reg_alloc<176>();
bf16* sO_addrs[B_EPI/8];
CUTE_UNROLL
for (int i = 0; i < B_EPI/8; ++i) {
Tensor sO = make_tensor(make_smem_ptr(smem.Q.data()), ku::make_umma_canonical_k_major_layout<H_Q/2, D_V, 128>());
sO_addrs[i] = &sO(idx_in_warpgroup%64, (idx_in_warpgroup/64)*(D_V/2) + i*8);
}
float* sO_accum_addrs[B_EPI_SPLITKV/4];
if constexpr (FWD_MODE == FwdMode::DecodeWithSplitKV) {
// If split-KV is enabled, we need to store back O in float32
// We view Q buffer (with shape 64 x 512, bf16) as 4 buffers with shape (H_Q/2) x (B_EPI_SPLITKV*2), float32
Tensor sO_accum = make_tensor(make_smem_ptr((float*)smem.Q.data()), ku::make_umma_canonical_k_major_layout<H_Q/2, D_V, 128, float>());
CUTE_UNROLL
for (int i = 0; i < B_EPI_SPLITKV/4; ++i) {
sO_accum_addrs[i] = &sO_accum(idx_in_warpgroup%64, i*4) + (idx_in_warpgroup >= 64 ? (H_Q/2)*B_EPI_SPLITKV : 0);
}
}
auto perform_o_copy_out = [&](const OuterloopArgs &args, bool is_last_o) {
// outer_loop_phase is the loop phase corresponding to s_q_idx
// Get li (output_scale actually)
smem.bar_li_full.wait(args.outer_loop_phase);
float output_scale = smem.rowwise_li_buf[idx_in_warpgroup%64];
float2 output_scale_float2 = float2 {output_scale, output_scale};
smem.bar_li_empty.arrive();
// Retrieve and store O, and calculate delta := sum(O*dO, dim=-1) if FWD_MODE is Recompute
smem.bar_tOut_full.wait(args.outer_loop_phase);
if (is_last_o && elect_one_sync()) {
cudaTriggerProgrammaticLaunchCompletion();
}
if (FWD_MODE != FwdMode::DecodeWithSplitKV || args.is_no_split) {
CUTE_UNROLL
for (int k = 0; k < (D_V/2)/B_EPI; ++k) {
float2 o[B_EPI/2];
ku::tmem_ld_32dp32bNx<B_EPI>(tmem_cols::O + k*B_EPI, o);
cutlass::arch::fence_view_async_tmem_load();
if (k == (D_V/2)/B_EPI-1) {
smem.bar_tOut_empty.arrive(0u);
}
CUTE_UNROLL
for (int i = 0; i < B_EPI/8; ++i) {
nv_bfloat162 o_bf16[4];
CUTE_UNROLL
for (int j = 0; j < 4; ++j) {
o[i*4+j] = ku::float2_mul(o[i*4+j], output_scale_float2);
o_bf16[j] = __float22bfloat162_rn(o[i*4+j]);
}
bf16* o_do_addr = sO_addrs[i] + k*B_EPI*(H_Q/2);
if (k == 0 && i == 0) {
smem.bar_tQ_full.wait(args.outer_loop_phase^1^is_last_o); // Wait for sQ's availability
}
ku::st_shared(o_do_addr, *(__int128_t*)o_bf16);
}
}
fence_view_async_shared();
NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC);
if (warp_idx == 0 && elect_one_sync()) {
SM90_TMA_STORE_5D::copy(
&tma_params.tensor_map_o,
smem.Q.data(),
0, cta_idx*(H_Q/2), 0, args.s_q_idx, IS_DECODE ? args.batch_idx : 0
);
cute::tma_store_arrive();
}
} else {
CUTE_UNROLL
for (int k = 0; k < (D_V/2)/B_EPI_SPLITKV; ++k) {
int cur_buf_idx = k % NUM_EPI_SPLITKV_BUFS;
if (k == 0) {
cute::tma_store_wait<0>();
} else {
cute::tma_store_wait<NUM_EPI_SPLITKV_BUFS-1>();
}
NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC);
float o[B_EPI_SPLITKV];
ku::tmem_ld_32dp32bNx<B_EPI_SPLITKV>(tmem_cols::O + k*B_EPI_SPLITKV, o);
cutlass::arch::fence_view_async_tmem_load();
if (k == (D_V/2)/B_EPI_SPLITKV-1) {
smem.bar_tOut_empty.arrive(0u);
}
CUTE_UNROLL
for (int i = 0; i < B_EPI_SPLITKV/4; ++i) {
CUTE_UNROLL
for (int j = 0; j < 4; j += 2) {
*(float2*)(o + i*4 + j) = ku::float2_mul(float2 {o[i*4+j], o[i*4+j+1]}, output_scale_float2);
}
if (k == 0 && i == 0) {
smem.bar_tQ_full.wait(args.outer_loop_phase^1^is_last_o); // Wait for sQ's availability
}
ku::st_shared(
sO_accum_addrs[i] + cur_buf_idx*((H_Q/2)*B_EPI_SPLITKV*2),
*(__int128_t*)(o + i*4)
);
}
fence_view_async_shared();
NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC);
if constexpr (IS_DECODE) { // Otherwise nvcc complains about `tma_params` doesn't have `tensor_map_o_accum`
float* cur_buf_base = (float*)smem.Q.data() + cur_buf_idx*((H_Q/2)*B_EPI_SPLITKV*2);
if (warp_idx == 0 && elect_one_sync()) {
SM90_TMA_STORE_5D::copy(
&tma_params.tensor_map_o_accum,
cur_buf_base,
0, cta_idx*(H_Q/2), k*(B_EPI_SPLITKV/32), args.s_q_idx, args.n_split_idx
);
cute::tma_store_arrive();
} else if (warp_idx == 1 && elect_one_sync()) {
SM90_TMA_STORE_5D::copy(
&tma_params.tensor_map_o_accum,
cur_buf_base + (H_Q/2)*B_EPI_SPLITKV,
0, cta_idx*(H_Q/2), k*(B_EPI_SPLITKV/32) + (D_V/2)/32, args.s_q_idx, args.n_split_idx
);
cute::tma_store_arrive();
}
}
}
}
};
OuterloopArgs last_args;
last_args.batch_idx = -1;
bool final_outer_loop_phase = \
run_outer_loop([&](const OuterloopArgs &args) {
// Copy Q for this round
if constexpr (FWD_MODE == FwdMode::DecodeWithSplitKV) {
cute::tma_store_wait<0>();
NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC); // Since we use two warps to issue TMA during FwdMode::DecodeWithSplitKV
}
if (warp_idx == 0 && elect_one_sync()) {
// Wait for sQ to become empty, and issue G -> S copy for Q
if constexpr (FWD_MODE != FwdMode::DecodeWithSplitKV) {
cute::tma_store_wait<0>(); // This thread must be the same one as o copy out thread (since `elect_one_sync()` always returns the same thread for the same `mask`, according to PTX document)
}
int stride_q_b_div_stride_q_s_q = 0;
if constexpr (IS_DECODE) {
stride_q_b_div_stride_q_s_q = params.stride_q_b / params.stride_q_s_q;
}
SM100_TMA_2SM_LOAD_5D_NOSPLIT::copy(
&tma_params.tensor_map_q,
(uint64_t*)&smem.bar_sQ_full,
(uint64_t)TMA::CacheHintSm90::EVICT_FIRST,
smem.Q.data(),
0, cta_idx*(H_Q/2), 0, 0, (IS_DECODE ? args.batch_idx*stride_q_b_div_stride_q_s_q : 0) + args.s_q_idx
);
// Wait for sQ to be ready, and issue S -> T copy for Q
if (cta_idx == 0) {
smem.bar_sQ_full.arrive_and_expect_tx(H_Q*D_Q*sizeof(bf16));
smem.bar_sQ_full.wait(args.outer_loop_phase);
smem.bar_tQ_empty.wait(args.outer_loop_phase^1);
ku::tcgen05_after_thread_sync();
UMMA::SmemDescriptor sQ_desc = UMMA::make_umma_desc<UMMA::Major::K>(
make_tensor(
make_smem_ptr(smem.Q.data()),
ku::make_umma_canonical_k_major_layout<(H_Q/2)*2, 64, 128>()
)
);
CUTE_UNROLL
for (int tile_idx = 0; tile_idx < D_Q/64/2; ++tile_idx) {
// A tile is 128 rows * 64 cols in UTCCP's view, or 64 rows * 128 cols in our view
CUTE_UNROLL
for (int subtile_idx = 0; subtile_idx < 4; ++subtile_idx) {
// A subtile is 128 rows * 16 cols (256b, 32B) (in UTCCP's view), or 64 rows * 16 cols * 2 (in our view)
// NOTE Using `sQ_desc+((tile_idx*((H_Q/2)*128*2) + subtile_idx*32) >> 4)` leads to IMA, doesn't know why
UMMA::SmemDescriptor cur_sQ_desc = sQ_desc;
cur_sQ_desc.lo += ((tile_idx*((H_Q/2)*128*2) + subtile_idx*32) >> 4);
// uint64_t cur_sQ_desc = sQ_desc;
// cur_sQ_desc += ((tile_idx*((H_Q/2)*128*2) + subtile_idx*32) >> 4);
SM100_UTCCP_128dp256bit_2cta::copy(
cur_sQ_desc,
tmem_cols::Q + tile_idx*32 + subtile_idx*8
);
}
}
ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_tQ_full, 1|2);
}
}
if (last_args.batch_idx != -1) {
perform_o_copy_out(last_args, false);
} else {
smem.bar_tQ_full.wait(args.outer_loop_phase); // To prevent double arrive
}
last_args = args;
});
if (last_args.batch_idx != -1) {
cute::tma_store_wait<0>();
NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC);
perform_o_copy_out(last_args, true);
}
if (warp_idx == 0) {
cute::TMEM::Allocator2Sm().free(0, 512);
}
} else if (warpgroup_idx == 1) {
// KV fetching threads for prefill, dequant threads for decoding
cutlass::arch::warpgroup_reg_dealloc<80>();
RingBufferState rs;
if constexpr (!IS_DECODE) {
const int warp_idx = cutlass::canonical_warp_idx(); // Using `warp_idx` without `__shfl_sync` is faster
if (elect_one_sync()) {
// KV fetching threads
run_outer_loop([&](const OuterloopArgs &args) {
int* gIndices = params.indices + args.s_q_idx*params.stride_indices_s_q;
int64_t cache_hint = ku::create_simple_cache_policy<ku::CacheHint::EVICT_LAST>();
static constexpr int NUM_ROWS_PER_THREAD = B_TOPK / 4;
CUTE_NO_UNROLL
for (int k = args.start_block_idx; k < args.end_block_idx; ++k) {
auto [k_buf_idx, k_bar_phase] = rs.get<NUM_K_BUFS>();
int cur_indices[NUM_ROWS_PER_THREAD];
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_ROWS_PER_THREAD/8; local_row += 1) {
int row = local_row*(4*8) + (warp_idx-4)*8;
KU_LDG_256(
gIndices + k*B_TOPK + row,
cur_indices + local_row*8,
".nc",
"no_allocate",
"evict_first",
"256B"
);
}
smem.bar_KV_empty[k_buf_idx].wait(k_bar_phase^1);
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_ROWS_PER_THREAD/4; local_row += 1) {
int row = (warp_idx-4)*8 + (local_row/2)*(4*8) + (local_row%2)*4;
int4 indices = *(int4*)(cur_indices+local_row*4);
static_assert(D_K == 512);
CUTE_UNROLL
for (int local_col = 0; local_col < (D_K/64)/2; ++local_col) {
ku::tma_gather4_cta_group_2<true>(
&tma_params.tensor_map_kv,
smem.bar_KV_full[k_buf_idx],
smem.K[k_buf_idx].data() + row*64 + local_col*64*B_TOPK,
local_col*64 + cta_idx*(D_K/2),
indices,
cache_hint
);
}
}
rs.update();
}
});
}
} else {
// 8 threads per token
struct IsCTA0 {};
struct IsCTA1 {};
auto launch_dequant_wg = [&](auto cta_id_t) {
static constexpr bool IS_CTA1 = std::is_same<decltype(cta_id_t), IsCTA1>::value;
constexpr int GROUP_SIZE = 8, NUM_GROUPS = 128/8, ROWS_PER_GROUP = B_TOPK / NUM_GROUPS, COLS_PER_GROUP = (IS_CTA1 ? 256-64 : 256) / (GROUP_SIZE*8);
int group_idx = idx_in_warpgroup/GROUP_SIZE, idx_in_group = idx_in_warpgroup%GROUP_SIZE;
Tensor nope0 = make_tensor(make_smem_ptr(smem.K[0].data()), ku::make_umma_canonical_k_major_layout<B_TOPK, D_K/2, 128>());
bf16* nope0_base = &nope0(group_idx, idx_in_group*8);
fp8_e4m3* raw_nope0_base = smem.K_raw[0].data() + group_idx*(D_K/2) + idx_in_group*8;
run_outer_loop([&](const OuterloopArgs &args) {
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {
auto [k_buf_idx, k_bar_phase] = rs.get<NUM_K_BUFS>();
auto [raw_k_buf_idx, raw_k_bar_phase] = rs.get<NUM_RAW_K_BUFS>();
auto [index_buf_idx, index_bar_phase] = rs.get<NUM_INDEX_BUFS>();
fp8_e4m3* raw_nope_base = raw_nope0_base + raw_k_buf_idx * (B_TOPK*(D_K/2));
auto get_raw_fp8 = [&](int local_row_idx, int local_col_idx) -> uint64_t {
return *(uint64_t*)(raw_nope_base + local_row_idx*NUM_GROUPS*(D_K/2) + local_col_idx*(GROUP_SIZE*8));
};
bf16* nope_base = nope0_base + k_buf_idx * (B_TOPK*(D_K/2));
uint32_t cur_nope_base_uint_addr = cute::cast_smem_ptr_to_uint(nope_base);
auto st_128b = [&](int local_row_idx, int local_col_idx, __int128_t &data) {
asm volatile ("st.weak.shared::cta.b128 [%0], %1;\n"
:
: "r"(cur_nope_base_uint_addr + 2*(local_row_idx*NUM_GROUPS*64 + local_col_idx*B_TOPK*64)), "q"(data) // 2 for sizeof(bf16)
); // We have this `asm volatile` here, otherwise the compiler generates ST.E instead of STS
};
smem.bar_valid_coord_scales_full[index_buf_idx].wait(index_bar_phase);
smem.bar_raw_KV_full[raw_k_buf_idx].wait(raw_k_bar_phase);
CUTE_UNROLL
for (int local_row_idx = 0; local_row_idx < ROWS_PER_GROUP; ++local_row_idx) {
int row_idx = local_row_idx*NUM_GROUPS + group_idx;
bf16 scales[4];
fp8_e8m0 scales_e8m0[4];
*(uint32_t*)scales_e8m0 = *(uint32_t*)(smem.scales[index_buf_idx][row_idx]);
*(__nv_bfloat162_raw*)(scales+0) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+0));
*(__nv_bfloat162_raw*)(scales+2) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+2));
uint64_t cur_data_fp8x8 = get_raw_fp8(local_row_idx, 0);
CUTE_UNROLL
for (int local_col_idx = 0; local_col_idx < COLS_PER_GROUP; ++local_col_idx) {
ku::nve4m3x2 data_fp8[4];
ku::nvbf16x2 data_bf16[4];
*(uint64_t*)data_fp8 = cur_data_fp8x8;
if (local_col_idx+1 < COLS_PER_GROUP)
cur_data_fp8x8 = get_raw_fp8(local_row_idx, local_col_idx+1);
bf16 scale = scales[local_col_idx];
CUTE_UNROLL
for (int i = 0; i < 4; ++i) {
data_bf16[i] = fp8x2_to_bf16x2_with_scale(data_fp8[i], *(ku::nvbf16*)(&scale));
}
if (local_row_idx == 0 && local_col_idx == 0) {
smem.bar_KV_empty[k_buf_idx].wait(k_bar_phase^1);
}
st_128b(local_row_idx, local_col_idx, *(__int128_t*)data_bf16);
}
}
fence_view_async_shared(); // NOTE Should we use shared::cluster here?
__syncwarp();
smem.bar_valid_coord_scales_empty[index_buf_idx].arrive();
smem.bar_raw_KV_empty[raw_k_buf_idx].arrive();
if (elect_one_sync()) {
smem.bar_KV_full[k_buf_idx].arrive(0u);
}
rs.update();
}
});
};
if (cta_idx == 0) {
launch_dequant_wg(IsCTA0{});
} else {
launch_dequant_wg(IsCTA1{});
}
}
} else if (warpgroup_idx == 2) {
cutlass::arch::warpgroup_reg_dealloc<80>();
RingBufferState rs;
if (warp_idx == 8 && cta_idx == 0 && elect_one_sync()) {
// UMMA thread
TiledMMA tiled_mma_P = TiledMMA_P{};
TiledMMA tiled_mma_O = TiledMMA_O{};
Tensor tP = partition_fragment_C(tiled_mma_P, Shape<Int<H_Q/2>, Int<B_TOPK*2>>{});
Tensor tO = partition_fragment_C(tiled_mma_O, Shape<Int<H_Q/2>, Int<D_V>>{});
Tensor tQ = tiled_mma_P.get_slice(_0{}).make_fragment_A(
partition_shape_A(tiled_mma_P, Shape<Int<H_Q/2>, Int<D_Q/2>>{})
);
tP.data().get() = tmem_cols::P;
tO.data().get() = tmem_cols::O;
tQ.data().get() = tmem_cols::Q;
run_outer_loop([&](const OuterloopArgs &args) {
smem.bar_tQ_full.wait(args.outer_loop_phase);
// Issue P = Q K^T
auto issue_P = [&](int k, int rs_offset) {
auto [k_buf_idx, k_bar_phase] = rs.offset_by(rs_offset).get<NUM_K_BUFS>();
auto [_, bar_phase] = rs.offset_by(rs_offset).get<1>();
smem.bar_P_empty.wait(bar_phase^1);
if constexpr (IS_PREFILL) {
smem.bar_KV_full[k_buf_idx].arrive_and_expect_tx(B_TOPK*D_K*sizeof(bf16));
} else {
// RoPE only
smem.bar_KV_full[k_buf_idx].arrive_and_expect_tx(B_TOPK*D_ROPE*sizeof(bf16));
}
smem.bar_KV_full[k_buf_idx].wait(k_bar_phase);
ku::tcgen05_after_thread_sync();
Tensor sK = make_tensor(
make_smem_ptr(smem.K[k_buf_idx].data()),
ku::make_umma_canonical_k_major_layout<B_TOPK, D_K/2, 128>()
);
ku::utcmma_ts(tiled_mma_P, tQ, sK, tP, true);
ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_QK_done, 1|2);
};
// Issue O += S V
auto issue_O = [&](int k, int rs_offset) {
auto [k_buf_idx, k_bar_phase] = rs.offset_by(rs_offset).get<NUM_K_BUFS>();
auto [_, bar_phase] = rs.offset_by(rs_offset).get<1>();
smem.bar_S_O_full.wait(bar_phase);
if (k == args.start_block_idx) {
smem.bar_tOut_empty.wait(args.outer_loop_phase^1);
}
ku::tcgen05_after_thread_sync();
Tensor sS = make_tensor(
make_smem_ptr(smem.S.data()),
ku::make_umma_canonical_k_major_layout<H_Q/2, B_TOPK, 0>()
);
Tensor sV = make_tensor(
make_smem_ptr(smem.K[k_buf_idx].data()),
ku::make_umma_canonical_mn_major_layout<D_V/2, B_TOPK, 128>()
);
ku::utcmma_ss(tiled_mma_O, sS, sV, tO, k == args.start_block_idx);
ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_SV_done, 1|2);
ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_KV_empty[k_buf_idx], 1|2);
};
CUTE_NO_UNROLL
for (int k = args.start_block_idx; k < args.end_block_idx+1; ++k) {
if (k < args.end_block_idx) {
issue_P(k, 0);
}
if (k == args.end_block_idx-1) {
ku::umma_arrive_2x1SM_noelect(smem.bar_tQ_empty);
}
if (k > args.start_block_idx) {
issue_O(k-1, -1);
}
if (k != args.end_block_idx) {
rs.update();
}
}
ku::tcgen05_before_thread_sync();
ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_tOut_full, 1|2);
});
} else if (warp_idx == 8 && cta_idx == 1 && elect_one_sync()) {
if constexpr (IS_DECODE) {
// KV RoPE fetching warp
run_outer_loop([&](const OuterloopArgs &args) {
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {
auto [index_buf_idx, index_bar_phase] = rs.get<NUM_INDEX_BUFS>();
auto [k_buf_idx, k_bar_phase] = rs.get<NUM_K_BUFS>();
smem.bar_valid_coord_scales_full[index_buf_idx].wait(index_bar_phase);
smem.bar_KV_empty[k_buf_idx].wait(k_bar_phase^1);
CUTE_UNROLL
for (int row = 0; row < B_TOPK; row += 4) {
int4 cur_indices = *(int4*)(smem.tma_coord[index_buf_idx] + row);
ku::tma_gather4_cta_group_2<true>(
block_idx >= args.num_orig_kv_blocks ? &tma_params.tensor_map_extra_kv_rope : &tma_params.tensor_map_kv_rope,
smem.bar_KV_full[k_buf_idx],
smem.K[k_buf_idx].data() + (D_NOPE-D_K/2)*B_TOPK + row*D_ROPE,
0,
cur_indices,
(int64_t)TMA::CacheHintSm90::EVICT_LAST
);
}
smem.bar_valid_coord_scales_empty[index_buf_idx].arrive();
rs.update();
}
});
}
} else if (warp_idx == 9) {
// KV validness loading warp (for prefill), Indices transformation warp (for decode, Responsible for generating: TMA coordinates, scale factors, and valid masks)
if constexpr (IS_PREFILL) {
if (lane_idx < B_TOPK/8) {
run_outer_loop([&](const OuterloopArgs &args) {
int* gIndices = params.indices + args.s_q_idx*params.stride_indices_s_q;
CUTE_NO_UNROLL
for (int k = args.start_block_idx; k < args.end_block_idx; ++k) {
char k_validness_mask = load_indices_and_generate_mask(
lane_idx,
gIndices + k*B_TOPK,
params.s_kv,
k*B_TOPK,
args.topk_length
);
auto [indices_buf_idx, indices_bar_phase] = rs.get<NUM_INDEX_BUFS>();
smem.bar_valid_coord_scales_empty[indices_buf_idx].wait(indices_bar_phase^1);
smem.is_k_valid[indices_buf_idx][lane_idx] = k_validness_mask;
smem.bar_valid_coord_scales_full[indices_buf_idx].arrive();
rs.update();
}
});
}
} else {
static_assert(B_TOPK == 64);
// Each thread is responsible for 2 tokens
static constexpr int tma_coords_step_per_token = 576/TMA_K_STRIDE_FOR_DECODING;
int tma_coords_step_per_block = params.stride_kv_block / TMA_K_STRIDE_FOR_DECODING; // must < 2G since k_batch_stride < 1T and TMA_K_STRIDE_FOR_DECODING > 512
int tma_coords_step_per_extra_block = params.stride_extra_kv_block / TMA_K_STRIDE_FOR_DECODING;
uint8_t* k_scales_ptr = (uint8_t*)params.kv + params.page_block_size*(D_NOPE+2*D_ROPE);
uint8_t* extra_k_scales_ptr = (uint8_t*)params.extra_kv + params.extra_page_block_size*(D_NOPE+2*D_ROPE);
run_outer_loop([&](const OuterloopArgs &args) {
int* indices = (int*)params.indices + params.stride_indices_b*args.batch_idx + params.stride_indices_s_q*args.s_q_idx;
int* extra_indices = (int*)params.extra_indices + params.stride_extra_indices_b*args.batch_idx + params.stride_extra_indices_s_q*args.s_q_idx;
struct IsOrigBlock {};
struct IsExtraBlock {};
auto process_one_block = [&](int block_idx, auto is_extra_block_t) {
auto [index_buf_idx, index_bar_phase] = rs.get<NUM_INDEX_BUFS>();
static constexpr bool IS_EXTRA_BLOCK = std::is_same_v<decltype(is_extra_block_t), IsExtraBlock>;
int cur_block_size = IS_EXTRA_BLOCK ? params.extra_page_block_size : params.page_block_size;
int64_t cur_k_block_stride = IS_EXTRA_BLOCK ? params.stride_extra_kv_block : params.stride_kv_block;
[[maybe_unused]] int cur_k_row_stride = IS_EXTRA_BLOCK ? params.stride_extra_kv_row : params.stride_kv_row;
uint8_t* cur_k_scales_ptr = IS_EXTRA_BLOCK ? extra_k_scales_ptr : k_scales_ptr;
int cur_tma_coords_step_per_block = IS_EXTRA_BLOCK ? tma_coords_step_per_extra_block : tma_coords_step_per_block;
int abs_pos, my_indices[2];
if (!IS_EXTRA_BLOCK) {
abs_pos = block_idx*B_TOPK + lane_idx*2;
*(int2*)my_indices = __ldg((int2*)(indices + abs_pos));
} else {
abs_pos = (block_idx-args.num_orig_kv_blocks)*B_TOPK + lane_idx*2;
*(int2*)my_indices = __ldg((int2*)(extra_indices + abs_pos));
}
smem.bar_valid_coord_scales_empty[index_buf_idx].wait(index_bar_phase^1);
int tma_coords[2];
fp8_e8m0 scales[2*(NUM_SCALES_EACH_TOKEN/2)];
char valid_mask = 0;
CUTE_UNROLL
for (int i = 0; i < 2; ++i) {
int block_idx, idx_in_block;
block_idx = (unsigned int)my_indices[i] / cur_block_size;
idx_in_block = (unsigned int)my_indices[i] % cur_block_size;
bool is_token_valid = my_indices[i] != -1 && (abs_pos+i < (IS_EXTRA_BLOCK?args.extra_topk_length:args.topk_length));
valid_mask |= is_token_valid << i;
tma_coords[i] = is_token_valid ? block_idx*cur_tma_coords_step_per_block + idx_in_block*tma_coords_step_per_token : -1; // If the token is invalid because it topk position exceeds topk_length, we must manually fill tma_coords with -1 to avoid copying-in NaN.
int64_t offset = block_idx*cur_k_block_stride + (idx_in_block*8 + (cta_idx == 1 ? 4 : 0)); // Each token has 7 scale factors with an extra 1B padding
uint32_t scalesx4 = is_token_valid ? __ldg((uint32_t*)(cur_k_scales_ptr + offset)) : 0;
*(uint32_t*)(scales+i*(NUM_SCALES_EACH_TOKEN/2)) = scalesx4;
}
valid_mask <<= lane_idx%4*2;
valid_mask |= __shfl_xor_sync(0xFFFFFFFF, valid_mask, 0x1);
valid_mask |= __shfl_xor_sync(0xFFFFFFFF, valid_mask, 0x2);
*(uint64_t*)(smem.scales[index_buf_idx] + lane_idx*2) = *(uint64_t*)scales;
*(int2*)(smem.tma_coord[index_buf_idx] + lane_idx*2) = *(int2*)tma_coords;
if (lane_idx%4 == 0)
smem.is_k_valid[index_buf_idx][lane_idx/4] = valid_mask;
smem.bar_valid_coord_scales_full[index_buf_idx].arrive();
rs.update();
};
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < min(args.num_orig_kv_blocks, args.end_block_idx); ++block_idx) {
process_one_block(block_idx, IsOrigBlock{});
}
CUTE_NO_UNROLL
for (int block_idx = max(args.start_block_idx, args.num_orig_kv_blocks); block_idx < args.end_block_idx; ++block_idx) {
process_one_block(block_idx, IsExtraBlock{});
}
});
}
} else if (warp_idx >= 10 && elect_one_sync()) {
if constexpr (IS_PREFILL) {
if (warp_idx == 10) {
// CLC Producer thread
run_outer_loop([&](const OuterloopArgs &args) {
if (cta_idx == 0) {
smem.bar_clc_empty.wait(args.outer_loop_phase^1);
ku::issue_clc_query_multicast_cluster_all(smem.bar_clc_full, smem.clc_response_obj);
}
smem.bar_clc_full.arrive_and_expect_tx(sizeof(smem.clc_response_obj));
});
}
} else {
// Raw KV NoPE Producer thread
run_outer_loop([&](const OuterloopArgs &args) {
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {
auto [raw_k_buf_idx, raw_k_bar_phase] = rs.get<NUM_RAW_K_BUFS>();
auto [index_buf_idx, index_bar_phase] = rs.get<NUM_INDEX_BUFS>();
smem.bar_valid_coord_scales_full[index_buf_idx].wait(index_bar_phase);
smem.bar_raw_KV_empty[raw_k_buf_idx].wait(raw_k_bar_phase^1);
int4 nxt_indices = *(int4*)(smem.tma_coord[index_buf_idx] + (warp_idx == 10 ? 0 : 4));
CUTE_UNROLL
for (int row = (warp_idx == 10 ? 0 : 4); row < B_TOPK; row += 8) {
int4 cur_indices = nxt_indices;
if (row+8 < B_TOPK)
nxt_indices = *(int4*)(smem.tma_coord[index_buf_idx] + row + 8);
ku::tma_gather4(
block_idx >= args.num_orig_kv_blocks ? &tma_params.tensor_map_extra_kv_nope : &tma_params.tensor_map_kv_nope,
smem.bar_raw_KV_full[raw_k_buf_idx],
smem.K_raw[raw_k_buf_idx].data() + row*(D_K/2),
cta_idx*(D_K/2),
cur_indices,
(int64_t)TMA::CacheHintSm90::EVICT_LAST
);
}
if (warp_idx == 10) {
smem.bar_raw_KV_full[raw_k_buf_idx].arrive_and_expect_tx(B_TOPK*(D_K/2)*sizeof(fp8_e4m3));
}
smem.bar_valid_coord_scales_empty[index_buf_idx].arrive();
rs.update();
}
});
}
}
} else {
// Scale & Exp threads
cutlass::arch::warpgroup_reg_alloc<176>();
int local_warp_idx = warp_idx - 12;
bf16* sS_base = smem.S.data() + (local_warp_idx >= 2 ? (H_Q/2)*(B_TOPK/2) : 0) + (idx_in_warpgroup%64)*8;
RingBufferState rs;
run_outer_loop([&](const OuterloopArgs &args) {
// For definition and consistency about `mi`, `li`, and `real_mi`, plz refer to head64 prefill
float mi = MAX_INIT_VAL;
float li = 0.0f;
float real_mi = -CUDART_INF_F;
static constexpr int NUM_ELEMS_PER_THREAD = B_TOPK / 2;
CUTE_NO_UNROLL
for (int k = args.start_block_idx; k < args.end_block_idx; ++k) {
auto [k_buf_idx, k_bar_phase] = rs.get<NUM_K_BUFS>();
auto [indices_buf_idx, indices_bar_phase] = rs.get<NUM_INDEX_BUFS>();
auto [_, bar_phase] = rs.get<1>();
// NOTE We don't need to sync for Prefill mode, since we have two synchronizations inside the loop body (one for p_exchange_buf sync, another one for rowwise_max_buf sync). The latter one guarantees the emptyness of p_exchange_buf and the former one guarantees the emptyness of rowwise_max_buf
smem.bar_valid_coord_scales_full[indices_buf_idx].wait(indices_bar_phase);
// Get P from TMEM
float p[NUM_ELEMS_PER_THREAD];
smem.bar_QK_done.wait(bar_phase);
ku::tcgen05_after_thread_sync();
retrieve_mask_and_reduce_p<
NUM_ELEMS_PER_THREAD,
tmem_cols::P,
barrier_ids::WG2_WARP02_SYNC,
barrier_ids::WG2_WARP13_SYNC,
false
>(
smem.is_k_valid[indices_buf_idx],
local_warp_idx,
lane_idx,
[&]() {smem.bar_P_empty.arrive(0u);},
smem.P_exchange,
p
);
// Get rowwise max of P
float cur_pi_max = get_max<NUM_ELEMS_PER_THREAD>(p);
cur_pi_max *= params.sm_scale_div_log2;
smem.rowwise_max_buf[idx_in_warpgroup] = cur_pi_max;
NamedBarrier::arrive_and_wait(64, barrier_ids::WG2_WARP02_SYNC + (local_warp_idx&1));
cur_pi_max = max(cur_pi_max, smem.rowwise_max_buf[idx_in_warpgroup^64]);
real_mi = max(real_mi, cur_pi_max);
bool should_scale_o = __any_sync(0xffffffff, cur_pi_max - mi > 6.0f);
// Calc scale factor, and scale li
float new_max, scale_for_old;
if (!should_scale_o) {
// Don't scale O
scale_for_old = 1.0f;
new_max = mi;
} else {
new_max = max(cur_pi_max, mi);
scale_for_old = exp2f(mi - new_max);
}
mi = new_max; // mi is still identical within each row
// Calculate S
nv_bfloat162 s[NUM_ELEMS_PER_THREAD/2];
float cur_sum = get_s_from_p<NUM_ELEMS_PER_THREAD>(s, p, params.sm_scale_div_log2, new_max);
li = fmaf(li, scale_for_old, cur_sum);
// Store S
smem.bar_SV_done.wait(bar_phase^1);
CUTE_UNROLL
for (int i = 0; i < NUM_ELEMS_PER_THREAD/8; ++i) {
ku::st_shared(sS_base + i*8*(H_Q/2), *(__int128_t*)(s + i*4));
}
// Rescale O
if (k > 0 && should_scale_o) {
ku::tcgen05_after_thread_sync();
rescale_O<D_V, 32, tmem_cols::O>(scale_for_old);
ku::tcgen05_before_thread_sync();
}
fence_view_async_shared();
smem.bar_S_O_full.arrive(0u);
smem.bar_valid_coord_scales_empty[indices_buf_idx].arrive();
rs.update();
}
if (real_mi == -CUDART_INF_F) {
// real_mi == -CUDART_INF_F <=> No valid TopK indices
// We set li to 0 to fit the definition that li := exp(x[i] - mi)
li = 0.0f;
mi = -CUDART_INF_F;
}
// Reduce li
smem.bar_li_empty.wait(args.outer_loop_phase^1);
smem.rowwise_li_buf[idx_in_warpgroup^64] = li;
NamedBarrier::arrive_and_wait(128, barrier_ids::WG2_SYNC);
li += smem.rowwise_li_buf[idx_in_warpgroup];
if (idx_in_warpgroup < H_Q/2) {
// Calculate output_scale and save
int head_idx = cta_idx*(H_Q/2) + idx_in_warpgroup;
float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : __ldg(params.attn_sink + head_idx);
float output_scale;
if (FWD_MODE != FwdMode::DecodeWithSplitKV || args.is_no_split) {
output_scale = __fdividef(1.0f, li + exp2f(fmaf(attn_sink, CUDART_L2E_F, -mi)));
} else {
output_scale = __fdividef(1.0f, li);
}
smem.rowwise_li_buf[idx_in_warpgroup] = li == 0.0f ? 0.0f : output_scale;
smem.bar_li_full.arrive();
float cur_lse = fmaf(mi, CUDART_LN2_F, logf(li));
cur_lse = cur_lse == -CUDART_INF_F ? +CUDART_INF_F : cur_lse;
if constexpr (IS_PREFILL) {
int global_index = args.s_q_idx*params.h_q + head_idx;
params.max_logits[global_index] = real_mi*CUDART_LN2_F;
params.lse[global_index] = cur_lse;
} else {
if (FWD_MODE != FwdMode::DecodeWithSplitKV || args.is_no_split) {
params.lse[args.batch_idx*params.stride_lse_b + args.s_q_idx*params.stride_lse_s_q + head_idx] = cur_lse;
} else {
float cur_lse_2base = log2f(li) + mi;
params.lse_accum[args.n_split_idx*params.stride_lse_accum_split + args.s_q_idx*params.stride_lse_accum_s_q + head_idx] = cur_lse_2base;
}
}
}
});
}
ku::barrier_cluster_arrive_relaxed();
ku::barrier_cluster_wait_acquire();
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100");
}
#endif
}
// We have two launchers with different kernel names to distinguish prefill and decode
template<typename Kernel>
static __global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 2)
sparse_attn_fwd_for_small_topk_kernel(__grid_constant__ const typename Kernel::ArgT params, __grid_constant__ const typename Kernel::TmaParams tma_params) {
Kernel::sparse_attn_fwd_kernel_devfunc(params, tma_params);
}
template<typename Kernel>
static __global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 2)
flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const typename Kernel::ArgT params, __grid_constant__ const typename Kernel::TmaParams tma_params) {
Kernel::sparse_attn_fwd_kernel_devfunc(params, tma_params);
}
template<FwdMode FWD_MODE, int D_QK>
void KernelTemplate<FWD_MODE, D_QK>::run(const ArgT& params) {
static_assert(D_QK == 576 || D_QK == 512);
KU_ASSERT(params.h_kv == 1);
KU_ASSERT(params.topk % B_TOPK == 0); // To save some boundry checkings
KU_ASSERT(params.h_q == H_Q); // To save some calculation
KU_ASSERT(params.d_qk == D_QK);
static_assert(D_Q == 512);
CUtensorMap tensor_map_q;
if constexpr (IS_DECODE) {
KU_ASSERT(params.stride_q_b % params.stride_q_s_q == 0, "In decode mode for MODEL1 sparse fp8 decoding on sm100f, q.stride(0) (on the batch dimension) must be divisible by q.stride(1) (on the sequence dimension).");
tensor_map_q = ku::make_tensor_map(
{64ul, H_Q, 2ul, (D_Q/64ul)/2ul, (unsigned long)params.b * (params.stride_q_b / params.stride_q_s_q)},
ku::make_stride_helper<int>({params.stride_q_h_q, D_Q/2, 64, params.stride_q_s_q}, sizeof(bf16)),
{64, H_Q/2, 2, (D_Q/64)/2, 1},
params.q,
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
CU_TENSOR_MAP_SWIZZLE_128B,
CU_TENSOR_MAP_L2_PROMOTION_L2_256B
);
} else {
tensor_map_q = ku::make_tensor_map(
{64ul, H_Q, 2ul, (D_Q/64ul)/2ul, (unsigned long)params.s_q},
ku::make_stride_helper<int>({params.stride_q_h_q, D_Q/2, 64, params.stride_q_s_q}, sizeof(bf16)),
{64, H_Q/2, 2, (D_Q/64)/2, 1},
params.q,
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
CU_TENSOR_MAP_SWIZZLE_128B,
CU_TENSOR_MAP_L2_PROMOTION_L2_256B
); // We use this layout to group Q[0:64] and Q[256:256+64] together, for UTCCP for dual gemm
}
CUtensorMap tensor_map_kv;
CUtensorMap tensor_map_kv_nope, tensor_map_kv_rope, tensor_map_extra_kv_nope = {}, tensor_map_extra_kv_rope = {};
if constexpr (IS_DECODE) {
auto get_kv_tensormap = [&](bool is_extra, void* k_ptr, int num_blocks, int64_t stride_kv_block, int64_t stride_kv_row) -> std::pair<CUtensorMap, CUtensorMap> {
KU_ASSERT((int64_t)k_ptr % 16 == 0, "The base address of %sk_ptr (%p) must be 16B aligned for sparse fp8 attention on sm100f", is_extra?"extra_":"", k_ptr);
KU_ASSERT(stride_kv_block % TMA_K_STRIDE_FOR_DECODING == 0, "%sk_cache.stride(0) (%ld) must be a multiple of %d. Padding might be necessary", is_extra?"extra_":"", stride_kv_block, TMA_K_STRIDE_FOR_DECODING);
CUtensorMap tensor_map_kv_nope = ku::make_tensor_map(
{D_NOPE + D_ROPE*2, (uint64_t)num_blocks * (stride_kv_block/TMA_K_STRIDE_FOR_DECODING)},
{TMA_K_STRIDE_FOR_DECODING},
{D_K/2, 1},
k_ptr,
CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B
); // NOTE: Here we use `D_NOPE+D_ROPE*2` as the box shape instead of D_NOPE because it's actually faster. I think that's because, if we use `D_NOPE+D_ROPE*2`, we can prefetch part of the RoPE part of the selected tokens.
CUtensorMap tensor_map_kv_rope = ku::make_tensor_map(
{D_ROPE, (uint64_t)num_blocks * (stride_kv_block/TMA_K_STRIDE_FOR_DECODING)},
{TMA_K_STRIDE_FOR_DECODING},
{64, 1},
(uint8_t*)k_ptr + D_NOPE,
CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B
);
return {tensor_map_kv_nope, tensor_map_kv_rope};
};
std::tie(tensor_map_kv_nope, tensor_map_kv_rope) = get_kv_tensormap(false, params.kv, params.num_blocks, params.stride_kv_block, params.stride_kv_row);
if (params.extra_topk > 0)
std::tie(tensor_map_extra_kv_nope, tensor_map_extra_kv_rope) = get_kv_tensormap(true, params.extra_kv, params.extra_num_blocks, params.stride_extra_kv_block, params.stride_extra_kv_row);
} else {
tensor_map_kv = ku::make_tensor_map(
{D_QK, (unsigned long)params.s_kv},
{(unsigned long)params.stride_kv_s_kv*sizeof(bf16)},
{64, 1},
params.kv,
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
CU_TENSOR_MAP_SWIZZLE_128B,
CU_TENSOR_MAP_L2_PROMOTION_L2_256B
);
}
CUtensorMap tensor_map_o;
if constexpr (IS_DECODE) {
tensor_map_o = ku::make_tensor_map(
{64, H_Q, D_V/64, (unsigned long)params.s_q, (unsigned long)params.b},
ku::make_stride_helper<int>({params.stride_o_h_q, 64, params.stride_o_s_q, params.stride_o_b}, sizeof(bf16)),
{64, H_Q/2, D_V/64, 1, 1},
params.out,
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
CU_TENSOR_MAP_SWIZZLE_128B,
CU_TENSOR_MAP_L2_PROMOTION_L2_256B
);
} else {
tensor_map_o = ku::make_tensor_map(
{64, H_Q, D_V/64, (unsigned long)params.s_q, 1ul},
ku::make_stride_helper<int>({D_V, 64, H_Q*D_V, H_Q*D_V}, sizeof(bf16)),
{64, H_Q/2, D_V/64, 1, 1},
params.out,
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
CU_TENSOR_MAP_SWIZZLE_128B,
CU_TENSOR_MAP_L2_PROMOTION_L2_256B
);
}
CUtensorMap tensor_map_o_accum = {};
if constexpr (FWD_MODE == FwdMode::DecodeWithSplitKV) {
tensor_map_o_accum = ku::make_tensor_map(
{32, H_Q, D_V/32, (unsigned long)params.s_q, (unsigned long)params.num_sm_parts + params.b},
ku::make_stride_helper<int>({params.stride_o_accum_h_q, 32, params.stride_o_accum_s_q, params.stride_o_accum_split}, sizeof(float)),
{32, H_Q/2, B_EPI_SPLITKV/32, 1, 1},
params.o_accum,
CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
CU_TENSOR_MAP_SWIZZLE_128B,
CU_TENSOR_MAP_L2_PROMOTION_L2_256B
);
}
TmaParams tma_params;
if constexpr (IS_DECODE) {
tma_params = {
tensor_map_q,
tensor_map_o,
tensor_map_o_accum,
tensor_map_kv_nope,
tensor_map_kv_rope,
tensor_map_extra_kv_nope,
tensor_map_extra_kv_rope
};
} else {
tma_params = {
tensor_map_q,
tensor_map_kv,
tensor_map_o
};
}
auto kernel = IS_PREFILL ? &sparse_attn_fwd_for_small_topk_kernel<KernelTemplate<FWD_MODE, D_QK>> : &flash_fwd_splitkv_mla_fp8_sparse_kernel<KernelTemplate<FWD_MODE, D_QK>>;
constexpr size_t smem_size = sizeof(SharedMemoryPlan);
KU_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
dim3 grid_shape;
if constexpr (IS_DECODE) {
grid_shape = dim3(2*params.s_q, FWD_MODE == FwdMode::DecodeWithSplitKV ? params.num_sm_parts : params.b, 1);
} else {
grid_shape = dim3(2*params.s_q, 1, 1);
}
cutlass::ClusterLaunchParams launch_params = {
grid_shape,
dim3(NUM_THREADS, 1, 1),
dim3(2, 1, 1),
smem_size,
params.stream
};
KU_CUTLASS_CHECK(cutlass::launch_kernel_on_cluster(
launch_params, (void*)kernel, params, tma_params
));
}
template<FwdMode FWD_MODE, int D_QK>
void run_fwd_for_small_topk_phase1_kernel(const SparseFwdArgT<FWD_MODE>& params) {
using Kernel = KernelTemplate<FWD_MODE, D_QK>;
Kernel::run(params);
}
}
#pragma once
#include "params.h"
namespace sm100::fwd_for_small_topk::head128 {
template<SparseAttnFwdMode FWD_MODE, int D_QK>
void run_fwd_for_small_topk_phase1_kernel(const SparseFwdArgT<FWD_MODE>& params);
}
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
#include "traits.h" #include "traits.h"
using namespace cute; using namespace cute;
using cutlass::arch::NamedBarrier;
namespace sm90 { namespace sm90 {
...@@ -17,1258 +16,11 @@ namespace sm90 { ...@@ -17,1258 +16,11 @@ namespace sm90 {
static constexpr float MAX_INIT_VAL_SM = -1e30f; static constexpr float MAX_INIT_VAL_SM = -1e30f;
static constexpr float MAX_INIT_VAL = -1e33f; static constexpr float MAX_INIT_VAL = -1e33f;
__forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) {
// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4);
return row_idx;
}
// Launch TMA copy for a range of KV tile
// A tile has a shape of PAGE_BLOCK_SIZE (64) x 64
template<
int START_HEAD_DIM_TILE_IDX,
int END_HEAD_DIM_TILE_IDX,
typename TMA_K_OneTile,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1
>
__forceinline__ __device__ void launch_kv_tiles_copy_tma(
Tensor<Engine0, Layout0> const &gKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K)
Tensor<Engine1, Layout1> &sKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K), swizzled
TMA_K_OneTile &tma_K,
TMABarrier* barriers_K,
int idx_in_warpgroup
) {
if (idx_in_warpgroup == 0) {
auto thr_tma = tma_K.get_slice(_0{});
Tensor cur_gKV = thr_tma.partition_S(gKV)(_, _0{}, Int<START_HEAD_DIM_TILE_IDX>{});
Tensor cur_sKV = thr_tma.partition_D(sKV)(_, _0{}, Int<START_HEAD_DIM_TILE_IDX>{});
cute::copy(tma_K.with(reinterpret_cast<typename TMABarrier::ValueType &>(barriers_K[START_HEAD_DIM_TILE_IDX]), 0, cute::TMA::CacheHintSm90::EVICT_FIRST), cur_gKV, cur_sKV);
if constexpr (START_HEAD_DIM_TILE_IDX+1 < END_HEAD_DIM_TILE_IDX) {
launch_kv_tiles_copy_tma<START_HEAD_DIM_TILE_IDX+1, END_HEAD_DIM_TILE_IDX>(gKV, sKV, tma_K, barriers_K, idx_in_warpgroup);
}
}
}
// Prefetch some KV tiles
// Currently this is not used because it leads to performance degradation
template<
int START_HEAD_DIM_TILE_IDX,
int END_HEAD_DIM_TILE_IDX,
typename TMA_K_OneTile,
typename Engine0, typename Layout0
>
__forceinline__ __device__ void prefetch_kv_tiles(
Tensor<Engine0, Layout0> const &gKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K)
TMA_K_OneTile &tma_K,
int idx_in_warpgroup
) {
if (idx_in_warpgroup == 0) {
auto thr_tma = tma_K.get_slice(_0{});
Tensor cur_gKV = thr_tma.partition_S(gKV)(_, _0{}, Int<START_HEAD_DIM_TILE_IDX>{});
cute::prefetch(tma_K, cur_gKV);
if constexpr (START_HEAD_DIM_TILE_IDX+1 < END_HEAD_DIM_TILE_IDX) {
prefetch_kv_tiles<START_HEAD_DIM_TILE_IDX+1, END_HEAD_DIM_TILE_IDX>(gKV, tma_K, idx_in_warpgroup);
}
}
}
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h
// * Copyright (c) 2024, Tri Dao.
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC);
if constexpr (arrive) {
warpgroup_arrive();
}
if constexpr (zero_init) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
} else {
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
}
if constexpr (commit) {
warpgroup_commit_batch();
}
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}
// Wait for one KV-tile to be ready, and then calculate P += Q K^T for one Q-tile (BLOCK_SIZE_Mx64) and one KV-tile (PAGE_BLOCK_SIZEx64)
// The Q-tile should be in shared memory
template<
typename TiledMMA,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1,
typename Engine2, typename Layout2
>
__forceinline__ __device__ void qkt_gemm_one_tile_sQ(
TiledMMA &tiled_mma,
Tensor<Engine0, Layout0> const &thr_mma_sQ_tile, // (MMA, 1, 4)
Tensor<Engine1, Layout1> const &thr_mma_sKV_tile, // (MMA, 1, 4)
Tensor<Engine2, Layout2> &rP, // ((2, 2, 8), 1, 1)
TMABarrier* barrier,
bool &cur_phase,
int idx_in_warpgroup
) {
if (idx_in_warpgroup == 0) {
barrier->arrive_and_expect_tx(64*64*2);
}
barrier->wait(cur_phase ? 1 : 0);
warpgroup_fence_operand(rP);
warpgroup_arrive();
cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _0{}), thr_mma_sKV_tile(_, _, _0{}), rP);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _1{}), thr_mma_sKV_tile(_, _, _1{}), rP);
cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _2{}), thr_mma_sKV_tile(_, _, _2{}), rP);
cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _3{}), thr_mma_sKV_tile(_, _, _3{}), rP);
warpgroup_commit_batch();
warpgroup_fence_operand(rP);
}
template<
typename TiledMMA,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1,
typename Engine2, typename Layout2
>
__forceinline__ __device__ void qkt_gemm_one_tile_rQ(
TiledMMA &tiled_mma,
Tensor<Engine0, Layout0> const &thr_mma_rQ_tile, // (MMA, 1, 4)
Tensor<Engine1, Layout1> const &thr_mma_sKV_tile, // (MMA, 1, 4)
Tensor<Engine2, Layout2> &rP, // ((2, 2, 8), 1, 1)
TMABarrier* barrier,
bool &cur_phase,
int idx_in_warpgroup
) {
if (idx_in_warpgroup == 0) {
barrier->arrive_and_expect_tx(64*64*2);
}
barrier->wait(cur_phase ? 1 : 0);
warpgroup_fence_operand(const_cast<Tensor<Engine0, Layout0> &>(thr_mma_rQ_tile));
warpgroup_fence_operand(rP);
warpgroup_arrive();
cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _0{}), thr_mma_sKV_tile(_, _, _0{}), rP);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _1{}), thr_mma_sKV_tile(_, _, _1{}), rP);
cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _2{}), thr_mma_sKV_tile(_, _, _2{}), rP);
cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _3{}), thr_mma_sKV_tile(_, _, _3{}), rP);
warpgroup_commit_batch();
warpgroup_fence_operand(rP);
warpgroup_fence_operand(const_cast<Tensor<Engine0, Layout0> &>(thr_mma_rQ_tile));
}
// Pipelined TMA wait and Q K^T gemm
// In order to overlap memory copy (G->S copy for K) and computation, we divide both Q and K into tiles of shape (BLOCK_SIZE_M, 64), and (PAGE_BLOCK_SIZE, 64) respectively, and then do the computation as follows:
// - Wait for the 0-th tile to be ready using `barrier.wait()`
// - Compute Q K^T for the 0-th tile
// - Wait for the 1-st tile to be ready
// - Compute Q K^T for the 1-st tile
// ...
// This gives latter tiles more time to be ready, and thus can overlap the memory copy and computation
template<
typename T, // Traits
int PHASE_IDX, // See comments in the code
typename Engine0, typename Layout0,
typename Engine1, typename Layout1,
typename Engine2, typename Layout2,
typename Engine3, typename Layout3
>
__forceinline__ __device__ void warpgroup_cooperative_qkt_gemm(
Tensor<Engine0, Layout0> &sQ, // (BLOCK_SIZE_M, HEAD_DIM_K)
Tensor<Engine1, Layout1> &sKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K)
Tensor<Engine2, Layout2> &rP, // ((2, 2, 8), 1, 1)
Tensor<Engine3, Layout3> &rQ8, // The 8-th tile of Q. We store it separately to leave some room for storing sP1
TMABarrier* barriers,
bool &cur_phase,
int idx_in_warpgroup
) {
Tensor sQ_tiled = flat_divide(sQ, Shape<Int<T::BLOCK_SIZE_M>, _64>{})(_, _, _0{}, _); // (BLOCK_SIZE_M, 64, 9)
Tensor sKV_tiled = flat_divide(sKV, Shape<Int<T::PAGE_BLOCK_SIZE>, _64>{})(_, _, _0{}, _); // (PAGE_BLOCK_SIZE, 64, 9)
TiledMMA tiled_mma_sQ = (typename T::TiledMMA_QK_sQ){};
ThrMMA thr_mma_sQ = tiled_mma_sQ.get_slice(idx_in_warpgroup);
Tensor thr_mma_sQ_tiled = thr_mma_sQ.partition_fragment_A(sQ_tiled); // (MMA, 1, 4, 9)
Tensor thr_mma_sKV_tiled = thr_mma_sQ.partition_fragment_B(sKV_tiled); // (MMA, 1, 4, 9)
TiledMMA tiled_mma_rQ = (typename T::TiledMMA_QK_rQ){};
#define QKT_GEMM_ONE_TILE(TILE_IDX) \
if constexpr(TILE_IDX != 8) { \
qkt_gemm_one_tile_sQ(tiled_mma_sQ, thr_mma_sQ_tiled(_, _, _, Int<TILE_IDX>{}), thr_mma_sKV_tiled(_, _, _, Int<TILE_IDX>{}), rP, barriers + TILE_IDX, cur_phase, idx_in_warpgroup); \
} else { \
qkt_gemm_one_tile_rQ(tiled_mma_rQ, rQ8, thr_mma_sKV_tiled(_, _, _, Int<TILE_IDX>{}), rP, barriers + TILE_IDX, cur_phase, idx_in_warpgroup); \
}
if constexpr (PHASE_IDX == 0) {
// In PHASE-0, warpgroup 0 calculates Q K^T for the first 4 tiles
tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::Zero;
tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One;
QKT_GEMM_ONE_TILE(0);
QKT_GEMM_ONE_TILE(1);
QKT_GEMM_ONE_TILE(2);
QKT_GEMM_ONE_TILE(3);
} else if constexpr (PHASE_IDX == 1) {
// In PHASE-1, warpgroup 1 calculates Q K^T for all the 9 tiles
tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::Zero;
tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One;
QKT_GEMM_ONE_TILE(4);
QKT_GEMM_ONE_TILE(5);
QKT_GEMM_ONE_TILE(6);
QKT_GEMM_ONE_TILE(7);
QKT_GEMM_ONE_TILE(8);
QKT_GEMM_ONE_TILE(0);
QKT_GEMM_ONE_TILE(1);
QKT_GEMM_ONE_TILE(2);
QKT_GEMM_ONE_TILE(3);
cur_phase ^= 1;
} else {
// In PHASE-2, warpgroup 0 calculates Q K^T for the last 5 tiles
static_assert(PHASE_IDX == 2);
tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::One;
tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One;
QKT_GEMM_ONE_TILE(4);
QKT_GEMM_ONE_TILE(5);
QKT_GEMM_ONE_TILE(6);
QKT_GEMM_ONE_TILE(7);
QKT_GEMM_ONE_TILE(8);
cur_phase ^= 1;
}
}
template<
typename T,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1,
typename Engine2, typename Layout2
>
__forceinline__ __device__ void warpgroup_cooperative_qkt_gemm_no_pipeline(
Tensor<Engine0, Layout0> &sQ, // (BLOCK_SIZE_M, HEAD_DIM_K)
Tensor<Engine1, Layout1> &sKV, // (BLOCK_SIZE_M, HEAD_DIM_K)
Tensor<Engine2, Layout2> &rP, // ((2, 2, 8), 1, 1)
int idx_in_warpgroup
) {
TiledMMA tiled_mma = (typename T::TiledMMA_QK_sQ){};
ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);
Tensor thr_mma_sQ = thr_mma.partition_fragment_A(sQ); // (MMA, 1, 576/16=36)
Tensor thr_mma_sKV = thr_mma.partition_fragment_B(sKV); // (MMA, 1, 576/16=36)
gemm<true, -1>(tiled_mma, thr_mma_sQ, thr_mma_sKV, rP);
}
// Compute O += PV, where P resides in register
template<
typename T,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1,
typename Engine2, typename Layout2
>
__forceinline__ __device__ void warpgroup_cooperative_pv_gemm_localP(
Tensor<Engine0, Layout0> &rP, // ((2, 2, 8), 1, 1), fragment A layout
Tensor<Engine1, Layout1> &sKV_half, // (HEAD_DIM_V/2, PAGE_BLOCK_SIZE)
Tensor<Engine2, Layout2> &rO, // ((2, 2, 32), 1, 1)
int idx_in_warpgroup
) {
TiledMMA tiled_mma = (typename T::TiledMMA_PV_LocalP){};
ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);
Tensor rP_retiled = make_tensor(rP.data(), Layout<
Shape<Shape<_2, _2, _2>, _1, _4>,
Stride<Stride<_1, _2, _4>, _0, _8>
>{});
Tensor thr_mma_sKV_half = thr_mma.partition_fragment_B(sKV_half); // (MMA, 1, 64/16=4)
gemm<false, -1>(tiled_mma, rP_retiled, thr_mma_sKV_half, rO);
}
// Compute O += PV, where P resides in shared memory
template<
typename T,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1,
typename Engine2, typename Layout2
>
__forceinline__ __device__ void warpgroup_cooperative_pv_gemm_remoteP(
Tensor<Engine0, Layout0> &sP,
Tensor<Engine1, Layout1> &sKV_half, // (HEAD_DIM_V/2, PAGE_BLOCK_SIZE)
Tensor<Engine2, Layout2> &rO, // ((2, 2, 32), 1, 1)
int idx_in_warpgroup
) {
TiledMMA tiled_mma = (typename T::TiledMMA_PV_RemoteP){};
ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);
Tensor thr_mma_sP = thr_mma.partition_fragment_A(sP);
Tensor thr_mma_sKV_half = thr_mma.partition_fragment_B(sKV_half); // (MMA, 1, 64/16=4)
gemm<false, -1>(tiled_mma, thr_mma_sP, thr_mma_sKV_half, rO);
}
template<
typename T,
bool DO_OOB_FILLING,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1,
typename Engine2, typename Layout2,
typename Engine3, typename Layout3,
typename Engine4, typename Layout4
>
__forceinline__ __device__ void wg0_bunch_0(
Tensor<Engine0, Layout0> &rPb, // ((2, 2, 8), 1, 1)
Tensor<Engine1, Layout1> &rP0, // ((2, 2, 8), 1, 1)
Tensor<Engine2, Layout2> &rO0, // ((2, 2, 32), 1, 1)
Tensor<Engine3, Layout3> &sScale0, // (BLOCK_SIZE_M)
Tensor<Engine4, Layout4> &sM, // (BLOCK_SIZE_M)
float rL[2],
int rRightBorderForQSeq[2],
float scale_softmax_log2,
int start_token_idx,
int idx_in_warpgroup
) {
// This piece of code is tightly coupled [Accumulate's layout](https://docs.nvidia.com/cuda/parallel-thread-execution/_images/wgmma-64N16-D.png)
CUTLASS_PRAGMA_UNROLL
for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {
int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup);
// Mask, and get row-wise max
float cur_max = MAX_INIT_VAL;
CUTLASS_PRAGMA_UNROLL
for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) {
if constexpr (DO_OOB_FILLING) {
int token_idx = start_token_idx + (i/4)*8 + idx_in_warpgroup%4*2;
rP0(i) = token_idx < rRightBorderForQSeq[local_row_idx] ? rP0(i) : MAX_INIT_VAL;
rP0(i+1) = token_idx+1 < rRightBorderForQSeq[local_row_idx] ? rP0(i+1) : MAX_INIT_VAL;
}
cur_max = max(cur_max, max(rP0(i), rP0(i+1)));
}
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1));
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2));
// Update sM and sL
cur_max *= scale_softmax_log2;
float new_max = max(sM(row_idx), cur_max);
float scale_for_old = exp2f(sM(row_idx) - new_max);
__syncwarp(); // Make sure all reads have finished before updating sM
if (idx_in_warpgroup%4 == 0) {
sScale0(row_idx) = scale_for_old;
sM(row_idx) = new_max;
}
// Scale-O
CUTLASS_PRAGMA_UNROLL
for (int i = local_row_idx ? 2 : 0; i < size(rO0); i += 4) {
rO0(i) *= scale_for_old;
rO0(i+1) *= scale_for_old;
}
// Scale, exp, and get row-wise expsum
float cur_sum = 0;
CUTLASS_PRAGMA_UNROLL
for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) {
rP0(i) = exp2f(rP0(i)*scale_softmax_log2 - new_max);
rP0(i+1) = exp2f(rP0(i+1)*scale_softmax_log2 - new_max);
rPb(i) = (typename T::InputT)rP0(i);
rPb(i+1) = (typename T::InputT)rP0(i+1);
cur_sum += rP0(i) + rP0(i+1);
}
rL[local_row_idx] = rL[local_row_idx]*scale_for_old + cur_sum;
}
}
template<
typename T,
bool IS_BLK0_LAST,
bool IS_BLK1_LAST,
bool IS_BLK2_LAST,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1,
typename Engine2, typename Layout2,
typename Engine3, typename Layout3,
typename Engine4, typename Layout4,
typename Engine5, typename Layout5
>
__forceinline__ __device__ void wg1_bunch_0(
Tensor<Engine0, Layout0> &rP1b, // ((2, 2, 8), 1, 1)
Tensor<Engine1, Layout1> &sScale1, // (BLOCK_SIZE_M)
Tensor<Engine2, Layout2> &rO1, // ((2, 2, 32), 1, 1)
Tensor<Engine3, Layout3> &sM, // (BLOCK_SIZE_M)
float rL[2],
int rRightBorderForQSeq[2],
Tensor<Engine4, Layout4> const &sScale0, // (BLOCK_SIZE_M)
Tensor<Engine5, Layout5> &rP1, // ((2, 2, 8), 1, 1)
float scale_softmax_log2,
int start_token_idx,
int idx_in_warpgroup
) {
CUTLASS_PRAGMA_UNROLL
for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {
int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup);
// Mask, and get row-wise max
float cur_max = MAX_INIT_VAL;
CUTLASS_PRAGMA_UNROLL
for (int i = local_row_idx ? 2 : 0; i < size(rP1); i += 4) {
if constexpr (IS_BLK1_LAST || IS_BLK2_LAST) {
// Need to apply the mask when either this block is the last one, or
// the next block is the last one (because of the causal mask)
int token_idx = start_token_idx + (i/4)*8 + idx_in_warpgroup%4*2;
rP1(i) = token_idx < rRightBorderForQSeq[local_row_idx] ? rP1(i) : MAX_INIT_VAL;
rP1(i+1) = token_idx+1 < rRightBorderForQSeq[local_row_idx] ? rP1(i+1) : MAX_INIT_VAL;
} else if constexpr (IS_BLK0_LAST) {
rP1(i) = rP1(i+1) = MAX_INIT_VAL;
}
cur_max = max(cur_max, max(rP1(i), rP1(i+1)));
}
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1));
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2));
cur_max *= scale_softmax_log2;
float old_max = sM(row_idx);
float new_max = max(old_max, cur_max);
float scale_for_old = exp2f(old_max - new_max);
__syncwarp();
if (idx_in_warpgroup%4 == 0) {
sM(row_idx) = new_max;
sScale1(row_idx) = scale_for_old;
}
// Scale, exp, and get row-wise expsum
float cur_sum = 0;
if constexpr (!IS_BLK0_LAST) {
CUTLASS_PRAGMA_UNROLL
for (int i = local_row_idx ? 2 : 0; i < size(rP1); i += 4) {
rP1(i) = exp2f(rP1(i)*scale_softmax_log2 - new_max);
rP1(i+1) = exp2f(rP1(i+1)*scale_softmax_log2 - new_max);
rP1b(i) = (typename T::InputT)rP1(i);
rP1b(i+1) = (typename T::InputT)rP1(i+1);
cur_sum += rP1(i) + rP1(i+1);
}
}
// Scale O
float cur_scale_for_o1 = scale_for_old * sScale0(row_idx);
CUTLASS_PRAGMA_UNROLL
for (int i = local_row_idx ? 2 : 0; i < size(rO1); i += 4) {
rO1(i) *= cur_scale_for_o1;
rO1(i+1) *= cur_scale_for_o1;
}
// Update rL
rL[local_row_idx] = rL[local_row_idx]*cur_scale_for_o1 + cur_sum;
}
}
// Save rPb (64x64, bfloat16/half) to sP using the stmatrix instruction
template<
typename T,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1
>
__forceinline__ __device__ void save_rPb_to_sP(
Tensor<Engine0, Layout0> &rPb,
Tensor<Engine1, Layout1> &sP,
int idx_in_warpgroup
) {
auto r2s_copy = make_tiled_copy_C(
Copy_Atom<SM90_U32x4_STSM_N, typename T::InputT>{},
(typename T::TiledMMA_QK_sQ){}
);
ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup);
Tensor thr_copy_rPb = thr_copy.retile_S(rPb);
Tensor thr_copy_sP = thr_copy.partition_D(sP);
cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP);
}
// Retrieve rPb (64x64, bfloat16/half) from sP using the ldmatrix instruction
template<
typename T,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1
>
__forceinline__ __device__ void retrieve_rP_from_sP(
Tensor<Engine0, Layout0> &rPb,
Tensor<Engine1, Layout1> const &sP,
int idx_in_warpgroup
) {
TiledCopy s2r_copy = make_tiled_copy_A(
Copy_Atom<SM75_U32x4_LDSM_N, typename T::InputT>{},
(typename T::TiledMMA_PV_LocalP){}
);
ThrCopy thr_copy = s2r_copy.get_slice(idx_in_warpgroup);
Tensor thr_copy_sP = thr_copy.partition_S(sP);
Tensor thr_copy_rPb = thr_copy.retile_D(rPb);
cute::copy(s2r_copy, thr_copy_sP, thr_copy_rPb);
}
// Rescale rP0 and save the result to rPb
template<
typename T,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1,
typename Engine2, typename Layout2
>
__forceinline__ __device__ void wg0_scale_rP0(
Tensor<Engine0, Layout0> const &sScale1, // (BLOCK_M)
Tensor<Engine1, Layout1> const &rP0, // ((2, 2, 8), 1, 1)
Tensor<Engine2, Layout2> &rPb, // ((2, 2, 8), 1, 1)
int idx_in_warpgroup
) {
CUTLASS_PRAGMA_UNROLL
for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {
int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup);
float scale_factor = sScale1(row_idx);
CUTLASS_PRAGMA_UNROLL
for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) {
rPb(i) = (typename T::InputT)(rP0(i)*scale_factor);
rPb(i+1) = (typename T::InputT)(rP0(i+1)*scale_factor);
}
}
}
// Rescale rO0 according to sScale1
template<
typename Engine0, typename Layout0,
typename Engine1, typename Layout1
>
__forceinline__ __device__ void wg0_rescale_rO0(
Tensor<Engine0, Layout0> &rO0,
Tensor<Engine1, Layout1> &sScale1,
float rL[2],
int idx_in_warpgroup
) {
CUTLASS_PRAGMA_UNROLL
for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {
int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup);
float scale_factor = sScale1(row_idx);
CUTLASS_PRAGMA_UNROLL
for (int i = local_row_idx ? 2 : 0; i < size(rO0); i += 4) {
rO0(i) *= scale_factor;
rO0(i+1) *= scale_factor;
}
rL[local_row_idx] *= scale_factor;
}
}
// Fill out-of-bound V with 0.0
// We must fill it since it may contain NaN, which may propagate to the final result
template<
typename T,
typename Engine0, typename Layout0
>
__forceinline__ __device__ void fill_oob_V(
Tensor<Engine0, Layout0> &sV, // tile_to_shape(GMMA::Layout_MN_SW128_Atom<InputT>{}, Shape<Int<HALF_HEAD_DIM>, Int<T::PAGE_BLOCK_SIZE>>{}, LayoutRight{} );
int valid_window_size,
int idx_in_warpgroup
) {
Tensor sV_int64 = make_tensor(
make_smem_ptr((int64_t*)(sV.data().get().get())),
tile_to_shape(
GMMA::Layout_MN_SW128_Atom<cute::int64_t>{},
Shape<Int<256/(64/16)>, Int<T::PAGE_BLOCK_SIZE>>{},
LayoutRight{}
)
);
valid_window_size = max(valid_window_size, 0);
int head_dim_size = size<0>(sV_int64); // 128%head_dim_size == 0 should holds
for (int token_idx = valid_window_size + (idx_in_warpgroup/head_dim_size); token_idx < size<1>(sV); token_idx += (128/head_dim_size)) {
sV_int64(idx_in_warpgroup%head_dim_size, token_idx) = 0;
}
}
// Store O / OAccum template<typename T>
template< __global__ void __launch_bounds__(T::NUM_THREADS, 1)
typename T, flash_fwd_splitkv_mla_kernel(const DenseAttnDecodeParams params) {
bool IS_NO_SPLIT,
typename TMAParams,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1
>
__forceinline__ __device__ void store_o(
Tensor<Engine0, Layout0> &rO, // ((2, 2, 32), 1, 1)
Tensor<Engine1, Layout1> &gOorAccum, // (BLOCK_SIZE_M, HEAD_DIM_V)
float rL[2],
char* sO_addr,
TMAParams &tma_params,
int batch_idx,
int k_head_idx,
int m_block_idx,
int num_valid_seq_q,
int warpgroup_idx,
int idx_in_warpgroup
) {
using InputT = typename T::InputT;
if constexpr (IS_NO_SPLIT) {
// Should convert the output to bfloat16 / float16, and save it to O
Tensor sOutputBuf = make_tensor(make_smem_ptr((InputT*)sO_addr), tile_to_shape(
GMMA::Layout_K_SW128_Atom<InputT>{},
Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_V>>{}
));
Tensor rOb = make_tensor_like<InputT>(rO);
CUTLASS_PRAGMA_UNROLL
for (int idx = 0; idx < size(rO); ++idx) {
rOb(idx) = (InputT)(rO(idx) / rL[idx%4 >= 2]);
}
Tensor sMyOutputBuf = local_tile(sOutputBuf, Shape<_64, _256>{}, make_coord(_0{}, warpgroup_idx));
TiledCopy r2s_tiled_copy = make_tiled_copy_C(
Copy_Atom<SM90_U32x4_STSM_N, InputT>{},
(typename T::TiledMMA_PV_LocalP){}
);
ThrCopy r2s_thr_copy = r2s_tiled_copy.get_slice(idx_in_warpgroup);
Tensor r2s_thr_copy_rOb = r2s_thr_copy.retile_S(rOb);
Tensor r2s_thr_copy_sMyOutputBuf = r2s_thr_copy.partition_D(sMyOutputBuf);
cute::copy(r2s_tiled_copy, r2s_thr_copy_rOb, r2s_thr_copy_sMyOutputBuf);
cutlass::arch::fence_view_async_shared();
__syncthreads();
if (threadIdx.x == 0) {
Tensor tma_gO = tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, k_head_idx, batch_idx); // (seqlen_q, HEAD_DIM)
auto thr_tma = tma_params.tma_O.get_slice(_0{});
Tensor my_tma_gO = flat_divide(tma_gO, Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_V>>{})(_, _, m_block_idx, _0{});
cute::copy(
tma_params.tma_O,
thr_tma.partition_S(sOutputBuf),
thr_tma.partition_D(my_tma_gO)
);
cute::tma_store_arrive();
}
} else {
// Should save the result to OAccum
Tensor sOutputBuf = make_tensor(make_smem_ptr((float*)sO_addr), Layout<
Shape<_64, _512>,
Stride<Int<520>, _1> // We use stride = 520 here to avoid bank conflict
>{});
CUTLASS_PRAGMA_UNROLL
for (int idx = 0; idx < size(rO); idx += 2) {
int row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%32/4) + (idx%4 >= 2 ? 8 : 0);
int col = warpgroup_idx*256 + (idx_in_warpgroup%4)*2 + idx/4*8;
*(float2*)((float*)sO_addr + sOutputBuf.layout()(row, col)) = float2 {
rO(idx) / rL[idx%4 >= 2],
rO(idx+1) / rL[idx%4 >= 2],
};
}
cutlass::arch::fence_view_async_shared();
__syncthreads();
int row = threadIdx.x;
if (row < num_valid_seq_q) {
SM90_BULK_COPY_S2G::copy(&sOutputBuf(row, _0{}), &gOorAccum(row, _0{}), T::HEAD_DIM_V*sizeof(float));
cute::tma_store_arrive();
}
}
}
template<
typename T,
typename TmaParams, typename Tensor0
>
__forceinline__ __device__ void launch_q_copy(
TmaParams const &tma_params,
int batch_idx,
int m_block_idx,
int k_head_idx,
Tensor0 &sQ,
TMABarrier* barrier_Q
) {
if (threadIdx.x == 0) {
Tensor tma_gQ = tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, k_head_idx, batch_idx); // (seqlen_q, HEAD_DIM)
auto thr_tma = tma_params.tma_Q.get_slice(_0{});
Tensor my_tma_gQ = flat_divide(tma_gQ, Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_K>>{})(_, _, m_block_idx, _0{});
cute::copy(
tma_params.tma_Q.with(reinterpret_cast<typename TMABarrier::ValueType &>(*barrier_Q), 0, cute::TMA::CacheHintSm90::EVICT_FIRST),
thr_tma.partition_S(my_tma_gQ),
thr_tma.partition_D(sQ)
);
barrier_Q->arrive_and_expect_tx(64*576*2);
}
}
template<
typename T,
bool IS_R,
typename Engine0, typename Layout0
>
__forceinline__ __device__ auto get_half_V(
Tensor<Engine0, Layout0> &sK
) {
Tensor sV = make_tensor(sK.data(), (typename T::SmemLayoutV){});
return flat_divide(sV, Shape<Int<T::HEAD_DIM_V/2>, Int<T::PAGE_BLOCK_SIZE>>{})(_, _, Int<(int)IS_R>{}, _0{});
}
template<
typename T,
bool IS_BLK0_LAST, // "BLK0" means block_idx+0, "BLK1" means block_idx+1, ...
bool IS_BLK1_LAST,
typename TMAParams,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1,
typename Engine2, typename Layout2,
typename Engine3, typename Layout3,
typename Engine4, typename Layout4,
typename Engine5, typename Layout5,
typename Engine6, typename Layout6,
typename Engine7, typename Layout7,
typename Engine8, typename Layout8,
typename Engine9, typename Layout9,
typename Engine10, typename Layout10,
typename Engine11, typename Layout11
>
__forceinline__ __device__ void wg0_subroutine(
Tensor<Engine0, Layout0> &tma_gK,
Tensor<Engine1, Layout1> &sQ,
Tensor<Engine2, Layout2> &sK0,
Tensor<Engine3, Layout3> &sK1,
Tensor<Engine4, Layout4> &sP0,
Tensor<Engine5, Layout5> &sP1,
Tensor<Engine6, Layout6> &sM,
Tensor<Engine7, Layout7> &sScale0,
Tensor<Engine8, Layout8> &sScale1,
Tensor<Engine9, Layout9> &rQ8,
Tensor<Engine10, Layout10> &rP0,
Tensor<Engine11, Layout11> &rO0,
float rL[2],
int rRightBorderForQSeq[2],
TMABarrier barriers_K0[9],
TMABarrier barriers_K1[9],
bool &cur_phase_K0,
const TMAParams &tma_params,
const DenseAttnDecodeParams &params,
int* block_table_ptr,
int seqlen_k,
int block_idx,
int end_block_idx,
int idx_in_warpgroup
) {
int start_token_idx = block_idx * T::PAGE_BLOCK_SIZE;
#define GET_BLOCK_INDEX(block_idx) ((block_idx) >= end_block_idx ? 0 : __ldg(block_table_ptr + (block_idx)))
int nxt_block0_index = GET_BLOCK_INDEX(block_idx+2);
int nxt_block1_index = GET_BLOCK_INDEX(block_idx+3);
Tensor sV0L = get_half_V<T, 0>(sK0);
Tensor sV1L = get_half_V<T, 0>(sK1);
Tensor rPb = make_tensor<T::InputT>(Shape<Shape<_2, _2, _2>, _1, _4>{});
// Calc P0 = softmax(P0)
wg0_bunch_0<T, IS_BLK0_LAST||IS_BLK1_LAST>(rPb, rP0, rO0, sScale0, sM, rL, rRightBorderForQSeq, params.scale_softmax_log2, start_token_idx, idx_in_warpgroup);
NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sScale0Ready);
// Issue rO0 += rPb @ sV0L
if constexpr (IS_BLK0_LAST) {
fill_oob_V<T>(sV0L, seqlen_k-start_token_idx, idx_in_warpgroup);
cutlass::arch::fence_view_async_shared();
}
warpgroup_cooperative_pv_gemm_localP<T>(rPb, sV0L, rO0, idx_in_warpgroup);
// Wait for rO0, launch TMA for the next V0L
cute::warpgroup_wait<0>();
// Wait for warpgroup 1, rescale P0, notify warpgroup 1
NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sScale1Ready);
if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) {
// Put it here seems to be faster, don't know why
launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, nxt_block0_index), sK0, tma_params.tma_K, barriers_K0, idx_in_warpgroup);
}
wg0_scale_rP0<T>(sScale1, rP0, rPb, idx_in_warpgroup);
save_rPb_to_sP<T>(rPb, sP0, idx_in_warpgroup);
cutlass::arch::fence_view_async_shared();
NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sP0Ready);
// Wait for warpgroup 1, rescale O0, issue rO0 += rPb @ sV1L
if constexpr (!IS_BLK0_LAST) {
if constexpr (IS_BLK1_LAST) {
fill_oob_V<T>(sV1L, seqlen_k-start_token_idx-T::PAGE_BLOCK_SIZE, idx_in_warpgroup);
cutlass::arch::fence_view_async_shared();
}
NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::rO1sP0sV0RIssued);
wg0_rescale_rO0(rO0, sScale1, rL, idx_in_warpgroup);
warpgroup_cooperative_pv_gemm_remoteP<T>(sP1, sV1L, rO0, idx_in_warpgroup);
}
// Issue P0 = Q @ K0^T
// Since TMAs for these 4 tiles are launched right after rO0 += rPb @ sV0L finishes, they should have already finished. Therefore, we issue the first 4 tiles to fill the pipeline.
if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) {
warpgroup_cooperative_qkt_gemm<T, 0>(sQ, sK0, rP0, rQ8, barriers_K0, cur_phase_K0, idx_in_warpgroup);
}
// Wait for rO0 += rPb @ sV1L, launch TMA
if (!IS_BLK0_LAST && !IS_BLK1_LAST && __builtin_expect(block_idx+3 < end_block_idx, true)) {
cute::warpgroup_wait<4>();
launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, nxt_block1_index), sK1, tma_params.tma_K, barriers_K1, idx_in_warpgroup);
}
// Issue P0 = Q @ K0^T
if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) {
warpgroup_cooperative_qkt_gemm<T, 2>(sQ, sK0, rP0, rQ8, barriers_K0, cur_phase_K0, idx_in_warpgroup);
}
// Wait for P0 = Q @ K0^T
cute::warpgroup_wait<0>();
}
template<
typename T,
bool IS_BLK0_LAST,
bool IS_BLK1_LAST,
bool IS_BLK2_LAST,
typename TMAParams,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1,
typename Engine2, typename Layout2,
typename Engine3, typename Layout3,
typename Engine4, typename Layout4,
typename Engine5, typename Layout5,
typename Engine6, typename Layout6,
typename Engine7, typename Layout7,
typename Engine8, typename Layout8,
typename Engine9, typename Layout9,
typename Engine10, typename Layout10,
typename Engine11, typename Layout11
>
__forceinline__ __device__ void wg1_subroutine(
Tensor<Engine0, Layout0> &tma_gK,
Tensor<Engine1, Layout1> &sQ,
Tensor<Engine2, Layout2> &sK0,
Tensor<Engine3, Layout3> &sK1,
Tensor<Engine4, Layout4> &sP0,
Tensor<Engine5, Layout5> &sP1,
Tensor<Engine6, Layout6> &sM,
Tensor<Engine7, Layout7> &sScale0,
Tensor<Engine8, Layout8> &sScale1,
Tensor<Engine9, Layout9> &rQ8,
Tensor<Engine10, Layout10> &rP1,
Tensor<Engine11, Layout11> &rO1,
float rL[2],
int rRightBorderForQSeq[2],
TMABarrier barriers_K0[9],
TMABarrier barriers_K1[9],
bool &cur_phase_K1,
const TMAParams &tma_params,
const DenseAttnDecodeParams &params,
int* block_table_ptr,
int seqlen_k,
int block_idx,
int end_block_idx,
int idx_in_warpgroup
) {
int start_token_idx = block_idx * T::PAGE_BLOCK_SIZE;
int nxt_block0_index = GET_BLOCK_INDEX(block_idx+2);
int nxt_block1_index = GET_BLOCK_INDEX(block_idx+3);
Tensor rP1b = make_tensor<T::InputT>(Shape<Shape<_2, _2, _2>, _1, _4>{});
Tensor sV0R = get_half_V<T, 1>(sK0);
Tensor sV1R = get_half_V<T, 1>(sK1);
// Wait for rP1 and warpgroup 0, run bunch 1, notify warpgroup 0
NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sScale0Ready);
wg1_bunch_0<T, IS_BLK0_LAST, IS_BLK1_LAST, IS_BLK2_LAST>(rP1b, sScale1, rO1, sM, rL, rRightBorderForQSeq, sScale0, rP1, params.scale_softmax_log2, start_token_idx+T::PAGE_BLOCK_SIZE, idx_in_warpgroup);
NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sScale1Ready);
// Save rPb to sP, and issue rO1 += rP1b @ sV1R
// We do this after notifying warpgroup 1, since both "saving rPb to sP" and "issuing" WGMMA are high-latency operations
if constexpr (!IS_BLK0_LAST) {
save_rPb_to_sP<T>(rP1b, sP1, idx_in_warpgroup);
}
if constexpr (!IS_BLK0_LAST) {
if constexpr (IS_BLK1_LAST) {
fill_oob_V<T>(sV1R, seqlen_k-start_token_idx-T::PAGE_BLOCK_SIZE, idx_in_warpgroup);
cutlass::arch::fence_view_async_shared();
}
warpgroup_cooperative_pv_gemm_localP<T>(rP1b, sV1R, rO1, idx_in_warpgroup);
if constexpr (!IS_BLK1_LAST) {
// We use this proxy for making sP1 visible to the async proxy
// We skip it if IS_BLK1_LAST, since in that case we have already put a fence
cutlass::arch::fence_view_async_shared();
}
}
// Wait for sP0, issue rO1 += sP0 @ sV0R, notify warpgroup 0
NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sP0Ready);
if constexpr (IS_BLK0_LAST) {
fill_oob_V<T>(sV0R, seqlen_k-start_token_idx, idx_in_warpgroup);
cutlass::arch::fence_view_async_shared();
}
warpgroup_cooperative_pv_gemm_remoteP<T>(sP0, sV0R, rO1, idx_in_warpgroup);
if constexpr (!IS_BLK0_LAST) {
NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::rO1sP0sV0RIssued);
}
// Wait for rO1 += rP1b @ sV1R, launch TMA for the next V1R
if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST && !IS_BLK2_LAST) {
cute::warpgroup_wait<1>();
launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, nxt_block1_index), sK1, tma_params.tma_K, barriers_K1, idx_in_warpgroup);
}
// Wait for rO1 += sP0 @ sV0R, launch TMA for the next V0R
if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) {
cute::warpgroup_wait<0>();
launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, nxt_block0_index), sK0, tma_params.tma_K, barriers_K0, idx_in_warpgroup);
}
if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST && !IS_BLK2_LAST) {
// Issue rP1 = sQ @ sK1, wait
warpgroup_cooperative_qkt_gemm<T, 1>(sQ, sK1, rP1, rQ8, barriers_K1, cur_phase_K1, idx_in_warpgroup);
}
// We put the `cute::warpgroup_wait<0>()` out of the `if` statement above, otherwise
// nvcc cannot correctly analyse the loop, and will think that we are using accumulator
// registers during the WGMMA pipeline, which results in `WARPGROUP.ARRIVE` and `WARPGROUP.DEPBAR.LE` being inserted in SASS and WGMMA instructions being serialized.
// This is also the reason why we put QK^T here, instead of the first operation in the loop
cute::warpgroup_wait<0>();
}
// A helper function for determining the length of the causal mask for one q token
__forceinline__ __device__ int get_mask_len(const DenseAttnDecodeParams &params, int m_block_idx, int local_seq_q_idx) {
int global_seq_q_idx = m_block_idx*Config::BLOCK_SIZE_M + local_seq_q_idx;
if (global_seq_q_idx < params.q_seq_per_hk) {
int s_q_idx = global_seq_q_idx / params.q_head_per_hk;
return params.s_q - s_q_idx - 1;
} else {
// Out-of-bound request, regard as no masks
return 0;
}
}
template<typename T, typename TmaParams>
__global__ void __launch_bounds__(T::NUM_THREADS, 1, 1)
flash_fwd_splitkv_mla_kernel(__grid_constant__ const DenseAttnDecodeParams params, __grid_constant__ const TmaParams tma_params) {
// grid shape: [
// num_m_blocks (=ceil_div(seqlen_q_ori*(num_q_heads//num_kv_heads))),
// num_kv_heads,
// num_sm_parts
// ]
// An "sm part" is responsible for all the BLOCK_SIZE_M q_heads in the m_block (as specified by m_block_idx), under one kv head (as specified by k_head_idx), of a segment (as specified by [start_block_idx, end_block_idx]) of one request (as specified by batch_idx).
// If is_no_split is True, then this request is exclusively assigned to this sm_part, so we shall write the result directly into params.o_ptr and params.softmax_lse_ptr. Otherwise, write to oaccum_ptr and softmax_lseaccum_ptr, with the corresponding split idx being (n_split_idx + num_splits_ptr[batch_idx])
// For the complete schedule of the kernel, please read our deep-dive write-up (link can be found in the README.md file).
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__))
const int m_block_idx = blockIdx.x;
const int k_head_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
const int warpgroup_idx = threadIdx.x / 128;
const int idx_in_warpgroup = threadIdx.x % 128;
// Define shared tensors
extern __shared__ char wksp_buf[];
using SharedMemoryPlan = typename T::SharedMemoryPlan;
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);
Tensor sQ = make_tensor(make_smem_ptr(plan.smem_sQ.data()), (typename T::SmemLayoutQ){});
Tensor sK0 = make_tensor(make_smem_ptr(plan.smem_sK0.data()), (typename T::SmemLayoutK){});
Tensor sK1 = make_tensor(make_smem_ptr(plan.smem_sK1.data()), (typename T::SmemLayoutK){});
Tensor sP0 = make_tensor(make_smem_ptr(plan.smem_sP0.data()), (typename T::SmemLayoutP0){});
Tensor sP1 = flat_divide(sQ, Shape<Int<T::BLOCK_SIZE_M>, Int<T::PAGE_BLOCK_SIZE>>{})(_, _, _0{}, _8{}); // Overlap with sQ's 8-th tile
Tensor sM = make_tensor(make_smem_ptr(plan.smem_sM.data()), make_shape(Int<T::BLOCK_SIZE_M>{}));
Tensor sL_reduction_wksp = make_tensor(make_smem_ptr(plan.sL_reduction_wksp.data()), make_shape(Int<2*T::BLOCK_SIZE_M>{}));
Tensor sScale0 = make_tensor(make_smem_ptr(plan.smem_sScale0.data()), make_shape(Int<T::BLOCK_SIZE_M>{}));
Tensor sScale1 = make_tensor(make_smem_ptr(plan.smem_sScale1.data()), make_shape(Int<T::BLOCK_SIZE_M>{}));
char* sO_addr = (char*)plan.smem_sK0.data(); // Overlap with sK0 and sK1
// Prefetch TMA descriptors
if (threadIdx.x == 0) {
cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_params.tma_K.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor());
}
// Define TMA stuffs
Tensor tma_gK = tma_params.tma_K.get_tma_tensor(tma_params.shape_K)(_, _, k_head_idx, _);
TMABarrier* barriers_K0 = plan.barriers_K0;
TMABarrier* barriers_K1 = plan.barriers_K1;
TMABarrier* barrier_Q = &(plan.barrier_Q);
// Initialize TMA barriers
if (threadIdx.x == 0) {
barrier_Q->init(1);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 9; ++i) {
barriers_K0[i].init(1);
barriers_K1[i].init(1);
}
cutlass::arch::fence_view_async_shared();
}
__syncthreads();
bool cur_phase_Q = 0, cur_phase_K0 = 0, cur_phase_K1 = 0;
DecodingSchedMeta sched_meta = params.tile_scheduler_metadata_ptr[partition_idx];
if (sched_meta.begin_req_idx >= params.b) return;
// Copy the first Q
launch_q_copy<T>(tma_params, sched_meta.begin_req_idx, m_block_idx, k_head_idx, sQ, barrier_Q);
#pragma unroll 1
for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) {
constexpr int kBlockN = T::PAGE_BLOCK_SIZE;
const int n_split_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_split_idx : 0;
int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx);
const int start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0;
int end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : cute::ceil_div(seqlen_k, kBlockN);
const bool is_no_split = batch_idx == sched_meta.begin_req_idx ? !sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? !sched_meta.is_last_req_splitted : true);
int rRightBorderForQSeq[2];
if (params.is_causal) {
// The causal mask looks like:
// XXXX
// XXXX
// ...
// XXXX
// XXX
// XXX
// ...
// XXX
// XX
// XX
// ...
// XX
// Firstly, there is a common_mask_len, which is the minimum length of causal masks among all tokens. Since the length of the causal mask decreases monotonically, the common_mask_len is the length of the causal mask for the last token. We consider the common_mask_len as a "reduction in the length of the k-sequence.", and adjust end_block_idx based on it, to save some calculation.
// Besides, a token may have some extra masks other than the common mask. We use rRightBorderForQSeq to denote it, which means the right border of the k-sequence for the particular q token. In this way, (seqlen_k-common_mask_len) - rRightBorderForQSeq < 64 holds, which means that we only need to apply the causal mask to the last two KV blocks
// NOTE This may lead to start_block_idx >= end_block_idx which needs some special handling
int common_mask_len = get_mask_len(params, m_block_idx, T::BLOCK_SIZE_M-1);
int last_block_in_seq = cute::ceil_div(seqlen_k-common_mask_len, kBlockN);
end_block_idx = batch_idx == sched_meta.end_req_idx ? min(sched_meta.end_block_idx, last_block_in_seq) : last_block_in_seq;
CUTLASS_PRAGMA_UNROLL
for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {
int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup);
rRightBorderForQSeq[local_row_idx] = min(seqlen_k-get_mask_len(params, m_block_idx, row_idx), end_block_idx*T::PAGE_BLOCK_SIZE);
}
} else {
rRightBorderForQSeq[0] = rRightBorderForQSeq[1] = seqlen_k;
}
// Define global tensors
using InputT = typename T::InputT;
InputT* o_ptr = (InputT*)params.o_ptr + batch_idx*params.o_batch_stride + m_block_idx*T::BLOCK_SIZE_M*params.o_row_stride + k_head_idx*params.o_head_stride; // (BLOCK_SIZE_M, HEAD_DIM_V) : (params.o_row_stride, 1)
float* softmax_lse_ptr = (float*)params.softmax_lse_ptr + (batch_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M; // (BLOCK_SIZE_M) : (1)
int* block_table_ptr = params.block_table + batch_idx*params.block_table_batch_stride; // (/) : (1)
Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout(
Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_V>>{},
make_stride(params.o_row_stride, _1{})
));
Tensor gSoftmaxLse = make_tensor(make_gmem_ptr(softmax_lse_ptr), Layout<
Shape<Int<T::BLOCK_SIZE_M>>,
Stride<_1>
>{});
// Copy K0 and K1
launch_kv_tiles_copy_tma<0, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx)), sK0, tma_params.tma_K, barriers_K0, threadIdx.x);
if (start_block_idx+1 < end_block_idx) {
launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x);
launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x);
}
Tensor rO = partition_fragment_C((typename T::TiledMMA_PV_LocalP){}, Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_V / 2>>{}); // ((2, 2, 32), 1, 1)
float rL[2];
rL[0] = rL[1] = 0.0f;
// Clear buffers
cute::fill(rO, 0.);
if (threadIdx.x < size(sM)) {
sM[threadIdx.x] = MAX_INIT_VAL_SM;
}
// Wait for Q
barrier_Q->wait(cur_phase_Q);
cur_phase_Q ^= 1;
Tensor rQ8 = make_tensor<InputT>(Shape<Shape<_2, _2, _2>, _1, _4>{});
retrieve_rP_from_sP<T>(rQ8, local_tile(sQ, Shape<_64, _64>{}, Coord<_0, _8>{}), idx_in_warpgroup);
if (warpgroup_idx == 0) {
// Warpgroup 0
Tensor rP0 = make_tensor<float>((typename T::rP0Layout){});
// NOTE We don't use the pipelined version of Q K^T here since it leads
// to a slow-down (or even register spilling, thanks to the great NVCC)
// Wait for K0
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 9; ++i) {
if (idx_in_warpgroup == 0)
barriers_K0[i].arrive_and_expect_tx(64*64*2);
barriers_K0[i].wait(cur_phase_K0);
}
cur_phase_K0 ^= 1;
// Issue P0 = Q @ K0^T, wait
if (start_block_idx-16777216 < end_block_idx) { // NOTE We use this `if` to prevent register spilling
warpgroup_cooperative_qkt_gemm_no_pipeline<T>(sQ, sK0, rP0, idx_in_warpgroup);
}
// We add a barrier here, making sure that previous writes to sM are visible to warpgroup 0
NamedBarrier::arrive_and_wait(128, NamedBarriers::sMInitialized);
cute::warpgroup_wait<0>();
#define LAUNCH_WG0_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST) \
wg0_subroutine<T, IS_BLK0_LAST, IS_BLK1_LAST>( \
tma_gK, sQ, sK0, sK1, sP0, sP1, sM, sScale0, sScale1, \
rQ8, rP0, rO, rL, rRightBorderForQSeq, \
barriers_K0, barriers_K1, cur_phase_K0, \
tma_params, params, \
block_table_ptr, seqlen_k, block_idx, end_block_idx, idx_in_warpgroup \
);
int block_idx = start_block_idx;
#pragma unroll 1
for (; block_idx < end_block_idx-2; block_idx += 2) {
LAUNCH_WG0_SUBROUTINE(false, false);
}
if (block_idx+1 < end_block_idx) {
LAUNCH_WG0_SUBROUTINE(false, true);
} else if (block_idx < end_block_idx) {
LAUNCH_WG0_SUBROUTINE(true, false);
}
} else {
// Warpgroup 1
Tensor rP1 = make_tensor<float>((typename T::rP0Layout){});
if (start_block_idx+1 < end_block_idx) {
// Issue rP1 = sQ @ sK1, wait
warpgroup_cooperative_qkt_gemm<T, 1>(sQ, sK1, rP1, rQ8, barriers_K1, cur_phase_K1, idx_in_warpgroup);
cute::warpgroup_wait<0>();
}
#define LAUNCH_WG1_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST, IS_BLK2_LAST) \
wg1_subroutine<T, IS_BLK0_LAST, IS_BLK1_LAST, IS_BLK2_LAST>( \
tma_gK, sQ, sK0, sK1, sP0, sP1, sM, sScale0, sScale1, \
rQ8, rP1, rO, rL, rRightBorderForQSeq, \
barriers_K0, barriers_K1, cur_phase_K1, \
tma_params, params, \
block_table_ptr, seqlen_k, block_idx, end_block_idx, idx_in_warpgroup \
);
int block_idx = start_block_idx;
#pragma unroll 1
for (; block_idx < end_block_idx-3; block_idx += 2) {
LAUNCH_WG1_SUBROUTINE(false, false, false);
}
if (block_idx+2 < end_block_idx) {
LAUNCH_WG1_SUBROUTINE(false, false, true);
block_idx += 2;
LAUNCH_WG1_SUBROUTINE(true, false, false);
} else if (block_idx+1 < end_block_idx) {
LAUNCH_WG1_SUBROUTINE(false, true, false);
} else if (block_idx < end_block_idx) {
LAUNCH_WG1_SUBROUTINE(true, false, false);
}
}
// Reduce rL across threads within the same warp
rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1);
rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2);
rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1);
rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2);
// Reduce rL across warpgroups
int my_row = get_AorC_row_idx(0, idx_in_warpgroup);
if (idx_in_warpgroup%4 == 0) {
sL_reduction_wksp[my_row + warpgroup_idx*64] = rL[0];
sL_reduction_wksp[my_row + 8 + warpgroup_idx*64] = rL[1];
}
__syncthreads();
if (warpgroup_idx == 0) {
rL[0] += sL_reduction_wksp[my_row + 64];
rL[1] += sL_reduction_wksp[my_row + 8 + 64];
} else {
if (idx_in_warpgroup%4 == 0) {
sL_reduction_wksp[my_row] += rL[0];
sL_reduction_wksp[my_row + 8] += rL[1];
}
__syncwarp();
rL[0] = sL_reduction_wksp[my_row];
rL[1] = sL_reduction_wksp[my_row+8];
}
// Prune out when rL is 0.0f or NaN
// rL may be 0.0f if there are large values (~10^12) in QK^T, which leads
// to exp2f(P(i)*scale-max) = 0.0f or +inf due to FMA error.
// When this happens, we set rL to 1.0f. This aligns with the old version
// of the MLA kernel.
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 2; ++i)
rL[i] = (rL[i] == 0.0f || rL[i] != rL[i]) ? 1.0f : rL[i];
// Copy Q for the next batch
if (batch_idx+1 <= sched_meta.end_req_idx) {
launch_q_copy<T>(tma_params, batch_idx+1, m_block_idx, k_head_idx, sQ, barrier_Q);
} else {
// Allow the next kernel (the combine kernel) to launch
// The next kernel MUST be the combine kernel
cudaTriggerProgrammaticLaunchCompletion();
}
int num_valid_seq_q = min(params.q_seq_per_hk - m_block_idx*T::BLOCK_SIZE_M, T::BLOCK_SIZE_M);
if (is_no_split) {
store_o<T, true>(rO, gO, rL, sO_addr, tma_params, batch_idx, k_head_idx, m_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);
int i = threadIdx.x;
if (i < num_valid_seq_q) {
float cur_L = sL_reduction_wksp[i];
gSoftmaxLse(i) = (cur_L == 0.0f || cur_L != cur_L) ? INFINITY : logf(cur_L) + sM(i) / (float)M_LOG2E;
}
cute::tma_store_wait<0>();
} else {
// Don't use __ldg because of PDL and instruction reordering
int split_idx = params.num_splits_ptr[batch_idx] + n_split_idx;
float* oaccum_ptr = (float*)params.oaccum_ptr + ((split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1)
float* softmax_lseaccum_ptr = (float*)params.softmax_lseaccum_ptr + (split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M; // (BLOCK_SIZE_M) : (1)
Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout<
Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_V>>,
Stride<Int<T::HEAD_DIM_V>, _1>
>{});
Tensor gSoftmaxLseAccum = make_tensor(make_gmem_ptr(softmax_lseaccum_ptr), Layout<
Shape<Int<T::BLOCK_SIZE_M>>,
Stride<_1>
>{});
store_o<T, false>(rO, gOAccum, rL, sO_addr, tma_params, batch_idx, k_head_idx, m_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);
int i = threadIdx.x;
if (i < num_valid_seq_q) {
float cur_L = sL_reduction_wksp[i];
gSoftmaxLseAccum(i) = (cur_L == 0.0f || cur_L != cur_L) ? -INFINITY : log2f(cur_L) + sM(i);
}
cute::tma_store_wait<0>();
}
if (batch_idx != sched_meta.end_req_idx)
__syncthreads();
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm90");
}
#endif
} }
...@@ -1279,76 +31,22 @@ void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams &params) { ...@@ -1279,76 +31,22 @@ void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams &params) {
using T = Traits<InputT>; using T = Traits<InputT>;
auto shape_Q = make_shape(params.q_seq_per_hk, params.d, params.h_k, params.b); auto shape_Q = make_shape(params.q_seq_per_hk, params.d, params.h_k, params.b);
auto tma_Q = cute::make_tma_copy(
SM90_TMA_LOAD{}, auto mla_kernel = &flash_fwd_splitkv_mla_kernel<T>;
make_tensor(
make_gmem_ptr((InputT*)params.q_ptr),
make_layout(
shape_Q,
make_stride(params.q_row_stride, _1{}, params.q_head_stride, params.q_batch_stride)
)
),
tile_to_shape(
GMMA::Layout_K_SW128_Atom<InputT>{},
Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_K>>{}
)
);
auto shape_K = make_shape(Int<T::PAGE_BLOCK_SIZE>{}, Int<T::HEAD_DIM_K>{}, params.h_k, params.num_blocks);
auto tma_K = cute::make_tma_copy(
SM90_TMA_LOAD{},
make_tensor(
make_gmem_ptr((InputT*)params.k_ptr),
make_layout(
shape_K,
make_stride(params.k_row_stride, _1{}, params.k_head_stride, params.k_batch_stride)
)
),
tile_to_shape(
GMMA::Layout_K_SW128_Atom<InputT>{},
Layout<
Shape<Int<T::PAGE_BLOCK_SIZE>, Int<64>>,
Stride<Int<T::HEAD_DIM_K>, _1>
>{}
)
);
auto shape_O = make_shape(params.q_seq_per_hk, params.d_v, params.h_k, params.b);
auto tma_O = cute::make_tma_copy(
SM90_TMA_STORE{},
make_tensor(
make_gmem_ptr((InputT*)params.o_ptr),
make_layout(
shape_O,
make_stride(params.o_row_stride, _1{}, params.o_head_stride, params.o_batch_stride)
)
),
tile_to_shape(
GMMA::Layout_K_SW128_Atom<InputT>{},
Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_V>>{}
)
);
TmaParams<decltype(shape_Q), decltype(tma_Q), decltype(shape_K), decltype(tma_K), decltype(shape_O), decltype(tma_O)> tma_params = {
shape_Q, tma_Q,
shape_K, tma_K,
shape_O, tma_O
};
auto mla_kernel = &flash_fwd_splitkv_mla_kernel<T, decltype(tma_params)>;
constexpr size_t smem_size = sizeof(typename T::SharedMemoryPlan); constexpr size_t smem_size = sizeof(typename T::SharedMemoryPlan);
CHECK_CUDA(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch) // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
const int num_m_block = cute::ceil_div(params.q_seq_per_hk, T::BLOCK_SIZE_M); const int num_m_block = cute::ceil_div(params.q_seq_per_hk, T::BLOCK_SIZE_M);
cudaLaunchAttribute mla_kernel_attributes[1];
mla_kernel_attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; // cudaLaunchConfig_t mla_kernel_config = {
mla_kernel_attributes[0].val.programmaticStreamSerializationAllowed = 1; // dim3(num_m_block, params.h_k, params.num_sm_parts),
cudaLaunchConfig_t mla_kernel_config = { // dim3(T::NUM_THREADS, 1, 1),
dim3(num_m_block, params.h_k, params.num_sm_parts), // smem_size,
dim3(T::NUM_THREADS, 1, 1), // params.stream,
smem_size, // mla_kernel_attributes,
params.stream, // 1
mla_kernel_attributes, // };
1 // cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params);
};
cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params);
CHECK_CUDA_KERNEL_LAUNCH(); CHECK_CUDA_KERNEL_LAUNCH();
} }
......
...@@ -84,24 +84,5 @@ struct Traits { ...@@ -84,24 +84,5 @@ struct Traits {
}; };
template<
typename ShapeQ, typename TMA_Q,
typename ShapeK, typename TMA_K,
typename ShapeO, typename TMA_O
>
struct TmaParams {
ShapeQ shape_Q;
TMA_Q tma_Q;
ShapeK shape_K;
TMA_K tma_K;
ShapeO shape_O;
TMA_O tma_O;
};
enum NamedBarriers : int {
sScale0Ready = 0,
sScale1Ready = 1,
sP0Ready = 2,
rO1sP0sV0RIssued = 3,
sMInitialized = 4,
};
#pragma once #pragma once
#include <cutlass/numeric_types.h> #include <cutlass/numeric_types.h>
#include <cutlass/arch/barrier.h> // #include <cutlass/arch/barrier.h>
#include <cute/tensor.hpp> #include <cute/tensor.hpp>
#include "defines.h" #include "defines.h"
......
#pragma once #pragma once
#include <cuda_fp8.h> // #include <cuda_fp8.h>
#include <cuda_bf16.h> // #include <cuda_bf16.h>
#include "defines.h" #include "defines.h"
namespace sm90::decode::sparse_fp8 { namespace sm90::decode::sparse_fp8 {
struct fp8x8 { // struct fp8x8 {
__nv_fp8x4_e4m3 lo; // // __nv_fp8x4_e4m3 lo;
__nv_fp8x4_e4m3 hi; // // __nv_fp8x4_e4m3 hi;
}; // };
struct fp8x16 { // struct fp8x16 {
fp8x8 lo; // fp8x8 lo;
fp8x8 hi; // fp8x8 hi;
}; // };
__device__ __forceinline__ // __device__ __forceinline__
bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const __nv_bfloat162 &scale_bf162) { // bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const __nv_bfloat162 &scale_bf162) {
#define DEQUANT_FP8x4(OUTPUT_BF16_LO, OUTPUT_BF16_HI, FP8x4) \ // #define DEQUANT_FP8x4(OUTPUT_BF16_LO, OUTPUT_BF16_HI, FP8x4) \
{ \ // { \
float4 fp32x4 = (float4)(FP8x4); \ // float4 fp32x4 = (float4)(FP8x4); \
OUTPUT_BF16_LO = __float22bfloat162_rn({fp32x4.x, fp32x4.y})*scale_bf162; \ // OUTPUT_BF16_LO = __float22bfloat162_rn({fp32x4.x, fp32x4.y})*scale_bf162; \
OUTPUT_BF16_HI = __float22bfloat162_rn({fp32x4.z, fp32x4.w})*scale_bf162; \ // OUTPUT_BF16_HI = __float22bfloat162_rn({fp32x4.z, fp32x4.w})*scale_bf162; \
} // }
bf16x8 result; // bf16x8 result;
DEQUANT_FP8x4(result.a01, result.a23, inputs.lo); // DEQUANT_FP8x4(result.a01, result.a23, inputs.lo);
DEQUANT_FP8x4(result.a45, result.a67, inputs.hi); // DEQUANT_FP8x4(result.a45, result.a67, inputs.hi);
return result; // return result;
} // }
enum class L1CacheHint { // enum class L1CacheHint {
NO_ALLOCATE, // NO_ALLOCATE,
EVICT_FIRST, // EVICT_FIRST,
EVICT_NORMAL, // EVICT_NORMAL,
EVICT_LAST // EVICT_LAST
}; // };
enum class L2PrefetchHint { // enum class L2PrefetchHint {
B64, // B64,
B128, // B128,
B256 // B256
}; // };
template< // template<
typename T, // typename T,
L1CacheHint l1_cache_hint, // L1CacheHint l1_cache_hint,
L2PrefetchHint l2_prefetch_hint // L2PrefetchHint l2_prefetch_hint
> // >
__device__ __forceinline__ // __device__ __forceinline__
T load_128b_from_gmem(const void* addr) { // T load_128b_from_gmem(const void* addr) {
static_assert(sizeof(T) == 128/8); // static_assert(sizeof(T) == 128/8);
int4 ret; // int4 ret;
#define EXEC(L1_HINT_STR, L2_HINT_STR) { \ // #define EXEC(L1_HINT_STR, L2_HINT_STR) { \
asm volatile("ld.global.nc.L1::" L1_HINT_STR ".L2::" L2_HINT_STR ".v4.s32 {%0, %1, %2, %3}, [%4];" \ // asm volatile("ld.global.nc.L1::" L1_HINT_STR ".L2::" L2_HINT_STR ".v4.s32 {%0, %1, %2, %3}, [%4];" \
: "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) \ // : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) \
: "l"(addr)); \ // : "l"(addr)); \
} // }
#define DISPATCH_L2(L1_HINT_STR) { \ // #define DISPATCH_L2(L1_HINT_STR) { \
if constexpr(l2_prefetch_hint == L2PrefetchHint::B64) \ // if constexpr(l2_prefetch_hint == L2PrefetchHint::B64) \
EXEC(L1_HINT_STR, "64B") \ // EXEC(L1_HINT_STR, "64B") \
else if constexpr(l2_prefetch_hint == L2PrefetchHint::B128) \ // else if constexpr(l2_prefetch_hint == L2PrefetchHint::B128) \
EXEC(L1_HINT_STR, "128B") \ // EXEC(L1_HINT_STR, "128B") \
else if constexpr(l2_prefetch_hint == L2PrefetchHint::B256) \ // else if constexpr(l2_prefetch_hint == L2PrefetchHint::B256) \
EXEC(L1_HINT_STR, "256B") \ // EXEC(L1_HINT_STR, "256B") \
} // }
if constexpr(l1_cache_hint == L1CacheHint::NO_ALLOCATE) // if constexpr(l1_cache_hint == L1CacheHint::NO_ALLOCATE)
DISPATCH_L2("no_allocate") // DISPATCH_L2("no_allocate")
else if constexpr(l1_cache_hint == L1CacheHint::EVICT_FIRST) // else if constexpr(l1_cache_hint == L1CacheHint::EVICT_FIRST)
DISPATCH_L2("evict_first") // DISPATCH_L2("evict_first")
else if constexpr(l1_cache_hint == L1CacheHint::EVICT_NORMAL) // else if constexpr(l1_cache_hint == L1CacheHint::EVICT_NORMAL)
DISPATCH_L2("evict_normal") // DISPATCH_L2("evict_normal")
else if constexpr(l1_cache_hint == L1CacheHint::EVICT_LAST) // else if constexpr(l1_cache_hint == L1CacheHint::EVICT_LAST)
DISPATCH_L2("evict_last") // DISPATCH_L2("evict_last")
#undef EXEC // #undef EXEC
#undef DISPATCH_L2 // #undef DISPATCH_L2
return *reinterpret_cast<T*>(&ret); // return *reinterpret_cast<T*>(&ret);
} // }
template< // template<
typename T, // typename T,
L1CacheHint l1_cache_hint, // L1CacheHint l1_cache_hint,
L2PrefetchHint l2_prefetch_hint // L2PrefetchHint l2_prefetch_hint
> // >
__device__ __forceinline__ // __device__ __forceinline__
T load_64b_from_gmem(const void* addr) { // T load_64b_from_gmem(const void* addr) {
static_assert(sizeof(T) == 64/8); // static_assert(sizeof(T) == 64/8);
int2 ret; // int2 ret;
#define EXEC(L1_HINT_STR, L2_HINT_STR) { \ // #define EXEC(L1_HINT_STR, L2_HINT_STR) { \
asm volatile("ld.global.nc.L1::" L1_HINT_STR ".L2::" L2_HINT_STR ".v2.s32 {%0, %1}, [%2];" \ // asm volatile("ld.global.nc.L1::" L1_HINT_STR ".L2::" L2_HINT_STR ".v2.s32 {%0, %1}, [%2];" \
: "=r"(ret.x), "=r"(ret.y) \ // : "=r"(ret.x), "=r"(ret.y) \
: "l"(addr)); \ // : "l"(addr)); \
} // }
#define DISPATCH_L2(L1_HINT_STR) { \ // #define DISPATCH_L2(L1_HINT_STR) { \
if constexpr(l2_prefetch_hint == L2PrefetchHint::B64) \ // if constexpr(l2_prefetch_hint == L2PrefetchHint::B64) \
EXEC(L1_HINT_STR, "64B") \ // EXEC(L1_HINT_STR, "64B") \
else if constexpr(l2_prefetch_hint == L2PrefetchHint::B128) \ // else if constexpr(l2_prefetch_hint == L2PrefetchHint::B128) \
EXEC(L1_HINT_STR, "128B") \ // EXEC(L1_HINT_STR, "128B") \
else if constexpr(l2_prefetch_hint == L2PrefetchHint::B256) \ // else if constexpr(l2_prefetch_hint == L2PrefetchHint::B256) \
EXEC(L1_HINT_STR, "256B") \ // EXEC(L1_HINT_STR, "256B") \
} // }
if constexpr(l1_cache_hint == L1CacheHint::NO_ALLOCATE) // if constexpr(l1_cache_hint == L1CacheHint::NO_ALLOCATE)
DISPATCH_L2("no_allocate") // DISPATCH_L2("no_allocate")
else if constexpr(l1_cache_hint == L1CacheHint::EVICT_FIRST) // else if constexpr(l1_cache_hint == L1CacheHint::EVICT_FIRST)
DISPATCH_L2("evict_first") // DISPATCH_L2("evict_first")
else if constexpr(l1_cache_hint == L1CacheHint::EVICT_NORMAL) // else if constexpr(l1_cache_hint == L1CacheHint::EVICT_NORMAL)
DISPATCH_L2("evict_normal") // DISPATCH_L2("evict_normal")
else if constexpr(l1_cache_hint == L1CacheHint::EVICT_LAST) // else if constexpr(l1_cache_hint == L1CacheHint::EVICT_LAST)
DISPATCH_L2("evict_last") // DISPATCH_L2("evict_last")
#undef EXEC // #undef EXEC
#undef DISPATCH_L2 // #undef DISPATCH_L2
return *reinterpret_cast<T*>(&ret); // return *reinterpret_cast<T*>(&ret);
} // }
} }
...@@ -9,101 +9,101 @@ using namespace cute; ...@@ -9,101 +9,101 @@ using namespace cute;
namespace sm90::decode::sparse_fp8 { namespace sm90::decode::sparse_fp8 {
// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~1) to the actual row_idx // // In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~1) to the actual row_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a // // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
__forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) { // __forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) {
int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4); // int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4);
return row_idx; // return row_idx;
} // }
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h // // Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma> // template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { // __forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value; // constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const // // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); } // if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC); // warpgroup_fence_operand(tCrC);
if constexpr (arrive) { // if constexpr (arrive) {
warpgroup_arrive(); // warpgroup_arrive();
} // }
if constexpr (zero_init) { // if constexpr (zero_init) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; // tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// Unroll the K mode manually to set scale D to 1 // // Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL // CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { // for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); // cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One; // tiled_mma.accumulate_ = GMMA::ScaleOut::One;
} // }
} else { // } else {
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC); // // cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
// Unroll the K mode manually to set scale D to 1 // // Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL // CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { // for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); // cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One; // tiled_mma.accumulate_ = GMMA::ScaleOut::One;
} // }
} // }
if constexpr (commit) { // if constexpr (commit) {
warpgroup_commit_batch(); // warpgroup_commit_batch();
} // }
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); } // if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_fence_operand(tCrC); // warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); } // if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
} // }
template< // template<
typename TMA, // typename TMA,
typename Tensor0, // typename Tensor0,
typename Tensor1 // typename Tensor1
> // >
CUTE_DEVICE // CUTE_DEVICE
void launch_tma_copy( // void launch_tma_copy(
const TMA &tma_copy, // const TMA &tma_copy,
const Tensor0 &src, // const Tensor0 &src,
Tensor1 &dst, // Tensor1 &dst,
transac_bar_t &bar, // transac_bar_t &bar,
const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL, // const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL,
const uint16_t &multicast_mask = 0 // const uint16_t &multicast_mask = 0
) { // ) {
auto thr_tma = tma_copy.get_slice(_0{}); // auto thr_tma = tma_copy.get_slice(_0{});
cute::copy( // cute::copy(
tma_copy.with(reinterpret_cast<typename transac_bar_t::ValueType&>(bar), multicast_mask, cache_hint), // tma_copy.with(reinterpret_cast<typename transac_bar_t::ValueType&>(bar), multicast_mask, cache_hint),
thr_tma.partition_S(src), // thr_tma.partition_S(src),
thr_tma.partition_D(dst) // thr_tma.partition_D(dst)
); // );
} // }
template<typename T> // template<typename T>
CUTE_DEVICE // CUTE_DEVICE
static void st_async_128b(void* dst_ptr, const T& data, const transac_bar_t* mbar_ptr) { // static void st_async_128b(void* dst_ptr, const T& data, const transac_bar_t* mbar_ptr) {
long2 data_long2 = *reinterpret_cast<const long2*>(&data); // long2 data_long2 = *reinterpret_cast<const long2*>(&data);
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr); // uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr); // uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr);
asm volatile ( // asm volatile (
"st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n" // "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n"
: // :
: "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr) // : "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr)
); // );
} // }
CUTE_DEVICE // CUTE_DEVICE
static void cp_async_bulk_shared_cta_shared_cluster(void* dst_ptr, void* src_ptr, int size, transac_bar_t* mbar_ptr) { // static void cp_async_bulk_shared_cta_shared_cluster(void* dst_ptr, void* src_ptr, int size, transac_bar_t* mbar_ptr) {
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr); // uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr);
uint32_t src_addr = cute::cast_smem_ptr_to_uint(src_ptr); // uint32_t src_addr = cute::cast_smem_ptr_to_uint(src_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr); // uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr);
asm volatile ( // asm volatile (
"cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3]; \n" // "cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3]; \n"
: // :
: "r"(dst_addr), "r"(src_addr), "r"(size), "r"(mbar_addr) // : "r"(dst_addr), "r"(src_addr), "r"(size), "r"(mbar_addr)
); // );
} // }
static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. // static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK.
template<typename T> // template<typename T>
CUTE_DEVICE // CUTE_DEVICE
T* get_peer_addr(T* p) { // T* get_peer_addr(T* p) {
return (T*)((int64_t)(p) ^ PEER_ADDR_MASK); // return (T*)((int64_t)(p) ^ PEER_ADDR_MASK);
} // }
} }
...@@ -89,188 +89,54 @@ using SmemLayoutS = decltype(tile_to_shape( ...@@ -89,188 +89,54 @@ using SmemLayoutS = decltype(tile_to_shape(
)); ));
struct SharedMemoryPlan { struct SharedMemoryPlan {
array_aligned<bf16, cosize_v<SmemLayoutQ>> q; // array_aligned<bf16, cosize_v<SmemLayoutQ>> q;
union { // union {
array_aligned<bf16, cosize_v<SmemLayoutK>> k[NUM_K_BUFS]; // array_aligned<bf16, cosize_v<SmemLayoutK>> k[NUM_K_BUFS];
array_aligned<bf16, cosize_v<SmemLayoutOBuf>> oBuf; // array_aligned<bf16, cosize_v<SmemLayoutOBuf>> oBuf;
array_aligned<float, cosize_v<SmemLayoutOAccumBuf>> oAccumBuf; // array_aligned<float, cosize_v<SmemLayoutOAccumBuf>> oAccumBuf;
} u; // } u;
CUTE_ALIGNAS(1024) array_aligned<bf16, cosize_v<SmemLayoutS>> s; // CUTE_ALIGNAS(1024) array_aligned<bf16, cosize_v<SmemLayoutS>> s;
bool is_kv_valid[NUM_K_BUFS][TOPK_BLOCK_SIZE]; // bool is_kv_valid[NUM_K_BUFS][TOPK_BLOCK_SIZE];
float sM[BLOCK_M], sL[BLOCK_M], sScale[BLOCK_M], sOScale[BLOCK_M]; // float sM[BLOCK_M], sL[BLOCK_M], sScale[BLOCK_M], sOScale[BLOCK_M];
transac_bar_t bar_q, bar_k_local_ready[NUM_K_BUFS], bar_k_remote_ready[NUM_K_BUFS], bar_k_avail[NUM_K_BUFS]; // transac_bar_t bar_q, bar_k_local_ready[NUM_K_BUFS], bar_k_remote_ready[NUM_K_BUFS], bar_k_avail[NUM_K_BUFS];
}; };
template< // template<
typename Shape_Q, typename TMA_Q // typename Shape_Q, typename TMA_Q
> // >
struct TmaParams {
Shape_Q shape_Q; TMA_Q tma_Q;
CUtensorMap tensor_map_o;
};
using TiledMMA_QK = decltype(make_tiled_mma( // using TiledMMA_QK = decltype(make_tiled_mma(
GMMA::MMA_64x64x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::K>{}, // GMMA::MMA_64x64x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::K>{},
Layout<Shape<_1, _1, _1>>{} // Layout<Shape<_1, _1, _1>>{}
)); // ));
using TiledMMA_QK_rQ = decltype(make_tiled_mma( // using TiledMMA_QK_rQ = decltype(make_tiled_mma(
GMMA::MMA_64x64x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::K>{}, // GMMA::MMA_64x64x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::K>{},
Layout<Shape<_1, _1, _1>>{} // Layout<Shape<_1, _1, _1>>{}
)); // ));
using TiledMMA_PV_LocalP = decltype(make_tiled_mma( // using TiledMMA_PV_LocalP = decltype(make_tiled_mma(
GMMA::MMA_64x256x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::MN>{}, // GMMA::MMA_64x256x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{} // Layout<Shape<_1, _1, _1>>{}
)); // ));
// using TiledMMA_PV_RemoteP = decltype(make_tiled_mma(
// GMMA::MMA_64x256x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::MN>{},
// Layout<Shape<_1, _1, _1>>{}
// ));
using TiledMMA_PV_RemoteP = decltype(make_tiled_mma(
GMMA::MMA_64x256x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{}
));
enum NamedBarriers : uint32_t {
sScale_and_sS_ready = 0,
sScale_and_sS_free = 1,
oBuf_free_and_sL_ready = 2,
epilogue_r2s_ready = 3,
batch_loop_sync = 4,
warpgroup0_sync = 5
};
// Synchronize all threads within the cluster (which processes one q token)
static __forceinline__ __device__ void sync_all_threads_in_cluster() {
if constexpr (CLUSTER_SIZE == 1) {
__syncthreads();
} else {
ku::barrier_cluster_arrive_relaxed();
ku::barrier_cluster_wait_acquire();
}
}
// Save rPb (64x64, bfloat16) to sP using the stmatrix instruction
template<
typename Tensor0,
typename Tensor1
>
static __forceinline__ __device__ void save_rPb_to_sP(
Tensor0 const &rPb,
Tensor1 const &sP,
int idx_in_warpgroup
) {
auto r2s_copy = make_tiled_copy_C(
Copy_Atom<SM90_U32x4_STSM_N, bf16>{},
TiledMMA_QK{}
);
ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup);
Tensor thr_copy_rPb = thr_copy.retile_S(rPb);
Tensor thr_copy_sP = thr_copy.partition_D(sP);
cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP);
}
template<
bool IS_NO_SPLIT,
typename TMAParams,
typename Tensor0,
typename Tensor1,
typename Tensor2,
typename Tensor3
>
static __forceinline__ __device__ void store_o(
Tensor0 &rO, // ((2, 2, 32), 1, 1)
Tensor1 &gOorAccum, // (BLOCK_SIZE_M, HEAD_DIM_V)
Tensor2 &sOutputBuf,
Tensor3 &sOutputAccumBuf,
SharedMemoryPlan &plan,
float o_scales[2],
TMAParams &tma_params,
int batch_idx,
int s_q_idx,
int head_block_idx,
int num_valid_seq_q,
int warpgroup_idx,
int idx_in_warpgroup
) {
using cutlass::arch::NamedBarrier;
if constexpr (IS_NO_SPLIT) {
// Should convert the output to bfloat16 / float16, and save it to O
// Here we don't pipeline STSM and tma store because it's slower
Tensor sMyOutputBuf = local_tile(sOutputBuf, Shape<_64, _256>{}, make_coord(_0{}, warpgroup_idx));
// Calculate "base" ptrs in advance
// Each STSM fills a chunk of shape 16x16, while we are using SW-OBUF_SW, so we need OBUF_SW/16 base pointers
constexpr int NUM_CHUNKS_IN_SW_ATOM = OBUF_SW/16;
bf16* base_output_buf_ptrs[NUM_CHUNKS_IN_SW_ATOM];
CUTE_UNROLL
for (int i = 0; i < NUM_CHUNKS_IN_SW_ATOM; ++i) {
base_output_buf_ptrs[i] = &sMyOutputBuf((idx_in_warpgroup/32)*16+idx_in_warpgroup%16, idx_in_warpgroup%32/16*8 + i*16);
}
CUTE_UNROLL
for (int idx = 0; idx < (HEAD_DIM_V/2)/16; idx += 1) {
// In each iteration we deal with a chunk of shape 16x16
using bf16x2 = __nv_bfloat162;
bf16x2 a01 = __float22bfloat162_rn(float2{rO(idx*8+0)*o_scales[0], rO(idx*8+1)*o_scales[0]});
bf16x2 a23 = __float22bfloat162_rn(float2{rO(idx*8+2)*o_scales[1], rO(idx*8+3)*o_scales[1]});
bf16x2 a45 = __float22bfloat162_rn(float2{rO(idx*8+4)*o_scales[0], rO(idx*8+5)*o_scales[0]});
bf16x2 a67 = __float22bfloat162_rn(float2{rO(idx*8+6)*o_scales[1], rO(idx*8+7)*o_scales[1]});
SM90_U32x4_STSM_N::copy(
*reinterpret_cast<uint32_t*>(&a01),
*reinterpret_cast<uint32_t*>(&a23),
*reinterpret_cast<uint32_t*>(&a45),
*reinterpret_cast<uint32_t*>(&a67),
*reinterpret_cast<uint128_t*>(base_output_buf_ptrs[idx%4] + (idx/4*4)*16*64)
);
}
cutlass::arch::fence_view_async_shared();
NamedBarrier::arrive_and_wait(256, NamedBarriers::epilogue_r2s_ready);
if (threadIdx.x == 0) {
SM90_TMA_STORE_5D::copy(
&tma_params.tensor_map_o,
plan.u.oBuf.data(),
0, head_block_idx*64, 0,
s_q_idx, batch_idx
);
cute::tma_store_arrive();
}
} else {
// Should save the result to OAccum
CUTLASS_PRAGMA_UNROLL
for (int idx = 0; idx < size(rO); idx += 2) {
int row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%32/4) + (idx%4 >= 2 ? 8 : 0);
int col = warpgroup_idx*256 + (idx_in_warpgroup%4)*2 + idx/4*8;
*(float2*)(&(sOutputAccumBuf(row, col))) = float2 {
rO(idx) * o_scales[idx%4>=2],
rO(idx+1) * o_scales[idx%4>=2],
};
}
cutlass::arch::fence_view_async_shared();
NamedBarrier::arrive_and_wait(256, NamedBarriers::epilogue_r2s_ready);
if (elect_one_sync()) {
CUTLASS_PRAGMA_UNROLL
for (int local_row = 0; local_row < BLOCK_M / (256/32); ++local_row) {
int row = local_row * (256/32) + (threadIdx.x / 32);
if (row < num_valid_seq_q) {
SM90_BULK_COPY_S2G::copy(&sOutputAccumBuf(row, _0{}), &gOorAccum(row, _0{}), HEAD_DIM_V*sizeof(float));
}
}
cute::tma_store_arrive();
}
}
}
template<typename TMAParams>
static __device__ __forceinline__ void static __device__ __forceinline__ void
devfunc(const SparseAttnDecodeParams &params, const TMAParams &tma_params); devfunc(const SparseAttnDecodeParams &params);
static void run(const SparseAttnDecodeParams &params); static void run(const SparseAttnDecodeParams &params);
......
...@@ -2,12 +2,12 @@ ...@@ -2,12 +2,12 @@
#include "splitkv_mla.h" #include "splitkv_mla.h"
#include <cuda_fp8.h> // #include <cuda_fp8.h>
#include <math_constants.h> // #include <math_constants.h>
#include <cutlass/barrier.h> // #include <cutlass/barrier.h>
#include <cutlass/arch/barrier.h> // #include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h> // #include <cutlass/arch/reg_reconfig.h>
#include <cutlass/cluster_launch.hpp> // #include <cutlass/cluster_launch.hpp>
#include <kerutils/kerutils.cuh> #include <kerutils/kerutils.cuh>
...@@ -20,667 +20,18 @@ using namespace cute; ...@@ -20,667 +20,18 @@ using namespace cute;
namespace sm90::decode::sparse_fp8 { namespace sm90::decode::sparse_fp8 {
static constexpr float MAX_INIT_VAL = -1e30; // Prevent (-inf) - (-inf) = nan static constexpr float MAX_INIT_VAL = -1e30; // Prevent (-inf) - (-inf) = nan
using cutlass::arch::fence_view_async_shared;
using cutlass::arch::NamedBarrier;
using fp8_e8m0 = __nv_fp8_e8m0;
template<
typename Tensor0,
typename Tensor1,
typename Tensor2
>
__forceinline__ __device__ void scale_softmax(
Tensor0 &rP,
Tensor1 &rS,
Tensor2 &rO,
float scale_softmax_log2,
float sScale[],
float rM[2],
float rL[2],
bool is_kv_valid[],
int block_idx,
int idx_in_warpgroup
) {
float scale_for_olds[2];
CUTE_UNROLL
for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {
Tensor cur_rP = flatten(rP(make_coord(_, local_row_idx, _), _, _));
Tensor cur_rS = flatten(rS(make_coord(_, local_row_idx, _), _, _));
Tensor cur_rO = flatten(rO(make_coord(_, local_row_idx, _), _, _));
float cur_max = -INFINITY;
CUTE_UNROLL
for (int i = 0; i < size(cur_rP); ++i) {
if (!is_kv_valid[(i&1)+(i/2)*8+(idx_in_warpgroup%4)*2])
cur_rP(i) = -INFINITY;
cur_max = max(cur_max, cur_rP(i));
}
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1));
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2));
cur_max *= scale_softmax_log2;
float old_max = rM[local_row_idx];
rM[local_row_idx] = max(cur_max, old_max);
float scale_for_old = exp2f(old_max - rM[local_row_idx]);
scale_for_olds[local_row_idx] = scale_for_old;
CUTE_UNROLL
for (int i = 0; i < size(cur_rO); ++i) {
cur_rO(i) *= scale_for_old;
}
float cur_sum = 0;
CUTE_UNROLL
for (int i = 0; i < size(cur_rP); ++i) {
cur_rP(i) = exp2f(cur_rP(i)*scale_softmax_log2 - rM[local_row_idx]);
cur_rS(i) = (bf16)cur_rP(i);
cur_sum += cur_rP(i);
}
rL[local_row_idx] = rL[local_row_idx]*scale_for_old + cur_sum;
}
if (idx_in_warpgroup%4 == 0)
*(float2*)(sScale + 2*(idx_in_warpgroup/4)) = *(float2*)(scale_for_olds);
}
template<ModelType MODEL_TYPE, int NUM_HEADS> template<ModelType MODEL_TYPE, int NUM_HEADS>
template<typename TMAParams> __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::devfunc(const SparseAttnDecodeParams &params) {
__device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::devfunc(const SparseAttnDecodeParams &params, const TMAParams &tma_params) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__))
const int head_block_idx = NUM_M_BLOCKS == 1 ? 0 : blockIdx.x;
const int s_q_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
const int idx_in_cluster = CLUSTER_SIZE == 1 ? 0 : head_block_idx % 2;
const int warpgroup_idx = cutlass::canonical_warp_group_idx();
const int idx_in_warpgroup = threadIdx.x % 128;
const int warp_idx = cutlass::canonical_warp_idx_sync();
// Define shared tensors
extern __shared__ char wksp_buf[];
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);
Tensor sQ = make_tensor(make_smem_ptr(plan.q.data()), SmemLayoutQ{});
Tensor sOBuf = make_tensor(make_smem_ptr(plan.u.oBuf.data()), SmemLayoutOBuf{});
Tensor sOAccumBuf = make_tensor(make_smem_ptr(plan.u.oAccumBuf.data()), SmemLayoutOAccumBuf{});
Tensor sS = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutS{});
float* sM = plan.sM;
float* sL = plan.sL;
float* sScale = plan.sScale;
// Prefetch TMA descriptors
if (warp_idx == 0 && elect_one_sync()) {
cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(&tma_params.tensor_map_o);
}
// Initialize TMA barriers
if (warp_idx == 0 && elect_one_sync()) {
plan.bar_q.init(1);
if constexpr (CLUSTER_SIZE == 2) {
CUTE_UNROLL
for (int i = 0; i < NUM_K_BUFS; ++i) {
plan.bar_k_local_ready[i].init(128);
plan.bar_k_remote_ready[i].init(1);
plan.bar_k_avail[i].init(4);
}
} else {
CUTE_UNROLL
for (int i = 0; i < NUM_K_BUFS; ++i) {
plan.bar_k_local_ready[i].init(128);
plan.bar_k_avail[i].init(256);
}
}
cutlass::arch::fence_barrier_init();
}
ku::barrier_cluster_arrive_relaxed();
int bar_phase_k = 0; // Don't use array here to prevent using local memory
// Programmatic Dependent Launch: Wait for the previous kernel to finish
// Don't use PDL because of compiler bugs!
// cudaGridDependencySynchronize();
DecodingSchedMeta sched_meta = params.tile_scheduler_metadata_ptr[partition_idx];
if (sched_meta.begin_req_idx >= params.b) return;
if (warp_idx == 0 && elect_one_sync()) {
Tensor gQ = flat_divide(
tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, sched_meta.begin_req_idx),
Tile<Int<BLOCK_M>, Int<HEAD_DIM_K>>{}
)(_, _, head_block_idx, _0{});
launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST);
plan.bar_q.arrive_and_expect_tx(BLOCK_M*HEAD_DIM_K*sizeof(bf16));
}
ku::barrier_cluster_wait_acquire();
struct MainloopArgs {
int start_block_idx, end_block_idx;
bool is_no_split;
// The following fields are only valid for MODEL1
int topk_length, extra_topk_length, num_orig_kv_blocks;
};
auto get_cur_req_info = [&](int batch_idx) -> MainloopArgs {
MainloopArgs args;
int total_topk_padded;
if constexpr (MODEL_TYPE == ModelType::V32) {
total_topk_padded = params.topk;
} else {
int topk_length = params.topk_length ? __ldg(params.topk_length + batch_idx) : params.topk;
int orig_topk_padded = max(ku::ceil(topk_length, (int)TOPK_BLOCK_SIZE), (int)TOPK_BLOCK_SIZE);
int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk;
total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)TOPK_BLOCK_SIZE);
args.topk_length = topk_length;
args.extra_topk_length = extra_topk_length;
args.num_orig_kv_blocks = orig_topk_padded / TOPK_BLOCK_SIZE;
}
args.start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0;
args.end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : total_topk_padded / TOPK_BLOCK_SIZE;
args.is_no_split = batch_idx == sched_meta.begin_req_idx ? !sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? !sched_meta.is_last_req_splitted : true);
return args;
};
if (warpgroup_idx == 0) {
cutlass::arch::warpgroup_reg_alloc<192>();
TiledMMA tiled_mma_QK = TiledMMA_QK{};
ThrMMA thr_mma_QK = tiled_mma_QK.get_slice(idx_in_warpgroup);
TiledMMA tiled_mma_PV = TiledMMA_PV_LocalP{};
ThrMMA thr_mma_PV = tiled_mma_PV.get_slice(idx_in_warpgroup);
float rL[2], rM[2];
Tensor rO = partition_fragment_C(TiledMMA_PV_LocalP{}, Shape<Int<BLOCK_M>, Int<HEAD_DIM_V/2>>{});
Tensor rP = partition_fragment_C(TiledMMA_QK{}, Shape<Int<BLOCK_M>, Int<TOPK_BLOCK_SIZE>>{});
Tensor rS = make_tensor<bf16>(partition_shape_A(TiledMMA_PV_LocalP{}, Shape<Int<BLOCK_M>, Int<TOPK_BLOCK_SIZE>>{}));
float rAttn_sink[2] = {-CUDART_INF_F, -CUDART_INF_F};
if (params.attn_sink != nullptr) {
for (int i = 0; i < 2; ++i) {
int head_idx = head_block_idx*BLOCK_M + get_AorC_row_idx(i, idx_in_warpgroup);
rAttn_sink[i] = __ldg((float*)params.attn_sink + head_idx) * CUDART_L2E_F;
}
}
#pragma unroll 1
for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) {
MainloopArgs args = get_cur_req_info(batch_idx);
rL[0] = rL[1] = 0.0f;
rM[0] = rM[1] = MAX_INIT_VAL;
cute::fill(rO, 0.);
// Wait for Q
plan.bar_q.wait((sched_meta.begin_req_idx-batch_idx)&1);
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; block_idx++) {
int buf_idx = (block_idx-args.start_block_idx) % NUM_K_BUFS;
Tensor sK = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutHalfV{});
// Wait, issue WGMMA
plan.bar_k_local_ready[buf_idx].wait(bar_phase_k>>buf_idx&1);
if constexpr (CLUSTER_SIZE == 2) {
plan.bar_k_remote_ready[buf_idx].wait(bar_phase_k>>buf_idx&1);
}
gemm<true, -1>(
tiled_mma_QK,
thr_mma_QK.partition_fragment_A(sQ),
thr_mma_QK.partition_fragment_B(sK),
rP
);
bar_phase_k ^= 1<<buf_idx;
cute::warpgroup_wait<0>();
// Calculate S = softmax(mask(scale(P)))
if (block_idx != args.start_block_idx)
NamedBarrier::arrive_and_wait(256, NamedBarriers::sScale_and_sS_free); // Make sure that sScale and sS is free
// Since in our case TOPK_BLOCK_SIZE == BLOCK_M, so we only need to do OOB checking for the last 2 blocks
scale_softmax(rP, rS, rO, params.sm_scale_div_log2, sScale, rM, rL, plan.is_kv_valid[buf_idx], block_idx, idx_in_warpgroup);
// Store S into shared, inform warpgroup 1
save_rPb_to_sP(rS, sS, idx_in_warpgroup);
fence_view_async_shared();
// Issue O += S @ V
gemm<false, -1>(
tiled_mma_PV,
rS,
thr_mma_PV.partition_fragment_B(sV),
rO
);
NamedBarrier::arrive(256, NamedBarriers::sScale_and_sS_ready);
cute::warpgroup_wait<0>();
if constexpr (CLUSTER_SIZE == 2) {
plan.bar_k_avail[buf_idx].arrive(0, idx_in_warpgroup == 32);
plan.bar_k_avail[buf_idx].arrive(1, idx_in_warpgroup == 64);
} else {
plan.bar_k_avail[buf_idx].arrive();
}
}
// Copy the next q
if (threadIdx.x/32 == 0 && elect_one_sync()) {
if (batch_idx != sched_meta.end_req_idx) {
Tensor gQ = flat_divide(
tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, batch_idx+1),
Tile<Int<BLOCK_M>, Int<HEAD_DIM_K>>{}
)(_, _, head_block_idx, _0{});
launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST);
plan.bar_q.arrive_and_expect_tx(BLOCK_M*HEAD_DIM_K*sizeof(bf16));
} else {
// This kernel is followed by the combine kernel, so we signal PDL here
cudaTriggerProgrammaticLaunchCompletion();
}
}
// Synchronize L and M across warpgroups
rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1);
rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2);
rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1);
rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2);
if (idx_in_warpgroup%4 == 0) {
CUTE_UNROLL
for (int i = 0; i < 2; ++i) {
int row = get_AorC_row_idx(i, idx_in_warpgroup);
sL[row] = rL[i];
sM[row] = rM[i];
}
}
float o_scales[2];
CUTE_UNROLL
for (int i = 0; i < 2; ++i) {
if (args.is_no_split) {
o_scales[i] = rL[i] == 0.0f ? 0.0f : __fdividef(1.0f, rL[i] + exp2f(rAttn_sink[i] - rM[i]));
} else {
o_scales[i] = rL[i] == 0.0f ? 0.0f : __fdividef(1.0f, rL[i]);
}
if (idx_in_warpgroup%4 == 0) {
int row = get_AorC_row_idx(i, idx_in_warpgroup);
plan.sOScale[row] = o_scales[i];
}
}
// This is a synchronization point for warpgroup 0/1.
// Warpgroup 0 should wait wg 1 for oBuf/oAccumBuf (overlapped with k) to be free
// Warpgroup 1 should wait wg 0 for sL to be ready
NamedBarrier::arrive_and_wait(256, NamedBarriers::oBuf_free_and_sL_ready);
CUTE_UNROLL
for (int i = 0; i < 2; ++i)
rL[i] = rL[i] == 0.0f ? 1.0f : rL[i];
int start_head_idx = head_block_idx*BLOCK_M;
int num_valid_seq_q = min(params.h_q - start_head_idx, BLOCK_M);
if (args.is_no_split) {
bf16* o_ptr = (bf16*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + start_head_idx*params.stride_o_h_q; // (BLOCK_M, HEAD_DIM_V) : (params.stride_o_h_q, 1)
Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout(
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{},
make_stride(params.stride_o_h_q, _1{})
));
float* gSoftmaxLse = (float*)params.lse + batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + start_head_idx; // (BLOCK_M) : (1)
store_o<true>(rO, gO, sOBuf, sOAccumBuf, plan, o_scales, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);
int i = threadIdx.x;
if (i < num_valid_seq_q) {
float cur_L = sL[i];
gSoftmaxLse[i] = cur_L == 0.0f ? INFINITY : logf(cur_L) + sM[i] / (float)M_LOG2E;
}
cute::tma_store_wait<0>();
} else {
int n_split_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_split_idx : 0;
int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx;
float* oaccum_ptr = (float*)params.o_accum + split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + start_head_idx*params.stride_o_accum_h_q; // (BLOCK_M, HEAD_DIM_V) : (params.stride_o_accum_h_q, 1)
float* gSoftmaxLseAccum = (float*)params.lse_accum + split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + start_head_idx; // (BLOCK_M) : (1)
Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), make_layout(
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{},
make_stride(params.stride_o_accum_h_q, _1{})
));
store_o<false>(rO, gOAccum, sOBuf, sOAccumBuf, plan, o_scales, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);
int i = threadIdx.x;
if (i < num_valid_seq_q) {
float cur_L = sL[i];
gSoftmaxLseAccum[i] = cur_L == 0.0f ? -INFINITY : log2f(cur_L) + sM[i];
}
cute::tma_store_wait<0>();
}
sync_all_threads_in_cluster();
}
} else if (warpgroup_idx == 1) {
cutlass::arch::warpgroup_reg_dealloc<160>();
TiledMMA tiled_mma_PV = TiledMMA_PV_RemoteP{};
ThrMMA thr_mma_PV = tiled_mma_PV.get_slice(idx_in_warpgroup);
Tensor rO = partition_fragment_C(tiled_mma_PV, Shape<Int<BLOCK_M>, Int<HEAD_DIM_V/2>>{});
#pragma unroll 1
for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) {
MainloopArgs args = get_cur_req_info(batch_idx);
cute::fill(rO, 0.);
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; block_idx++) {
int buf_idx = (block_idx-args.start_block_idx) % NUM_K_BUFS;
Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data() + (SmemLayoutV{})(_256{}, _0{})), SmemLayoutHalfV{});
// Wait for S and sScale
NamedBarrier::arrive_and_wait(256, NamedBarriers::sScale_and_sS_ready);
// Scale O
float cur_scales[2];
*(float2*)cur_scales = *(float2*)(sScale + (idx_in_warpgroup/4)*2);
CUTE_UNROLL
for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {
Tensor cur_rO = flatten(rO(make_coord(_, local_row_idx, _), _, _));
CUTE_UNROLL
for (int i = 0; i < size(cur_rO); ++i) {
cur_rO(i) *= cur_scales[local_row_idx];
}
}
// Issue O += S @ V, and wait
gemm<false, -1>(
tiled_mma_PV,
thr_mma_PV.partition_fragment_A(sS),
thr_mma_PV.partition_fragment_B(sV),
rO
);
cute::warpgroup_wait<0>();
if constexpr (CLUSTER_SIZE == 2) {
plan.bar_k_avail[buf_idx].arrive(0, idx_in_warpgroup == 32);
plan.bar_k_avail[buf_idx].arrive(1, idx_in_warpgroup == 64);
} else {
plan.bar_k_avail[buf_idx].arrive();
}
if (block_idx != args.end_block_idx-1)
NamedBarrier::arrive(256, NamedBarriers::sScale_and_sS_free); // Tell WG0 that sScale and sS are available
}
NamedBarrier::arrive_and_wait(256, NamedBarriers::oBuf_free_and_sL_ready);
float o_scales[2];
CUTE_UNROLL
for (int i = 0; i < 2; ++i) {
int row = get_AorC_row_idx(i, idx_in_warpgroup);
o_scales[i] = plan.sOScale[row];
}
int start_head_idx = head_block_idx*BLOCK_M;
int num_valid_seq_q = min(params.h_q - start_head_idx, BLOCK_M);
if (args.is_no_split) {
bf16* o_ptr = (bf16*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + start_head_idx*params.stride_o_h_q; // (BLOCK_M, HEAD_DIM_V) : (params.stride_o_h_q, 1)
Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout(
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{},
make_stride(params.stride_o_h_q, _1{})
));
store_o<true>(rO, gO, sOBuf, sOAccumBuf, plan, o_scales, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);
cute::tma_store_wait<0>();
} else {
int n_split_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_split_idx : 0;
int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx;
float* oaccum_ptr = (float*)params.o_accum + split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + start_head_idx*params.stride_o_accum_h_q; // (BLOCK_M, HEAD_DIM_V) : (params.stride_o_accum_h_q, 1)
Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), make_layout(
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{},
make_stride(params.stride_o_accum_h_q, _1{})
));
store_o<false>(rO, gOAccum, sOBuf, sOAccumBuf, plan, o_scales, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);
cute::tma_store_wait<0>();
}
sync_all_threads_in_cluster();
}
} else {
// Producer warpgroup
cutlass::arch::warpgroup_reg_dealloc<152>();
static_assert(CLUSTER_SIZE == 1 || CLUSTER_SIZE == 2);
static constexpr int NUM_TOKENS_PER_THREAD = CLUSTER_SIZE == 1 ? 2 : 1;
static constexpr int NUM_TOKENS_PER_ROUND = 32; // If head is 128, each CTA is responsible for dequantizing 32 tokens (1 rounds); if head is 64, each CTA is responsible for dequantizing 64 tokens (2 rounds)
int warp_idx = __shfl_sync(0xffffffff, idx_in_warpgroup / 32, 0);
int lane_idx = idx_in_warpgroup % 32;
int my_token_idx_base = warp_idx*8 + lane_idx%8;
CUTE_NO_UNROLL
for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) {
MainloopArgs args = get_cur_req_info(batch_idx);
int* gIndices = params.indices + batch_idx*params.stride_indices_b + s_q_idx*params.stride_indices_s_q; // (topk) : (1)
int* gExtraIndices = params.extra_indices + batch_idx*params.stride_extra_indices_b + s_q_idx*params.stride_extra_indices_s_q; // (extra_topk) : (1)
int nxt_token_indexs[NUM_TOKENS_PER_THREAD];
CUTE_UNROLL
for (int round = 0; round < NUM_TOKENS_PER_THREAD; ++round) {
if (MODEL_TYPE == ModelType::V32 || args.start_block_idx < args.num_orig_kv_blocks)
nxt_token_indexs[round] = __ldg(gIndices + args.start_block_idx*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + round*NUM_TOKENS_PER_ROUND + my_token_idx_base);
}
struct IsOrigBlock {};
struct IsExtraBlock {};
struct IsFirstExtraBlock {};
struct IsNotFirstExtraBlock {};
auto process_one_block = [&](int block_idx, auto is_extra_block_t, auto is_first_extra_block_t) {
static constexpr bool IS_EXTRA_BLOCK = std::is_same_v<decltype(is_extra_block_t), IsExtraBlock>;
static constexpr bool IS_FIRST_EXTRA_BLOCK = std::is_same_v<decltype(is_first_extra_block_t), IsFirstExtraBlock>;
int buf_idx = (block_idx-args.start_block_idx) % NUM_K_BUFS;
int* indices_base;
int page_block_size;
int64_t k_block_stride, k_row_stride;
fp8* k_ptr;
if constexpr (!IS_EXTRA_BLOCK) {
indices_base = gIndices + (block_idx)*TOPK_BLOCK_SIZE;
page_block_size = params.page_block_size;
k_block_stride = params.stride_kv_block;
k_row_stride = params.stride_kv_row;
k_ptr = (fp8*)params.kv;
} else {
indices_base = gExtraIndices + (block_idx-args.num_orig_kv_blocks)*TOPK_BLOCK_SIZE;
page_block_size = params.extra_page_block_size;
k_block_stride = params.stride_extra_kv_block;
k_row_stride = params.stride_extra_kv_row;
k_ptr = (fp8*)params.extra_kv;
}
[[maybe_unused]] int topk_length = IS_EXTRA_BLOCK ? args.extra_topk_length : args.topk_length;
[[maybe_unused]] int rel_block_idx = IS_EXTRA_BLOCK ? (block_idx - args.num_orig_kv_blocks) : block_idx;
transac_bar_t* peer_bar_k_remote_ready = get_peer_addr(&(plan.bar_k_remote_ready[buf_idx]));
CUTE_UNROLL
for (int round = 0; round < NUM_TOKENS_PER_THREAD; ++round) {
int my_token_idx = my_token_idx_base + round*NUM_TOKENS_PER_ROUND;
bf16* sK_nope_base = plan.u.k[buf_idx].data() + (idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx)*8 + ((lane_idx/8)*16)*TOPK_BLOCK_SIZE;
bf16* sK_nope_peer_base = get_peer_addr(sK_nope_base);
// Get prefetched token index
int token_index;
if constexpr (!IS_EXTRA_BLOCK) {
token_index = nxt_token_indexs[round];
if (block_idx+1 != (MODEL_TYPE == ModelType::V32 ? args.end_block_idx : args.num_orig_kv_blocks))
nxt_token_indexs[round] = __ldg(gIndices + (block_idx+1)*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx);
} else {
if constexpr (IS_FIRST_EXTRA_BLOCK) {
token_index = __ldg(gExtraIndices + (block_idx-args.num_orig_kv_blocks)*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx);
} else {
token_index = nxt_token_indexs[round];
}
if (block_idx+1 != args.end_block_idx)
nxt_token_indexs[round] = __ldg(gExtraIndices + (block_idx+1-args.num_orig_kv_blocks)*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx);
}
if constexpr (MODEL_TYPE == ModelType::MODEL1) {
// For MODEL1, we need to check whether the token_index is within topk_length
if (rel_block_idx*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx >= topk_length) {
token_index = -1; // To prevent IMA when we have invalid (e.g. INT_MAX) topk indexes outside topk_length
}
}
int block_index = token_index == -1 ? 0 : (int)((uint32_t)token_index/(uint32_t)page_block_size); // Use uint32_t division and mod to improve performance
int rel_idx_in_block = (uint32_t)token_index % (uint32_t)page_block_size; // NOTE When token_index is -1 (UINT_MAX), UINT_MAX%page_block_size < page_block_size, so there will be no illegal-memory-access error
fp8* gK_base;
bf16 scales[NUM_SCALES];
if constexpr (MODEL_TYPE == ModelType::V32) {
static_assert(NUM_SCALES == 4);
gK_base = k_ptr + block_index*k_block_stride + rel_idx_in_block*k_row_stride;
float scales_float[NUM_SCALES];
*(float4*)(scales_float) = load_128b_from_gmem<float4, L1CacheHint::EVICT_LAST, L2PrefetchHint::B128>((float*)(gK_base+HEAD_DIM_NOPE));
CUTE_UNROLL
for (int i = 0; i < NUM_SCALES; ++i) {
scales[i] = (bf16)scales_float[i];
}
} else {
static_assert(NUM_SCALES == 8);
gK_base = k_ptr + block_index*k_block_stride + rel_idx_in_block*(HEAD_DIM_NOPE + HEAD_DIM_ROPE*sizeof(bf16));
fp8_e8m0* gK_scales_base = (fp8_e8m0*)(k_ptr + block_index*k_block_stride + page_block_size*(HEAD_DIM_NOPE+HEAD_DIM_ROPE*sizeof(bf16)) + rel_idx_in_block*NUM_SCALES*sizeof(fp8_e8m0));
fp8_e8m0 scales_e8m0[NUM_SCALES];
*(int64_t*)scales_e8m0 = __ldg((int64_t*)gK_scales_base);
CUTE_UNROLL
for (int i = 0; i < NUM_SCALES; i += 2) {
*(__nv_bfloat162_raw*)(scales+i) = __nv_cvt_e8m0x2_to_bf162raw(*(__nv_fp8x2_storage_t*)(scales_e8m0+i));
}
}
// Wait for the nope buffer to be available
if (round == 0) {
plan.bar_k_avail[buf_idx].wait((bar_phase_k>>buf_idx&1)^1);
}
if (CLUSTER_SIZE == 2 && round == 0 && idx_in_warpgroup == 0) {
plan.bar_k_remote_ready[buf_idx].arrive_and_expect_tx((TOPK_BLOCK_SIZE/2)*(HEAD_DIM_NOPE+HEAD_DIM_ROPE)*sizeof(bf16));
}
// Collectively copy from global memory and dequant
// For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py
fp8* gK_nope = gK_base + (lane_idx/8)*16;
if (token_index == -1) {
CUTE_UNROLL
for (int i = 0; i < NUM_SCALES; ++i)
scales[i] = (bf16)0.0f;
}
CUTE_UNROLL
for (int dim_idx = 0; dim_idx < HEAD_DIM_NOPE/64; dim_idx += 1) {
fp8x16 cur_fp8x16 = load_128b_from_gmem<fp8x16, L1CacheHint::EVICT_LAST, L2PrefetchHint::B256>(gK_nope + dim_idx*64); // We use EVICT_LAST here since gK_base may not be aligned to 32B (for V3.2) and the performance is the best among all cache hints (for MODEL1)
bf16 scale = scales[MODEL_TYPE == ModelType::V32 ? dim_idx/2 : dim_idx];
auto dequant_and_save_bf16x8 = [&](const fp8x8 &data, int offset) {
int smem_offset = (dim_idx*64 + offset) * TOPK_BLOCK_SIZE;
bf16x8 cur_bf16x8 = cvt_fp8x8_bf16x8(data, __bfloat162bfloat162(*(__nv_bfloat16*)(&scale)));
*(__int128_t*)(sK_nope_base + smem_offset) = *(__int128_t*)&cur_bf16x8;
if constexpr (CLUSTER_SIZE == 2) {
st_async_128b(sK_nope_peer_base + smem_offset, cur_bf16x8, peer_bar_k_remote_ready);
}
};
if (token_index == -1)
*(uint128_t*)(&cur_fp8x16) = uint128_t();
dequant_and_save_bf16x8(cur_fp8x16.lo, 0);
dequant_and_save_bf16x8(cur_fp8x16.hi, 8);
}
bf16* gK_rope;
if constexpr (MODEL_TYPE == ModelType::V32) {
gK_rope = (bf16*)(gK_base+HEAD_DIM_NOPE+NUM_SCALES*sizeof(float)) + (lane_idx/8)*8;
} else {
gK_rope = (bf16*)(gK_base+HEAD_DIM_NOPE) + (lane_idx/8)*8;
}
bf16* sK_rope_base = plan.u.k[buf_idx].data() + (idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx)*8 + ((lane_idx/8)*8)*TOPK_BLOCK_SIZE;
bf16* sK_rope_peer_base = get_peer_addr(sK_rope_base);
CUTE_UNROLL
for (int dim_idx = 0; dim_idx < HEAD_DIM_ROPE/32; dim_idx += 1) {
bf16x8 cur_bf16x8 = load_128b_from_gmem<bf16x8, L1CacheHint::EVICT_LAST, L2PrefetchHint::B128>(gK_rope + dim_idx*32);
if constexpr (MODEL_TYPE == ModelType::V32) {
// NOTE We do not need to mask the RoPE part for V3.2 since it isn't involved in the SV gemm
} else {
if (token_index == -1)
*(uint128_t*)(&cur_bf16x8) = uint128_t();
}
int smem_offset = (HEAD_DIM_NOPE + dim_idx*32) * TOPK_BLOCK_SIZE;
*(__int128_t*)(sK_rope_base + smem_offset) = *(__int128_t*)&cur_bf16x8;
if constexpr (CLUSTER_SIZE == 2) {
st_async_128b(sK_rope_peer_base + smem_offset, cur_bf16x8, peer_bar_k_remote_ready);
}
}
}
fence_view_async_shared();
if (idx_in_warpgroup < 32) {
// We put this after fence_view_async_shared() since this won't be read by async proxy
auto is_index_valid = [&](int index, int offset_within_thread) -> bool {
if constexpr (MODEL_TYPE == ModelType::V32) {
return index != -1;
} else {
return index != -1 && rel_block_idx*TOPK_BLOCK_SIZE + lane_idx*2 + offset_within_thread < topk_length;
}
};
int2 indices = __ldg((int2*)(indices_base + lane_idx*2));
*(char2*)(&plan.is_kv_valid[buf_idx][lane_idx*2]) = {
is_index_valid(indices.x, 0),
is_index_valid(indices.y, 1)
};
}
// Signal the barrier
plan.bar_k_local_ready[buf_idx].arrive();
bar_phase_k ^= 1 << buf_idx;
};
if constexpr (MODEL_TYPE == ModelType::V32) {
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {
process_one_block(block_idx, IsOrigBlock{}, IsNotFirstExtraBlock{});
}
} else {
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < min(args.num_orig_kv_blocks, args.end_block_idx); ++block_idx) {
process_one_block(block_idx, IsOrigBlock{}, IsNotFirstExtraBlock{});
}
if (args.num_orig_kv_blocks < args.end_block_idx) {
process_one_block(max(args.start_block_idx, args.num_orig_kv_blocks), IsExtraBlock{}, IsFirstExtraBlock{});
}
CUTE_NO_UNROLL
for (int block_idx = max(args.start_block_idx, args.num_orig_kv_blocks)+1; block_idx < args.end_block_idx; ++block_idx) {
process_one_block(block_idx, IsExtraBlock{}, IsNotFirstExtraBlock{});
}
}
sync_all_threads_in_cluster();
}
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm90");
}
#endif
} }
template<typename Kernel, typename TMAParams> template<typename Kernel>
__global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, Kernel::CLUSTER_SIZE) __global__ void __launch_bounds__(Kernel::NUM_THREADS, 1)
flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const SparseAttnDecodeParams params, __grid_constant__ const TMAParams tma_params) { flash_fwd_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams params) {
Kernel::devfunc(params, tma_params); Kernel::devfunc(params);
} }
template<ModelType MODEL_TYPE, int NUM_HEADS> template<ModelType MODEL_TYPE, int NUM_HEADS>
...@@ -701,82 +52,6 @@ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::run(const SparseAttnDecodeParams &pa ...@@ -701,82 +52,6 @@ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::run(const SparseAttnDecodeParams &pa
KU_ASSERT(params.topk_length == nullptr, "V3.2 does not support dynamic topk length"); KU_ASSERT(params.topk_length == nullptr, "V3.2 does not support dynamic topk length");
KU_ASSERT(params.stride_kv_row == 656); // number of bytes per token (512 fp8 + 4 float32 + 64 bfloat16) KU_ASSERT(params.stride_kv_row == 656); // number of bytes per token (512 fp8 + 4 float32 + 64 bfloat16)
} }
auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q, params.b);
auto tma_Q = cute::make_tma_copy(
SM90_TMA_LOAD{},
make_tensor(
make_gmem_ptr((bf16*)params.q),
make_layout(
shape_Q,
make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q, params.stride_q_b)
)
),
SmemLayoutQ{}
);
CUtensorMap tensor_map_o;
{
// Here we manually construct TMA descriptor to store O, in order to leverage 5D TMA
uint64_t size[5] = {OBUF_SW, (unsigned long)params.h_q, HEAD_DIM_V/OBUF_SW, (unsigned long)params.s_q, (unsigned long)params.b};
uint64_t stride[4] = {params.stride_o_h_q*sizeof(bf16), OBUF_SW*sizeof(bf16), params.stride_o_s_q*sizeof(bf16), params.stride_o_b*sizeof(bf16)};
uint32_t box_size[5] = {OBUF_SW, BLOCK_M, HEAD_DIM_V/OBUF_SW, 1, 1};
uint32_t elem_stride[5] = {1, 1, 1, 1, 1};
CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
&tensor_map_o,
CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
5,
params.out,
size,
stride,
box_size,
elem_stride,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
OBUF_SW == 64 ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B :
OBUF_SW == 32 ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B :
OBUF_SW == 16 ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_32B :
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);
KU_ASSERT(res == CUresult::CUDA_SUCCESS);
}
TmaParams<
decltype(shape_Q), decltype(tma_Q)
> tma_params = {
shape_Q, tma_Q,
tensor_map_o
};
auto mla_kernel = &flash_fwd_splitkv_mla_fp8_sparse_kernel<KernelTemplate<MODEL_TYPE, NUM_HEADS>, decltype(tma_params)>;
constexpr size_t smem_size = sizeof(SharedMemoryPlan);
KU_CUDA_CHECK(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// NOTE Don't use PDL because of potential compiler bugs!
// cudaLaunchAttribute mla_kernel_attributes[1];
// mla_kernel_attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
// mla_kernel_attributes[0].val.programmaticStreamSerializationAllowed = 1;
// cudaLaunchConfig_t mla_kernel_config = {
// dim3(num_m_block, params.h_k, params.num_sm_parts),
// dim3(NUM_THREADS, 1, 1),
// smem_size,
// stream,
// mla_kernel_attributes,
// 1
// };
// cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params);
cutlass::ClusterLaunchParams launch_params = {
dim3(NUM_M_BLOCKS, params.s_q, params.num_sm_parts),
dim3(NUM_THREADS, 1, 1),
dim3(CLUSTER_SIZE, 1, 1),
smem_size,
params.stream
};
cutlass::launch_kernel_on_cluster(
launch_params, (void*)mla_kernel, params, tma_params
);
KU_CHECK_KERNEL_LAUNCH();
} }
template<ModelType MODEL_TYPE, int NUM_HEADS> template<ModelType MODEL_TYPE, int NUM_HEADS>
......
#pragma once #pragma once
#include <cute/tensor.hpp> #include <cute/tensor.hpp>
#include <cutlass/arch/barrier.h> // #include <cutlass/arch/barrier.h>
namespace sm90 { namespace sm90 {
__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) { // __forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) {
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst); // uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst);
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], %2;\n" // asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], %2;\n"
:: "r"(dst_addr), // :: "r"(dst_addr),
"l"(src), // "l"(src),
"n"(16)); // "n"(16));
} // }
__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst, bool pred, int64_t cache_policy) { // __forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst, bool pred, int64_t cache_policy) {
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst); // uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst);
asm volatile("cp.async.cg.shared.global.L2::cache_hint.L2::256B [%0], [%1], 16, %2, %3;\n" // asm volatile("cp.async.cg.shared.global.L2::cache_hint.L2::256B [%0], [%1], 16, %2, %3;\n"
:: "r"(dst_addr), // :: "r"(dst_addr),
"l"(src), // "l"(src),
"r"(pred?16:0), // "r"(pred?16:0),
"l"(cache_policy)); // "l"(cache_policy));
} // }
__forceinline__ __device__ int64_t createpolicy_evict_last() { // __forceinline__ __device__ int64_t createpolicy_evict_last() {
int64_t res; // int64_t res;
asm volatile( // asm volatile(
"createpolicy.fractional.L2::evict_last.b64 %0, 1.0; \n\t" // "createpolicy.fractional.L2::evict_last.b64 %0, 1.0; \n\t"
: "=l"(res) // : "=l"(res)
: // :
); // );
return res; // return res;
} // }
__forceinline__ __device__ int64_t createpolicy_evict_first() { // __forceinline__ __device__ int64_t createpolicy_evict_first() {
int64_t res; // int64_t res;
asm volatile( // asm volatile(
"createpolicy.fractional.L2::evict_first.b64 %0, 1.0; \n\t" // "createpolicy.fractional.L2::evict_first.b64 %0, 1.0; \n\t"
: "=l"(res) // : "=l"(res)
: // :
); // );
return res; // return res;
} // }
__forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) { // __forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) {
// In the layout of fragment A and fragment C during WGMMA, the data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx // // In the layout of fragment A and fragment C during WGMMA, the data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a // // You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4); // int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4);
return row_idx; // return row_idx;
} // }
__forceinline__ __device__ int get_AorC_col_idx(int local_elem_idx, int idx_in_warpgroup) { // __forceinline__ __device__ int get_AorC_col_idx(int local_elem_idx, int idx_in_warpgroup) {
int col_idx = 8*(local_elem_idx/4) + (idx_in_warpgroup%4)*2 + (local_elem_idx&1); // int col_idx = 8*(local_elem_idx/4) + (idx_in_warpgroup%4)*2 + (local_elem_idx&1);
return col_idx; // return col_idx;
} // }
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h // // Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h
// * Copyright (c) 2024, Tri Dao. // // * Copyright (c) 2024, Tri Dao.
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma> // template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { // __forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
using namespace cute; // using namespace cute;
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value; // constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const // // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); } // if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC); // warpgroup_fence_operand(tCrC);
if constexpr (arrive) { // if constexpr (arrive) {
warpgroup_arrive(); // warpgroup_arrive();
} // }
if constexpr (zero_init) { // if constexpr (zero_init) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; // tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// Unroll the K mode manually to set scale D to 1 // // Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL // CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { // for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); // cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One; // tiled_mma.accumulate_ = GMMA::ScaleOut::One;
} // }
} else { // } else {
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC); // // cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
// Unroll the K mode manually to set scale D to 1 // // Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL // CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { // for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); // cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One; // tiled_mma.accumulate_ = GMMA::ScaleOut::One;
} // }
} // }
if constexpr (commit) { // if constexpr (commit) {
warpgroup_commit_batch(); // warpgroup_commit_batch();
} // }
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); } // if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_fence_operand(tCrC); // warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); } // if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
} // }
// A simpler version of gemm // // A simpler version of gemm
template <typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma> // template <typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm_ss(bool clear_accum, TiledMma tiled_mma, Tensor0 const &sA, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) { // __forceinline__ __device__ void gemm_ss(bool clear_accum, TiledMma tiled_mma, Tensor0 const &sA, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) {
using namespace cute; // using namespace cute;
ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); // ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);
Tensor sA_frag = thr_mma.partition_fragment_A(sA); // Tensor sA_frag = thr_mma.partition_fragment_A(sA);
Tensor sB_frag = thr_mma.partition_fragment_B(sB); // Tensor sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(sA_frag) == size<2>(sB_frag)); // static_assert(size<2>(sA_frag) == size<2>(sB_frag));
warpgroup_fence_operand(rC_frag); // warpgroup_fence_operand(rC_frag);
warpgroup_arrive(); // warpgroup_arrive();
tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One; // tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One;
CUTLASS_PRAGMA_UNROLL // CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < size<2>(sA_frag); ++k) { // for (int k = 0; k < size<2>(sA_frag); ++k) {
cute::gemm(tiled_mma, sA_frag(_, _, k), sB_frag(_, _, k), rC_frag); // cute::gemm(tiled_mma, sA_frag(_, _, k), sB_frag(_, _, k), rC_frag);
tiled_mma.accumulate_ = GMMA::ScaleOut::One; // tiled_mma.accumulate_ = GMMA::ScaleOut::One;
} // }
warpgroup_fence_operand(rC_frag); // warpgroup_fence_operand(rC_frag);
} // }
template <typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma> // template <typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm_rs(bool clear_accum, TiledMma tiled_mma, Tensor0 rA_frag, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) { // __forceinline__ __device__ void gemm_rs(bool clear_accum, TiledMma tiled_mma, Tensor0 rA_frag, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) {
using namespace cute; // using namespace cute;
ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup); // ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);
Tensor sB_frag = thr_mma.partition_fragment_B(sB); // Tensor sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(rA_frag) == size<2>(sB_frag)); // static_assert(size<2>(rA_frag) == size<2>(sB_frag));
warpgroup_fence_operand(const_cast<Tensor0 &>(rA_frag)); // warpgroup_fence_operand(const_cast<Tensor0 &>(rA_frag));
warpgroup_fence_operand(rC_frag); // warpgroup_fence_operand(rC_frag);
warpgroup_arrive(); // warpgroup_arrive();
tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One; // tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One;
CUTLASS_PRAGMA_UNROLL // CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < size<2>(rA_frag); ++k) { // for (int k = 0; k < size<2>(rA_frag); ++k) {
cute::gemm(tiled_mma, rA_frag(_, _, k), sB_frag(_, _, k), rC_frag); // cute::gemm(tiled_mma, rA_frag(_, _, k), sB_frag(_, _, k), rC_frag);
tiled_mma.accumulate_ = GMMA::ScaleOut::One; // tiled_mma.accumulate_ = GMMA::ScaleOut::One;
} // }
warpgroup_fence_operand(rC_frag); // warpgroup_fence_operand(rC_frag);
warpgroup_fence_operand(const_cast<Tensor0 &>(rA_frag)); // warpgroup_fence_operand(const_cast<Tensor0 &>(rA_frag));
} // }
__forceinline__ __device__ uint32_t get_sm_id() { // __forceinline__ __device__ uint32_t get_sm_id() {
uint32_t ret; // uint32_t ret;
asm("mov.u32 %0, %%smid;" : "=r"(ret)); // asm("mov.u32 %0, %%smid;" : "=r"(ret));
return ret; // return ret;
} // }
static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. Not sure if this number is the same on all GPUs. // static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. Not sure if this number is the same on all GPUs.
template<typename T> // template<typename T>
CUTE_DEVICE // CUTE_DEVICE
T* get_peer_addr(const T* p) { // T* get_peer_addr(const T* p) {
return (T*)((int64_t)(p) ^ PEER_ADDR_MASK); // return (T*)((int64_t)(p) ^ PEER_ADDR_MASK);
} // }
template< // template<
typename TMA, // typename TMA,
typename Tensor0, // typename Tensor0,
typename Tensor1 // typename Tensor1
> // >
CUTE_DEVICE // CUTE_DEVICE
void launch_tma_copy( // void launch_tma_copy(
const TMA &tma_copy, // const TMA &tma_copy,
Tensor0 src, // Tensor0 src,
Tensor1 dst, // Tensor1 dst,
cutlass::arch::ClusterTransactionBarrier &bar, // cutlass::arch::ClusterTransactionBarrier &bar,
const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL // const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL
) { // ) {
auto thr_tma = tma_copy.get_slice(cute::_0{}); // auto thr_tma = tma_copy.get_slice(cute::_0{});
cute::copy( // cute::copy(
tma_copy.with(reinterpret_cast<typename cutlass::arch::ClusterTransactionBarrier::ValueType&>(bar), 0, cache_hint), // tma_copy.with(reinterpret_cast<typename cutlass::arch::ClusterTransactionBarrier::ValueType&>(bar), 0, cache_hint),
thr_tma.partition_S(src), // thr_tma.partition_S(src),
thr_tma.partition_D(dst) // thr_tma.partition_D(dst)
); // );
} // }
} }
#pragma once #pragma once
#include <math_constants.h>
#include <cute/tensor.hpp> #include <cute/tensor.hpp>
#include <cutlass/cluster_launch.hpp>
#include <cooperative_groups.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/arch/arch.h> #include <cutlass/arch/arch.h>
#include <kerutils/kerutils.cuh> #include <kerutils/kerutils.cuh>
...@@ -78,66 +74,15 @@ struct SharedMemoryPlan { ...@@ -78,66 +74,15 @@ struct SharedMemoryPlan {
float2 sM[32]; float2 sM[32];
float2 sL[64]; // For reduction across WG0/1 in epilogue float2 sL[64]; // For reduction across WG0/1 in epilogue
float final_max_logits[64], final_lse[64]; float final_max_logits[64], final_lse[64];
transac_bar_t bar_q, bar_k0_free[2], bar_k0_ready[2], bar_k1_free[2], bar_k1_ready[2], bar_is_kv_valid_ready; // transac_bar_t bar_q, bar_k0_free[2], bar_k0_ready[2], bar_k1_free[2], bar_k1_ready[2], bar_is_kv_valid_ready;
}; };
using TiledMMA_QK = decltype(make_tiled_mma(
GMMA::MMA_64x64x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::K>{},
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_LocalP = decltype(make_tiled_mma(
GMMA::MMA_64x256x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_RemoteP = decltype(make_tiled_mma(
GMMA::MMA_64x256x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{}
));
template<
typename Shape_Q, typename TMA_Q
>
struct TmaParams {
Shape_Q shape_Q; TMA_Q tma_Q;
CUtensorMap tensor_map_O;
};
enum NamedBarriers : uint32_t {
wg0_bunch_0_ready = 0,
wg1_bunch_0_ready = 1,
wg0_s0_ready = 2,
wg1_s1_ready = 3,
sL_ready = 4,
warpgroup0_sync = 5,
warpgroup1_sync = 6,
epilogue_sync = 7
};
// Save rPb (64x64, bfloat16) to sP using the stmatrix instruction
template<
typename Tensor0,
typename Tensor1
>
static __forceinline__ __device__ void save_rS_to_sS(
Tensor0 const &rPb,
Tensor1 const &sP,
int idx_in_warpgroup
) {
auto r2s_copy = make_tiled_copy_C(
Copy_Atom<SM90_U32x4_STSM_N, bf16>{},
TiledMMA_QK{}
);
ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup);
Tensor thr_copy_rPb = thr_copy.retile_S(rPb);
Tensor thr_copy_sP = thr_copy.partition_D(sP);
cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP);
}
template<typename TMAParams>
static __device__ __forceinline__ void static __device__ __forceinline__ void
devfunc(const SparseAttnFwdParams &params, const TMAParams &tma_params); devfunc(const SparseAttnFwdParams &params);
static void run(const SparseAttnFwdParams &params); static void run(const SparseAttnFwdParams &params);
......
...@@ -9,566 +9,15 @@ namespace sm90::fwd { ...@@ -9,566 +9,15 @@ namespace sm90::fwd {
using namespace cute; using namespace cute;
CUTE_DEVICE void st_global_cs_128(float f0, float f1, float f2, float f3, void *dst_ptr) {
asm volatile("st.weak.global.cs.v4.f32 [%0], {%1, %2, %3, %4};\n"
:
: "l"(dst_ptr),
"f"(f0), "f"(f1), "f"(f2), "f"(f3)
);
}
CUTE_DEVICE
float2 __shfl_xor_sync_float2(
uint32_t mask, float2 value, int offset
) {
float2 res;
*reinterpret_cast<long long*>(&res) = __shfl_xor_sync(
mask,
*reinterpret_cast<long long*>(&value),
offset
);
return res;
}
CUTE_DEVICE
void tma_bulk_reduce_add(void const* src_ptr, void* dst_ptr, int32_t store_bytes) {
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(src_ptr);
asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n"
:
: "l"(dst_ptr), "r"(smem_int_ptr), "r"(store_bytes)
: "memory");
}
template<int D_QK, bool HAVE_TOPK_LENGTH> template<int D_QK, bool HAVE_TOPK_LENGTH>
template<typename TMAParams> __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttnFwdParams &params) {
__device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttnFwdParams &params, const TMAParams &tma_params) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__))
const int q_h_idx = blockIdx.x % (params.h_q/B_H);
const int s_q_idx = blockIdx.x / (params.h_q/B_H);
const int warpgroup_idx = cutlass::canonical_warp_group_idx();
const int warp_idx = cutlass::canonical_warp_idx_sync();
const int idx_in_warpgroup = threadIdx.x % 128;
// Define shared tensors
extern __shared__ char wksp_buf[];
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);
Tensor sQ = make_tensor(make_smem_ptr(plan.q_o.q.data()), SmemLayoutQ{});
Tensor sO = make_tensor(make_smem_ptr(plan.q_o.o.data()), SmemLayoutO{});
Tensor sS0 = make_tensor(make_smem_ptr(D_QK == 576 ? plan.k[0].data()+64*512 : plan.s[1].data()), SmemLayoutS{}); // Overlap with sK0's RoPE part for V3.2
Tensor sS1 = make_tensor(make_smem_ptr(plan.s[0].data()), SmemLayoutS{});
if (warp_idx == 0 && elect_one_sync()) {
// Prefetch TMA descriptors
cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(&tma_params.tensor_map_O);
// Initialize barriers
plan.bar_q.init(1);
CUTE_UNROLL
for (int i = 0; i < 2; ++i) {
plan.bar_k0_free[i].init(128);
plan.bar_k0_ready[i].init(128);
plan.bar_k1_free[i].init(128);
plan.bar_k1_ready[i].init(128);
}
plan.bar_is_kv_valid_ready.init(16);
fence_barrier_init();
}
__syncthreads();
const int topk_length = HAVE_TOPK_LENGTH ? __ldg(params.topk_length + s_q_idx) : params.topk;
const int num_topk_blocks = HAVE_TOPK_LENGTH ? ku::ceil_div(topk_length, (int)B_TOPK) : (int)((unsigned int)params.topk/(unsigned int)B_TOPK);
if (warpgroup_idx == 0 || warpgroup_idx == 1) {
cutlass::arch::warpgroup_reg_alloc<216>();
if (warp_idx == 0 && elect_one_sync()) {
// Load Q
Tensor gQ = flat_divide(
tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx),
Tile<Int<B_H>, Int<D_Q>>{}
)(_, _, q_h_idx, _0{});
launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST);
plan.bar_q.arrive_and_expect_tx(B_H*D_Q*sizeof(bf16));
}
float rM[2] = {MAX_INIT_VAL, MAX_INIT_VAL}; // Meaning: the `max_logits` used for O / rL calculation
float rL[2] = {0.0f, 0.0f};
Tensor rO = partition_fragment_C(TiledMMA_PV_LocalP{}, Shape<Int<B_H>, Int<D_V/2>>{});
Tensor rP = partition_fragment_C(TiledMMA_QK{}, Shape<Int<B_H>, Int<B_TOPK>>{});
Tensor rS = make_tensor<bf16>(partition_shape_A(TiledMMA_PV_LocalP{}, Shape<Int<B_H>, Int<B_TOPK>>{}));
cute::fill(rO, 0.0f);
// Wait for Q
plan.bar_q.wait(0);
bool cur_bar_wait_phase = 0;
struct Warpgroup0 {};
struct Warpgroup1 {};
auto qkt_gemm_one_tile = [&](auto warpgroup_idx, int tile_idx, bool clear_accum) {
constexpr bool IS_WG1 = std::is_same_v<decltype(warpgroup_idx), Warpgroup1>;
TiledMMA tiled_mma_QK = TiledMMA_QK{};
Tensor sQ_tile = flat_divide(sQ, Tile<Int<B_H>, Int<64>>{})(_, _, _0{}, tile_idx);
Tensor sK_tile = make_tensor(make_smem_ptr(plan.k[(int)IS_WG1].data() + tile_idx*B_TOPK*64), SmemLayoutKTiles<1>{});
gemm_ss(clear_accum, tiled_mma_QK, sQ_tile, sK_tile, rP, idx_in_warpgroup);
};
auto mask_rP = [&](auto warpgroup_idx) {
constexpr bool IS_WG1 = std::is_same_v<decltype(warpgroup_idx), Warpgroup1>;
plan.bar_is_kv_valid_ready.wait(cur_bar_wait_phase);
CUTE_UNROLL
for (int row_idx = 0; row_idx < 2; ++row_idx) {
CUTE_UNROLL
for (int i = row_idx*2; i < size(rP); i += 4) {
int col = 8*(i/4) + (idx_in_warpgroup%4)*2;
if (!plan.is_kv_valid[IS_WG1][col]) rP(i) = -INFINITY;
if (!plan.is_kv_valid[IS_WG1][col+1]) rP(i+1) = -INFINITY;
}
}
};
auto online_softmax_and_rescale_o = [&](auto warpgroup_idx) {
plan.bar_is_kv_valid_ready.wait(cur_bar_wait_phase);
constexpr bool IS_WG1 = std::is_same_v<decltype(warpgroup_idx), Warpgroup1>;
const float scale = params.sm_scale_div_log2;
float r_sM[2];
if constexpr (IS_WG1) {
*(float2*)r_sM = plan.sM[idx_in_warpgroup/4];
}
float new_maxs[2];
CUTE_UNROLL
for (int row_idx = 0; row_idx < 2; ++row_idx) {
// Get rowwise max
float cur_max = -INFINITY;
CUTE_UNROLL
for (int i = row_idx*2; i < size(rP); i += 4) {
cur_max = max(cur_max, max(rP(i), rP(i+1)));
}
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1));
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2));
cur_max *= scale;
// Get new max and scale
// For WG1, old_max comes from sM (written by WG0); for WG0, old_max comes from rM (read by WG0 from sM in the last round)
new_maxs[row_idx] = max(IS_WG1 ? r_sM[row_idx] : rM[row_idx], cur_max);
// Scale O
float scale_for_o = exp2f(rM[row_idx]-new_maxs[row_idx]);
CUTE_UNROLL
for (int i = row_idx*2; i < size(rO); i += 4) {
rO(i) *= scale_for_o;
rO(i+1) *= scale_for_o;
}
// Get rS
float cur_sum = 0;
CUTE_UNROLL
for (int i = row_idx*2; i < size(rP); i += 4) {
rP(i) = exp2f(rP(i)*scale - new_maxs[row_idx]);
rP(i+1) = exp2f(rP(i+1)*scale - new_maxs[row_idx]);
rS(i) = (bf16)rP(i);
rS(i+1) = (bf16)rP(i+1);
cur_sum += rP(i) + rP(i+1);
}
rL[row_idx] = rL[row_idx]*scale_for_o + cur_sum;
}
__syncwarp();
if (idx_in_warpgroup%4 == 0) {
plan.sM[idx_in_warpgroup/4] = *(float2*)new_maxs;
}
rM[0] = new_maxs[0];
rM[1] = new_maxs[1];
};
auto reduce_L = [&]() {
// Reduce L
// For example, thread 0 reduces with thread 1, 2, and 3, as well as thread 128, 129, 130, and 131
rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1);
rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2);
rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1);
rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2);
if (idx_in_warpgroup%4 == 0)
plan.sL[threadIdx.x/4] = *(float2*)(rL);
NamedBarrier::arrive_and_wait(256, NamedBarriers::sL_ready);
float2 peer_L = plan.sL[(threadIdx.x/4)^32];
rL[0] += peer_L.x;
rL[1] += peer_L.y;
};
auto store_O = [&]() {
float scale_factors[2];
CUTE_UNROLL
for (int i = 0; i < 2; ++i) {
float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : params.attn_sink[q_h_idx*B_H + get_AorC_row_idx(i, idx_in_warpgroup)]*CUDART_L2E_F;
scale_factors[i] = 1.0f / (rL[i] + exp2f(attn_sink - rM[i]));
if (rL[i] == 0.0f)
scale_factors[i] = 0.0f; // The output should be 0 whatever attn_sink is
}
Tensor sO = make_tensor(make_smem_ptr(plan.q_o.o.data() + warpgroup_idx*B_H*(D_V/2)), SmemLayoutOTiles<4>{});
bf16* stsm_addrs[4];
int stsm_row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%16);
CUTE_UNROLL
for (int i = 0; i < 64/16; ++i) {
stsm_addrs[i] = &sO(stsm_row, (idx_in_warpgroup%32/16*8) + 16*i);
}
bool s2g_pred = warp_idx%4 == 0 && elect_one_sync();
warpgroup_wait<0>();
CUTE_UNROLL
for (int tile_idx = 0; tile_idx < (D_V/2)/64; tile_idx += 1) {
// Convert
constexpr int NUM_ELEMS_EACH_TILE = B_H*64 / 128; // 64: tile size, 128: warpgroup size
bf16 cur_rOb[NUM_ELEMS_EACH_TILE];
CUTE_UNROLL
for (int i = 0; i < NUM_ELEMS_EACH_TILE; ++i) {
cur_rOb[i] = (bf16)(rO(tile_idx*NUM_ELEMS_EACH_TILE + i) * scale_factors[i%4>=2]);
}
// R -> S
CUTE_UNROLL
for (int i = 0; i < 64/16; ++i) {
SM90_U32x4_STSM_N::copy(
*reinterpret_cast<uint32_t*>(cur_rOb + i*8 + 0),
*reinterpret_cast<uint32_t*>(cur_rOb + i*8 + 2),
*reinterpret_cast<uint32_t*>(cur_rOb + i*8 + 4),
*reinterpret_cast<uint32_t*>(cur_rOb + i*8 + 6),
*reinterpret_cast<uint128_t*>(stsm_addrs[i] + tile_idx*(B_H*64))
);
}
fence_view_async_shared();
NamedBarrier::arrive_and_wait(128, warpgroup_idx ? NamedBarriers::warpgroup1_sync : NamedBarriers::warpgroup0_sync);
// S -> G
if (s2g_pred) {
int g_tile_idx = warpgroup_idx*4 + tile_idx;
SM90_TMA_STORE_3D::copy(
&tma_params.tensor_map_O,
plan.q_o.o.data() + g_tile_idx*(B_H*64),
g_tile_idx*64,
q_h_idx*B_H,
s_q_idx
);
}
}
cute::tma_store_arrive();
};
if (warpgroup_idx == 0) {
// Warpgroup 0
auto pipelined_wait_and_qkt_gemm_l = [&]() __attribute__((always_inline)) {
plan.bar_k0_ready[0].wait(cur_bar_wait_phase);
qkt_gemm_one_tile(Warpgroup0{}, 0, true);
qkt_gemm_one_tile(Warpgroup0{}, 1, false);
qkt_gemm_one_tile(Warpgroup0{}, 2, false);
qkt_gemm_one_tile(Warpgroup0{}, 3, false);
warpgroup_commit_batch();
};
auto pipelined_wait_and_qkt_gemm_r = [&]() __attribute__((always_inline)) {
plan.bar_k0_ready[1].wait(cur_bar_wait_phase);
qkt_gemm_one_tile(Warpgroup0{}, 4, false);
qkt_gemm_one_tile(Warpgroup0{}, 5, false);
qkt_gemm_one_tile(Warpgroup0{}, 6, false);
qkt_gemm_one_tile(Warpgroup0{}, 7, false);
if constexpr (D_QK == 576) {
qkt_gemm_one_tile(Warpgroup0{}, 8, false);
}
warpgroup_commit_batch();
};
auto scale_rS = [&](float scales[2]) {
CUTE_UNROLL
for (int row = 0; row < 2; ++row) {
CUTE_UNROLL
for (int i = row*2; i < size(rP); i += 4) {
rS(i) = (bf16)(rP(i) * scales[row]);
rS(i+1) = (bf16)(rP(i+1) * scales[row]);
}
}
};
auto rescale_rO = [&](float scales[2]) {
CUTE_UNROLL
for (int row = 0; row < 2; ++row) {
CUTE_UNROLL
for (int i = row*2; i < size(rO); i += 4) {
rO(i) *= scales[row];
rO(i+1) *= scales[row];
}
rL[row] *= scales[row];
}
};
CUTE_NO_UNROLL
for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) {
Tensor sV0l = make_tensor(make_smem_ptr(plan.k[0].data()), SmemLayoutKTilesTransposed<4>{});
Tensor sV1l = make_tensor(make_smem_ptr(plan.k[1].data()), SmemLayoutKTilesTransposed<4>{});
if (block_idx == 0) {
// NOTE: We put this code here to avoid register spilling
pipelined_wait_and_qkt_gemm_l();
pipelined_wait_and_qkt_gemm_r();
warpgroup_wait<0>();
}
// Online softmax, inform WG1
mask_rP(Warpgroup0{});
online_softmax_and_rescale_o(Warpgroup0{});
NamedBarrier::arrive(256, NamedBarriers::wg0_bunch_0_ready);
// Issue rO0 += rS0 @ sV0l
gemm_rs(false, TiledMMA_PV_LocalP{}, rS, sV0l, rO, idx_in_warpgroup);
warpgroup_commit_batch();
// Mark V0L as free
warpgroup_wait<0>();
plan.bar_k0_free[0].arrive();
// Wait for new sM, scale rS, save, inform WG1
NamedBarrier::arrive_and_wait(256, NamedBarriers::wg1_bunch_0_ready);
float new_rM[2], scale_factors[2];
*(float2*)new_rM = plan.sM[idx_in_warpgroup/4];
CUTE_UNROLL
for (int i = 0; i < 2; ++i) {
scale_factors[i] = exp2f(rM[i] - new_rM[i]);
rM[i] = new_rM[i];
}
scale_rS(scale_factors);
save_rS_to_sS(rS, sS0, idx_in_warpgroup);
fence_view_async_shared();
NamedBarrier::arrive(256, NamedBarriers::wg0_s0_ready);
// Wait for sS1
NamedBarrier::arrive_and_wait(256, NamedBarriers::wg1_s1_ready);
// Rescale rO0, Issue rO0 += sS1 @ sV1L
rescale_rO(scale_factors);
gemm_ss(false, TiledMMA_PV_RemoteP{}, sS1, sV1l, rO, idx_in_warpgroup);
warpgroup_commit_batch();
cur_bar_wait_phase ^= 1;
if (block_idx+2 < num_topk_blocks) {
// Launch the next QK^T GEMM
pipelined_wait_and_qkt_gemm_l();
// Mark V1L as free
warpgroup_wait<1>();
plan.bar_k1_free[0].arrive();
pipelined_wait_and_qkt_gemm_r();
// Wait for rP0 = sQ @ sK0
warpgroup_wait<0>();
} else {
// Mark V1L as free
warpgroup_wait<0>();
plan.bar_k1_free[0].arrive();
}
}
reduce_L();
store_O();
} else {
// Warpgroup 1
auto pipelined_wait_and_qkt_gemm = [&]() __attribute__((always_inline)) {
plan.bar_k1_ready[1].wait(cur_bar_wait_phase);
qkt_gemm_one_tile(Warpgroup1{}, 4, true);
qkt_gemm_one_tile(Warpgroup1{}, 5, false);
qkt_gemm_one_tile(Warpgroup1{}, 6, false);
qkt_gemm_one_tile(Warpgroup1{}, 7, false);
if constexpr (D_QK == 576) {
qkt_gemm_one_tile(Warpgroup1{}, 8, false);
}
plan.bar_k1_ready[0].wait(cur_bar_wait_phase);
qkt_gemm_one_tile(Warpgroup1{}, 0, false);
qkt_gemm_one_tile(Warpgroup1{}, 1, false);
qkt_gemm_one_tile(Warpgroup1{}, 2, false);
qkt_gemm_one_tile(Warpgroup1{}, 3, false);
warpgroup_commit_batch();
};
CUTE_NO_UNROLL
for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) {
Tensor sV0r = make_tensor(make_smem_ptr(plan.k[0].data()+64*256), SmemLayoutKTilesTransposed<4>{});
Tensor sV1r = make_tensor(make_smem_ptr(plan.k[1].data()+64*256), SmemLayoutKTilesTransposed<4>{});
// Issue rP1 = sQ @ sK1, and wait
pipelined_wait_and_qkt_gemm();
warpgroup_wait<0>();
mask_rP(Warpgroup1{});
// Wait for WG0 (for sM), online softmax, Notify WG0 (sM ready)
NamedBarrier::arrive_and_wait(256, NamedBarriers::wg0_bunch_0_ready);
online_softmax_and_rescale_o(Warpgroup1{});
NamedBarrier::arrive(256, NamedBarriers::wg1_bunch_0_ready);
// Issue rO1 += rS1 @ sV1R
gemm_rs(false, TiledMMA_PV_LocalP{}, rS, sV1r, rO, idx_in_warpgroup);
warpgroup_commit_batch();
// Wait for WG0 (for sS0), Issue rO1 += rS0 @ sV0R
save_rS_to_sS(rS, sS1, idx_in_warpgroup); // Put it here is faster
NamedBarrier::arrive_and_wait(256, NamedBarriers::wg0_s0_ready);
gemm_ss(false, TiledMMA_PV_RemoteP{}, sS0, sV0r, rO, idx_in_warpgroup);
warpgroup_commit_batch();
// Save rS1, inform WG0
fence_view_async_shared();
NamedBarrier::arrive(256, NamedBarriers::wg1_s1_ready);
// Wait for GEMM, and inform that sV1R is free
warpgroup_wait<1>();
plan.bar_k1_free[1].arrive();
// Wait for GEMM, and inform that sV0R is free
warpgroup_wait<0>();
plan.bar_k0_free[1].arrive();
cur_bar_wait_phase ^= 1;
}
reduce_L();
store_O();
// Save lse
if (idx_in_warpgroup%4 == 0) {
for (int row = 0; row < 2; ++row) {
int real_row = get_AorC_row_idx(row, idx_in_warpgroup);
bool is_no_valid_tokens = rL[row] == 0.0f;
plan.final_max_logits[real_row] = is_no_valid_tokens ? -INFINITY : rM[row]*CUDART_LN2_F;
plan.final_lse[real_row] = is_no_valid_tokens ? +INFINITY : logf(rL[row]) + rM[row]*CUDART_LN2_F;
}
fence_view_async_shared();
}
NamedBarrier::arrive_and_wait(128, NamedBarriers::warpgroup1_sync);
if (idx_in_warpgroup == 0) {
int g_offset = s_q_idx*params.h_q + q_h_idx*B_H;
SM90_BULK_COPY_S2G::copy(plan.final_max_logits, params.max_logits + g_offset, B_H*sizeof(float));
SM90_BULK_COPY_S2G::copy(plan.final_lse, params.lse + g_offset, B_H*sizeof(float));
cute::tma_store_arrive();
}
}
} else {
// Producer warpgroup
cutlass::arch::warpgroup_reg_dealloc<72>();
constexpr int GROUP_SIZE = 8, NUM_GROUPS = 128/GROUP_SIZE;
constexpr int NUM_ROWS_PER_GROUP = B_TOPK / NUM_GROUPS;
int idx_in_group = idx_in_warpgroup % GROUP_SIZE;
int group_idx = idx_in_warpgroup / GROUP_SIZE;
int* gIndices = params.indices + s_q_idx*params.stride_indices_s_q; // [topk]
bf16* my_sKV_base = &(make_tensor(make_smem_ptr(plan.k[0].data()), SmemLayoutKTiles<1>{})(group_idx, idx_in_group*8));
bf16* my_gKV_base = params.kv + idx_in_group*8;
int64_t token_indices[2][NUM_ROWS_PER_GROUP];
bool is_token_valid[2][NUM_ROWS_PER_GROUP];
auto load_token_indices = [&](int block_idx) {
CUTE_UNROLL
for (int buf_idx = 0; buf_idx < 2; ++buf_idx) {
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) {
int offs = (block_idx+buf_idx)*B_TOPK + local_row*NUM_GROUPS + group_idx;
int t = __ldg(gIndices + offs);
token_indices[buf_idx][local_row] = t*(int64_t)params.stride_kv_s_kv; // We mult it with params.stride_kv_s_kv here since it's faster
bool is_cur_token_valid = t >= 0 && t < params.s_kv;
if constexpr (HAVE_TOPK_LENGTH) {
is_cur_token_valid &= offs < topk_length;
}
is_token_valid[buf_idx][local_row] = is_cur_token_valid;
}
}
};
int64_t cache_policy = createpolicy_evict_last();
auto copy_tiles = [&](int block_idx, int buf_idx, int tile_start, int tile_end) {
// Copy some K/V tiles from global memory to shared memory
// A tile has a shape of 64 (B_TOPK) x 64
// `buf_idx` is the index of the shared memory buffer, 0 or 1
// `tile_idx` is the index of the tile to load, from 0 to D_K/64-1 = 8
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) {
int64_t token_index = token_indices[buf_idx][local_row];
CUTE_UNROLL
for (int tile_idx = tile_start; tile_idx < tile_end; ++tile_idx) {
cp_async_cacheglobal_l2_prefetch_256B(
my_gKV_base + token_index + tile_idx*64,
my_sKV_base + (buf_idx*B_TOPK*D_K + tile_idx*(B_TOPK*64) + local_row*NUM_GROUPS*64),
is_token_valid[buf_idx][local_row],
cache_policy
);
}
}
};
auto commit_to_mbar = [&](transac_bar_t &bar) {
cutlass::arch::cpasync_barrier_arrive_noinc((uint64_t*)(&bar));
};
int cur_bar_wait_phase = 1;
CUTE_NO_UNROLL
for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) {
load_token_indices(block_idx);
// V0L
plan.bar_k0_free[0].wait(cur_bar_wait_phase);
copy_tiles(block_idx+0, 0, 0, 4);
commit_to_mbar(plan.bar_k0_ready[0]);
// V1R
plan.bar_k1_free[1].wait(cur_bar_wait_phase);
copy_tiles(block_idx+1, 1, 4, D_K/64);
commit_to_mbar(plan.bar_k1_ready[1]);
// V0R
plan.bar_k0_free[1].wait(cur_bar_wait_phase);
copy_tiles(block_idx+0, 0, 4, D_K/64);
commit_to_mbar(plan.bar_k0_ready[1]);
// V1L
plan.bar_k1_free[0].wait(cur_bar_wait_phase);
copy_tiles(block_idx+1, 1, 0, 4);
commit_to_mbar(plan.bar_k1_ready[0]);
// Valid mask
// NOTE: V1R's finish implies maskings of the last round have finished
if (idx_in_group == 0) {
CUTE_UNROLL
for (int buf_idx = 0; buf_idx < 2; ++buf_idx)
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row)
plan.is_kv_valid[buf_idx][local_row*NUM_GROUPS+group_idx] = is_token_valid[buf_idx][local_row];
plan.bar_is_kv_valid_ready.arrive();
}
cur_bar_wait_phase ^= 1;
}
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm90");
}
#endif
} }
template<typename Kernel, typename TMAParams> template<typename Kernel>
__global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 1) __global__ void __launch_bounds__(Kernel::NUM_THREADS, 1)
sparse_attn_fwd_kernel(__grid_constant__ const SparseAttnFwdParams params, __grid_constant__ const TMAParams tma_params) { sparse_attn_fwd_kernel(const SparseAttnFwdParams params) {
Kernel::devfunc(params, tma_params); Kernel::devfunc(params);
} }
template<int D_QK, bool HAVE_TOPK_LENGTH> template<int D_QK, bool HAVE_TOPK_LENGTH>
...@@ -579,63 +28,7 @@ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams &para ...@@ -579,63 +28,7 @@ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams &para
KU_ASSERT(params.h_q % B_H == 0); KU_ASSERT(params.h_q % B_H == 0);
auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q); auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q);
auto tma_Q = cute::make_tma_copy(
SM90_TMA_LOAD{},
make_tensor(
make_gmem_ptr((bf16*)params.q),
make_layout(
shape_Q,
make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q)
)
),
SmemLayoutQ{}
);
CUtensorMap tensor_map_O;
{
uint64_t size[3] = {D_V, (unsigned long)params.h_q, (unsigned long)params.s_q};
uint64_t stride[2] = {D_V*sizeof(bf16), D_V*params.h_q*sizeof(bf16)};
uint32_t box_size[3] = {64, B_H, 1};
uint32_t elem_stride[3] = {1, 1, 1};
CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
&tensor_map_O,
CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
3,
params.out,
size,
stride,
box_size,
elem_stride,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);
KU_ASSERT(res == CUresult::CUDA_SUCCESS);
}
TmaParams<
decltype(shape_Q), decltype(tma_Q)
> tma_params = {
shape_Q, tma_Q,
tensor_map_O
};
auto kernel = &sparse_attn_fwd_kernel<KernelTemplate<D_QK, HAVE_TOPK_LENGTH>, decltype(tma_params)>;
constexpr size_t smem_size = sizeof(SharedMemoryPlan);
KU_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
cutlass::ClusterLaunchParams launch_params = {
dim3((params.h_q/B_H)*params.s_q, 1, 1), // NOTE: We put s_q on the first dim since it can be larger than 65536 (the maximum size of griddim.y and griddim.z)
dim3(NUM_THREADS, 1, 1),
dim3(1, 1, 1),
smem_size,
params.stream
};
cutlass::launch_kernel_on_cluster(
launch_params, (void*)kernel, params, tma_params
);
KU_CHECK_KERNEL_LAUNCH();
} }
template<int D_QK, bool HAVE_TOPK_LENGTH> template<int D_QK, bool HAVE_TOPK_LENGTH>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment