Commit 082094b7 authored by Shengyu Liu's avatar Shengyu Liu
Browse files

Multiple updates and refactorings (#150)

* Multiple updates and refactorings

* Remove dead code
parent 1408756a
#pragma once
#include <cute/tensor.hpp>
namespace cute {
// Extensions to CuTe
// CuTe don't support UTCMMA with .ws, so we add it here
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
struct SM100_MMA_F16BF16_WS_SS_NOELECT
{
static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_SS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA.");
static_assert(N == 64 || N == 128 || N == 256,
"SM100_MMA_F16BF16_WS_SS_NOELECT N-mode size should be 32, 64 or 128");
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC));
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F16BF16_WS_SS_NOELECT<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_WS_SS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_ws_1sm<c_type>;
// Logical shape-K is always 256bits, transform to units of elements
static constexpr int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_1>;
using ALayout = Layout<Shape <_1,Shape <Int<M>,Int<K>>>,
Stride<_0,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape <_1,Shape <Int<N>,Int<K>>>,
Stride<_0,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape <_1,Shape <Int<M>,Int<N>>>,
Stride<_0,Stride< _1,Int<M>>>>;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_rmem<TA>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t desc_a = A[0];
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_WS_SS_NOELECT<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One,
UMMA::Saturate c_sat = UMMA::Saturate::False>
struct SM100_MMA_F16BF16_2x1SM_TS_NOELECT
{
static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA.");
static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_TS_NOELECT N-mode size should be a multiple of 32 between 32 and 256.");
static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT A from TMEM can't be transposed");
using DRegisters = void;
using ARegisters = uint32_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint32_t const& tmem_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED)
uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0};
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC),
"r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]),
"r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_TS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED");
#endif
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg,
UMMA::Saturate c_sat>
struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_TS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg, c_sat>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::tmem_frg_2sm<a_type, a_type, UMMA::TmemAllocMode::Duplicated>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;
// Size of instructions' K extent is always 256 bits; convert to units of element
constexpr static int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_2>;
using ALayout = Layout<Shape < _2,Shape <Int<M/2>,Int<K>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape < _2,Shape <Int<N/2>,Int<K>>>,
Stride<Int<N/2>,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape < _2,Shape <Int<M/2>,Int<N>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>();
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_tmem<TA>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t tmem_a = raw_pointer_cast(A.data());
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_2x1SM_TS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg, c_sat>::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
// SM100_MMA_F16BF16_2x1SM_SS without elect_one_sync()
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
struct SM100_MMA_F16BF16_2x1SM_SS_NOELECT
{
static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA.");
static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS_NOELECT N-mode size should be a multiple of 32 between 32 and 256.");
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED)
uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0};
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC),
"r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]),
"r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED");
#endif
}
};
// template <class a_type, class b_type, class c_type,
// int M, int N, UMMA::Major a_major, UMMA::Major b_major,
// UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
// struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS_NOELECT<a_type, b_type, c_type,
// M, N, a_major, b_major,
// a_neg, b_neg>> : MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS<a_type, b_type, c_type,
// M, N, a_major, b_major,
// a_neg, b_neg>> {};
template <class a_type, class b_type, class c_type,
int M, int N,
UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS_NOELECT<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;
// Size of instructions's K extent is always 256bits, convert to units of element
constexpr static int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_2>;
using ALayout = Layout<Shape < _2,Shape <Int<M/2>,Int<K>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape < _2,Shape <Int<N/2>,Int<K>>>,
Stride<Int<N/2>,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape < _2,Shape <Int<M/2>,Int<N>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_rmem<TA>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t desc_a = A[0];
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_2x1SM_SS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
}
\ No newline at end of file
#include "../splitkv_mla.cuh"
#include "../splitkv_mla.h"
namespace sm90 {
template void run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(DenseAttnDecodeParams &params);
}
#include "../splitkv_mla.cuh"
#include "../splitkv_mla.h"
namespace sm90 {
#ifndef FLASH_MLA_DISABLE_FP16
template void run_flash_splitkv_mla_kernel<cutlass::half_t>(DenseAttnDecodeParams &params);
#endif
}
......@@ -758,7 +758,7 @@ __forceinline__ __device__ void wg0_subroutine(
TMABarrier barriers_K1[9],
bool &cur_phase_K0,
const TMAParams &tma_params,
const DecodingParams &params,
const DenseAttnDecodeParams &params,
int* block_table_ptr,
int seqlen_k,
int block_idx,
......@@ -870,7 +870,7 @@ __forceinline__ __device__ void wg1_subroutine(
TMABarrier barriers_K1[9],
bool &cur_phase_K1,
const TMAParams &tma_params,
const DecodingParams &params,
const DenseAttnDecodeParams &params,
int* block_table_ptr,
int seqlen_k,
int block_idx,
......@@ -945,7 +945,7 @@ __forceinline__ __device__ void wg1_subroutine(
}
// A helper function for determining the length of the causal mask for one q token
__forceinline__ __device__ int get_mask_len(const DecodingParams &params, int m_block_idx, int local_seq_q_idx) {
__forceinline__ __device__ int get_mask_len(const DenseAttnDecodeParams &params, int m_block_idx, int local_seq_q_idx) {
int global_seq_q_idx = m_block_idx*Config::BLOCK_SIZE_M + local_seq_q_idx;
if (global_seq_q_idx < params.q_seq_per_hk) {
int s_q_idx = global_seq_q_idx / params.q_head_per_hk;
......@@ -958,7 +958,7 @@ __forceinline__ __device__ int get_mask_len(const DecodingParams &params, int m_
template<typename T, typename TmaParams>
__global__ void __launch_bounds__(T::NUM_THREADS, 1, 1)
flash_fwd_splitkv_mla_kernel(__grid_constant__ const DecodingParams params, __grid_constant__ const TmaParams tma_params) {
flash_fwd_splitkv_mla_kernel(__grid_constant__ const DenseAttnDecodeParams params, __grid_constant__ const TmaParams tma_params) {
// grid shape: [
// num_m_blocks (=ceil_div(seqlen_q_ori*(num_q_heads//num_kv_heads))),
// num_kv_heads,
......@@ -968,7 +968,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const DecodingParams params, __gr
// If is_no_split is True, then this request is exclusively assigned to this sm_part, so we shall write the result directly into params.o_ptr and params.softmax_lse_ptr. Otherwise, write to oaccum_ptr and softmax_lseaccum_ptr, with the corresponding split idx being (n_split_idx + num_splits_ptr[batch_idx])
// For the complete schedule of the kernel, please read our deep-dive write-up (link can be found in the README.md file).
#if IS_SM90
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__))
const int m_block_idx = blockIdx.x;
const int k_head_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
......@@ -1016,30 +1016,21 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const DecodingParams params, __gr
__syncthreads();
bool cur_phase_Q = 0, cur_phase_K0 = 0, cur_phase_K1 = 0;
// Programmatic Dependent Launch: Wait for the previous kernel to finish
cudaGridDependencySynchronize();
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
// We don't use __ldg here, otherwise NVCC (ptxas, in particular) will do instruction reorder and place __ldg (LDG.E.128.CONSTANT in SASS) in front of cudaGridDependencySynchronize() (ACQBULK in SASS), leading to data race.
int4 tile_scheduler_metadata = *(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
int begin_idx = tile_scheduler_metadata.x;
int sched_begin_block_idx = tile_scheduler_metadata.y;
int end_idx = tile_scheduler_metadata.z;
int sched_end_block_idx = tile_scheduler_metadata.w;
if (begin_idx >= params.b) return;
int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4);
DecodingSchedMeta sched_meta = params.tile_scheduler_metadata_ptr[partition_idx];
if (sched_meta.begin_req_idx >= params.b) return;
// Copy the first Q
launch_q_copy<T>(tma_params, begin_idx, m_block_idx, k_head_idx, sQ, barrier_Q);
launch_q_copy<T>(tma_params, sched_meta.begin_req_idx, m_block_idx, k_head_idx, sQ, barrier_Q);
#pragma unroll 1
for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) {
for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) {
constexpr int kBlockN = T::PAGE_BLOCK_SIZE;
const int n_split_idx = batch_idx == begin_idx ? begin_n_split_idx : 0;
const int n_split_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_split_idx : 0;
int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx);
const int start_block_idx = batch_idx == begin_idx ? sched_begin_block_idx : 0;
int end_block_idx = batch_idx == end_idx ? sched_end_block_idx : cute::ceil_div(seqlen_k, kBlockN);
const bool is_no_split = __ldg(params.num_splits_ptr + batch_idx + 1) - __ldg(params.num_splits_ptr + batch_idx) == 1;
const int start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0;
int end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : cute::ceil_div(seqlen_k, kBlockN);
const bool is_no_split = batch_idx == sched_meta.begin_req_idx ? !sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? !sched_meta.is_last_req_splitted : true);
int rRightBorderForQSeq[2];
if (params.is_causal) {
......@@ -1061,7 +1052,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const DecodingParams params, __gr
// NOTE This may lead to start_block_idx >= end_block_idx which needs some special handling
int common_mask_len = get_mask_len(params, m_block_idx, T::BLOCK_SIZE_M-1);
int last_block_in_seq = cute::ceil_div(seqlen_k-common_mask_len, kBlockN);
end_block_idx = batch_idx == end_idx ? min(sched_end_block_idx, last_block_in_seq) : last_block_in_seq;
end_block_idx = batch_idx == sched_meta.end_req_idx ? min(sched_meta.end_block_idx, last_block_in_seq) : last_block_in_seq;
CUTLASS_PRAGMA_UNROLL
for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {
......@@ -1127,7 +1118,9 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const DecodingParams params, __gr
cur_phase_K0 ^= 1;
// Issue P0 = Q @ K0^T, wait
warpgroup_cooperative_qkt_gemm_no_pipeline<T>(sQ, sK0, rP0, idx_in_warpgroup);
if (start_block_idx-16777216 < end_block_idx) { // NOTE We use this `if` to prevent register spilling
warpgroup_cooperative_qkt_gemm_no_pipeline<T>(sQ, sK0, rP0, idx_in_warpgroup);
}
// We add a barrier here, making sure that previous writes to sM are visible to warpgroup 0
NamedBarrier::arrive_and_wait(128, NamedBarriers::sMInitialized);
cute::warpgroup_wait<0>();
......@@ -1225,7 +1218,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const DecodingParams params, __gr
rL[i] = (rL[i] == 0.0f || rL[i] != rL[i]) ? 1.0f : rL[i];
// Copy Q for the next batch
if (batch_idx+1 <= end_idx) {
if (batch_idx+1 <= sched_meta.end_req_idx) {
launch_q_copy<T>(tma_params, batch_idx+1, m_block_idx, k_head_idx, sQ, barrier_Q);
} else {
// Allow the next kernel (the combine kernel) to launch
......@@ -1268,7 +1261,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const DecodingParams params, __gr
cute::tma_store_wait<0>();
}
if (batch_idx != end_idx)
if (batch_idx != sched_meta.end_req_idx)
__syncthreads();
}
#else
......@@ -1280,7 +1273,10 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const DecodingParams params, __gr
template<typename InputT>
void run_flash_splitkv_mla_kernel(DecodingParams &params, cudaStream_t stream) {
void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams &params) {
FLASH_ASSERT(params.d == Config::HEAD_DIM_K);
FLASH_ASSERT(params.d_v == Config::HEAD_DIM_V);
using T = Traits<InputT>;
auto shape_Q = make_shape(params.q_seq_per_hk, params.d, params.h_k, params.b);
auto tma_Q = cute::make_tma_copy(
......@@ -1348,7 +1344,7 @@ void run_flash_splitkv_mla_kernel(DecodingParams &params, cudaStream_t stream) {
dim3(num_m_block, params.h_k, params.num_sm_parts),
dim3(T::NUM_THREADS, 1, 1),
smem_size,
stream,
params.stream,
mla_kernel_attributes,
1
};
......@@ -1356,10 +1352,4 @@ void run_flash_splitkv_mla_kernel(DecodingParams &params, cudaStream_t stream) {
CHECK_CUDA_KERNEL_LAUNCH();
}
template void run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(DecodingParams &params, cudaStream_t stream);
#ifndef FLASH_MLA_DISABLE_FP16
template void run_flash_splitkv_mla_kernel<cutlass::half_t>(DecodingParams &params, cudaStream_t stream);
#endif
}
......@@ -5,6 +5,6 @@
namespace sm90 {
template<typename InputT>
void run_flash_splitkv_mla_kernel(DecodingParams &params, cudaStream_t stream);
void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams &params);
}
......@@ -3,119 +3,19 @@
#include <cutlass/numeric_types.h>
#include <cutlass/arch/barrier.h>
#include <cute/tensor.hpp>
using bf16 = cutlass::bfloat16_t;
using fp8 = cutlass::float_e4m3_t;
using transac_bar_t = cutlass::arch::ClusterTransactionBarrier;
#include "defines.h"
using namespace cute;
static constexpr int NUM_THREADS = 128*3;
static constexpr int BLOCK_M = 64;
static constexpr int TOPK_BLOCK_SIZE = 64;
static constexpr int PAGE_BLOCK_SIZE = 64;
static constexpr int QUANT_TILE_SIZE = 128;
namespace sm90::decode::sparse_fp8 {
static constexpr int HEAD_DIM_K = 576;
static constexpr int HEAD_DIM_V = 512;
static constexpr int HEAD_DIM_NOPE = HEAD_DIM_V;
static constexpr int HEAD_DIM_ROPE = HEAD_DIM_K - HEAD_DIM_V;
static constexpr int QUANT_TILE_SIZE = 128;
static constexpr int NUM_SCALES = HEAD_DIM_NOPE / QUANT_TILE_SIZE;
static constexpr int NUM_BYTES_PER_TOKEN = HEAD_DIM_NOPE + NUM_SCALES*sizeof(float) + HEAD_DIM_ROPE*sizeof(bf16);
static constexpr int PAGE_BLOCK_SIZE = 64;
static constexpr int NUM_K_BUFS = 2;
using SmemLayoutQTile = decltype(tile_to_shape(
GMMA::Layout_SW128_Atom<bf16, GMMA::Major::K>{},
Shape<Int<BLOCK_M>, Int<64>>{}
));
template<int NUM_TILES>
using SmemLayoutQTiles = decltype(tile_to_shape(
SmemLayoutQTile{},
Shape<Int<BLOCK_M>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
));
using SmemLayoutQ = SmemLayoutQTiles<9>;
using SmemLayoutKTile = decltype(tile_to_shape(
GMMA::Layout_INTER_Atom<bf16, GMMA::Major::K>{},
Shape<Int<TOPK_BLOCK_SIZE>, _64>{},
Step<_1, _2>{}
));
template<int NUM_TILES>
using SmemLayoutKTiles = decltype(tile_to_shape(
SmemLayoutKTile{},
Shape<Int<TOPK_BLOCK_SIZE>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
));
template<int NUM_TILES>
using SmemLayoutKTilesTransposed = decltype(composition(
SmemLayoutKTiles<NUM_TILES>{},
Layout<Shape<Int<64*NUM_TILES>, Int<TOPK_BLOCK_SIZE>>, Stride<Int<TOPK_BLOCK_SIZE>, _1>>{}
));
using SmemLayoutOBuf = decltype(tile_to_shape(
GMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{}
));
using SmemLayoutOAccumBuf = Layout<
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>,
Stride<Int<520>, _1> // We use stride = 520 here to avoid bank conflict
>;
using SmemLayoutK = SmemLayoutKTiles<9>;
using SmemLayoutV = SmemLayoutKTilesTransposed<8>;
using SmemLayoutHalfV = SmemLayoutKTilesTransposed<4>;
using SmemLayoutS = decltype(tile_to_shape(
GMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<BLOCK_M>, Int<TOPK_BLOCK_SIZE>>{}
));
struct SharedMemoryPlan {
array_aligned<bf16, cosize_v<SmemLayoutQ>> q;
union {
array_aligned<bf16, cosize_v<SmemLayoutK>> k[NUM_K_BUFS];
array_aligned<bf16, cosize_v<SmemLayoutOBuf>> oBuf;
array_aligned<float, cosize_v<SmemLayoutOAccumBuf>> oAccumBuf;
} u;
array_aligned<bf16, cosize_v<SmemLayoutS>> s;
bool is_kv_valid[NUM_K_BUFS][TOPK_BLOCK_SIZE];
float sM[BLOCK_M], sL[BLOCK_M], sScale[BLOCK_M];
transac_bar_t bar_q, bar_k_local_ready[NUM_K_BUFS], bar_k_remote_ready[NUM_K_BUFS], bar_k_avail[NUM_K_BUFS];
};
template<
typename Shape_Q, typename TMA_Q,
typename Shape_O, typename TMA_O
>
struct TmaParams {
Shape_Q shape_Q; TMA_Q tma_Q;
Shape_O shape_O; TMA_O tma_O;
};
using TiledMMA_QK = decltype(make_tiled_mma(
GMMA::MMA_64x64x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::K>{},
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_QK_rQ = decltype(make_tiled_mma(
GMMA::MMA_64x64x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::K>{},
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_LocalP = decltype(make_tiled_mma(
GMMA::MMA_64x256x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_RemoteP = decltype(make_tiled_mma(
GMMA::MMA_64x256x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{}
));
}
\ No newline at end of file
......@@ -3,6 +3,10 @@
#include <cuda_fp8.h>
#include <cuda_bf16.h>
#include "defines.h"
namespace sm90::decode::sparse_fp8 {
struct fp8x8 {
__nv_fp8x4_e4m3 lo;
__nv_fp8x4_e4m3 hi;
......@@ -13,14 +17,8 @@ struct fp8x16 {
fp8x8 hi;
};
struct bf16x8 {
__nv_bfloat162 a, b, c, d;
};
__device__ __forceinline__
bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const float &scale) {
__nv_bfloat162 scale_bf162 = __float2bfloat162_rn(scale);
bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const __nv_bfloat162 &scale_bf162) {
#define DEQUANT_FP8x4(OUTPUT_BF16_LO, OUTPUT_BF16_HI, FP8x4) \
{ \
float4 fp32x4 = (float4)(FP8x4); \
......@@ -29,8 +27,8 @@ bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const float &scale) {
}
bf16x8 result;
DEQUANT_FP8x4(result.a, result.b, inputs.lo);
DEQUANT_FP8x4(result.c, result.d, inputs.hi);
DEQUANT_FP8x4(result.a01, result.a23, inputs.lo);
DEQUANT_FP8x4(result.a45, result.a67, inputs.hi);
return result;
}
......@@ -86,3 +84,44 @@ T load_128b_from_gmem(const void* addr) {
#undef DISPATCH_L2
return *reinterpret_cast<T*>(&ret);
}
template<
typename T,
L1CacheHint l1_cache_hint,
L2PrefetchHint l2_prefetch_hint
>
__device__ __forceinline__
T load_64b_from_gmem(const void* addr) {
static_assert(sizeof(T) == 64/8);
int2 ret;
#define EXEC(L1_HINT_STR, L2_HINT_STR) { \
asm volatile("ld.global.nc.L1::" L1_HINT_STR ".L2::" L2_HINT_STR ".v2.s32 {%0, %1}, [%2];" \
: "=r"(ret.x), "=r"(ret.y) \
: "l"(addr)); \
}
#define DISPATCH_L2(L1_HINT_STR) { \
if constexpr(l2_prefetch_hint == L2PrefetchHint::B64) \
EXEC(L1_HINT_STR, "64B") \
else if constexpr(l2_prefetch_hint == L2PrefetchHint::B128) \
EXEC(L1_HINT_STR, "128B") \
else if constexpr(l2_prefetch_hint == L2PrefetchHint::B256) \
EXEC(L1_HINT_STR, "256B") \
}
if constexpr(l1_cache_hint == L1CacheHint::NO_ALLOCATE)
DISPATCH_L2("no_allocate")
else if constexpr(l1_cache_hint == L1CacheHint::EVICT_FIRST)
DISPATCH_L2("evict_first")
else if constexpr(l1_cache_hint == L1CacheHint::EVICT_NORMAL)
DISPATCH_L2("evict_normal")
else if constexpr(l1_cache_hint == L1CacheHint::EVICT_LAST)
DISPATCH_L2("evict_last")
#undef EXEC
#undef DISPATCH_L2
return *reinterpret_cast<T*>(&ret);
}
}
#pragma once
#include "named_barriers.h"
// Store O / OAccum
template<
bool IS_NO_SPLIT,
typename TMAParams,
typename Tensor0,
typename Tensor1,
typename Tensor2,
typename Tensor3
>
__forceinline__ __device__ void store_o(
Tensor0 &rO, // ((2, 2, 32), 1, 1)
Tensor1 &gOorAccum, // (BLOCK_SIZE_M, HEAD_DIM_V)
Tensor2 &sOutputBuf,
Tensor3 &sOutputAccumBuf,
float rL[2],
TMAParams &tma_params,
int batch_idx,
int s_q_idx,
int head_block_idx,
int num_valid_seq_q,
int warpgroup_idx,
int idx_in_warpgroup
) {
using cutlass::arch::NamedBarrier;
if constexpr (IS_NO_SPLIT) {
// Should convert the output to bfloat16 / float16, and save it to O
Tensor rOb = make_tensor_like<bf16>(rO);
CUTLASS_PRAGMA_UNROLL
for (int idx = 0; idx < size(rO); ++idx) {
rOb(idx) = (bf16)(rO(idx) / rL[idx%4 >= 2]);
}
Tensor sMyOutputBuf = local_tile(sOutputBuf, Shape<_64, _256>{}, make_coord(_0{}, warpgroup_idx));
TiledCopy r2s_tiled_copy = make_tiled_copy_C(
Copy_Atom<SM90_U32x4_STSM_N, bf16>{},
TiledMMA_PV_LocalP{}
);
ThrCopy r2s_thr_copy = r2s_tiled_copy.get_slice(idx_in_warpgroup);
Tensor r2s_thr_copy_rOb = r2s_thr_copy.retile_S(rOb);
Tensor r2s_thr_copy_sMyOutputBuf = r2s_thr_copy.partition_D(sMyOutputBuf);
cute::copy(r2s_tiled_copy, r2s_thr_copy_rOb, r2s_thr_copy_sMyOutputBuf);
cutlass::arch::fence_view_async_shared();
NamedBarrier::arrive_and_wait(256, NamedBarriers::epilogue_r2s_ready);
if (threadIdx.x == 0) {
Tensor tma_gO = tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx, batch_idx);
auto thr_tma = tma_params.tma_O.get_slice(_0{});
Tensor my_tma_gO = flat_divide(tma_gO, Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{})(_, _, head_block_idx, _0{});
cute::copy(
tma_params.tma_O,
thr_tma.partition_S(sOutputBuf),
thr_tma.partition_D(my_tma_gO)
);
cute::tma_store_arrive();
}
} else {
// Should save the result to OAccum
CUTLASS_PRAGMA_UNROLL
for (int idx = 0; idx < size(rO); idx += 2) {
int row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%32/4) + (idx%4 >= 2 ? 8 : 0);
int col = warpgroup_idx*256 + (idx_in_warpgroup%4)*2 + idx/4*8;
*(float2*)(&(sOutputAccumBuf(row, col))) = float2 {
rO(idx) / rL[idx%4 >= 2],
rO(idx+1) / rL[idx%4 >= 2],
};
}
cutlass::arch::fence_view_async_shared();
NamedBarrier::arrive_and_wait(256, NamedBarriers::epilogue_r2s_ready);
if (elect_one_sync()) {
CUTLASS_PRAGMA_UNROLL
for (int local_row = 0; local_row < BLOCK_M / (256/32); ++local_row) {
int row = local_row * (256/32) + (threadIdx.x / 32);
if (row < num_valid_seq_q) {
SM90_BULK_COPY_S2G::copy(&sOutputAccumBuf(row, _0{}), &gOorAccum(row, _0{}), HEAD_DIM_V*sizeof(float));
}
}
cute::tma_store_arrive();
}
}
}
#pragma once
#include <cooperative_groups.h>
#include <cute/tensor.hpp>
#include "config.h"
using namespace cute;
namespace sm90::decode::sparse_fp8 {
// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~1) to the actual row_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
__forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) {
......@@ -78,9 +87,23 @@ static void st_async_128b(void* dst_ptr, const T& data, const transac_bar_t* mba
);
}
static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. 不确定是不是在所有显卡上都是这个数字
CUTE_DEVICE
static void cp_async_bulk_shared_cta_shared_cluster(void* dst_ptr, void* src_ptr, int size, transac_bar_t* mbar_ptr) {
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr);
uint32_t src_addr = cute::cast_smem_ptr_to_uint(src_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr);
asm volatile (
"cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3]; \n"
:
: "r"(dst_addr), "r"(src_addr), "r"(size), "r"(mbar_addr)
);
}
static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK.
template<typename T>
CUTE_DEVICE
T* get_peer_addr(const T* p) {
T* get_peer_addr(T* p) {
return (T*)((int64_t)(p) ^ PEER_ADDR_MASK);
}
}
#pragma once
enum NamedBarriers : uint32_t {
sScale_and_sS_ready = 0,
sScale_and_sS_free = 1,
oBuf_free_and_sL_ready = 2,
epilogue_r2s_ready = 3,
batch_loop_sync = 4,
warpgroup0_sync = 5
};
#pragma once
#include <cutlass/numeric_types.h>
#include <cutlass/arch/barrier.h>
#include <cute/tensor.hpp>
#include <kerutils/kerutils.cuh>
#include "defines.h"
#include "params.h"
using namespace cute;
namespace sm90::decode::sparse_fp8 {
template<ModelType MODEL_TYPE, int NUM_HEADS>
class KernelTemplate {
public:
static_assert(NUM_HEADS == 64 || NUM_HEADS == 128);
static constexpr int NUM_M_BLOCKS = NUM_HEADS / 64;
static constexpr int CLUSTER_SIZE = NUM_M_BLOCKS;
static constexpr int HEAD_DIM_K = MODEL_TYPE == ModelType::V32 ? 576 : 512;
static constexpr int HEAD_DIM_V = 512;
static constexpr int HEAD_DIM_ROPE = 64;
static constexpr int HEAD_DIM_NOPE = HEAD_DIM_K - HEAD_DIM_ROPE;
static constexpr int QUANT_TILE_SIZE = MODEL_TYPE == ModelType::V32 ? 128 : 64;
static constexpr int NUM_SCALES = MODEL_TYPE == ModelType::V32 ? 4 : 8; // For MODEL1: 7 fp8_e4m3 + 1 padding
static constexpr int NUM_THREADS = 128*3;
static constexpr int BLOCK_M = 64;
static constexpr int TOPK_BLOCK_SIZE = 64;
static constexpr int NUM_K_BUFS = 2;
using SmemLayoutQTile = decltype(tile_to_shape(
GMMA::Layout_SW128_Atom<bf16, GMMA::Major::K>{},
Shape<Int<BLOCK_M>, Int<64>>{}
));
template<int NUM_TILES>
using SmemLayoutQTiles = decltype(tile_to_shape(
SmemLayoutQTile{},
Shape<Int<BLOCK_M>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
));
using SmemLayoutQ = SmemLayoutQTiles<HEAD_DIM_K/64>;
using SmemLayoutKTile = decltype(tile_to_shape(
GMMA::Layout_INTER_Atom<bf16, GMMA::Major::K>{},
Shape<Int<TOPK_BLOCK_SIZE>, _64>{},
Step<_1, _2>{}
));
template<int NUM_TILES>
using SmemLayoutKTiles = decltype(tile_to_shape(
SmemLayoutKTile{},
Shape<Int<TOPK_BLOCK_SIZE>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
));
template<int NUM_TILES>
using SmemLayoutKTilesTransposed = decltype(composition(
SmemLayoutKTiles<NUM_TILES>{},
Layout<Shape<Int<64*NUM_TILES>, Int<TOPK_BLOCK_SIZE>>, Stride<Int<TOPK_BLOCK_SIZE>, _1>>{}
));
static constexpr int OBUF_SW = 64;
using SmemLayoutOBufAtom = GMMA::Layout_K_SW128_Atom<bf16>;
using SmemLayoutOBuf = decltype(tile_to_shape(
SmemLayoutOBufAtom{},
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{},
Step<_1, _2>{}
));
using SmemLayoutOAccumBuf = Layout<
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>,
Stride<Int<520>, _1> // We use stride = 520 here to avoid bank conflict
>;
using SmemLayoutK = SmemLayoutKTiles<HEAD_DIM_K/64>;
using SmemLayoutV = SmemLayoutKTilesTransposed<HEAD_DIM_V/64>;
using SmemLayoutHalfV = SmemLayoutKTilesTransposed<HEAD_DIM_V/64/2>;
using SmemLayoutS = decltype(tile_to_shape(
GMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<BLOCK_M>, Int<TOPK_BLOCK_SIZE>>{}
));
struct SharedMemoryPlan {
array_aligned<bf16, cosize_v<SmemLayoutQ>> q;
union {
array_aligned<bf16, cosize_v<SmemLayoutK>> k[NUM_K_BUFS];
array_aligned<bf16, cosize_v<SmemLayoutOBuf>> oBuf;
array_aligned<float, cosize_v<SmemLayoutOAccumBuf>> oAccumBuf;
} u;
CUTE_ALIGNAS(1024) array_aligned<bf16, cosize_v<SmemLayoutS>> s;
bool is_kv_valid[NUM_K_BUFS][TOPK_BLOCK_SIZE];
float sM[BLOCK_M], sL[BLOCK_M], sScale[BLOCK_M], sOScale[BLOCK_M];
transac_bar_t bar_q, bar_k_local_ready[NUM_K_BUFS], bar_k_remote_ready[NUM_K_BUFS], bar_k_avail[NUM_K_BUFS];
};
template<
typename Shape_Q, typename TMA_Q
>
struct TmaParams {
Shape_Q shape_Q; TMA_Q tma_Q;
CUtensorMap tensor_map_o;
};
using TiledMMA_QK = decltype(make_tiled_mma(
GMMA::MMA_64x64x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::K>{},
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_QK_rQ = decltype(make_tiled_mma(
GMMA::MMA_64x64x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::K>{},
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_LocalP = decltype(make_tiled_mma(
GMMA::MMA_64x256x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_RemoteP = decltype(make_tiled_mma(
GMMA::MMA_64x256x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{}
));
enum NamedBarriers : uint32_t {
sScale_and_sS_ready = 0,
sScale_and_sS_free = 1,
oBuf_free_and_sL_ready = 2,
epilogue_r2s_ready = 3,
batch_loop_sync = 4,
warpgroup0_sync = 5
};
// Synchronize all threads within the cluster (which processes one q token)
static __forceinline__ __device__ void sync_all_threads_in_cluster() {
if constexpr (CLUSTER_SIZE == 1) {
__syncthreads();
} else {
ku::barrier_cluster_arrive_relaxed();
ku::barrier_cluster_wait_acquire();
}
}
// Save rPb (64x64, bfloat16) to sP using the stmatrix instruction
template<
typename Tensor0,
typename Tensor1
>
static __forceinline__ __device__ void save_rPb_to_sP(
Tensor0 const &rPb,
Tensor1 const &sP,
int idx_in_warpgroup
) {
auto r2s_copy = make_tiled_copy_C(
Copy_Atom<SM90_U32x4_STSM_N, bf16>{},
TiledMMA_QK{}
);
ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup);
Tensor thr_copy_rPb = thr_copy.retile_S(rPb);
Tensor thr_copy_sP = thr_copy.partition_D(sP);
cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP);
}
template<
bool IS_NO_SPLIT,
typename TMAParams,
typename Tensor0,
typename Tensor1,
typename Tensor2,
typename Tensor3
>
static __forceinline__ __device__ void store_o(
Tensor0 &rO, // ((2, 2, 32), 1, 1)
Tensor1 &gOorAccum, // (BLOCK_SIZE_M, HEAD_DIM_V)
Tensor2 &sOutputBuf,
Tensor3 &sOutputAccumBuf,
SharedMemoryPlan &plan,
float o_scales[2],
TMAParams &tma_params,
int batch_idx,
int s_q_idx,
int head_block_idx,
int num_valid_seq_q,
int warpgroup_idx,
int idx_in_warpgroup
) {
using cutlass::arch::NamedBarrier;
if constexpr (IS_NO_SPLIT) {
// Should convert the output to bfloat16 / float16, and save it to O
// Here we don't pipeline STSM and tma store because it's slower
Tensor sMyOutputBuf = local_tile(sOutputBuf, Shape<_64, _256>{}, make_coord(_0{}, warpgroup_idx));
// Calculate "base" ptrs in advance
// Each STSM fills a chunk of shape 16x16, while we are using SW-OBUF_SW, so we need OBUF_SW/16 base pointers
constexpr int NUM_CHUNKS_IN_SW_ATOM = OBUF_SW/16;
bf16* base_output_buf_ptrs[NUM_CHUNKS_IN_SW_ATOM];
CUTE_UNROLL
for (int i = 0; i < NUM_CHUNKS_IN_SW_ATOM; ++i) {
base_output_buf_ptrs[i] = &sMyOutputBuf((idx_in_warpgroup/32)*16+idx_in_warpgroup%16, idx_in_warpgroup%32/16*8 + i*16);
}
CUTE_UNROLL
for (int idx = 0; idx < (HEAD_DIM_V/2)/16; idx += 1) {
// In each iteration we deal with a chunk of shape 16x16
using bf16x2 = __nv_bfloat162;
bf16x2 a01 = __float22bfloat162_rn(float2{rO(idx*8+0)*o_scales[0], rO(idx*8+1)*o_scales[0]});
bf16x2 a23 = __float22bfloat162_rn(float2{rO(idx*8+2)*o_scales[1], rO(idx*8+3)*o_scales[1]});
bf16x2 a45 = __float22bfloat162_rn(float2{rO(idx*8+4)*o_scales[0], rO(idx*8+5)*o_scales[0]});
bf16x2 a67 = __float22bfloat162_rn(float2{rO(idx*8+6)*o_scales[1], rO(idx*8+7)*o_scales[1]});
SM90_U32x4_STSM_N::copy(
*reinterpret_cast<uint32_t*>(&a01),
*reinterpret_cast<uint32_t*>(&a23),
*reinterpret_cast<uint32_t*>(&a45),
*reinterpret_cast<uint32_t*>(&a67),
*reinterpret_cast<uint128_t*>(base_output_buf_ptrs[idx%4] + (idx/4*4)*16*64)
);
}
cutlass::arch::fence_view_async_shared();
NamedBarrier::arrive_and_wait(256, NamedBarriers::epilogue_r2s_ready);
if (threadIdx.x == 0) {
SM90_TMA_STORE_5D::copy(
&tma_params.tensor_map_o,
plan.u.oBuf.data(),
0, head_block_idx*64, 0,
s_q_idx, batch_idx
);
cute::tma_store_arrive();
}
} else {
// Should save the result to OAccum
CUTLASS_PRAGMA_UNROLL
for (int idx = 0; idx < size(rO); idx += 2) {
int row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%32/4) + (idx%4 >= 2 ? 8 : 0);
int col = warpgroup_idx*256 + (idx_in_warpgroup%4)*2 + idx/4*8;
*(float2*)(&(sOutputAccumBuf(row, col))) = float2 {
rO(idx) * o_scales[idx%4>=2],
rO(idx+1) * o_scales[idx%4>=2],
};
}
cutlass::arch::fence_view_async_shared();
NamedBarrier::arrive_and_wait(256, NamedBarriers::epilogue_r2s_ready);
if (elect_one_sync()) {
CUTLASS_PRAGMA_UNROLL
for (int local_row = 0; local_row < BLOCK_M / (256/32); ++local_row) {
int row = local_row * (256/32) + (threadIdx.x / 32);
if (row < num_valid_seq_q) {
SM90_BULK_COPY_S2G::copy(&sOutputAccumBuf(row, _0{}), &gOorAccum(row, _0{}), HEAD_DIM_V*sizeof(float));
}
}
cute::tma_store_arrive();
}
}
}
template<typename TMAParams>
static __device__ __forceinline__ void
devfunc(const SparseAttnDecodeParams &params, const TMAParams &tma_params);
static void run(const SparseAttnDecodeParams &params);
};
}
#include "../splitkv_mla.cuh"
namespace sm90::decode::sparse_fp8 {
template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::MODEL1, 128>(const SparseAttnDecodeParams &params);
}
#include "../splitkv_mla.cuh"
namespace sm90::decode::sparse_fp8 {
template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::MODEL1, 64>(const SparseAttnDecodeParams &params);
}
#include "../splitkv_mla.cuh"
namespace sm90::decode::sparse_fp8 {
template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::V32, 128>(const SparseAttnDecodeParams &params);
}
#include "../splitkv_mla.cuh"
namespace sm90::decode::sparse_fp8 {
template void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::V32, 64>(const SparseAttnDecodeParams &params);
}
#pragma once
#include "splitkv_mla.h"
#include <cuda_fp8.h>
#include <math_constants.h>
#include <cutlass/barrier.h>
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/cluster_launch.hpp>
#include <kerutils/kerutils.cuh>
#include "utils.h"
#include "components/config.h"
#include "components/epilogue.h"
#include "components/helpers.h"
#include "components/named_barriers.h"
#include "components/dequant.h"
#include "components/helpers.h"
#include "config.h"
using namespace cute;
namespace sm90 {
namespace sm90::decode::sparse_fp8 {
static constexpr float MAX_INIT_VAL = -1e30; // Prevent (-inf) - (-inf) = nan
using cutlass::arch::fence_view_async_shared;
using cutlass::arch::NamedBarrier;
// Save rPb (64x64, bfloat16) to sP using the stmatrix instruction
template<
typename Tensor0,
typename Tensor1
>
__forceinline__ __device__ void save_rPb_to_sP(
Tensor0 const &rPb,
Tensor1 const &sP,
int idx_in_warpgroup
) {
auto r2s_copy = make_tiled_copy_C(
Copy_Atom<SM90_U32x4_STSM_N, bf16>{},
TiledMMA_QK{}
);
ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup);
Tensor thr_copy_rPb = thr_copy.retile_S(rPb);
Tensor thr_copy_sP = thr_copy.partition_D(sP);
cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP);
}
// Retrieve rPb (64x64, bfloat16) from sP using the ldmatrix instruction
template<
typename Tensor0,
typename Tensor1
>
__forceinline__ __device__ void retrieve_rP_from_sP(
Tensor0 &rPb,
Tensor1 const &sP,
int idx_in_warpgroup
) {
TiledCopy s2r_copy = make_tiled_copy_A(
Copy_Atom<SM75_U32x4_LDSM_N, bf16>{},
TiledMMA_PV_LocalP{}
);
ThrCopy thr_copy = s2r_copy.get_slice(idx_in_warpgroup);
Tensor thr_copy_sP = thr_copy.partition_S(sP);
Tensor thr_copy_rPb = thr_copy.retile_D(rPb);
cute::copy(s2r_copy, thr_copy_sP, thr_copy_rPb);
}
using fp8_e8m0 = __nv_fp8_e8m0;
template<
typename Tensor0,
......@@ -113,20 +76,21 @@ __forceinline__ __device__ void scale_softmax(
cur_rS(i) = (bf16)cur_rP(i);
cur_sum += cur_rP(i);
}
rL[local_row_idx] = rL[local_row_idx]*scale_for_old + cur_sum;
}
if (idx_in_warpgroup%4 == 0)
*(float2*)(sScale + 2*(idx_in_warpgroup/4)) = *(float2*)(scale_for_olds);
}
template<typename TmaParams>
__global__ void __launch_bounds__(NUM_THREADS, 1, 2)
flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams params, __grid_constant__ const TmaParams tma_params) {
#if IS_SM90
const int head_block_idx = blockIdx.x;
template<ModelType MODEL_TYPE, int NUM_HEADS>
template<typename TMAParams>
__device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::devfunc(const SparseAttnDecodeParams &params, const TMAParams &tma_params) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__))
const int head_block_idx = NUM_M_BLOCKS == 1 ? 0 : blockIdx.x;
const int s_q_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
const int idx_in_cluster = head_block_idx % 2;
const int idx_in_cluster = CLUSTER_SIZE == 1 ? 0 : head_block_idx % 2;
const int warpgroup_idx = cutlass::canonical_warp_group_idx();
const int idx_in_warpgroup = threadIdx.x % 128;
const int warp_idx = cutlass::canonical_warp_idx_sync();
......@@ -145,56 +109,78 @@ flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams p
// Prefetch TMA descriptors
if (warp_idx == 0 && elect_one_sync()) {
cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor());
cute::prefetch_tma_descriptor(&tma_params.tensor_map_o);
}
// Initialize TMA barriers
if (warp_idx == 0 && elect_one_sync()) {
plan.bar_q.init(1);
CUTE_UNROLL
for (int i = 0; i < NUM_K_BUFS; ++i) {
plan.bar_k_local_ready[i].init(128);
plan.bar_k_remote_ready[i].init(1);
plan.bar_k_avail[i].init(4);
if constexpr (CLUSTER_SIZE == 2) {
CUTE_UNROLL
for (int i = 0; i < NUM_K_BUFS; ++i) {
plan.bar_k_local_ready[i].init(128);
plan.bar_k_remote_ready[i].init(1);
plan.bar_k_avail[i].init(4);
}
} else {
CUTE_UNROLL
for (int i = 0; i < NUM_K_BUFS; ++i) {
plan.bar_k_local_ready[i].init(128);
plan.bar_k_avail[i].init(256);
}
}
fence_view_async_shared();
cutlass::arch::fence_barrier_init();
}
cute::cluster_arrive();
ku::barrier_cluster_arrive_relaxed();
bool bar_phase_q = 0;
int bar_phase_k = 0; // Don't use array here to prevent using local memory
// Programmatic Dependent Launch: Wait for the previous kernel to finish
// Don't use PDL because of compiler bugs!
// cudaGridDependencySynchronize();
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
int4 tile_scheduler_metadata = __ldg(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
int begin_idx = tile_scheduler_metadata.x;
int sched_begin_block_idx = tile_scheduler_metadata.y;
int end_idx = tile_scheduler_metadata.z;
int sched_end_block_idx = tile_scheduler_metadata.w;
if (begin_idx >= params.b) return;
int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
DecodingSchedMeta sched_meta = params.tile_scheduler_metadata_ptr[partition_idx];
if (sched_meta.begin_req_idx >= params.b) return;
if (warp_idx == 0 && elect_one_sync()) {
Tensor gQ = flat_divide(
tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, begin_idx),
tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, sched_meta.begin_req_idx),
Tile<Int<BLOCK_M>, Int<HEAD_DIM_K>>{}
)(_, _, head_block_idx, _0{});
launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST);
plan.bar_q.arrive_and_expect_tx(BLOCK_M*HEAD_DIM_K*sizeof(bf16));
}
cute::cluster_wait(); // Wait for barriers from the other CTA to be ready
ku::barrier_cluster_wait_acquire();
auto get_cur_req_info = [&](int batch_idx) -> std::tuple<int, int, bool> {
constexpr int kBlockN = TOPK_BLOCK_SIZE;
const int start_block_idx = batch_idx == begin_idx ? sched_begin_block_idx : 0;
// NOTE TopK attention has nothing to do with causal mask and sliding window
int end_block_idx = batch_idx == end_idx ? sched_end_block_idx : cute::ceil_div(params.topk, kBlockN);
const bool is_no_split = start_block_idx == 0 && end_block_idx == cute::ceil_div(params.topk, kBlockN);
return {start_block_idx, end_block_idx, is_no_split};
struct MainloopArgs {
int start_block_idx, end_block_idx;
bool is_no_split;
// The following fields are only valid for MODEL1
int topk_length, extra_topk_length, num_orig_kv_blocks;
};
auto get_cur_req_info = [&](int batch_idx) -> MainloopArgs {
MainloopArgs args;
int total_topk_padded;
if constexpr (MODEL_TYPE == ModelType::V32) {
total_topk_padded = params.topk;
} else {
int topk_length = params.topk_length ? __ldg(params.topk_length + batch_idx) : params.topk;
int orig_topk_padded = max(ku::ceil(topk_length, (int)TOPK_BLOCK_SIZE), (int)TOPK_BLOCK_SIZE);
int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk;
total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)TOPK_BLOCK_SIZE);
args.topk_length = topk_length;
args.extra_topk_length = extra_topk_length;
args.num_orig_kv_blocks = orig_topk_padded / TOPK_BLOCK_SIZE;
}
args.start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0;
args.end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : total_topk_padded / TOPK_BLOCK_SIZE;
args.is_no_split = batch_idx == sched_meta.begin_req_idx ? !sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? !sched_meta.is_last_req_splitted : true);
return args;
};
if (warpgroup_idx == 0) {
......@@ -210,27 +196,36 @@ flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams p
Tensor rP = partition_fragment_C(TiledMMA_QK{}, Shape<Int<BLOCK_M>, Int<TOPK_BLOCK_SIZE>>{});
Tensor rS = make_tensor<bf16>(partition_shape_A(TiledMMA_PV_LocalP{}, Shape<Int<BLOCK_M>, Int<TOPK_BLOCK_SIZE>>{}));
float rAttn_sink[2] = {-CUDART_INF_F, -CUDART_INF_F};
if (params.attn_sink != nullptr) {
for (int i = 0; i < 2; ++i) {
int head_idx = head_block_idx*BLOCK_M + get_AorC_row_idx(i, idx_in_warpgroup);
rAttn_sink[i] = __ldg((float*)params.attn_sink + head_idx) * CUDART_L2E_F;
}
}
#pragma unroll 1
for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) {
auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx);
for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) {
MainloopArgs args = get_cur_req_info(batch_idx);
rL[0] = rL[1] = 0.0f;
rM[0] = rM[1] = MAX_INIT_VAL;
cute::fill(rO, 0.);
// Wait for Q
plan.bar_q.wait(bar_phase_q);
bar_phase_q ^= 1;
plan.bar_q.wait((sched_meta.begin_req_idx-batch_idx)&1);
CUTE_NO_UNROLL
for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) {
int buf_idx = (block_idx-start_block_idx) % NUM_K_BUFS;
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; block_idx++) {
int buf_idx = (block_idx-args.start_block_idx) % NUM_K_BUFS;
Tensor sK = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutHalfV{});
// Wait, issue WGMMA
plan.bar_k_local_ready[buf_idx].wait(bar_phase_k>>buf_idx&1);
plan.bar_k_remote_ready[buf_idx].wait(bar_phase_k>>buf_idx&1);
if constexpr (CLUSTER_SIZE == 2) {
plan.bar_k_remote_ready[buf_idx].wait(bar_phase_k>>buf_idx&1);
}
gemm<true, -1>(
tiled_mma_QK,
......@@ -244,11 +239,11 @@ flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams p
cute::warpgroup_wait<0>();
// Calculate S = softmax(mask(scale(P)))
if (block_idx != start_block_idx)
if (block_idx != args.start_block_idx)
NamedBarrier::arrive_and_wait(256, NamedBarriers::sScale_and_sS_free); // Make sure that sScale and sS is free
// Since in our case TOPK_BLOCK_SIZE == BLOCK_M, so we only need to do OOB checking for the last 2 blocks
scale_softmax(rP, rS, rO, params.scale_softmax_log2, sScale, rM, rL, plan.is_kv_valid[buf_idx], block_idx, idx_in_warpgroup);
scale_softmax(rP, rS, rO, params.sm_scale_div_log2, sScale, rM, rL, plan.is_kv_valid[buf_idx], block_idx, idx_in_warpgroup);
// Store S into shared, inform warpgroup 1
save_rPb_to_sP(rS, sS, idx_in_warpgroup);
......@@ -266,13 +261,17 @@ flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams p
cute::warpgroup_wait<0>();
plan.bar_k_avail[buf_idx].arrive(0, idx_in_warpgroup == 32);
plan.bar_k_avail[buf_idx].arrive(1, idx_in_warpgroup == 64);
if constexpr (CLUSTER_SIZE == 2) {
plan.bar_k_avail[buf_idx].arrive(0, idx_in_warpgroup == 32);
plan.bar_k_avail[buf_idx].arrive(1, idx_in_warpgroup == 64);
} else {
plan.bar_k_avail[buf_idx].arrive();
}
}
// Copy the next q
if (warp_idx == 0 && elect_one_sync()) {
if (batch_idx != end_idx) {
if (threadIdx.x/32 == 0 && elect_one_sync()) {
if (batch_idx != sched_meta.end_req_idx) {
Tensor gQ = flat_divide(
tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, batch_idx+1),
Tile<Int<BLOCK_M>, Int<HEAD_DIM_K>>{}
......@@ -280,6 +279,7 @@ flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams p
launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST);
plan.bar_q.arrive_and_expect_tx(BLOCK_M*HEAD_DIM_K*sizeof(bf16));
} else {
// This kernel is followed by the combine kernel, so we signal PDL here
cudaTriggerProgrammaticLaunchCompletion();
}
}
......@@ -289,6 +289,7 @@ flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams p
rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2);
rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1);
rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2);
if (idx_in_warpgroup%4 == 0) {
CUTE_UNROLL
for (int i = 0; i < 2; ++i) {
......@@ -297,6 +298,20 @@ flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams p
sM[row] = rM[i];
}
}
float o_scales[2];
CUTE_UNROLL
for (int i = 0; i < 2; ++i) {
if (args.is_no_split) {
o_scales[i] = rL[i] == 0.0f ? 0.0f : __fdividef(1.0f, rL[i] + exp2f(rAttn_sink[i] - rM[i]));
} else {
o_scales[i] = rL[i] == 0.0f ? 0.0f : __fdividef(1.0f, rL[i]);
}
if (idx_in_warpgroup%4 == 0) {
int row = get_AorC_row_idx(i, idx_in_warpgroup);
plan.sOScale[row] = o_scales[i];
}
}
// This is a synchronization point for warpgroup 0/1.
// Warpgroup 0 should wait wg 1 for oBuf/oAccumBuf (overlapped with k) to be free
......@@ -307,17 +322,17 @@ flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams p
for (int i = 0; i < 2; ++i)
rL[i] = rL[i] == 0.0f ? 1.0f : rL[i];
int num_valid_seq_q = min(params.q_head_per_hk - head_block_idx*BLOCK_M, BLOCK_M);
int start_seq_idx = s_q_idx*params.q_head_per_hk + head_block_idx*BLOCK_M;
if (is_no_split) {
bf16* o_ptr = (bf16*)params.o_ptr + batch_idx*params.o_batch_stride + start_seq_idx*params.o_row_stride; // (BLOCK_M, HEAD_DIM_V) : (params.o_row_stride, 1)
int start_head_idx = head_block_idx*BLOCK_M;
int num_valid_seq_q = min(params.h_q - start_head_idx, BLOCK_M);
if (args.is_no_split) {
bf16* o_ptr = (bf16*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + start_head_idx*params.stride_o_h_q; // (BLOCK_M, HEAD_DIM_V) : (params.stride_o_h_q, 1)
Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout(
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{},
make_stride(params.o_row_stride, _1{})
make_stride(params.stride_o_h_q, _1{})
));
float* gSoftmaxLse = (float*)params.softmax_lse_ptr + batch_idx*params.q_seq_per_hk + start_seq_idx; // (BLOCK_M) : (1)
float* gSoftmaxLse = (float*)params.lse + batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + start_head_idx; // (BLOCK_M) : (1)
store_o<true>(rO, gO, sOBuf, sOAccumBuf, rL, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);
store_o<true>(rO, gO, sOBuf, sOAccumBuf, plan, o_scales, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);
int i = threadIdx.x;
if (i < num_valid_seq_q) {
......@@ -327,15 +342,15 @@ flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams p
cute::tma_store_wait<0>();
} else {
int n_split_idx = batch_idx == begin_idx ? begin_n_split_idx : 0;
int n_split_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_split_idx : 0;
int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx;
float* oaccum_ptr = (float*)params.oaccum_ptr + (split_idx*params.q_seq_per_hk + start_seq_idx)*HEAD_DIM_V; // (BLOCK_M, HEAD_DIM_V) : (HEAD_DIM_V, 1)
float* gSoftmaxLseAccum = (float*)params.softmax_lseaccum_ptr + split_idx*params.q_seq_per_hk + start_seq_idx; // (BLOCK_M) : (1)
Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout<
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>,
Stride<Int<HEAD_DIM_V>, _1>
>{});
store_o<false>(rO, gOAccum, sOBuf, sOAccumBuf, rL, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);
float* oaccum_ptr = (float*)params.o_accum + split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + start_head_idx*params.stride_o_accum_h_q; // (BLOCK_M, HEAD_DIM_V) : (params.stride_o_accum_h_q, 1)
float* gSoftmaxLseAccum = (float*)params.lse_accum + split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + start_head_idx; // (BLOCK_M) : (1)
Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), make_layout(
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{},
make_stride(params.stride_o_accum_h_q, _1{})
));
store_o<false>(rO, gOAccum, sOBuf, sOAccumBuf, plan, o_scales, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);
int i = threadIdx.x;
if (i < num_valid_seq_q) {
......@@ -346,7 +361,7 @@ flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams p
cute::tma_store_wait<0>();
}
cute::cluster_sync(); // Must use arrive_and_wait here to prevent overwritting sL while WG1 is writing back its result
sync_all_threads_in_cluster();
}
} else if (warpgroup_idx == 1) {
cutlass::arch::warpgroup_reg_dealloc<160>();
......@@ -354,16 +369,15 @@ flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams p
TiledMMA tiled_mma_PV = TiledMMA_PV_RemoteP{};
ThrMMA thr_mma_PV = tiled_mma_PV.get_slice(idx_in_warpgroup);
Tensor rO = partition_fragment_C(tiled_mma_PV, Shape<Int<BLOCK_M>, Int<HEAD_DIM_V/2>>{});
float rL[2];
#pragma unroll 1
for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) {
auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx);
for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) {
MainloopArgs args = get_cur_req_info(batch_idx);
cute::fill(rO, 0.);
CUTE_NO_UNROLL
for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) {
int buf_idx = (block_idx-start_block_idx) % NUM_K_BUFS;
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; block_idx++) {
int buf_idx = (block_idx-args.start_block_idx) % NUM_K_BUFS;
Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data() + (SmemLayoutV{})(_256{}, _0{})), SmemLayoutHalfV{});
// Wait for S and sScale
......@@ -390,149 +404,271 @@ flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams p
);
cute::warpgroup_wait<0>();
plan.bar_k_avail[buf_idx].arrive(0, idx_in_warpgroup == 32);
plan.bar_k_avail[buf_idx].arrive(1, idx_in_warpgroup == 64);
if constexpr (CLUSTER_SIZE == 2) {
plan.bar_k_avail[buf_idx].arrive(0, idx_in_warpgroup == 32);
plan.bar_k_avail[buf_idx].arrive(1, idx_in_warpgroup == 64);
} else {
plan.bar_k_avail[buf_idx].arrive();
}
if (block_idx != end_block_idx-1)
if (block_idx != args.end_block_idx-1)
NamedBarrier::arrive(256, NamedBarriers::sScale_and_sS_free); // Tell WG0 that sScale and sS are available
}
NamedBarrier::arrive_and_wait(256, NamedBarriers::oBuf_free_and_sL_ready);
float o_scales[2];
CUTE_UNROLL
for (int i = 0; i < 2; ++i) {
int row = get_AorC_row_idx(i, idx_in_warpgroup);
rL[i] = sL[row];
o_scales[i] = plan.sOScale[row];
}
CUTE_UNROLL
for (int i = 0; i < 2; ++i)
rL[i] = rL[i] == 0.0f ? 1.0f : rL[i];
int num_valid_seq_q = min(params.q_head_per_hk - head_block_idx*BLOCK_M, BLOCK_M);
int start_seq_idx = s_q_idx*params.q_head_per_hk+head_block_idx*BLOCK_M;
if (is_no_split) {
bf16* o_ptr = (bf16*)params.o_ptr + batch_idx*params.o_batch_stride + start_seq_idx*params.o_row_stride; // (BLOCK_M, HEAD_DIM_V) : (params.o_row_stride, 1)
int start_head_idx = head_block_idx*BLOCK_M;
int num_valid_seq_q = min(params.h_q - start_head_idx, BLOCK_M);
if (args.is_no_split) {
bf16* o_ptr = (bf16*)params.out + batch_idx*params.stride_o_b + s_q_idx*params.stride_o_s_q + start_head_idx*params.stride_o_h_q; // (BLOCK_M, HEAD_DIM_V) : (params.stride_o_h_q, 1)
Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout(
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{},
make_stride(params.o_row_stride, _1{})
make_stride(params.stride_o_h_q, _1{})
));
store_o<true>(rO, gO, sOBuf, sOAccumBuf, rL, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);
store_o<true>(rO, gO, sOBuf, sOAccumBuf, plan, o_scales, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);
cute::tma_store_wait<0>();
} else {
int n_split_idx = batch_idx == begin_idx ? begin_n_split_idx : 0;
int n_split_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_split_idx : 0;
int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx;
float* oaccum_ptr = (float*)params.oaccum_ptr + (split_idx*params.q_seq_per_hk + start_seq_idx)*HEAD_DIM_V; // (BLOCK_M, HEAD_DIM_V) : (HEAD_DIM_V, 1)
Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout<
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>,
Stride<Int<HEAD_DIM_V>, _1>
>{});
store_o<false>(rO, gOAccum, sOBuf, sOAccumBuf, rL, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);
float* oaccum_ptr = (float*)params.o_accum + split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + start_head_idx*params.stride_o_accum_h_q; // (BLOCK_M, HEAD_DIM_V) : (params.stride_o_accum_h_q, 1)
Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), make_layout(
Shape<Int<BLOCK_M>, Int<HEAD_DIM_V>>{},
make_stride(params.stride_o_accum_h_q, _1{})
));
store_o<false>(rO, gOAccum, sOBuf, sOAccumBuf, plan, o_scales, tma_params, batch_idx, s_q_idx, head_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);
cute::tma_store_wait<0>();
}
cute::cluster_sync(); // We must use arrive_and_wait instead of arrive here to create an order between "forall warp in WG1, warp has done written back O" and "warp 2 signals `bar_k_avail`"
sync_all_threads_in_cluster();
}
} else {
// Producer warpgroup
cutlass::arch::warpgroup_reg_dealloc<152>();
int warp_idx = __shfl_sync(0xffffffff, idx_in_warpgroup / 32, 0); // NOTE TPBNO
static_assert(CLUSTER_SIZE == 1 || CLUSTER_SIZE == 2);
static constexpr int NUM_TOKENS_PER_THREAD = CLUSTER_SIZE == 1 ? 2 : 1;
static constexpr int NUM_TOKENS_PER_ROUND = 32; // If head is 128, each CTA is responsible for dequantizing 32 tokens (1 rounds); if head is 64, each CTA is responsible for dequantizing 64 tokens (2 rounds)
int warp_idx = __shfl_sync(0xffffffff, idx_in_warpgroup / 32, 0);
int lane_idx = idx_in_warpgroup % 32;
int my_token_idx = warp_idx*8 + lane_idx%8;
int my_token_idx_base = warp_idx*8 + lane_idx%8;
CUTE_NO_UNROLL
for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) {
auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx);
int* gIndices = params.indices_ptr + batch_idx*params.indices_batch_stride + s_q_idx*params.indices_row_stride; // (topk) : (1)
for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) {
MainloopArgs args = get_cur_req_info(batch_idx);
int* gIndices = params.indices + batch_idx*params.stride_indices_b + s_q_idx*params.stride_indices_s_q; // (topk) : (1)
int* gExtraIndices = params.extra_indices + batch_idx*params.stride_extra_indices_b + s_q_idx*params.stride_extra_indices_s_q; // (extra_topk) : (1)
#define GET_TOKEN_INDEX(block_idx) __ldg(gIndices + (block_idx)*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx)
int nxt_token_index = GET_TOKEN_INDEX(start_block_idx);
CUTE_NO_UNROLL
for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) {
int buf_idx = (block_idx-start_block_idx) % NUM_K_BUFS;
int nxt_token_indexs[NUM_TOKENS_PER_THREAD];
CUTE_UNROLL
for (int round = 0; round < NUM_TOKENS_PER_THREAD; ++round) {
if (MODEL_TYPE == ModelType::V32 || args.start_block_idx < args.num_orig_kv_blocks)
nxt_token_indexs[round] = __ldg(gIndices + args.start_block_idx*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + round*NUM_TOKENS_PER_ROUND + my_token_idx_base);
}
// Define shared and global tensors
bf16* sK_nope_base = plan.u.k[buf_idx].data() + (idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx)*8 + ((lane_idx/8)*16)*TOPK_BLOCK_SIZE;
bf16* sK_nope_peer_base = get_peer_addr(sK_nope_base);
transac_bar_t* peer_bar_k_remote_ready = get_peer_addr(&(plan.bar_k_remote_ready[buf_idx]));
int token_index = nxt_token_index;
if (block_idx+1 != end_block_idx)
nxt_token_index = GET_TOKEN_INDEX(block_idx+1);
int block_index = token_index/PAGE_BLOCK_SIZE;
int rel_idx_in_block = (token_index+PAGE_BLOCK_SIZE) % PAGE_BLOCK_SIZE; // NOTE When token_index is -1, -1/PAGE_BLOCK_SIZE = 0 and (-1+PAGE_BLOCK_SIZE)%PAGE_BLOCK_SIZE = 63, so there will be no illegal-memory-access error
fp8* gK_base = (fp8*)params.k_ptr + block_index*params.k_batch_stride + rel_idx_in_block*params.k_row_stride;
float4 scales = load_128b_from_gmem<float4, L1CacheHint::EVICT_LAST, L2PrefetchHint::B128>((float*)(gK_base+HEAD_DIM_NOPE));
// Wait for the nope buffer to be available
plan.bar_k_avail[buf_idx].wait((bar_phase_k>>buf_idx&1)^1);
bar_phase_k ^= 1 << buf_idx;
// Copy block #block_index
if (idx_in_warpgroup == 0) {
plan.bar_k_remote_ready[buf_idx].arrive_and_expect_tx((TOPK_BLOCK_SIZE/2)*(HEAD_DIM_NOPE+HEAD_DIM_ROPE)*sizeof(bf16));
struct IsOrigBlock {};
struct IsExtraBlock {};
struct IsFirstExtraBlock {};
struct IsNotFirstExtraBlock {};
auto process_one_block = [&](int block_idx, auto is_extra_block_t, auto is_first_extra_block_t) {
static constexpr bool IS_EXTRA_BLOCK = std::is_same_v<decltype(is_extra_block_t), IsExtraBlock>;
static constexpr bool IS_FIRST_EXTRA_BLOCK = std::is_same_v<decltype(is_first_extra_block_t), IsFirstExtraBlock>;
int buf_idx = (block_idx-args.start_block_idx) % NUM_K_BUFS;
int* indices_base;
int page_block_size;
int64_t k_block_stride, k_row_stride;
fp8* k_ptr;
if constexpr (!IS_EXTRA_BLOCK) {
indices_base = gIndices + (block_idx)*TOPK_BLOCK_SIZE;
page_block_size = params.page_block_size;
k_block_stride = params.stride_kv_block;
k_row_stride = params.stride_kv_row;
k_ptr = (fp8*)params.kv;
} else {
indices_base = gExtraIndices + (block_idx-args.num_orig_kv_blocks)*TOPK_BLOCK_SIZE;
page_block_size = params.extra_page_block_size;
k_block_stride = params.stride_extra_kv_block;
k_row_stride = params.stride_extra_kv_row;
k_ptr = (fp8*)params.extra_kv;
}
[[maybe_unused]] int topk_length = IS_EXTRA_BLOCK ? args.extra_topk_length : args.topk_length;
[[maybe_unused]] int rel_block_idx = IS_EXTRA_BLOCK ? (block_idx - args.num_orig_kv_blocks) : block_idx;
transac_bar_t* peer_bar_k_remote_ready = get_peer_addr(&(plan.bar_k_remote_ready[buf_idx]));
// Collectively copy from global memory and dequant
// For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py
fp8* gK_nope = gK_base + (lane_idx/8)*16;
if (token_index == -1) {
scales = {0.0f, 0.0f, 0.0f, 0.0f};
}
CUTE_UNROLL
for (int dim_idx = 0; dim_idx < HEAD_DIM_NOPE/64; dim_idx += 1) {
fp8x16 cur_fp8x16 = load_128b_from_gmem<fp8x16, L1CacheHint::EVICT_LAST, L2PrefetchHint::B256>(gK_nope + dim_idx*64); // We use EVICT_LAST here since gK_base may not be aligned to 32B
float scale = dim_idx < 4 ? (dim_idx < 2 ? scales.x : scales.y) : (dim_idx < 6 ? scales.z : scales.w);
auto dequant_and_save_bf16x8 = [&](const fp8x8 &data, int offset) {
int smem_offset = (dim_idx*64 + offset) * TOPK_BLOCK_SIZE;
bf16x8 cur_bf16x8 = cvt_fp8x8_bf16x8(data, scale);
*(__int128_t*)(sK_nope_base + smem_offset) = *(__int128_t*)&cur_bf16x8;
st_async_128b(sK_nope_peer_base + smem_offset, cur_bf16x8, peer_bar_k_remote_ready);
};
if (token_index == -1)
*(uint128_t*)(&cur_fp8x16) = uint128_t();
dequant_and_save_bf16x8(cur_fp8x16.lo, 0);
dequant_and_save_bf16x8(cur_fp8x16.hi, 8);
}
for (int round = 0; round < NUM_TOKENS_PER_THREAD; ++round) {
int my_token_idx = my_token_idx_base + round*NUM_TOKENS_PER_ROUND;
bf16* sK_nope_base = plan.u.k[buf_idx].data() + (idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx)*8 + ((lane_idx/8)*16)*TOPK_BLOCK_SIZE;
bf16* sK_nope_peer_base = get_peer_addr(sK_nope_base);
// Get prefetched token index
int token_index;
if constexpr (!IS_EXTRA_BLOCK) {
token_index = nxt_token_indexs[round];
if (block_idx+1 != (MODEL_TYPE == ModelType::V32 ? args.end_block_idx : args.num_orig_kv_blocks))
nxt_token_indexs[round] = __ldg(gIndices + (block_idx+1)*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx);
} else {
if constexpr (IS_FIRST_EXTRA_BLOCK) {
token_index = __ldg(gExtraIndices + (block_idx-args.num_orig_kv_blocks)*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx);
} else {
token_index = nxt_token_indexs[round];
}
if (block_idx+1 != args.end_block_idx)
nxt_token_indexs[round] = __ldg(gExtraIndices + (block_idx+1-args.num_orig_kv_blocks)*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx);
}
if constexpr (MODEL_TYPE == ModelType::MODEL1) {
// For MODEL1, we need to check whether the token_index is within topk_length
if (rel_block_idx*TOPK_BLOCK_SIZE + idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx >= topk_length) {
token_index = -1; // To prevent IMA when we have invalid (e.g. INT_MAX) topk indexes outside topk_length
}
}
bf16* gK_rope = (bf16*)(gK_base+HEAD_DIM_NOPE+NUM_SCALES*sizeof(float)) + (lane_idx/8)*8;
bf16* sK_rope_base = plan.u.k[buf_idx].data() + (idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx)*8 + ((lane_idx/8)*8)*TOPK_BLOCK_SIZE;
bf16* sK_rope_peer_base = get_peer_addr(sK_rope_base);
int block_index = token_index == -1 ? 0 : (int)((uint32_t)token_index/(uint32_t)page_block_size); // Use uint32_t division and mod to improve performance
int rel_idx_in_block = (uint32_t)token_index % (uint32_t)page_block_size; // NOTE When token_index is -1 (UINT_MAX), UINT_MAX%page_block_size < page_block_size, so there will be no illegal-memory-access error
fp8* gK_base;
bf16 scales[NUM_SCALES];
if constexpr (MODEL_TYPE == ModelType::V32) {
static_assert(NUM_SCALES == 4);
gK_base = k_ptr + block_index*k_block_stride + rel_idx_in_block*k_row_stride;
float scales_float[NUM_SCALES];
*(float4*)(scales_float) = load_128b_from_gmem<float4, L1CacheHint::EVICT_LAST, L2PrefetchHint::B128>((float*)(gK_base+HEAD_DIM_NOPE));
CUTE_UNROLL
for (int i = 0; i < NUM_SCALES; ++i) {
scales[i] = (bf16)scales_float[i];
}
} else {
static_assert(NUM_SCALES == 8);
gK_base = k_ptr + block_index*k_block_stride + rel_idx_in_block*(HEAD_DIM_NOPE + HEAD_DIM_ROPE*sizeof(bf16));
fp8_e8m0* gK_scales_base = (fp8_e8m0*)(k_ptr + block_index*k_block_stride + page_block_size*(HEAD_DIM_NOPE+HEAD_DIM_ROPE*sizeof(bf16)) + rel_idx_in_block*NUM_SCALES*sizeof(fp8_e8m0));
fp8_e8m0 scales_e8m0[NUM_SCALES];
*(int64_t*)scales_e8m0 = __ldg((int64_t*)gK_scales_base);
CUTE_UNROLL
for (int i = 0; i < NUM_SCALES; i += 2) {
*(__nv_bfloat162_raw*)(scales+i) = __nv_cvt_e8m0x2_to_bf162raw(*(__nv_fp8x2_storage_t*)(scales_e8m0+i));
}
}
CUTE_UNROLL
for (int dim_idx = 0; dim_idx < HEAD_DIM_ROPE/32; dim_idx += 1) {
bf16x8 cur_bf16x8 = load_128b_from_gmem<bf16x8, L1CacheHint::EVICT_LAST, L2PrefetchHint::B128>(gK_rope + dim_idx*32);
if (token_index == -1)
*(uint128_t*)(&cur_bf16x8) = uint128_t();
int smem_offset = (HEAD_DIM_NOPE + dim_idx*32) * TOPK_BLOCK_SIZE;
*(__int128_t*)(sK_rope_base + smem_offset) = *(__int128_t*)&cur_bf16x8;
st_async_128b(sK_rope_peer_base + smem_offset, cur_bf16x8, peer_bar_k_remote_ready);
// Wait for the nope buffer to be available
if (round == 0) {
plan.bar_k_avail[buf_idx].wait((bar_phase_k>>buf_idx&1)^1);
}
if (CLUSTER_SIZE == 2 && round == 0 && idx_in_warpgroup == 0) {
plan.bar_k_remote_ready[buf_idx].arrive_and_expect_tx((TOPK_BLOCK_SIZE/2)*(HEAD_DIM_NOPE+HEAD_DIM_ROPE)*sizeof(bf16));
}
// Collectively copy from global memory and dequant
// For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py
fp8* gK_nope = gK_base + (lane_idx/8)*16;
if (token_index == -1) {
CUTE_UNROLL
for (int i = 0; i < NUM_SCALES; ++i)
scales[i] = (bf16)0.0f;
}
CUTE_UNROLL
for (int dim_idx = 0; dim_idx < HEAD_DIM_NOPE/64; dim_idx += 1) {
fp8x16 cur_fp8x16 = load_128b_from_gmem<fp8x16, L1CacheHint::EVICT_LAST, L2PrefetchHint::B256>(gK_nope + dim_idx*64); // We use EVICT_LAST here since gK_base may not be aligned to 32B (for V3.2) and the performance is the best among all cache hints (for MODEL1)
bf16 scale = scales[MODEL_TYPE == ModelType::V32 ? dim_idx/2 : dim_idx];
auto dequant_and_save_bf16x8 = [&](const fp8x8 &data, int offset) {
int smem_offset = (dim_idx*64 + offset) * TOPK_BLOCK_SIZE;
bf16x8 cur_bf16x8 = cvt_fp8x8_bf16x8(data, __bfloat162bfloat162(*(__nv_bfloat16*)(&scale)));
*(__int128_t*)(sK_nope_base + smem_offset) = *(__int128_t*)&cur_bf16x8;
if constexpr (CLUSTER_SIZE == 2) {
st_async_128b(sK_nope_peer_base + smem_offset, cur_bf16x8, peer_bar_k_remote_ready);
}
};
if (token_index == -1)
*(uint128_t*)(&cur_fp8x16) = uint128_t();
dequant_and_save_bf16x8(cur_fp8x16.lo, 0);
dequant_and_save_bf16x8(cur_fp8x16.hi, 8);
}
bf16* gK_rope;
if constexpr (MODEL_TYPE == ModelType::V32) {
gK_rope = (bf16*)(gK_base+HEAD_DIM_NOPE+NUM_SCALES*sizeof(float)) + (lane_idx/8)*8;
} else {
gK_rope = (bf16*)(gK_base+HEAD_DIM_NOPE) + (lane_idx/8)*8;
}
bf16* sK_rope_base = plan.u.k[buf_idx].data() + (idx_in_cluster*(TOPK_BLOCK_SIZE/2) + my_token_idx)*8 + ((lane_idx/8)*8)*TOPK_BLOCK_SIZE;
bf16* sK_rope_peer_base = get_peer_addr(sK_rope_base);
CUTE_UNROLL
for (int dim_idx = 0; dim_idx < HEAD_DIM_ROPE/32; dim_idx += 1) {
bf16x8 cur_bf16x8 = load_128b_from_gmem<bf16x8, L1CacheHint::EVICT_LAST, L2PrefetchHint::B128>(gK_rope + dim_idx*32);
if constexpr (MODEL_TYPE == ModelType::V32) {
// NOTE We do not need to mask the RoPE part for V3.2 since it isn't involved in the SV gemm
} else {
if (token_index == -1)
*(uint128_t*)(&cur_bf16x8) = uint128_t();
}
int smem_offset = (HEAD_DIM_NOPE + dim_idx*32) * TOPK_BLOCK_SIZE;
*(__int128_t*)(sK_rope_base + smem_offset) = *(__int128_t*)&cur_bf16x8;
if constexpr (CLUSTER_SIZE == 2) {
st_async_128b(sK_rope_peer_base + smem_offset, cur_bf16x8, peer_bar_k_remote_ready);
}
}
}
fence_view_async_shared();
if (idx_in_warpgroup < 32) {
// We put this after fence_view_async_shared() since this won't be read by async proxy
int2 indices = __ldg((int2*)(gIndices + block_idx*TOPK_BLOCK_SIZE + lane_idx*2));
*(char2*)(&plan.is_kv_valid[buf_idx][lane_idx*2]) = {indices.x != -1, indices.y != -1};
auto is_index_valid = [&](int index, int offset_within_thread) -> bool {
if constexpr (MODEL_TYPE == ModelType::V32) {
return index != -1;
} else {
return index != -1 && rel_block_idx*TOPK_BLOCK_SIZE + lane_idx*2 + offset_within_thread < topk_length;
}
};
int2 indices = __ldg((int2*)(indices_base + lane_idx*2));
*(char2*)(&plan.is_kv_valid[buf_idx][lane_idx*2]) = {
is_index_valid(indices.x, 0),
is_index_valid(indices.y, 1)
};
}
// Signal the barrier
plan.bar_k_local_ready[buf_idx].arrive();
bar_phase_k ^= 1 << buf_idx;
};
if constexpr (MODEL_TYPE == ModelType::V32) {
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {
process_one_block(block_idx, IsOrigBlock{}, IsNotFirstExtraBlock{});
}
} else {
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < min(args.num_orig_kv_blocks, args.end_block_idx); ++block_idx) {
process_one_block(block_idx, IsOrigBlock{}, IsNotFirstExtraBlock{});
}
if (args.num_orig_kv_blocks < args.end_block_idx) {
process_one_block(max(args.start_block_idx, args.num_orig_kv_blocks), IsExtraBlock{}, IsFirstExtraBlock{});
}
CUTE_NO_UNROLL
for (int block_idx = max(args.start_block_idx, args.num_orig_kv_blocks)+1; block_idx < args.end_block_idx; ++block_idx) {
process_one_block(block_idx, IsExtraBlock{}, IsNotFirstExtraBlock{});
}
}
cute::cluster_sync();
sync_all_threads_in_cluster();
}
}
if (begin_idx > end_idx) {
cute::cluster_sync(); // Don't need a cluster_sync() when begin_idx <= end_idx, since the loop will execute at least once and the final statement is cluster_sync()
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm90");
......@@ -541,50 +677,82 @@ flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams p
}
template<typename Kernel, typename TMAParams>
__global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, Kernel::CLUSTER_SIZE)
flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const SparseAttnDecodeParams params, __grid_constant__ const TMAParams tma_params) {
Kernel::devfunc(params, tma_params);
}
void run_flash_splitkv_mla_fp8_sparse_kernel(DecodingParams &params, cudaStream_t stream) {
FLASH_ASSERT(params.h_k == 1);
FLASH_ASSERT(params.topk % TOPK_BLOCK_SIZE == 0);
template<ModelType MODEL_TYPE, int NUM_HEADS>
void KernelTemplate<MODEL_TYPE, NUM_HEADS>::run(const SparseAttnDecodeParams &params) {
KU_ASSERT(params.h_kv == 1);
KU_ASSERT(params.topk % TOPK_BLOCK_SIZE == 0);
KU_ASSERT(params.d_qk == HEAD_DIM_K);
KU_ASSERT(params.d_v == HEAD_DIM_V);
KU_ASSERT(params.h_q % BLOCK_M == 0);
if constexpr (MODEL_TYPE == ModelType::MODEL1) {
constexpr int BYTES_PER_TOKEN = HEAD_DIM_NOPE + 2*HEAD_DIM_ROPE + 8;
KU_ASSERT(params.stride_kv_row == BYTES_PER_TOKEN, "Each page block in KV cache must be contiguous for head64 sparse fp8 decoding attention in MODEL1"); // Each block must be contiguous
if (params.extra_kv != nullptr) {
KU_ASSERT(params.stride_extra_kv_row == BYTES_PER_TOKEN, "Each page block in extra KV cache must be contiguous for head64 sparse fp8 decoding attention in MODEL1"); // Each block must be contiguous
}
} else {
KU_ASSERT(params.extra_kv == nullptr, "V3.2 does not support extra KV cache");
KU_ASSERT(params.topk_length == nullptr, "V3.2 does not support dynamic topk length");
KU_ASSERT(params.stride_kv_row == 656); // number of bytes per token (512 fp8 + 4 float32 + 64 bfloat16)
}
auto shape_Q = make_shape(params.q_head_per_hk, params.d, params.s_q, params.b);
auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q, params.b);
auto tma_Q = cute::make_tma_copy(
SM90_TMA_LOAD{},
make_tensor(
make_gmem_ptr((bf16*)params.q_ptr),
make_gmem_ptr((bf16*)params.q),
make_layout(
shape_Q,
make_stride(params.q_row_stride, _1{}, params.q_head_per_hk*params.q_row_stride, params.q_batch_stride)
make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q, params.stride_q_b)
)
),
SmemLayoutQ{}
);
auto shape_O = make_shape(params.q_head_per_hk, params.d_v, params.s_q, params.b);
auto tma_O = cute::make_tma_copy(
SM90_TMA_STORE{},
make_tensor(
make_gmem_ptr((bf16*)params.o_ptr),
make_layout(
shape_O,
make_stride(params.o_row_stride, _1{}, params.q_head_per_hk*params.o_row_stride, params.o_batch_stride)
)
),
SmemLayoutOBuf{}
);
CUtensorMap tensor_map_o;
{
// Here we manually construct TMA descriptor to store O, in order to leverage 5D TMA
uint64_t size[5] = {OBUF_SW, (unsigned long)params.h_q, HEAD_DIM_V/OBUF_SW, (unsigned long)params.s_q, (unsigned long)params.b};
uint64_t stride[4] = {params.stride_o_h_q*sizeof(bf16), OBUF_SW*sizeof(bf16), params.stride_o_s_q*sizeof(bf16), params.stride_o_b*sizeof(bf16)};
uint32_t box_size[5] = {OBUF_SW, BLOCK_M, HEAD_DIM_V/OBUF_SW, 1, 1};
uint32_t elem_stride[5] = {1, 1, 1, 1, 1};
CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
&tensor_map_o,
CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
5,
params.out,
size,
stride,
box_size,
elem_stride,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
OBUF_SW == 64 ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B :
OBUF_SW == 32 ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B :
OBUF_SW == 16 ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_32B :
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);
KU_ASSERT(res == CUresult::CUDA_SUCCESS);
}
TmaParams<
decltype(shape_Q), decltype(tma_Q),
decltype(shape_O), decltype(tma_O)
decltype(shape_Q), decltype(tma_Q)
> tma_params = {
shape_Q, tma_Q,
shape_O, tma_O
tensor_map_o
};
auto mla_kernel = &flash_fwd_splitkv_mla_fp8_sparse_kernel<decltype(tma_params)>;
auto mla_kernel = &flash_fwd_splitkv_mla_fp8_sparse_kernel<KernelTemplate<MODEL_TYPE, NUM_HEADS>, decltype(tma_params)>;
constexpr size_t smem_size = sizeof(SharedMemoryPlan);
CHECK_CUDA(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
KU_CUDA_CHECK(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
const int num_m_block = cute::ceil_div(params.q_head_per_hk, 2*BLOCK_M) * 2;
// NOTE Don't use PDL because of potential compiler bugs!
// cudaLaunchAttribute mla_kernel_attributes[1];
// mla_kernel_attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
......@@ -599,16 +767,21 @@ void run_flash_splitkv_mla_fp8_sparse_kernel(DecodingParams &params, cudaStream_
// };
// cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params);
cutlass::ClusterLaunchParams launch_params = {
dim3(num_m_block, params.s_q, params.num_sm_parts),
dim3(NUM_M_BLOCKS, params.s_q, params.num_sm_parts),
dim3(NUM_THREADS, 1, 1),
dim3(2, 1, 1),
dim3(CLUSTER_SIZE, 1, 1),
smem_size,
stream
params.stream
};
cutlass::launch_kernel_on_cluster(
launch_params, (void*)mla_kernel, params, tma_params
);
CHECK_CUDA_KERNEL_LAUNCH();
KU_CHECK_KERNEL_LAUNCH();
}
template<ModelType MODEL_TYPE, int NUM_HEADS>
void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams &params) {
KernelTemplate<MODEL_TYPE, NUM_HEADS>::run(params);
}
}
......@@ -2,8 +2,10 @@
#include "params.h"
namespace sm90 {
namespace sm90::decode::sparse_fp8 {
void run_flash_splitkv_mla_fp8_sparse_kernel(DecodingParams &params, cudaStream_t stream);
template<ModelType MODEL_TYPE, int NUM_HEADS>
void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams &params);
}
#pragma once
#include <cutlass/bfloat16.h>
#include <cutlass/arch/barrier.h>
#include <cute/tensor.hpp>
#include <cutlass/arch/barrier.h>
namespace sm90 {
using bf16 = cutlass::bfloat16_t;
using transac_bar_t = cutlass::arch::ClusterTransactionBarrier;
using cutlass::arch::fence_view_async_shared;
using cutlass::arch::fence_barrier_init;
using cutlass::arch::NamedBarrier;
__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) {
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst);
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], %2;\n"
......@@ -51,7 +44,7 @@ __forceinline__ __device__ int64_t createpolicy_evict_first() {
__forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) {
// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx
// In the layout of fragment A and fragment C during WGMMA, the data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4);
return row_idx;
......@@ -99,7 +92,7 @@ __forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, T
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}
// A simpiler version of gemm
// A simpler version of gemm
template <typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm_ss(bool clear_accum, TiledMma tiled_mma, Tensor0 const &sA, Tensor1 const &sB, Tensor2 &rC_frag, int idx_in_warpgroup) {
using namespace cute;
......@@ -142,11 +135,11 @@ __forceinline__ __device__ void gemm_rs(bool clear_accum, TiledMma tiled_mma, Te
__forceinline__ __device__ uint32_t get_sm_id() {
uint32_t ret;
asm("mov.u32 %0, %smid;" : "=r"(ret));
asm("mov.u32 %0, %%smid;" : "=r"(ret));
return ret;
}
static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. 不确定是不是在所有显卡上都是这个数字
static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. Not sure if this number is the same on all GPUs.
template<typename T>
CUTE_DEVICE
T* get_peer_addr(const T* p) {
......@@ -163,12 +156,12 @@ void launch_tma_copy(
const TMA &tma_copy,
Tensor0 src,
Tensor1 dst,
transac_bar_t &bar,
cutlass::arch::ClusterTransactionBarrier &bar,
const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL
) {
auto thr_tma = tma_copy.get_slice(cute::_0{});
cute::copy(
tma_copy.with(reinterpret_cast<typename transac_bar_t::ValueType&>(bar), 0, cache_hint),
tma_copy.with(reinterpret_cast<typename cutlass::arch::ClusterTransactionBarrier::ValueType&>(bar), 0, cache_hint),
thr_tma.partition_S(src),
thr_tma.partition_D(dst)
);
......
#pragma once
#include <math_constants.h>
#include <cute/tensor.hpp>
#include <cutlass/cluster_launch.hpp>
#include <cooperative_groups.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/arch/arch.h>
#include <kerutils/kerutils.cuh>
#include "defines.h"
#include "params.h"
namespace sm90::fwd {
using namespace cute;
template<int D_QK, bool HAVE_TOPK_LENGTH>
class KernelTemplate {
public:
static constexpr int D_Q = D_QK;
static constexpr int D_K = D_QK;
static constexpr int D_V = 512;
static constexpr int B_H = 64;
static constexpr int B_TOPK = 64; // TopK block size
static constexpr int NUM_THREADS = 128*3;
static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits)
template<int NUM_TILES>
using SmemLayoutQTiles = decltype(coalesce(tile_to_shape(
GMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutOTiles = decltype(coalesce(tile_to_shape(
GMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTiles = decltype(coalesce(tile_to_shape(
GMMA::Layout_SW128_Atom<bf16, GMMA::Major::K>{},
Shape<Int<B_TOPK>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTilesTransposed = decltype(composition(
SmemLayoutKTiles<NUM_TILES>{},
Layout<Shape<Int<64*NUM_TILES>, Int<B_TOPK>>, Stride<Int<B_TOPK>, _1>>{}
));
using SmemLayoutQ = SmemLayoutQTiles<D_Q/64>;
using SmemLayoutO = SmemLayoutOTiles<D_V/64>;
using SmemLayoutK = SmemLayoutKTiles<D_Q/64>;
using SmemLayoutV = SmemLayoutKTilesTransposed<D_V/64>;
using SmemLayoutHalfV = SmemLayoutKTilesTransposed<D_V/64/2>;
using SmemLayoutS = decltype(coalesce(tile_to_shape(
GMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<B_TOPK>>{}
), Shape<_1, _1>{}));
struct SharedMemoryPlan {
union {
array_aligned<bf16, cosize_v<SmemLayoutQ>> q;
array_aligned<bf16, cosize_v<SmemLayoutO>> o;
} q_o;
array_aligned<bf16, cosize_v<SmemLayoutK>> k[2];
array_aligned<bf16, cosize_v<SmemLayoutS>> s[D_QK == 576 ? 1 : 2]; // For V3.2 (whose D_QK is 576), we overlap sS[0] with k's RoPE part to save shared memory; For MODEL1 (whose D_QK is 512), we allocate two buffers
bool is_kv_valid[2][B_TOPK];
float2 sM[32];
float2 sL[64]; // For reduction across WG0/1 in epilogue
float final_max_logits[64], final_lse[64];
transac_bar_t bar_q, bar_k0_free[2], bar_k0_ready[2], bar_k1_free[2], bar_k1_ready[2], bar_is_kv_valid_ready;
};
using TiledMMA_QK = decltype(make_tiled_mma(
GMMA::MMA_64x64x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::K>{},
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_LocalP = decltype(make_tiled_mma(
GMMA::MMA_64x256x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_RemoteP = decltype(make_tiled_mma(
GMMA::MMA_64x256x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{}
));
template<
typename Shape_Q, typename TMA_Q
>
struct TmaParams {
Shape_Q shape_Q; TMA_Q tma_Q;
CUtensorMap tensor_map_O;
};
enum NamedBarriers : uint32_t {
wg0_bunch_0_ready = 0,
wg1_bunch_0_ready = 1,
wg0_s0_ready = 2,
wg1_s1_ready = 3,
sL_ready = 4,
warpgroup0_sync = 5,
warpgroup1_sync = 6,
epilogue_sync = 7
};
// Save rPb (64x64, bfloat16) to sP using the stmatrix instruction
template<
typename Tensor0,
typename Tensor1
>
static __forceinline__ __device__ void save_rS_to_sS(
Tensor0 const &rPb,
Tensor1 const &sP,
int idx_in_warpgroup
) {
auto r2s_copy = make_tiled_copy_C(
Copy_Atom<SM90_U32x4_STSM_N, bf16>{},
TiledMMA_QK{}
);
ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup);
Tensor thr_copy_rPb = thr_copy.retile_S(rPb);
Tensor thr_copy_sP = thr_copy.partition_D(sP);
cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP);
}
template<typename TMAParams>
static __device__ __forceinline__ void
devfunc(const SparseAttnFwdParams &params, const TMAParams &tma_params);
static void run(const SparseAttnFwdParams &params);
};
};
#include "fwd.h"
#include <math_constants.h>
#include <cute/tensor.hpp>
#include <cutlass/cluster_launch.hpp>
#include <cooperative_groups.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/arch/arch.h>
#include <stdexcept>
#include "utils.h"
#include "helpers.h"
#include "phase1.h"
namespace sm90 {
using namespace cute;
void run_fwd_kernel(const SparseAttnFwdParams& params) {
const bool have_topk_length = params.topk_length != nullptr;
constexpr int D_Q = 576;
constexpr int D_K = 576;
constexpr int D_V = 512;
constexpr int B_H = 64;
constexpr int B_TOPK = 64; // TopK block size
constexpr int NUM_THREADS = 128*3;
static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits)
template<int NUM_TILES>
using SmemLayoutQTiles = decltype(coalesce(tile_to_shape(
GMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutOTiles = decltype(coalesce(tile_to_shape(
GMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTiles = decltype(coalesce(tile_to_shape(
GMMA::Layout_SW128_Atom<bf16, GMMA::Major::K>{},
Shape<Int<B_TOPK>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTilesTransposed = decltype(composition(
SmemLayoutKTiles<NUM_TILES>{},
Layout<Shape<Int<64*NUM_TILES>, Int<B_TOPK>>, Stride<Int<B_TOPK>, _1>>{}
));
using SmemLayoutQ = SmemLayoutQTiles<9>;
using SmemLayoutO = SmemLayoutOTiles<8>;
using SmemLayoutK = SmemLayoutKTiles<9>;
using SmemLayoutV = SmemLayoutKTilesTransposed<8>;
using SmemLayoutHalfV = SmemLayoutKTilesTransposed<4>;
using SmemLayoutS = decltype(coalesce(tile_to_shape(
GMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<B_TOPK>>{}
), Shape<_1, _1>{}));
struct SharedMemoryPlan {
union {
array_aligned<bf16, cosize_v<SmemLayoutQ>> q;
array_aligned<bf16, cosize_v<SmemLayoutO>> o;
} q_o;
array_aligned<bf16, cosize_v<SmemLayoutK>> k[2];
array_aligned<bf16, cosize_v<SmemLayoutS>> s;
bool is_kv_valid[2][B_TOPK];
float2 sM[32];
float2 sL[64]; // For reduction across WG0/1 in epilogue
float final_max_logits[64], final_lse[64];
transac_bar_t bar_q, bar_k0_free[2], bar_k0_ready[2], bar_k1_free[2], bar_k1_ready[2], bar_is_kv_valid_ready;
};
using TiledMMA_QK = decltype(make_tiled_mma(
GMMA::MMA_64x64x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::K>{},
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_LocalP = decltype(make_tiled_mma(
GMMA::MMA_64x256x16_F32BF16BF16_RS<GMMA::Major::K, GMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_RemoteP = decltype(make_tiled_mma(
GMMA::MMA_64x256x16_F32BF16BF16_SS<GMMA::Major::K, GMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{}
));
template<
typename Shape_Q, typename TMA_Q
>
struct TmaParams {
Shape_Q shape_Q; TMA_Q tma_Q;
CUtensorMap tensor_map_O;
};
enum NamedBarriers : uint32_t {
wg0_bunch_0_ready = 0,
wg1_bunch_0_ready = 1,
wg0_s0_ready = 2,
wg1_s1_ready = 3,
sL_ready = 4,
warpgroup0_sync = 5,
warpgroup1_sync = 6
};
// Save rPb (64x64, bfloat16) to sP using the stmatrix instruction
template<
typename Tensor0,
typename Tensor1
>
__forceinline__ __device__ void save_rS_to_sS(
Tensor0 const &rPb,
Tensor1 const &sP,
int idx_in_warpgroup
) {
auto r2s_copy = make_tiled_copy_C(
Copy_Atom<SM90_U32x4_STSM_N, bf16>{},
TiledMMA_QK{}
);
ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup);
Tensor thr_copy_rPb = thr_copy.retile_S(rPb);
Tensor thr_copy_sP = thr_copy.partition_D(sP);
cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP);
}
template<typename TmaParams>
__global__ void __launch_bounds__(NUM_THREADS, 1, 1)
sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __grid_constant__ const TmaParams tma_params) {
// NOTE This kernel uses a similar schedule to Flash MLA - 0422. For a detailed explanation, please refer to https://github.com/deepseek-ai/FlashMLA/blob/main/docs/20250422-new-kernel-deep-dive.md
#if IS_SM90
const int q_h_idx = blockIdx.x % (params.h_q/B_H);
const int s_q_idx = blockIdx.x / (params.h_q/B_H);
const int warpgroup_idx = cutlass::canonical_warp_group_idx();
const int warp_idx = cutlass::canonical_warp_idx_sync();
const int idx_in_warpgroup = threadIdx.x % 128;
// Define shared tensors
extern __shared__ char wksp_buf[];
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);
Tensor sQ = make_tensor(make_smem_ptr(plan.q_o.q.data()), SmemLayoutQ{});
Tensor sO = make_tensor(make_smem_ptr(plan.q_o.o.data()), SmemLayoutO{});
Tensor sS0 = make_tensor(make_smem_ptr(plan.k[0].data()+64*512), SmemLayoutS{}); // Overlap with sK0's RoPE part
Tensor sS1 = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutS{});
if (warp_idx == 0 && elect_one_sync()) {
// Prefetch TMA descriptors
cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(&tma_params.tensor_map_O);
// Initialize barriers
plan.bar_q.init(1);
CUTE_UNROLL
for (int i = 0; i < 2; ++i) {
plan.bar_k0_free[i].init(128);
plan.bar_k0_ready[i].init(128);
plan.bar_k1_free[i].init(128);
plan.bar_k1_ready[i].init(128);
}
plan.bar_is_kv_valid_ready.init(16);
fence_barrier_init();
}
__syncthreads();
const int num_topk_blocks = params.topk / B_TOPK;
if (warpgroup_idx == 0 || warpgroup_idx == 1) {
cutlass::arch::warpgroup_reg_alloc<216>();
if (warp_idx == 0 && elect_one_sync()) {
// Load Q
Tensor gQ = flat_divide(
tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx),
Tile<Int<B_H>, Int<D_Q>>{}
)(_, _, q_h_idx, _0{});
launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST);
plan.bar_q.arrive_and_expect_tx(B_H*D_Q*sizeof(bf16));
// Dispatch based on d_qk dimension and presence of topk_length
if (params.d_qk == 512) {
if (have_topk_length) {
sm90::fwd::run_fwd_phase1_kernel<512, true>(params);
} else {
sm90::fwd::run_fwd_phase1_kernel<512, false>(params);
}
float rM[2] = {MAX_INIT_VAL, MAX_INIT_VAL}; // Meaning: the `max_logits` used for O / rL calculation
float rL[2] = {0.0f, 0.0f};
Tensor rO = partition_fragment_C(TiledMMA_PV_LocalP{}, Shape<Int<B_H>, Int<D_V/2>>{});
Tensor rP = partition_fragment_C(TiledMMA_QK{}, Shape<Int<B_H>, Int<B_TOPK>>{});
Tensor rS = make_tensor<bf16>(partition_shape_A(TiledMMA_PV_LocalP{}, Shape<Int<B_H>, Int<B_TOPK>>{}));
cute::fill(rO, 0.0f);
// Wait for Q
plan.bar_q.wait(0);
bool cur_bar_wait_phase = 0;
struct Warpgroup0 {};
struct Warpgroup1 {};
auto qkt_gemm_one_tile = [&](auto warpgroup_idx, int tile_idx, bool clear_accum) {
constexpr bool IS_WG1 = std::is_same_v<decltype(warpgroup_idx), Warpgroup1>;
TiledMMA tiled_mma_QK = TiledMMA_QK{};
Tensor sQ_tile = flat_divide(sQ, Tile<Int<B_H>, Int<64>>{})(_, _, _0{}, tile_idx);
Tensor sK_tile = make_tensor(make_smem_ptr(plan.k[(int)IS_WG1].data() + tile_idx*B_TOPK*64), SmemLayoutKTiles<1>{});
gemm_ss(clear_accum, tiled_mma_QK, sQ_tile, sK_tile, rP, idx_in_warpgroup);
};
auto mask_rP = [&](auto warpgroup_idx) {
constexpr bool IS_WG1 = std::is_same_v<decltype(warpgroup_idx), Warpgroup1>;
plan.bar_is_kv_valid_ready.wait(cur_bar_wait_phase);
CUTE_UNROLL
for (int row_idx = 0; row_idx < 2; ++row_idx) {
CUTE_UNROLL
for (int i = row_idx*2; i < size(rP); i += 4) {
int col = 8*(i/4) + (idx_in_warpgroup%4)*2;
if (!plan.is_kv_valid[IS_WG1][col]) rP(i) = -INFINITY;
if (!plan.is_kv_valid[IS_WG1][col+1]) rP(i+1) = -INFINITY;
}
}
};
auto online_softmax_and_rescale_o = [&](auto warpgroup_idx) {
plan.bar_is_kv_valid_ready.wait(cur_bar_wait_phase);
constexpr bool IS_WG1 = std::is_same_v<decltype(warpgroup_idx), Warpgroup1>;
const float scale = params.sm_scale_div_log2;
float r_sM[2];
if constexpr (IS_WG1) {
*(float2*)r_sM = plan.sM[idx_in_warpgroup/4];
}
float new_maxs[2];
CUTE_UNROLL
for (int row_idx = 0; row_idx < 2; ++row_idx) {
// Get rowwise max
float cur_max = -INFINITY;
CUTE_UNROLL
for (int i = row_idx*2; i < size(rP); i += 4) {
cur_max = max(cur_max, max(rP(i), rP(i+1)));
}
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1));
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2));
cur_max *= scale;
// Get new max and scale
// For WG1, old_max comes from sM (written by WG0); for WG0, old_max comes from rM (read by WG0 from sM in the last round)
new_maxs[row_idx] = max(IS_WG1 ? r_sM[row_idx] : rM[row_idx], cur_max);
// Scale O
float scale_for_o = exp2f(rM[row_idx]-new_maxs[row_idx]);
CUTE_UNROLL
for (int i = row_idx*2; i < size(rO); i += 4) {
rO(i) *= scale_for_o;
rO(i+1) *= scale_for_o;
}
// Get rS
float cur_sum = 0;
CUTE_UNROLL
for (int i = row_idx*2; i < size(rP); i += 4) {
rP(i) = exp2f(rP(i)*scale - new_maxs[row_idx]);
rP(i+1) = exp2f(rP(i+1)*scale - new_maxs[row_idx]);
rS(i) = (bf16)rP(i);
rS(i+1) = (bf16)rP(i+1);
cur_sum += rP(i) + rP(i+1);
}
rL[row_idx] = rL[row_idx]*scale_for_o + cur_sum;
}
__syncwarp();
if (idx_in_warpgroup%4 == 0) {
plan.sM[idx_in_warpgroup/4] = *(float2*)new_maxs;
}
rM[0] = new_maxs[0];
rM[1] = new_maxs[1];
};
auto reduce_L = [&]() {
// Reduce L
// For example, thread 0 reduces with thread 1, 2, and 3, as well as thread 128, 129, 130, and 131
rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1);
rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2);
rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1);
rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2);
if (idx_in_warpgroup%4 == 0)
plan.sL[threadIdx.x/4] = *(float2*)(rL);
NamedBarrier::arrive_and_wait(256, NamedBarriers::sL_ready);
float2 peer_L = plan.sL[(threadIdx.x/4)^32];
rL[0] += peer_L.x;
rL[1] += peer_L.y;
};
auto store_O = [&]() {
float scale_factors[2];
CUTE_UNROLL
for (int i = 0; i < 2; ++i)
scale_factors[i] = rL[i] == 0.0f ? 1.0f : 1.0f / rL[i];
Tensor sO = make_tensor(make_smem_ptr(plan.q_o.o.data() + warpgroup_idx*B_H*(D_V/2)), SmemLayoutOTiles<4>{});
bf16* stsm_addrs[4];
int stsm_row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%16);
CUTE_UNROLL
for (int i = 0; i < 64/16; ++i) {
stsm_addrs[i] = &sO(stsm_row, (idx_in_warpgroup%32/16*8) + 16*i);
}
bool s2g_pred = warp_idx%4 == 0 && elect_one_sync();
warpgroup_wait<0>();
CUTE_UNROLL
for (int tile_idx = 0; tile_idx < (D_V/2)/64; tile_idx += 1) {
// Convert
constexpr int NUM_ELEMS_EACH_TILE = B_H*64 / 128; // 64: tile size, 128: warpgroup size
bf16 cur_rOb[NUM_ELEMS_EACH_TILE];
CUTE_UNROLL
for (int i = 0; i < NUM_ELEMS_EACH_TILE; ++i) {
cur_rOb[i] = (bf16)(rO(tile_idx*NUM_ELEMS_EACH_TILE + i) * scale_factors[i%4>=2]);
}
// R -> S
CUTE_UNROLL
for (int i = 0; i < 64/16; ++i) {
SM90_U32x4_STSM_N::copy(
*reinterpret_cast<uint32_t*>(cur_rOb + i*8 + 0),
*reinterpret_cast<uint32_t*>(cur_rOb + i*8 + 2),
*reinterpret_cast<uint32_t*>(cur_rOb + i*8 + 4),
*reinterpret_cast<uint32_t*>(cur_rOb + i*8 + 6),
*reinterpret_cast<uint128_t*>(stsm_addrs[i] + tile_idx*(B_H*64))
);
}
fence_view_async_shared();
NamedBarrier::arrive_and_wait(128, warpgroup_idx ? NamedBarriers::warpgroup1_sync : NamedBarriers::warpgroup0_sync);
// S -> G
if (s2g_pred) {
int g_tile_idx = warpgroup_idx*4 + tile_idx;
SM90_TMA_STORE_3D::copy(
&tma_params.tensor_map_O,
plan.q_o.o.data() + g_tile_idx*(B_H*64),
g_tile_idx*64,
q_h_idx*B_H,
s_q_idx
);
}
}
cute::tma_store_arrive();
};
if (warpgroup_idx == 0) {
// Warpgroup 0
auto pipelined_wait_and_qkt_gemm_l = [&]() __attribute__((always_inline)) {
plan.bar_k0_ready[0].wait(cur_bar_wait_phase);
qkt_gemm_one_tile(Warpgroup0{}, 0, true);
qkt_gemm_one_tile(Warpgroup0{}, 1, false);
qkt_gemm_one_tile(Warpgroup0{}, 2, false);
qkt_gemm_one_tile(Warpgroup0{}, 3, false);
warpgroup_commit_batch();
};
auto pipelined_wait_and_qkt_gemm_r = [&]() __attribute__((always_inline)) {
plan.bar_k0_ready[1].wait(cur_bar_wait_phase);
qkt_gemm_one_tile(Warpgroup0{}, 4, false);
qkt_gemm_one_tile(Warpgroup0{}, 5, false);
qkt_gemm_one_tile(Warpgroup0{}, 6, false);
qkt_gemm_one_tile(Warpgroup0{}, 7, false);
qkt_gemm_one_tile(Warpgroup0{}, 8, false);
warpgroup_commit_batch();
};
auto scale_rS = [&](float scales[2]) {
CUTE_UNROLL
for (int row = 0; row < 2; ++row) {
CUTE_UNROLL
for (int i = row*2; i < size(rP); i += 4) {
rS(i) = (bf16)(rP(i) * scales[row]);
rS(i+1) = (bf16)(rP(i+1) * scales[row]);
}
}
};
auto rescale_rO = [&](float scales[2]) {
CUTE_UNROLL
for (int row = 0; row < 2; ++row) {
CUTE_UNROLL
for (int i = row*2; i < size(rO); i += 4) {
rO(i) *= scales[row];
rO(i+1) *= scales[row];
}
rL[row] *= scales[row];
}
};
CUTE_NO_UNROLL
for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) {
Tensor sV0l = make_tensor(make_smem_ptr(plan.k[0].data()), SmemLayoutKTilesTransposed<4>{});
Tensor sV1l = make_tensor(make_smem_ptr(plan.k[1].data()), SmemLayoutKTilesTransposed<4>{});
if (block_idx == 0) {
// NOTE We put these code here to avoid register spilling
pipelined_wait_and_qkt_gemm_l();
pipelined_wait_and_qkt_gemm_r();
warpgroup_wait<0>();
}
// Online softmax, inform WG1
mask_rP(Warpgroup0{});
online_softmax_and_rescale_o(Warpgroup0{});
NamedBarrier::arrive(256, NamedBarriers::wg0_bunch_0_ready);
// Issue rO0 += rS0 @ sV0l
gemm_rs(false, TiledMMA_PV_LocalP{}, rS, sV0l, rO, idx_in_warpgroup);
warpgroup_commit_batch();
// Mark V0L as free
warpgroup_wait<0>();
plan.bar_k0_free[0].arrive();
// Wait for new sM, scale rS, save, inform WG1
NamedBarrier::arrive_and_wait(256, NamedBarriers::wg1_bunch_0_ready);
float new_rM[2], scale_factors[2];
*(float2*)new_rM = plan.sM[idx_in_warpgroup/4];
CUTE_UNROLL
for (int i = 0; i < 2; ++i) {
scale_factors[i] = exp2f(rM[i] - new_rM[i]);
rM[i] = new_rM[i];
}
scale_rS(scale_factors);
save_rS_to_sS(rS, sS0, idx_in_warpgroup);
fence_view_async_shared();
NamedBarrier::arrive(256, NamedBarriers::wg0_s0_ready);
// Wait for sS1
NamedBarrier::arrive_and_wait(256, NamedBarriers::wg1_s1_ready);
// Rescale rO0, Issue rO0 += sS1 @ sV1L
rescale_rO(scale_factors);
gemm_ss(false, TiledMMA_PV_RemoteP{}, sS1, sV1l, rO, idx_in_warpgroup);
warpgroup_commit_batch();
cur_bar_wait_phase ^= 1;
if (block_idx+2 < num_topk_blocks) {
// Launch the next QK^T GEMM
pipelined_wait_and_qkt_gemm_l();
// Mark V1L as free
warpgroup_wait<1>();
plan.bar_k1_free[0].arrive();
pipelined_wait_and_qkt_gemm_r();
// Wait for rP0 = sQ @ sK0
warpgroup_wait<0>();
} else {
// Mark V1L as free
warpgroup_wait<0>();
plan.bar_k1_free[0].arrive();
}
}
reduce_L();
store_O();
} else if (params.d_qk == 576) {
if (have_topk_length) {
sm90::fwd::run_fwd_phase1_kernel<576, true>(params);
} else {
// Warpgroup 1
auto pipelined_wait_and_qkt_gemm = [&]() __attribute__((always_inline)) {
plan.bar_k1_ready[1].wait(cur_bar_wait_phase);
qkt_gemm_one_tile(Warpgroup1{}, 4, true);
qkt_gemm_one_tile(Warpgroup1{}, 5, false);
qkt_gemm_one_tile(Warpgroup1{}, 6, false);
qkt_gemm_one_tile(Warpgroup1{}, 7, false);
qkt_gemm_one_tile(Warpgroup1{}, 8, false);
plan.bar_k1_ready[0].wait(cur_bar_wait_phase);
qkt_gemm_one_tile(Warpgroup1{}, 0, false);
qkt_gemm_one_tile(Warpgroup1{}, 1, false);
qkt_gemm_one_tile(Warpgroup1{}, 2, false);
qkt_gemm_one_tile(Warpgroup1{}, 3, false);
warpgroup_commit_batch();
};
CUTE_NO_UNROLL
for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) {
Tensor sV0r = make_tensor(make_smem_ptr(plan.k[0].data()+64*256), SmemLayoutKTilesTransposed<4>{});
Tensor sV1r = make_tensor(make_smem_ptr(plan.k[1].data()+64*256), SmemLayoutKTilesTransposed<4>{});
// Issue rP1 = sQ @ sK1, and wait
pipelined_wait_and_qkt_gemm();
warpgroup_wait<0>();
mask_rP(Warpgroup1{});
// Wait for WG0 (for sM), online softmax, Notify WG0 (sM ready)
NamedBarrier::arrive_and_wait(256, NamedBarriers::wg0_bunch_0_ready);
online_softmax_and_rescale_o(Warpgroup1{});
NamedBarrier::arrive(256, NamedBarriers::wg1_bunch_0_ready);
// Issue rO1 += rS1 @ sV1R
gemm_rs(false, TiledMMA_PV_LocalP{}, rS, sV1r, rO, idx_in_warpgroup);
warpgroup_commit_batch();
// Wait for WG0 (for sS0), Issue rO1 += rS0 @ sV0R
save_rS_to_sS(rS, sS1, idx_in_warpgroup); // Put it here is faster
NamedBarrier::arrive_and_wait(256, NamedBarriers::wg0_s0_ready);
gemm_ss(false, TiledMMA_PV_RemoteP{}, sS0, sV0r, rO, idx_in_warpgroup);
warpgroup_commit_batch();
// Save rS1, inform WG0
fence_view_async_shared();
NamedBarrier::arrive(256, NamedBarriers::wg1_s1_ready);
// Wait for GEMM, and inform that sV1R is free
warpgroup_wait<1>();
plan.bar_k1_free[1].arrive();
// Wait for GEMM, and inform that sV0R is free
warpgroup_wait<0>();
plan.bar_k0_free[1].arrive();
cur_bar_wait_phase ^= 1;
}
reduce_L();
store_O();
// Save lse
if (idx_in_warpgroup%4 == 0) {
for (int row = 0; row < 2; ++row) {
int real_row = get_AorC_row_idx(row, idx_in_warpgroup);
bool is_no_valid_tokens = rL[row] == 0.0f;
plan.final_max_logits[real_row] = is_no_valid_tokens ? -INFINITY : rM[row];
plan.final_lse[real_row] = is_no_valid_tokens ? -INFINITY : log2f(rL[row]) + rM[row];
}
fence_view_async_shared();
}
NamedBarrier::arrive_and_wait(128, NamedBarriers::warpgroup1_sync);
if (idx_in_warpgroup == 0) {
int g_offset = s_q_idx*params.h_q + q_h_idx*B_H;
SM90_BULK_COPY_S2G::copy(plan.final_max_logits, params.max_logits + g_offset, B_H*sizeof(float));
SM90_BULK_COPY_S2G::copy(plan.final_lse, params.lse + g_offset, B_H*sizeof(float));
cute::tma_store_arrive();
}
sm90::fwd::run_fwd_phase1_kernel<576, false>(params);
}
} else {
// Producer warpgroup
cutlass::arch::warpgroup_reg_dealloc<72>();
constexpr int GROUP_SIZE = 8, NUM_GROUPS = 128/GROUP_SIZE;
constexpr int NUM_ROWS_PER_GROUP = B_TOPK / NUM_GROUPS;
int idx_in_group = idx_in_warpgroup % GROUP_SIZE;
int group_idx = idx_in_warpgroup / GROUP_SIZE;
int* gIndices = params.indices + s_q_idx*params.topk; // [topk]
bf16* my_sKV_base = &(make_tensor(make_smem_ptr(plan.k[0].data()), SmemLayoutKTiles<1>{})(group_idx, idx_in_group*8));
bf16* my_gKV_base = params.kv + idx_in_group*8;
int64_t token_indices[2][NUM_ROWS_PER_GROUP];
bool is_token_valid[2][NUM_ROWS_PER_GROUP];
auto load_token_indices = [&](int block_idx) {
CUTE_UNROLL
for (int buf_idx = 0; buf_idx < 2; ++buf_idx) {
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) {
int offs = (block_idx+buf_idx)*B_TOPK + local_row*NUM_GROUPS + group_idx;
int t = __ldg(gIndices + offs);
token_indices[buf_idx][local_row] = t*(int64_t)params.stride_kv_s_kv; // We mult it with params.stride_kv_s_kv here since it's faster
is_token_valid[buf_idx][local_row] = t >= 0 && t < params.s_kv;
}
}
};
int64_t cache_policy = createpolicy_evict_last();
auto copy_tiles = [&](int block_idx, int buf_idx, int tile_start, int tile_end) {
// Copy some K/V tiles from global memory to shared memory
// A tile has a shape of 64 (B_TOPK) x 64
// `buf_idx` is the index of the shared memory buffer, 0 or 1
// `tile_idx` is the index of the tile to load, from 0 to D_K/64-1 = 8
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row) {
int64_t token_index = token_indices[buf_idx][local_row];
CUTE_UNROLL
for (int tile_idx = tile_start; tile_idx < tile_end; ++tile_idx) {
cp_async_cacheglobal_l2_prefetch_256B(
my_gKV_base + token_index + tile_idx*64,
my_sKV_base + (buf_idx*B_TOPK*D_K + tile_idx*(B_TOPK*64) + local_row*NUM_GROUPS*64),
is_token_valid[buf_idx][local_row],
cache_policy
);
}
}
};
auto commit_to_mbar = [&](transac_bar_t &bar) {
cutlass::arch::cpasync_barrier_arrive_noinc((uint64_t*)(&bar));
};
int cur_bar_wait_phase = 1;
CUTE_NO_UNROLL
for (int block_idx = 0; block_idx < num_topk_blocks; block_idx += 2) {
load_token_indices(block_idx);
// V0L
plan.bar_k0_free[0].wait(cur_bar_wait_phase);
copy_tiles(block_idx+0, 0, 0, 4);
commit_to_mbar(plan.bar_k0_ready[0]);
// V1R
plan.bar_k1_free[1].wait(cur_bar_wait_phase);
copy_tiles(block_idx+1, 1, 4, 9);
commit_to_mbar(plan.bar_k1_ready[1]);
// V0R
plan.bar_k0_free[1].wait(cur_bar_wait_phase);
copy_tiles(block_idx+0, 0, 4, 9);
commit_to_mbar(plan.bar_k0_ready[1]);
// V1L
plan.bar_k1_free[0].wait(cur_bar_wait_phase);
copy_tiles(block_idx+1, 1, 0, 4);
commit_to_mbar(plan.bar_k1_ready[0]);
// Valid mask
// NOTE V1R's finish implies maskings of the last round have finished
if (idx_in_group == 0) {
CUTE_UNROLL
for (int buf_idx = 0; buf_idx < 2; ++buf_idx)
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_ROWS_PER_GROUP; ++local_row)
plan.is_kv_valid[buf_idx][local_row*NUM_GROUPS+group_idx] = is_token_valid[buf_idx][local_row];
plan.bar_is_kv_valid_ready.arrive();
}
cur_bar_wait_phase ^= 1;
}
throw std::runtime_error("Unsupported d_qk value in sparse attention fwd kernel");
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm90");
}
#endif
}
void run_fwd_kernel(const SparsePrefillParams& params) {
FLASH_ASSERT(params.h_kv == 1);
FLASH_ASSERT(params.topk % (2*B_TOPK) == 0); // To save some boundry checkings
FLASH_ASSERT(params.topk > 0);
FLASH_ASSERT(params.h_q % B_H == 0);
auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q);
auto tma_Q = cute::make_tma_copy(
SM90_TMA_LOAD{},
make_tensor(
make_gmem_ptr((bf16*)params.q),
make_layout(
shape_Q,
make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q)
)
),
SmemLayoutQ{}
);
CUtensorMap tensor_map_O;
{
uint64_t size[3] = {D_V, (unsigned long)params.h_q, (unsigned long)params.s_q};
uint64_t stride[2] = {D_V*sizeof(bf16), D_V*params.h_q*sizeof(bf16)};
uint32_t box_size[3] = {64, B_H, 1};
uint32_t elem_stride[3] = {1, 1, 1};
CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
&tensor_map_O,
CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
3,
params.out,
size,
stride,
box_size,
elem_stride,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);
FLASH_ASSERT(res == CUresult::CUDA_SUCCESS);
}
TmaParams<
decltype(shape_Q), decltype(tma_Q)
> tma_params = {
shape_Q, tma_Q,
tensor_map_O
};
auto kernel = &sparse_attn_fwd_kernel<decltype(tma_params)>;
constexpr size_t smem_size = sizeof(SharedMemoryPlan);
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
cutlass::ClusterLaunchParams launch_params = {
dim3((params.h_q/B_H)*params.s_q, 1, 1), // NOTE We put s_q on the first dim since it can be larger than 65536 (the maximum size of griddim.y and griddim.z)
dim3(NUM_THREADS, 1, 1),
dim3(1, 1, 1),
smem_size,
params.stream
};
cutlass::launch_kernel_on_cluster(
launch_params, (void*)kernel, params, tma_params
);
CHECK_CUDA_KERNEL_LAUNCH();
}
}
} // namespace sm90
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment