Commit c28eca99 authored by Shengyu Liu's avatar Shengyu Liu
Browse files

Reorganize files and add sparse prefill/decoding kernels on hopper

parent 261330bb
#pragma once
enum NamedBarriers : uint32_t {
sScale_and_sS_ready = 0,
sScale_and_sS_free = 1,
oBuf_free_and_sL_ready = 2,
epilogue_r2s_ready = 3,
batch_loop_sync = 4,
warpgroup0_sync = 5
};
#include "splitkv_mla.h"
#include <cutlass/barrier.h>
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/cluster_launch.hpp>
#include "utils.h"
#include "components/config.h"
#include "components/epilogue.h"
#include "components/helpers.h"
#include "components/named_barriers.h"
#include "components/dequant.h"
using namespace cute;
namespace sm90 {
static constexpr float MAX_INIT_VAL = -1e30; // Prevent (-inf) - (-inf) = nan
using cutlass::arch::fence_view_async_shared;
using cutlass::arch::NamedBarrier;
// Save rPb (64x64, bfloat16) to sP using the stmatrix instruction
template<
typename Tensor0,
typename Tensor1
>
__forceinline__ __device__ void save_rPb_to_sP(
Tensor0 const &rPb,
Tensor1 const &sP,
int idx_in_warpgroup
) {
auto r2s_copy = make_tiled_copy_C(
Copy_Atom<SM90_U32x4_STSM_N, bf16>{},
TiledMMA_QK{}
);
ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup);
Tensor thr_copy_rPb = thr_copy.retile_S(rPb);
Tensor thr_copy_sP = thr_copy.partition_D(sP);
cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP);
}
// Retrieve rPb (64x64, bfloat16) from sP using the ldmatrix instruction
template<
typename Tensor0,
typename Tensor1
>
__forceinline__ __device__ void retrieve_rP_from_sP(
Tensor0 &rPb,
Tensor1 const &sP,
int idx_in_warpgroup
) {
TiledCopy s2r_copy = make_tiled_copy_A(
Copy_Atom<SM75_U32x4_LDSM_N, bf16>{},
TiledMMA_PV_LocalP{}
);
ThrCopy thr_copy = s2r_copy.get_slice(idx_in_warpgroup);
Tensor thr_copy_sP = thr_copy.partition_S(sP);
Tensor thr_copy_rPb = thr_copy.retile_D(rPb);
cute::copy(s2r_copy, thr_copy_sP, thr_copy_rPb);
}
template<
typename Tensor0,
typename Tensor1,
typename Tensor2
>
__forceinline__ __device__ void scale_softmax(
Tensor0 &rP,
Tensor1 &rS,
Tensor2 &rO,
float scale_softmax_log2,
float sScale[],
float rM[2],
float rL[2],
bool is_kv_valid[],
int block_idx,
int idx_in_warpgroup
) {
float scale_for_olds[2];
CUTE_UNROLL
for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {
Tensor cur_rP = flatten(rP(make_coord(_, local_row_idx, _), _, _));
Tensor cur_rS = flatten(rS(make_coord(_, local_row_idx, _), _, _));
Tensor cur_rO = flatten(rO(make_coord(_, local_row_idx, _), _, _));
float cur_max = -INFINITY;
CUTE_UNROLL
for (int i = 0; i < size(cur_rP); ++i) {
if (!is_kv_valid[(i&1)+(i/2)*8+(idx_in_warpgroup%4)*2])
cur_rP(i) = -INFINITY;
cur_max = max(cur_max, cur_rP(i));
}
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1));
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2));
cur_max *= scale_softmax_log2;
float old_max = rM[local_row_idx];
rM[local_row_idx] = max(cur_max, old_max);
float scale_for_old = exp2f(old_max - rM[local_row_idx]);
scale_for_olds[local_row_idx] = scale_for_old;
CUTE_UNROLL
for (int i = 0; i < size(cur_rO); ++i) {
cur_rO(i) *= scale_for_old;
}
float cur_sum = 0;
CUTE_UNROLL
for (int i = 0; i < size(cur_rP); ++i) {
cur_rP(i) = exp2f(cur_rP(i)*scale_softmax_log2 - rM[local_row_idx]);
cur_rS(i) = (bf16)cur_rP(i);
cur_sum += cur_rP(i);
}
rL[local_row_idx] = rL[local_row_idx]*scale_for_old + cur_sum;
}
if (idx_in_warpgroup%4 == 0)
*(float2*)(sScale + 2*(idx_in_warpgroup/4)) = *(float2*)(scale_for_olds);
}
template<typename TmaParams>
__global__ void __launch_bounds__(NUM_THREADS, 1, 2)
flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams params, __grid_constant__ const TmaParams tma_params) {
#if IS_SM90
const int head_block_idx = blockIdx.x;
const int s_q_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
const int idx_in_cluster = head_block_idx % 2;
const int warpgroup_idx = cutlass::canonical_warp_group_idx();
const int idx_in_warpgroup = threadIdx.x % 128;
const int warp_idx = cutlass::canonical_warp_idx_sync();
// Define shared tensors
extern __shared__ char wksp_buf[];
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);
Tensor sQ = make_tensor(make_smem_ptr(plan.q.data()), SmemLayoutQ{});
Tensor sOBuf = make_tensor(make_smem_ptr(plan.u.oBuf.data()), SmemLayoutOBuf{});
Tensor sOAccumBuf = make_tensor(make_smem_ptr(plan.u.oAccumBuf.data()), SmemLayoutOAccumBuf{});
Tensor sS = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutS{});
float* sM = plan.sM;
float* sL = plan.sL;
float* sScale = plan.sScale;
// Prefetch TMA descriptors
if (warp_idx == 0 && elect_one_sync()) {
cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor());
}
// Initialize TMA barriers
if (warp_idx == 0 && elect_one_sync()) {
plan.bar_q.init(1);
CUTE_UNROLL
for (int i = 0; i < NUM_K_BUFS; ++i) {
plan.bar_k_local_ready[i].init(128);
plan.bar_k_remote_ready[i].init(1);
plan.bar_k_avail[i].init(4);
}
fence_view_async_shared();
}
cute::cluster_arrive();
bool bar_phase_q = 0;
int bar_phase_k = 0; // Don't use array here to prevent using local memory
// Programmatic Dependent Launch: Wait for the previous kernel to finish
// Don't use PDL because of compiler bugs!
// cudaGridDependencySynchronize();
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
int4 tile_scheduler_metadata = __ldg(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
int begin_idx = tile_scheduler_metadata.x;
int sched_begin_block_idx = tile_scheduler_metadata.y;
int end_idx = tile_scheduler_metadata.z;
int sched_end_block_idx = tile_scheduler_metadata.w;
if (begin_idx >= params.b) return;
int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
if (warp_idx == 0 && elect_one_sync()) {
Tensor gQ = flat_divide(
tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, begin_idx),
Tile<Int<BLOCK_M>, Int<HEAD_DIM_K>>{}
)(_, _, head_block_idx, _0{});
launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST);
plan.bar_q.arrive_and_expect_tx(BLOCK_M*HEAD_DIM_K*sizeof(bf16));
}
cute::cluster_wait(); // Wait for barriers from the other CTA to be ready
auto get_cur_req_info = [&](int batch_idx) -> std::tuple<int, int, bool> {
constexpr int kBlockN = TOPK_BLOCK_SIZE;
const int start_block_idx = batch_idx == begin_idx ? sched_begin_block_idx : 0;
// NOTE TopK attention has nothing to do with causal mask and sliding window
int end_block_idx = batch_idx == end_idx ? sched_end_block_idx : cute::ceil_div(params.topk, kBlockN);
const bool is_no_split = start_block_idx == 0 && end_block_idx == cute::ceil_div(params.topk, kBlockN);
return {start_block_idx, end_block_idx, is_no_split};
};
if (warpgroup_idx == 0) {
cutlass::arch::warpgroup_reg_alloc<192>();
TiledMMA tiled_mma_QK = TiledMMA_QK{};
ThrMMA thr_mma_QK = tiled_mma_QK.get_slice(idx_in_warpgroup);
TiledMMA tiled_mma_PV = TiledMMA_PV_LocalP{};
ThrMMA thr_mma_PV = tiled_mma_PV.get_slice(idx_in_warpgroup);
float rL[2], rM[2];
Tensor rO = partition_fragment_C(TiledMMA_PV_LocalP{}, Shape<Int<BLOCK_M>, Int<HEAD_DIM_V/2>>{});
Tensor rP = partition_fragment_C(TiledMMA_QK{}, Shape<Int<BLOCK_M>, Int<TOPK_BLOCK_SIZE>>{});
Tensor rS = make_tensor<bf16>(partition_shape_A(TiledMMA_PV_LocalP{}, Shape<Int<BLOCK_M>, Int<TOPK_BLOCK_SIZE>>{}));
#pragma unroll 1
for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) {
auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx);
rL[0] = rL[1] = 0.0f;
rM[0] = rM[1] = MAX_INIT_VAL;
cute::fill(rO, 0.);
// Wait for Q
plan.bar_q.wait(bar_phase_q);
bar_phase_q ^= 1;
CUTE_NO_UNROLL
for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) {
int buf_idx = (block_idx-start_block_idx) % NUM_K_BUFS;
Tensor sK = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutHalfV{});
// Wait, issue WGMMA
plan.bar_k_local_ready[buf_idx].wait(bar_phase_k>>buf_idx&1);
plan.bar_k_remote_ready[buf_idx].wait(bar_phase_k>>buf_idx&1);
gemm<true, -1>(
tiled_mma_QK,
thr_mma_QK.partition_fragment_A(sQ),
thr_mma_QK.partition_fragment_B(sK),
rP
);
bar_phase_k ^= 1<<buf_idx;
cute::warpgroup_wait<0>();
// Calculate S = softmax(mask(scale(P)))
if (block_idx != start_block_idx)
NamedBarrier::arrive_and_wait(256, NamedBarriers::sScale_and_sS_free); // Make sure that sScale and sS is free
// Since in our case TOPK_BLOCK_SIZE == BLOCK_M, so we only need to do OOB checking for the last 2 blocks
scale_softmax(rP, rS, rO, params.scale_softmax_log2, sScale, rM, rL, plan.is_kv_valid[buf_idx], block_idx, idx_in_warpgroup);
// Store S into shared, inform warpgroup 1
save_rPb_to_sP(rS, sS, idx_in_warpgroup);
fence_view_async_shared();
// Issue O += S @ V
gemm<false, -1>(
tiled_mma_PV,
rS,
thr_mma_PV.partition_fragment_B(sV),
rO
);
NamedBarrier::arrive(256, NamedBarriers::sScale_and_sS_ready);
cute::warpgroup_wait<0>();
plan.bar_k_avail[buf_idx].arrive(0, idx_in_warpgroup == 32);
plan.bar_k_avail[buf_idx].arrive(1, idx_in_warpgroup == 64);
}
// Copy the next q
if (warp_idx == 0 && elect_one_sync()) {
if (batch_idx != end_idx) {
Tensor gQ = flat_divide(
tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, batch_idx+1),
Tile<Int<BLOCK_M>, Int<HEAD_DIM_K>>{}
)(_, _, head_block_idx, _0{});
launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST);
plan.bar_q.arrive_and_expect_tx(BLOCK_M*HEAD_DIM_K*sizeof(bf16));
} else {
cudaTriggerProgrammaticLaunchCompletion();
}
}
// Synchronize L and M across warpgroups
rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1);
rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2);
rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1);
rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2);
if (idx_in_warpgroup%4 == 0) {
CUTE_UNROLL
for (int i = 0; i < 2; ++i) {
int row = get_AorC_row_idx(i, idx_in_warpgroup);
sL[row] = rL[i];
sM[row] = rM[i];
}
}
// This is a synchronization point for warpgroup 0/1.
// Warpgroup 0 should wait wg 1 for oBuf/oAccumBuf (overlapped with k) to be free
// Warpgroup 1 should wait wg 0 for sL to be ready
NamedBarrier::arrive_and_wait(256, NamedBarriers::oBuf_free_and_sL_ready);
CUTE_UNROLL
for (int i = 0; i < 2; ++i)
rL[i] = rL[i] == 0.0f ? 1.0f : rL[i];
int num_valid_seq_q = min(params.q_head_per_hk - head_block_idx*BLOCK_M, BLOCK_M);
int start_seq_idx = s_q_idx*params.q_head_per_hk + head_block_idx*BLOCK_M;
if (is_no_split) {
bf16* o_ptr = (bf16*)params.o_ptr + batch_idx*params.o_batch_stride + start_seq_idx*params.o_row_stride; // (BLOCK_M, HEAD_DIM_V) : (params.o_row_stride, 1)
Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout(
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{},
make_stride(params.o_row_stride, _1{})
));
float* gSoftmaxLse = (float*)params.softmax_lse_ptr + batch_idx*params.q_seq_per_hk + start_seq_idx; // (BLOCK_M) : (1)
store_o<true>(rO, gO, sOBuf, sOAccumBuf, rL, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);
int i = threadIdx.x;
if (i < num_valid_seq_q) {
float cur_L = sL[i];
gSoftmaxLse[i] = cur_L == 0.0f ? INFINITY : logf(cur_L) + sM[i] / (float)M_LOG2E;
}
cute::tma_store_wait<0>();
} else {
int n_split_idx = batch_idx == begin_idx ? begin_n_split_idx : 0;
int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx;
float* oaccum_ptr = (float*)params.oaccum_ptr + (split_idx*params.q_seq_per_hk + start_seq_idx)*HEAD_DIM_V; // (BLOCK_M, HEAD_DIM_V) : (HEAD_DIM_V, 1)
float* gSoftmaxLseAccum = (float*)params.softmax_lseaccum_ptr + split_idx*params.q_seq_per_hk + start_seq_idx; // (BLOCK_M) : (1)
Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout<
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>,
Stride<Int<HEAD_DIM_V>, _1>
>{});
store_o<false>(rO, gOAccum, sOBuf, sOAccumBuf, rL, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);
int i = threadIdx.x;
if (i < num_valid_seq_q) {
float cur_L = sL[i];
gSoftmaxLseAccum[i] = cur_L == 0.0f ? -INFINITY : log2f(cur_L) + sM[i];
}
cute::tma_store_wait<0>();
}
cute::cluster_sync(); // Must use arrive_and_wait here to prevent overwritting sL while WG1 is writing back its result
}
} else if (warpgroup_idx == 1) {
cutlass::arch::warpgroup_reg_dealloc<160>();
TiledMMA tiled_mma_PV = TiledMMA_PV_RemoteP{};
ThrMMA thr_mma_PV = tiled_mma_PV.get_slice(idx_in_warpgroup);
Tensor rO = partition_fragment_C(tiled_mma_PV, Shape<Int<BLOCK_M>, Int<HEAD_DIM_V/2>>{});
float rL[2];
#pragma unroll 1
for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) {
auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx);
cute::fill(rO, 0.);
CUTE_NO_UNROLL
for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) {
int buf_idx = (block_idx-start_block_idx) % NUM_K_BUFS;
Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data() + (SmemLayoutV{})(_256{}, _0{})), SmemLayoutHalfV{});
// Wait for S and sScale
NamedBarrier::arrive_and_wait(256, NamedBarriers::sScale_and_sS_ready);
// Scale O
float cur_scales[2];
*(float2*)cur_scales = *(float2*)(sScale + (idx_in_warpgroup/4)*2);
CUTE_UNROLL
for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {
Tensor cur_rO = flatten(rO(make_coord(_, local_row_idx, _), _, _));
CUTE_UNROLL
for (int i = 0; i < size(cur_rO); ++i) {
cur_rO(i) *= cur_scales[local_row_idx];
}
}
// Issue O += S @ V, and wait
gemm<false, -1>(
tiled_mma_PV,
thr_mma_PV.partition_fragment_A(sS),
thr_mma_PV.partition_fragment_B(sV),
rO
);
cute::warpgroup_wait<0>();
plan.bar_k_avail[buf_idx].arrive(0, idx_in_warpgroup == 32);
plan.bar_k_avail[buf_idx].arrive(1, idx_in_warpgroup == 64);
if (block_idx != end_block_idx-1)
NamedBarrier::arrive(256, NamedBarriers::sScale_and_sS_free); // Tell WG0 that sScale and sS are available
}
NamedBarrier::arrive_and_wait(256, NamedBarriers::oBuf_free_and_sL_ready);
CUTE_UNROLL
for (int i = 0; i < 2; ++i) {
int row = get_AorC_row_idx(i, idx_in_warpgroup);
rL[i] = sL[row];
}
CUTE_UNROLL
for (int i = 0; i < 2; ++i)
rL[i] = rL[i] == 0.0f ? 1.0f : rL[i];
int num_valid_seq_q = min(params.q_head_per_hk - head_block_idx*BLOCK_M, BLOCK_M);
int start_seq_idx = s_q_idx*params.q_head_per_hk+head_block_idx*BLOCK_M;
if (is_no_split) {
bf16* o_ptr = (bf16*)params.o_ptr + batch_idx*params.o_batch_stride + start_seq_idx*params.o_row_stride; // (BLOCK_M, HEAD_DIM_V) : (params.o_row_stride, 1)
Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout(
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{},
make_stride(params.o_row_stride, _1{})
));
store_o<true>(rO, gO, sOBuf, sOAccumBuf, rL, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);
cute::tma_store_wait<0>();
} else {
int n_split_idx = batch_idx == begin_idx ? begin_n_split_idx : 0;
int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx;
float* oaccum_ptr = (float*)params.oaccum_ptr + (split_idx*params.q_seq_per_hk + start_seq_idx)*HEAD_DIM_V; // (BLOCK_M, HEAD_DIM_V) : (HEAD_DIM_V, 1)
Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout<
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>,
Stride<Int<HEAD_DIM_V>, _1>
>{});
store_o<false>(rO, gOAccum, sOBuf, sOAccumBuf, rL, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);
cute::tma_store_wait<0>();
}
cute::cluster_sync(); // We must use arrive_and_wait instead of arrive here to create an order between "forall warp in WG1, warp has done written back O" and "warp 2 signals `bar_k_avail`"
}
} else {
// Producer warpgroup
cutlass::arch::warpgroup_reg_dealloc<152>();
int warp_idx = __shfl_sync(0xffffffff, idx_in_warpgroup / 32, 0); // NOTE TPBNO
int lane_idx = idx_in_warpgroup % 32;
int my_token_idx = warp_idx*8 + lane_idx%8;
CUTE_NO_UNROLL
for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) {
auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx);
int* gIndices = params.indices_ptr + batch_idx*params.indices_batch_stride + s_q_idx*params.indices_row_stride; // (topk) : (1)
#define GET_TOKEN_INDEX(block_idx) __ldg(gIndices + (block_idx)*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx)
int nxt_token_index = GET_TOKEN_INDEX(start_block_idx);
CUTE_NO_UNROLL
for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) {
int buf_idx = (block_idx-start_block_idx) % NUM_K_BUFS;
// Define shared and global tensors
bf16* sK_nope_base = plan.u.k[buf_idx].data() + (idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx)*8 + ((lane_idx/8)*16)*TOPK_BLOCK_SIZE;
bf16* sK_nope_peer_base = get_peer_addr(sK_nope_base);
transac_bar_t* peer_bar_k_remote_ready = get_peer_addr(&(plan.bar_k_remote_ready[buf_idx]));
int token_index = nxt_token_index;
if (block_idx+1 != end_block_idx)
nxt_token_index = GET_TOKEN_INDEX(block_idx+1);
int block_index = token_index/PAGE_BLOCK_SIZE;
int rel_idx_in_block = (token_index+PAGE_BLOCK_SIZE) % PAGE_BLOCK_SIZE; // NOTE When token_index is -1, -1/PAGE_BLOCK_SIZE = 0 and (-1+PAGE_BLOCK_SIZE)%PAGE_BLOCK_SIZE = 63, so there will be no illegal-memory-access error
fp8* gK_base = (fp8*)params.k_ptr + block_index*params.k_batch_stride + rel_idx_in_block*params.k_row_stride;
float4 scales = load_128b_from_gmem<float4, L1CacheHint::EVICT_LAST, L2PrefetchHint::B128>((float*)(gK_base+HEAD_DIM_NOPE));
// Wait for the nope buffer to be available
plan.bar_k_avail[buf_idx].wait((bar_phase_k>>buf_idx&1)^1);
bar_phase_k ^= 1 << buf_idx;
// Copy block #block_index
if (idx_in_warpgroup == 0) {
plan.bar_k_remote_ready[buf_idx].arrive_and_expect_tx((TOPK_BLOCK_SIZE/2)*(HEAD_DIM_NOPE+HEAD_DIM_ROPE)*sizeof(bf16));
}
// Collectively copy from global memory and dequant
// For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py
fp8* gK_nope = gK_base + (lane_idx/8)*16;
if (token_index == -1) {
scales = {0.0f, 0.0f, 0.0f, 0.0f};
}
CUTE_UNROLL
for (int dim_idx = 0; dim_idx < HEAD_DIM_NOPE/64; dim_idx += 1) {
fp8x16 cur_fp8x16 = load_128b_from_gmem<fp8x16, L1CacheHint::EVICT_LAST, L2PrefetchHint::B256>(gK_nope + dim_idx*64); // We use EVICT_LAST here since gK_base may not be aligned to 32B
float scale = dim_idx < 4 ? (dim_idx < 2 ? scales.x : scales.y) : (dim_idx < 6 ? scales.z : scales.w);
auto dequant_and_save_bf16x8 = [&](const fp8x8 &data, int offset) {
int smem_offset = (dim_idx*64 + offset) * TOPK_BLOCK_SIZE;
bf16x8 cur_bf16x8 = cvt_fp8x8_bf16x8(data, scale);
*(__int128_t*)(sK_nope_base + smem_offset) = *(__int128_t*)&cur_bf16x8;
st_async_128b(sK_nope_peer_base + smem_offset, cur_bf16x8, peer_bar_k_remote_ready);
};
if (token_index == -1)
*(uint128_t*)(&cur_fp8x16) = uint128_t();
dequant_and_save_bf16x8(cur_fp8x16.lo, 0);
dequant_and_save_bf16x8(cur_fp8x16.hi, 8);
}
bf16* gK_rope = (bf16*)(gK_base+HEAD_DIM_NOPE+NUM_SCALES*sizeof(float)) + (lane_idx/8)*8;
bf16* sK_rope_base = plan.u.k[buf_idx].data() + (idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx)*8 + ((lane_idx/8)*8)*TOPK_BLOCK_SIZE;
bf16* sK_rope_peer_base = get_peer_addr(sK_rope_base);
CUTE_UNROLL
for (int dim_idx = 0; dim_idx < HEAD_DIM_ROPE/32; dim_idx += 1) {
bf16x8 cur_bf16x8 = load_128b_from_gmem<bf16x8, L1CacheHint::EVICT_LAST, L2PrefetchHint::B128>(gK_rope + dim_idx*32);
if (token_index == -1)
*(uint128_t*)(&cur_bf16x8) = uint128_t();
int smem_offset = (HEAD_DIM_NOPE + dim_idx*32) * TOPK_BLOCK_SIZE;
*(__int128_t*)(sK_rope_base + smem_offset) = *(__int128_t*)&cur_bf16x8;
st_async_128b(sK_rope_peer_base + smem_offset, cur_bf16x8, peer_bar_k_remote_ready);
}
fence_view_async_shared();
if (idx_in_warpgroup < 32) {
// We put this after fence_view_async_shared() since this won't be read by async proxy
int2 indices = __ldg((int2*)(gIndices + block_idx*TOPK_BLOCK_SIZE + lane_idx*2));
*(char2*)(&plan.is_kv_valid[buf_idx][lane_idx*2]) = {indices.x != -1, indices.y != -1};
}
// Signal the barrier
plan.bar_k_local_ready[buf_idx].arrive();
}
cute::cluster_sync();
}
}
if (begin_idx > end_idx) {
cute::cluster_sync(); // Don't need a cluster_sync() when begin_idx <= end_idx, since the loop will execute at least once and the final statement is cluster_sync()
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm90");
}
#endif
}
void run_flash_splitkv_mla_fp8_sparse_kernel(DecodingParams &params, cudaStream_t stream) {
FLASH_ASSERT(params.h_k == 1);
FLASH_ASSERT(params.topk % TOPK_BLOCK_SIZE == 0);
auto shape_Q = make_shape(params.q_head_per_hk, params.d, params.s_q, params.b);
auto tma_Q = cute::make_tma_copy(
SM90_TMA_LOAD{},
make_tensor(
make_gmem_ptr((bf16*)params.q_ptr),
make_layout(
shape_Q,
make_stride(params.q_row_stride, _1{}, params.q_head_per_hk*params.q_row_stride, params.q_batch_stride)
)
),
SmemLayoutQ{}
);
auto shape_O = make_shape(params.q_head_per_hk, params.d_v, params.s_q, params.b);
auto tma_O = cute::make_tma_copy(
SM90_TMA_STORE{},
make_tensor(
make_gmem_ptr((bf16*)params.o_ptr),
make_layout(
shape_O,
make_stride(params.o_row_stride, _1{}, params.q_head_per_hk*params.o_row_stride, params.o_batch_stride)
)
),
SmemLayoutOBuf{}
);
TmaParams<
decltype(shape_Q), decltype(tma_Q),
decltype(shape_O), decltype(tma_O)
> tma_params = {
shape_Q, tma_Q,
shape_O, tma_O
};
auto mla_kernel = &flash_fwd_splitkv_mla_fp8_sparse_kernel<decltype(tma_params)>;
constexpr size_t smem_size = sizeof(SharedMemoryPlan);
CHECK_CUDA(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
const int num_m_block = cute::ceil_div(params.q_head_per_hk, 2*BLOCK_M) * 2;
// NOTE Don't use PDL because of potential compiler bugs!
// cudaLaunchAttribute mla_kernel_attributes[1];
// mla_kernel_attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
// mla_kernel_attributes[0].val.programmaticStreamSerializationAllowed = 1;
// cudaLaunchConfig_t mla_kernel_config = {
// dim3(num_m_block, params.h_k, params.num_sm_parts),
// dim3(NUM_THREADS, 1, 1),
// smem_size,
// stream,
// mla_kernel_attributes,
// 1
// };
// cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params);
cutlass::ClusterLaunchParams launch_params = {
dim3(num_m_block, params.s_q, params.num_sm_parts),
dim3(NUM_THREADS, 1, 1),
dim3(2, 1, 1),
smem_size,
stream
};
cutlass::launch_kernel_on_cluster(
launch_params, (void*)mla_kernel, params, tma_params
);
CHECK_CUDA_KERNEL_LAUNCH();
}
}
#pragma once
#include "params.h"
namespace sm90 {
void run_flash_splitkv_mla_fp8_sparse_kernel(DecodingParams &params, cudaStream_t stream);
}
#pragma once
#include "params.h"
void run_get_mla_metadata_kernel(Mla_metadata_params &params, cudaStream_t stream);
#include "fwd.h"
#include <math_constants.h>
#include <cute/tensor.hpp>
#include <cutlass/cluster_launch.hpp>
#include <cooperative_groups.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/arch/arch.h>
#include "utils.h"
#include "helpers.h"
namespace sm90 {
using namespace cute;
constexpr int D_Q = 576;
constexpr int D_K = 576;
constexpr int D_V = 512;
constexpr int B_H = 64;
constexpr int B_TOPK = 64; // TopK block size
constexpr int NUM_THREADS = 128*3;
static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits)
template<int NUM_TILES>
using SmemLayoutQTiles = decltype(coalesce(tile_to_shape(
GMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutOTiles = decltype(coalesce(tile_to_shape(
GMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTiles = decltype(coalesce(tile_to_shape(
GMMA::Layout_SW128_Atom<bf16, GMMA::Major::K>{},
Shape<Int<B_TOPK>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTilesTransposed = decltype(composition(
SmemLayoutKTiles<NUM_TILES>{},
Layout<Shape<Int<64*NUM_TILES>, Int<B_TOPK>>, Stride<Int<B_TOPK>, _1>>{}
));
using SmemLayoutQ = SmemLayoutQTiles<9>;
using SmemLayoutO = SmemLayoutOTiles<8>;
using SmemLayoutK = SmemLayoutKTiles<9>;
using SmemLayoutV = SmemLayoutKTilesTransposed<8>;
using SmemLayoutHalfV = SmemLayoutKTilesTransposed<4>;
using SmemLayoutS = decltype(coalesce(tile_to_shape(
GMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<B_TOPK>>{}
), Shape<_1, _1>{}));
struct SharedMemoryPlan {
union {
array_aligned<bf16, cosize_v<SmemLayoutQ>> q;
array_aligned<bf16, cosize_v<SmemLayoutO>> o;
} q_o;
array_aligned<bf16, cosize_v<SmemLayoutK>> k[2];
array_aligned<bf16, cosize_v<SmemLayoutS>> s;
bool is_kv_valid[2][B_TOPK];
float2 sM[32];
float2 sL[64]; // For reduction across WG0/1 in epilogue
float final_max_logits[64], final_lse[64];
transac_bar_t bar_q, bar_k0_free[2], bar_k0_ready[2], bar_k1_free[2], bar_k1_ready[2], bar_is_kv_valid_ready;
};
using TiledMMA_QK = decltype(make_tiled_mma(
GMMA::MMA_64x64x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::K>{},
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_LocalP = decltype(make_tiled_mma(
GMMA::MMA_64x256x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_RemoteP = decltype(make_tiled_mma(
GMMA::MMA_64x256x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{}
));
template<
typename Shape_Q, typename TMA_Q
>
struct TmaParams {
Shape_Q shape_Q; TMA_Q tma_Q;
CUtensorMap tensor_map_O;
};
enum NamedBarriers : uint32_t {
wg0_bunch_0_ready = 0,
wg1_bunch_0_ready = 1,
wg0_s0_ready = 2,
wg1_s1_ready = 3,
sL_ready = 4,
warpgroup0_sync = 5,
warpgroup1_sync = 6
};
// Save rPb (64x64, bfloat16) to sP using the stmatrix instruction
template<
typename Tensor0,
typename Tensor1
>
__forceinline__ __device__ void save_rS_to_sS(
Tensor0 const &rPb,
Tensor1 const &sP,
int idx_in_warpgroup
) {
auto r2s_copy = make_tiled_copy_C(
Copy_Atom<SM90_U32x4_STSM_N, bf16>{},
TiledMMA_QK{}
);
ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup);
Tensor thr_copy_rPb = thr_copy.retile_S(rPb);
Tensor thr_copy_sP = thr_copy.partition_D(sP);
cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP);
}
template<typename TmaParams>
__global__ void __launch_bounds__(NUM_THREADS, 1, 1)
sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __grid_constant__ const TmaParams tma_params) {
// NOTE This kernel uses a similar schedule to Flash MLA - 0422. For a detailed explanation, please refer to https://github.com/deepseek-ai/FlashMLA/blob/main/docs/20250422-new-kernel-deep-dive.md
#if IS_SM90
const int q_h_idx = blockIdx.x % (params.h_q/B_H);
const int s_q_idx = blockIdx.x / (params.h_q/B_H);
const int warpgroup_idx = cutlass::canonical_warp_group_idx();
const int warp_idx = cutlass::canonical_warp_idx_sync();
const int idx_in_warpgroup = threadIdx.x % 128;
// Define shared tensors
extern __shared__ char wksp_buf[];
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);
Tensor sQ = make_tensor(make_smem_ptr(plan.q_o.q.data()), SmemLayoutQ{});
Tensor sO = make_tensor(make_smem_ptr(plan.q_o.o.data()), SmemLayoutO{});
Tensor sS0 = make_tensor(make_smem_ptr(plan.k[0].data()+64*512), SmemLayoutS{}); // Overlap with sK0's RoPE part
Tensor sS1 = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutS{});
if (warp_idx == 0 && elect_one_sync()) {
// Prefetch TMA descriptors
cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(&tma_params.tensor_map_O);
// Initialize barriers
plan.bar_q.init(1);
CUTE_UNROLL
for (int i = 0; i < 2; ++i) {
plan.bar_k0_free[i].init(128);
plan.bar_k0_ready[i].init(128);
plan.bar_k1_free[i].init(128);
plan.bar_k1_ready[i].init(128);
}
plan.bar_is_kv_valid_ready.init(16);
fence_barrier_init();
}
__syncthreads();
const int num_topk_blocks = params.topk / B_TOPK;
if (warpgroup_idx == 0 || warpgroup_idx == 1) {
cutlass::arch::warpgroup_reg_alloc<216>();
if (warp_idx == 0 && elect_one_sync()) {
// Load Q
Tensor gQ = flat_divide(
tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx),
Tile<Int<B_H>, Int<D_Q>>{}
)(_, _, q_h_idx, _0{});
launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST);
plan.bar_q.arrive_and_expect_tx(B_H*D_Q*sizeof(bf16));
}
float rM[2] = {MAX_INIT_VAL, MAX_INIT_VAL}; // Meaning: the `max_logits` used for O / rL calculation
float rL[2] = {0.0f, 0.0f};
Tensor rO = partition_fragment_C(TiledMMA_PV_LocalP{}, Shape<Int<B_H>, Int<D_V/2>>{});
Tensor rP = partition_fragment_C(TiledMMA_QK{}, Shape<Int<B_H>, Int<B_TOPK>>{});
Tensor rS = make_tensor<bf16>(partition_shape_A(TiledMMA_PV_LocalP{}, Shape<Int<B_H>, Int<B_TOPK>>{}));
cute::fill(rO, 0.0f);
// Wait for Q
plan.bar_q.wait(0);
bool cur_bar_wait_phase = 0;
struct Warpgroup0 {};
struct Warpgroup1 {};
auto qkt_gemm_one_tile = [&](auto warpgroup_idx, int tile_idx, bool clear_accum) {
constexpr bool IS_WG1 = std::is_same_v<decltype(warpgroup_idx), Warpgroup1>;
TiledMMA tiled_mma_QK = TiledMMA_QK{};
Tensor sQ_tile = flat_divide(sQ, Tile<Int<B_H>, Int<64>>{})(_, _, _0{}, tile_idx);
Tensor sK_tile = make_tensor(make_smem_ptr(plan.k[(int)IS_WG1].data() + tile_idx*B_TOPK*64), SmemLayoutKTiles<1>{});
gemm_ss(clear_accum, tiled_mma_QK, sQ_tile, sK_tile, rP, idx_in_warpgroup);
};
auto mask_rP = [&](auto warpgroup_idx) {
constexpr bool IS_WG1 = std::is_same_v<decltype(warpgroup_idx), Warpgroup1>;
plan.bar_is_kv_valid_ready.wait(cur_bar_wait_phase);
CUTE_UNROLL
for (int row_idx = 0; row_idx < 2; ++row_idx) {
CUTE_UNROLL
for (int i = row_idx*2; i < size(rP); i += 4) {
int col = 8*(i/4) + (idx_in_warpgroup%4)*2;
if (!plan.is_kv_valid[IS_WG1][col]) rP(i) = -INFINITY;
if (!plan.is_kv_valid[IS_WG1][col+1]) rP(i+1) = -INFINITY;
}
}
};
auto online_softmax_and_rescale_o = [&](auto warpgroup_idx) {
plan.bar_is_kv_valid_ready.wait(cur_bar_wait_phase);
constexpr bool IS_WG1 = std::is_same_v<decltype(warpgroup_idx), Warpgroup1>;
const float scale = params.sm_scale_div_log2;
float r_sM[2];
if constexpr (IS_WG1) {
*(float2*)r_sM = plan.sM[idx_in_warpgroup/4];
}
float new_maxs[2];
CUTE_UNROLL
for (int row_idx = 0; row_idx < 2; ++row_idx) {
// Get rowwise max
float cur_max = -INFINITY;
CUTE_UNROLL
for (int i = row_idx*2; i < size(rP); i += 4) {
cur_max = max(cur_max, max(rP(i), rP(i+1)));
}
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1));
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2));
cur_max *= scale;
// Get new max and scale
// For WG1, old_max comes from sM (written by WG0); for WG0, old_max comes from rM (read by WG0 from sM in the last round)
new_maxs[row_idx] = max(IS_WG1 ? r_sM[row_idx] : rM[row_idx], cur_max);
// Scale O
float scale_for_o = exp2f(rM[row_idx]-new_maxs[row_idx]);
CUTE_UNROLL
for (int i = row_idx*2; i < size(rO); i += 4) {
rO(i) *= scale_for_o;
rO(i+1) *= scale_for_o;
}
// Get rS
float cur_sum = 0;
CUTE_UNROLL
for (int i = row_idx*2; i < size(rP); i += 4) {
rP(i) = exp2f(rP(i)*scale - new_maxs[row_idx]);
rP(i+1) = exp2f(rP(i+1)*scale - new_maxs[row_idx]);
rS(i) = (bf16)rP(i);
rS(i+1) = (bf16)rP(i+1);
cur_sum += rP(i) + rP(i+1);
}
rL[row_idx] = rL[row_idx]*scale_for_o + cur_sum;
}
__syncwarp();
if (idx_in_warpgroup%4 == 0) {
plan.sM[idx_in_warpgroup/4] = *(float2*)new_maxs;
}
rM[0] = new_maxs[0];
rM[1] = new_maxs[1];
};
auto reduce_L = [&]() {
// Reduce L
// For example, thread 0 reduces with thread 1, 2, and 3, as well as thread 128, 129, 130, and 131
rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1);
rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2);
rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1);
rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2);
if (idx_in_warpgroup%4 == 0)
plan.sL[threadIdx.x/4] = *(float2*)(rL);
NamedBarrier::arrive_and_wait(256, NamedBarriers::sL_ready);
float2 peer_L = plan.sL[(threadIdx.x/4)^32];
rL[0] += peer_L.x;
rL[1] += peer_L.y;
};
auto store_O = [&]() {
float scale_factors[2];
CUTE_UNROLL
for (int i = 0; i < 2; ++i)
scale_factors[i] = rL[i] == 0.0f ? 1.0f : 1.0f / rL[i];
Tensor sO = make_tensor(make_smem_ptr(plan.q_o.o.data() + warpgroup_idx*B_H*(D_V/2)), SmemLayoutOTiles<4>{});
bf16* stsm_addrs[4];
int stsm_row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%16);
CUTE_UNROLL
for (int i = 0; i < 64/16; ++i) {
stsm_addrs[i] = &sO(stsm_row, (idx_in_warpgroup%32/16*8) + 16*i);
}
bool s2g_pred = warp_idx%4 == 0 && elect_one_sync();
warpgroup_wait<0>();
CUTE_UNROLL
for (int tile_idx = 0; tile_idx < (D_V/2)/64; tile_idx += 1) {
// Convert
constexpr int NUM_ELEMS_EACH_TILE = B_H*64 / 128; // 64: tile size, 128: warpgroup size
bf16 cur_rOb[NUM_ELEMS_EACH_TILE];
CUTE_UNROLL
for (int i = 0; i < NUM_ELEMS_EACH_TILE; ++i) {
cur_rOb[i] = (bf16)(rO(tile_idx*NUM_ELEMS_EACH_TILE + i) * scale_factors[i%4>=2]);
}
// R -> S
CUTE_UNROLL
for (int i = 0; i < 64/16; ++i) {
SM90_U32x4_STSM_N::copy(
*reinterpret_cast<uint32_t*>(cur_rOb + i*8 + 0),
*reinterpret_cast<uint32_t*>(cur_rOb + i*8 + 2),
*reinterpret_cast<uint32_t*>(cur_rOb + i*8 + 4),
*reinterpret_cast<uint32_t*>(cur_rOb + i*8 + 6),
*reinterpret_cast<uint128_t*>(stsm_addrs[i] + tile_idx*(B_H*64))
);
}
fence_view_async_shared();
NamedBarrier::arrive_and_wait(128, warpgroup_idx ? NamedBarriers::warpgroup1_sync : NamedBarriers::warpgroup0_sync);
// S -> G
if (s2g_pred) {
int g_tile_idx = warpgroup_idx*4 + tile_idx;
SM90_TMA_STORE_3D::copy(
&tma_params.tensor_map_O,
plan.q_o.o.data() + g_tile_idx*(B_H*64),
g_tile_idx*64,
q_h_idx*B_H,
s_q_idx
);
}
}
cute::tma_store_arrive();
};
if (warpgroup_idx == 0) {
// Warpgroup 0
auto pipelined_wait_and_qkt_gemm_l = [&]() __attribute__((always_inline)) {
plan.bar_k0_ready[0].wait(cur_bar_wait_phase);
qkt_gemm_one_tile(Warpgroup0{}, 0, true);
qkt_gemm_one_tile(Warpgroup0{}, 1, false);
qkt_gemm_one_tile(Warpgroup0{}, 2, false);
qkt_gemm_one_tile(Warpgroup0{}, 3, false);
warpgroup_commit_batch();
};
auto pipelined_wait_and_qkt_gemm_r = [&]() __attribute__((always_inline)) {
plan.bar_k0_ready[1].wait(cur_bar_wait_phase);
qkt_gemm_one_tile(Warpgroup0{}, 4, false);
qkt_gemm_one_tile(Warpgroup0{}, 5, false);
qkt_gemm_one_tile(Warpgroup0{}, 6, false);
qkt_gemm_one_tile(Warpgroup0{}, 7, false);
qkt_gemm_one_tile(Warpgroup0{}, 8, false);
warpgroup_commit_batch();
};
auto scale_rS = [&](float scales[2]) {
CUTE_UNROLL
for (int row = 0; row < 2; ++row) {
CUTE_UNROLL
for (int i = row*2; i < size(rP); i += 4) {
rS(i) = (bf16)(rP(i) * scales[row]);
rS(i+1) = (bf16)(rP(i+1) * scales[row]);
}
}
};
auto rescale_rO = [&](float scales[2]) {
CUTE_UNROLL
for (int row = 0; row < 2; ++row) {
CUTE_UNROLL
for (int i = row*2; i < size(rO); i += 4) {
rO(i) *= scales[row];
rO(i+1) *= scales[row];
}
rL[row] *= scales[row];
}
};
CUTE_NO_UNROLL
for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) {
Tensor sV0l = make_tensor(make_smem_ptr(plan.k[0].data()), SmemLayoutKTilesTransposed<4>{});
Tensor sV1l = make_tensor(make_smem_ptr(plan.k[1].data()), SmemLayoutKTilesTransposed<4>{});
if (block_idx == 0) {
// NOTE We put these code here to avoid register spilling
pipelined_wait_and_qkt_gemm_l();
pipelined_wait_and_qkt_gemm_r();
warpgroup_wait<0>();
}
// Online softmax, inform WG1
mask_rP(Warpgroup0{});
online_softmax_and_rescale_o(Warpgroup0{});
NamedBarrier::arrive(256, NamedBarriers::wg0_bunch_0_ready);
// Issue rO0 += rS0 @ sV0l
gemm_rs(false, TiledMMA_PV_LocalP{}, rS, sV0l, rO, idx_in_warpgroup);
warpgroup_commit_batch();
// Mark V0L as free
warpgroup_wait<0>();
plan.bar_k0_free[0].arrive();
// Wait for new sM, scale rS, save, inform WG1
NamedBarrier::arrive_and_wait(256, NamedBarriers::wg1_bunch_0_ready);
float new_rM[2], scale_factors[2];
*(float2*)new_rM = plan.sM[idx_in_warpgroup/4];
CUTE_UNROLL
for (int i = 0; i < 2; ++i) {
scale_factors[i] = exp2f(rM[i] - new_rM[i]);
rM[i] = new_rM[i];
}
scale_rS(scale_factors);
save_rS_to_sS(rS, sS0, idx_in_warpgroup);
fence_view_async_shared();
NamedBarrier::arrive(256, NamedBarriers::wg0_s0_ready);
// Wait for sS1
NamedBarrier::arrive_and_wait(256, NamedBarriers::wg1_s1_ready);
// Rescale rO0, Issue rO0 += sS1 @ sV1L
rescale_rO(scale_factors);
gemm_ss(false, TiledMMA_PV_RemoteP{}, sS1, sV1l, rO, idx_in_warpgroup);
warpgroup_commit_batch();
cur_bar_wait_phase ^= 1;
if (block_idx+2 < num_topk_blocks) {
// Launch the next QK^T GEMM
pipelined_wait_and_qkt_gemm_l();
// Mark V1L as free
warpgroup_wait<1>();
plan.bar_k1_free[0].arrive();
pipelined_wait_and_qkt_gemm_r();
// Wait for rP0 = sQ @ sK0
warpgroup_wait<0>();
} else {
// Mark V1L as free
warpgroup_wait<0>();
plan.bar_k1_free[0].arrive();
}
}
reduce_L();
store_O();
} else {
// Warpgroup 1
auto pipelined_wait_and_qkt_gemm = [&]() __attribute__((always_inline)) {
plan.bar_k1_ready[1].wait(cur_bar_wait_phase);
qkt_gemm_one_tile(Warpgroup1{}, 4, true);
qkt_gemm_one_tile(Warpgroup1{}, 5, false);
qkt_gemm_one_tile(Warpgroup1{}, 6, false);
qkt_gemm_one_tile(Warpgroup1{}, 7, false);
qkt_gemm_one_tile(Warpgroup1{}, 8, false);
plan.bar_k1_ready[0].wait(cur_bar_wait_phase);
qkt_gemm_one_tile(Warpgroup1{}, 0, false);
qkt_gemm_one_tile(Warpgroup1{}, 1, false);
qkt_gemm_one_tile(Warpgroup1{}, 2, false);
qkt_gemm_one_tile(Warpgroup1{}, 3, false);
warpgroup_commit_batch();
};
CUTE_NO_UNROLL
for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) {
Tensor sV0r = make_tensor(make_smem_ptr(plan.k[0].data()+64*256), SmemLayoutKTilesTransposed<4>{});
Tensor sV1r = make_tensor(make_smem_ptr(plan.k[1].data()+64*256), SmemLayoutKTilesTransposed<4>{});
// Issue rP1 = sQ @ sK1, and wait
pipelined_wait_and_qkt_gemm();
warpgroup_wait<0>();
mask_rP(Warpgroup1{});
// Wait for WG0 (for sM), online softmax, Notify WG0 (sM ready)
NamedBarrier::arrive_and_wait(256, NamedBarriers::wg0_bunch_0_ready);
online_softmax_and_rescale_o(Warpgroup1{});
NamedBarrier::arrive(256, NamedBarriers::wg1_bunch_0_ready);
// Issue rO1 += rS1 @ sV1R
gemm_rs(false, TiledMMA_PV_LocalP{}, rS, sV1r, rO, idx_in_warpgroup);
warpgroup_commit_batch();
// Wait for WG0 (for sS0), Issue rO1 += rS0 @ sV0R
save_rS_to_sS(rS, sS1, idx_in_warpgroup); // Put it here is faster
NamedBarrier::arrive_and_wait(256, NamedBarriers::wg0_s0_ready);
gemm_ss(false, TiledMMA_PV_RemoteP{}, sS0, sV0r, rO, idx_in_warpgroup);
warpgroup_commit_batch();
// Save rS1, inform WG0
fence_view_async_shared();
NamedBarrier::arrive(256, NamedBarriers::wg1_s1_ready);
// Wait for GEMM, and inform that sV1R is free
warpgroup_wait<1>();
plan.bar_k1_free[1].arrive();
// Wait for GEMM, and inform that sV0R is free
warpgroup_wait<0>();
plan.bar_k0_free[1].arrive();
cur_bar_wait_phase ^= 1;
}
reduce_L();
store_O();
// Save lse
if (idx_in_warpgroup%4 == 0) {
for (int row = 0; row < 2; ++row) {
int real_row = get_AorC_row_idx(row, idx_in_warpgroup);
bool is_no_valid_tokens = rL[row] == 0.0f;
plan.final_max_logits[real_row] = is_no_valid_tokens ? -INFINITY : rM[row];
plan.final_lse[real_row] = is_no_valid_tokens ? -INFINITY : log2f(rL[row]) + rM[row];
}
fence_view_async_shared();
}
NamedBarrier::arrive_and_wait(128, NamedBarriers::warpgroup1_sync);
if (idx_in_warpgroup == 0) {
int g_offset = s_q_idx*params.h_q + q_h_idx*B_H;
SM90_BULK_COPY_S2G::copy(plan.final_max_logits, params.max_logits + g_offset, B_H*sizeof(float));
SM90_BULK_COPY_S2G::copy(plan.final_lse, params.lse + g_offset, B_H*sizeof(float));
cute::tma_store_arrive();
}
}
} else {
// Producer warpgroup
cutlass::arch::warpgroup_reg_dealloc<72>();
constexpr int GROUP_SIZE = 8, NUM_GROUPS = 128/GROUP_SIZE;
constexpr int NUM_ROWS_PER_GROUP = B_TOPK / NUM_GROUPS;
int idx_in_group = idx_in_warpgroup % GROUP_SIZE;
int group_idx = idx_in_warpgroup / GROUP_SIZE;
int* gIndices = params.indices + s_q_idx*params.topk; // [topk]
bf16* my_sKV_base = &(make_tensor(make_smem_ptr(plan.k[0].data()), SmemLayoutKTiles<1>{})(group_idx, idx_in_group*8));
bf16* my_gKV_base = params.kv + idx_in_group*8;
int64_t token_indices[2][NUM_ROWS_PER_GROUP];
bool is_token_valid[2][NUM_ROWS_PER_GROUP];
auto load_token_indices = [&](int block_idx) {
CUTE_UNROLL
for (int buf_idx = 0; buf_idx < 2; ++buf_idx) {
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) {
int offs = (block_idx+buf_idx)*B_TOPK + local_row*NUM_GROUPS + group_idx;
int t = __ldg(gIndices + offs);
token_indices[buf_idx][local_row] = t*(int64_t)params.stride_kv_s_kv; // We mult it with params.stride_kv_s_kv here since it's faster
is_token_valid[buf_idx][local_row] = t >= 0 && t < params.s_kv;
}
}
};
int64_t cache_policy = createpolicy_evict_last();
auto copy_tiles = [&](int block_idx, int buf_idx, int tile_start, int tile_end) {
// Copy some K/V tiles from global memory to shared memory
// A tile has a shape of 64 (B_TOPK) x 64
// `buf_idx` is the index of the shared memory buffer, 0 or 1
// `tile_idx` is the index of the tile to load, from 0 to D_K/64-1 = 8
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) {
int64_t token_index = token_indices[buf_idx][local_row];
CUTE_UNROLL
for (int tile_idx = tile_start; tile_idx < tile_end; ++tile_idx) {
cp_async_cacheglobal_l2_prefetch_256B(
my_gKV_base + token_index + tile_idx*64,
my_sKV_base + (buf_idx*B_TOPK*D_K + tile_idx*(B_TOPK*64) + local_row*NUM_GROUPS*64),
is_token_valid[buf_idx][local_row],
cache_policy
);
}
}
};
auto commit_to_mbar = [&](transac_bar_t &bar) {
cutlass::arch::cpasync_barrier_arrive_noinc((uint64_t*)(&bar));
};
int cur_bar_wait_phase = 1;
CUTE_NO_UNROLL
for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) {
load_token_indices(block_idx);
// V0L
plan.bar_k0_free[0].wait(cur_bar_wait_phase);
copy_tiles(block_idx+0, 0, 0, 4);
commit_to_mbar(plan.bar_k0_ready[0]);
// V1R
plan.bar_k1_free[1].wait(cur_bar_wait_phase);
copy_tiles(block_idx+1, 1, 4, 9);
commit_to_mbar(plan.bar_k1_ready[1]);
// V0R
plan.bar_k0_free[1].wait(cur_bar_wait_phase);
copy_tiles(block_idx+0, 0, 4, 9);
commit_to_mbar(plan.bar_k0_ready[1]);
// V1L
plan.bar_k1_free[0].wait(cur_bar_wait_phase);
copy_tiles(block_idx+1, 1, 0, 4);
commit_to_mbar(plan.bar_k1_ready[0]);
// Valid mask
// NOTE V1R's finish implies maskings of the last round have finished
if (idx_in_group == 0) {
CUTE_UNROLL
for (int buf_idx = 0; buf_idx < 2; ++buf_idx)
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row)
plan.is_kv_valid[buf_idx][local_row*NUM_GROUPS+group_idx] = is_token_valid[buf_idx][local_row];
plan.bar_is_kv_valid_ready.arrive();
}
cur_bar_wait_phase ^= 1;
}
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm90");
}
#endif
}
void run_fwd_kernel(const SparsePrefillParams& params) {
FLASH_ASSERT(params.h_kv == 1);
FLASH_ASSERT(params.topk % (2*B_TOPK) == 0); // To save some boundry checkings
FLASH_ASSERT(params.topk > 0);
FLASH_ASSERT(params.h_q % B_H == 0);
auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q);
auto tma_Q = cute::make_tma_copy(
SM90_TMA_LOAD{},
make_tensor(
make_gmem_ptr((bf16*)params.q),
make_layout(
shape_Q,
make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q)
)
),
SmemLayoutQ{}
);
CUtensorMap tensor_map_O;
{
uint64_t size[3] = {D_V, (unsigned long)params.h_q, (unsigned long)params.s_q};
uint64_t stride[2] = {D_V*sizeof(bf16), D_V*params.h_q*sizeof(bf16)};
uint32_t box_size[3] = {64, B_H, 1};
uint32_t elem_stride[3] = {1, 1, 1};
CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
&tensor_map_O,
CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
3,
params.out,
size,
stride,
box_size,
elem_stride,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);
FLASH_ASSERT(res == CUresult::CUDA_SUCCESS);
}
TmaParams<
decltype(shape_Q), decltype(tma_Q)
> tma_params = {
shape_Q, tma_Q,
tensor_map_O
};
auto kernel = &sparse_attn_fwd_kernel<decltype(tma_params)>;
constexpr size_t smem_size = sizeof(SharedMemoryPlan);
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
cutlass::ClusterLaunchParams launch_params = {
dim3((params.h_q/B_H)*params.s_q, 1, 1), // NOTE We put s_q on the first dim since it can be larger than 65536 (the maximum size of griddim.y and griddim.z)
dim3(NUM_THREADS, 1, 1),
dim3(1, 1, 1),
smem_size,
params.stream
};
cutlass::launch_kernel_on_cluster(
launch_params, (void*)kernel, params, tma_params
);
CHECK_CUDA_KERNEL_LAUNCH();
}
}
#pragma once
#include "params.h"
namespace sm90 {
void run_fwd_kernel(const SparsePrefillParams& params);
}
#pragma once
#include <cutlass/bfloat16.h>
#include <cutlass/arch/barrier.h>
#include <cute/tensor.hpp>
namespace sm90 {
using bf16 = cutlass::bfloat16_t;
using transac_bar_t = cutlass::arch::ClusterTransactionBarrier;
using cutlass::arch::fence_view_async_shared;
using cutlass::arch::fence_barrier_init;
using cutlass::arch::NamedBarrier;
__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) {
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst);
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], %2;\n"
:: "r"(dst_addr),
"l"(src),
"n"(16));
}
__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst, bool pred, int64_t cache_policy) {
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst);
asm volatile("cp.async.cg.shared.global.L2::cache_hint.L2::256B [%0], [%1], 16, %2, %3;\n"
:: "r"(dst_addr),
"l"(src),
"r"(pred?16:0),
"l"(cache_policy));
}
__forceinline__ __device__ int64_t createpolicy_evict_last() {
int64_t res;
asm volatile(
"createpolicy.fractional.L2::evict_last.b64 %0, 1.0; \n\t"
: "=l"(res)
:
);
return res;
}
__forceinline__ __device__ int64_t createpolicy_evict_first() {
int64_t res;
asm volatile(
"createpolicy.fractional.L2::evict_first.b64 %0, 1.0; \n\t"
: "=l"(res)
:
);
return res;
}
__forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) {
// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4);
return row_idx;
}
__forceinline__ __device__ int get_AorC_col_idx(int local_elem_idx, int idx_in_warpgroup) {
int col_idx = 8*(local_elem_idx/4) + (idx_in_warpgroup%4)*2 + (local_elem_idx&1);
return col_idx;
}
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h
// * Copyright (c) 2024, Tri Dao.
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
using namespace cute;
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC);
if constexpr (arrive) {
warpgroup_arrive();
}
if constexpr (zero_init) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
} else {
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
}
if constexpr (commit) {
warpgroup_commit_batch();
}
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}
// A simpiler version of gemm
template <typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm_ss(bool clear_accum, TiledMma tiled_mma, Tensor0 const &sA, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) {
using namespace cute;
ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);
Tensor sA_frag = thr_mma.partition_fragment_A(sA);
Tensor sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(sA_frag) == size<2>(sB_frag));
warpgroup_fence_operand(rC_frag);
warpgroup_arrive();
tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One;
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < size<2>(sA_frag); ++k) {
cute::gemm(tiled_mma, sA_frag(_, _, k), sB_frag(_, _, k), rC_frag);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_fence_operand(rC_frag);
}
template <typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm_rs(bool clear_accum, TiledMma tiled_mma, Tensor0 rA_frag, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) {
using namespace cute;
ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);
Tensor sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(rA_frag) == size<2>(sB_frag));
warpgroup_fence_operand(const_cast<Tensor0 &>(rA_frag));
warpgroup_fence_operand(rC_frag);
warpgroup_arrive();
tiled_mma.accumulate_ = clear_accum ? GMMA::ScaleOut::Zero : GMMA::ScaleOut::One;
CUTLASS_PRAGMA_UNROLL
for (int k = 0; k < size<2>(rA_frag); ++k) {
cute::gemm(tiled_mma, rA_frag(_, _, k), sB_frag(_, _, k), rC_frag);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
warpgroup_fence_operand(rC_frag);
warpgroup_fence_operand(const_cast<Tensor0 &>(rA_frag));
}
__forceinline__ __device__ uint32_t get_sm_id() {
uint32_t ret;
asm("mov.u32 %0, %smid;" : "=r"(ret));
return ret;
}
static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. 不确定是不是在所有显卡上都是这个数字
template<typename T>
CUTE_DEVICE
T* get_peer_addr(const T* p) {
return (T*)((int64_t)(p) ^ PEER_ADDR_MASK);
}
template<
typename TMA,
typename Tensor0,
typename Tensor1
>
CUTE_DEVICE
void launch_tma_copy(
const TMA &tma_copy,
Tensor0 src,
Tensor1 dst,
transac_bar_t &bar,
const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL
) {
auto thr_tma = tma_copy.get_slice(cute::_0{});
cute::copy(
tma_copy.with(reinterpret_cast<typename transac_bar_t::ValueType&>(bar), 0, cache_hint),
thr_tma.partition_S(src),
thr_tma.partition_D(dst)
);
}
}
......@@ -6,7 +6,7 @@
#include "utils.h"
__global__ void __launch_bounds__(32, 1, 1)
get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
get_mla_metadata_kernel(__grid_constant__ const GetDecodingMetadataParams params) {
int *seqlens_k_ptr = params.seqlens_k_ptr;
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
int *num_splits_ptr = params.num_splits_ptr;
......@@ -18,12 +18,26 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
extern __shared__ int shared_mem[];
int* num_blocks_shared = shared_mem; // [batch_size]
int* num_splits_shared = shared_mem + batch_size; // [batch_size+1]
int* seqlens_k_shared = shared_mem + batch_size*2+1; // [batch_size]
int* first_block_idx_shared = shared_mem + batch_size*3+1; // [batch_size]
int* last_block_idx_shared = shared_mem + batch_size*4+1; // [batch_size]
int total_num_blocks = 0;
for (int i = threadIdx.x; i < batch_size; i += 32) {
int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n);
int cur_s_k = params.topk == -1 ? __ldg(seqlens_k_ptr + i) : params.topk;
seqlens_k_shared[i] = cur_s_k;
int first_token_idx = 0;
int last_token_idx = max(cur_s_k-1, 0);
int cur_first_block_idx = first_token_idx / block_size_n;
int cur_last_block_idx = last_token_idx / block_size_n;
// NOTE Should attend to tokens [first_token_idx, last_token_idx], i.e. blocks [cur_first_block_idx, cur_last_block_idx]
// NOTE Before clamping, first_token_idx <= last_token_idx always holds, so after clamping, first_token_idx <= last_token_idx still holds.
// NOTE if seqlens_k is 0, then first_token_idx == last_token_idx == cur_first_block_idx == cur_last_block_idx == 0. So the sequence will have 1 block. We will correct this later in this kernel.
int num_blocks = cur_last_block_idx - cur_first_block_idx + 1;
total_num_blocks += num_blocks + fixed_overhead_num_blocks;
num_blocks_shared[i] = num_blocks;
first_block_idx_shared[i] = cur_first_block_idx;
last_block_idx_shared[i] = cur_last_block_idx;
}
for (int offset = 16; offset >= 1; offset /= 2) {
total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset);
......@@ -31,14 +45,14 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
__syncwarp();
if (threadIdx.x == 0) {
int payload = max(cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks, 2*fixed_overhead_num_blocks);
int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
num_splits_shared[0] = 0;
for (int i = 0; i < num_sm_parts; ++i) {
int tile_scheduler_metadata0[4], tile_scheduler_metadata1;
tile_scheduler_metadata0[0] = now_idx;
tile_scheduler_metadata0[1] = now_block * block_size_n;
tile_scheduler_metadata0[1] = now_block + first_block_idx_shared[now_idx];
tile_scheduler_metadata1 = now_n_split_idx;
int remain_payload = payload;
while (now_idx < batch_size) {
......@@ -61,7 +75,7 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
}
}
tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1;
tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1];
tile_scheduler_metadata0[3] = now_block > 0 ? now_block + first_block_idx_shared[now_idx] : (seqlens_k_shared[now_idx-1] == 0 ? 0 : last_block_idx_shared[now_idx-1] + 1);
*reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast<int4 *>(tile_scheduler_metadata0);
tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1;
}
......@@ -74,8 +88,8 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
}
}
void run_get_mla_metadata_kernel(Mla_metadata_params &params, cudaStream_t stream) {
int smem_size = sizeof(int) * (params.batch_size*2+1);
void run_get_mla_metadata_kernel(GetDecodingMetadataParams &params, cudaStream_t stream) {
int smem_size = sizeof(int) * (params.batch_size*5+1);
CHECK_CUDA(cudaFuncSetAttribute(get_mla_metadata_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
get_mla_metadata_kernel<<<1, 32, smem_size, stream>>>(params);
CHECK_CUDA_KERNEL_LAUNCH();
......
#pragma once
#include "params.h"
void run_get_mla_metadata_kernel(GetDecodingMetadataParams &params, cudaStream_t stream);
......@@ -7,13 +7,12 @@
#include "params.h"
#include "utils.h"
#include "config.h" // for BLOCK_SIZE_M and HEAD_DIM_V
using namespace cute;
template<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS>
__global__ void __launch_bounds__(NUM_THREADS)
flash_fwd_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
flash_fwd_mla_combine_kernel(__grid_constant__ const DecodingParams params) {
// grid_shape: [batch_size, num_q_heads*s_q / BLOCK_SIZE_M]
// Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result
static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m
......@@ -176,12 +175,14 @@ flash_fwd_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params
template<typename ElementT>
void run_flash_mla_combine_kernel(Flash_fwd_mla_params &params, cudaStream_t stream) {
void run_flash_mla_combine_kernel(DecodingParams &params, cudaStream_t stream) {
static constexpr int HEAD_DIM_V = 512; // Since only this head dimension is supported by Flash MLA
FLASH_ASSERT(params.d_v == HEAD_DIM_V);
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] {
constexpr int BLOCK_SIZE_M = 8;
constexpr int NUM_THREADS = BLOCK_SIZE_M*32;
constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float);
auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, Config::HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS>;
auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS>;
CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
cudaLaunchAttribute attribute[1];
......@@ -200,8 +201,8 @@ void run_flash_mla_combine_kernel(Flash_fwd_mla_params &params, cudaStream_t str
CHECK_CUDA_KERNEL_LAUNCH();
}
template void run_flash_mla_combine_kernel<cutlass::bfloat16_t>(Flash_fwd_mla_params &params, cudaStream_t stream);
template void run_flash_mla_combine_kernel<cutlass::bfloat16_t>(DecodingParams &params, cudaStream_t stream);
#ifndef FLASH_MLA_DISABLE_FP16
template void run_flash_mla_combine_kernel<cutlass::half_t>(Flash_fwd_mla_params &params, cudaStream_t stream);
template void run_flash_mla_combine_kernel<cutlass::half_t>(DecodingParams &params, cudaStream_t stream);
#endif
\ No newline at end of file
......@@ -3,4 +3,4 @@
#include "params.h"
template<typename ElementT>
void run_flash_mla_combine_kernel(Flash_fwd_mla_params &params, cudaStream_t stream);
void run_flash_mla_combine_kernel(DecodingParams &params, cudaStream_t stream);
......@@ -30,3 +30,37 @@
} while(0)
#define println(fmt, ...) { print(fmt, ##__VA_ARGS__); print("\n"); }
template<typename T>
__inline__ __host__ __device__ T ceil_div(const T &a, const T &b) {
return (a + b - 1) / b;
}
#ifndef TRAP_ONLY_DEVICE_ASSERT
#define TRAP_ONLY_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) \
asm("trap;"); \
} while (0)
#endif
// For development, we define both IS_SM100 and IS_SM90 when using CLion or VSCode IDEs so code highlighting will be correct.
#if defined(__CLION_IDE__) || defined(__VSCODE_IDE__)
#define IS_SM100 1
#define IS_SM90 1
#else
// We define the following macros to detect the CUDA architecture, so that we can enable/disable certains kernels that depends on specific architectures.
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 1000)
#define IS_SM100 1
#else
#define IS_SM100 0
#endif
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900)
#define IS_SM90 1
#else
#define IS_SM90 0
#endif
#endif // defined(__CLION_IDE__) || defined(__VSCODE_IDE__)
\ No newline at end of file
......@@ -6,4 +6,5 @@ from flash_mla.flash_mla_interface import (
flash_attn_varlen_func,
flash_attn_varlen_qkvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_mla_sparse_fwd
)
......@@ -2,30 +2,33 @@ from typing import Optional, Tuple
import torch
import flash_mla_sm90
import flash_mla_sm100
import flash_mla.cuda as flash_mla_cuda
def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_q_tokens_per_head_k: int,
num_heads_k: int,
num_heads_q: Optional[int] = None,
is_fp8_kvcache: bool = False,
topk: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
num_q_tokens_per_head_k: Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
num_heads_k: The number of k heads.
num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled
is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache_sm90` will be attended to.
Returns:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
return flash_mla_sm90.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
return flash_mla_cuda.get_mla_decoding_metadata(cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q, is_fp8_kvcache, topk)
def flash_mla_with_kvcache_sm90(
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
......@@ -35,6 +38,8 @@ def flash_mla_with_kvcache_sm90(
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
......@@ -47,6 +52,8 @@ def flash_mla_with_kvcache_sm90(
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md
indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the `indices` array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up `indices`, please refer to README.md.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
......@@ -54,7 +61,9 @@ def flash_mla_with_kvcache_sm90(
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse = flash_mla_sm90.fwd_kvcache_mla(
if indices is not None:
assert causal == False, "causal must be `false` if sparse attention is enabled."
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q,
k_cache,
head_dim_v,
......@@ -64,10 +73,42 @@ def flash_mla_with_kvcache_sm90(
causal,
tile_scheduler_metadata,
num_splits,
is_fp8_kvcache,
indices
)
return out, softmax_lse
def flash_mla_sparse_fwd(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
sm_scale: float,
d_v: int = 512,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sparse attention prefill kernel
Args:
q: [s_q, h_q, d_qk], bfloat16
kv: [s_kv, h_kv, d_qk], bfloat16
indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv
sm_scale: float
d_v: The dimension of value vectors. Can only be 512
Returns:
(output, max_logits, lse)
About the definition of output, max_logits and lse, please refer to README.md
- output: [s_q, h_q, d_v], bfloat16
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, 2-based log-sum-exp
"""
results = flash_mla_cuda.sparse_prefill_fwd(
q, kv, indices, sm_scale, d_v
)
return results
def _flash_attn_varlen_forward(
q: torch.Tensor,
k: torch.Tensor,
......@@ -96,7 +137,7 @@ def _flash_attn_varlen_forward(
lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device)
flash_mla_sm100.fwd(
flash_mla_cuda.dense_prefill_fwd(
workspace_buffer,
q,
k,
......@@ -159,7 +200,7 @@ def _flash_attn_varlen_backward(
if num_qo_heads != num_kv_heads:
workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc
workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device)
flash_mla_sm100.bwd(
flash_mla_cuda.dense_prefill_bwd(
workspace_buffer,
do,
q,
......@@ -195,7 +236,7 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
causal: bool = False,
softmax_scale: Optional[float] = None,
is_varlen: bool = True,
):
) -> Tuple[torch.Tensor, torch.Tensor]:
out, lse = _flash_attn_varlen_forward(
q, k, v,
cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv,
......@@ -290,40 +331,3 @@ def flash_attn_varlen_kvpacked_func(
cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv,
causal, softmax_scale, is_varlen,
)
def flash_mla_with_kvcache_sm100(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
# TODO
pass
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: Optional[torch.Tensor] = None,
num_splits: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
capability = torch.cuda.get_device_capability(q.device.index)
if capability == (9, 0):
return flash_mla_with_kvcache_sm90(
q, k_cache, block_table, cache_seqlens, head_dim_v,
tile_scheduler_metadata, num_splits,
softmax_scale, causal,
)
elif capability == (10, 0):
raise ValueError(f"Unsupported device capability: {capability}")
else:
raise ValueError(f"Unsupported device capability: {capability}")
......@@ -12,29 +12,31 @@ from torch.utils.cpp_extension import (
)
def append_nvcc_threads(nvcc_extra_args):
nvcc_threads = os.getenv("NVCC_THREADS") or "32"
return nvcc_extra_args + ["--threads", nvcc_threads]
def is_flag_set(flag: str) -> bool:
return os.getenv(flag, "FALSE").lower() in ["true", "1", "y", "yes"]
def get_features_args():
features_args = []
DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") in ["TRUE", "1"]
if DISABLE_FP16:
if is_flag_set("FLASH_MLA_DISABLE_FP16"):
features_args.append("-DFLASH_MLA_DISABLE_FP16")
return features_args
def get_arch_flags():
DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100")
DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90")
arch_flags = []
if not DISABLE_SM100:
arch_flags.extend(["-gencode", "arch=compute_100a,code=sm_100a"])
if not DISABLE_SM90:
arch_flags.extend(["-gencode", "arch=compute_90a,code=sm_90a"])
return arch_flags
def get_nvcc_thread_args():
nvcc_threads = os.getenv("NVCC_THREADS") or "32"
return ["--threads", nvcc_threads]
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
cc_flag_sm90 = []
cc_flag_sm90.append("-gencode")
cc_flag_sm90.append("arch=compute_90a,code=sm_90a")
cc_flag_sm100 = []
cc_flag_sm100.append("-gencode")
cc_flag_sm100.append("arch=compute_100a,code=sm_100a")
this_dir = os.path.dirname(os.path.abspath(__file__))
if IS_WINDOWS:
......@@ -45,17 +47,20 @@ else:
ext_modules = []
ext_modules.append(
CUDAExtension(
name="flash_mla_sm90",
name="flash_mla.cuda",
sources=[
"csrc/sm90/flash_api.cpp",
"csrc/sm90/kernels/get_mla_metadata.cu",
"csrc/sm90/kernels/mla_combine.cu",
"csrc/sm90/kernels/splitkv_mla.cu",
"csrc/pybind.cpp",
"csrc/smxx/get_mla_metadata.cu",
"csrc/smxx/mla_combine.cu",
"csrc/sm90/decode/dense/splitkv_mla.cu",
"csrc/sm90/decode/sparse_fp8/splitkv_mla.cu",
"csrc/sm90/prefill/sparse/fwd.cu",
"csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu",
"csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu",
],
extra_compile_args={
"cxx": cxx_args + get_features_args(),
"nvcc": append_nvcc_threads(
[
"nvcc": [
"-O3",
"-std=c++17",
"-DNDEBUG",
......@@ -69,55 +74,17 @@ ext_modules.append(
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=-v,--register-usage-level=10"
]
+ cc_flag_sm90
) + get_features_args(),
] + get_features_args() + get_arch_flags() + get_nvcc_thread_args(),
},
include_dirs=[
Path(this_dir) / "csrc",
Path(this_dir) / "csrc" / "sm90",
Path(this_dir) / "csrc" / "cutlass" / "include",
],
)
)
ext_modules.append(
CUDAExtension(
name="flash_mla_sm100",
sources=[
"csrc/sm100/pybind.cu",
"csrc/sm100/fmha_cutlass_fwd_sm100.cu",
"csrc/sm100/fmha_cutlass_bwd_sm100.cu",
],
extra_compile_args={
"cxx": ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"],
"nvcc": append_nvcc_threads(
[
"-O3",
"-std=c++17",
"-DNDEBUG",
"-Wno-deprecated-declarations",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"-lineinfo",
"--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage",
]
+ cc_flag_sm100
),
},
include_dirs=[
Path(this_dir) / "csrc" / "sm100",
Path(this_dir) / "csrc" / "cutlass" / "include",
Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include",
],
)
)
try:
cmd = ['git', 'rev-parse', '--short', 'HEAD']
rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
......
from typing import List
import torch
def cdiv(x: int, y: int):
return (x+y-1) // y
def check_is_allclose(name: str, ans: torch.Tensor, ref: torch.Tensor, abs_tol: float = 1e-5, rel_tol: float = 1e-2, cos_diff_tol: float = 1e-7):
"""
Check if two tensors are close enough
"""
def get_cos_diff(x: torch.Tensor, y: torch.Tensor) -> float:
"""
Calculate the cosine diff between two tensors
"""
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum().item()
if denominator == 0:
return 0
sim = 2 * (x * y).sum().item() / denominator
return 1 - sim
assert ans.shape == ref.shape, f"`{name}` Shape mismatch: {ans.shape} vs {ref.shape}"
ans = ans.clone().to(torch.float)
ref = ref.clone().to(torch.float)
# Deal with anomalies
def deal_with_anomalies(val: float):
ref_mask = (ref == val) if (val == val) else (ref != ref)
ans_mask = (ans == val) if (val == val) else (ans != ans)
ref[ref_mask] = 0.0
ans[ans_mask] = 0.0
if not torch.equal(ref_mask, ans_mask):
print(f"`{name}` Anomaly number `{val}` mismatch: {ans_mask.sum().item()} in ans but {ref_mask.sum().item()} in ref")
return False
return True
anomalies_check_passed = True
anomalies_check_passed &= deal_with_anomalies(float("inf"))
anomalies_check_passed &= deal_with_anomalies(float("-inf"))
anomalies_check_passed &= deal_with_anomalies(float("nan"))
if not anomalies_check_passed:
return False
cos_diff = get_cos_diff(ans, ref)
raw_abs_err = torch.abs(ans-ref)
raw_rel_err = raw_abs_err / (torch.abs(ref)+(1e-6))
rel_err = raw_rel_err.masked_fill(raw_abs_err<abs_tol, 0)
abs_err = raw_abs_err.masked_fill(raw_rel_err<rel_tol, 0)
pass_mask = (abs_err < abs_tol) | (rel_err < rel_tol)
if not pass_mask.all():
print(f"`{name}` mismatch")
max_abs_err_pos: int = torch.argmax(abs_err, keepdim=True).item() # type: ignore
max_rel_err_pos: int = torch.argmax(rel_err, keepdim=True).item() # type: ignore
def get_pos_in_tensor(t: torch.Tensor, pos: int) -> List[int]:
result = []
for size in t.shape[::-1]:
result.append(pos % size)
pos = pos // size
assert pos == 0
return result[::-1]
print(f"max abs err: {torch.max(abs_err).item()}: pos {get_pos_in_tensor(ans, max_abs_err_pos)}, {ans.reshape(-1)[max_abs_err_pos].item()} vs {ref.reshape(-1)[max_abs_err_pos].item()}")
print(f"max rel err: {torch.max(rel_err).item()}: pos {get_pos_in_tensor(ans, max_rel_err_pos)}, {ans.reshape(-1)[max_rel_err_pos].item()} vs {ref.reshape(-1)[max_rel_err_pos].item()}")
print(f"{pass_mask.sum()} out of {pass_mask.numel()} passed ({pass_mask.sum()/pass_mask.numel()*100.0:.2f}%)")
print(f"Cosine diff: {cos_diff} (threshold: {cos_diff_tol})")
return False
else:
if abs(cos_diff) > cos_diff_tol:
print(f"`{name}` mismatch: Cosine diff too large: {cos_diff} vs {cos_diff_tol})")
return False
return True
\ No newline at end of file
import enum
import torch
def quantize_k_cache(
input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d)
dv: int,
tile_size: int = 128,
) -> torch.Tensor:
"""
Quantize the k-cache
Return a tensor with shape (num_blocks, block_size, h_k, dv + 4(dv/tile_size) + t(d-dv)) of dtype uint8_t, where t = input_k_cache.element_size()
For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py or README.md
"""
assert dv % tile_size == 0
num_tiles = dv // tile_size
num_blocks, block_size, h_k, d = input_k_cache.shape
assert h_k == 1
input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d]
input_elem_size = input_k_cache.element_size()
result = torch.empty((num_blocks, block_size, dv + num_tiles*4 + input_elem_size*(d-dv)), dtype=torch.float8_e4m3fn, device=input_k_cache.device)
result_k_nope_part = result[..., :dv]
result_k_scale_factor = result[..., dv:dv + num_tiles*4].view(torch.float32)
result_k_rope_part = result[..., dv + num_tiles*4:].view(input_k_cache.dtype)
result_k_rope_part[:] = input_k_cache[..., dv:]
for tile_idx in range(0, num_tiles):
cur_scale_factors_inv = torch.abs(input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size]).max(dim=-1).values / 448.0 # [num_blocks, block_size]
result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv
cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1]
cur_quantized_nope = (input_k_cache[..., tile_idx*tile_size:(tile_idx+1)*tile_size].float() / cur_scale_factors_inv.float()).to(torch.float8_e4m3fn)
result_k_nope_part[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_quantized_nope
result = result.view(num_blocks, block_size, 1, -1)
return result
def dequantize_k_cache(
quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token)
dv: int = 512,
tile_size: int = 128,
d: int = 576
) -> torch.Tensor:
"""
De-quantize the k-cache
"""
assert dv % tile_size == 0
num_tiles = dv // tile_size
num_blocks, block_size, h_k, _ = quant_k_cache.shape
assert h_k == 1
result = torch.empty((num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device)
quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1)
input_nope = quant_k_cache[..., :dv]
input_scale = quant_k_cache[..., dv:dv + num_tiles*4].view(torch.float32)
input_rope = quant_k_cache[..., dv + num_tiles*4:].view(torch.bfloat16)
result[..., dv:] = input_rope
for tile_idx in range(0, num_tiles):
cur_nope = input_nope[..., tile_idx*tile_size:(tile_idx+1)*tile_size].to(torch.float32)
cur_scales = input_scale[..., tile_idx].unsqueeze(-1)
result[..., tile_idx*tile_size:(tile_idx+1)*tile_size] = cur_nope * cur_scales
result = result.view(num_blocks, block_size, 1, d)
return result
import argparse
import math
import random
import dataclasses
from typing import Optional, Tuple, List
import torch
import triton
import quant
import flash_mla
from lib import cdiv, check_is_allclose
@dataclasses.dataclass
class TestParam:
b: int # Batch size
s_q: int # Number of queries for one request
s_k: int # Seq len, or mean seq len if varlen == True
is_varlen: bool
is_causal: bool
is_fp8: bool
topk: Optional[int] = None
test_performance: bool = True
is_all_indices_invalid: bool = False
have_zero_seqlen_k: bool = False
block_size: int = 64
h_q: int = 128 # Number of q heads
h_kv: int = 1 # Number of kv heads
d: int = 576 # Q/K head dim (= dv + RoPE dim)
dv: int = 512 # V head dim
seed: int = 0
def generate_test_data(t: TestParam) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Generate test data from a given configuration
Return: [cache_seqlens, q, block_table, blocked_k]
Pay attention: This function changes the random seed
"""
random.seed(t.seed)
torch.manual_seed(t.seed)
torch.cuda.manual_seed(t.seed)
torch.backends.cudnn.deterministic = True
assert t.h_q % t.h_kv == 0
cache_seqlens_cpu = torch.full((t.b,), t.s_k, dtype=torch.int32, device='cpu')
if t.is_varlen:
for i in range(t.b):
cache_seqlens_cpu[i] = max(random.normalvariate(t.s_k, t.s_k / 2), t.s_q)
if t.have_zero_seqlen_k:
zeros_mask = torch.randn(t.b, dtype=torch.float32, device='cpu') > 0
cache_seqlens_cpu[zeros_mask] = 0
max_seqlen = cache_seqlens_cpu.max().item()
max_seqlen_pad = cdiv(max_seqlen, 256) * 256
cache_seqlens = cache_seqlens_cpu.cuda()
q = torch.randn(t.b, t.s_q, t.h_q, t.d)
q.clamp_(min=-1.0, max=1.0)
block_table = torch.arange(t.b * max_seqlen_pad // t.block_size, dtype=torch.int32).view(t.b, max_seqlen_pad // t.block_size)
block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view(t.b, -1)
blocked_k = torch.randn(block_table.numel(), t.block_size, t.h_kv, t.d) / 10
blocked_k.clamp_(min=-1.0, max=1.0)
if t.topk is None:
for i in range(t.b):
cur_len = cache_seqlens_cpu[i].item()
cur_num_blocks = cdiv(cur_len, t.block_size)
blocked_k[block_table[i][cur_num_blocks:]] = float("nan")
if cur_len % t.block_size != 0:
blocked_k[block_table[i][cur_num_blocks-1]][cur_len % t.block_size:] = float("nan")
block_table[i][cur_num_blocks:] = 2147480000
return cache_seqlens, q, block_table, blocked_k, None, None
else:
block_table_cpu = block_table.cpu()
abs_indices = torch.empty(t.b, t.s_q, t.topk, dtype=torch.int32, device="cpu")
indices_in_kvcache = torch.empty(t.b, t.s_q, t.topk, dtype=torch.int32, device="cpu")
for i in range(t.b):
# Generate indices
for j in range(t.s_q):
cur_abs_indices = torch.randperm(int(cache_seqlens_cpu[i].item()), device="cpu")[:t.topk]
cur_blocked_indices = block_table_cpu[i, cur_abs_indices//t.block_size]*t.block_size + (cur_abs_indices%t.block_size)
if len(cur_abs_indices) < t.topk:
pad_len = t.topk - len(cur_abs_indices)
cur_abs_indices = torch.cat([cur_abs_indices, torch.full((pad_len,), -1, device='cpu')])
cur_blocked_indices = torch.cat([cur_blocked_indices, torch.full((pad_len,), -1, device='cpu')])
# Mask KV
perm = torch.randperm(t.topk, device='cpu')
cur_abs_indices = cur_abs_indices[perm]
cur_blocked_indices = cur_blocked_indices[perm]
# Fill it with invalid indices if needed
if t.is_all_indices_invalid:
cur_abs_indices.fill_(-1)
cur_blocked_indices.fill_(-1)
abs_indices[i, j, :] = cur_abs_indices
indices_in_kvcache[i, j, :] = cur_blocked_indices
# Mask nonused KV as NaN
all_indices = indices_in_kvcache.flatten().tolist()
all_indices = list(set(all_indices))
if -1 in all_indices:
all_indices.remove(-1)
all_indices = torch.tensor(all_indices, dtype=torch.int32, device='cpu')
blocked_k = blocked_k.view(-1, t.h_kv, t.d)
nonused_indices_mask = torch.ones(blocked_k.size(0)*blocked_k.size(1), dtype=torch.bool, device='cpu')
nonused_indices_mask[all_indices] = False
blocked_k[nonused_indices_mask, :, :] = float("nan")
blocked_k = blocked_k.view(-1, t.block_size, t.h_kv, t.d)
abs_indices = abs_indices.to(q.device)
indices_in_kvcache = indices_in_kvcache.to(q.device)
return cache_seqlens, q, block_table, blocked_k, abs_indices, indices_in_kvcache
def reference_torch(
cache_seqlens: torch.Tensor, # [batch_size]
block_table: torch.Tensor, # [batch_size, ?]
q: torch.Tensor, # [batch_size, s_q, h_q, d]
blocked_k: torch.Tensor, # [?, block_size, h_kv, d]
dv: int,
is_causal: bool,
indices: Optional[torch.Tensor] = None # [batch_size, s_q, topk]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
A reference implementation in PyTorch
"""
def get_topk_attn_mask(s_q: int, s_k: int, indices: torch.Tensor):
mask = torch.zeros(s_q, s_k, dtype=torch.bool)
for i in range(s_q):
cur_indices = indices[i]
valid_indices = cur_indices[cur_indices != -1]
mask[i, valid_indices] = True
return mask
def scaled_dot_product_attention(
batch_idx: int,
query: torch.Tensor, # [h_q, s_q, d]
kv: torch.Tensor, # [h_kv, s_k, d]
dv: int,
is_causal,
indices: Optional[torch.Tensor], # [s_q, topk]
) -> Tuple[torch.Tensor, torch.Tensor]:
h_q = query.size(0)
h_kv = kv.size(0)
s_q = query.shape[-2]
s_k = kv.shape[-2]
query = query.float()
kv = kv.float()
if h_kv != 1:
kv = kv.repeat_interleave(h_q // h_kv, dim=0)
kv[kv != kv] = 0.0
attn_weight = query @ kv.transpose(-2, -1) # [h_q, s_q, s_k]
if (is_causal and query.size(1) > 1) or indices is not None:
mask = torch.ones(s_q, s_k, dtype=torch.bool)
if is_causal:
assert indices is None
mask = mask.tril(diagonal=s_k - s_q)
if indices is not None:
mask &= get_topk_attn_mask(s_q, s_k, indices)
attn_bias = torch.zeros(s_q, s_k, dtype=torch.float)
attn_bias.masked_fill_(mask.logical_not(), float("-inf"))
attn_weight += attn_bias.to(q.dtype)
attn_weight /= math.sqrt(query.size(-1))
lse = attn_weight.logsumexp(dim=-1) # [h_q, s_q]
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
output = attn_weight @ kv[..., :dv] # [h_q, s_q, dv]
# Correct for q tokens which has no attendable k
lonely_q_mask = (lse == float("-inf"))
output[lonely_q_mask.unsqueeze(-1).broadcast_to(h_q, s_q, dv)] = 0.0
lse[lonely_q_mask] = float("+inf")
return output, lse
b, s_q, h_q, d = q.size()
block_size = blocked_k.size(1)
h_kv = blocked_k.size(2)
cache_seqlens_cpu = cache_seqlens.cpu()
out_ref = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse_ref = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
cur_len = cache_seqlens_cpu[i].item()
cur_num_blocks = cdiv(cur_len, block_size)
cur_block_indices = block_table[i][0: cur_num_blocks]
cur_kv = blocked_k[cur_block_indices].view(-1, h_kv, d)[:cur_len, ...]
cur_out, cur_lse = scaled_dot_product_attention(
i,
q[i].transpose(0, 1),
cur_kv.transpose(0, 1),
dv,
is_causal,
indices[i] if indices is not None else None
)
out_ref[i] = cur_out.transpose(0, 1)
lse_ref[i] = cur_lse
out_ref = out_ref.to(torch.bfloat16)
return out_ref, lse_ref
@torch.inference_mode()
def test_flash_mla(t: TestParam):
print('-------------------------------')
print(f"Running on {t}...")
# Generating test data
torch.cuda.synchronize()
cache_seqlens, q, block_table, blocked_k, abs_indices, indices_in_kvcache = generate_test_data(t)
if t.is_fp8:
# The quantization error may be too large to be distinguished from wrong kernels
# So we quantize and de-quantize kv-cache here to mitigate quantization error
blocked_k_quantized = quant.quantize_k_cache(blocked_k, t.dv, 128)
blocked_k_dequantized = quant.dequantize_k_cache(blocked_k_quantized)
blocked_k = blocked_k_dequantized
# Get schedule metadata
torch.cuda.synchronize()
tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata(
cache_seqlens,
t.s_q * t.h_q // t.h_kv,
t.h_kv,
t.h_q,
t.is_fp8,
t.topk
)
torch.cuda.synchronize()
def run_flash_mla():
return flash_mla.flash_mla_with_kvcache(
q,
blocked_k if not t.is_fp8 else blocked_k_quantized, # type: ignore
block_table,
cache_seqlens,
t.dv,
tile_scheduler_metadata,
num_splits,
causal=t.is_causal,
is_fp8_kvcache=t.is_fp8,
indices=indices_in_kvcache
)
out_ans, lse_ans = run_flash_mla()
out_ref, lse_ref = reference_torch(cache_seqlens, block_table, q, blocked_k, t.dv, t.is_causal, abs_indices)
assert check_is_allclose("out", out_ans, out_ref, abs_tol=8e-4, rel_tol=2.01/128, cos_diff_tol=5e-6)
assert check_is_allclose("lse", lse_ans, lse_ref, abs_tol=1e-6, rel_tol=8.01/65536)
if t.test_performance:
time_usage: float = triton.testing.do_bench(run_flash_mla)/1000 # type: ignore
mean_attended_seqlens = cache_seqlens.float().mean().item() if t.topk is None else t.topk
compute_volume_flop = t.b*t.h_q*t.s_q*sum([
2*t.d*mean_attended_seqlens, # Q * K^T
2*mean_attended_seqlens*t.dv, # attention * V
])
q_elem_size = torch.bfloat16.itemsize
kv_token_size = 656 if t.is_fp8 else t.d*torch.bfloat16.itemsize
memory_volume_B = t.b*sum([
t.s_q*t.h_q*(t.d*q_elem_size), # Q
(t.s_q if t.topk is not None else 1) * mean_attended_seqlens*t.h_kv*kv_token_size, # K/V
t.s_q*t.h_q*(t.dv*q_elem_size), # Output
])
achieved_tflops = compute_volume_flop / time_usage / 1e12
achieved_gBps = memory_volume_B / time_usage / 1e9
print(f"{time_usage*1000:.3f} ms, {achieved_tflops:.0f} TFLOPS, {achieved_gBps:.0f} GB/s")
def main(torch_dtype):
device = torch.device("cuda:0")
torch.set_default_dtype(torch_dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
correctness_cases = [
TestParam(b, s_q, s_k, is_varlen, is_causal, is_fp8, topk, test_performance=False)
for b in [1, 2, 6, 64]
for s_q in [1, 2, 4]
for s_k in [20, 140, 4096]
for is_varlen in [False, True]
for is_causal in [False, True]
for (is_fp8, topk) in [
(False, None),
(True, 128),
(True, 2048)
]
if not (is_causal and topk is not None)
]
corner_cases = [
# Cases where all topk indices are invalid
TestParam(128, 2, 4096, is_varlen=True, is_causal=False, is_fp8=True, topk=topk, test_performance=False, is_all_indices_invalid=True)
for topk in [128, 2048, 4096]
] + [
# Cases where some kv cache have zero length
TestParam(128, 2, 4096, is_varlen=True, is_causal=is_causal, is_fp8=is_fp8, topk=topk, test_performance=False, have_zero_seqlen_k=True)
for (is_causal, is_fp8, topk) in [
(False, False, None),
(True, False, None),
(False, True, 128),
(False, True, 2048),
]
]
performance_cases = [
TestParam(128, s_q, s_k, is_varlen=True, is_causal=is_causal, is_fp8=is_fp8, topk=topk, test_performance=True)
for (is_causal, is_fp8, topk) in [
(False, False, None),
(True, False, None),
(False, True, 2048),
]
for s_q in [1, 2]
for s_k in [4096, 8192, 16384, 32768]
]
testcases = correctness_cases + corner_cases + performance_cases
for testcase in testcases:
test_flash_mla(testcase)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dtype",
type=str,
choices=["bf16", "fp16"],
default="bf16",
help="Data type to use for testing (bf16 or fp16)",
)
args = parser.parse_args()
torch_dtype = torch.bfloat16
if args.dtype == "fp16":
torch_dtype = torch.float16
main(torch_dtype)
import math
import time
from typing import Tuple
import random
import dataclasses
import torch
import triton
from flash_mla import flash_mla_sparse_fwd
from lib import check_is_allclose
@dataclasses.dataclass
class TestParam:
b: int
s_q: int
s_kv: int
topk: int
h_q: int = 128
h_kv: int = 1
d_qk: int = 576
d_v: int = 512
seed: int = 0
check_correctness: bool = True
benchmark: bool = True
@dataclasses.dataclass
class Testcase:
t: TestParam
q: torch.Tensor
kv: torch.Tensor
indices: torch.Tensor
def generate_testcase(t: TestParam) -> Testcase:
torch.manual_seed(t.seed)
torch.cuda.manual_seed(t.seed)
random.seed(t.seed)
q = torch.randn((t.b, t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16)/10
kv = torch.randn((t.b, t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16)/10
q.clamp_(-10, 10)
kv.clamp_(-10, 10)
indices = torch.full((t.b, t.s_q, t.h_kv, t.topk), t.s_kv, dtype=torch.int32)
for b in range(t.b):
for s in range(t.s_q):
for h in range(t.h_kv):
# TODO Comment
near_mask = torch.randint(0, 32, (min(t.topk, t.s_kv),)) < 31
cur_indices = torch.randperm(t.s_kv)[:t.topk]
cur_indices[near_mask] = torch.randint(max(0, t.s_kv-20000), t.s_kv-1, (near_mask.sum().item(),))
if len(cur_indices) < t.topk:
cur_indices = torch.cat([cur_indices, torch.full((t.topk - len(cur_indices),), 2147480000)])
cur_indices = cur_indices[torch.randperm(t.topk)]
indices[b, s, h] = cur_indices
indices = indices.to(q.device)
return Testcase(
t=t,
q=q,
kv=kv,
indices=indices
)
def get_flop(p: TestParam) -> float:
flop = 2 * sum([
p.h_q * p.d_qk * p.topk,
p.h_q * p.d_v * p.topk
]) * p.b * p.s_q
return flop
def reference_torch(p: TestParam, t: Testcase, sm_scale: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor:
return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e)
assert p.b == 1
indices = t.indices[0, :, 0, :] # [s_q, topk]
invalid_indices_mask = (indices < 0) | (indices >= p.s_kv)
qs = t.q[0, :, :, :].float() # [s_q, h_q, d_qk]
kvs = t.kv[0, :, 0, :].float() # [s_kv, d_qk]
kvs = torch.index_select(kvs, 0, indices.masked_fill(invalid_indices_mask, 0).flatten()).view(p.s_q, p.topk, p.d_qk) # [s_q, topk, d_qk]
attn_score = qs @ kvs.transpose(1, 2) # [s_q, h_q, topk]
attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float('-inf'))
attn_score *= sm_scale * math.log2(math.e)
max_logits = torch.max(attn_score, dim=-1)[0] # [s_q, h_q]
lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q]
attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk]
result = attn_score @ kvs[:, :, :p.d_v]
return (max_logits, lse, result)
@torch.inference_mode()
def run_test(p: TestParam) -> bool:
print("================")
print(f"Running on {p}")
torch.cuda.empty_cache()
assert p.b == 1
t = generate_testcase(p)
sm_scale = 1 / math.sqrt(p.d_qk)
torch.cuda.synchronize()
def run_ans():
return flash_mla_sparse_fwd(
t.q.squeeze(0), t.kv.squeeze(0), t.indices.squeeze(0), sm_scale=sm_scale
)
ans_out, ans_max_logits, ans_lse = run_ans()
torch.cuda.synchronize()
if p.benchmark:
flop = get_flop(p)
prefill_ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20)/1000 # type: ignore
prefill_flops = flop/prefill_ans_time/1e12
print(f"Prefill: {prefill_ans_time*1e6:4.0f} us, {prefill_flops:.3f} TFlops")
if p.check_correctness:
torch.cuda.synchronize()
ref_max_logits, ref_lse, ref_out = reference_torch(p, t, sm_scale)
torch.cuda.synchronize()
is_correct = True
is_correct &= check_is_allclose("out", ans_out, ref_out, abs_tol=8e-4, rel_tol=2.01/128, cos_diff_tol=7e-6)
is_correct &= check_is_allclose("max_logits", ans_max_logits, ref_max_logits, abs_tol=1e-6, rel_tol=2.01/65536)
is_correct &= check_is_allclose("lse", ans_lse, ref_lse, abs_tol=1e-6, rel_tol=2.01/65536)
return is_correct
else:
return True
if __name__ == '__main__':
device = torch.device("cuda:0")
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.set_float32_matmul_precision('high')
correctness_cases = [
# Regular shapes
TestParam(1, s_q, s_kv, topk, h_q=128, benchmark=False)
for s_kv, topk in [
# Regular shapes
(128, 128),
(256, 256),
(512, 512),
# Irregular shapes
(592, 128),
(1840, 256),
(1592, 384),
(1521, 512),
# Irregular shapes with OOB TopK
(95, 128),
(153, 256),
(114, 384),
]
for s_q in [
1, 62
]
]
corner_cases = [
# In these cases, some blocks may not have any valid topk indices
TestParam(1, s_q, s_kv, topk, h_q=128, benchmark=False)
for s_kv, topk in [
(32, 2048),
(64, 8192)
]
for s_q in [1, 1024]
]
performance_cases = [
TestParam(1, s_q, s_kv, topk, h_q=128)
for s_q in [4096]
for s_kv in [4096, 8192, 16384, 32768, 49152, 65536, 81920, 98304, 114688, 131072]
for topk in [2048]
]
testcases = correctness_cases + corner_cases + performance_cases
failed_cases = []
for test in testcases:
if test.benchmark:
time.sleep(0.2)
is_correct = run_test(test)
if not is_correct:
failed_cases.append(test)
if len(failed_cases) > 0:
print(f"\033[31m\033[1m{len(failed_cases)} / {len(testcases)} cases failed:\033[0m")
for case in failed_cases:
print(f" {case}")
else:
print(f"\033[32m\033[1mAll {len(testcases)} cases passed!\033[0m")
import argparse
import math
import random
import torch
import triton
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
query = query.float()
key = key.float()
value = value.float()
key = key.repeat_interleave(h_q // h_kv, dim=0)
value = value.repeat_interleave(h_q // h_kv, dim=0)
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
if is_causal:
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight @ value, lse
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
x, y = x.double(), y.double()
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
amax_diff = (x - y).abs().max().item()
# print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
assert cos_diff < 1e-5
@torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
print(
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}"
)
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
if varlen:
for i in range(b):
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32
).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
for i in range(b):
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = (
float("nan")
)
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens, s_q * h_q // h_kv, h_kv
)
def flash_mla():
return flash_mla_with_kvcache(
q,
blocked_k,
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
)
def ref_mla():
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
O, LSE = scaled_dot_product_attention(
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q=h_q,
h_kv=h_kv,
is_causal=causal,
)
out[i] = O.transpose(0, 1)
lse[i] = LSE
return out, lse
out_flash, lse_flash = flash_mla()
out_torch, lse_torch = ref_mla()
cal_diff(out_flash, out_torch, "out")
cal_diff(lse_flash, lse_torch, "lse")
t = triton.testing.do_bench(flash_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
torch.finfo(q.dtype).bits // 8
)
print(
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
)
def main(torch_dtype):
device = torch.device("cuda:0")
torch.set_default_dtype(torch_dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)
h_kv = 1
d, dv = 576, 512
causal = True
for b in [128]:
for s in [4096, 8192, 16384]:
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
for s_q in [1, 2]: # MTP = 1, 2
for varlen in [False, True]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dtype",
type=str,
choices=["bf16", "fp16"],
default="bf16",
help="Data type to use for testing (bf16 or fp16)",
)
args = parser.parse_args()
torch_dtype = torch.bfloat16
if args.dtype == "fp16":
torch_dtype = torch.float16
main(torch_dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment