"plugins/vscode:/vscode.git/clone" did not exist on "41abd9fb9464de0b9dc859aeea520393034f42e9"
Commit e2e0225c authored by zhanghj2's avatar zhanghj2
Browse files

空kernel可以编译通过

parent 48c6dc42
#pragma once
#include "kernel.h"
#include <cuda_fp8.h>
#include <cutlass/barrier.h>
#include <cute/tensor.hpp>
#include <kerutils/kerutils.cuh>
#include "defines.h"
#include "params.h"
namespace sm100::decode::head64 {
using cutlass::arch::fence_view_async_shared;
using cutlass::arch::NamedBarrier;
using e8m0 = __nv_fp8_e8m0;
using e4m3 = cutlass::float_e4m3_t;
using namespace cute;
enum NamedBarriers : uint32_t {
main_loop_sync = 0,
wg0_sync = 1,
wg0_warp02_sync = 2,
wg0_warp13_sync = 3,
everyone_sync = 4
};
template<ModelType MODEL_TYPE>
struct KernelTemplate {
static constexpr int D_Q = MODEL_TYPE == ModelType::V32 ? 576 : 512;
static constexpr int D_K = D_Q;
static constexpr int D_V = 512;
static constexpr int D_NOPE = MODEL_TYPE == ModelType::V32 ? 512 : 448;
static constexpr int D_ROPE = 64;
static constexpr int QUANT_TILE_SIZE = MODEL_TYPE == ModelType::V32 ? 128 : 64;
static constexpr bool V_HAVE_ROPE = MODEL_TYPE == ModelType::V32 ? false : true;
static constexpr int NUM_SCALES_EACH_TOKEN = MODEL_TYPE == ModelType::V32 ? 4 : 8; // Padding is included
static constexpr int TMA_K_STRIDE = MODEL_TYPE == ModelType::V32 ? D_NOPE+2*D_ROPE+4*(D_NOPE/QUANT_TILE_SIZE) : D_NOPE+2*D_ROPE; // Stride of K's tensormap. This stride must 1) be a factor of the actual stride between tokens 2) large enough to cover the entire KV cache. Since TMA copy's coordinate can only be 32bit signed integers, this number must >= 128, perferrably >= 256. So we set this to 656 for V32 and 576 for MODEL1. Extra padding may be necessary for KV blocks.
static_assert(D_NOPE + D_ROPE == D_Q);
static_assert(V_HAVE_ROPE ? (D_NOPE + D_ROPE == D_V) : (D_NOPE == D_V));
static constexpr int B_H = 64;
static constexpr int B_TOPK = 64;
static constexpr int NUM_BUFS = 2;
static constexpr int NUM_INDEX_BUFS = 4; // Number of buffers for indices (tma_coords) & is_token_valid & scales
static constexpr int NUM_THREADS = 128*3; // 128 exp + 1/32 utcmma + 1/32 raw KV producer + 1/32 rope producer + 32 index+scale+valid_mask producer + 128 dequant
static constexpr float MAX_INIT_VAL = -1e30f; // To avoid (-inf) - (-inf) = NaN
static constexpr int D_Q_SW128 = 512;
static constexpr int D_Q_SW64 = MODEL_TYPE == ModelType::V32 ? 64 : 0;
static_assert(D_Q_SW128 + D_Q_SW64 == D_Q);
static constexpr int K_ROPE_SW = MODEL_TYPE == ModelType::V32 ? 64 : 128; // RoPE part stored in SW64 (for V32) or SW128 (for MODEL1), in bytes
template<
typename Shape_Q_SW128, typename TMA_Q_SW128,
typename Shape_O, typename TMA_O
>
struct TmaParams {
Shape_Q_SW128 shape_Q_SW128; TMA_Q_SW128 tma_Q_SW128;
Shape_O shape_O; TMA_O tma_O;
CUtensorMap tensor_map_q_sw64; // Invalid if D_Q_SW64 == 0
CUtensorMap tensor_map_kv_nope;
CUtensorMap tensor_map_kv_rope;
CUtensorMap tensor_map_extra_kv_nope;
CUtensorMap tensor_map_extra_kv_rope;
};
// Tensor memory columns
struct tmem_cols {
// 0 ~ 256: output
// 256 ~ 256 + 64*D_Q/256: Q
// 400 ~ 464: P
static constexpr int O = 0;
static constexpr int Q = 256;
static constexpr int Q_Tail = 256 + B_H*D_NOPE/2/128;
static constexpr int P = 400;
};
template<int NUM_TILES>
using SmemLayoutQTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<NUM_TILES*64>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutQ_SW128 = SmemLayoutQTiles<D_Q_SW128/64>;
using SmemLayoutOBuf = decltype(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<D_V>>{}
));
using SmemLayoutOBuf_TMA = decltype(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<64>>{}
)); // A TMA tile
static_assert(D_V == 512);
using SmemLayoutOAccumBuf = Layout<
Shape<Int<B_H>, Int<D_V>>,
Stride<Int<520>, _1> // We use stride = 520 here to avoid bank conflict
>;
using SmemLayoutS = decltype(tile_to_shape(
UMMA::Layout_K_INTER_Atom<bf16>{},
Shape<Int<B_H>, Int<B_TOPK>>{},
Step<_1, _2>{}
));
template<int NUM_TILES>
using SmemLayoutKTiles_SW128 = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTiles_DualGemm_SW128 = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H*2>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTilesTransposed_SW128 = decltype(composition(
SmemLayoutKTiles_SW128<NUM_TILES>{},
Layout<
Shape<Int<64*NUM_TILES>, Int<B_TOPK>>,
Stride<Int<B_TOPK>, _1>
>{}
));
template<int NUM_TILES>
using SmemLayoutKTiles_SW64 = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW64_Atom<bf16>{},
Shape<Int<B_H>, Int<32*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTiles_DualGemm_SW64 = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW64_Atom<bf16>{},
Shape<Int<B_H*2>, Int<32*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTilesTransposed_SW64 = decltype(composition(
SmemLayoutKTiles_SW64<NUM_TILES>{},
Layout<
Shape<Int<32*NUM_TILES>, Int<B_TOPK>>,
Stride<Int<B_TOPK>, _1>
>{}
));
struct SharedMemoryPlan {
union {
struct {
array_aligned<bf16, cosize_v<SmemLayoutQ_SW128>> q;
bf16 q_sw64[B_H*D_Q_SW64]; // NOTE D_Q_SW64 may be 0 but array_aligned<bf16, 0> will have a size of 16, so we use array here. The former tensor (`q`) promises its alignment.
union {
array_aligned<bf16, cosize_v<SmemLayoutOBuf>> o_buf;
array_aligned<float, cosize_v<SmemLayoutOAccumBuf>> o_accum_buf;
} o;
} qo;
struct {
struct {
array_aligned<bf16, B_H*D_NOPE> nope; // NoPE part, dequantized
array_aligned<bf16, B_H*D_ROPE> rope; // RoPE part, dequantized. SW64 in v32 mode, SW128 in MODEL1 mode
} dequant[NUM_BUFS];
static_assert(sizeof(dequant) >= sizeof(bf16) * (B_H*D_Q)); // So that Q does not covers raw_nope
array_aligned<e4m3, B_H*D_NOPE> raw_nope[NUM_BUFS]; // Raw (quantized) NoPE part
} kv;
} u;
union {
float4 p_exchange_buf[4][16 * B_TOPK / 4];
array_aligned<bf16, cosize_v<SmemLayoutS>> s;
} s_p;
CUTE_ALIGNAS(16) float rowwise_max_buf[128];
char is_token_valid[NUM_INDEX_BUFS][B_TOPK/8];
int tma_coord[NUM_INDEX_BUFS][B_TOPK];
e8m0 scales[NUM_INDEX_BUFS][B_TOPK][NUM_SCALES_EACH_TOKEN];
array_aligned<uint32_t, 1> tmem_start_addr;
transac_bar_t bar_last_store_done;
transac_bar_t bar_q_tma, bar_q_utccp;
transac_bar_t bar_rope_ready[NUM_BUFS];
transac_bar_t bar_nope_ready[NUM_BUFS];
transac_bar_t bar_raw_ready[NUM_BUFS], bar_raw_free[NUM_BUFS];
transac_bar_t bar_valid_coord_scale_ready[NUM_INDEX_BUFS], bar_valid_coord_scale_free[NUM_INDEX_BUFS];
transac_bar_t bar_qk_done[NUM_BUFS], bar_so_ready[NUM_BUFS], bar_sv_done[NUM_BUFS];
};
using TiledMMA_P = decltype(make_tiled_mma(
SM100_MMA_F16BF16_WS_TS_NOELECT<bf16, bf16, float, B_H, B_TOPK*2, UMMA::Major::K, UMMA::Major::K>{}
)); // *2 for dual gemm
using TiledMMA_O = decltype(make_tiled_mma(
SM100_MMA_F16BF16_WS_SS_NOELECT<bf16, bf16, float, B_H, 256, UMMA::Major::K, UMMA::Major::MN>{}
));
template<typename TmaParam>
static __device__ void
flash_fwd_splitkv_mla_fp8_sparse_kernel_devfunc(const SparseAttnDecodeParams &params, const TmaParam &tma_params);
static void run(const SparseAttnDecodeParams &params);
};
}
\ No newline at end of file
#include "../kernel.cuh"
namespace sm100::decode::head64 {
template
void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::MODEL1>(const SparseAttnDecodeParams &params);
}
#include "../kernel.cuh"
namespace sm100::decode::head64 {
template
void run_flash_splitkv_mla_fp8_sparse_kernel<ModelType::V32>(const SparseAttnDecodeParams &params);
}
#include "kernel.h"
#include <math_constants.h>
#include <cutlass/barrier.h>
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/tensor.hpp>
#include <cute/arch/tmem_allocator_sm100.hpp>
#include "kerutils/kerutils.cuh"
#include "utils.h"
#include "sm100/helpers.h"
#include "config.h"
namespace sm100::decode::head64 {
template<ModelType MODEL_TYPE>
template<typename TmaParam>
__device__ void
KernelTemplate<MODEL_TYPE>
::flash_fwd_splitkv_mla_fp8_sparse_kernel_devfunc(const SparseAttnDecodeParams &params, const TmaParam &tma_params) {
#if defined(KERUTILS_ENABLE_SM100A)
const int s_q_idx = blockIdx.x;
const int partition_idx = blockIdx.y;
const int warpgroup_idx = cutlass::canonical_warp_group_idx();
const int idx_in_warpgroup = threadIdx.x % 128;
const int warp_idx = cutlass::canonical_warp_idx_sync();
const int lane_idx = threadIdx.x % 32;
extern __shared__ char wksp_buf[];
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);
if (warp_idx == 0 && elect_one_sync()) {
cute::prefetch_tma_descriptor(tma_params.tma_Q_SW128.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor());
cute::prefetch_tma_descriptor(&tma_params.tensor_map_q_sw64);
cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv_nope);
cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv_rope);
}
if (warp_idx == 0) {
if (elect_one_sync()) {
plan.bar_last_store_done.init(128);
plan.bar_q_tma.init(1);
plan.bar_q_utccp.init(1);
for (int i = 0; i < NUM_BUFS; ++i) {
plan.bar_rope_ready[i].init(1);
plan.bar_nope_ready[i].init(128);
plan.bar_raw_ready[i].init(1);
plan.bar_raw_free[i].init(128);
plan.bar_qk_done[i].init(1);
plan.bar_so_ready[i].init(128);
plan.bar_sv_done[i].init(1);
}
for (int i = 0; i < NUM_INDEX_BUFS; ++i) {
plan.bar_valid_coord_scale_ready[i].init(32);
plan.bar_valid_coord_scale_free[i].init(128+128+1+1);
}
cutlass::arch::fence_barrier_init();
}
cute::TMEM::Allocator1Sm().allocate(512, plan.tmem_start_addr.data());
KU_TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0);
cute::TMEM::Allocator1Sm().release_allocation_lock();
}
__syncthreads();
struct MainLoopArgs {
int batch_idx, start_block_idx, end_block_idx;
bool is_no_split; int n_split_idx;
bool bar_phase_batch_rel; // Bar phase of barriers that are used once per batch
int topk_length, extra_topk_length, num_orig_kv_blocks;
bool is_last_batch;
};
auto run_main_loop = [&](auto f) {
// NOTE Putting the following code outside the warpgroup specialization switch results in register spilling.
// [[maybe_unused]] int begin_req_idx, end_req_idx, sched_begin_block_idx, sched_end_block_idx, begin_n_split_idx, is_first_req_splitted, is_last_req_splitted;
DecodingSchedMeta sched_meta;
KU_LDG_256(
params.tile_scheduler_metadata_ptr + partition_idx,
&sched_meta,
".nc",
"no_allocate",
"evict_normal",
"256B"
);
if (sched_meta.begin_req_idx >= params.b) {
return;
}
bool bar_phase_batch_rel = 0;
#pragma unroll 1
for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx, bar_phase_batch_rel ^= 1) {
int topk_length = params.topk_length ? __ldg(params.topk_length + batch_idx) : params.topk;
int orig_topk_padded = max(ku::ceil(topk_length, (int)B_TOPK), (int)B_TOPK);
int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk;
int total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)B_TOPK); // % B_TOPK == 0
int start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0;
int end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : total_topk_padded / B_TOPK;
bool is_split = batch_idx == sched_meta.begin_req_idx ? sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? sched_meta.is_last_req_splitted : false);
int n_split_idx = batch_idx == sched_meta.begin_req_idx ? (__ldg(params.num_splits_ptr+batch_idx) + sched_meta.begin_split_idx) : __ldg(params.num_splits_ptr+batch_idx);
MainLoopArgs args = {
batch_idx, start_block_idx, end_block_idx,
!is_split, n_split_idx,
bar_phase_batch_rel,
topk_length, extra_topk_length,
orig_topk_padded / B_TOPK,
batch_idx == sched_meta.end_req_idx
};
f(args);
NamedBarrier(NUM_THREADS, NamedBarriers::everyone_sync).arrive_and_wait_unaligned();
}
};
struct RingState {
int buf_idx = 0;
bool bar_phase = 0;
int index_buf_idx = 0;
bool index_bar_phase = 0;
CUTE_DEVICE void update() {
bar_phase ^= (buf_idx == NUM_BUFS-1);
buf_idx = (buf_idx+1) % NUM_BUFS;
index_bar_phase ^= (index_buf_idx == NUM_INDEX_BUFS-1);
index_buf_idx = (index_buf_idx+1) % NUM_INDEX_BUFS;
}
};
RingState rs;
if (warpgroup_idx == 0) {
// Scale & Exp warpgroup
// The same technique (and highly similar code) as the sm100 sparse prefill head64 kernel
cutlass::arch::warpgroup_reg_alloc<224>();
constexpr int B_EPI = 64; // Must be equal to the size of the swizzle atom
Tensor sO = make_tensor(make_smem_ptr(plan.u.qo.o.o_buf.data()), SmemLayoutOBuf{});
bf16* sO_bases[B_EPI/8]; // 64 is the size of the swizzle atom (in number of elements) while 8 is the width of each write
CUTE_UNROLL
for (int i = 0; i < B_EPI/8; ++i)
sO_bases[i] = &sO(idx_in_warpgroup%64, (idx_in_warpgroup/64)*128 + i*8);
const float2 scale = float2 {params.sm_scale_div_log2, params.sm_scale_div_log2};
bf16* sS_base = plan.s_p.s.data() + lane_idx*8 + (warp_idx&1)*(B_H/2)*8 + (warp_idx/2)*B_H*(B_TOPK/2);
float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : __ldg((float*)params.attn_sink + (idx_in_warpgroup%64)) * CUDART_L2E_F;
run_main_loop([&](const MainLoopArgs &args) {
cute::tma_store_wait<0>();
plan.bar_last_store_done.arrive();
float mi = MAX_INIT_VAL;
float li = 0.0f;
float real_mi = -CUDART_INF_F;
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {
NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync); // Make sure all intermediate buffers (including p_exchange_buf, rowwise max_buf) are free
plan.bar_valid_coord_scale_ready[rs.index_buf_idx].wait(rs.index_bar_phase); // Put the barrier wait here for more code reordering space
plan.bar_qk_done[rs.buf_idx].wait(rs.bar_phase);
ku::tcgen05_after_thread_sync();
// Load P
float p[B_TOPK/2], p_peer[B_TOPK/2];
if (warp_idx < 2) {
ku::tmem_ld_32dp32bNx<B_TOPK/2>(tmem_cols::P, p);
ku::tmem_ld_32dp32bNx<B_TOPK/2>(tmem_cols::P+32, p_peer);
} else {
ku::tmem_ld_32dp32bNx<B_TOPK/2>(tmem_cols::P, p_peer);
ku::tmem_ld_32dp32bNx<B_TOPK/2>(tmem_cols::P+32, p);
}
cutlass::arch::fence_view_async_tmem_load();
ku::tcgen05_before_thread_sync();
// Reduce within shared mem
{
// Store
// Warp 0, 1 store their right (col 32 ~ 63) part, while warp 2, 3 store their left (row 0 ~ 31) part
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2)/4; ++i)
plan.s_p.p_exchange_buf[warp_idx^2][i*32 + lane_idx] = *(float4*)(p_peer + i*4);
NamedBarrier::arrive_and_wait(64, NamedBarriers::wg0_warp02_sync+(warp_idx&1)); // Synchronize between warp 0 and warp 2, as well as warp 1 - warp 3
// Load
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2)/4; ++i) {
float2 t[2];
*(float4*)t = plan.s_p.p_exchange_buf[warp_idx][i*32 + lane_idx];
float2* cur_p = (float2*)(p + i*4);
cur_p[0] = ku::float2_add(cur_p[0], t[0]);
cur_p[1] = ku::float2_add(cur_p[1], t[1]);
}
}
// Since dual gemm is utilized, the layout of P in register now look like:
//
// 32 32
// +-------+-------+
// | | |
// 32 | Warp0 | Warp2 |
// | | |
// +-------+-------+
// | | |
// 32 | Warp1 | Warp3 |
// | | |
// +-------+-------+
// Mask
uint32_t valid_mask = *((uint32_t*)plan.is_token_valid[rs.index_buf_idx] + (idx_in_warpgroup>=64?1:0));
CUTE_UNROLL
for (int i = 0; i < B_TOPK/2; i += 1) {
if (!(valid_mask>>i&1))
p[i] = -CUDART_INF_F;
}
// Get rowwise max of Pi
float cur_pi_max = -CUDART_INF_F;
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2); i += 1) {
cur_pi_max = max(cur_pi_max, p[i]);
}
cur_pi_max *= params.sm_scale_div_log2;
plan.rowwise_max_buf[idx_in_warpgroup] = cur_pi_max;
NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync); // This also separates "reading p_exchange_buf" and "writing S"
plan.bar_valid_coord_scale_free[rs.index_buf_idx].arrive();
cur_pi_max = max(cur_pi_max, plan.rowwise_max_buf[idx_in_warpgroup^64]);
real_mi = max(real_mi, cur_pi_max);
bool should_scale_o = __any_sync(0xffffffff, cur_pi_max - mi > 6.0f);
// By this point:
// - cur_pi_max, real_mi, and mi is identical within each row (i.e. thread 0+64, 1+65, ...)
// - should_scale_o is identical among every warp, and is identical among threads that controls the same row (i.e. among threads 0~31+64~95; and is identical among threads 32~63+96~127)
// Calc scale factor, and scale li
float new_max, scale_for_old;
if (!should_scale_o) {
// Don't scale O
scale_for_old = 1.0f;
new_max = mi;
} else {
new_max = max(cur_pi_max, mi);
scale_for_old = exp2f(mi - new_max);
}
mi = new_max; // mi is still identical within each row
// Calculate S
__nv_bfloat162 s[(B_TOPK/2)/2];
float2 neg_new_max = float2 {-new_max, -new_max};
float2 cur_sum = float2 {0.0f, 0.0f};
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2)/2; i += 1) {
float2 d = ku::float2_fma(float2{p[i*2], p[i*2+1]}, scale, neg_new_max);
d.x = exp2f(d.x);
d.y = exp2f(d.y);
cur_sum = ku::float2_add(cur_sum, d);
s[i] = __float22bfloat162_rn(d);
}
li = fma(li, scale_for_old, (cur_sum.x + cur_sum.y));
// Write S
CUTE_UNROLL
for (int i = 0; i < B_TOPK/2/8; i += 1) {
*(uint128_t*)(sS_base + B_H*8*i) = *(uint128_t*)(s + i*4);
}
// Scale O
if (block_idx != args.start_block_idx && should_scale_o) {
float2 scale_for_old_float2 = float2 {scale_for_old, scale_for_old};
ku::tcgen05_after_thread_sync();
static constexpr int CHUNK_SIZE = 64;
float2 o[CHUNK_SIZE/2];
CUTE_UNROLL
for (int chunk_idx = 0; chunk_idx < (D_V/2)/CHUNK_SIZE; ++chunk_idx) {
// Load O
ku::tmem_ld_32dp32bNx<CHUNK_SIZE>(tmem_cols::O + chunk_idx*CHUNK_SIZE, o);
cutlass::arch::fence_view_async_tmem_load();
// Mult
for (int i = 0; i < CHUNK_SIZE/2; ++i) {
o[i] = ku::float2_mul(o[i], scale_for_old_float2);
}
// Store O
ku::tmem_st_32dp32bNx<CHUNK_SIZE>(tmem_cols::O + chunk_idx*CHUNK_SIZE, o);
cutlass::arch::fence_view_async_tmem_store();
}
ku::tcgen05_before_thread_sync();
}
fence_view_async_shared();
plan.bar_so_ready[rs.buf_idx].arrive();
if (block_idx != args.end_block_idx-1) {
rs.update(); // Don't update rs for the last round since we want to wait for the last SV gemm
}
}
if (real_mi == -CUDART_INF_F) {
// real_mi == -CUDART_INF_F <=> No valid TopK indices
// We set li to 0 to fit the definition that li := exp(x[i] - mi)
li = 0.0f;
mi = -CUDART_INF_F;
}
// Exchange li
plan.rowwise_max_buf[idx_in_warpgroup] = li;
NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync);
li += plan.rowwise_max_buf[idx_in_warpgroup^64];
// Store li
if (idx_in_warpgroup < B_H) {
if (args.is_no_split) {
float cur_lse = fmaf(mi, CUDART_LN2_F, logf(li));
cur_lse = cur_lse == -CUDART_INF_F ? +CUDART_INF_F : cur_lse;
float* gSoftmaxLse = (float*)params.lse + args.batch_idx*params.stride_lse_b + s_q_idx*params.stride_lse_s_q + idx_in_warpgroup;
*gSoftmaxLse = cur_lse;
} else {
float cur_lse = log2f(li) + mi;
float* gSoftmaxLseAccum = (float*)params.lse_accum + args.n_split_idx*params.stride_lse_accum_split + s_q_idx*params.stride_lse_accum_s_q + idx_in_warpgroup;
*gSoftmaxLseAccum = cur_lse;
}
}
plan.bar_sv_done[rs.buf_idx].wait(rs.bar_phase);
rs.update();
ku::tcgen05_after_thread_sync();
if (args.is_last_batch) {
cudaTriggerProgrammaticLaunchCompletion();
}
if (args.is_no_split) {
Tensor tma_gO = flat_divide(
tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx, args.batch_idx),
Shape<Int<B_H>, Int<64>>{}
)(_, _, _0{}, _);
auto thr_tma = tma_params.tma_O.get_slice(_0{});
Tensor tma_sO = flat_divide(
sO,
Shape<Int<B_H>, Int<64>>{}
)(_, _, _0{}, _);
float o_scale = li == 0.0f ? 0.0f : __fdividef(1.0f, li + exp2f(attn_sink - mi));
float2 o_scale_float2 = {o_scale, o_scale};
float2 o[B_EPI/2];
__nv_bfloat162 o_bf16[B_EPI/2];
CUTE_UNROLL
for (int i = 0; i < (D_V/2) / B_EPI; ++i) {
// Load
ku::tmem_ld_32dp32bNx<B_EPI>(tmem_cols::O + i*B_EPI, o);
cutlass::arch::fence_view_async_tmem_load();
// Scale & Convert
CUTE_UNROLL
for (int j = 0; j < B_EPI/2; ++j) {
o[j] = ku::float2_mul(o[j], o_scale_float2);
o_bf16[j] = __float22bfloat162_rn(o[j]);
}
// Store
int col_base = (i*B_EPI>=D_V/4 ? D_V/2 : 0) + (i*B_EPI%(D_V/4));
CUTE_UNROLL
for (int j = 0; j < B_EPI / 8; ++j)
*(__int128_t*)(sO_bases[j] + col_base*B_H) = *(__int128_t*)(&o_bf16[j*4]);
// Sync
fence_view_async_shared();
NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync);
// S -> G
if (warp_idx == 0 && elect_one_sync()) {
cute::copy(
tma_params.tma_O,
thr_tma.partition_S(tma_sO(_, _, col_base/64)),
thr_tma.partition_D(tma_gO(_, _, col_base/64))
);
}
if (warp_idx == 1 && elect_one_sync()) {
cute::copy(
tma_params.tma_O,
thr_tma.partition_S(tma_sO(_, _, col_base/64 + (D_V/4)/64)),
thr_tma.partition_D(tma_gO(_, _, col_base/64 + (D_V/4)/64))
);
}
}
cute::tma_store_arrive();
} else {
float o_scale = li == 0.0f ? 0.0f : __fdividef(1.0f, li); // Here we leave attn_sink to the combine kernel, otherwise attn_sink will take effect for multiple times
float2 o_scale_float2 = {o_scale, o_scale};
constexpr int B_EPI = 64;
float2 o[B_EPI/2];
Tensor sO = make_tensor(make_smem_ptr(plan.u.qo.o.o_accum_buf.data()), SmemLayoutOAccumBuf{});
CUTE_UNROLL
for (int i = 0; i < (D_V/2) / B_EPI; ++i) {
// Load
ku::tmem_ld_32dp32bNx<B_EPI>(tmem_cols::O + i*B_EPI, o);
cutlass::arch::fence_view_async_tmem_load();
// Scale & Convert
CUTE_UNROLL
for (int j = 0; j < B_EPI/2; ++j)
o[j] = ku::float2_mul(o[j], o_scale_float2);
// Store
int col_base = (idx_in_warpgroup/64)*128 + (i*B_EPI >= D_V/4 ? D_V/2 : 0) + (i*B_EPI%(D_V/4));
CUTE_UNROLL
for (int j = 0; j < B_EPI / 4; ++j)
*(__int128_t*)&sO(idx_in_warpgroup%64, col_base + j*4) = *(__int128_t*)(&o[j*2]);
}
fence_view_async_shared();
NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync);
if (elect_one_sync()) {
CUTE_UNROLL
for (int local_row = 0; local_row < B_H/4; ++local_row) {
int smem_row = local_row*4 + warp_idx;
SM90_BULK_COPY_S2G::copy(
&sO(smem_row, _0{}),
(float*)params.o_accum + args.n_split_idx*params.stride_o_accum_split + s_q_idx*params.stride_o_accum_s_q + smem_row*params.stride_o_accum_h_q,
D_V*sizeof(float)
);
}
cute::tma_store_arrive();
}
}
});
if (warp_idx == 0) {
cute::TMEM::Allocator1Sm().free(0, 512);
}
} else if (warpgroup_idx == 1) {
cutlass::arch::warpgroup_reg_dealloc<72>();
const int warp_idx = cutlass::canonical_warp_idx_sync(); // Missing this leads to reg spilling
if (warp_idx == 4 && elect_one_sync()) {
// MMA Warp
run_main_loop([&](const MainLoopArgs &args) {
if (args.start_block_idx >= args.end_block_idx) {
ku::trap();
}
// Issue Q (SW128) G->S
{
Tensor gQ = tma_params.tma_Q_SW128.get_tma_tensor(tma_params.shape_Q_SW128)(_, _, s_q_idx, args.batch_idx);
Tensor sQ = make_tensor(make_smem_ptr(plan.u.qo.q.data()), SmemLayoutQ_SW128{});
ku::launch_tma_copy(
tma_params.tma_Q_SW128,
gQ,
sQ,
plan.bar_q_tma,
TMA::CacheHintSm90::EVICT_FIRST
);
}
// Issue Q (SW64) G -> S
if constexpr (D_Q_SW64 > 0) {
cute::SM90_TMA_LOAD_5D::copy(
&tma_params.tensor_map_q_sw64,
(uint64_t*)&plan.bar_q_tma,
(uint64_t)TMA::CacheHintSm90::EVICT_FIRST,
plan.u.qo.q_sw64,
0, 0, 0,
s_q_idx, args.batch_idx
);
}
plan.bar_q_tma.arrive_and_expect_tx(B_H*D_Q*sizeof(bf16));
plan.bar_q_tma.wait(args.bar_phase_batch_rel);
ku::tcgen05_after_thread_sync();
// Issue Q (SW128) UTCCP
{
UMMA::SmemDescriptor sQ_desc = UMMA::make_umma_desc<UMMA::Major::K>(
make_tensor(
make_smem_ptr(plan.u.qo.q.data()),
tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H*2>, Int<64>>{} // *2 to leverage dual GEMM
)
)
);
static_assert(D_Q_SW128%128 == 0);
CUTE_UNROLL
for (int tile_idx = 0; tile_idx < D_Q_SW128/128; ++tile_idx) {
// Each tile: 64 x (64*2) logically, 128 x 64 bf16 on TMEM
CUTE_UNROLL
for (int subtile_idx = 0; subtile_idx < 64/16; ++subtile_idx) {
// Each subtile: 64 x (16*2) logically, 128 x 16 bf16 (128dp256b) on TMEM
SM100_UTCCP_128dp256bit_1cta::copy(
sQ_desc + (tile_idx*(B_H*128) + subtile_idx*16) * 2 / 16,
tmem_cols::Q + tile_idx*32 + subtile_idx*8
);
}
}
}
// Issue Q (SW64) UTCCP
if constexpr (D_Q_SW64 > 0) {
UMMA::SmemDescriptor sQ_SW64_desc = UMMA::make_umma_desc<UMMA::Major::K>(
make_tensor(
make_smem_ptr(plan.u.qo.q_sw64),
tile_to_shape(
UMMA::Layout_K_SW64_Atom<bf16>{},
Shape<Int<B_H*2>, Int<32>>{} // *2 to leverage dual GEMM
)
)
);
static_assert(D_Q_SW64%64 == 0);
CUTE_UNROLL
for (int tile_idx = 0; tile_idx < D_Q_SW64/64; ++tile_idx) {
// Each tile: 64 x (32*2) logically, 128 x 32 bf16 on TMEM
CUTE_UNROLL
for (int subtile_idx = 0; subtile_idx < 32/16; ++subtile_idx) {
// Each subtile: 64 x (16*2) logically, 128 x 16 bf16 (128dp256b) on TMEM
SM100_UTCCP_128dp256bit_1cta::copy(
sQ_SW64_desc + (tile_idx*(B_H*64) + subtile_idx*16) * 2 / 16,
tmem_cols::Q + (B_H*D_Q_SW128/2/128) + tile_idx*16 + subtile_idx*8
);
}
}
}
ku::umma_arrive_noelect(plan.bar_q_utccp);
// Allocate tmem tensors
TiledMMA tiled_mma_P = TiledMMA_P{};
TiledMMA tiled_mma_O = TiledMMA_O{};
// NOTE These tXXX tensors are only for a forged layout (so that CuTe is able to generate correct address in cute::gemm)
Tensor tP = partition_fragment_C(tiled_mma_P, Shape<Int<B_H>, _128>{});
Tensor tO = partition_fragment_C(tiled_mma_O, Shape<Int<B_H>, Int<D_V>>{});
tP.data().get() = tmem_cols::P;
tO.data().get() = tmem_cols::O;
// Wait for UTCCP
plan.bar_q_utccp.wait(args.bar_phase_batch_rel);
ku::tcgen05_after_thread_sync();
// Mainloop
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {
if constexpr (MODEL_TYPE == ModelType::V32) {
// V3.2: RoPE behaves like an extra block with size 64, so we can do RoPE first
// QK RoPE
plan.bar_rope_ready[rs.buf_idx].wait(rs.bar_phase);
ku::tcgen05_after_thread_sync();
Tensor tQ_rope = tiled_mma_P.get_slice(_0{}).make_fragment_A(
partition_shape_A(tiled_mma_P, Shape<Int<B_H>, Int<D_ROPE/2>>{})
);
tQ_rope.data().get() = tmem_cols::Q_Tail;
Tensor sK_rope = make_tensor(make_smem_ptr(plan.u.kv.dequant[rs.buf_idx].rope.data()), SmemLayoutKTiles_DualGemm_SW64<2/2>{});
ku::utcmma_ts(tiled_mma_P, tQ_rope, sK_rope, tP, true);
// QK NoPE
plan.bar_nope_ready[rs.buf_idx].wait(rs.bar_phase);
ku::tcgen05_after_thread_sync();
Tensor tQ_nope = tiled_mma_P.get_slice(_0{}).make_fragment_A(
partition_shape_A(tiled_mma_P, Shape<Int<B_H>, Int<D_NOPE/2>>{})
);
tQ_nope.data().get() = tmem_cols::Q;
Tensor sK_nope = make_tensor(make_smem_ptr(plan.u.kv.dequant[rs.buf_idx].nope.data()), SmemLayoutKTiles_DualGemm_SW128<512/64/2>{});
ku::utcmma_ts(tiled_mma_P, tQ_nope, sK_nope, tP, false);
} else {
// MODEL1: RoPE is the last 64 dims within the full 512 dim, which couples with the last 64 dim from the NoPE part when performing dual GEMM. i.e.
//
// logical view: |0|1|2|3|4|5|6|7| (where 7 is the RoPE part)
// dual gemm's view:
// |0|2|4|6|
// |1|3|5|7|
//
// So we must wait for both the NoPE and the RoPE part, and then perform dual GEMM
plan.bar_rope_ready[rs.buf_idx].wait(rs.bar_phase);
plan.bar_nope_ready[rs.buf_idx].wait(rs.bar_phase);
ku::tcgen05_after_thread_sync();
Tensor tQ = tiled_mma_P.get_slice(_0{}).make_fragment_A(
partition_shape_A(tiled_mma_P, Shape<Int<B_H>, Int<D_Q/2>>{})
);
tQ.data().get() = tmem_cols::Q;
Tensor sK = make_tensor(make_smem_ptr(plan.u.kv.dequant[rs.buf_idx].nope.data()), SmemLayoutKTiles_DualGemm_SW128<512/64/2>{});
ku::utcmma_ts(tiled_mma_P, tQ, sK, tP, true);
}
ku::umma_arrive_noelect(plan.bar_qk_done[rs.buf_idx]);
// SV
plan.bar_so_ready[rs.buf_idx].wait(rs.bar_phase);
ku::tcgen05_after_thread_sync();
Tensor sS = make_tensor(make_smem_ptr(plan.s_p.s.data()), SmemLayoutS{});
Tensor sV = make_tensor(make_smem_ptr(plan.u.kv.dequant[rs.buf_idx].nope.data()), SmemLayoutKTilesTransposed_SW128<D_V/64>{}); // NOTE: For MODEL1, it "expands" to the RoPE part.
ku::utcmma_ss(tiled_mma_O, sS, sV, tO, block_idx == args.start_block_idx);
ku::umma_arrive_noelect(plan.bar_sv_done[rs.buf_idx]);
rs.update();
}
});
} else if (warp_idx == 5 && elect_one_sync()) {
// Raw KV NoPE retrieval warp
run_main_loop([&](const MainLoopArgs &args) {
plan.bar_q_utccp.wait(args.bar_phase_batch_rel);
plan.bar_last_store_done.wait(args.bar_phase_batch_rel);
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {
plan.bar_valid_coord_scale_ready[rs.index_buf_idx].wait(rs.index_bar_phase);
plan.bar_raw_free[rs.buf_idx].wait(rs.bar_phase^1);
int4 cur_indices = *(int4*)(plan.tma_coord[rs.index_buf_idx] + 0);
int4 nxt_cur_indices;
CUTE_UNROLL
for (int row = 0; row < B_TOPK; row += 4) {
if (row+4 < B_TOPK)
nxt_cur_indices = *(int4*)(plan.tma_coord[rs.index_buf_idx] + row + 4);
ku::tma_gather4(
block_idx >= args.num_orig_kv_blocks ? &tma_params.tensor_map_extra_kv_nope : &tma_params.tensor_map_kv_nope,
plan.bar_raw_ready[rs.buf_idx],
plan.u.kv.raw_nope[rs.buf_idx].data() + D_NOPE*row,
0,
cur_indices,
(int64_t)TMA::CacheHintSm90::EVICT_LAST
);
cur_indices = nxt_cur_indices;
}
plan.bar_raw_ready[rs.buf_idx].arrive_and_expect_tx(B_TOPK*D_NOPE*sizeof(e4m3));
plan.bar_valid_coord_scale_free[rs.index_buf_idx].arrive();
rs.update();
}
});
} else if (warp_idx == 6 && elect_one_sync()) {
// KV RoPE retrieval warp
run_main_loop([&](const MainLoopArgs &args) {
plan.bar_q_utccp.wait(args.bar_phase_batch_rel);
plan.bar_last_store_done.wait(args.bar_phase_batch_rel);
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {
plan.bar_valid_coord_scale_ready[rs.index_buf_idx].wait(rs.index_bar_phase);
if constexpr (MODEL_TYPE == ModelType::V32) {
plan.bar_qk_done[rs.buf_idx].wait(rs.bar_phase^1);
} else {
plan.bar_sv_done[rs.buf_idx].wait(rs.bar_phase^1);
}
int4 cur_indices = *(int4*)(plan.tma_coord[rs.index_buf_idx] + 0);
int4 nxt_cur_indices;
CUTE_UNROLL
for (int row = 0; row < B_TOPK; row += 4) {
if (row+4 < B_TOPK)
nxt_cur_indices = *(int4*)(plan.tma_coord[rs.index_buf_idx] + row + 4);
CUTE_UNROLL
for (int t = 0; t < D_ROPE/(K_ROPE_SW/2); ++t) {
ku::tma_gather4(
block_idx >= args.num_orig_kv_blocks ? &tma_params.tensor_map_extra_kv_rope : &tma_params.tensor_map_kv_rope,
plan.bar_rope_ready[rs.buf_idx],
plan.u.kv.dequant[rs.buf_idx].rope.data() + (K_ROPE_SW/2)*row + t*B_TOPK*(K_ROPE_SW/2),
t*(K_ROPE_SW/2),
cur_indices,
(int64_t)TMA::CacheHintSm90::EVICT_LAST
);
}
cur_indices = nxt_cur_indices;
}
plan.bar_rope_ready[rs.buf_idx].arrive_and_expect_tx(B_TOPK*D_ROPE*sizeof(bf16));
plan.bar_valid_coord_scale_free[rs.index_buf_idx].arrive();
rs.update();
}
});
} else if (warp_idx == 7) {
// Indices transformation warp
// Responsible for generating: TMA coordinates, scale factors, and valid masks
static_assert(B_TOPK == 64);
static constexpr int tma_coords_step_per_token = MODEL_TYPE == ModelType::V32 ? 656/TMA_K_STRIDE : 576/TMA_K_STRIDE;
int tma_coords_step_per_block = params.stride_kv_block / TMA_K_STRIDE; // must < 2G since k_batch_stride < 1T and TMA_K_STRIDE > 512
int tma_coords_step_per_extra_block = params.stride_extra_kv_block / TMA_K_STRIDE;
uint8_t* k_scales_ptr =
MODEL_TYPE == ModelType::V32 ?
(uint8_t*)params.kv + D_NOPE :
(uint8_t*)params.kv + params.page_block_size*(D_NOPE+2*D_ROPE);
uint8_t* extra_k_scales_ptr =
MODEL_TYPE == ModelType::V32 ?
(uint8_t*)params.extra_kv + D_NOPE :
(uint8_t*)params.extra_kv + params.extra_page_block_size*(D_NOPE+2*D_ROPE);
run_main_loop([&](const MainLoopArgs &args) {
int* indices = (int*)params.indices + params.stride_indices_b*args.batch_idx + params.stride_indices_s_q*s_q_idx;
int* extra_indices = (int*)params.extra_indices + params.stride_extra_indices_b*args.batch_idx + params.stride_extra_indices_s_q*s_q_idx;
struct IsOrigBlock {};
struct IsExtraBlock {};
auto process_one_block = [&](int block_idx, auto is_extra_block_t) {
static constexpr bool IS_EXTRA_BLOCK = std::is_same_v<decltype(is_extra_block_t), IsExtraBlock>;
int cur_block_size = IS_EXTRA_BLOCK ? params.extra_page_block_size : params.page_block_size;
int64_t cur_k_block_stride = IS_EXTRA_BLOCK ? params.stride_extra_kv_block : params.stride_kv_block;
[[maybe_unused]] int cur_k_row_stride = IS_EXTRA_BLOCK ? params.stride_extra_kv_row : params.stride_kv_row;
uint8_t* cur_k_scales_ptr = IS_EXTRA_BLOCK ? extra_k_scales_ptr : k_scales_ptr;
int cur_tma_coords_step_per_block = IS_EXTRA_BLOCK ? tma_coords_step_per_extra_block : tma_coords_step_per_block;
int abs_pos, my_indices[2];
if (!IS_EXTRA_BLOCK) {
abs_pos = block_idx*B_TOPK + lane_idx*2;
*(int2*)my_indices = __ldg((int2*)(indices + abs_pos));
} else {
abs_pos = (block_idx-args.num_orig_kv_blocks)*B_TOPK + lane_idx*2;
*(int2*)my_indices = __ldg((int2*)(extra_indices + abs_pos));
}
plan.bar_valid_coord_scale_free[rs.index_buf_idx].wait(rs.index_bar_phase^1);
int tma_coords[2];
e8m0 scales[2*NUM_SCALES_EACH_TOKEN];
char valid_mask = 0;
CUTE_UNROLL
for (int i = 0; i < 2; ++i) {
int block_idx, idx_in_block;
block_idx = (unsigned int)my_indices[i] / cur_block_size;
idx_in_block = (unsigned int)my_indices[i] % cur_block_size;
bool is_token_valid = my_indices[i] != -1 && (abs_pos+i < (IS_EXTRA_BLOCK?args.extra_topk_length:args.topk_length));
valid_mask |= is_token_valid << i;
tma_coords[i] = is_token_valid ? block_idx*cur_tma_coords_step_per_block + idx_in_block*tma_coords_step_per_token : -1; // If the token is invalid because it topk position exceeds topk_length, we must manually fill tma_coords with -1 to avoid copying-in NaN.
if constexpr (MODEL_TYPE == ModelType::V32) {
int64_t offset = is_token_valid ? block_idx*cur_k_block_stride + idx_in_block*cur_k_row_stride : 0;
float4 cur_scale_fp32 = __ldg((float4*)(cur_k_scales_ptr + offset));
e8m0 res[4];
*(__nv_fp8x2_storage_t*)(res+0) = __nv_cvt_float2_to_e8m0x2(float2{cur_scale_fp32.x, cur_scale_fp32.y}, __NV_NOSAT, cudaRoundZero);
*(__nv_fp8x2_storage_t*)(res+2) = __nv_cvt_float2_to_e8m0x2(float2{cur_scale_fp32.z, cur_scale_fp32.w}, __NV_NOSAT, cudaRoundZero);
if (!is_token_valid) *(uint32_t*)res = (uint32_t)0;
*(uint32_t*)(scales+i*NUM_SCALES_EACH_TOKEN) = *(uint32_t*)(res);
} else {
int64_t offset = block_idx*cur_k_block_stride + idx_in_block*8; // Each token has 7 scale factors with an extra 1B padding
uint64_t scalesx8 = is_token_valid ? __ldg((uint64_t*)(cur_k_scales_ptr + offset)) : 0;
*(uint64_t*)(scales+i*NUM_SCALES_EACH_TOKEN) = scalesx8;
}
}
valid_mask <<= lane_idx%4*2;
valid_mask |= __shfl_xor_sync(0xFFFFFFFF, valid_mask, 0x1);
valid_mask |= __shfl_xor_sync(0xFFFFFFFF, valid_mask, 0x2);
if constexpr (MODEL_TYPE == ModelType::V32) {
*(uint64_t*)(plan.scales[rs.index_buf_idx] + lane_idx*2) = *(uint64_t*)scales;
} else {
*(__int128_t*)(plan.scales[rs.index_buf_idx] + lane_idx*2) = *(__int128_t*)scales;
}
*(int2*)(plan.tma_coord[rs.index_buf_idx] + lane_idx*2) = *(int2*)tma_coords;
if (lane_idx%4 == 0)
plan.is_token_valid[rs.index_buf_idx][lane_idx/4] = valid_mask;
plan.bar_valid_coord_scale_ready[rs.index_buf_idx].arrive();
rs.update();
};
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < min(args.num_orig_kv_blocks, args.end_block_idx); ++block_idx) {
process_one_block(block_idx, IsOrigBlock{});
}
CUTE_NO_UNROLL
for (int block_idx = max(args.start_block_idx, args.num_orig_kv_blocks); block_idx < args.end_block_idx; ++block_idx) {
process_one_block(block_idx, IsExtraBlock{});
}
});
} else {
run_main_loop([&](const MainLoopArgs &args) {});
}
} else {
// Dequant warpgroup
cutlass::arch::warpgroup_reg_alloc<208>();
// 8 threads per token
constexpr int GROUP_SIZE = 8, NUM_GROUPS = 128/8, ROWS_PER_GROUP = B_TOPK / NUM_GROUPS, COLS_PER_GROUP = D_NOPE/(GROUP_SIZE*8);
int group_idx = idx_in_warpgroup/GROUP_SIZE, idx_in_group = idx_in_warpgroup%GROUP_SIZE;
Tensor nope0 = make_tensor(make_smem_ptr(plan.u.kv.dequant[0].nope.data()), SmemLayoutKTiles_SW128<D_NOPE/64>{});
bf16* nope0_base = &nope0(group_idx, idx_in_group*8);
bf16* nope1_base = nope0_base + (plan.u.kv.dequant[1].nope.data() - plan.u.kv.dequant[0].nope.data());
e4m3* raw_nope0_base = plan.u.kv.raw_nope[rs.buf_idx].data() + group_idx*D_NOPE + idx_in_group*8;
e4m3* raw_nope1_base = raw_nope0_base + B_H*D_NOPE;
run_main_loop([&](const MainLoopArgs &args) {
// plan.bar_last_store_done.wait(args.bar_phase_batch_rel); // No need to wait since the raw nope producer must wait
plan.bar_q_utccp.wait(args.bar_phase_batch_rel);
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {
plan.bar_valid_coord_scale_ready[rs.index_buf_idx].wait(rs.index_bar_phase);
plan.bar_raw_ready[rs.buf_idx].wait(rs.bar_phase);
plan.bar_sv_done[rs.buf_idx].wait(rs.bar_phase^1);
uint32_t cur_nope_base_uint_addr = cute::cast_smem_ptr_to_uint(rs.buf_idx == 0 ? nope0_base : nope1_base);
e4m3* raw_nope_base = rs.buf_idx == 0 ? raw_nope0_base : raw_nope1_base;
auto st_128b = [&](int local_row_idx, int local_col_idx, __int128_t &data) {
asm volatile ("st.weak.shared::cta.b128 [%0], %1;\n"
:
: "r"(cur_nope_base_uint_addr + 2*(local_row_idx*NUM_GROUPS*64 + local_col_idx*B_TOPK*64)), "q"(data) // 2 for sizeof(bf16)
); // We have this `asm volatile` here, otherwise the compiler generates ST.E instead of STS
};
auto get_raw_fp8 = [&](int local_row_idx, int local_col_idx) -> uint64_t {
return *(uint64_t*)(raw_nope_base + local_row_idx*NUM_GROUPS*D_NOPE + local_col_idx*(GROUP_SIZE*8));
};
// The following code suffers from a 2-way bank conflict when reading from SMEM.
if constexpr (MODEL_TYPE == ModelType::V32) {
CUTE_UNROLL
for (int local_row_idx = 0; local_row_idx < ROWS_PER_GROUP; ++local_row_idx) {
int row_idx = local_row_idx*NUM_GROUPS + group_idx;
bf16 scales[4];
e8m0 scales_e8m0[4];
*(uint32_t*)scales_e8m0 = *(uint32_t*)plan.scales[rs.index_buf_idx][row_idx];
*(__nv_bfloat162_raw*)(scales+0) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+0));
*(__nv_bfloat162_raw*)(scales+2) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+2));
uint64_t cur_data_fp8x8 = get_raw_fp8(local_row_idx, 0);
CUTE_UNROLL
for (int local_col_idx = 0; local_col_idx < COLS_PER_GROUP; ++local_col_idx) {
ku::nve4m3x2 data_fp8[4];
ku::nvbf16x2 data_bf16[4];
*(uint64_t*)data_fp8 = cur_data_fp8x8;
if (local_col_idx+1 < COLS_PER_GROUP)
cur_data_fp8x8 = get_raw_fp8(local_row_idx, local_col_idx+1);
bf16 scale = scales[local_col_idx / (D_NOPE/(GROUP_SIZE*8)/4)];
CUTE_UNROLL
for (int i = 0; i < 4; ++i) {
data_bf16[i] = fp8x2_to_bf16x2_with_scale(data_fp8[i], *(ku::nvbf16*)(&scale));
}
st_128b(local_row_idx, local_col_idx, *(__int128_t*)data_bf16);
}
}
} else {
CUTE_UNROLL
for (int local_row_idx = 0; local_row_idx < ROWS_PER_GROUP; ++local_row_idx) {
int row_idx = local_row_idx*NUM_GROUPS + group_idx;
bf16 scales[8];
e8m0 scales_e8m0[8];
*(uint64_t*)scales_e8m0 = *(uint64_t*)plan.scales[rs.index_buf_idx][row_idx];
*(__nv_bfloat162_raw*)(scales+0) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+0));
*(__nv_bfloat162_raw*)(scales+2) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+2));
*(__nv_bfloat162_raw*)(scales+4) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+4));
*(__nv_bfloat162_raw*)(scales+6) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+6));
uint64_t cur_data_fp8x8 = get_raw_fp8(local_row_idx, 0);
CUTE_UNROLL
for (int local_col_idx = 0; local_col_idx < COLS_PER_GROUP; ++local_col_idx) {
ku::nve4m3x2 data_fp8[4];
ku::nvbf16x2 data_bf16[4];
*(uint64_t*)data_fp8 = cur_data_fp8x8;
if (local_col_idx+1 < COLS_PER_GROUP)
cur_data_fp8x8 = get_raw_fp8(local_row_idx, local_col_idx+1);
bf16 scale = scales[local_col_idx];
CUTE_UNROLL
for (int i = 0; i < 4; ++i) {
data_bf16[i] = fp8x2_to_bf16x2_with_scale(data_fp8[i], *(ku::nvbf16*)(&scale));
}
st_128b(local_row_idx, local_col_idx, *(__int128_t*)data_bf16);
}
}
}
cutlass::arch::fence_view_async_shared();
plan.bar_nope_ready[rs.buf_idx].arrive();
plan.bar_raw_free[rs.buf_idx].arrive();
plan.bar_valid_coord_scale_free[rs.index_buf_idx].arrive();
rs.update();
}
});
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100 ~ sm119");
}
#endif
}
template<typename Kernel, typename TmaParams>
__global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 1)
flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const SparseAttnDecodeParams params, __grid_constant__ const TmaParams tma_params) {
Kernel::flash_fwd_splitkv_mla_fp8_sparse_kernel_devfunc(params, tma_params);
}
template<ModelType MODEL_TYPE>
void KernelTemplate<MODEL_TYPE>::run(const SparseAttnDecodeParams &params) {
KU_ASSERT(params.topk % B_TOPK == 0, "topk (%d) mod B_TOPK (%d) must be 0", params.topk, B_TOPK);
KU_ASSERT(params.extra_topk % B_TOPK == 0, "extra_topk (%d) mod B_TOPK (%d) must be 0", params.extra_topk, B_TOPK);
KU_ASSERT(params.h_q == B_H);
KU_ASSERT(params.h_kv == 1);
KU_ASSERT(params.d_qk == D_Q);
KU_ASSERT(params.d_v == D_V);
if constexpr (MODEL_TYPE == ModelType::MODEL1) {
constexpr int BYTES_PER_TOKEN = D_NOPE + 2*D_ROPE + 8;
KU_ASSERT(params.stride_kv_row == BYTES_PER_TOKEN, "Each page block in KV cache must be contiguous for head64 sparse fp8 decoding attention in MODEL1"); // Each block must be contiguous
}
auto shape_Q_SW128 = make_shape(B_H, D_Q, params.s_q, params.b);
auto tma_Q_SW128 = cute::make_tma_copy(
SM90_TMA_LOAD{},
make_tensor(
make_gmem_ptr((bf16*)params.q),
make_layout(
shape_Q_SW128,
make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q, params.stride_q_b)
)
),
SmemLayoutQ_SW128{}
);
auto shape_O = make_shape(B_H, D_V, params.s_q, params.b);
auto tma_O = cute::make_tma_copy(
SM90_TMA_STORE{},
make_tensor(
make_gmem_ptr((bf16*)params.out),
make_layout(
shape_O,
make_stride(params.stride_o_h_q, _1{}, params.stride_o_s_q, params.stride_o_b)
)
),
SmemLayoutOBuf_TMA{}
);
CUtensorMap tensor_map_q_sw64{};
if constexpr (D_Q_SW64 > 0) {
tensor_map_q_sw64 = ku::make_tensor_map(
{D_Q_SW64, (uint64_t)params.h_q, D_Q_SW64/32, (uint64_t)params.s_q, (uint64_t)params.b},
ku::make_stride_helper(std::vector<int64_t>{params.stride_q_h_q, (int64_t)32, params.stride_q_s_q, params.stride_q_b}, sizeof(bf16)),
{32, B_H, D_Q_SW64/32, 1, 1},
(bf16*)params.q + D_Q_SW128,
CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B
);
}
auto get_nope_rope_tensormap = [&](bool is_extra, void* k_ptr, int num_blocks, int64_t k_batch_stride) -> std::pair<CUtensorMap, CUtensorMap> {
static_assert(D_NOPE%8 == 0);
KU_ASSERT((int64_t)k_ptr % 16 == 0, "The base address of %sk_ptr (%p) must be 16B aligned for sparse fp8 attention on sm100f", is_extra?"extra_":"", k_ptr);
KU_ASSERT(k_batch_stride % TMA_K_STRIDE == 0, "%sk_cache.stride(0) (%ld) must be a multiple of %d. Padding might be necessary", is_extra?"extra_":"", k_batch_stride, TMA_K_STRIDE);
CUtensorMap tensor_map_kv_nope = ku::make_tensor_map(
{D_NOPE/8, (uint64_t)num_blocks * (k_batch_stride/TMA_K_STRIDE)},
{TMA_K_STRIDE},
{D_NOPE/8, 1},
k_ptr,
CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_INT64,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B
); // NOTE We combine 8 float8 into 1 int64 since boxdim cannot > 256
CUtensorMap tensor_map_kv_rope = ku::make_tensor_map(
{D_ROPE, (uint64_t)num_blocks * (k_batch_stride/TMA_K_STRIDE)},
{TMA_K_STRIDE},
{K_ROPE_SW/2, 1},
(uint8_t*)k_ptr + (MODEL_TYPE == ModelType::V32 ? (D_NOPE+16) : D_NOPE),
CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
K_ROPE_SW == 64 ? CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B : CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B
);
return {tensor_map_kv_nope, tensor_map_kv_rope};
};
auto [tensor_map_kv_nope, tensor_map_kv_rope] = get_nope_rope_tensormap(false, params.kv, params.num_blocks, params.stride_kv_block);
CUtensorMap tensor_map_extra_kv_nope{}, tensor_map_extra_kv_rope{};
if (params.extra_topk > 0) {
std::tie(tensor_map_extra_kv_nope, tensor_map_extra_kv_rope) = get_nope_rope_tensormap(true, params.extra_kv, params.extra_num_blocks, params.stride_extra_kv_block);
}
TmaParams<
decltype(shape_Q_SW128), decltype(tma_Q_SW128),
decltype(shape_O), decltype(tma_O)
> tma_params = {
shape_Q_SW128, tma_Q_SW128,
shape_O, tma_O,
tensor_map_q_sw64,
tensor_map_kv_nope,
tensor_map_kv_rope,
tensor_map_extra_kv_nope,
tensor_map_extra_kv_rope
};
auto mla_kernel = &flash_fwd_splitkv_mla_fp8_sparse_kernel<KernelTemplate<MODEL_TYPE>, decltype(tma_params)>;
constexpr size_t smem_size = sizeof(SharedMemoryPlan);
static_assert(smem_size < 227*1024);
KU_CUDA_CHECK(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// NOTE Don't use PDL because of potential compiler bugs!
mla_kernel<<<dim3(params.s_q, params.num_sm_parts, 1), dim3(NUM_THREADS, 1, 1), smem_size, params.stream>>>(params, tma_params);
KU_CHECK_KERNEL_LAUNCH();
}
template<ModelType MODEL_TYPE>
void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams &params) {
KernelTemplate<MODEL_TYPE>::run(params);
}
}
#pragma once
#include "params.h"
namespace sm100::decode::head64 {
template<ModelType MODEL_TYPE>
void run_flash_splitkv_mla_fp8_sparse_kernel(const SparseAttnDecodeParams &params);
}
#pragma once
#include <cute/tensor.hpp>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include "defines.h"
namespace sm100 {
using namespace cute;
CUTE_DEVICE
int int4_max(int4 t) {
return max(max(t.x, t.y), max(t.z, t.w));
}
CUTE_DEVICE
int int4_min(int4 t) {
return min(min(t.x, t.y), min(t.z, t.w));
}
// Convert 2x fp8_e4m3 to 2x bf16 with scaling
CUTE_DEVICE
nv_bfloat162 fp8x2_to_bf16x2_with_scale(__nv_fp8x2_e4m3 data, nv_bfloat16 scale) {
// TODO Use native conversion for CUDA >= 13.1
float2 data_float2 = (float2)data;
nv_bfloat162 data_bf16x2 = __float22bfloat162_rn(data_float2);
return nv_bfloat162 {
data_bf16x2.x * scale,
data_bf16x2.y * scale
};
}
}
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/kernel_hardware_info.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cute/tensor.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
template<typename Atom, typename TA, typename TB, typename TC>
CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {
constexpr int rA = decltype(rank(tA))::value;
constexpr int rB = decltype(rank(tB))::value;
constexpr int rC = decltype(rank(tC))::value;
static_assert(rA == 3 && rB == 3 && rC == 3);
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tA); k_block++) {
cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC);
atom.accumulate_ = decltype(atom.accumulate_)::One;
}
}
template<typename Atom, typename TA, typename TB, typename TC>
CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) {
atom.accumulate_ = decltype(atom.accumulate_)::Zero;
gemm_reset_zero_acc(atom, tA, tB, tC);
}
template<class Layout, class Stages = _1>
CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) {
return composition(layout, prepend<decltype(rank(layout))::value>(make_layout(stages), _));
}
template<class T>
CUTE_DEVICE T warp_uniform(T a) {
return __shfl_sync(0xffffffff, a, 0);
}
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg, class... TAs, class... TMs>
CUTE_HOST_DEVICE constexpr
auto
to_tiled_mma_sm100_ts(
TiledMMA<MMA_Atom<
MMA_Traits<SM100_MMA_F8F6F4_SS, a_type, b_type, c_type,
cute::C<M>, cute::C<N>,
cute::integral_constant<UMMA::Major, a_major>,
cute::integral_constant<UMMA::Major, b_major>,
cute::integral_constant<UMMA::ScaleIn, a_neg>,
cute::integral_constant<UMMA::ScaleIn, b_neg>>,
TAs...>, TMs...>) {
return TiledMMA<MMA_Atom<
MMA_Traits<SM100_MMA_F8F6F4_TS<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg, UMMA::Saturate::False>>,
TAs...>, TMs...>{};
}
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg, class... TAs, class... TMs>
CUTE_HOST_DEVICE constexpr
auto
to_tiled_mma_sm100_ts(
TiledMMA<MMA_Atom<
SM100_MMA_F16BF16_SS<a_type, b_type, c_type,
M, N,
a_major,
b_major,
a_neg,
b_neg>,
TAs...>, TMs...>) {
return TiledMMA<MMA_Atom<
SM100_MMA_F16BF16_TS<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg, UMMA::Saturate::False>,
TAs...>, TMs...>{};
}
template<uint32_t RegCount>
CUTLASS_DEVICE
void warpgroup_reg_set() {
if constexpr (RegCount < 128) {
cutlass::arch::warpgroup_reg_dealloc<RegCount>();
}
else {
cutlass::arch::warpgroup_reg_alloc<RegCount>();
}
}
} // namespace cutlass::fmha::collective
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
struct NoMask {
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
return ceil_div(get<1>(problem_size), get<1>(tile_shape));
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_masked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
return 0;
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_unmasked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
return get_trip_count(blk_coord, tile_shape, problem_size);
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void apply_mask(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size) {
return;
}
};
struct ResidualMask : NoMask {
using Base = NoMask;
template <class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE int get_masked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
return 1;
}
return 0;
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_unmasked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
// if the sequence length does not divide the tile size evenly
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
return get_trip_count(blk_coord, tile_shape, problem_size) - 1;
}
return get_trip_count(blk_coord, tile_shape, problem_size);
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void apply_mask(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size) {
// This is useful is seqlen_k % kBlockN != 0 since it masks
// the remaining elements out from softmax.
// d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar
// issues as they are transparently taken care of by TMA and the
// epilogue, if it is instantiated with predication support.
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if (get<1>(pos) >= get<1>(problem_size)) {
acc_qk(i) = -INFINITY;
}
}
}
};
struct ResidualMaskForBackward : NoMask {
using Base = NoMask;
template <class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE int get_masked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
return 1;
}
return 0;
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_unmasked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
// if the sequence length does not divide the tile size evenly
if (get<1>(problem_size) % get<1>(tile_shape) != 0) {
return get_trip_count(blk_coord, tile_shape, problem_size) - 1;
}
return get_trip_count(blk_coord, tile_shape, problem_size);
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void apply_mask(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size) {
// This is useful is seqlen_k % kBlockN != 0 since it masks
// the remaining elements out from softmax.
// d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar
// issues as they are transparently taken care of by TMA and the
// epilogue, if it is instantiated with predication support.
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if (! elem_less(pos, select<0,1>(problem_size))) {
acc_qk(i) = -INFINITY;
}
}
}
};
// There are two ways to do causal if N_Q != N_K
// (1) The Q is at the beginning of the matrix
// (2) The Q is at the end of the matrix
template<bool kIsQBegin = true>
struct CausalMask : NoMask {
using Base = NoMask;
static constexpr bool IsQBegin = kIsQBegin;
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
// See note below on different ways to think about causal attention
// Again, we'd add the offset_q into the max_blocks_q calculation
int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size);
if constexpr (IsQBegin) {
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape));
return std::min(max_blocks_k, max_blocks_q);
} else {
const int offset_q = get<1>(problem_size) - get<0>(problem_size);
int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape) + offset_q, get<1>(tile_shape));
return std::min(max_blocks_k, max_blocks_q);
}
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_masked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
int trip_count = get_trip_count(blk_coord, tile_shape, problem_size);
if constexpr (IsQBegin) {
return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape))));
} else {
const int corner_count = int((get<1>(problem_size) % get<1>(tile_shape) || get<0>(problem_size) % get<0>(tile_shape))) ;
return std::min(trip_count, int(ceil_div(get<0>(tile_shape), get<1>(tile_shape))) + corner_count);
}
}
template<class BlkCoord, class TileShape, class ProblemSize>
CUTLASS_DEVICE
int get_unmasked_trip_count(
BlkCoord const& blk_coord,
TileShape const& tile_shape,
ProblemSize const& problem_size) {
return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size);
}
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void apply_mask(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size) {
// There are two ways to do causal if N_Q != N_K
// (1) is to assume that the Q is at the beginning of the matrix
// - this is the default setting.
// (2) is that it is at the end of the matrix
// - this is usually what we want for inference settings
// where we only compute the next row and use cache for the rest
// - if you'd like this, you only need to set kIsQBegin=false
if constexpr (IsQBegin) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) {
acc_qk(i) = -INFINITY;
}
}
} else {
const auto offset_q = get<1>(problem_size) - get<0>(problem_size);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
if ((get<0>(pos) + offset_q < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) {
acc_qk(i) = -INFINITY;
}
}
}
}
};
template<bool kIsQBegin = true>
struct CausalForBackwardMask : CausalMask<kIsQBegin>, ResidualMaskForBackward {
using Base = CausalMask<kIsQBegin>;
template<class AccQK, class IndexQK, class ProblemSize>
CUTLASS_DEVICE
void apply_mask(
AccQK& acc_qk,
IndexQK const& index_qk,
ProblemSize const& problem_size) {
// There are two ways to do causal if N_Q != N_K
// (1) is to assume that the Q is at the beginning of the matrix
// - this is what we demonstrate here
// (2) is that it is at the end of the matrix
// - this is usually what we want for inference settings
// where we only compute the next row and use cache for the rest
// - if you'd like this, you only need to add an offset like so:
// get<0>(pos) + offset_q < get<1>(pos)
int offset_q = 0;
if constexpr (!kIsQBegin) {
offset_q = get<1>(problem_size) - get<0>(problem_size);
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(acc_qk); i++) {
auto pos = index_qk(i);
bool masked = (get<0>(pos) + offset_q < get<1>(pos)) || !elem_less(pos, problem_size);
if (masked) {
acc_qk(i) = -INFINITY;
}
}
}
};
struct VariableLength {
int max_length;
int* cumulative_length = nullptr;
int total_length = -1;
CUTE_HOST_DEVICE operator int() const {
return max_length;
}
};
template<class T> struct is_variable_length_impl : std::false_type {};
template<> struct is_variable_length_impl<VariableLength> : std::true_type {};
template<class T> constexpr bool is_variable_length_v = is_variable_length_impl<remove_cvref_t<T>>::value;
template<class Shape, class Idx>
CUTE_HOST_DEVICE
constexpr auto
apply_variable_length(Shape const& shape, Idx const& idx) {
return transform_leaf(shape, [&](auto const& s) {
if constexpr (is_variable_length_v<decltype(s)>) {
return s.cumulative_length[idx+1] - s.cumulative_length[idx];
}
else {
return s;
}
});
}
template<class Shape, class Coord, class Idx>
CUTE_HOST_DEVICE
constexpr auto
apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) {
auto new_shape = apply_variable_length(shape, idx);
auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) {
if constexpr (is_variable_length_v<decltype(s)>) {
return cute::make_tuple(c, s.cumulative_length[idx]);
}
else {
return c;
}
});
return cute::make_tuple(new_shape, new_coord);
}
template<class Shape, class Coord>
CUTE_HOST_DEVICE
constexpr auto
apply_variable_length_offset(Shape const& shape, Coord const& coord) {
auto idx = back(back(coord));
auto result_shape = transform_leaf(shape, [&](auto const& s) {
if constexpr (is_variable_length_v<decltype(s)>) {
return s.cumulative_length[idx+1] - s.cumulative_length[idx];
}
else {
return s;
}
});
auto result_offset = transform_leaf(coord, shape, [&](auto const& c, auto const& s) {
if constexpr (is_variable_length_v<decltype(s)>) {
return s.cumulative_length[idx];
}
else {
return _0{};
}
});
return cute::make_tuple(result_shape, result_offset);
}
} // namespace cutlass::fmha::collective
namespace cute {
template<>
struct is_integral<cutlass::fmha::collective::VariableLength> : true_type {};
CUTE_HOST_DEVICE
void print(cutlass::fmha::collective::VariableLength a) {
printf("Varlen<%d, %p>", a.max_length, a.cumulative_length);
}
}
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
namespace cutlass::fmha::collective {
template<
class Element,
class ElementAcc,
class TileShape, // Q, D, _
class StrideO, // Q, D, B
class StrideLSE_, // Q, B
class OrderLoadEpilogue = cute::false_type
>
struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
using Pipeline = cutlass::PipelineAsync<2>;
// using SmemLayoutO = decltypa(make_layout(append<3>(select<0,1>(TileShape_WG{}), _2{})));
using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
cute::UMMA::Major::K, Element, tuple_element_t<0, TileShape>, tuple_element_t<1, TileShape>>());
// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{}));
using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{}));
using SmemLayoutO_ = SmemLayoutO;
using StrideLSE = StrideLSE_;
using ElementOut = Element;
static const int NumWarpsEpilogue = 1;
static const int NumWarpsLoad = 1;
struct TensorStorage {
using SmemLayoutO = SmemLayoutO_;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutO>> smem_o;
};
struct Arguments {
Element* ptr_O;
StrideO dO;
ElementAcc* ptr_LSE;
StrideLSE dLSE;
};
using TMA_O = decltype(make_tma_copy(
SM90_TMA_STORE{},
make_tensor((Element*) nullptr, repeat_like(StrideO{}, 0), StrideO{}),
SmemLayoutO{}(_,_,_0{})
));
struct Params {
TMA_O tma_store_o;
ElementAcc* ptr_LSE;
StrideLSE dLSE;
};
// FMHA and MLA have different input ProblemShapes;
// get problem_shape_O according to the input ProblemShape.
template<class ProblemShape>
CUTLASS_DEVICE static constexpr
auto get_problem_shape_O (
ProblemShape const& problem_shape) {
if constexpr (rank_v<decltype(get<2>(ProblemShape{}))> == 2) {
return replace<1>(select<0,2,3>(problem_shape), get<2, 0>(problem_shape));
} else {
return select<0,2,3>(problem_shape);
}
}
template<class ProblemShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape,
Arguments const& args,
void* workspace = nullptr) {
auto ptr_O = args.ptr_O;
StrideO dO = args.dO;
auto problem_shape_O = get_problem_shape_O(problem_shape);
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(problem_shape).max_length;
get<0>(problem_shape_O).max_length = max(1, max_length_q);
// for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dO) = get<0>(dO);
get<2,1>(problem_shape_O) = max(1, max_length_q * (1 + get<2,1>(problem_shape_O)));
// offset ptr by the amount we add back in later
ptr_O -= max_length_q * get<0>(dO);
}
} else {
get<0>(problem_shape_O) = max(1, get<0>(problem_shape_O));
}
auto tma_store_o = make_tma_copy(
SM90_TMA_STORE{},
make_tensor(ptr_O, problem_shape_O, dO),
SmemLayoutO{}(_,_,_0{})
);
return {
tma_store_o,
args.ptr_LSE,
args.dLSE
};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& params) {
cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor());
}
const Params& params;
CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {}
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
CUTLASS_DEVICE auto
store(
BlkCoord const& blk_coord_in, ProblemShape const& problem_shape,
Params const& params, ParamsProblemShape const& params_problem_shape,
TensorStorage& shared_storage,
Pipeline& pipeline, typename Pipeline::PipelineState& pipeline_consumer_state) {
BlkCoord blk_coord = blk_coord_in;
uint32_t lane_predicate = cute::elect_one_sync();
using X = Underscore;
int o0_index = 2 * get<0>(blk_coord);
int o1_index = 2 * get<0>(blk_coord) + 1;
Tensor mO_qdl_p = params.tma_store_o.get_tma_tensor(get_problem_shape_O(problem_shape));
// offset mode 0 by (max_length - real_length)
// offset mode 3,1 by cumulative_length + real_length
// the ptr is already offset by - max_length
// so in total this achieves
int offs_0 = 0;
int offs_2_1 = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(params_problem_shape).max_length;
offs_0 = max_length_q - get<0>(problem_shape);
offs_2_1 = cumulative_length_q[get<2,1>(blk_coord)] + get<0>(problem_shape);
get<2,1>(blk_coord) = 0;
}
}
Tensor mO_qdl = domain_offset(make_coord(offs_0, _0{}, make_coord(_0{}, offs_2_1)), mO_qdl_p);
Tensor gO_qdl = local_tile(mO_qdl, TileShape{}, make_coord(_, _, _), Step<_1, _1, X>{});
Tensor gO = gO_qdl(_, _, _, _0{}, get<2>(blk_coord));
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{});
auto block_tma = params.tma_store_o.get_slice(0);
Tensor tOsO = block_tma.partition_S(sO);
Tensor tOgO = block_tma.partition_D(gO);
auto pipeline_release_state = pipeline_consumer_state;
// O1 O2
// one pipeline: O
// wait from corr, issue tma store on smem
pipeline.consumer_wait(pipeline_consumer_state);
++pipeline_consumer_state;
if (lane_predicate) {
copy(params.tma_store_o, tOsO(_,_,_,_0{}), tOgO(_,_,_,o0_index));
}
tma_store_arrive();
pipeline.consumer_wait(pipeline_consumer_state);
++pipeline_consumer_state;
if (lane_predicate) {
copy(params.tma_store_o, tOsO(_,_,_,_1{}), tOgO(_,_,_,o1_index));
}
tma_store_arrive();
tma_store_wait<1>();
pipeline.consumer_release(pipeline_release_state);
++pipeline_release_state;
tma_store_wait<0>();
if constexpr (cute::is_same_v<OrderLoadEpilogue, cute::true_type>) {
cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
}
pipeline.consumer_release(pipeline_release_state);
++pipeline_release_state;
}
};
} // namespace cutlass::fmha::collective
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cute/arch/simd_sm100.hpp"
#include "cute/tensor.hpp"
#include "cute/layout.hpp"
#include "../collective/fmha_common.hpp"
#include "../collective/fmha_fusion.hpp"
#include "../collective/sm100_fmha_load_tma_warpspecialized.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
template<
class Element_,
class ElementQK_,
class ElementPV_,
class TileShape_,
class StrideQ_,
class StrideK_,
class StrideV_,
class Mask_,
// shape here is QG K H
// and referes to the two softmax warps
// (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V)
// (1, 2, 1) means they sit side by side (best for small Q / large K)
class ThreadShape = Shape<_2, _1, _1>,
// Since shared memory is sufficient for FMHA, there is no need to reuse shared memory.
class OrderLoadEpilogue = cute::false_type
>
struct Sm100FmhaFwdMainloopTmaWarpspecialized {
using Element = Element_;
using ElementQK = ElementQK_;
using ElementPV = ElementPV_;
using TileShape = TileShape_;
using StrideQ = StrideQ_;
using StrideK = StrideK_;
using StrideV = StrideV_;
using Mask = Mask_;
static constexpr int StageCountQ = 2;
static constexpr int StageCountKV = sizeof(Element_) == 1 ? 4 : 3;
using StagesQ = cutlass::gemm::collective::StageCount<StageCountQ>;
using StagesKV = cutlass::gemm::collective::StageCount<StageCountKV>;
using ClusterShape = Shape<_1, _1, _1>;
static const int Alignment = 128 / sizeof_bits_v<Element>;
using TileShapeQK = decltype(shape_div(TileShape{}, ThreadShape{}));
using TileShapePV = decltype(select<0,2,1>(TileShapeQK{}));
using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
Element, StrideQ, Alignment,
Element, StrideK, Alignment,
ElementQK,
TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/,
cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp;
using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
// the stride for A does not matter since we do not load from smem at all
Element, StrideK, Alignment,
Element, decltype(select<1,0,2>(StrideV{})), Alignment,
ElementPV,
TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/,
cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp;
using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int<StageCountQ>{}));
using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int<StageCountKV>{}));
using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int<StageCountKV>{}));
// Reuse shared memory for V and O.
static constexpr bool IsOrderLoadEpilogue = std::is_same_v<OrderLoadEpilogue, cute::true_type>;
struct TensorStorage {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
union {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
};
};
enum class TmemAllocation : uint32_t {
kSizeS = 128,
kSizeO = 128,
kSizeP = 32,
S0 = 0,
S1 = S0 + kSizeS,
V0 = S0, // stats storage from softmax to correction
V1 = S1,
P0 = S0 + kSizeP,
P1 = S1 + kSizeP,
O0 = S1 + kSizeS,
O1 = O0 + kSizeO,
kEnd = O1 + kSizeO
};
// indices for V0 / V1
enum : int {
kIdxOldRowMax = 0,
kIdxNewRowMax = 1,
kIdxFinalRowSum = 0,
kIdxFinalRowMax = 1
};
// from load to mma warp, protects q in smem
using PipelineQ = cutlass::PipelineTmaUmmaAsync<
StageCountQ,
typename CollectiveMmaQK::AtomThrShapeMNK
>;
// from load to mma warp, protects k/v in smem
using PipelineKV = cutlass::PipelineTmaUmmaAsync<
StageCountKV,
typename CollectiveMmaQK::AtomThrShapeMNK
>;
// from mma to softmax0/1 warp, protects S in tmem
// (not sure yet about the reverse direction)
// there is one pipe per softmax warp, and the mma warp alternates between them
using PipelineS = cutlass::PipelineUmmaAsync<1>;
// from softmax0/1/ to correction wg
using PipelineC = cutlass::PipelineAsync<1>;
// from mma to correction
using PipelineO = cutlass::PipelineUmmaAsync<2>;
// from corr to epilogue
using PipelineE = cutlass::PipelineAsync<2>;
using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier<
/*stages*/ 1, /*groups*/ 2>;
static const int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v<Element>);
static const int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);
static const int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);
static_assert(TransactionBytesLoadK == TransactionBytesLoadV, "K and V smem layouts must be of equal size");
using Load = Sm100FmhaLoadTmaWarpspecialized<
Element, StrideQ, StrideK, StrideV,
CollectiveMmaQK, CollectiveMmaPV,
SmemLayoutQ, SmemLayoutK, SmemLayoutV,
TensorStorage, PipelineQ, PipelineKV, Mask, TileShape
>;
struct Arguments {
typename Load::Arguments load;
// if zero, defaults to 1/sqrt(D)
float scale_softmax = 0.0f;
// scaling factors to dequantize QKV
float scale_q = 1.0f;
float scale_k = 1.0f;
float scale_v = 1.0f;
// scaling factor to quantize O
float inv_scale_o = 1.0f;
};
struct Params {
typename Load::Params load;
float scale_softmax;
float scale_softmax_log2;
float scale_output;
};
template<class ProblemShape>
static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template<class ProblemShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape,
Arguments const& args,
void* workspace) {
float scale_softmax = args.scale_softmax;
if (scale_softmax == 0.0f) {
scale_softmax = 1.0f / (float) std::sqrt(get<2>(problem_shape));
}
float log2_e = static_cast<float>(std::log2(std::exp(1.0)));
return Params{
Load::to_underlying_arguments(problem_shape, args.load, workspace),
args.scale_q * args.scale_k * scale_softmax,
args.scale_q * args.scale_k * log2_e * scale_softmax,
args.scale_v * args.inv_scale_o
};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& params) {
Load::prefetch_tma_descriptors(params.load);
}
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
CUTLASS_DEVICE void
load(
BlkCoord const& blk_coord, ProblemShape const& problem_shape,
Params const& params, ParamsProblemShape const& params_problem_shape,
TensorStorage& storage,
PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state,
PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) {
Load load;
load.load(blk_coord, problem_shape, params.load, params_problem_shape,
storage,
pipeline_q, pipeline_q_producer_state,
pipeline_kv, pipeline_kv_producer_state);
}
template<class BlkCoord, class ProblemShape>
CUTLASS_DEVICE auto
mma(
BlkCoord const& blk_coord,
Params const& params, ProblemShape const& problem_shape,
TensorStorage& storage,
PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state,
PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state,
PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state,
PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state,
PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) {
auto pipeline_q_release_state = pipeline_q_consumer_state;
auto pipeline_kv_release_state = pipeline_kv_consumer_state;
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
typename CollectiveMmaQK::TiledMma mma_qk;
ThrMMA thr_mma_qk = mma_qk.get_slice(0);
typename CollectiveMmaPV::TiledMma mma_pv;
TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv);
ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0);
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ);
Tensor tSrK = thr_mma_qk.make_fragment_B(sK);
Tensor tOrV = thr_mma_pv.make_fragment_B(sV);
// tmem layout is
// S0 S1`O0 O1
// sequential in memory, where S overlaps with P and V
Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{}));
Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{}));
Tensor tStS0 = tStS;
tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0);
Tensor tStS1 = tStS;
tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1);
Tensor tOtO0 = tOtO;
tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0);
Tensor tOtO1 = tOtO;
tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1);
Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{});
Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging
Tensor tOrP0 = tOrP;
tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0);
Tensor tOrP1 = tOrP;
tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1);
int k_index = 0;
int v_index = 0;
int q_index = 0;
// wait for Q1
q_index = pipeline_q_consumer_state.index();
pipeline_q.consumer_wait(pipeline_q_consumer_state);
++pipeline_q_consumer_state;
Tensor tSrQ0 = tSrQ(_,_,_,q_index);
// wait for K1
k_index = pipeline_kv_consumer_state.index();
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
// gemm Q1 * K1 -> S1
pipeline_s0.producer_acquire(pipeline_s0_producer_state);
gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0);
pipeline_s0.producer_commit(pipeline_s0_producer_state);
++pipeline_s0_producer_state;
// release K1
if constexpr (get<1>(ThreadShape{}) > 1) {
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
}
// wait for Q2
if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) {
q_index = pipeline_q_consumer_state.index();
pipeline_q.consumer_wait(pipeline_q_consumer_state);
++pipeline_q_consumer_state;
}
Tensor tSrQ1 = tSrQ(_,_,_,q_index);
if constexpr (get<1>(ThreadShape{}) > 1) {
k_index = pipeline_kv_consumer_state.index();
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
}
pipeline_s1.producer_acquire(pipeline_s1_producer_state);
// gemm Q2 * K1 -> S2
gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1);
pipeline_s1.producer_commit(pipeline_s1_producer_state);
++pipeline_s1_producer_state;
// release K1
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
// wait for V1
v_index = pipeline_kv_consumer_state.index();
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
// this acquire returns the ownership of all of S0 to the mma warp
// including the P0 part
// acquire corr first to take it out of the critical
// path since softmax takes longer
pipeline_corr.producer_acquire(pipeline_corr_producer_state);
pipeline_s0.producer_acquire(pipeline_s0_producer_state);
// gemm P1 * V1 -> O1
gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0);
pipeline_corr.producer_commit(pipeline_corr_producer_state);
++pipeline_corr_producer_state;
if constexpr (get<1>(ThreadShape{}) > 1) {
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
}
mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero;
// loop:
mask_tile_count -= 1;
for (; mask_tile_count > 0; mask_tile_count -= 1) {
// wait for Ki
k_index = (pipeline_kv_consumer_state.index());
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
// gemm Q1 * Ki -> S1
gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0);
pipeline_s0.producer_commit(pipeline_s0_producer_state);
++pipeline_s0_producer_state;
if constexpr (get<1>(ThreadShape{}) > 1) {
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
}
// gemm P2 * V(i-1) -> O2
if constexpr (get<1>(ThreadShape{}) > 1) {
v_index = pipeline_kv_consumer_state.index();
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
}
pipeline_corr.producer_acquire(pipeline_corr_producer_state);
pipeline_s1.producer_acquire(pipeline_s1_producer_state);
gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1);
pipeline_corr.producer_commit(pipeline_corr_producer_state);
++pipeline_corr_producer_state;
// release V(i-1)
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
if constexpr (get<1>(ThreadShape{}) > 1) {
k_index = (pipeline_kv_consumer_state.index());
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
}
// gemm Q2 * Ki -> S2
gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1);
pipeline_s1.producer_commit(pipeline_s1_producer_state);
++pipeline_s1_producer_state;
// release Ki
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
// wait for Vi
v_index = (pipeline_kv_consumer_state.index());
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
// gemm P1 * Vi -> O1
pipeline_corr.producer_acquire(pipeline_corr_producer_state);
pipeline_s0.producer_acquire(pipeline_s0_producer_state);
gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0);
pipeline_corr.producer_commit(pipeline_corr_producer_state);
++pipeline_corr_producer_state;
if constexpr (get<1>(ThreadShape{}) > 1) {
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
}
}
// release Q1
pipeline_q.consumer_release(pipeline_q_release_state);
++pipeline_q_release_state;
// release Q2
if constexpr (get<0>(ThreadShape{}) > 1) {
pipeline_q.consumer_release(pipeline_q_release_state);
++pipeline_q_release_state;
}
// wait for Vi
if constexpr (get<1>(ThreadShape{}) > 1) {
v_index = pipeline_kv_consumer_state.index();
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
}
// gemm P2 * Vi -> O2
pipeline_corr.producer_acquire(pipeline_corr_producer_state);
pipeline_s1.producer_acquire(pipeline_s1_producer_state);
gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1);
pipeline_corr.producer_commit(pipeline_corr_producer_state);
++pipeline_corr_producer_state;
// release Vi
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
pipeline_s0.producer_commit(pipeline_s0_producer_state);
++pipeline_s0_producer_state;
pipeline_s1.producer_commit(pipeline_s1_producer_state);
++pipeline_s1_producer_state;
// T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ...
// Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ...
}
template<bool need_apply_mask, class Stage, class BlkCoord, class CoordTensor, class ProblemShape>
CUTLASS_DEVICE auto
softmax_step(
float& row_max, float& row_sum,
Stage stage, bool final_call,
BlkCoord const& blk_coord, CoordTensor const& cS,
Params const& params, ProblemShape const& problem_shape,
PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state,
PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state,
OrderBarrierSoftmax& order_s) {
Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);
Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{}));
tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1);
Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1);
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
// Each thread owns a single row
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS);
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS);
Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS);
auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v);
auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx);
Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v);
Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v);
auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P);
auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx);
Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P);
tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get());
Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P);
// wait on tensor core pipe
pipeline_s.consumer_wait(pipeline_s_consumer_state);
// read all of S from tmem into reg mem
Tensor tTMEM_LOADrS = make_tensor<ElementQK>(shape(tTMEM_LOADcS));
copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS);
if constexpr (need_apply_mask) {
Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape);
}
ElementQK old_row_max = row_max;
{
// compute rowmax
float row_max_0 = row_max;
float row_max_1 = row_max;
float row_max_2 = row_max;
float row_max_3 = row_max;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTMEM_LOADrS); i += 4) {
row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i));
row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1));
row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2));
row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3));
}
row_max = ::fmax(row_max_0, row_max_1);
row_max = ::fmax(row_max, row_max_2);
row_max = ::fmax(row_max, row_max_3);
}
ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max;
Tensor tTMEM_STOREVrS = make_tensor<ElementQK>(shape(tTMEM_STOREVcS));
tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max;
tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe;
copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS);
pipeline_c.producer_commit(pipeline_c_producer_state);
++pipeline_c_producer_state;
// notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's)
ElementQK scale = params.scale_softmax_log2;
ElementQK row_max_scale = row_max_safe * scale;
float2 scale_fp32x2 = make_float2(scale, scale);
float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale);
Tensor tTMEM_STORErS_x4 = make_tensor<uint32_t>(shape(tTMEM_STOREcS));
constexpr int kConversionsPerStep = 2;
Tensor tTMEM_STORErS_x4_e = recast<Array<Element, kConversionsPerStep>>(tTMEM_STORErS_x4);
NumericArrayConverter<Element, ElementQK, kConversionsPerStep> convert;
const int kReleasePipeCount = 10; // must be multiple of 2
order_s.wait();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTMEM_LOADrS); i += 2) {
float2 in = make_float2(
tTMEM_LOADrS(i + 0),
tTMEM_LOADrS(i + 1)
);
float2 out;
cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2);
tTMEM_LOADrS(i + 0) = out.x;
tTMEM_LOADrS(i + 1) = out.y;
tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0));
tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1));
Array<ElementQK, kConversionsPerStep> in_conv;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < kConversionsPerStep; j++) {
in_conv[j] = tTMEM_LOADrS(i + j);
}
tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv);
if (i == size(tTMEM_LOADrS) - kReleasePipeCount) {
order_s.arrive();
}
// this prevents register spills in fp16
if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) {
if (i == size(tTMEM_LOADrS) - 6) {
copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0));
}
}
}
// tmem_store(reg_S8) -> op_P
CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{});
CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{});
copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1));
cutlass::arch::fence_view_async_tmem_store();
// notify tensor core warp that P is ready
pipeline_s.consumer_release(pipeline_s_consumer_state);
++pipeline_s_consumer_state;
pipeline_c.producer_acquire(pipeline_c_producer_state);
ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe));
row_sum *= acc_scale;
// row_sum = sum(reg_S)
float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum);
float2 local_row_sum_1 = make_float2(0, 0);
float2 local_row_sum_2 = make_float2(0, 0);
float2 local_row_sum_3 = make_float2(0, 0);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTMEM_LOADrS); i += 8) {
// row_sum += tTMEM_LOADrS(i);
float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1));
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in);
in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1));
cute::add(local_row_sum_1, local_row_sum_1, in);
in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1));
cute::add(local_row_sum_2, local_row_sum_2, in);
in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1));
cute::add(local_row_sum_3, local_row_sum_3, in);
}
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1);
cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3);
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2);
float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y;
row_sum = local_row_sum;
if (final_call) {
// re-acquire the S part in the final step
pipeline_s.consumer_wait(pipeline_s_consumer_state);
Tensor tTMEM_STOREVrS = make_tensor<ElementQK>(shape(tTMEM_STOREVcS));
tTMEM_STOREVrS(kIdxFinalRowMax) = row_max;
tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum;
copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS);
}
}
template<class Stage, class BlkCoord, class ProblemShape>
CUTLASS_DEVICE auto
softmax(
Stage stage,
BlkCoord const& blk_coord,
Params const& params, ProblemShape const& problem_shape,
PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state,
PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state,
OrderBarrierSoftmax& order_s) {
int mask_tile_count = Mask{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_shape);
ElementQK row_max = -INFINITY;
ElementQK row_sum = 0;
Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{}));
auto logical_offset = make_coord(
get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}),
0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{})
);
Tensor cS = domain_offset(logical_offset, cS_base);
pipeline_c.producer_acquire(pipeline_c_producer_state);
CUTLASS_PRAGMA_NO_UNROLL
for (; mask_tile_count > 0; mask_tile_count -= 1) {
softmax_step<false /* need_apply_mask */>(
row_max, row_sum, stage,
(mask_tile_count == 1) &&
(Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape) == 0),
blk_coord, cS, params, problem_shape,
pipeline_s, pipeline_s_consumer_state,
pipeline_c, pipeline_c_producer_state,
order_s
);
cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{});
}
// Masked iterations
mask_tile_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape);
CUTLASS_PRAGMA_NO_UNROLL
for (; mask_tile_count > 0; mask_tile_count -= 1) {
softmax_step<true /* need_apply_mask */>(
row_max, row_sum, stage, mask_tile_count == 1,
blk_coord, cS, params, problem_shape,
pipeline_s, pipeline_s_consumer_state,
pipeline_c, pipeline_c_producer_state,
order_s
);
cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{});
}
pipeline_c.producer_commit(pipeline_c_producer_state);
++pipeline_c_producer_state;
pipeline_c.producer_acquire(pipeline_c_producer_state);
// empty step to sync against pipe s
pipeline_s.consumer_release(pipeline_s_consumer_state);
++pipeline_s_consumer_state;
}
template<class Stage, class TensorO>
CUTLASS_DEVICE auto
correction_epilogue(
float scale,
Stage stage,
TensorO const& sO_01) {
using ElementOut = typename TensorO::value_type;
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
Tensor sO = sO_01(_,_,stage);
// As opposed to the softmax, we do not have enough registers here
// to load all of the values (for tile kv = 128), so we loop
// good values would be either 32 or 64
const int kCorrectionTileSize = 32 / sizeof(ElementOut);
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
typename CollectiveMmaPV::TiledMma mma;
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
Tensor tOcO = mma.get_slice(0).partition_C(cO);
Tensor tOsO = mma.get_slice(0).partition_C(sO);
Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
if constexpr (decltype(stage == _0{})::value) {
tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O0);
}
else {
static_assert(decltype(stage == _1{})::value, "stage is either 0 or 1");
tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O1);
}
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{}));
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _));
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _));
Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _));
float2 scale_f32x2 = make_float2(scale, scale);
// loop:
// TMEM_LOAD, FMUL2 scale, TMEM_STORE
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) {
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i);
Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i);
Tensor tTMrO = make_tensor<ElementPV>(shape(tTMEM_LOADcO(_, _0{}, _0{}, i)));
copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO);
#ifndef ONLY_SOFTMAX
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(tTMrO); j += 2) {
float2 in = make_float2(tTMrO(j), tTMrO(j+1));
float2 out;
cute::mul(out, scale_f32x2, in);
tTMrO(j) = out.x;
tTMrO(j+1) = out.y;
}
#endif
constexpr int N = 4 / sizeof(ElementOut);
NumericArrayConverter<ElementOut, ElementPV, N> convert;
Tensor tSMrO = make_tensor_like<ElementOut>(tTMrO);
Tensor tCs = recast<decltype(convert)::source_type>(tTMrO);
Tensor tCd = recast<decltype(convert)::result_type>(tSMrO);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(tCs); j++) {
tCd(j) = convert.convert(tCs(j));
}
Tensor tSMsO_i = recast<uint32_t>(tTMEM_LOADsO_i);
Tensor tSMrO_i = recast<uint32_t>(tSMrO);
copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMsO_i);
}
cutlass::arch::fence_view_async_shared();
}
CUTLASS_DEVICE auto
correction_rescale(
float scale,
uint32_t tmem_O) {
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
// As opposed to the softmax, we do not have enough registers here
// to load all of the values (for tile kv = 128), so we loop
// good values would be either 32 or 64
const int kCorrectionTileSize = 16;
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
typename CollectiveMmaPV::TiledMma mma;
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
Tensor tOcO = mma.get_slice(0).partition_C(cO);
Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
tOtO_i.data() = tOtO_i.data().get() + tmem_O;
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i);
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i);
auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx);
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i);
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i);
Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i);
Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i);
static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO));
float2 scale_f32x2 = make_float2(scale, scale);
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
auto copy_in = [&](int i) {
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;
tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO)));
copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i);
};
auto copy_out = [&](int i) {
Tensor tTMEM_STOREtO_i = tTMEM_STOREtO;
tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO)));
copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i);
};
// sequence: LLMSLMSLMSS
// loop:
// TMEM_LOAD, FMUL2 scale, TMEM_STORE
copy_in(0);
int count = get<2>(TileShape{}) / kCorrectionTileSize;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < count; i++) {
if (i != count - 1) {
copy_in(i+1);
}
Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO)));
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(tTMrO_i); j += 2) {
float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1));
float2 out;
cute::mul(out, scale_f32x2, in);
tTMrO_i(j) = out.x;
tTMrO_i(j+1) = out.y;
}
copy_out(i);
}
}
template<
class BlkCoord, class ProblemShape, class ParamsProblemShape,
class TensorStorageEpi, class CollectiveEpilogue
>
CUTLASS_DEVICE auto
correction(
BlkCoord const& blk_coord,
Params const& params, ProblemShape const& problem_shape,
ParamsProblemShape const& params_problem_shape,
TensorStorageEpi& shared_storage_epi,
PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state,
PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state,
PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state,
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
CollectiveEpilogue& epilogue) {
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{}));
Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{}));
Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);
Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v);
auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx);
Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v);
Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v);
Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS;
tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0);
Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS;
tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1);
// ignore first signal from softmax as no correction is required
pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state);
pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state);
++pipeline_s0_c_consumer_state;
pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state);
// handle the last iteration differently (i.e. tmem_load/stsm for epi)
mask_tile_count -= 1;
CUTLASS_PRAGMA_NO_UNROLL
for (; mask_tile_count > 0; mask_tile_count -= 1) {
pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state);
Tensor tTMEM_LOADVrS = make_tensor<ElementQK>(shape(tTMEM_LOADVcS));
// read row_wise new global max
copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS);
// e^(scale * (old_max - new_max)
float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
pipeline_o.consumer_wait(pipeline_o_consumer_state);
correction_rescale(scale, uint32_t(TmemAllocation::O0));
pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);
++pipeline_s1_c_consumer_state;
cutlass::arch::fence_view_async_tmem_store();
pipeline_o.consumer_release(pipeline_o_consumer_state);
++pipeline_o_consumer_state;
pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state);
copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS);
scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
pipeline_o.consumer_wait(pipeline_o_consumer_state);
correction_rescale(scale, uint32_t(TmemAllocation::O1));
pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state);
++pipeline_s0_c_consumer_state;
cutlass::arch::fence_view_async_tmem_store();
pipeline_o.consumer_release(pipeline_o_consumer_state);
++pipeline_o_consumer_state;
}
pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);
++pipeline_s1_c_consumer_state;
// do the final correction to O1
// better to somehow special-case it in the loop above
// doesn't matter for non-persistent code, but if it were
// persistent we do not want to release O too early
pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state);
// read from V0
// read row_sum and final row_max here
Tensor tTMEM_LOADVrS = make_tensor<ElementQK>(shape(tTMEM_LOADVcS));
copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS);
pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state);
++pipeline_s0_c_consumer_state;
pipeline_o.consumer_wait(pipeline_o_consumer_state);
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
// store to epi smem
// loop:
// TMEM_LOAD
// FMUL2 scale = 1 / global_sum * out_quant_scale
// F2FP
// store to smem
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord);
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_tmem_load();
pipeline_o.consumer_release(pipeline_o_consumer_state);
++pipeline_o_consumer_state;
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state);
// load from V1
copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS);
pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);
++pipeline_s1_c_consumer_state;
pipeline_o.consumer_wait(pipeline_o_consumer_state);
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_tmem_load();
pipeline_o.consumer_release(pipeline_o_consumer_state);
++pipeline_o_consumer_state;
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
}
template<
class BlkCoord, class ProblemShape, class ParamsProblemShape,
class TensorStorageEpi, class CollectiveEpilogue
>
CUTLASS_DEVICE auto
correction_empty(
BlkCoord const& blk_coord,
Params const& params, ProblemShape const& problem_shape,
ParamsProblemShape const& params_problem_shape,
TensorStorageEpi& shared_storage_epi,
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
CollectiveEpilogue& epilogue) {
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);
float lse = -INFINITY;
int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp);
#if 1
using ElementOut = typename CollectiveEpilogue::ElementOut;
auto tiled_copy = make_cotiled_copy(
Copy_Atom<UniversalCopy<uint32_t>, ElementOut>{},
make_ordered_layout(make_shape(_128{}, Int<sizeof(uint32_t) / sizeof(ElementOut)>{}), Step<_1, _0>{}),
sO.layout());
auto thr_copy = tiled_copy.get_slice(thread_idx);
auto tOgO = thr_copy.partition_D(sO);
auto tOrO = make_tensor<ElementOut>(shape(tOgO(_,_,_,_0{})));
clear(tOrO);
copy(tiled_copy, tOrO, tOgO(_,_,_,_0{}));
#endif
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord);
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
copy(tiled_copy, tOrO, tOgO(_,_,_,_1{}));
cutlass::arch::fence_view_async_shared();
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_shared();
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
}
};
} // namespace cutlass::fmha::collective
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cute/tensor.hpp"
#include "cute/layout.hpp"
#include "../collective/fmha_common.hpp"
#include "../collective/fmha_fusion.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
template<
class Element,
class StrideQ,
class StrideK,
class StrideV,
class CollectiveMmaQK,
class CollectiveMmaPV,
class SmemLayoutQ,
class SmemLayoutK,
class SmemLayoutV,
class TensorStorage,
class PipelineQ,
class PipelineKV,
class Mask,
class TileShape
>
struct Sm100FmhaLoadTmaWarpspecialized {
using TileShapeQK = typename CollectiveMmaQK::TileShape;
using TileShapePV = typename CollectiveMmaPV::TileShape;
struct Arguments {
const Element* ptr_Q;
StrideQ dQ;
const Element* ptr_K;
StrideK dK;
const Element* ptr_V;
StrideV dV;
};
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
using TMA_V = typename CollectiveMmaPV::Params::TMA_B;
struct Params {
TMA_Q tma_load_q;
TMA_K tma_load_k;
TMA_V tma_load_v;
};
template<class ProblemShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape,
Arguments const& args,
void* workspace) {
auto ptr_Q = args.ptr_Q;
auto ptr_K = args.ptr_K;
auto ptr_V = args.ptr_V;
auto dQ = args.dQ;
auto dK = args.dK;
auto dV = args.dV;
using IntProblemShape = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
IntProblemShape problem_shape_qk;
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
auto cumulative_length_k = get<1>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) {
get<0>(problem_shape_qk) = get<0>(problem_shape).total_length;
get<1>(problem_shape_qk) = get<1>(problem_shape).total_length;
get<2>(problem_shape_qk) = get<2>(problem_shape);
get<3>(problem_shape_qk) = get<3>(problem_shape);
}
} else {
problem_shape_qk = problem_shape;
}
get<0>(problem_shape_qk) = max(1, get<0>(problem_shape_qk));
get<1>(problem_shape_qk) = max(1, get<1>(problem_shape_qk));
auto params_qk = CollectiveMmaQK::to_underlying_arguments(
problem_shape_qk,
typename CollectiveMmaQK::Arguments {
ptr_Q, dQ,
ptr_K, dK,
}, /*workspace=*/ nullptr);
auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk);
auto params_pv = CollectiveMmaPV::to_underlying_arguments(
problem_shape_pv,
typename CollectiveMmaPV::Arguments {
ptr_K, dK, // never used, dummy
ptr_V, select<1,0,2>(dV),
}, /*workspace=*/ nullptr);
return Params{
params_qk.tma_load_a,
params_qk.tma_load_b,
params_pv.tma_load_b
};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& params) {
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
}
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
CUTLASS_DEVICE void
load(
BlkCoord const& blk_coord_in, ProblemShape const& problem_shape,
Params const& params, ParamsProblemShape const& params_problem_shape,
TensorStorage& storage,
PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state,
PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) {
BlkCoord blk_coord_q = blk_coord_in;
BlkCoord blk_coord_kv = blk_coord_in;
int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape);
using X = Underscore;
// this one is only executed by one thread, no need to elect_one
// Q1, K1, Q2, V1, K2, V2, K3, V3, ...
// two pipes: Q and KV
// from Memory (prod) to TensorCore (cons)
// compute gQ, sQ
// we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1
ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0);
Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape));
int q_offs_0 = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)];
get<2,1>(blk_coord_q) = 0;
}
}
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p);
Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
auto [tQgQ_qdl, tQsQ] = tma_partition(
params.tma_load_q, _0{}, make_layout(_1{}),
group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl)
);
Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q));
// compute gK, sK
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape));
int kv_offs_0 = 0;
if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {
auto cumulative_length = get<1>(params_problem_shape).cumulative_length;
if (cumulative_length != nullptr) {
kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)];
get<2,1>(blk_coord_kv) = 0;
}
}
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p);
Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
auto [tKgK_kdl, tKsK] = tma_partition(
params.tma_load_k, _0{}, make_layout(_1{}),
group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl)
);
Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv));
// compute gV, sV
ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape));
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p);
Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
auto [tVgV_dkl, tVsV] = tma_partition(
params.tma_load_v, _0{}, make_layout(_1{}),
group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl)
);
auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv));
// blk_coord in decomposed in terms of TileShape, not TileShapeQK
// As such, it needs to be transformed as
// (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1)
// b -> 2*a (Ki i even) 2*a+1 (Ki i odd)
uint32_t lane_predicate = cute::elect_one_sync();
// Q1
int q0_index = 2 * get<0>(blk_coord_q);
int q1_index = 2 * get<0>(blk_coord_q) + 1;
pipeline_q.producer_acquire(pipeline_q_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);
copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index()));
}
++pipeline_q_producer_state;
// K1
int k_index = 0;
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index()));
}
++pipeline_kv_producer_state;
// Q2
pipeline_q.producer_acquire(pipeline_q_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);
copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index()));
}
++pipeline_q_producer_state;
// V1
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index()));
}
++pipeline_kv_producer_state;
k_index += 1;
// loop:
mask_tile_count -= 1;
for (; mask_tile_count > 0; mask_tile_count -= 1) {
// Ki
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index()));
}
++pipeline_kv_producer_state;
// Vi
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index()));
}
++pipeline_kv_producer_state;
k_index += 1;
}
}
};
} // namespace cutlass::fmha::collective
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cute/arch/simd_sm100.hpp"
#include "cute/tensor.hpp"
#include "cute/layout.hpp"
#include "../collective/fmha_common.hpp"
#include "../collective/fmha_fusion.hpp"
#include "../collective/sm100_fmha_mla_load_tma_warpspecialized.hpp"
#include "../common/pipeline_mla.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
template<
class Element_,
class ElementQK_,
class ElementPV_,
class ComposedTileShape_,
class StrideQ_,
class StrideK_,
class StrideV_,
class Mask_,
// shape here is QG K H
// and referes to the two softmax warps
// (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V)
// (1, 2, 1) means they sit side by side (best for small Q / large K)
class ThreadShape = Shape<_2, _1, _1>,
class OrderLoadEpilogue = cute::false_type
>
struct Sm100MlaFwdMainloopTmaWarpspecialized {
using Element = Element_;
using ElementQK = ElementQK_;
using ElementPV = ElementPV_;
using ComposedTileShape = ComposedTileShape_;
using StrideQ = StrideQ_;
using StrideK = StrideK_;
using StrideV = StrideV_;
using Mask = Mask_;
static constexpr int StageCountQ = 2;
static constexpr int StageCountK = 1;
static constexpr int StageCountV = 1;
static constexpr int StageCountKV = StageCountK + StageCountV;
// Support StageCountKV > 2 in the future.
static_assert(StageCountK == 1 && StageCountV == 1, "Only support StageCountK = StageCountV = 1!");
static_assert(std::is_same_v<ThreadShape, Shape<_2, _1, _1>>, "Only support ThreadShape = Shape<_2, _1, _1>");
using ClusterShape = Shape<_1, _1, _1>;
static const int Alignment = 128 / sizeof_bits_v<Element>;
static constexpr auto HeadDimLatent = size<2, 0>(ComposedTileShape{});
static constexpr auto HeadDimRope = size<2, 1>(ComposedTileShape{});
static constexpr auto HeadDimQK = HeadDimLatent + HeadDimRope;
static constexpr auto HeadDimPV = HeadDimLatent;
using TileShapeQK = decltype(shape_div(replace<2>(ComposedTileShape{}, HeadDimQK), ThreadShape{}));
using TileShapePV = decltype(select<0,2,1>(shape_div(replace<2>(ComposedTileShape{}, HeadDimPV), ThreadShape{})));
using TileShape = decltype(replace<2>(ComposedTileShape{}, HeadDimLatent));
using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
Element, StrideQ, Alignment,
Element, StrideK, Alignment,
ElementQK,
TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/,
cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp;
using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
// the stride for A does not matter since we do not load from smem at all
Element, StrideK, Alignment,
Element, decltype(select<1,0,2>(StrideV{})), Alignment,
ElementPV,
TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/,
cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp;
using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int<StageCountQ>{}));
using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int<StageCountK>{}));
using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int<StageCountV>{}));
using SmemStorageOneStageO = decltype(make_layout(replace<2>(TileShapePV{}, _1{})));
// Since the shared memory is not sufficient if we use separate Q, K, V, and O shared memory,
// we reuse shared memory for V and O to address this problem,
// and a barrier has been added to coordinate access to shared memory.
static constexpr bool IsOrderLoadEpilogue = std::is_same_v<OrderLoadEpilogue, cute::true_type>;
static const int NumWarpsEpilogue = 1;
static const int NumWarpsLoad = 1;
struct TensorStorageQKVO {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_o; // use as O0
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v; // use as V0 and O1
};
struct TensorStorageQKV {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutK>> smem_k;
cute::array_aligned<Element, cute::cosize_v<SmemLayoutV>> smem_v;
};
using TensorStorage = std::conditional_t<IsOrderLoadEpilogue, TensorStorageQKVO, TensorStorageQKV>;
enum class TmemAllocation : uint32_t {
kSizeS = 128,
kSizeO = 128,
kSizeP = 32,
S0 = 0,
S1 = S0 + kSizeS,
V0 = S0, // stats storage from softmax to correction
V1 = S1,
P0 = S0 + kSizeP,
P1 = S1 + kSizeP,
O0 = S1 + kSizeS,
O1 = O0 + kSizeO,
kEnd = O1 + kSizeO
};
// indices for V0 / V1
enum : int {
kIdxOldRowMax = 0,
kIdxNewRowMax = 1,
kIdxFinalRowSum = 0,
kIdxFinalRowMax = 1
};
// from load to mma warp, protects q in smem
using PipelineQ = cutlass::PipelineTmaUmmaAsync<
StageCountQ,
typename CollectiveMmaQK::AtomThrShapeMNK
>;
// from load to mma warp, protects k/v in smem
using PipelineKV = cutlass::PipelineTmaAsyncMla<
StageCountKV,
typename CollectiveMmaQK::AtomThrShapeMNK
>;
// from mma to softmax0/1 warp, protects S in tmem
// (not sure yet about the reverse direction)
// there is one pipe per softmax warp, and the mma warp alternates between them
using PipelineS = cutlass::PipelineUmmaAsync<1>;
// from softmax0/1/ to correction wg
using PipelineC = cutlass::PipelineAsync<1>;
// from mma to correction
using PipelineO = cutlass::PipelineUmmaAsync<2>;
// from corr to epilogue
using PipelineE = cutlass::PipelineAsync<2>;
using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier<
/*stages*/ 1, /*groups*/ 2>;
static constexpr int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v<Element>);
static constexpr int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);
static constexpr int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);
using Load = Sm100MlaFwdLoadTmaWarpspecialized<
Element, StrideQ, StrideK, StrideV,
CollectiveMmaQK, CollectiveMmaPV,
SmemLayoutQ, SmemLayoutK, SmemLayoutV,
TensorStorage, PipelineQ, PipelineKV, Mask, TileShape, OrderLoadEpilogue
>;
struct Arguments {
typename Load::Arguments load;
// if zero, defaults to 1/sqrt(D)
float scale_softmax = 0.0f;
// scaling factors to dequantize QKV
float scale_q = 1.0f;
float scale_k = 1.0f;
float scale_v = 1.0f;
// scaling factor to quantize O
float inv_scale_o = 1.0f;
};
struct Params {
typename Load::Params load;
float scale_softmax;
float scale_softmax_log2;
float scale_output;
};
template<class ProblemShape>
static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template<class ProblemShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape,
Arguments const& args,
void* workspace) {
float scale_softmax = args.scale_softmax;
if (scale_softmax == 0.0f) {
scale_softmax = 1.0f / (float) std::sqrt(get<2, 0>(problem_shape) + get<2, 1>(problem_shape));
}
float log2_e = static_cast<float>(std::log2(std::exp(1.0)));
return Params{
Load::to_underlying_arguments(problem_shape, args.load, workspace),
args.scale_q * args.scale_k * scale_softmax,
args.scale_q * args.scale_k * log2_e * scale_softmax,
args.scale_v * args.inv_scale_o
};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& params) {
Load::prefetch_tma_descriptors(params.load);
}
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
CUTLASS_DEVICE void
load(
BlkCoord const& blk_coord, ProblemShape const& problem_shape,
Params const& params, ParamsProblemShape const& params_problem_shape,
TensorStorage& storage,
PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state,
PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) {
Load load;
load.load(blk_coord, problem_shape, params.load, params_problem_shape,
storage,
pipeline_q, pipeline_q_producer_state,
pipeline_kv, pipeline_kv_producer_state);
}
template<class BlkCoord, class ProblemShape>
CUTLASS_DEVICE auto
mma(
BlkCoord const& blk_coord,
Params const& params, ProblemShape const& problem_shape,
TensorStorage& storage,
PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state,
PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state,
PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state,
PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state,
PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) {
auto pipeline_q_release_state = pipeline_q_consumer_state;
auto pipeline_kv_release_state = pipeline_kv_consumer_state;
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
typename CollectiveMmaQK::TiledMma mma_qk;
ThrMMA thr_mma_qk = mma_qk.get_slice(0);
typename CollectiveMmaPV::TiledMma mma_pv;
TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv);
ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0);
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ);
Tensor tSrK = thr_mma_qk.make_fragment_B(sK);
Tensor tOrV = thr_mma_pv.make_fragment_B(sV);
// tmem layout is
// S0 S1`O0 O1
// sequential in memory, where S overlaps with P and V
Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{}));
Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{}));
Tensor tStS0 = tStS;
tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0);
Tensor tStS1 = tStS;
tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1);
Tensor tOtO0 = tOtO;
tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0);
Tensor tOtO1 = tOtO;
tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1);
Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{});
Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging
Tensor tOrP0 = tOrP;
tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0);
Tensor tOrP1 = tOrP;
tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1);
int k_index = 0;
int v_index = 0;
int q_index = 0;
// wait for Q1
q_index = pipeline_q_consumer_state.index();
pipeline_q.consumer_wait(pipeline_q_consumer_state);
++pipeline_q_consumer_state;
Tensor tSrQ0 = tSrQ(_,_,_,q_index);
// wait for K1
k_index = pipeline_kv_consumer_state.index();
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
// gemm Q1 * K1 -> S1
pipeline_s0.producer_acquire(pipeline_s0_producer_state);
gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index / 2), tStS0);
pipeline_s0.producer_commit(pipeline_s0_producer_state);
++pipeline_s0_producer_state;
// release K1
if constexpr (get<1>(ThreadShape{}) > 1) {
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
}
// wait for Q2
if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) {
q_index = pipeline_q_consumer_state.index();
pipeline_q.consumer_wait(pipeline_q_consumer_state);
++pipeline_q_consumer_state;
}
Tensor tSrQ1 = tSrQ(_,_,_,q_index);
if constexpr (get<1>(ThreadShape{}) > 1) {
k_index = pipeline_kv_consumer_state.index();
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
}
pipeline_s1.producer_acquire(pipeline_s1_producer_state);
// gemm Q2 * K1 -> S2
gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index / 2), tStS1);
pipeline_s1.producer_commit(pipeline_s1_producer_state);
++pipeline_s1_producer_state;
// release K1
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
// wait for V1
v_index = pipeline_kv_consumer_state.index();
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
// this acquire returns the ownership of all of S0 to the mma warp
// including the P0 part
// acquire corr first to take it out of the critical
// path since softmax takes longer
pipeline_corr.producer_acquire(pipeline_corr_producer_state);
pipeline_s0.producer_acquire(pipeline_s0_producer_state);
// gemm P1 * V1 -> O1
gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index / 2), tOtO0);
pipeline_corr.producer_commit(pipeline_corr_producer_state);
++pipeline_corr_producer_state;
if constexpr (get<1>(ThreadShape{}) > 1) {
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
}
mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero;
// loop:
mask_tile_count -= 1;
for (; mask_tile_count > 0; mask_tile_count -= 1) {
// wait for Ki
k_index = (pipeline_kv_consumer_state.index());
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
// gemm Q1 * Ki -> S1
gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index / 2), tStS0);
pipeline_s0.producer_commit(pipeline_s0_producer_state);
++pipeline_s0_producer_state;
if constexpr (get<1>(ThreadShape{}) > 1) {
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
}
// gemm P2 * V(i-1) -> O2
if constexpr (get<1>(ThreadShape{}) > 1) {
v_index = pipeline_kv_consumer_state.index();
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
}
pipeline_corr.producer_acquire(pipeline_corr_producer_state);
pipeline_s1.producer_acquire(pipeline_s1_producer_state);
gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index / 2), tOtO1);
pipeline_corr.producer_commit(pipeline_corr_producer_state);
++pipeline_corr_producer_state;
// release V(i-1)
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
if constexpr (get<1>(ThreadShape{}) > 1) {
k_index = (pipeline_kv_consumer_state.index());
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
}
// gemm Q2 * Ki -> S2
gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index / 2), tStS1);
pipeline_s1.producer_commit(pipeline_s1_producer_state);
++pipeline_s1_producer_state;
// release Ki
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
// wait for Vi
v_index = (pipeline_kv_consumer_state.index());
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
// gemm P1 * Vi -> O1
pipeline_corr.producer_acquire(pipeline_corr_producer_state);
pipeline_s0.producer_acquire(pipeline_s0_producer_state);
gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index / 2), tOtO0);
pipeline_corr.producer_commit(pipeline_corr_producer_state);
++pipeline_corr_producer_state;
if constexpr (get<1>(ThreadShape{}) > 1) {
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
}
}
// release Q1
pipeline_q.consumer_release(pipeline_q_release_state);
++pipeline_q_release_state;
// release Q2
if constexpr (get<0>(ThreadShape{}) > 1) {
pipeline_q.consumer_release(pipeline_q_release_state);
++pipeline_q_release_state;
}
// wait for Vi
if constexpr (get<1>(ThreadShape{}) > 1) {
v_index = pipeline_kv_consumer_state.index();
pipeline_kv.consumer_wait(pipeline_kv_consumer_state);
++pipeline_kv_consumer_state;
}
// gemm P2 * Vi -> O2
pipeline_corr.producer_acquire(pipeline_corr_producer_state);
pipeline_s1.producer_acquire(pipeline_s1_producer_state);
gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index / 2), tOtO1);
pipeline_corr.producer_commit(pipeline_corr_producer_state);
++pipeline_corr_producer_state;
// release Vi
pipeline_kv.consumer_release(pipeline_kv_release_state);
++pipeline_kv_release_state;
pipeline_s0.producer_commit(pipeline_s0_producer_state);
++pipeline_s0_producer_state;
pipeline_s1.producer_commit(pipeline_s1_producer_state);
++pipeline_s1_producer_state;
// T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ...
// Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ...
}
template<bool need_mask, class Stage, class BlkCoord, class CoordTensor, class ProblemShape>
CUTLASS_DEVICE auto
softmax_step(
bool need_apply_mask,
float& row_max, float& row_sum,
Stage stage, bool final_call,
BlkCoord const& blk_coord, CoordTensor const& cS,
Params const& params, ProblemShape const& problem_shape,
PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state,
PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state,
OrderBarrierSoftmax& order_s) {
Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);
Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{}));
tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1);
Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1);
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int<sizeof(float)>{} * Int<sizeof(Element)>{};
Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1));
Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32)));
// Each thread owns a single row
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem
using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem
using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS);
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS);
Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS);
auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v);
auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx);
Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v);
Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v);
auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P);
auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx);
Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P);
tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get());
Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P);
// wait on tensor core pipe
pipeline_s.consumer_wait(pipeline_s_consumer_state);
// read all of S from tmem into reg mem
Tensor tTMEM_LOADrS = make_tensor<ElementQK>(shape(tTMEM_LOADcS));
copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS);
if constexpr (need_mask) {
if(need_apply_mask) {
Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape);
}
}
ElementQK old_row_max = row_max;
{
// compute rowmax
float row_max_0 = row_max;
float row_max_1 = row_max;
float row_max_2 = row_max;
float row_max_3 = row_max;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTMEM_LOADrS); i += 4) {
row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i));
row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1));
row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2));
row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3));
}
row_max = ::fmax(row_max_0, row_max_1);
row_max = ::fmax(row_max, row_max_2);
row_max = ::fmax(row_max, row_max_3);
}
ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max;
Tensor tTMEM_STOREVrS = make_tensor<ElementQK>(shape(tTMEM_STOREVcS));
tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max;
tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe;
copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS);
pipeline_c.producer_commit(pipeline_c_producer_state);
++pipeline_c_producer_state;
// notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's)
ElementQK scale = params.scale_softmax_log2;
ElementQK row_max_scale = row_max_safe * scale;
float2 scale_fp32x2 = make_float2(scale, scale);
float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale);
Tensor tTMEM_STORErS_x4 = make_tensor<uint32_t>(shape(tTMEM_STOREcS));
constexpr int kConversionsPerStep = 2;
Tensor tTMEM_STORErS_x4_e = recast<Array<Element, kConversionsPerStep>>(tTMEM_STORErS_x4);
NumericArrayConverter<Element, ElementQK, kConversionsPerStep> convert;
constexpr int kReleasePipeCount = 10; // must be multiple of 2
order_s.wait();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTMEM_LOADrS); i += 2) {
float2 in = make_float2(
tTMEM_LOADrS(i + 0),
tTMEM_LOADrS(i + 1)
);
float2 out;
cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2);
tTMEM_LOADrS(i + 0) = out.x;
tTMEM_LOADrS(i + 1) = out.y;
tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0));
tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1));
Array<ElementQK, kConversionsPerStep> in_conv;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < kConversionsPerStep; j++) {
in_conv[j] = tTMEM_LOADrS(i + j);
}
tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv);
if (i == size(tTMEM_LOADrS) - kReleasePipeCount) {
order_s.arrive();
}
// this prevents register spills in fp16
if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) {
if (i == size(tTMEM_LOADrS) - 6) {
copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0));
}
}
}
// tmem_store(reg_S8) -> op_P
CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{});
CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{});
copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1));
cutlass::arch::fence_view_async_tmem_store();
// notify tensor core warp that P is ready
pipeline_s.consumer_release(pipeline_s_consumer_state);
++pipeline_s_consumer_state;
pipeline_c.producer_acquire(pipeline_c_producer_state);
ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe));
row_sum *= acc_scale;
// row_sum = sum(reg_S)
float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum);
float2 local_row_sum_1 = make_float2(0, 0);
float2 local_row_sum_2 = make_float2(0, 0);
float2 local_row_sum_3 = make_float2(0, 0);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTMEM_LOADrS); i += 8) {
// row_sum += tTMEM_LOADrS(i);
float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1));
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in);
in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1));
cute::add(local_row_sum_1, local_row_sum_1, in);
in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1));
cute::add(local_row_sum_2, local_row_sum_2, in);
in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1));
cute::add(local_row_sum_3, local_row_sum_3, in);
}
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1);
cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3);
cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2);
float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y;
row_sum = local_row_sum;
if (final_call) {
// re-acquire the S part in the final step
pipeline_s.consumer_wait(pipeline_s_consumer_state);
Tensor tTMEM_STOREVrS = make_tensor<ElementQK>(shape(tTMEM_STOREVcS));
tTMEM_STOREVrS(kIdxFinalRowMax) = row_max;
tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum;
copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS);
}
}
template<class Stage, class BlkCoord, class ProblemShape>
CUTLASS_DEVICE auto
softmax(
Stage stage,
BlkCoord const& blk_coord,
Params const& params, ProblemShape const& problem_shape,
PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state,
PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state,
OrderBarrierSoftmax& order_s) {
const int mask_trip_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape);
const int total_trip_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
int trip_idx = total_trip_count;
ElementQK row_max = -INFINITY;
ElementQK row_sum = 0;
Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{}));
auto logical_offset = make_coord(
get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}),
0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{})
);
Tensor cS = domain_offset(logical_offset, cS_base);
pipeline_c.producer_acquire(pipeline_c_producer_state);
constexpr bool NeedMask = !std::is_same_v<Mask, NoMask>;
CUTLASS_PRAGMA_NO_UNROLL
for (; trip_idx > 0; trip_idx -= 1) {
softmax_step<NeedMask /* need_mask */>(
trip_idx <= mask_trip_count,
row_max, row_sum, stage,
trip_idx == 1,
blk_coord, cS, params, problem_shape,
pipeline_s, pipeline_s_consumer_state,
pipeline_c, pipeline_c_producer_state,
order_s
);
cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{});
}
pipeline_c.producer_commit(pipeline_c_producer_state);
++pipeline_c_producer_state;
pipeline_c.producer_acquire(pipeline_c_producer_state);
// empty step to sync against pipe s
pipeline_s.consumer_release(pipeline_s_consumer_state);
++pipeline_s_consumer_state;
}
template<class Stage, class TensorO>
CUTLASS_DEVICE auto
correction_epilogue(
float scale,
Stage stage,
TensorO const& sO_01) {
using ElementOut = typename TensorO::value_type;
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
Tensor sO = sO_01(_,_,stage);
// As opposed to the softmax, we do not have enough registers here
// to load all of the values (for tile kv = 128), so we loop
// good values would be either 32 or 64
constexpr int kCorrectionTileSize = 32 / sizeof(ElementOut);
using TMEM_LOAD = std::conditional_t<kCorrectionTileSize == 32, SM100_TMEM_LOAD_32dp32b32x, SM100_TMEM_LOAD_32dp32b16x>; // 4x32 threads with 64 cols of 32b elem
typename CollectiveMmaPV::TiledMma mma;
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
Tensor tOcO = mma.get_slice(0).partition_C(cO);
Tensor tOsO = mma.get_slice(0).partition_C(sO);
Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
if constexpr (decltype(stage == _0{})::value) {
tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O0);
}
else {
static_assert(decltype(stage == _1{})::value, "stage is either 0 or 1");
tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O1);
}
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{}));
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _));
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _));
Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _));
float2 scale_f32x2 = make_float2(scale, scale);
// loop:
// TMEM_LOAD, FMUL2 scale, TMEM_STORE
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) {
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i);
Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i);
Tensor tTMrO = make_tensor<ElementPV>(shape(tTMEM_LOADcO(_, _0{}, _0{}, i)));
copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO);
#ifndef ONLY_SOFTMAX
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(tTMrO); j += 2) {
float2 in = make_float2(tTMrO(j), tTMrO(j+1));
float2 out;
cute::mul(out, scale_f32x2, in);
tTMrO(j) = out.x;
tTMrO(j+1) = out.y;
}
#endif
constexpr int N = 4 / sizeof(ElementOut);
NumericArrayConverter<ElementOut, ElementPV, N> convert;
Tensor tSMrO = make_tensor_like<ElementOut>(tTMrO);
Tensor tCs = recast<decltype(convert)::source_type>(tTMrO);
Tensor tCd = recast<decltype(convert)::result_type>(tSMrO);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(tCs); j++) {
tCd(j) = convert.convert(tCs(j));
}
Tensor tSMsO_i = recast<uint32_t>(tTMEM_LOADsO_i);
Tensor tSMrO_i = recast<uint32_t>(tSMrO);
copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMsO_i);
}
cutlass::arch::fence_view_async_shared();
}
CUTLASS_DEVICE auto
correction_rescale(
float scale,
uint32_t tmem_O) {
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
// As opposed to the softmax, we do not have enough registers here
// to load all of the values (for tile kv = 128), so we loop
// good values would be either 32 or 64
constexpr int kCorrectionTileSize = 16;
using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem
typename CollectiveMmaPV::TiledMma mma;
Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{}));
Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{}));
Tensor tOcO = mma.get_slice(0).partition_C(cO);
Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int<kCorrectionTileSize>{})));
tOtO_i.data() = tOtO_i.data().get() + tmem_O;
auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i);
auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx);
auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i);
auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx);
Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i);
Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i);
Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i);
Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i);
static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO));
float2 scale_f32x2 = make_float2(scale, scale);
Tensor tTMrO = make_tensor<ElementPV>(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{}));
auto copy_in = [&](int i) {
Tensor tTMEM_LOADtO_i = tTMEM_LOADtO;
tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO)));
copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i);
};
auto copy_out = [&](int i) {
Tensor tTMEM_STOREtO_i = tTMEM_STOREtO;
tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize);
Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO)));
copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i);
};
// sequence: LLMSLMSLMSS
// loop:
// TMEM_LOAD, FMUL2 scale, TMEM_STORE
copy_in(0);
constexpr int count = get<2>(TileShape{}) / kCorrectionTileSize;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < count; i++) {
if (i != count - 1) {
copy_in(i+1);
}
Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO)));
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(tTMrO_i); j += 2) {
float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1));
float2 out;
cute::mul(out, scale_f32x2, in);
tTMrO_i(j) = out.x;
tTMrO_i(j+1) = out.y;
}
copy_out(i);
}
}
template<
class BlkCoord, class ProblemShape, class ParamsProblemShape,
class TensorStorageEpi, class CollectiveEpilogue
>
CUTLASS_DEVICE auto
correction(
BlkCoord const& blk_coord,
Params const& params, ProblemShape const& problem_shape,
ParamsProblemShape const& params_problem_shape,
TensorStorageEpi& shared_storage_epi,
PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state,
PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state,
PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state,
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
CollectiveEpilogue& epilogue) {
int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape);
int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp);
Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{}));
Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{}));
Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS);
Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{})));
Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{})));
using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem
auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v);
auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx);
Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v);
Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v);
Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS;
tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0);
Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS;
tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1);
// ignore first signal from softmax as no correction is required
pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state);
pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state);
++pipeline_s0_c_consumer_state;
pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state);
// handle the last iteration differently (i.e. tmem_load/stsm for epi)
mask_tile_count -= 1;
CUTLASS_PRAGMA_NO_UNROLL
for (; mask_tile_count > 0; mask_tile_count -= 1) {
pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state);
Tensor tTMEM_LOADVrS = make_tensor<ElementQK>(shape(tTMEM_LOADVcS));
// read row_wise new global max
copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS);
// e^(scale * (old_max - new_max)
float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
pipeline_o.consumer_wait(pipeline_o_consumer_state);
correction_rescale(scale, uint32_t(TmemAllocation::O0));
pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);
++pipeline_s1_c_consumer_state;
cutlass::arch::fence_view_async_tmem_store();
pipeline_o.consumer_release(pipeline_o_consumer_state);
++pipeline_o_consumer_state;
pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state);
copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS);
scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax)));
pipeline_o.consumer_wait(pipeline_o_consumer_state);
correction_rescale(scale, uint32_t(TmemAllocation::O1));
pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state);
++pipeline_s0_c_consumer_state;
cutlass::arch::fence_view_async_tmem_store();
pipeline_o.consumer_release(pipeline_o_consumer_state);
++pipeline_o_consumer_state;
}
pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);
++pipeline_s1_c_consumer_state;
// do the final correction to O1
// better to somehow special-case it in the loop above
// doesn't matter for non-persistent code, but if it were
// persistent we do not want to release O too early
pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state);
// read from V0
// read row_sum and final row_max here
Tensor tTMEM_LOADVrS = make_tensor<ElementQK>(shape(tTMEM_LOADVcS));
copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS);
pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state);
++pipeline_s0_c_consumer_state;
pipeline_o.consumer_wait(pipeline_o_consumer_state);
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
// store to epi smem
// loop:
// TMEM_LOAD
// FMUL2 scale = 1 / global_sum * out_quant_scale
// F2FP
// store to smem
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord);
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_tmem_load();
pipeline_o.consumer_release(pipeline_o_consumer_state);
++pipeline_o_consumer_state;
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state);
// load from V1
copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS);
pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state);
++pipeline_s1_c_consumer_state;
pipeline_o.consumer_wait(pipeline_o_consumer_state);
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax);
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_tmem_load();
pipeline_o.consumer_release(pipeline_o_consumer_state);
++pipeline_o_consumer_state;
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
}
template<
class BlkCoord, class ProblemShape, class ParamsProblemShape,
class TensorStorageEpi, class CollectiveEpilogue
>
CUTLASS_DEVICE auto
correction_empty(
BlkCoord const& blk_coord,
Params const& params, ProblemShape const& problem_shape,
ParamsProblemShape const& params_problem_shape,
TensorStorageEpi& shared_storage_epi,
PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state,
CollectiveEpilogue& epilogue) {
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{});
Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE);
float lse = -INFINITY;
int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp);
#if 1
using ElementOut = typename CollectiveEpilogue::ElementOut;
auto tiled_copy = make_cotiled_copy(
Copy_Atom<UniversalCopy<uint32_t>, ElementOut>{},
make_ordered_layout(make_shape(_128{}, Int<sizeof(uint32_t) / sizeof(ElementOut)>{}), Step<_1, _0>{}),
sO.layout());
auto thr_copy = tiled_copy.get_slice(thread_idx);
auto tOgO = thr_copy.partition_D(sO);
auto tOrO = make_tensor<ElementOut>(shape(tOgO(_,_,_,_0{})));
clear(tOrO);
copy(tiled_copy, tOrO, tOgO(_,_,_,_0{}));
#endif
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord);
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
copy(tiled_copy, tOrO, tOgO(_,_,_,_1{}));
cutlass::arch::fence_view_async_shared();
pipeline_epi.producer_acquire(pipeline_epi_producer_state);
if (epilogue.params.ptr_LSE != nullptr) {
int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{});
int row_offset = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)];
}
if (row_idx < get<0>(problem_shape)) {
gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse;
}
}
cutlass::arch::fence_view_async_shared();
pipeline_epi.producer_commit(pipeline_epi_producer_state);
++pipeline_epi_producer_state;
}
};
} // namespace cutlass::fmha::collective
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cute/tensor.hpp"
#include "cute/layout.hpp"
#include "../collective/fmha_common.hpp"
#include "../collective/fmha_fusion.hpp"
namespace cutlass::fmha::collective {
using namespace cute;
template<
class Element,
class StrideQ,
class StrideK,
class StrideV,
class CollectiveMmaQK,
class CollectiveMmaPV,
class SmemLayoutQ,
class SmemLayoutK,
class SmemLayoutV,
class TensorStorage,
class PipelineQ,
class PipelineKV,
class Mask,
class TileShape,
class OrderLoadEpilogue = cute::false_type
>
struct Sm100MlaFwdLoadTmaWarpspecialized {
using TileShapeQK = typename CollectiveMmaQK::TileShape;
using TileShapePV = typename CollectiveMmaPV::TileShape;
static constexpr int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v<Element>);
static constexpr int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v<Element>);
static const int NumWarpsEpilogue = 1;
static const int NumWarpsLoad = 1;
struct Arguments {
const Element* ptr_Q;
StrideQ dQ;
const Element* ptr_K;
StrideK dK;
const Element* ptr_V;
StrideV dV;
};
using TMA_Q = typename CollectiveMmaQK::Params::TMA_A;
using TMA_K = typename CollectiveMmaQK::Params::TMA_B;
using TMA_V = typename CollectiveMmaPV::Params::TMA_B;
struct Params {
TMA_Q tma_load_q;
TMA_K tma_load_k;
TMA_V tma_load_v;
};
template<class ProblemShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape,
Arguments const& args,
void* workspace) {
auto ptr_Q = args.ptr_Q;
auto ptr_K = args.ptr_K;
auto ptr_V = args.ptr_V;
auto dQ = args.dQ;
auto dK = args.dK;
auto dV = args.dV;
using IntProblemShape = cute::tuple<int, int, int, cute::tuple<cute::tuple<int, int>, int>>;
IntProblemShape problem_shape_qk;
if constexpr (is_variable_length_v<tuple_element_t<0, ProblemShape>>) {
auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
auto cumulative_length_k = get<1>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) {
get<0>(problem_shape_qk) = get<0>(problem_shape).total_length;
get<1>(problem_shape_qk) = get<1>(problem_shape).total_length;
get<2>(problem_shape_qk) = get<2, 0>(problem_shape) + get<2, 1>(problem_shape);
get<3>(problem_shape_qk) = get<3>(problem_shape);
}
} else {
problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));;
}
get<0>(problem_shape_qk) = max(1, get<0>(problem_shape_qk));
get<1>(problem_shape_qk) = max(1, get<1>(problem_shape_qk));
auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape));
auto params_qk = CollectiveMmaQK::to_underlying_arguments(
problem_shape_qk,
typename CollectiveMmaQK::Arguments {
ptr_Q, dQ,
ptr_K, dK,
}, /*workspace=*/ nullptr);
auto params_pv = CollectiveMmaPV::to_underlying_arguments(
problem_shape_pv,
typename CollectiveMmaPV::Arguments {
ptr_K, dK, // never used, dummy
ptr_V, select<1,0,2>(dV),
}, /*workspace=*/ nullptr);
return Params{
params_qk.tma_load_a,
params_qk.tma_load_b,
params_pv.tma_load_b
};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& params) {
cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor());
cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor());
}
template<class BlkCoord, class ProblemShape, class ParamsProblemShape>
CUTLASS_DEVICE void
load(
BlkCoord const& blk_coord_in, ProblemShape const& problem_shape,
Params const& params, ParamsProblemShape const& params_problem_shape,
TensorStorage& storage,
PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state,
PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) {
BlkCoord blk_coord_q = blk_coord_in;
BlkCoord blk_coord_kv = blk_coord_in;
auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));
auto problem_shape_v = replace<2>(problem_shape, get<2, 0>(problem_shape));
int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape);
using X = Underscore;
// this one is only executed by one thread, no need to elect_one
// Q1, K1, Q2, V1, K2, V2, K3, V3, ...
// two pipes: Q and KV
// from Memory (prod) to TensorCore (cons)
// compute gQ, sQ
// we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1
ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0);
Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape_qk));
int q_offs_0 = 0;
if constexpr (is_variable_length_v<tuple_element_t<0, ParamsProblemShape>>) {
auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) {
q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)];
get<2,1>(blk_coord_q) = 0;
}
}
Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p);
Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl);
Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{});
auto [tQgQ_qdl, tQsQ] = tma_partition(
params.tma_load_q, _0{}, make_layout(_1{}),
group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl)
);
Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q));
// compute gK, sK
Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_qk));
int kv_offs_0 = 0;
if constexpr (is_variable_length_v<tuple_element_t<1, ParamsProblemShape>>) {
auto cumulative_length = get<1>(params_problem_shape).cumulative_length;
if (cumulative_length != nullptr) {
kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)];
get<2,1>(blk_coord_kv) = 0;
}
}
Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p);
Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl);
Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{});
auto [tKgK_kdl, tKsK] = tma_partition(
params.tma_load_k, _0{}, make_layout(_1{}),
group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl)
);
Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv));
// compute gV, sV
ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0);
Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v));
Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p);
Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step<X, _1, _1>{});
Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl);
Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{});
auto [tVgV_dkl, tVsV] = tma_partition(
params.tma_load_v, _0{}, make_layout(_1{}),
group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl)
);
auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv));
// blk_coord in decomposed in terms of TileShape, not TileShapeQK
// As such, it needs to be transformed as
// (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1)
// b -> 2*a (Ki i even) 2*a+1 (Ki i odd)
uint32_t lane_predicate = cute::elect_one_sync();
// Q1
int q0_index = 2 * get<0>(blk_coord_q);
int q1_index = 2 * get<0>(blk_coord_q) + 1;
pipeline_q.producer_acquire(pipeline_q_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);
copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index()));
}
++pipeline_q_producer_state;
// K1
int k_index = 0;
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2));
}
++pipeline_kv_producer_state;
// Q2
pipeline_q.producer_acquire(pipeline_q_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state);
copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index()));
}
++pipeline_q_producer_state;
if constexpr (cute::is_same_v<OrderLoadEpilogue, cute::true_type>) {
cutlass::arch::NamedBarrier::sync((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
}
// V1
pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2));
}
++pipeline_kv_producer_state;
k_index += 1;
// loop:
mask_tile_count -= 1;
for (; mask_tile_count > 0; mask_tile_count -= 1) {
// Ki
pipeline_kv.producer_acquire(pipeline_kv_producer_state);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2));
// prefetch vi
cute::prefetch(params.tma_load_v, tVgV(_, k_index));
}
++pipeline_kv_producer_state;
// Vi
pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV);
if (lane_predicate) {
auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state);
copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2));
// prefetch ki+1
if(mask_tile_count > 1) {
cute::prefetch(params.tma_load_k, tKgK(_, k_index + 1));
}
}
++pipeline_kv_producer_state;
k_index += 1;
}
}
};
} // namespace cutlass::fmha::collective
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/layout.hpp"
#include "cute/tensor.hpp"
#include "cute/util/print.hpp"
namespace example {
using namespace cute;
// Empty type used to disable gather/scatter for a GEMM argument
struct NoGather
{
template<class... Ts>
NoGather(Ts...) {};
};
/// Function object that applies an index to its argument
template <class Index>
struct IndexedGather
{
CUTE_HOST_DEVICE constexpr
IndexedGather(Index const *indices = {}): indices_(indices) {}
template <typename I>
CUTE_HOST_DEVICE constexpr
Index
operator()(I i) const { return indices_[i]; }
CUTE_HOST_DEVICE friend
void
print(IndexedGather const &s) {
cute::print("Indexed");
}
Index const *indices_;
};
/// Function object that applies a stride to its argument
/// Example: StridedFunc<int,_2> gathers every other row/column
template <class Stride>
struct StridedGather
{
CUTE_HOST_DEVICE constexpr
StridedGather(Stride stride = {}): stride_(stride) {}
template <class I>
CUTE_HOST_DEVICE constexpr
auto
operator()(I i) const { return i * stride_; }
CUTE_HOST_DEVICE friend
void
print(StridedGather const &s) {
cute::print("Strided{");
print(s.stride_);
cute::print("}");
}
Stride stride_;
};
/// Custom stride object that applies a function followed by a stride
template <class Func, class Stride>
struct CustomStride
{
CUTE_HOST_DEVICE constexpr
CustomStride(Func const &func, Stride const &stride): func_(func), stride_(stride) {}
template <class I>
CUTE_HOST_DEVICE constexpr friend
auto
operator*(I i, CustomStride const &s) { return s.func_(i) * s.stride_; }
template <class I>
CUTE_HOST_DEVICE constexpr friend
auto
operator*(CustomStride const &s, I i) { return s.func_(i) * s.stride_; }
CUTE_HOST_DEVICE friend
void
print(CustomStride const & s) {
cute::print("Custom{");
print(s.func_);
cute::print(",");
print(s.stride_);
cute::print("}");
}
template<class Div>
CUTE_HOST_DEVICE constexpr friend
auto
safe_div(CustomStride const &s, Div const &div)
{
return CustomStride<Func, decltype(safe_div(s.stride_, div))>(s.func_, safe_div(s.stride_, div));
}
// Circumvent the requirement on make_layout that shape and stride are integral
template <class Shape>
CUTE_HOST_DEVICE constexpr friend
auto
make_layout(Shape const &shape, CustomStride const &stride)
{
return Layout<Shape, CustomStride>(shape, stride);
}
Func func_;
Stride stride_;
};
template<class Stride, class Func>
CUTLASS_HOST_DEVICE
auto
make_custom_stride_layout(Stride const &stride, Func&& func)
{
// Use a dummy shape and replace the first non-unit stride with a custom gather stride
auto idx = find_if(stride, [](auto x){ return not is_constant<1, decltype(x)>{}; });
constexpr int I = decltype(idx)::value;
return make_layout(repeat_like(stride, _1{}),
replace<I>(stride, CustomStride{static_cast<Func&&>(func), get<I>(stride)}));
}
/// Helper function to optionally create a gather tensor
template<class Iterator, class Shape, class Stride, class Func>
CUTLASS_HOST_DEVICE
auto
make_gather_tensor(Iterator iter, Shape const &shape, Stride const &stride, Func &&func)
{
if constexpr (not cutlass::platform::is_same<remove_cvref_t<Func>, NoGather>::value) {
Layout matrix_layout = make_identity_layout(shape);
auto offset = as_arithmetic_tuple(repeat_like(shape, _0{}));
Layout gather_layout = make_custom_stride_layout(stride, static_cast<Func&&>(func));
return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout});
} else {
return make_tensor(iter, shape, stride);
}
}
} // namespace example
namespace cute
{
template<int N, int I, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
upcast(Shape const& shape, Stride const& stride)
{
if constexpr (is_tuple<Shape>::value) {
return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast<N,I>(s,d); });
} else if constexpr (is_scaled_basis<Stride>::value) {
if constexpr (Stride::mode() == I) {
return make_layout(ceil_div(shape, Int<N>{}), ceil_div(stride, Int<N>{}));
} else {
return make_layout(shape, stride);
}
} else {
return upcast<N>(shape, stride);
}
CUTE_GCC_UNREACHABLE;
}
template <int N, class OuterShape, class OuterStride, class Offset, class Shape, class Stride>
CUTE_HOST_DEVICE constexpr
auto
upcast(ComposedLayout<Layout<OuterShape,OuterStride>,Offset,Layout<Shape,Stride>> const& layout)
{
// Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset
auto idx = find_if(layout.layout_a().stride(), [](auto x){ return is_constant<1, decltype(x)>{}; });
constexpr int I = decltype(idx)::value;
// Upcast the outer layout (works as expected)
auto outer = upcast<N>(layout.layout_a());
// Upcast the accumulated offset along stride-1 mode
auto offset = as_arithmetic_tuple(replace<I>(layout.offset(), upcast<N>(get<I>(layout.offset()))));
// Upcast the inner layout's shape along stride-1 mode
auto inner = upcast<N,I>(layout.layout_b().shape(), layout.layout_b().stride());
return composition(outer, offset, inner);
}
} // namespace example
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cuda_runtime.h"
#include <iostream>
/**
* Panic wrapper for unwinding CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \
<< std::endl; \
exit(EXIT_FAILURE); \
} \
}
/**
* Panic wrapper for unwinding CUDA runtime errors
*/
#define CUDA_CHECK(status) \
{ \
cudaError_t error = status; \
if (error != cudaSuccess) { \
std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \
<< " at line: " << __LINE__ << std::endl; \
exit(EXIT_FAILURE); \
} \
}
#define FLASH_MLA_ASSERT(cond) \
do { \
if (!(cond)) { \
std::cerr << "FLASH_MLA_ASSERT: " << #cond << " failed at " << __FILE__ << ":" << __LINE__ << std::endl; \
std::abort(); \
} \
} while (0)
\ No newline at end of file
#pragma once
enum class MaskMode {
kNone = 0U, // No mask
kCausal = 1U, // Causal mask
kCustom = 2U, // Custom mask
};
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Support the producer to acquire specific bytes of data.
*/
#pragma once
#include "cutlass/pipeline/sm100_pipeline.hpp"
namespace cutlass {
using namespace cute;
template <
int Stages_,
class ClusterShape = Shape<int,int,_1>,
class AtomThrShape_MNK_ = Shape<_1,_1,_1>
>
class PipelineTmaAsyncMla {
public:
static constexpr uint32_t Stages = Stages_;
using AtomThrShape_MNK = AtomThrShape_MNK_;
private:
using Impl = PipelineTmaUmmaAsync<Stages_, ClusterShape, AtomThrShape_MNK_>;
public:
using FullBarrier = typename Impl::FullBarrier;
using EmptyBarrier = typename Impl::EmptyBarrier;
using ProducerBarrierType = typename Impl::ProducerBarrierType;
using ConsumerBarrierType = typename Impl::ConsumerBarrierType;
using PipelineState = typename Impl::PipelineState;
using SharedStorage = typename Impl::SharedStorage;
using ThreadCategory = typename Impl::ThreadCategory;
using Params = typename Impl::Params;
using McastDirection = McastDirection;
// Helper function to initialize barriers
static
CUTLASS_DEVICE
void
init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape) {
int warp_idx = canonical_warp_idx_sync();
if (warp_idx == params.initializing_warp) {
// Barrier FULL and EMPTY init
constexpr int producer_arv_cnt = 1;
auto atom_thr_shape = AtomThrShape_MNK{};
uint32_t const multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) +
(cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1;
cutlass::arch::detail::initialize_barrier_array_pair_aligned<decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>(
storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count);
}
cutlass::arch::fence_barrier_init();
}
static
CUTLASS_DEVICE
void
init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction) {
auto atom_thr_shape = AtomThrShape_MNK{};
int warp_idx = canonical_warp_idx_sync();
if (warp_idx == params.initializing_warp) {
// Barrier FULL and EMPTY init
constexpr int producer_arv_cnt = 1;
uint32_t const multicast_consumer_arrival_count = (mcast_direction == McastDirection::kRow) ?
cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape) : // Mcast with row ctas
cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape); // Mcast with col ctas
cutlass::arch::detail::initialize_barrier_array_pair_aligned<decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>(
storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count);
}
cutlass::arch::fence_barrier_init();
}
CUTLASS_DEVICE
void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) {
// Calculate consumer mask
if (params_.role == ThreadCategory::Consumer) {
auto cluster_layout = make_layout(cluster_shape);
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kRowCol>(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
}
}
CUTLASS_DEVICE
void init_masks(ClusterShape cluster_shape, McastDirection mcast_direction) {
// Calculate consumer mask
dim3 block_id_in_cluster = cute::block_id_in_cluster();
auto cluster_layout = make_layout(cluster_shape);
if (mcast_direction == McastDirection::kRow) {
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kRow>(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
}
else {
block_id_mask_ = detail::calculate_multicast_mask<McastDirection::kCol>(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster);
}
}
public:
template<typename InitBarriers = cute::true_type, typename InitMasks = cute::true_type>
CUTLASS_DEVICE
PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {})
: impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{})
, params_(params)
, empty_barrier_ptr_(&storage.empty_barrier_[0])
, full_barrier_ptr_(&storage.full_barrier_[0]) {
static_assert(cute::is_same_v<InitBarriers, cute::true_type> || cute::is_same_v<InitBarriers, cute::false_type>);
if constexpr (cute::is_same_v<InitBarriers, cute::true_type>) {
init_barriers(storage, params_, cluster_shape);
}
static_assert(cute::is_same_v<InitMasks, cute::true_type> || cute::is_same_v<InitMasks, cute::false_type>);
if constexpr (cute::is_same_v<InitMasks, cute::true_type>) {
init_masks(cluster_shape);
}
}
template<typename InitBarriers = cute::true_type, typename InitMasks = cute::true_type>
CUTLASS_DEVICE
PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction, InitBarriers = {}, InitMasks = {})
: impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{})
, params_(params)
, empty_barrier_ptr_(&storage.empty_barrier_[0])
, full_barrier_ptr_(&storage.full_barrier_[0]) {
static_assert(cute::is_same_v<InitBarriers, cute::true_type> || cute::is_same_v<InitBarriers, cute::false_type>);
if constexpr (cute::is_same_v<InitBarriers, cute::true_type>) {
init_barriers(storage, params_, cluster_shape, mcast_direction);
}
static_assert(cute::is_same_v<InitMasks, cute::true_type> || cute::is_same_v<InitMasks, cute::false_type>);
if constexpr (cute::is_same_v<InitMasks, cute::true_type>) {
init_masks(cluster_shape, mcast_direction);
}
}
CUTLASS_DEVICE
void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) {
impl_.producer_acquire(state, barrier_token);
}
CUTLASS_DEVICE
void producer_acquire_bytes(uint32_t stage, uint32_t bytes, uint32_t phase, ProducerToken barrier_token) {
detail::pipeline_check_is_producer(params_.role);
if (barrier_token != BarrierStatus::WaitDone) {
empty_barrier_ptr_[stage].wait(phase);
}
if (params_.is_leader) {
full_barrier_ptr_[stage].arrive_and_expect_tx(bytes);
}
#ifndef NDEBUG
if (params_.role == ThreadCategory::Consumer || params_.role == ThreadCategory::NonParticipant) {
asm volatile ("brkpt;\n" ::);
}
// Most likely you have elected more than one leader
if (params_.is_leader && (threadIdx.x % 32 != 0)) {
asm volatile ("brkpt;\n" ::);
}
#endif
}
CUTLASS_DEVICE
void producer_acquire_bytes(PipelineState state, uint32_t bytes, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) {
producer_acquire_bytes(state.index(), bytes, state.phase(), barrier_token);
}
CUTLASS_DEVICE
ProducerBarrierType* producer_get_barrier(PipelineState state) {
return impl_.producer_get_barrier(state);
}
CUTLASS_DEVICE
void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) {
impl_.consumer_wait(state, barrier_token);
}
CUTLASS_DEVICE
void consumer_release(PipelineState state) {
consumer_release(state.index(), false);
}
private:
Impl impl_;
Params params_;
EmptyBarrier *empty_barrier_ptr_;
FullBarrier *full_barrier_ptr_;
uint16_t block_id_mask_ = 0;
static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1;
// Consumer signalling Producer of completion
// Ensures all blocks in the Same Row and Column get notifed.
CUTLASS_DEVICE
void consumer_release(uint32_t stage, uint32_t skip) {
detail::pipeline_check_is_consumer(params_.role);
uint64_t* smem_ptr = reinterpret_cast<uint64_t*>(&empty_barrier_ptr_[stage]);
if constexpr (is_2sm_mma) { // Mma cluster shape is 2x1
if (!skip) {
cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, block_id_mask_);
}
}
else {
if (!skip) {
if constexpr (cute::is_static_v<ClusterShape> and size(ClusterShape{}) == 1) {
cutlass::arch::umma_arrive(smem_ptr);
}
else {
cutlass::arch::umma_arrive_multicast(smem_ptr, block_id_mask_);
}
}
}
}
};
}
/***************************************************************************************************
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/numeric/integral_constant.hpp>
#include <cuda_runtime.h>
namespace cutlass::fmha {
struct Pow2 {
int n;
int log2_n;
explicit CUTE_DEVICE Pow2(int n) : n(n) {
#ifdef __CUDA_ARCH__
log2_n = __ffs(n) - 1;
#endif
}
template<class T>
CUTE_HOST_DEVICE T operator *(T const& b) const {
return n * b;
}
template<int N>
CUTE_HOST_DEVICE auto operator *(Int<N> const&) const {
if constexpr (N & (N - 1) == 0) {
return Pow2{n * N};
}
return n * N;
}
};
template<class T>
CUTE_HOST_DEVICE auto operator/(T const& a, Pow2 const& b) {
return a >> b.log2_n;
}
template<class T>
CUTE_HOST_DEVICE auto operator%(T const& a, Pow2 const& b) {
return a & (b.n - 1);
}
template<class T>
CUTE_HOST_DEVICE bool operator<(T const& a, Pow2 const& b) {
return a < b.n;
}
CUTE_HOST_DEVICE void print(Pow2 const& a) {
printf("2^%d", a.log2_n);
}
} // end namespace cutlass::fmha
namespace cute {
template <>
struct is_integral<cutlass::fmha::Pow2> : true_type {};
} // end namespace cute
#pragma once
#include <torch/extension.h>
#include "cutlass/numeric_types.h"
#include "helper.h"
template <typename T>
struct cutlass_dtype {
using type = T;
};
template <>
struct cutlass_dtype<half> {
using type = cutlass::half_t;
};
template <>
struct cutlass_dtype<nv_bfloat16> {
using type = cutlass::bfloat16_t;
};
template <>
struct cutlass_dtype<__nv_fp8_e4m3> {
using type = cutlass::float_e4m3_t;
};
template <>
struct cutlass_dtype<__nv_fp8_e5m2> {
using type = cutlass::float_e5m2_t;
};
template <typename T>
using cutlass_dtype_t = typename cutlass_dtype<T>::type;
\ No newline at end of file
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief An universal device layer for cutlass 3.x-style kernels.
*/
#pragma once
// common
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#if !defined(__CUDACC_RTC__)
#include "cutlass/cluster_launch.hpp"
#include "cutlass/trace.h"
#endif // !defined(__CUDACC_RTC__)
////////////////////////////////////////////////////////////////////////////////
namespace cutlass::fmha::device {
////////////////////////////////////////////////////////////////////////////////
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template <class Kernel_>
class FMHA {
public:
using Kernel = Kernel_;
static int const kThreadCount = Kernel::MaxThreadsPerBlock;
/// Argument structure: User API
using Arguments = typename Kernel::Arguments;
/// Argument structure: Kernel API
using Params = typename Kernel::Params;
private:
/// Kernel API parameters object
Params params_;
bool is_initialized(bool set = false) {
static bool initialized = false;
if (set) initialized = true;
return initialized;
}
public:
/// Access the Params structure
Params const& params() const {
return params_;
}
/// Determines whether the GEMM can execute the given problem.
static Status
can_implement(Arguments const& args) {
if (Kernel::can_implement(args)) {
return Status::kSuccess;
}
else {
return Status::kInvalid;
}
}
/// Gets the workspace size
static size_t
get_workspace_size(Arguments const& args) {
size_t workspace_bytes = 0;
workspace_bytes += Kernel::get_workspace_size(args);
return workspace_bytes;
}
/// Computes the grid shape
static dim3
get_grid_shape(Params const& params) {
return Kernel::get_grid_shape(params);
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
CUTLASS_TRACE_HOST("FMHA::maximum_active_blocks()");
int max_active_blocks = -1;
int smem_size = Kernel::SharedStorageSize;
// first, account for dynamic smem capacity if needed
cudaError_t result;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
result = cudaFuncSetAttribute(
device_kernel<Kernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error: "
<< cudaGetErrorString(result));
return -1;
}
}
// query occupancy after setting smem size
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks,
device_kernel<Kernel>,
Kernel::MaxThreadsPerBlock,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
<< cudaGetErrorString(result));
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
/// Initializes GEMM state from arguments.
Status
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("FMHA::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
// Initialize the workspace
Status status = Kernel::initialize_workspace(args, workspace, stream);
if (status != Status::kSuccess) {
return status;
}
// Initialize the Params structure
params_ = Kernel::to_underlying_arguments(args, workspace);
if (is_initialized()) return Status::kSuccess;
// account for dynamic smem capacity if needed
int smem_size = Kernel::SharedStorageSize;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
cudaError_t result = cudaFuncSetAttribute(
device_kernel<Kernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
is_initialized(true);
return Status::kSuccess;
}
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
Status
update(Arguments const& args, void* workspace = nullptr) {
CUTLASS_TRACE_HOST("FMHA()::update() - workspace: " << workspace);
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes > 0 && nullptr == workspace) {
return Status::kErrorWorkspaceNull;
}
params_ = Kernel::to_underlying_arguments(args, workspace);
return Status::kSuccess;
}
/// Primary run() entry point API that is static allowing users to create and manage their own params.
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
static Status
run(Params& params, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("FMHA::run()");
dim3 const block = Kernel::get_block_shape();
dim3 const grid = get_grid_shape(params);
// No need to launch the kernel
if(grid.x == 0 || grid.y == 0 || grid.z == 0) {
return Status::kSuccess;
}
// configure smem size and carveout
int smem_size = Kernel::SharedStorageSize;
Status launch_result;
// Use extended launch API only for mainloops that use it
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
cute::size<1>(typename Kernel::ClusterShape{}),
cute::size<2>(typename Kernel::ClusterShape{}));
void const* kernel = (void const*) device_kernel<Kernel>;
void* kernel_params[] = {&params};
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
}
else {
launch_result = Status::kSuccess;
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params);
}
cudaError_t result = cudaGetLastError();
if (cudaSuccess == result && Status::kSuccess == launch_result) {
return Status::kSuccess;
}
else {
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
return Status::kErrorInternal;
}
}
//
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
//
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (Status::kSuccess == status) {
status = run(params_, stream);
}
return status;
}
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
return run(args, workspace, stream);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
run(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
operator()(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::device
////////////////////////////////////////////////////////////////////////////////
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment