Commit 082094b7 authored by Shengyu Liu's avatar Shengyu Liu
Browse files

Multiple updates and refactorings (#150)

* Multiple updates and refactorings

* Remove dead code
parent 1408756a
......@@ -4,6 +4,6 @@
namespace sm90 {
void run_fwd_kernel(const SparsePrefillParams& params);
void run_fwd_kernel(const SparseAttnFwdParams& params);
}
#include "../phase1.h"
#include "../phase1.cuh"
namespace sm90::fwd {
// NOTE (intlsy): We instantiate run_fwd_phase1_kernel in two .cu files as functions with HAVE_TOPK_LENGTH
// = true / false respectively, to compile them in parallel.
template void run_fwd_phase1_kernel<512, false>(const SparseAttnFwdParams& params);
}
#include "../phase1.h"
#include "../phase1.cuh"
namespace sm90::fwd {
// NOTE (intlsy): We instantiate run_fwd_phase1_kernel in two .cu files as functions with HAVE_TOPK_LENGTH
// = true / false respectively, to compile them in parallel.
template void run_fwd_phase1_kernel<512, true>(const SparseAttnFwdParams& params);
}
#include "../phase1.h"
#include "../phase1.cuh"
namespace sm90::fwd {
template void run_fwd_phase1_kernel<576, false>(const SparseAttnFwdParams& params);
}
#include "../phase1.h"
#include "../phase1.cuh"
namespace sm90::fwd {
template void run_fwd_phase1_kernel<576, true>(const SparseAttnFwdParams& params);
}
#pragma once
#include "config.h"
#include "utils.h"
#include "../../helpers.h"
namespace sm90::fwd {
using namespace cute;
CUTE_DEVICE void st_global_cs_128(float f0, float f1, float f2, float f3, void *dst_ptr) {
asm volatile("st.weak.global.cs.v4.f32 [%0], {%1, %2, %3, %4};\n"
:
: "l"(dst_ptr),
"f"(f0), "f"(f1), "f"(f2), "f"(f3)
);
}
CUTE_DEVICE
float2 __shfl_xor_sync_float2(
uint32_t mask, float2 value, int offset
) {
float2 res;
*reinterpret_cast<long long*>(&res) = __shfl_xor_sync(
mask,
*reinterpret_cast<long long*>(&value),
offset
);
return res;
}
CUTE_DEVICE
void tma_bulk_reduce_add(void const* src_ptr, void* dst_ptr, int32_t store_bytes) {
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(src_ptr);
asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n"
:
: "l"(dst_ptr), "r"(smem_int_ptr), "r"(store_bytes)
: "memory");
}
template<int D_QK, bool HAVE_TOPK_LENGTH>
template<typename TMAParams>
__device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttnFwdParams &params, const TMAParams &tma_params) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__))
const int q_h_idx = blockIdx.x % (params.h_q/B_H);
const int s_q_idx = blockIdx.x / (params.h_q/B_H);
const int warpgroup_idx = cutlass::canonical_warp_group_idx();
const int warp_idx = cutlass::canonical_warp_idx_sync();
const int idx_in_warpgroup = threadIdx.x % 128;
// Define shared tensors
extern __shared__ char wksp_buf[];
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);
Tensor sQ = make_tensor(make_smem_ptr(plan.q_o.q.data()), SmemLayoutQ{});
Tensor sO = make_tensor(make_smem_ptr(plan.q_o.o.data()), SmemLayoutO{});
Tensor sS0 = make_tensor(make_smem_ptr(D_QK == 576 ? plan.k[0].data()+64*512 : plan.s[1].data()), SmemLayoutS{}); // Overlap with sK0's RoPE part for V3.2
Tensor sS1 = make_tensor(make_smem_ptr(plan.s[0].data()), SmemLayoutS{});
if (warp_idx == 0 && elect_one_sync()) {
// Prefetch TMA descriptors
cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(&tma_params.tensor_map_O);
// Initialize barriers
plan.bar_q.init(1);
CUTE_UNROLL
for (int i = 0; i < 2; ++i) {
plan.bar_k0_free[i].init(128);
plan.bar_k0_ready[i].init(128);
plan.bar_k1_free[i].init(128);
plan.bar_k1_ready[i].init(128);
}
plan.bar_is_kv_valid_ready.init(16);
fence_barrier_init();
}
__syncthreads();
const int topk_length = HAVE_TOPK_LENGTH ? __ldg(params.topk_length + s_q_idx) : params.topk;
const int num_topk_blocks = HAVE_TOPK_LENGTH ? ku::ceil_div(topk_length, (int)B_TOPK) : (int)((unsigned int)params.topk/(unsigned int)B_TOPK);
if (warpgroup_idx == 0 || warpgroup_idx == 1) {
cutlass::arch::warpgroup_reg_alloc<216>();
if (warp_idx == 0 && elect_one_sync()) {
// Load Q
Tensor gQ = flat_divide(
tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx),
Tile<Int<B_H>, Int<D_Q>>{}
)(_, _, q_h_idx, _0{});
launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST);
plan.bar_q.arrive_and_expect_tx(B_H*D_Q*sizeof(bf16));
}
float rM[2] = {MAX_INIT_VAL, MAX_INIT_VAL}; // Meaning: the `max_logits` used for O / rL calculation
float rL[2] = {0.0f, 0.0f};
Tensor rO = partition_fragment_C(TiledMMA_PV_LocalP{}, Shape<Int<B_H>, Int<D_V/2>>{});
Tensor rP = partition_fragment_C(TiledMMA_QK{}, Shape<Int<B_H>, Int<B_TOPK>>{});
Tensor rS = make_tensor<bf16>(partition_shape_A(TiledMMA_PV_LocalP{}, Shape<Int<B_H>, Int<B_TOPK>>{}));
cute::fill(rO, 0.0f);
// Wait for Q
plan.bar_q.wait(0);
bool cur_bar_wait_phase = 0;
struct Warpgroup0 {};
struct Warpgroup1 {};
auto qkt_gemm_one_tile = [&](auto warpgroup_idx, int tile_idx, bool clear_accum) {
constexpr bool IS_WG1 = std::is_same_v<decltype(warpgroup_idx), Warpgroup1>;
TiledMMA tiled_mma_QK = TiledMMA_QK{};
Tensor sQ_tile = flat_divide(sQ, Tile<Int<B_H>, Int<64>>{})(_, _, _0{}, tile_idx);
Tensor sK_tile = make_tensor(make_smem_ptr(plan.k[(int)IS_WG1].data() + tile_idx*B_TOPK*64), SmemLayoutKTiles<1>{});
gemm_ss(clear_accum, tiled_mma_QK, sQ_tile, sK_tile, rP, idx_in_warpgroup);
};
auto mask_rP = [&](auto warpgroup_idx) {
constexpr bool IS_WG1 = std::is_same_v<decltype(warpgroup_idx), Warpgroup1>;
plan.bar_is_kv_valid_ready.wait(cur_bar_wait_phase);
CUTE_UNROLL
for (int row_idx = 0; row_idx < 2; ++row_idx) {
CUTE_UNROLL
for (int i = row_idx*2; i < size(rP); i += 4) {
int col = 8*(i/4) + (idx_in_warpgroup%4)*2;
if (!plan.is_kv_valid[IS_WG1][col]) rP(i) = -INFINITY;
if (!plan.is_kv_valid[IS_WG1][col+1]) rP(i+1) = -INFINITY;
}
}
};
auto online_softmax_and_rescale_o = [&](auto warpgroup_idx) {
plan.bar_is_kv_valid_ready.wait(cur_bar_wait_phase);
constexpr bool IS_WG1 = std::is_same_v<decltype(warpgroup_idx), Warpgroup1>;
const float scale = params.sm_scale_div_log2;
float r_sM[2];
if constexpr (IS_WG1) {
*(float2*)r_sM = plan.sM[idx_in_warpgroup/4];
}
float new_maxs[2];
CUTE_UNROLL
for (int row_idx = 0; row_idx < 2; ++row_idx) {
// Get rowwise max
float cur_max = -INFINITY;
CUTE_UNROLL
for (int i = row_idx*2; i < size(rP); i += 4) {
cur_max = max(cur_max, max(rP(i), rP(i+1)));
}
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1));
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2));
cur_max *= scale;
// Get new max and scale
// For WG1, old_max comes from sM (written by WG0); for WG0, old_max comes from rM (read by WG0 from sM in the last round)
new_maxs[row_idx] = max(IS_WG1 ? r_sM[row_idx] : rM[row_idx], cur_max);
// Scale O
float scale_for_o = exp2f(rM[row_idx]-new_maxs[row_idx]);
CUTE_UNROLL
for (int i = row_idx*2; i < size(rO); i += 4) {
rO(i) *= scale_for_o;
rO(i+1) *= scale_for_o;
}
// Get rS
float cur_sum = 0;
CUTE_UNROLL
for (int i = row_idx*2; i < size(rP); i += 4) {
rP(i) = exp2f(rP(i)*scale - new_maxs[row_idx]);
rP(i+1) = exp2f(rP(i+1)*scale - new_maxs[row_idx]);
rS(i) = (bf16)rP(i);
rS(i+1) = (bf16)rP(i+1);
cur_sum += rP(i) + rP(i+1);
}
rL[row_idx] = rL[row_idx]*scale_for_o + cur_sum;
}
__syncwarp();
if (idx_in_warpgroup%4 == 0) {
plan.sM[idx_in_warpgroup/4] = *(float2*)new_maxs;
}
rM[0] = new_maxs[0];
rM[1] = new_maxs[1];
};
auto reduce_L = [&]() {
// Reduce L
// For example, thread 0 reduces with thread 1, 2, and 3, as well as thread 128, 129, 130, and 131
rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1);
rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2);
rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1);
rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2);
if (idx_in_warpgroup%4 == 0)
plan.sL[threadIdx.x/4] = *(float2*)(rL);
NamedBarrier::arrive_and_wait(256, NamedBarriers::sL_ready);
float2 peer_L = plan.sL[(threadIdx.x/4)^32];
rL[0] += peer_L.x;
rL[1] += peer_L.y;
};
auto store_O = [&]() {
float scale_factors[2];
CUTE_UNROLL
for (int i = 0; i < 2; ++i) {
float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : params.attn_sink[q_h_idx*B_H + get_AorC_row_idx(i, idx_in_warpgroup)]*CUDART_L2E_F;
scale_factors[i] = 1.0f / (rL[i] + exp2f(attn_sink - rM[i]));
if (rL[i] == 0.0f)
scale_factors[i] = 0.0f; // The output should be 0 whatever attn_sink is
}
Tensor sO = make_tensor(make_smem_ptr(plan.q_o.o.data() + warpgroup_idx*B_H*(D_V/2)), SmemLayoutOTiles<4>{});
bf16* stsm_addrs[4];
int stsm_row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%16);
CUTE_UNROLL
for (int i = 0; i < 64/16; ++i) {
stsm_addrs[i] = &sO(stsm_row, (idx_in_warpgroup%32/16*8) + 16*i);
}
bool s2g_pred = warp_idx%4 == 0 && elect_one_sync();
warpgroup_wait<0>();
CUTE_UNROLL
for (int tile_idx = 0; tile_idx < (D_V/2)/64; tile_idx += 1) {
// Convert
constexpr int NUM_ELEMS_EACH_TILE = B_H*64 / 128; // 64: tile size, 128: warpgroup size
bf16 cur_rOb[NUM_ELEMS_EACH_TILE];
CUTE_UNROLL
for (int i = 0; i < NUM_ELEMS_EACH_TILE; ++i) {
cur_rOb[i] = (bf16)(rO(tile_idx*NUM_ELEMS_EACH_TILE + i) * scale_factors[i%4>=2]);
}
// R -> S
CUTE_UNROLL
for (int i = 0; i < 64/16; ++i) {
SM90_U32x4_STSM_N::copy(
*reinterpret_cast<uint32_t*>(cur_rOb + i*8 + 0),
*reinterpret_cast<uint32_t*>(cur_rOb + i*8 + 2),
*reinterpret_cast<uint32_t*>(cur_rOb + i*8 + 4),
*reinterpret_cast<uint32_t*>(cur_rOb + i*8 + 6),
*reinterpret_cast<uint128_t*>(stsm_addrs[i] + tile_idx*(B_H*64))
);
}
fence_view_async_shared();
NamedBarrier::arrive_and_wait(128, warpgroup_idx ? NamedBarriers::warpgroup1_sync : NamedBarriers::warpgroup0_sync);
// S -> G
if (s2g_pred) {
int g_tile_idx = warpgroup_idx*4 + tile_idx;
SM90_TMA_STORE_3D::copy(
&tma_params.tensor_map_O,
plan.q_o.o.data() + g_tile_idx*(B_H*64),
g_tile_idx*64,
q_h_idx*B_H,
s_q_idx
);
}
}
cute::tma_store_arrive();
};
if (warpgroup_idx == 0) {
// Warpgroup 0
auto pipelined_wait_and_qkt_gemm_l = [&]() __attribute__((always_inline)) {
plan.bar_k0_ready[0].wait(cur_bar_wait_phase);
qkt_gemm_one_tile(Warpgroup0{}, 0, true);
qkt_gemm_one_tile(Warpgroup0{}, 1, false);
qkt_gemm_one_tile(Warpgroup0{}, 2, false);
qkt_gemm_one_tile(Warpgroup0{}, 3, false);
warpgroup_commit_batch();
};
auto pipelined_wait_and_qkt_gemm_r = [&]() __attribute__((always_inline)) {
plan.bar_k0_ready[1].wait(cur_bar_wait_phase);
qkt_gemm_one_tile(Warpgroup0{}, 4, false);
qkt_gemm_one_tile(Warpgroup0{}, 5, false);
qkt_gemm_one_tile(Warpgroup0{}, 6, false);
qkt_gemm_one_tile(Warpgroup0{}, 7, false);
if constexpr (D_QK == 576) {
qkt_gemm_one_tile(Warpgroup0{}, 8, false);
}
warpgroup_commit_batch();
};
auto scale_rS = [&](float scales[2]) {
CUTE_UNROLL
for (int row = 0; row < 2; ++row) {
CUTE_UNROLL
for (int i = row*2; i < size(rP); i += 4) {
rS(i) = (bf16)(rP(i) * scales[row]);
rS(i+1) = (bf16)(rP(i+1) * scales[row]);
}
}
};
auto rescale_rO = [&](float scales[2]) {
CUTE_UNROLL
for (int row = 0; row < 2; ++row) {
CUTE_UNROLL
for (int i = row*2; i < size(rO); i += 4) {
rO(i) *= scales[row];
rO(i+1) *= scales[row];
}
rL[row] *= scales[row];
}
};
CUTE_NO_UNROLL
for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) {
Tensor sV0l = make_tensor(make_smem_ptr(plan.k[0].data()), SmemLayoutKTilesTransposed<4>{});
Tensor sV1l = make_tensor(make_smem_ptr(plan.k[1].data()), SmemLayoutKTilesTransposed<4>{});
if (block_idx == 0) {
// NOTE: We put this code here to avoid register spilling
pipelined_wait_and_qkt_gemm_l();
pipelined_wait_and_qkt_gemm_r();
warpgroup_wait<0>();
}
// Online softmax, inform WG1
mask_rP(Warpgroup0{});
online_softmax_and_rescale_o(Warpgroup0{});
NamedBarrier::arrive(256, NamedBarriers::wg0_bunch_0_ready);
// Issue rO0 += rS0 @ sV0l
gemm_rs(false, TiledMMA_PV_LocalP{}, rS, sV0l, rO, idx_in_warpgroup);
warpgroup_commit_batch();
// Mark V0L as free
warpgroup_wait<0>();
plan.bar_k0_free[0].arrive();
// Wait for new sM, scale rS, save, inform WG1
NamedBarrier::arrive_and_wait(256, NamedBarriers::wg1_bunch_0_ready);
float new_rM[2], scale_factors[2];
*(float2*)new_rM = plan.sM[idx_in_warpgroup/4];
CUTE_UNROLL
for (int i = 0; i < 2; ++i) {
scale_factors[i] = exp2f(rM[i] - new_rM[i]);
rM[i] = new_rM[i];
}
scale_rS(scale_factors);
save_rS_to_sS(rS, sS0, idx_in_warpgroup);
fence_view_async_shared();
NamedBarrier::arrive(256, NamedBarriers::wg0_s0_ready);
// Wait for sS1
NamedBarrier::arrive_and_wait(256, NamedBarriers::wg1_s1_ready);
// Rescale rO0, Issue rO0 += sS1 @ sV1L
rescale_rO(scale_factors);
gemm_ss(false, TiledMMA_PV_RemoteP{}, sS1, sV1l, rO, idx_in_warpgroup);
warpgroup_commit_batch();
cur_bar_wait_phase ^= 1;
if (block_idx+2 < num_topk_blocks) {
// Launch the next QK^T GEMM
pipelined_wait_and_qkt_gemm_l();
// Mark V1L as free
warpgroup_wait<1>();
plan.bar_k1_free[0].arrive();
pipelined_wait_and_qkt_gemm_r();
// Wait for rP0 = sQ @ sK0
warpgroup_wait<0>();
} else {
// Mark V1L as free
warpgroup_wait<0>();
plan.bar_k1_free[0].arrive();
}
}
reduce_L();
store_O();
} else {
// Warpgroup 1
auto pipelined_wait_and_qkt_gemm = [&]() __attribute__((always_inline)) {
plan.bar_k1_ready[1].wait(cur_bar_wait_phase);
qkt_gemm_one_tile(Warpgroup1{}, 4, true);
qkt_gemm_one_tile(Warpgroup1{}, 5, false);
qkt_gemm_one_tile(Warpgroup1{}, 6, false);
qkt_gemm_one_tile(Warpgroup1{}, 7, false);
if constexpr (D_QK == 576) {
qkt_gemm_one_tile(Warpgroup1{}, 8, false);
}
plan.bar_k1_ready[0].wait(cur_bar_wait_phase);
qkt_gemm_one_tile(Warpgroup1{}, 0, false);
qkt_gemm_one_tile(Warpgroup1{}, 1, false);
qkt_gemm_one_tile(Warpgroup1{}, 2, false);
qkt_gemm_one_tile(Warpgroup1{}, 3, false);
warpgroup_commit_batch();
};
CUTE_NO_UNROLL
for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) {
Tensor sV0r = make_tensor(make_smem_ptr(plan.k[0].data()+64*256), SmemLayoutKTilesTransposed<4>{});
Tensor sV1r = make_tensor(make_smem_ptr(plan.k[1].data()+64*256), SmemLayoutKTilesTransposed<4>{});
// Issue rP1 = sQ @ sK1, and wait
pipelined_wait_and_qkt_gemm();
warpgroup_wait<0>();
mask_rP(Warpgroup1{});
// Wait for WG0 (for sM), online softmax, Notify WG0 (sM ready)
NamedBarrier::arrive_and_wait(256, NamedBarriers::wg0_bunch_0_ready);
online_softmax_and_rescale_o(Warpgroup1{});
NamedBarrier::arrive(256, NamedBarriers::wg1_bunch_0_ready);
// Issue rO1 += rS1 @ sV1R
gemm_rs(false, TiledMMA_PV_LocalP{}, rS, sV1r, rO, idx_in_warpgroup);
warpgroup_commit_batch();
// Wait for WG0 (for sS0), Issue rO1 += rS0 @ sV0R
save_rS_to_sS(rS, sS1, idx_in_warpgroup); // Put it here is faster
NamedBarrier::arrive_and_wait(256, NamedBarriers::wg0_s0_ready);
gemm_ss(false, TiledMMA_PV_RemoteP{}, sS0, sV0r, rO, idx_in_warpgroup);
warpgroup_commit_batch();
// Save rS1, inform WG0
fence_view_async_shared();
NamedBarrier::arrive(256, NamedBarriers::wg1_s1_ready);
// Wait for GEMM, and inform that sV1R is free
warpgroup_wait<1>();
plan.bar_k1_free[1].arrive();
// Wait for GEMM, and inform that sV0R is free
warpgroup_wait<0>();
plan.bar_k0_free[1].arrive();
cur_bar_wait_phase ^= 1;
}
reduce_L();
store_O();
// Save lse
if (idx_in_warpgroup%4 == 0) {
for (int row = 0; row < 2; ++row) {
int real_row = get_AorC_row_idx(row, idx_in_warpgroup);
bool is_no_valid_tokens = rL[row] == 0.0f;
plan.final_max_logits[real_row] = is_no_valid_tokens ? -INFINITY : rM[row]*CUDART_LN2_F;
plan.final_lse[real_row] = is_no_valid_tokens ? +INFINITY : logf(rL[row]) + rM[row]*CUDART_LN2_F;
}
fence_view_async_shared();
}
NamedBarrier::arrive_and_wait(128, NamedBarriers::warpgroup1_sync);
if (idx_in_warpgroup == 0) {
int g_offset = s_q_idx*params.h_q + q_h_idx*B_H;
SM90_BULK_COPY_S2G::copy(plan.final_max_logits, params.max_logits + g_offset, B_H*sizeof(float));
SM90_BULK_COPY_S2G::copy(plan.final_lse, params.lse + g_offset, B_H*sizeof(float));
cute::tma_store_arrive();
}
}
} else {
// Producer warpgroup
cutlass::arch::warpgroup_reg_dealloc<72>();
constexpr int GROUP_SIZE = 8, NUM_GROUPS = 128/GROUP_SIZE;
constexpr int NUM_ROWS_PER_GROUP = B_TOPK / NUM_GROUPS;
int idx_in_group = idx_in_warpgroup % GROUP_SIZE;
int group_idx = idx_in_warpgroup / GROUP_SIZE;
int* gIndices = params.indices + s_q_idx*params.stride_indices_s_q; // [topk]
bf16* my_sKV_base = &(make_tensor(make_smem_ptr(plan.k[0].data()), SmemLayoutKTiles<1>{})(group_idx, idx_in_group*8));
bf16* my_gKV_base = params.kv + idx_in_group*8;
int64_t token_indices[2][NUM_ROWS_PER_GROUP];
bool is_token_valid[2][NUM_ROWS_PER_GROUP];
auto load_token_indices = [&](int block_idx) {
CUTE_UNROLL
for (int buf_idx = 0; buf_idx < 2; ++buf_idx) {
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) {
int offs = (block_idx+buf_idx)*B_TOPK + local_row*NUM_GROUPS + group_idx;
int t = __ldg(gIndices + offs);
token_indices[buf_idx][local_row] = t*(int64_t)params.stride_kv_s_kv; // We mult it with params.stride_kv_s_kv here since it's faster
bool is_cur_token_valid = t >= 0 && t < params.s_kv;
if constexpr (HAVE_TOPK_LENGTH) {
is_cur_token_valid &= offs < topk_length;
}
is_token_valid[buf_idx][local_row] = is_cur_token_valid;
}
}
};
int64_t cache_policy = createpolicy_evict_last();
auto copy_tiles = [&](int block_idx, int buf_idx, int tile_start, int tile_end) {
// Copy some K/V tiles from global memory to shared memory
// A tile has a shape of 64 (B_TOPK) x 64
// `buf_idx` is the index of the shared memory buffer, 0 or 1
// `tile_idx` is the index of the tile to load, from 0 to D_K/64-1 = 8
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) {
int64_t token_index = token_indices[buf_idx][local_row];
CUTE_UNROLL
for (int tile_idx = tile_start; tile_idx < tile_end; ++tile_idx) {
cp_async_cacheglobal_l2_prefetch_256B(
my_gKV_base + token_index + tile_idx*64,
my_sKV_base + (buf_idx*B_TOPK*D_K + tile_idx*(B_TOPK*64) + local_row*NUM_GROUPS*64),
is_token_valid[buf_idx][local_row],
cache_policy
);
}
}
};
auto commit_to_mbar = [&](transac_bar_t &bar) {
cutlass::arch::cpasync_barrier_arrive_noinc((uint64_t*)(&bar));
};
int cur_bar_wait_phase = 1;
CUTE_NO_UNROLL
for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) {
load_token_indices(block_idx);
// V0L
plan.bar_k0_free[0].wait(cur_bar_wait_phase);
copy_tiles(block_idx+0, 0, 0, 4);
commit_to_mbar(plan.bar_k0_ready[0]);
// V1R
plan.bar_k1_free[1].wait(cur_bar_wait_phase);
copy_tiles(block_idx+1, 1, 4, D_K/64);
commit_to_mbar(plan.bar_k1_ready[1]);
// V0R
plan.bar_k0_free[1].wait(cur_bar_wait_phase);
copy_tiles(block_idx+0, 0, 4, D_K/64);
commit_to_mbar(plan.bar_k0_ready[1]);
// V1L
plan.bar_k1_free[0].wait(cur_bar_wait_phase);
copy_tiles(block_idx+1, 1, 0, 4);
commit_to_mbar(plan.bar_k1_ready[0]);
// Valid mask
// NOTE: V1R's finish implies maskings of the last round have finished
if (idx_in_group == 0) {
CUTE_UNROLL
for (int buf_idx = 0; buf_idx < 2; ++buf_idx)
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row)
plan.is_kv_valid[buf_idx][local_row*NUM_GROUPS+group_idx] = is_token_valid[buf_idx][local_row];
plan.bar_is_kv_valid_ready.arrive();
}
cur_bar_wait_phase ^= 1;
}
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm90");
}
#endif
}
template<typename Kernel, typename TMAParams>
__global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 1)
sparse_attn_fwd_kernel(__grid_constant__ const SparseAttnFwdParams params, __grid_constant__ const TMAParams tma_params) {
Kernel::devfunc(params, tma_params);
}
template<int D_QK, bool HAVE_TOPK_LENGTH>
void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams &params) {
KU_ASSERT(params.h_kv == 1);
KU_ASSERT(params.topk % (2*B_TOPK) == 0); // To save some boundry checkings
KU_ASSERT(params.topk > 0);
KU_ASSERT(params.h_q % B_H == 0);
auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q);
auto tma_Q = cute::make_tma_copy(
SM90_TMA_LOAD{},
make_tensor(
make_gmem_ptr((bf16*)params.q),
make_layout(
shape_Q,
make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q)
)
),
SmemLayoutQ{}
);
CUtensorMap tensor_map_O;
{
uint64_t size[3] = {D_V, (unsigned long)params.h_q, (unsigned long)params.s_q};
uint64_t stride[2] = {D_V*sizeof(bf16), D_V*params.h_q*sizeof(bf16)};
uint32_t box_size[3] = {64, B_H, 1};
uint32_t elem_stride[3] = {1, 1, 1};
CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
&tensor_map_O,
CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
3,
params.out,
size,
stride,
box_size,
elem_stride,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);
KU_ASSERT(res == CUresult::CUDA_SUCCESS);
}
TmaParams<
decltype(shape_Q), decltype(tma_Q)
> tma_params = {
shape_Q, tma_Q,
tensor_map_O
};
auto kernel = &sparse_attn_fwd_kernel<KernelTemplate<D_QK, HAVE_TOPK_LENGTH>, decltype(tma_params)>;
constexpr size_t smem_size = sizeof(SharedMemoryPlan);
KU_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
cutlass::ClusterLaunchParams launch_params = {
dim3((params.h_q/B_H)*params.s_q, 1, 1), // NOTE: We put s_q on the first dim since it can be larger than 65536 (the maximum size of griddim.y and griddim.z)
dim3(NUM_THREADS, 1, 1),
dim3(1, 1, 1),
smem_size,
params.stream
};
cutlass::launch_kernel_on_cluster(
launch_params, (void*)kernel, params, tma_params
);
KU_CHECK_KERNEL_LAUNCH();
}
template<int D_QK, bool HAVE_TOPK_LENGTH>
void run_fwd_phase1_kernel(const SparseAttnFwdParams& params) {
KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(params);
}
}
#pragma once
#include "../../../params.h"
namespace sm90::fwd {
template<int D_QK, bool HAVE_TOPK_LENGTH>
void run_fwd_phase1_kernel(const SparseAttnFwdParams& params);
}
#include "mla_combine.h"
#include "combine.h"
#include <math_constants.h>
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <kerutils/kerutils.cuh>
#include "params.h"
#include "utils.h"
using namespace cute;
namespace smxx::decode {
template<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS>
__global__ void __launch_bounds__(NUM_THREADS)
flash_fwd_mla_combine_kernel(__grid_constant__ const DecodingParams params) {
// grid_shape: [batch_size, num_q_heads*s_q / BLOCK_SIZE_M]
flash_fwd_mla_combine_kernel(__grid_constant__ const CombineParams params) {
// grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M]
// Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result
static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m
const int batch_idx = blockIdx.x;
const int m_block_idx = blockIdx.y;
const int s_q_idx = blockIdx.y;
const int h_block_idx = blockIdx.z;
const int warp_idx = threadIdx.x / 32;
const int lane_idx = threadIdx.x % 32;
int num_valid_heads = std::min(BLOCK_SIZE_M, params.h_q - BLOCK_SIZE_M*h_block_idx);
if (warp_idx >= num_valid_heads) {
return;
}
const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx);
const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1);
const int my_num_splits = end_split_idx - start_split_idx;
FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
if (my_num_splits == 1) {
return;
}
const int num_q_seqs = params.q_seq_per_hk * params.h_k;
const int num_cur_valid_q_seqs = min(BLOCK_SIZE_M, num_q_seqs - m_block_idx*BLOCK_SIZE_M);
FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
Tensor gLseAccum = make_tensor(
make_gmem_ptr((float*)params.softmax_lseaccum_ptr + start_split_idx*num_q_seqs + m_block_idx*BLOCK_SIZE_M),
make_gmem_ptr((float*)params.lse_accum + start_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + h_block_idx*BLOCK_SIZE_M),
Shape<Int<MAX_SPLITS>, Int<BLOCK_SIZE_M>>{},
make_stride(num_q_seqs, _1{})
make_stride(params.stride_lse_accum_split, _1{})
);
Tensor gLse = make_tensor(
make_gmem_ptr((float*)params.softmax_lse_ptr + batch_idx*num_q_seqs + m_block_idx*BLOCK_SIZE_M),
make_gmem_ptr((float*)params.lse + batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + h_block_idx*BLOCK_SIZE_M),
Shape<Int<BLOCK_SIZE_M>>{},
Stride<_1>{}
);
extern __shared__ float smem_buf[];
Tensor sLseScale = make_tensor(
make_smem_ptr(smem_buf),
Shape<Int<BLOCK_SIZE_M>, Int<MAX_SPLITS>>{},
Stride<Int<MAX_SPLITS+1>, _1>{} // +1 to avoid bank conflict
);
__shared__ float smem_buf[BLOCK_SIZE_M][MAX_SPLITS];
// Wait for the previous kernel (the MLA kernel) to finish
cudaGridDependencySynchronize();
// Read gLseAccum into sLseScale
{
#pragma unroll 4
for (int elem_idx = threadIdx.x; elem_idx < my_num_splits*BLOCK_SIZE_M; elem_idx += NUM_THREADS) {
int split_idx = elem_idx / BLOCK_SIZE_M;
int seq_idx = elem_idx % BLOCK_SIZE_M;
sLseScale(seq_idx, split_idx) = seq_idx < num_cur_valid_q_seqs ? gLseAccum(split_idx, seq_idx) : -INFINITY;
}
__syncthreads();
// Prefetch
static_assert(HEAD_DIM_V % (32*4) == 0);
constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / (32*4);
float* oaccum_ptr = params.o_accum + start_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + (h_block_idx*BLOCK_SIZE_M + warp_idx)*params.stride_o_accum_h_q;
float4 datas[ELEMS_PER_THREAD];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
datas[i] = *(float4*)(oaccum_ptr + lane_idx*4 + i*128); // NOTE We don't use __ldg here since it is incompatible with PDL
}
if (warp_idx >= num_cur_valid_q_seqs)
return;
// Warp #i gathers LseAccum for seq #i
{
constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS, 32);
......@@ -73,7 +74,7 @@ flash_fwd_mla_combine_kernel(__grid_constant__ const DecodingParams params) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
const int split_idx = i*32 + lane_idx;
local_lse[i] = split_idx < my_num_splits ? sLseScale(warp_idx, split_idx) : -INFINITY;
local_lse[i] = split_idx < my_num_splits ? gLseAccum(split_idx, warp_idx) : -INFINITY;
}
float max_lse = -INFINITY;
......@@ -93,14 +94,26 @@ flash_fwd_mla_combine_kernel(__grid_constant__ const DecodingParams params) {
for (int offset = 16; offset >= 1; offset /= 2)
sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset);
float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : log2f(sum_lse) + max_lse;
float global_lse = (sum_lse == 0.f || sum_lse == -INFINITY) ? INFINITY : log2f(sum_lse) + max_lse;
if (lane_idx == 0)
gLse(warp_idx) = global_lse / (float)M_LOG2E;
if (params.attn_sink != nullptr) {
int q_head_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
float attn_sink = __ldg(params.attn_sink + q_head_idx);
if (global_lse != INFINITY) {
// If attn_sink is +inf, global_lse will be +inf and scale factors will be exp2f(local_lse - inf) = 0 (since local_lse never becomes +inf)
// If attn_sink is -inf, this has no effect on global_lse
global_lse += log2f(1 + exp2f(attn_sink*CUDART_L2E_F - global_lse));
} else {
// We have no tokens to attend, so global lse should be attn_sink*CUDART_L2E_F (+inf if it's -inf or +inf)
global_lse = attn_sink == -INFINITY ? +INFINITY : attn_sink*CUDART_L2E_F;
}
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
const int split_idx = i*32 + lane_idx;
if (split_idx < my_num_splits) sLseScale(warp_idx, split_idx) = exp2f(local_lse[i] - global_lse);
smem_buf[warp_idx][split_idx] = exp2f(local_lse[i] - global_lse);
}
}
......@@ -108,45 +121,42 @@ flash_fwd_mla_combine_kernel(__grid_constant__ const DecodingParams params) {
// Warp #i accumulates activation for seq #i
{
const int64_t row_offset_oaccum = (int64_t)(start_split_idx*num_q_seqs+m_block_idx*BLOCK_SIZE_M+warp_idx) * HEAD_DIM_V;
Tensor gOaccum = make_tensor(
make_gmem_ptr(reinterpret_cast<float *>(params.oaccum_ptr) + row_offset_oaccum),
Shape<Int<MAX_SPLITS>, Int<HEAD_DIM_V>>{},
make_stride(num_q_seqs*HEAD_DIM_V, _1{})
);
static_assert(HEAD_DIM_V % 32 == 0);
constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / 32;
float result[ELEMS_PER_THREAD];
float4 result[ELEMS_PER_THREAD];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ELEMS_PER_THREAD; ++i)
result[i] = 0.0f;
result[i] = {0.0f, 0.0f, 0.0f, 0.0f};
#pragma unroll 2
#pragma unroll 1
for (int split = 0; split < my_num_splits; ++split) {
float lse_scale = sLseScale(warp_idx, split);
if (lse_scale != 0.f) {
float lse_scale = smem_buf[warp_idx][split];
// if (lse_scale != 0.f) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
result[i] += lse_scale * gOaccum(split, lane_idx + i*32);
result[i].x += lse_scale * datas[i].x;
result[i].y += lse_scale * datas[i].y;
result[i].z += lse_scale * datas[i].z;
result[i].w += lse_scale * datas[i].w;
if (split != my_num_splits-1) {
datas[i] = *(float4*)(oaccum_ptr + (split+1)*params.stride_o_accum_split + lane_idx*4 + i*128);
}
}
// }
}
cudaTriggerProgrammaticLaunchCompletion();
const int q_seq_idx = m_block_idx*BLOCK_SIZE_M + warp_idx;
const int k_head_idx = q_seq_idx / params.q_seq_per_hk;
auto o_ptr = reinterpret_cast<ElementT *>(params.o_ptr) + batch_idx*params.o_batch_stride + k_head_idx*params.o_head_stride + (q_seq_idx%params.q_seq_per_hk)*params.o_row_stride;
Tensor gO = make_tensor(
make_gmem_ptr(o_ptr),
Shape<Int<HEAD_DIM_V>>{},
Stride<_1>{}
);
const int h_q_idx = h_block_idx*BLOCK_SIZE_M + warp_idx;
ElementT* o_ptr = (ElementT*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + h_q_idx*params.stride_o_h_q;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ELEMS_PER_THREAD; ++i)
gO(lane_idx+i*32) = (ElementT)result[i];
for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
float4 data = result[i];
ElementT data_converted[4];
data_converted[0] = (ElementT)(data.x);
data_converted[1] = (ElementT)(data.y);
data_converted[2] = (ElementT)(data.z);
data_converted[3] = (ElementT)(data.w);
static_assert(sizeof(ElementT) == 2);
*(uint64_t*)(o_ptr + lane_idx*4 + i*128) = *(uint64_t*)data_converted;
}
}
}
......@@ -175,7 +185,7 @@ flash_fwd_mla_combine_kernel(__grid_constant__ const DecodingParams params) {
template<typename ElementT>
void run_flash_mla_combine_kernel(DecodingParams &params, cudaStream_t stream) {
void run_flash_mla_combine_kernel(CombineParams &params) {
static constexpr int HEAD_DIM_V = 512; // Since only this head dimension is supported by Flash MLA
FLASH_ASSERT(params.d_v == HEAD_DIM_V);
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] {
......@@ -189,20 +199,22 @@ void run_flash_mla_combine_kernel(DecodingParams &params, cudaStream_t stream) {
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
cudaLaunchConfig_t combine_kernel_config = {
dim3(params.b, cute::ceil_div(params.h_k*params.q_seq_per_hk, BLOCK_SIZE_M), 1),
dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)),
dim3(NUM_THREADS, 1, 1),
smem_size,
stream,
0,
params.stream,
attribute,
1
};
cudaLaunchKernelEx(&combine_kernel_config, combine_kernel, params);
CHECK_CUDA(cudaLaunchKernelEx(&combine_kernel_config, combine_kernel, params));
});
CHECK_CUDA_KERNEL_LAUNCH();
}
template void run_flash_mla_combine_kernel<cutlass::bfloat16_t>(DecodingParams &params, cudaStream_t stream);
template void run_flash_mla_combine_kernel<cutlass::bfloat16_t>(CombineParams &params);
#ifndef FLASH_MLA_DISABLE_FP16
template void run_flash_mla_combine_kernel<cutlass::half_t>(DecodingParams &params, cudaStream_t stream);
template void run_flash_mla_combine_kernel<cutlass::half_t>(CombineParams &params);
#endif
}
......@@ -2,5 +2,9 @@
#include "params.h"
namespace smxx::decode {
template<typename ElementT>
void run_flash_mla_combine_kernel(DecodingParams &params, cudaStream_t stream);
void run_flash_mla_combine_kernel(CombineParams &params);
}
#include "get_mla_metadata.h"
#include "get_decoding_sched_meta.h"
#include <cuda_runtime_api.h>
#include <cutlass/fast_math.h>
#include <kerutils/kerutils.cuh>
#include "utils.h"
namespace smxx::decode {
__global__ void __launch_bounds__(32, 1, 1)
get_mla_metadata_kernel(__grid_constant__ const GetDecodingMetadataParams params) {
get_mla_metadata_kernel(__grid_constant__ const GetDecodeSchedMetaParams params) {
int *seqlens_k_ptr = params.seqlens_k_ptr;
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
DecodingSchedMeta *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
int *num_splits_ptr = params.num_splits_ptr;
int batch_size = params.batch_size;
int batch_size = params.b;
int block_size_n = params.block_size_n;
int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
int num_sm_parts = params.num_sm_parts;
......@@ -24,14 +27,25 @@ get_mla_metadata_kernel(__grid_constant__ const GetDecodingMetadataParams params
int total_num_blocks = 0;
for (int i = threadIdx.x; i < batch_size; i += 32) {
int cur_s_k = params.topk == -1 ? __ldg(seqlens_k_ptr + i) : params.topk;
int cur_s_k;
if (params.topk == -1) {
// Dense model, cur_s_k = actual s_k
cur_s_k = __ldg(seqlens_k_ptr + i);
} else {
// Sparse model, cur_s_k = topk (+ extra topk)
cur_s_k = params.topk_length ? __ldg(params.topk_length + i) : params.topk;
if (cur_s_k == 0) cur_s_k = 1; // Ensure the main loop will never be empty
if (params.extra_topk) {
cur_s_k = ku::ceil(cur_s_k, block_size_n);
cur_s_k += params.extra_topk_length ? __ldg(params.extra_topk_length + i) : params.extra_topk;
}
}
seqlens_k_shared[i] = cur_s_k;
int first_token_idx = 0;
int last_token_idx = max(cur_s_k-1, 0);
int cur_first_block_idx = first_token_idx / block_size_n;
int cur_last_block_idx = last_token_idx / block_size_n;
// NOTE Should attend to tokens [first_token_idx, last_token_idx], i.e. blocks [cur_first_block_idx, cur_last_block_idx]
// NOTE Before clamping, first_token_idx <= last_token_idx always holds, so after clamping, first_token_idx <= last_token_idx still holds.
// NOTE if seqlens_k is 0, then first_token_idx == last_token_idx == cur_first_block_idx == cur_last_block_idx == 0. So the sequence will have 1 block. We will correct this later in this kernel.
int num_blocks = cur_last_block_idx - cur_first_block_idx + 1;
total_num_blocks += num_blocks + fixed_overhead_num_blocks;
......@@ -47,22 +61,23 @@ get_mla_metadata_kernel(__grid_constant__ const GetDecodingMetadataParams params
if (threadIdx.x == 0) {
int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
int now_req_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
num_splits_shared[0] = 0;
for (int i = 0; i < num_sm_parts; ++i) {
int tile_scheduler_metadata0[4], tile_scheduler_metadata1;
tile_scheduler_metadata0[0] = now_idx;
tile_scheduler_metadata0[1] = now_block + first_block_idx_shared[now_idx];
tile_scheduler_metadata1 = now_n_split_idx;
DecodingSchedMeta cur_meta;
cur_meta.begin_req_idx = now_req_idx;
cur_meta.begin_block_idx = now_block + first_block_idx_shared[now_req_idx];
cur_meta.begin_split_idx = now_n_split_idx;
cur_meta.is_first_req_splitted = (now_block != 0);
int remain_payload = payload;
while (now_idx < batch_size) {
int num_blocks = num_blocks_shared[now_idx];
while (now_req_idx < batch_size) {
int num_blocks = num_blocks_shared[now_req_idx];
int now_remain_blocks = num_blocks - now_block;
if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) {
cum_num_splits += now_n_split_idx + 1;
num_splits_shared[now_idx + 1] = cum_num_splits;
num_splits_shared[now_req_idx + 1] = cum_num_splits;
remain_payload -= now_remain_blocks + fixed_overhead_num_blocks;
++now_idx;
++now_req_idx;
now_block = 0;
now_n_split_idx = 0;
} else {
......@@ -74,12 +89,15 @@ get_mla_metadata_kernel(__grid_constant__ const GetDecodingMetadataParams params
break;
}
}
tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1;
tile_scheduler_metadata0[3] = now_block > 0 ? now_block + first_block_idx_shared[now_idx] : (seqlens_k_shared[now_idx-1] == 0 ? 0 : last_block_idx_shared[now_idx-1] + 1);
*reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast<int4 *>(tile_scheduler_metadata0);
tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1;
cur_meta.end_req_idx = now_block > 0 ? now_req_idx : now_req_idx - 1;
cur_meta.end_block_idx = now_block > 0 ? now_block + first_block_idx_shared[now_req_idx] : (seqlens_k_shared[now_req_idx-1] == 0 ? 0 : last_block_idx_shared[now_req_idx-1] + 1);
cur_meta.is_last_req_splitted = cur_meta.end_block_idx != last_block_idx_shared[cur_meta.end_req_idx] + 1 && seqlens_k_shared[cur_meta.end_req_idx] != 0;
if (cur_meta.begin_req_idx == cur_meta.end_req_idx) {
cur_meta.is_first_req_splitted = cur_meta.is_last_req_splitted = cur_meta.is_first_req_splitted || cur_meta.is_last_req_splitted;
}
FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
tile_scheduler_metadata_ptr[i] = cur_meta;
}
FLASH_DEVICE_ASSERT(now_req_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
}
__syncwarp();
......@@ -88,9 +106,11 @@ get_mla_metadata_kernel(__grid_constant__ const GetDecodingMetadataParams params
}
}
void run_get_mla_metadata_kernel(GetDecodingMetadataParams &params, cudaStream_t stream) {
int smem_size = sizeof(int) * (params.batch_size*5+1);
void run_get_decoding_sched_meta_kernel(GetDecodeSchedMetaParams &params) {
int smem_size = sizeof(int) * (params.b*5+1);
CHECK_CUDA(cudaFuncSetAttribute(get_mla_metadata_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
get_mla_metadata_kernel<<<1, 32, smem_size, stream>>>(params);
get_mla_metadata_kernel<<<1, 32, smem_size, params.stream>>>(params);
CHECK_CUDA_KERNEL_LAUNCH();
}
}
#pragma once
#include "params.h"
namespace smxx::decode {
void run_get_decoding_sched_meta_kernel(GetDecodeSchedMetaParams &params);
}
#pragma once
#include "params.h"
void run_get_mla_metadata_kernel(GetDecodingMetadataParams &params, cudaStream_t stream);
#pragma once
#include <cstdint>
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
......@@ -44,23 +46,37 @@ do { \
} while (0)
#endif
// For development, we define both IS_SM100 and IS_SM90 when using CLion or VSCode IDEs so code highlighting will be correct.
#if defined(__CLION_IDE__) || defined(__VSCODE_IDE__)
#define IS_SM100 1
#define IS_SM90 1
#else
// We define the following macros to detect the CUDA architecture, so that we can enable/disable certains kernels that depends on specific architectures.
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 1000)
#define IS_SM100 1
#else
#define IS_SM100 0
#ifndef TRAP_ONLY_DEVICE_ASSERT
#define TRAP_ONLY_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) \
asm("trap;"); \
} while (0)
#endif
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900)
#define IS_SM90 1
#else
#define IS_SM90 0
#endif
#endif // defined(__CLION_IDE__) || defined(__VSCODE_IDE__)
\ No newline at end of file
struct RingBufferState {
uint32_t cur_block_idx = 0u;
__device__ __forceinline__
void update() {
cur_block_idx += 1;
}
template<uint32_t NUM_STAGES>
__device__ __forceinline__
std::pair<uint32_t, bool> get() const {
uint32_t stage_idx = cur_block_idx % NUM_STAGES;
bool phase = (cur_block_idx / NUM_STAGES) & 1;
return {stage_idx, phase};
}
__device__ __forceinline__
RingBufferState offset_by(const int offset) const {
// Must guarantee no underflow
uint32_t new_block_idx = static_cast<uint32_t>(static_cast<int>(cur_block_idx) + offset);
RingBufferState new_state;
new_state.cur_block_idx = new_block_idx;
return new_state;
}
};
......@@ -8,3 +8,12 @@ from flash_mla.flash_mla_interface import (
flash_attn_varlen_kvpacked_func,
flash_mla_sparse_fwd
)
__all__ = [
"get_mla_metadata",
"flash_mla_with_kvcache",
"flash_attn_varlen_func",
"flash_attn_varlen_qkvpacked_func",
"flash_attn_varlen_kvpacked_func",
"flash_mla_sparse_fwd"
]
from typing import Optional, Tuple
import dataclasses
import torch
import flash_mla.cuda as flash_mla_cuda
@dataclasses.dataclass
class FlashMLASchedMeta:
"""
A class that stores the tile scheduler metadata of FlashMLA
"""
@dataclasses.dataclass
class Config:
b: int
s_q: int
h_q: int
page_block_size: int
h_k: int
causal: bool
is_fp8_kvcache: bool
topk: Optional[int]
extra_page_block_size: Optional[int]
extra_topk: Optional[int]
have_initialized: bool = False
config: Optional[Config] = None
tile_scheduler_metadata: Optional[torch.Tensor] = None # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: Optional[torch.Tensor] = None # (1), dtype torch.int32.
def get_mla_metadata(
cache_seqlens: torch.Tensor,
num_q_tokens_per_head_k: int,
num_heads_k: int,
num_heads_q: Optional[int] = None,
is_fp8_kvcache: bool = False,
topk: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
*args,
**kwargs
) -> Tuple[FlashMLASchedMeta, None]:
"""
Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache.
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_q_tokens_per_head_k: Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
num_heads_k: The number of k heads.
num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled
is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache_sm90` will be attended to.
This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface.
Returns:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
Return:
A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful.
"""
return flash_mla_cuda.get_mla_decoding_metadata(cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q, is_fp8_kvcache, topk)
return FlashMLASchedMeta(), None
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
block_table: Optional[torch.Tensor],
cache_seqlens: Optional[torch.Tensor],
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
tile_scheduler_metadata: FlashMLASchedMeta,
num_splits: None = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None,
attn_sink: Optional[torch.Tensor] = None,
extra_k_cache: Optional[torch.Tensor] = None,
extra_indices_in_kvcache: Optional[torch.Tensor] = None,
topk_length: Optional[torch.Tensor] = None,
extra_topk_length: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md
indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the `indices` array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up `indices`, please refer to README.md.
Returns:
Different modes (including fp8/bf16, sparsity, and model version (i.e. V3.2 or MODEL1)) has different KV cache layouts. See comments below for details.
The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks.
Besides, some kernels also have their own requirements on the layout of k cache, including:
- For sparse fp8 decoding kernel on F3, k_cache.stride(0) must be a multiple of 656B (for V32) or 576B (for MODEL1). Padding is needed sometimes.
block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
head_dim_v: Head_dim of v. Must be 512
sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
num_splits_placeholder: must be "None" (to be compatible with the old interface).
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
causal: bool. Whether to apply causal attention mask. Only valid for dense attention
is_fp8_kvcache: bool.
indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled.
Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block),
where t is the k-th token of the j-th q-sequence in the i-th batch.
attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0.
extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. This is used to support MODEL1. Their format requirements are the same as k_cache and indices_in_kvcache respectively.
topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking.
For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2:
head_dim should be 576 while head_dim_v should be 512.
In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as:
- The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1.
- First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values.
- Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on.
- Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy.
For DeepSeek MODEL1:
head_dim should be 512 while head_dim_v is also 512.
In FP8+sparse mode, every block can be divided into two parts. The first parts stores NoPE0, RoPE0, NoPE1, RoPE1, ... while the second part stores scale factors: 7xue8m0, 1Bpad, 7xue8m0, 1Bpad, ...
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
sched_meta = tile_scheduler_metadata
indices_in_kvcache = indices
assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
assert num_splits is None, "num_splits must be None"
topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None
extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None
extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if indices is not None:
assert causal == False, "causal must be `false` if sparse attention is enabled."
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q,
k_cache,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
if not sched_meta.have_initialized:
# Sanity check. We only perform sanity check during the first invocation to save CPU time.
if indices_in_kvcache is not None:
assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)"
# Initialize the tile scheduler metadata during the first invocation.
sched_meta.have_initialized = True
sched_meta.config = FlashMLASchedMeta.Config(
q.shape[0],
q.shape[1],
q.shape[2],
k_cache.shape[1],
k_cache.shape[2],
causal,
tile_scheduler_metadata,
num_splits,
is_fp8_kvcache,
indices
topk,
extra_k_page_block_size,
extra_topk,
)
else:
# Check whether the input arguments are consistent with sched_meta
helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
assert sched_meta.config is not None
assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg
assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg
assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg
if topk is not None:
# Sparse attention
assert not causal, "causal must be False when sparse attention is enabled"
assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled"
out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd(
q, k_cache, indices_in_kvcache, topk_length, attn_sink,
sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
extra_k_cache, extra_indices_in_kvcache, extra_topk_length,
head_dim_v, softmax_scale
)
else:
# Dense attention
assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used."
assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd(
q, k_cache, head_dim_v,
cache_seqlens, block_table,
softmax_scale, causal,
sched_meta.tile_scheduler_metadata, sched_meta.num_splits
)
return out, softmax_lse
sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
sched_meta.num_splits = new_num_splits
return (out, lse)
def flash_mla_sparse_fwd(
......@@ -85,6 +186,8 @@ def flash_mla_sparse_fwd(
indices: torch.Tensor,
sm_scale: float,
d_v: int = 512,
attn_sink: Optional[torch.Tensor] = None,
topk_length: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sparse attention prefill kernel
......@@ -95,16 +198,22 @@ def flash_mla_sparse_fwd(
indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv
sm_scale: float
d_v: The dimension of value vectors. Can only be 512
attn_sink: optional, [h_q], float32.
If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)).
+-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros).
This argument has no effect on lse and max_logits.
topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices).
In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation.
Returns:
(output, max_logits, lse)
About the definition of output, max_logits and lse, please refer to README.md
Please refer to tests/ref.py for the precise definitions of these parameters.
- output: [s_q, h_q, d_v], bfloat16
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, 2-based log-sum-exp
- lse: [s_q, h_q], float, log-sum-exp of attention scores
"""
results = flash_mla_cuda.sparse_prefill_fwd(
q, kv, indices, sm_scale, d_v
q, kv, indices, sm_scale, d_v, attn_sink, topk_length
)
return results
......
......@@ -36,11 +36,11 @@ def get_arch_flags():
DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100")
DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90")
if major < 12 or (major == 12 and minor <= 8):
assert DISABLE_SM100, "sm100 compilation for Flash MLA requires NVCC 12.9 or higher. Please set FLASH_MLA_DISABLE_SM100=1 to disable sm100 compilation, or update your environment."
assert DISABLE_SM100, "sm100 compilation for Flash MLA requires NVCC 12.9 or higher. Please set FLASH_MLA_DISABLE_SM100=1 to disable sm100 compilation, or update your environment." # TODO Implement this
arch_flags = []
if not DISABLE_SM100:
arch_flags.extend(["-gencode", "arch=compute_100a,code=sm_100a"])
arch_flags.extend(["-gencode", "arch=compute_100f,code=sm_100f"])
if not DISABLE_SM90:
arch_flags.extend(["-gencode", "arch=compute_90a,code=sm_90a"])
return arch_flags
......@@ -54,31 +54,60 @@ subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
this_dir = os.path.dirname(os.path.abspath(__file__))
if IS_WINDOWS:
cxx_args = ["/O2", "/std:c++17", "/DNDEBUG", "/W0"]
cxx_args = ["/O2", "/std:c++20", "/DNDEBUG", "/W0"]
else:
cxx_args = ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"]
cxx_args = ["-O3", "-std=c++20", "-DNDEBUG", "-Wno-deprecated-declarations"]
ext_modules = []
ext_modules.append(
CUDAExtension(
name="flash_mla.cuda",
sources=[
"csrc/pybind.cpp",
"csrc/smxx/get_mla_metadata.cu",
"csrc/smxx/mla_combine.cu",
"csrc/sm90/decode/dense/splitkv_mla.cu",
"csrc/sm90/decode/sparse_fp8/splitkv_mla.cu",
# API
"csrc/api/api.cpp",
# Misc kernels for decoding
"csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu",
"csrc/smxx/decode/combine/combine.cu",
# sm90 dense decode
"csrc/sm90/decode/dense/instantiations/fp16.cu",
"csrc/sm90/decode/dense/instantiations/bf16.cu",
# sm90 sparse decode
"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu",
"csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu",
"csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu",
"csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu",
# sm90 sparse prefill
"csrc/sm90/prefill/sparse/fwd.cu",
"csrc/sm100/decode/sparse_fp8/splitkv_mla.cu",
"csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu",
"csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu",
"csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu",
"csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu",
# sm100 dense prefill & backward
"csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu",
"csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu",
"csrc/sm100/prefill/sparse/fwd.cu",
# sm100 sparse prefill
"csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu",
"csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k576.cu",
"csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k512.cu",
"csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k576.cu",
"csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_prefill_k512.cu",
# sm100 sparse decode
"csrc/sm100/decode/head64/instantiations/v32.cu",
"csrc/sm100/decode/head64/instantiations/model1.cu",
"csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu",
],
extra_compile_args={
"cxx": cxx_args + get_features_args(),
"nvcc": [
"-O3",
"-std=c++17",
"-std=c++20",
"-DNDEBUG",
"-D_USE_MATH_DEFINES",
"-Wno-deprecated-declarations",
......@@ -89,11 +118,14 @@ ext_modules.append(
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=-v,--register-usage-level=10"
"--ptxas-options=-v,--register-usage-level=10,--warn-on-spills,--warn-on-local-memory-usage,--warn-on-double-precision-use",
"-lineinfo",
"--source-in-ptx",
] + get_features_args() + get_arch_flags() + get_nvcc_thread_args(),
},
include_dirs=[
Path(this_dir) / "csrc",
Path(this_dir) / "csrc" / "kerutils" / "include", # TODO Remove me
Path(this_dir) / "csrc" / "sm90",
Path(this_dir) / "csrc" / "cutlass" / "include",
Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include",
......
build
*.so
*.egg-info/
__pycache__/
dist/
/.vscode
.cache
/temp
/profiles
from . import bench
from . import compare
from . import generate
from . import precision
from . import utils
from .bench import bench_kineto, bench_by_cuda_events
from .compare import get_cos_diff, check_is_bitwise_equal, check_is_allclose, check_is_bitwise_equal_comparator, check_is_allclose_comparator
from .generate import gen_non_contiguous_randn_tensor, gen_non_contiguous_tensor, non_contiguousify
from .precision import LowPrecisionMode, is_low_precision_mode, optional_cast_to_bf16_and_cast_back
from .utils import colors, cdiv, is_using_profiling_tools, set_random_seed, Counter
from typing import Tuple, List, Callable, Union, Dict, overload
import dataclasses
import torch
import triton
from .utils import is_using_profiling_tools
class empty_suppress:
def __enter__(self):
return self
def __exit__(self, *_):
pass
@triton.jit
def profiler_range_start_marker_kernel():
pass
def _run_profiler_range_start_marker_kernel():
profiler_range_start_marker_kernel[(1,)]()
@dataclasses.dataclass
class BenchKinetoRawResult:
"""
A struct holding the result of `bench_kineto`
"""
is_using_nsys: bool
num_tests: int
time_ranges: Dict[str, List[Tuple[float, float]]]
def _get_matched_kernel_name(self, name_substr: str, allow_no_match: bool = False, allow_multiple_match: bool = False) -> List[str]:
matched_names = [name for name in self.time_ranges.keys() if name_substr in name]
if not allow_no_match and len(matched_names) == 0:
all_kernel_names_str = '\n - ' + '\n - '.join(self.time_ranges.keys())
raise ValueError(f"Error: No kernel name matched for substring {name_substr}.\nAvailable kernels are: {all_kernel_names_str}")
if not allow_multiple_match and len(matched_names) > 1:
raise ValueError(f"Error: Multiple kernel matched for substring {name_substr}: {', '.join(matched_names)}")
return matched_names
def get_kernel_names(self) -> List[str]:
return list(self.time_ranges.keys())
def get_kernel_times(self, kernel_names_substr: List[str], allow_indivisible_run_count: bool = False, allow_missing: bool = False, allow_multiple_match: bool = False, return_avg_individual_run: bool = False) -> List[float]:
"""
Get the average each-run time usage of each kernel provided in `kernel_names`
If return_avg_individual_run is False, return sum(time) / num_tests, else return sum(time) / len(time)
If is_using_profiling_tools (which is conflict with bench_kineto), return a series of 1 seconds
"""
if is_using_profiling_tools():
return [1 for _ in range(len(kernel_names_substr))]
result = []
for substr in kernel_names_substr:
matched_names = self._get_matched_kernel_name(substr, allow_no_match=allow_missing, allow_multiple_match=allow_multiple_match)
if len(matched_names) == 0:
assert allow_missing
result.append(0)
else:
time_usage_sum = 0
run_cnt_sum = 0
for matched_name in matched_names:
run_cnt = len(self.time_ranges[matched_name])
if not allow_indivisible_run_count and run_cnt % self.num_tests != 0:
raise RuntimeError(f"Error: the number of runs for kernel {matched_name} ({run_cnt}) is indivisible by `num_tests` ({self.num_tests})")
time_usage_sum += sum([end-start for (start, end) in self.time_ranges[matched_name]])
run_cnt_sum += run_cnt
denominator = run_cnt_sum if return_avg_individual_run else self.num_tests
result.append(time_usage_sum / denominator)
return result
def get_kernel_time(self, kernel_name_substr: str) -> float:
return self.get_kernel_times([kernel_name_substr])[0]
def get_e2e_time(self, start_kernel_name_substr: str, end_kenrel_name_substr: str) -> float:
"""
Get the end-to-end time usage for a sequence of kernels
defined as "last kernel end time" - "first kernel start time"
If is_using_profiling_tools (which is conflict with bench_kineto), return 1 second
"""
if is_using_profiling_tools():
return 1
start_kernel_name = self._get_matched_kernel_name(start_kernel_name_substr)[0]
end_kernel_name = self._get_matched_kernel_name(end_kenrel_name_substr)[0]
num_start_kernels = len(self.time_ranges[start_kernel_name])
num_end_kernels = len(self.time_ranges[end_kernel_name])
if num_start_kernels%self.num_tests != 0:
raise RuntimeError(f"Error: the number of runs for kernel {start_kernel_name} ({num_start_kernels}) is indivisible by `num_tests` ({self.num_tests})")
if num_end_kernels%self.num_tests != 0:
raise RuntimeError(f"Error: the number of runs for kernel {end_kernel_name} ({num_end_kernels}) is indivisible by `num_tests` ({self.num_tests})")
time_spans = []
for i in range(self.num_tests):
end_time = self.time_ranges[end_kernel_name][(i+1)*(num_end_kernels//self.num_tests)-1][1]
start_time = self.time_ranges[start_kernel_name][i*(num_start_kernels//self.num_tests)][0]
time_spans.append((start_time, end_time))
result = sum([end-start for (start, end) in time_spans]) / self.num_tests
return result
def bench_kineto(fn: Callable, num_tests: int = 30,
flush_l2: bool = True) -> BenchKinetoRawResult:
"""
Run `fn` for `num_tests` times under `bench_kineto` (CUPTI), and returns a BenchKinetoRawResult
"""
using_nsys = is_using_profiling_tools()
# By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle
flush_l2_size = int(8e9 // 4)
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None
profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress()
with profiler:
for i in range(2):
if i == 1 and not using_nsys:
_run_profiler_range_start_marker_kernel() # This marks the start of the profiling range
for _ in range(num_tests):
if flush_l2:
torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
enable_nvtx_range = i == 1 and _ == num_tests-1
if enable_nvtx_range:
torch.cuda.nvtx.range_push("profile_target")
fn()
if enable_nvtx_range:
torch.cuda.nvtx.range_pop()
if not using_nsys:
if i == 0:
torch.cuda.synchronize()
profiler.step()
if using_nsys:
return BenchKinetoRawResult(True, num_tests, {})
from torch.autograd.profiler_util import EventList, FunctionEvent # pylint: disable=import-outside-toplevel
events: EventList = profiler.events() # type: ignore
# Filter out all events that are not function events
events: List[FunctionEvent] = [event for event in events if isinstance(event, FunctionEvent)]
# Filter out all events before the range marker
for idx, event in enumerate(events):
if event.name == "profiler_range_start_marker_kernel":
events = events[idx+1:]
break
else:
raise RuntimeError("Could not find profiler range start marker kernel event")
# Get time ranges of each kernel
kernel_times = {}
for event in events:
kernel_name = event.name
if kernel_name not in kernel_times:
kernel_times[kernel_name] = []
kernel_times[kernel_name].append((event.time_range.start/1e6, event.time_range.end/1e6))
return BenchKinetoRawResult(False, num_tests, kernel_times)
@overload
def bench_by_cuda_events(kernels: List[Callable], num_warmups_each: int, num_runs_each: int) -> List[float]: ...
@overload
def bench_by_cuda_events(kernels: Callable, num_warmups_each: int, num_runs_each: int) -> float: ...
def bench_by_cuda_events(kernels: Union[List[Callable], Callable], num_warmups_each: int, num_runs_each: int) -> Union[List[float], float]:
buf_for_l2_clear = torch.empty(int(256e6//4), dtype=torch.int32, device='cuda')
is_kernel_single_callable = isinstance(kernels, Callable)
if is_kernel_single_callable:
kernels = [kernels]
torch.cuda.synchronize()
for i in range(num_warmups_each):
for kernel in kernels:
kernel()
if i == 0:
# Ensure the first run is successful
try:
torch.cuda.synchronize()
except Exception as e:
print(f"Kernel {kernel.__name__} failed on warmup run {i}: {e}")
return []
start_events = [[torch.cuda.Event(enable_timing=True) for _ in range(num_runs_each)] for _ in kernels]
end_events = [[torch.cuda.Event(enable_timing=True) for _ in range(num_runs_each)] for _ in kernels]
for i in range(num_runs_each):
for j, kernel in enumerate(kernels):
buf_for_l2_clear.random_()
if i == num_runs_each-1:
torch.cuda.nvtx.range_push("profile_target")
start_events[j][i].record()
kernel()
end_events[j][i].record()
if i == num_runs_each-1:
torch.cuda.nvtx.range_pop()
torch.cuda.synchronize()
time_usages = [
sum([start_events[j][i].elapsed_time(end_events[j][i])*1e-3 for i in range(num_runs_each)]) / num_runs_each
for j in range(len(kernels))
]
if is_kernel_single_callable:
time_usages = time_usages[0]
return time_usages
from typing import List
import torch
def check_is_bitwise_equal_comparator(ans: torch.Tensor, ref: torch.Tensor, result: torch.Tensor):
"""
Return if two tensors are bitwise equal
Return a bool if avoid_sync is False, else return a tensor
"""
assert ans.shape == ref.shape, "Shape mismatch"
torch.all(torch.eq(ans, ref), out=result)
def check_is_bitwise_equal(name: str, ans: torch.Tensor, ref: torch.Tensor, quiet: bool = False) -> bool:
is_bitwise_equal = torch.equal(ans, ref)
if not quiet and not is_bitwise_equal:
print(f"`{name}` mismatch: not bitwise equal. Mismatch count: {(ans != ref).sum().item()} out of {ans.numel()}")
return is_bitwise_equal
def get_cos_diff(ans: torch.Tensor, ref: torch.Tensor) -> float:
"""
Calculate the cosine diff between two tensors
Return a float if avoid_sync is False, else return a tensor
"""
ans, ref = ans.double(), ref.double()
if (ref*ref).sum().item() < 1e-12:
return 0
denominator = (ans*ans + ref*ref).sum().item()
sim = 2 * (ans*ref).sum().item() / denominator
return 1 - sim
def check_is_allclose(name: str, ans: torch.Tensor, ref: torch.Tensor, abs_tol: float = 1e-5, rel_tol: float = 1e-2, cos_diff_tol: float = 1e-7, quiet: bool = False) -> bool:
"""
Check if two tensors are close enough
Return a bool if avoid_sync is False, else return a tensor
"""
assert ans.shape == ref.shape, f"`{name}` Shape mismatch: {ans.shape} vs {ref.shape}"
assert ans.dtype == ref.dtype, f"`{name}` Dtype mismatch: {ans.dtype} vs {ref.dtype}"
ans = ans.clone().to(torch.float)
ref = ref.clone().to(torch.float)
def report_err(*args, **kwargs):
if not quiet:
print(*args, **kwargs)
# Deal with anomalies
def deal_with_anomalies(val: float):
ref_mask = (ref == val) if (val == val) else (ref != ref)
ans_mask = (ans == val) if (val == val) else (ans != ans)
ref[ref_mask] = 0.0
ans[ans_mask] = 0.0
if not torch.equal(ref_mask, ans_mask):
report_err(f"`{name}` Anomaly number `{val}` mismatch: {ans_mask.sum().item()} in ans but {ref_mask.sum().item()} in ref")
return False
return True
anomalies_check_passed = True
anomalies_check_passed &= deal_with_anomalies(float("inf"))
anomalies_check_passed &= deal_with_anomalies(float("-inf"))
anomalies_check_passed &= deal_with_anomalies(float("nan"))
cos_diff = get_cos_diff(ans, ref)
raw_abs_err = torch.abs(ans-ref)
raw_rel_err = raw_abs_err / (torch.abs(ref)+(1e-6))
rel_err = raw_rel_err.masked_fill(raw_abs_err<abs_tol, 0)
abs_err = raw_abs_err.masked_fill(raw_rel_err<rel_tol, 0)
pass_mask = (abs_err < abs_tol) | (rel_err < rel_tol)
if not anomalies_check_passed:
return False
if not pass_mask.all():
report_err(f"`{name}` mismatch")
max_abs_err_pos: int = torch.argmax(abs_err, keepdim=True).item()
max_rel_err_pos: int = torch.argmax(rel_err, keepdim=True).item()
def get_pos_in_tensor(t: torch.Tensor, pos: int) -> List[int]:
result = []
for size in t.shape[::-1]:
result.append(pos % size)
pos = pos // size
assert pos == 0
return result[::-1]
report_err(f"max abs err: {torch.max(abs_err).item()}: pos {get_pos_in_tensor(ans, max_abs_err_pos)}, {ans.reshape(-1)[max_abs_err_pos].item()} vs {ref.reshape(-1)[max_abs_err_pos].item()}")
report_err(f"max rel err: {torch.max(rel_err).item()}: pos {get_pos_in_tensor(ans, max_rel_err_pos)}, {ans.reshape(-1)[max_rel_err_pos].item()} vs {ref.reshape(-1)[max_rel_err_pos].item()}")
report_err(f"{pass_mask.sum()} out of {pass_mask.numel()} passed ({pass_mask.sum()/pass_mask.numel()*100.0:.2f}%)")
report_err(f"Cosine diff: {cos_diff} (threshold: {cos_diff_tol})")
return False
else:
if abs(cos_diff) > cos_diff_tol:
report_err(f"`{name}` mismatch: Cosine diff too large: {cos_diff} vs {cos_diff_tol})")
return False
return True
def check_is_allclose_comparator(name: str, ans: torch.Tensor, ref: torch.Tensor, out: torch.Tensor, abs_tol: float = 1e-5, rel_tol: float = 1e-2, cos_diff_tol: float = 1e-7):
out.fill_(check_is_allclose(name, ans, ref, abs_tol, rel_tol, cos_diff_tol))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment