Commit 082094b7 authored by Shengyu Liu's avatar Shengyu Liu
Browse files

Multiple updates and refactorings (#150)

* Multiple updates and refactorings

* Remove dead code
parent 1408756a
......@@ -41,7 +41,7 @@
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "utils.h" // for IS_SM100
#include <kerutils/kerutils.cuh> // for KERUTILS_ENABLE_SM100A
#include "../collective/fmha_common.hpp"
#include <cmath>
......@@ -954,8 +954,7 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
TensorC const& coord,
TensorShape const& tensor_shape) {
// TODO: Performance of FlashMLA on sm90 is dropped with latest cutlass, so here revert the to the old version.
// Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); });
auto copy_op = make_cotiled_copy(
Copy_Atom<UniversalCopy<uint128_t>, Element>{},
......@@ -965,23 +964,11 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
auto thr_copy = copy_op.get_slice(_0{});
Tensor quantized_regs = quantize(regs);
auto tCg = thr_copy.partition_D(gmem);
auto tCr = thr_copy.partition_S(quantize(regs));
auto tCc = thr_copy.partition_D(coord);
constexpr int R = decltype(tCr.layout())::rank;
auto tCg_v = group_modes<1, R>(tCg);
auto tCr_v = group_modes<1, R>(tCr);
auto tCc_v = group_modes<1, R>(tCc);
auto tCp_v = make_tensor<bool>(shape<1>(tCc_v));
for (int i = 0; i < size(tCp_v); ++i) {
tCp_v(i) = elem_less(tCc_v(_0{},i), tensor_shape);
}
copy_if(copy_op, tCp_v, tCr_v, tCg_v);
Tensor tCr = thr_copy.partition_S(quantized_regs);
Tensor tCg = thr_copy.partition_D(gmem);
Tensor tPc = thr_copy.partition_D(preds);
copy_if(copy_op, tPc, tCr, tCg);
}
......@@ -1494,7 +1481,7 @@ struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized {
CUTLASS_DEVICE void operator()(Params const& params, char* smem) {
#if IS_SM100
#if defined(KERUTILS_ENABLE_SM100A)
int warp_idx = cutlass::canonical_warp_idx_sync();
auto role = warp_idx_to_role(warp_idx);
uint32_t lane_predicate = cute::elect_one_sync();
......
......@@ -37,7 +37,7 @@
#include "cutlass/pipeline/pipeline.hpp"
#include "cute/arch/tmem_allocator_sm100.hpp"
#include "utils.h" // for IS_SM100
#include <kerutils/kerutils.cuh> // for KERUTILS_ENABLE_SM100A
#include "../kernel/fmha_options.hpp"
#include "../kernel/fmha_tile_scheduler.hpp"
#include "../kernel/fmha_causal_tile_scheduler.hpp"
......@@ -252,7 +252,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized {
}
CUTLASS_DEVICE void operator()(const Params &params, char* smem) {
#if IS_SM100
#if defined(KERUTILS_ENABLE_SM100A)
TileScheduler tile_scheduler{params.tile_scheduler};
......
#pragma once
#include <cute/tensor.hpp>
#include <kerutils/kerutils.cuh>
namespace sm100 {
/*
Load K/V indices from global memory, and generate validity mask
Each thread loads 8 indices
Should be called by lanes 0 ~ (BLOCK_TOPK/8)
*/
CUTE_DEVICE
char load_indices_and_generate_mask(
int lane_idx,
int* gIndices,
int s_kv,
int abs_pos_start,
int topk_length
) {
int indices[8];
KU_LDG_256(
gIndices + lane_idx*8,
indices,
".nc",
"no_allocate",
"evict_normal",
"256B"
);
auto is_valid = [&](int rel_pos_in_lane, int index) -> char {
int abs_pos = abs_pos_start + lane_idx*8 + rel_pos_in_lane;
return index >= 0 && index < s_kv && abs_pos < topk_length;
};
char is_ks_valid_mask = \
is_valid(7, indices[7]) << 7 |
is_valid(6, indices[6]) << 6 |
is_valid(5, indices[5]) << 5 |
is_valid(4, indices[4]) << 4 |
is_valid(3, indices[3]) << 3 |
is_valid(2, indices[2]) << 2 |
is_valid(1, indices[1]) << 1 |
is_valid(0, indices[0]) << 0;
return is_ks_valid_mask;
}
/*
Get P from Tensor Memory, reduce P within shared memory, perform masking, and store back if necessary
Initially, since dual gemm is used, we have two P pieces in Tensor Memory, one occupying rows 0 ~ 63 while the other occupying rows 64 ~ 127. We'd like to have them reduced into one single P piece, stored in registers with layout:
N N --- (topk)
+-------+-------+
| | |
32 | Warp0 | Warp2 |
| | |
+-------+-------+
| | |
32 | Warp1 | Warp3 |
| | |
+-------+-------+
|
(head)
where N = NUM_ELEMS_PER_THREAD
*/
template<
int NUM_ELEMS_PER_THREAD,
int TMEM_COL_START,
int BARRIER_WARP02_SYNC_ID,
int BARRIER_WARP13_SYNC_ID,
bool STORE_BACK_P
>
CUTE_DEVICE
void retrieve_mask_and_reduce_p(
char* k_validness_base,
int local_warp_idx,
int lane_idx,
auto slot_bar_P_empty_arrival,
float p_exchange_buf[4][32*NUM_ELEMS_PER_THREAD],
float p[NUM_ELEMS_PER_THREAD]
) {
using namespace cute;
using cutlass::arch::NamedBarrier;
static_assert(BARRIER_WARP13_SYNC_ID == BARRIER_WARP02_SYNC_ID+1);
float p_peer[NUM_ELEMS_PER_THREAD];
if (local_warp_idx < 2) {
ku::tmem_ld_32dp32bNx<NUM_ELEMS_PER_THREAD>(TMEM_COL_START, p);
ku::tmem_ld_32dp32bNx<NUM_ELEMS_PER_THREAD>(TMEM_COL_START + NUM_ELEMS_PER_THREAD, p_peer);
} else {
ku::tmem_ld_32dp32bNx<NUM_ELEMS_PER_THREAD>(TMEM_COL_START, p_peer);
ku::tmem_ld_32dp32bNx<NUM_ELEMS_PER_THREAD>(TMEM_COL_START + NUM_ELEMS_PER_THREAD, p);
}
cutlass::arch::fence_view_async_tmem_load();
ku::tcgen05_before_thread_sync();
slot_bar_P_empty_arrival();
// Mask invalid tokens
// We put masking before reduction, since (-inf) + anything (except nan and +inf) is (-inf), which guarantees correctness, and this can overlap with smem load
static_assert(NUM_ELEMS_PER_THREAD == 32);
uint32_t is_k_valid = *(uint32_t*)(k_validness_base + (local_warp_idx>=2?NUM_ELEMS_PER_THREAD/8:0));
CUTE_UNROLL
for (int i = 0; i < NUM_ELEMS_PER_THREAD; i += 1) {
if (!(is_k_valid >> i & 1))
p[i] = -CUDART_INF_F;
}
// Reduce P within the cluster
{
// Store
// Warp 0, 1 store their right (col 32 ~ 63) part, while warp 2, 3 store their left (row 0 ~ 31) part
CUTE_UNROLL
for (int i = 0; i < NUM_ELEMS_PER_THREAD/4; ++i) {
ku::st_shared(&p_exchange_buf[local_warp_idx^2][i*32*4 + lane_idx*4], *(float4*)(p_peer + i*4));
}
NamedBarrier::arrive_and_wait(64, BARRIER_WARP02_SYNC_ID + (local_warp_idx&1));
CUTE_UNROLL
for (int i = 0; i < NUM_ELEMS_PER_THREAD/4; ++i) {
float2 t[2];
*(float4*)t = *(float4*)(&p_exchange_buf[local_warp_idx][i*32*4 + lane_idx*4]);
float2* cur_p = (float2*)(p + i*4);
cur_p[0] = ku::float2_add(cur_p[0], t[0]);
cur_p[1] = ku::float2_add(cur_p[1], t[1]);
}
}
if constexpr (STORE_BACK_P) {
CUTE_UNROLL
for (int i = 0; i < NUM_ELEMS_PER_THREAD/4; ++i) {
ku::st_shared(&p_exchange_buf[local_warp_idx][i*32*4 + lane_idx*4], *(float4*)(p+i*4));
}
}
}
/*
Rescale O in Tensor Memory.
O should occupy 128 rows x (D_V/2) columns in Tensor Memory.
*/
template<
int D_V,
int CHUNK_SIZE,
int TMEM_COL_START
>
CUTE_DEVICE
void rescale_O(
float scale_factor
) {
float2 scale_factor_float2 = {scale_factor, scale_factor};
float2 o[CHUNK_SIZE/2];
CUTE_UNROLL
for (int chunk_idx = 0; chunk_idx < (D_V/2)/CHUNK_SIZE; ++chunk_idx) {
// Load O
ku::tmem_ld_32dp32bNx<CHUNK_SIZE>(TMEM_COL_START + chunk_idx*CHUNK_SIZE, o);
cutlass::arch::fence_view_async_tmem_load();
// Mult
for (int i = 0; i < CHUNK_SIZE/2; ++i) {
o[i] = ku::float2_mul(o[i], scale_factor_float2);
}
// Store O
ku::tmem_st_32dp32bNx<CHUNK_SIZE>(TMEM_COL_START + chunk_idx*CHUNK_SIZE, o);
cutlass::arch::fence_view_async_tmem_store();
}
}
template<int NUM_ELEMS_PER_THREAD>
CUTE_DEVICE
float get_max(
float p[NUM_ELEMS_PER_THREAD]
) {
float local_max = -CUDART_INF_F;
CUTE_UNROLL
for (int i = 0; i < NUM_ELEMS_PER_THREAD; ++i) {
local_max = max(local_max, p[i]);
}
return local_max;
}
/*
Calculate s := exp2f(p*scale - new_max) and its sum
*/
template<int NUM_ELEMS_PER_THREAD>
CUTE_DEVICE
float get_s_from_p(
nv_bfloat162 s[NUM_ELEMS_PER_THREAD/2],
float p[NUM_ELEMS_PER_THREAD],
float scale,
float new_max
) {
float2 cur_sum = float2 {0.0f, 0.0f};
float2 neg_new_max_float2 = float2 {-new_max, -new_max};
float2 scale_float2 = float2 {scale, scale};
CUTE_UNROLL
for (int i = 0; i < NUM_ELEMS_PER_THREAD/2; i += 1) {
float2 d = ku::float2_fma(float2{p[i*2], p[i*2+1]}, scale_float2, neg_new_max_float2);
d.x = exp2f(d.x);
d.y = exp2f(d.y);
cur_sum = ku::float2_add(cur_sum, d);
s[i] = __float22bfloat162_rn(d);
}
return cur_sum.x + cur_sum.y;
}
}
#pragma once
#include "params.h"
namespace sm100 {
void run_fwd_kernel(const SparsePrefillParams& params);
}
#pragma once
#include <math_constants.h>
#include <cute/tensor.hpp>
#include <kerutils/kerutils.cuh>
#include "params.h"
#include "defines.h"
namespace sm100::fwd::head128 {
using namespace cute;
template<
typename Shape_Q, typename TMA_Q,
typename Shape_O, typename TMA_O
>
struct TmaParams {
Shape_Q shape_Q; TMA_Q tma_Q;
Shape_O shape_O; TMA_O tma_O;
CUtensorMap tensor_map_kv;
};
struct float2x2 {
float2 lo, hi;
};
template<int D_QK>
struct KernelTemplate {
static constexpr int D_Q = D_QK;
static constexpr int D_K = D_QK;
static constexpr int D_V = 512;
static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) to avoid -inf - (-inf) = nan
static constexpr int B_H = 128; // For 2 CTAs
static constexpr int B_TOPK = 128; // For 2 CTAs
static constexpr int NUM_BUFS = 2;
static constexpr int NUM_THREADS = 256 + 128 + 128; // 128 scale & exp threads, 128x2 TMA threads, 32 UTCMMA threads
static constexpr int D_tQ = 384, NUM_tQ_TILES = D_tQ / 64;
static constexpr int D_sQ = D_QK-D_tQ, NUM_sQ_TILES = D_sQ / 64;
static_assert(D_sQ%64 == 0 && D_tQ%64 == 0 && D_sQ + D_tQ == D_Q);
// Tensor memory columns
struct tmem_cols {
// 0 ~ 256: output
// 256 ~ 320: P
// 320 ~ 512: Q[D_QK-D_tQ:]
static constexpr int o = 0;
static constexpr int p = 256;
static constexpr int q = 512 - D_tQ/2;
static_assert(p+64 <= q);
};
template<int NUM_TILES>
using SmemLayoutQTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H/2>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutOTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H/2>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutO = SmemLayoutOTiles<8>;
template<int NUM_TILES>
using SmemLayoutKTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_TOPK/2>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutV = decltype(coalesce(tile_to_shape(
UMMA::Layout_MN_SW128_Atom<bf16>{},
Shape<Int<256>, Int<B_TOPK>>{},
Step<_2, _1>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutSTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_INTER_Atom<bf16>{},
Shape<Int<B_H/2>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
struct SharedMemoryPlan {
union {
array_aligned<bf16, cosize_v<SmemLayoutQTiles<D_Q/64>>> q_full;
struct {
array_aligned<bf16, cosize_v<SmemLayoutQTiles<NUM_sQ_TILES>>> sq;
array_aligned<bf16, cosize_v<SmemLayoutV>> v;
// NOTE K is not overlapped with q_full, so we can do k copy-in while performing S->T copy for q
static_assert(cosize_v<SmemLayoutQTiles<D_Q/64>> <= cosize_v<SmemLayoutQTiles<NUM_sQ_TILES>> + cosize_v<SmemLayoutV>);
array_aligned<bf16, cosize_v<SmemLayoutKTiles<D_K/64>>> k;
} s;
array_aligned<bf16, cosize_v<SmemLayoutO>> o;
} u;
array_aligned<bf16, cosize_v<SmemLayoutSTiles<2>>> s;
float p[(B_H/2)*B_TOPK];
char is_k_valid[NUM_BUFS][B_TOPK/8];
transac_bar_t bar_prologue_q, bar_prologue_utccp;
transac_bar_t bar_qk_part_done[NUM_BUFS], bar_qk_done[NUM_BUFS]; // Pi = QKi^T done (i.e. Ki free)
transac_bar_t bar_sv_part_done[NUM_BUFS], bar_sv_done[NUM_BUFS]; // O += SiVi done (i.e. Vi free)
transac_bar_t bar_k_part0_ready[NUM_BUFS], bar_k_part1_ready[NUM_BUFS];
transac_bar_t bar_v_part0_ready[NUM_BUFS], bar_v_part1_ready[NUM_BUFS]; // Vi is ready
transac_bar_t bar_p_free[NUM_BUFS];
transac_bar_t bar_so_ready[NUM_BUFS]; // S and O are ready
transac_bar_t bar_k_valid_ready[NUM_BUFS], bar_k_valid_free[NUM_BUFS];
array_aligned<uint32_t, 1> tmem_start_addr;
float rowwise_max_buf[128], rowwise_li_buf[128];
};
using TiledMMA_P_tQ = decltype(make_tiled_mma(
SM100_MMA_F16BF16_2x1SM_TS_NOELECT<bf16, bf16, float, B_H, B_TOPK, UMMA::Major::K, UMMA::Major::K>{}
));
using TiledMMA_P_sQ = decltype(make_tiled_mma(
SM100_MMA_F16BF16_2x1SM_SS_NOELECT<bf16, bf16, float, B_H, B_TOPK, UMMA::Major::K, UMMA::Major::K>{}
));
using TiledMMA_O = decltype(make_tiled_mma(
SM100_MMA_F16BF16_2x1SM_SS_NOELECT<bf16, bf16, float, B_H, 256, UMMA::Major::K, UMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{},
Tile<Int<128>, Layout<Shape<_128, _2, _2>, Stride<_1, _256, _128>>, _16>{} // We use this permutation layout to let CTA0 takes V[:, 0:256] and CTA1 takes V[:, 256:512]
));
template<typename TmaParams>
static __device__ void
sparse_attn_fwd_kernel_devfunc(const SparseAttnFwdParams &params, const TmaParams &tma_params);
};
}
#include "../phase1.h"
#include "../phase1.cuh"
namespace sm100::fwd::head128 {
template void run_fwd_phase1_kernel<512>(const SparseAttnFwdParams& params);
}
#include "../phase1.h"
#include "../phase1.cuh"
namespace sm100::fwd::head128 {
template void run_fwd_phase1_kernel<576>(const SparseAttnFwdParams& params);
}
#include "fwd.h"
#pragma once
#include "phase1.h"
#include <math_constants.h>
#include <cute/tensor.hpp>
......@@ -9,12 +10,11 @@
#include "params.h"
#include "utils.h"
#include "sm100/ws_gemm.h"
#include "sm100/helpers.h"
#include "sm100/intrinsics.h"
#include "sm100/tma_cta_group2_nosplit.h"
namespace sm100 {
#include "config.h"
namespace sm100::fwd::head128 {
using namespace cute;
......@@ -28,120 +28,6 @@ CUTE_DEVICE int32x8_t ldg_256_indices(void* src_ptr) {
return val;
}
template<
typename Shape_Q, typename TMA_Q,
typename Shape_O, typename TMA_O
>
struct TmaParams {
Shape_Q shape_Q; TMA_Q tma_Q;
Shape_O shape_O; TMA_O tma_O;
CUtensorMap tensor_map_kv;
};
struct float2x2 {
float2 lo, hi;
};
constexpr int D_Q = 576;
constexpr int D_K = 576;
constexpr int D_V = 512;
constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) to avoid -inf - (-inf) = nan
constexpr int B_H = 128; // For 2 CTAs
constexpr int B_TOPK = 128; // For 2 CTAs
constexpr int NUM_BUFS = 2;
constexpr int NUM_THREADS = 256 + 128 + 128; // 128 TMA threads, 128 scale & exp threads, 32 UTCMMA threads
constexpr int D_sQ = 256, NUM_sQ_TILES = D_sQ / 64;
constexpr int D_tQ = D_Q - D_sQ, NUM_tQ_TILES = D_tQ / 64;
static_assert(D_sQ%64 == 0 && D_tQ%64 == 0 && D_sQ + D_tQ == D_Q);
// Tensor memory columns
namespace tmem_cols {
// 0 ~ 256: output
// 256 ~ 320: P
// 320 ~ 512: Q[192:576]
constexpr int o = 0;
constexpr int p = 256;
constexpr int q = 512 - D_tQ/2;
static_assert(p+64 <= q);
}
template<int NUM_TILES>
using SmemLayoutQTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H/2>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutOTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H/2>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutO = SmemLayoutOTiles<8>;
template<int NUM_TILES>
using SmemLayoutKTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_TOPK/2>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutV = decltype(coalesce(tile_to_shape(
UMMA::Layout_MN_SW128_Atom<bf16>{},
Shape<Int<256>, Int<B_TOPK>>{},
Step<_2, _1>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutSTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_INTER_Atom<bf16>{},
Shape<Int<B_H/2>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
struct SharedMemoryPlan {
union {
array_aligned<bf16, cosize_v<SmemLayoutQTiles<9>>> q_full;
struct {
array_aligned<bf16, cosize_v<SmemLayoutQTiles<NUM_sQ_TILES>>> sq;
array_aligned<bf16, cosize_v<SmemLayoutV>> v;
// NOTE K is not overlapped with q_full, so we can do k copy-in while performing S->T copy for q
array_aligned<bf16, cosize_v<SmemLayoutKTiles<9>>> k;
} s;
array_aligned<bf16, cosize_v<SmemLayoutO>> o;
} u;
array_aligned<bf16, cosize_v<SmemLayoutSTiles<2>>> s;
char is_k_valid[NUM_BUFS][B_TOPK/8];
transac_bar_t bar_prologue_q, bar_prologue_utccp;
transac_bar_t bar_qk_part_done[NUM_BUFS], bar_qk_done[NUM_BUFS]; // Pi = QKi^T done (i.e. Ki free)
transac_bar_t bar_sv_part_done[NUM_BUFS], bar_sv_done[NUM_BUFS]; // O += SiVi done (i.e. Vi free)
transac_bar_t bar_k_part0_ready[NUM_BUFS], bar_k_part1_ready[NUM_BUFS];
transac_bar_t bar_v_part0_ready[NUM_BUFS], bar_v_part1_ready[NUM_BUFS]; // Vi is ready
transac_bar_t bar_p_free[NUM_BUFS];
transac_bar_t bar_so_ready[NUM_BUFS]; // S and O are ready
transac_bar_t bar_k_valid_ready[NUM_BUFS], bar_k_valid_free[NUM_BUFS];
array_aligned<uint32_t, 1> tmem_start_addr;
float rowwise_max_buf[128], rowwise_li_buf[128];
};
using TiledMMA_P_tQ = decltype(make_tiled_mma(
SM100_MMA_F16BF16_2x1SM_TS_NOELECT<bf16, bf16, float, B_H, B_TOPK, UMMA::Major::K, UMMA::Major::K>{}
));
using TiledMMA_P_sQ = decltype(make_tiled_mma(
SM100_MMA_F16BF16_2x1SM_SS_NOELECT<bf16, bf16, float, B_H, B_TOPK, UMMA::Major::K, UMMA::Major::K>{}
));
using TiledMMA_O = decltype(make_tiled_mma(
SM100_MMA_F16BF16_2x1SM_SS_NOELECT<bf16, bf16, float, B_H, 256, UMMA::Major::K, UMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{},
Tile<Int<128>, Layout<Shape<_128, _2, _2>, Stride<_1, _256, _128>>, _16>{} // We use this permutation layout to let CTA0 takes V[:, 0:256] and CTA1 takes V[:, 256:512]
));
/*
Pipeline Overview:
......@@ -176,15 +62,17 @@ V(n-1) scale(O) w.r.t P(n-1)
O += S(n-1)V(n-1)
*/
template<int D_QK>
template<typename TmaParams>
__global__ void __launch_bounds__(NUM_THREADS, 1, 2)
sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __grid_constant__ const TmaParams tma_params) {
#if IS_SM100
__device__ void
KernelTemplate<D_QK>::sparse_attn_fwd_kernel_devfunc(const SparseAttnFwdParams &params, const TmaParams &tma_params) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000 && __CUDA_ARCH__ < 1200)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__))
const int cta_idx = blockIdx.x % 2;
const int s_q_idx = blockIdx.x / 2;
const int warp_idx = cutlass::canonical_warp_idx_sync();
const int lane_idx = threadIdx.x % 32;
const int num_k_blocks = params.topk / B_TOPK;
const int topk_length = params.topk_length != nullptr ? __ldg(params.topk_length + s_q_idx) : params.topk;
const int num_k_blocks = max(cute::ceil_div(topk_length, (int)B_TOPK), 1); // num_k_blocks always >= 1
const int warpgroup_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
const int idx_in_warpgroup = threadIdx.x % 128;
......@@ -198,7 +86,7 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri
// Define shared tensors
extern __shared__ char wksp_buf[];
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);
Tensor sQ_full = make_tensor(make_smem_ptr(plan.u.q_full.data()), SmemLayoutQTiles<9>{});
Tensor sQ_full = make_tensor(make_smem_ptr(plan.u.q_full.data()), SmemLayoutQTiles<D_Q/64>{});
int* gIndices = params.indices + s_q_idx*params.stride_indices_s_q; // [topk]
......@@ -248,17 +136,17 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri
tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx),
Tile<Int<B_H/2>>{}
)(_, cta_idx, _);
launch_tma_copy(tma_params.tma_Q, gQ, sQ_full, plan.bar_prologue_q, TMA::CacheHintSm90::EVICT_FIRST);
ku::launch_tma_copy(tma_params.tma_Q, gQ, sQ_full, plan.bar_prologue_q, TMA::CacheHintSm90::EVICT_FIRST);
}
// Initialize TMEM
// We put this before cluster_arrive to make sure that the TMEM allocation is done before UTCCP
cute::TMEM::Allocator2Sm().allocate(512, plan.tmem_start_addr.data());
TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0);
cute::TMEM::Allocator2Sm().release_allocation_lock();
__syncwarp();
}
__syncthreads(); // Wait for TMEM allocation
if (warpgroup_idx == 0) {
cutlass::arch::warpgroup_reg_alloc<144>();
// Scale & Exp warps
......@@ -276,18 +164,19 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri
const float2 scale = float2 {params.sm_scale_div_log2, params.sm_scale_div_log2};
uint128_t* sS_base = (uint128_t*)plan.s.data() + idx_in_warpgroup%64 + 64*((idx_in_warpgroup/64)*8);
float* sP_base = plan.p + idx_in_warpgroup%64*4 + (idx_in_warpgroup/64)*((B_H/2)*(B_TOPK/2));
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks; ++k) {
// Wait for P
plan.bar_qk_done[k%NUM_BUFS].wait((k/NUM_BUFS)&1);
tcgen05_after_thread_sync();
ku::tcgen05_after_thread_sync();
// Load P
float2 p[(B_TOPK/2)/2];
tmem_ld_32dp32bNx<B_TOPK/2>(tmem_cols::p, p);
ku::tmem_ld_32dp32bNx<B_TOPK/2>(tmem_cols::p, p);
cutlass::arch::fence_view_async_tmem_load();
tcgen05_before_thread_sync();
ku::tcgen05_before_thread_sync();
plan.bar_p_free[k%NUM_BUFS].arrive(0u);
// Mask
......@@ -330,6 +219,7 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri
// - cur_pi_max, real_mi, and mi is identical within each row (i.e. thread 0+64, 1+65, ...)
// - should_scale_o is identical among threads 0~31+64~95; and is identical among threads 32~63+96~127
// Calc scale factor, and scale li
float new_max, scale_for_old;
if (!should_scale_o) {
......@@ -348,10 +238,10 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri
float2 neg_new_max = float2 {-new_max, -new_max};
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2)/2; i += 1) {
float2 d = float2_fma(p[i], scale, neg_new_max);
float2 d = ku::float2_fma(p[i], scale, neg_new_max);
d.x = exp2f(d.x);
d.y = exp2f(d.y);
li += d.x + d.y; // NOTE Theorically we can have use FFMA2 here but actually this is faster...
li += d.x + d.y; // NOTE: Theoretically we could use FFMA2 here but actually this is faster...
s[i] = __float22bfloat162_rn(d);
}
......@@ -367,27 +257,27 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri
// Scale O
if (k > 0 && should_scale_o) {
float2 scale_for_old_float2 = float2 {scale_for_old, scale_for_old};
// plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); // NOTE We have waited for last SV gemm before
tcgen05_after_thread_sync();
// plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); // NOTE: We have waited for last SV gemm before
ku::tcgen05_after_thread_sync();
static constexpr int CHUNK_SIZE = 32;
float2 o[CHUNK_SIZE/2];
CUTE_UNROLL
for (int chunk_idx = 0; chunk_idx < (D_V/2)/CHUNK_SIZE; ++chunk_idx) {
// Load O
tmem_ld_32dp32bNx<CHUNK_SIZE>(tmem_cols::o + chunk_idx*CHUNK_SIZE, o);
ku::tmem_ld_32dp32bNx<CHUNK_SIZE>(tmem_cols::o + chunk_idx*CHUNK_SIZE, o);
cutlass::arch::fence_view_async_tmem_load();
// Mult
for (int i = 0; i < CHUNK_SIZE/2; ++i) {
o[i] = float2_mul(o[i], scale_for_old_float2);
o[i] = ku::float2_mul(o[i], scale_for_old_float2);
}
// Store O
tmem_st_32dp32bNx<CHUNK_SIZE>(tmem_cols::o + chunk_idx*CHUNK_SIZE, o);
ku::tmem_st_32dp32bNx<CHUNK_SIZE>(tmem_cols::o + chunk_idx*CHUNK_SIZE, o);
cutlass::arch::fence_view_async_tmem_store();
}
tcgen05_before_thread_sync();
ku::tcgen05_before_thread_sync();
}
fence_view_async_shared();
......@@ -411,17 +301,19 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri
// Store mi and li
if (idx_in_warpgroup < 64) {
int global_index = s_q_idx*params.h_q + cta_idx*(B_H/2) + idx_in_warpgroup;
float cur_lse = log2f(li) + mi;
params.max_logits[global_index] = real_mi;
float cur_lse = logf(li) + mi*CUDART_LN2_F;
cur_lse = cur_lse == -CUDART_INF_F ? +CUDART_INF_F : cur_lse;
params.max_logits[global_index] = real_mi*CUDART_LN2_F;
params.lse[global_index] = cur_lse;
}
// Wait for the last GEMM
plan.bar_sv_done[(num_k_blocks-1)%NUM_BUFS].wait(((num_k_blocks-1)/NUM_BUFS)&1);
tcgen05_after_thread_sync();
ku::tcgen05_after_thread_sync();
// Store O
float output_scale = __fdividef(1.0f, li);
float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : __ldg(params.attn_sink + cta_idx*B_H/2 + (idx_in_warpgroup%64))*CUDART_L2E_F;
float output_scale = __fdividef(1.0f, li + exp2f(attn_sink - mi));
Tensor sO = make_tensor(make_smem_ptr(plan.u.o.data()), SmemLayoutO{});
constexpr int B_EPI = 64;
Tensor tma_gO = flat_divide(
......@@ -435,7 +327,7 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri
auto thr_tma = tma_params.tma_O.get_slice(_0{});
float2 o[B_EPI/2];
bool have_valid_indices = __any_sync(0xffffffff, li != 0); // Prevent some threads' li == 0 and some threads' li != 0 which lead to deadlock during tmem_ld
bool have_valid_indices = __any_sync(0xffffffff, li != 0); // Prevent some threads' li == 0 and some threads' li != 0 which lead to deadlock during ku::tmem_ld
if (!have_valid_indices) {
// If there are no valid indices, we set o[i] to 0 and don't load from TMEM
CUTE_UNROLL
......@@ -450,7 +342,7 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri
for (int k = 0; k < (D_V/2)/B_EPI; ++k) {
// Load O from tO
if (have_valid_indices) {
tmem_ld_32dp32bNx<B_EPI>(tmem_cols::o + k*B_EPI, o);
ku::tmem_ld_32dp32bNx<B_EPI>(tmem_cols::o + k*B_EPI, o);
cutlass::arch::fence_view_async_tmem_load();
}
......@@ -460,7 +352,7 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri
__nv_bfloat162 o_bf16[4];
CUTE_UNROLL
for (int j = 0; j < 4; ++j) {
float2 d = float2_mul(o[i*4+j], output_scale_float2);
float2 d = ku::float2_mul(o[i*4+j], output_scale_float2);
o_bf16[j] = __float22bfloat162_rn(d);
}
int smem_row = idx_in_warpgroup % 64;
......@@ -503,22 +395,28 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks; ++k) {
int4 indices[NUM_LOCAL_ROWS_PER_WARP];
int max_indices = -1, min_indices = params.s_kv;
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row)
for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) {
indices[local_row] = __ldg((int4*)(gIndices + k*B_TOPK + cta_idx*(B_TOPK/2)) + local_row*NUM_WARPS + warp_idx);
max_indices = max(max_indices, int4_max(indices[local_row]));
min_indices = min(min_indices, int4_min(indices[local_row]));
}
bool is_all_rows_invalid = min_indices == params.s_kv || max_indices == -1;
bool should_skip_tma = is_all_rows_invalid && k >= NUM_BUFS;
auto load_part_ki = [&](transac_bar_t* bar, int local_col_start, int local_col_end) {
auto load_part_ki = [&](transac_bar_t &bar, int local_col_start, int local_col_end) {
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) {
CUTE_UNROLL
for (int local_col = local_col_start; local_col < local_col_end; ++local_col)
tma_gather4<true>(
ku::tma_gather4_cta_group_2<true>(
&(tma_params.tensor_map_kv),
bar,
sK_base + local_row*(4*NUM_WARPS)*64 + local_col*((B_TOPK/2)*64),
local_col*64,
indices[local_row],
TMA::CacheHintSm90::EVICT_LAST
(int64_t)TMA::CacheHintSm90::EVICT_LAST
);
}
};
......@@ -527,12 +425,23 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri
if (k > 0) {
plan.bar_qk_part_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);
}
load_part_ki(plan.bar_k_part0_ready+cur_buf, 0, D_sQ/64);
if (!should_skip_tma) {
load_part_ki(plan.bar_k_part0_ready[cur_buf], 0, D_sQ/64);
} else {
// NOTE: TMA has performance issues when all indices are the same (even if those indices are invalid), so we detect whether all indices in our block are invalid (by inspecting their MIN and MAX, for performance reasons), and skip the copy if all indices are invalid.
// NOTE: We can also skip the initial zero-fill procedure (which prevents NaN from appearing in K/V buf if the first TMA copy is skipped) by disabling skipping on the first NUM_BUFS TMAs.
// NOTE: We only do this for K to save some checking overhead, since after doing this for K, cases where topk indices are all invalid are faster than the other cases
plan.bar_k_part0_ready[cur_buf].complete_transaction(0u, NUM_LOCAL_ROWS_PER_WARP*4*D_sQ*sizeof(bf16), 1u);
}
if (k > 0) {
plan.bar_qk_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);
}
load_part_ki(plan.bar_k_part1_ready+cur_buf, D_sQ/64, D_K/64);
if (!should_skip_tma) {
load_part_ki(plan.bar_k_part1_ready[cur_buf], D_sQ/64, D_K/64);
} else {
plan.bar_k_part1_ready[cur_buf].complete_transaction(0u, NUM_LOCAL_ROWS_PER_WARP*4*D_tQ*sizeof(bf16), 1u);
}
}
}
} else if (warpgroup_idx == 2) {
......@@ -549,19 +458,19 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks; ++k) {
auto load_part_vi = [&](transac_bar_t* bar, int local_row_start, int local_row_end) {
auto load_part_vi = [&](transac_bar_t &bar, int local_row_start, int local_row_end) {
CUTE_UNROLL
for (int local_row = local_row_start; local_row < local_row_end; ++local_row) {
int4 token_idxs = __ldg((int4*)(gIndices + k*B_TOPK) + local_row*NUM_WARPS + warp_idx);
CUTE_UNROLL
for (int local_col = 0; local_col < (D_V/2)/64; ++local_col)
tma_gather4<true>(
ku::tma_gather4_cta_group_2<true>(
&(tma_params.tensor_map_kv),
bar,
sV_base + local_row*(4*NUM_WARPS)*64 + local_col*(B_TOPK*64),
local_col*64 + (cta_idx?256:0),
token_idxs,
TMA::CacheHintSm90::EVICT_LAST
(int64_t)TMA::CacheHintSm90::EVICT_LAST
);
}
};
......@@ -570,12 +479,12 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri
if (k > 0) {
plan.bar_sv_part_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);
}
load_part_vi(plan.bar_v_part0_ready+cur_buf, 0, (B_TOPK/2)/4/NUM_WARPS);
load_part_vi(plan.bar_v_part0_ready[cur_buf], 0, (B_TOPK/2)/4/NUM_WARPS);
if (k > 0) {
plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);
}
load_part_vi(plan.bar_v_part1_ready+cur_buf, (B_TOPK/2)/4/NUM_WARPS, B_TOPK/4/NUM_WARPS);
load_part_vi(plan.bar_v_part1_ready[cur_buf], (B_TOPK/2)/4/NUM_WARPS, B_TOPK/4/NUM_WARPS);
}
}
} else {
......@@ -595,7 +504,7 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri
);
plan.bar_prologue_q.arrive_and_expect_tx(B_H*D_K*sizeof(bf16));
plan.bar_prologue_q.wait(0);
tcgen05_after_thread_sync();
ku::tcgen05_after_thread_sync();
CUTE_UNROLL
for (int tile_idx = 0; tile_idx < NUM_tQ_TILES; ++tile_idx) {
// A tile is 64 rows * 64 cols (128B)
......@@ -608,7 +517,7 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri
);
}
}
umma_arrive_multicast_2x1SM_noelect(plan.bar_prologue_utccp, 1|2);
ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_prologue_utccp, 1|2);
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks+1; ++k) {
......@@ -625,18 +534,18 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri
if (k > 0) {
plan.bar_p_free[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);
}
tcgen05_after_thread_sync();
ku::tcgen05_after_thread_sync();
utcmma_ss(tiled_mma_P_sQ, sQl, sKl, tP, true);
umma_arrive_multicast_2x1SM_noelect(plan.bar_qk_part_done[cur_buf], 1|2);
ku::utcmma_ss(tiled_mma_P_sQ, sQl, sKl, tP, true);
ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_qk_part_done[cur_buf], 1|2);
// Wait for K (part1)
plan.bar_k_part1_ready[cur_buf].arrive_and_expect_tx(B_TOPK*(D_K-D_sQ)*sizeof(bf16));
plan.bar_k_part1_ready[cur_buf].wait((k/NUM_BUFS)&1);
tcgen05_after_thread_sync();
ku::tcgen05_after_thread_sync();
utcmma_ts(tiled_mma_P_tQ, tQr, sKr, tP, false);
umma_arrive_multicast_2x1SM_noelect(plan.bar_qk_done[cur_buf], 1|2);
ku::utcmma_ts(tiled_mma_P_tQ, tQr, sKr, tP, false);
ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_qk_done[cur_buf], 1|2);
}
if (k > 0) {
// O += S(i-1)V(i-1)
......@@ -653,17 +562,17 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri
// Wait for V (part0), and issue O += sS @ sV
plan.bar_v_part0_ready[cur_buf].arrive_and_expect_tx((B_TOPK/2)*D_V*sizeof(bf16));
plan.bar_v_part0_ready[cur_buf].wait(((k-1)/NUM_BUFS)&1);
tcgen05_after_thread_sync();
ku::tcgen05_after_thread_sync();
utcmma_ss(tiled_mma_O, sS_divided(_, _, _0{}), sV_divided(_, _, _0{}), tO, k == 1);
umma_arrive_multicast_2x1SM_noelect(plan.bar_sv_part_done[cur_buf], 1|2);
ku::utcmma_ss(tiled_mma_O, sS_divided(_, _, _0{}), sV_divided(_, _, _0{}), tO, k == 1);
ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_sv_part_done[cur_buf], 1|2);
// Wait for V (part1), and issue O += sS @ sV
plan.bar_v_part1_ready[cur_buf].arrive_and_expect_tx((B_TOPK/2)*D_V*sizeof(bf16));
plan.bar_v_part1_ready[cur_buf].wait(((k-1)/NUM_BUFS)&1);
tcgen05_after_thread_sync();
utcmma_ss(tiled_mma_O, sS_divided(_, _, _1{}), sV_divided(_, _, _1{}), tO, false);
umma_arrive_multicast_2x1SM_noelect(plan.bar_sv_done[cur_buf], 1|2);
ku::tcgen05_after_thread_sync();
ku::utcmma_ss(tiled_mma_O, sS_divided(_, _, _1{}), sV_divided(_, _, _1{}), tO, false);
ku::umma_arrive_multicast_2x1SM_noelect(plan.bar_sv_done[cur_buf], 1|2);
}
}
} else if (warp_idx == 13) {
......@@ -674,18 +583,19 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri
for (int k = 0; k < num_k_blocks; ++k) {
int cur_buf = k%NUM_BUFS;
int32x8_t indices = ldg_256_indices(gIndices + k*B_TOPK + lane_idx*8);
auto is_valid = [&](int index) -> char {
return index >= 0 && index < params.s_kv;
auto is_valid = [&](int rel_pos_in_lane, int index) -> char {
int abs_pos = k*B_TOPK + lane_idx*8 + rel_pos_in_lane;
return index >= 0 && index < params.s_kv && abs_pos < topk_length;
};
char is_ks_valid_mask = \
is_valid(indices.a7) << 7 |
is_valid(indices.a6) << 6 |
is_valid(indices.a5) << 5 |
is_valid(indices.a4) << 4 |
is_valid(indices.a3) << 3 |
is_valid(indices.a2) << 2 |
is_valid(indices.a1) << 1 |
is_valid(indices.a0) << 0;
is_valid(7, indices.a7) << 7 |
is_valid(6, indices.a6) << 6 |
is_valid(5, indices.a5) << 5 |
is_valid(4, indices.a4) << 4 |
is_valid(3, indices.a3) << 3 |
is_valid(2, indices.a2) << 2 |
is_valid(1, indices.a1) << 1 |
is_valid(0, indices.a0) << 0;
plan.bar_k_valid_free[cur_buf].wait((k/NUM_BUFS)&1^1);
plan.is_k_valid[cur_buf][lane_idx] = is_ks_valid_mask;
......@@ -695,6 +605,7 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri
}
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100");
......@@ -702,10 +613,21 @@ sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __gri
#endif
}
void run_fwd_kernel(const SparsePrefillParams& params) {
FLASH_ASSERT(params.h_kv == 1);
FLASH_ASSERT(params.topk % B_TOPK == 0); // To save some boundry checkings
FLASH_ASSERT(params.h_q == B_H); // To save some calculation
template<typename Kernel, typename TmaParams>
__global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 2)
sparse_attn_fwd_kernel(__grid_constant__ const SparseAttnFwdParams params, __grid_constant__ const TmaParams tma_params) {
Kernel::sparse_attn_fwd_kernel_devfunc(params, tma_params);
}
template<int D_QK>
void run_fwd_phase1_kernel(const SparseAttnFwdParams& params) {
static_assert(D_QK == 576 || D_QK == 512);
using Kernel = KernelTemplate<D_QK>;
KU_ASSERT(params.h_kv == 1);
KU_ASSERT(params.topk % Kernel::B_TOPK == 0); // To save some boundry checkings
KU_ASSERT(params.h_q == Kernel::B_H); // To save some calculation
KU_ASSERT(params.d_qk == D_QK);
auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q);
auto tma_Q = cute::make_tma_copy(
......@@ -717,7 +639,7 @@ void run_fwd_kernel(const SparsePrefillParams& params) {
make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q)
)
),
SmemLayoutQTiles<9>{}
(typename Kernel::template SmemLayoutQTiles<D_QK/64>){}
);
auto shape_O = make_shape(params.h_q, params.d_v, params.s_q);
......@@ -730,12 +652,12 @@ void run_fwd_kernel(const SparsePrefillParams& params) {
make_stride(params.d_v, _1{}, params.h_q*params.d_v)
)
),
SmemLayoutOTiles<1>{}
(typename Kernel::template SmemLayoutOTiles<1>){}
);
CUtensorMap tensor_map_kv;
{
uint64_t size[2] = {D_K, (unsigned long)params.s_kv};
uint64_t size[2] = {D_QK, (unsigned long)params.s_kv};
uint64_t stride[1] = {params.stride_kv_s_kv*sizeof(bf16)};
uint32_t box_size[2] = {64, 1};
uint32_t elem_stride[2] = {1, 1};
......@@ -753,7 +675,7 @@ void run_fwd_kernel(const SparsePrefillParams& params) {
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);
FLASH_ASSERT(res == CUresult::CUDA_SUCCESS);
KU_ASSERT(res == CUresult::CUDA_SUCCESS);
}
TmaParams<
......@@ -764,22 +686,21 @@ void run_fwd_kernel(const SparsePrefillParams& params) {
shape_O, tma_O,
tensor_map_kv
};
auto kernel = &sparse_attn_fwd_kernel<decltype(tma_params)>;
auto kernel = &sparse_attn_fwd_kernel<Kernel, decltype(tma_params)>;
constexpr size_t smem_size = sizeof(SharedMemoryPlan);
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
constexpr size_t smem_size = sizeof(typename Kernel::SharedMemoryPlan);
KU_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
cutlass::ClusterLaunchParams launch_params = {
dim3(2*params.s_q, 1, 1),
dim3(NUM_THREADS, 1, 1),
dim3(Kernel::NUM_THREADS, 1, 1),
dim3(2, 1, 1),
smem_size,
params.stream
};
cutlass::launch_kernel_on_cluster(
KU_CUTLASS_CHECK(cutlass::launch_kernel_on_cluster(
launch_params, (void*)kernel, params, tma_params
);
CHECK_CUDA_KERNEL_LAUNCH();
));
}
}
#pragma once
#include "params.h"
namespace sm100::fwd::head128 {
template<int D_QK>
void run_fwd_phase1_kernel(const SparseAttnFwdParams& params);
}
#pragma once
#include <cute/tensor.hpp>
#include <kerutils/kerutils.cuh>
#include "defines.h"
namespace sm100::fwd::head64 {
using namespace cute;
template<
typename Shape_Q_NoPE, typename TMA_Q_NoPE,
typename Shape_Q_RoPE, typename TMA_Q_RoPE,
typename Shape_O, typename TMA_O
>
struct TmaParams {
Shape_Q_NoPE shape_Q_nope; TMA_Q_NoPE tma_Q_nope;
Shape_Q_RoPE shape_Q_rope; TMA_Q_RoPE tma_Q_rope;
Shape_O shape_O; TMA_O tma_O;
CUtensorMap tensor_map_kv_nope;
};
struct float2x2 {
float2 lo, hi;
};
constexpr int D_Q = 576;
constexpr int D_K = 576;
constexpr int D_V = 512;
constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) to avoid -inf - (-inf) = nan
constexpr int B_H = 64;
constexpr int B_TOPK = 64;
constexpr int NUM_BUFS = 3;
constexpr int NUM_THREADS = 128 + 128 + 128; // 128 scale & exp threads, 128 TMA threads, 32 UTCMMA threads
// Tensor memory columns
namespace tmem_cols {
// 0 ~ 256: output
// 256 ~ 400: Q
// 400 ~ 464: P
constexpr int O = 0;
constexpr int Q = 256;
constexpr int Q_RoPE = 256 + 128;
constexpr int P = 400;
}
using SmemLayoutQNoPE = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<D_V>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutQRoPE = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW64_Atom<bf16>{},
Shape<Int<B_H>, Int<D_Q-D_V>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutOTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutO = SmemLayoutOTiles<8>;
template<int NUM_TILES>
using SmemLayoutKTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_TOPK>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutKNoPE = SmemLayoutKTiles<8>;
using SmemLayoutV = decltype(coalesce(
composition(
SmemLayoutKNoPE{},
Layout<Shape<Int<D_V>, Int<B_TOPK>>, Stride<Int<B_TOPK>, _1>>{}
)
, Shape<_1, _1>{}));
using SmemLayoutKRoPE = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW64_Atom<bf16>{},
Shape<Int<B_TOPK>, Int<64>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutKNoPE_TiledMMA = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_TOPK*2>, Int<D_V/2>>{},
Step<_1, _2>{}
), Shape<_1, _1>{})); // Re-view K-NoPE as B_TOPK*2 x D_V/2 for dual gemm
using SmemLayoutKRoPE_TiledMMA = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW64_Atom<bf16>{},
Shape<Int<B_TOPK*2>, Int<64/2>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutS = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_INTER_Atom<bf16>{},
Shape<Int<B_H>, Int<B_TOPK>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
struct SharedMemoryPlan {
union {
struct {
array_aligned<bf16, cosize_v<SmemLayoutKRoPE>> _k_rope_pad;
array_aligned<bf16, cosize_v<SmemLayoutKNoPE>> _k_pad[2]; // So that q_nope covers k[2]
array_aligned<bf16, cosize_v<SmemLayoutQNoPE>> q_nope;
} q_full;
struct {
array_aligned<bf16, cosize_v<SmemLayoutKRoPE>> k_rope;
array_aligned<bf16, cosize_v<SmemLayoutKNoPE>> k_nope[NUM_BUFS];
} k;
array_aligned<bf16, cosize_v<SmemLayoutO>> o;
} u;
float p_exchange_buf[4][32 * (B_TOPK/2)];
union {
bf16 s[B_H*B_TOPK];
array_aligned<bf16, cosize_v<SmemLayoutQRoPE>> q_rope;
} s_q_rope;
char is_k_valid[NUM_BUFS][B_TOPK/8];
transac_bar_t bar_prologue_q_nope, bar_prologue_q_rope, bar_prologue_utccp_nope, bar_prologue_utccp_rope;
transac_bar_t bar_qk_nope_done[NUM_BUFS], bar_qk_rope_done; // Pi = QKi^T (the nope part) done
transac_bar_t bar_sv_done[NUM_BUFS]; // O += SiVi done (i.e. O, Si and Vi are free)
transac_bar_t bar_kv_nope_ready[NUM_BUFS][2], bar_kv_rope_ready;
transac_bar_t bar_p_free;
transac_bar_t bar_so_ready; // S and O are ready
transac_bar_t bar_k_valid_ready[NUM_BUFS], bar_k_valid_free[NUM_BUFS];
array_aligned<uint32_t, 1> tmem_start_addr;
float rowwise_max_buf[128], rowwise_li_buf[128];
};
using TiledMMA_P = decltype(make_tiled_mma(
SM100_MMA_F16BF16_WS_TS_NOELECT<bf16, bf16, float, B_H, 128, UMMA::Major::K, UMMA::Major::K>{} // Here we use N = 128 = 2*B_TOPK since we're going to use implicit dual gemm: <TODO Fill link here>
));
using TiledMMA_O = decltype(make_tiled_mma(
SM100_MMA_F16BF16_WS_SS_NOELECT<bf16, bf16, float, B_H, 256, UMMA::Major::K, UMMA::Major::MN>{}
));
enum NamedBarriers : int {
wg0_sync = 0,
wg0_warp02_sync = 1,
wg0_warp13_sync = 2,
pepi_sync = 3,
};
}
#include "../phase1.h"
#include "../phase1.cuh"
namespace sm100::fwd::head64 {
template void run_fwd_phase1_kernel<512>(const SparseAttnFwdParams& params);
}
#include "../phase1.h"
#include "../phase1.cuh"
namespace sm100::fwd::head64 {
template void run_fwd_phase1_kernel<576>(const SparseAttnFwdParams& params);
}
#pragma once
#include "phase1.h"
#include <math_constants.h>
#include <cute/tensor.hpp>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/arch/arch.h>
#include <cutlass/cuda_host_adapter.hpp>
#include <kerutils/kerutils.cuh>
#include "params.h"
#include "utils.h"
#include "sm100/helpers.h"
#include "sm100/prefill/sparse/common_subroutine.h"
#include "config.h"
namespace sm100::fwd::head64 {
using namespace cute;
/*
Pipeline Overview:
| Copy | MMA | Scale & Exp |
KV0
KV1
KV2
P0 = QK0^T
S0 = exp(P0)
scale(O) w.r.t P0
P1 = QK1^T
S1 = exp(P1)
O += S0V0
KV3 scale(O) w.r.t P1
P2 = QK2^T
S2 = exp(P2)
O += S1V1
KV4 scale(O) w.r.t P2
P3 = QK3^T
S3 = exp(P3)
O += S2V2
KV5 scale(O) w.r.t P3
...
O += S(n-3)V(n-3)
scale(O) w.r.t P(n-2)
P(n-1) = QK(n-1)^T
S(n-1) = exp(P(n-1))
O += S(n-2)V(n-2)
scale(O) w.r.t P(n-1)
O += S(n-1)V(n-1)
*/
using FwdMode = SparseAttnFwdMode;
template<bool HAVE_ROPE, typename TmaParams>
__global__ void __launch_bounds__(NUM_THREADS, 1, 1)
sparse_attn_fwd_kernel(__grid_constant__ const SparseAttnFwdParams params, __grid_constant__ const TmaParams tma_params) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000 && __CUDA_ARCH__ < 1200)) || (defined(__CLION_IDE__) || defined(__VSCODE_IDE__))
// Grid shape: [s_q, 1, 1]
const int s_q_idx = blockIdx.x;
const int warp_idx = cutlass::canonical_warp_idx_sync();
const int lane_idx = threadIdx.x % 32;
const int warpgroup_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
const int idx_in_warpgroup = threadIdx.x % 128;
const int topk_length = params.topk_length != nullptr ? __ldg(params.topk_length + s_q_idx) : params.topk;
const int num_k_blocks = max(cute::ceil_div(topk_length, (int)B_TOPK), 1); // num_k_blocks always >= 1
// Define shared tensors
extern __shared__ char wksp_buf[];
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);
int* gIndices = params.indices + s_q_idx*params.stride_indices_s_q; // [topk]
// Allocate tmem tensors
TiledMMA tiled_mma_P = TiledMMA_P{};
TiledMMA tiled_mma_O = TiledMMA_O{};
// NOTE These tXXX tensors are only for a forged layout (so that CuTe is able to generate correct address in cute::gemm)
Tensor tP = partition_fragment_C(tiled_mma_P, Shape<Int<B_H>, _128>{});
Tensor tQ_nope_part0 = tiled_mma_P.get_slice(_0{}).make_fragment_A(
partition_shape_A(tiled_mma_P, Shape<Int<B_H>, Int<(D_V/2)/2>>{})
);
Tensor tQ_nope_part1 = tiled_mma_P.get_slice(_0{}).make_fragment_A(
partition_shape_A(tiled_mma_P, Shape<Int<B_H>, Int<(D_V/2)/2>>{})
);
Tensor tQ_rope = tiled_mma_P.get_slice(_0{}).make_fragment_A(
partition_shape_A(tiled_mma_P, Shape<Int<B_H>, Int<64/2>>{})
);
Tensor tO = partition_fragment_C(tiled_mma_O, Shape<Int<B_H>, Int<D_V>>{});
tP.data().get() = tmem_cols::P;
tQ_nope_part0.data().get() = tmem_cols::Q;
tQ_nope_part1.data().get() = tmem_cols::Q + 64;
tQ_rope.data().get() = tmem_cols::Q_RoPE;
tO.data().get() = tmem_cols::O;
if (warp_idx == 0) {
if (elect_one_sync()) {
// Copy Q
if constexpr (HAVE_ROPE) {
cute::prefetch_tma_descriptor(tma_params.tma_Q_rope.get_tma_descriptor());
}
cute::prefetch_tma_descriptor(tma_params.tma_Q_nope.get_tma_descriptor());
plan.bar_prologue_q_nope.init(1);
plan.bar_prologue_q_rope.init(1);
fence_barrier_init();
if constexpr (HAVE_ROPE) {
Tensor gQ_rope = tma_params.tma_Q_rope.get_tma_tensor(tma_params.shape_Q_rope)(_, _, s_q_idx);
Tensor sQ_rope = make_tensor(make_smem_ptr(plan.s_q_rope.q_rope.data()), SmemLayoutQRoPE{});
ku::launch_tma_copy(tma_params.tma_Q_rope, gQ_rope, sQ_rope, plan.bar_prologue_q_rope, TMA::CacheHintSm90::EVICT_FIRST);
}
Tensor gQ_nope = tma_params.tma_Q_nope.get_tma_tensor(tma_params.shape_Q_nope)(_, _, s_q_idx);
Tensor sQ_nope = make_tensor(make_smem_ptr(plan.u.q_full.q_nope.data()), SmemLayoutQNoPE{});
ku::launch_tma_copy(tma_params.tma_Q_nope, gQ_nope, sQ_nope, plan.bar_prologue_q_nope, TMA::CacheHintSm90::EVICT_FIRST);
cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor());
cute::prefetch_tma_descriptor(&(tma_params.tensor_map_kv_nope));
// Initialize other barriers
plan.bar_prologue_utccp_rope.init(1);
plan.bar_prologue_utccp_nope.init(1);
CUTE_UNROLL
for (int i = 0; i < NUM_BUFS; ++i) {
plan.bar_qk_nope_done[i].init(1);
plan.bar_sv_done[i].init(1);
plan.bar_kv_nope_ready[i][0].init(1);
plan.bar_kv_nope_ready[i][1].init(1);
plan.bar_k_valid_ready[i].init(B_TOPK/8);
plan.bar_k_valid_free[i].init(128);
}
plan.bar_p_free.init(128);
plan.bar_so_ready.init(128);
plan.bar_qk_rope_done.init(1);
plan.bar_kv_rope_ready.init(64);
fence_barrier_init();
}
// Initialize TMEM
cute::TMEM::Allocator1Sm().allocate(512, plan.tmem_start_addr.data());
TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0);
cute::TMEM::Allocator1Sm().release_allocation_lock();
}
__syncthreads();
if (warpgroup_idx == 0) {
// Scale & Exp warps
// The following three numbers are
// - mi: max_logits used to scale Pi (i.e. O := exp2(Pi*scale - mi) @ V)
// - li: sumexp, i.e. li := sum(exp(Pi*scale - mi))
// - real_mi: real max logits, i.e. real_mi := max(Pi*scale)
// where Pi is the i-th row of P, P := QK^T
// mi and real_mi are always consistent within the two threads that
// controls one row (i.e. thread 0+64, 1+65, 2+66, ...) after every update
float mi = MAX_INIT_VAL;
float li = 0.0f;
float real_mi = -CUDART_INF_F;
bf16* sS_base = plan.s_q_rope.s + lane_idx*8 + (warp_idx&1)*(B_H/2)*8 + (warp_idx/2)*B_H*(B_TOPK/2);
static constexpr int NUM_ELEMS_PER_THREAD = B_TOPK / 2;
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks; ++k) {
// Wait for P
NamedBarrier::arrive_and_wait(64, NamedBarriers::wg0_warp02_sync+(warp_idx&1));
plan.bar_qk_nope_done[k%NUM_BUFS].wait((k/NUM_BUFS)&1);
plan.bar_k_valid_ready[k%NUM_BUFS].wait((k/NUM_BUFS)&1); // Put the barrier wait here for more code reordering space
ku::tcgen05_after_thread_sync();
// Load P
float p[NUM_ELEMS_PER_THREAD];
retrieve_mask_and_reduce_p<
NUM_ELEMS_PER_THREAD,
tmem_cols::P,
NamedBarriers::wg0_warp02_sync,
NamedBarriers::wg0_warp13_sync,
false
>(
plan.is_k_valid[k%NUM_BUFS],
warp_idx, lane_idx,
[&]() {plan.bar_p_free.arrive();},
plan.p_exchange_buf,
p
);
plan.bar_k_valid_free[k%NUM_BUFS].arrive();
// Get rowwise max of Pi
float cur_pi_max = get_max<NUM_ELEMS_PER_THREAD>(p);
cur_pi_max *= params.sm_scale_div_log2;
plan.rowwise_max_buf[idx_in_warpgroup] = cur_pi_max;
NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync);
cur_pi_max = max(cur_pi_max, plan.rowwise_max_buf[idx_in_warpgroup^64]);
real_mi = max(real_mi, cur_pi_max);
bool should_scale_o = __any_sync(0xffffffff, cur_pi_max - mi > 6.0f);
// By this point:
// - cur_pi_max, real_mi, and mi is identical within each row (i.e. thread 0+64, 1+65, ...)
// - should_scale_o is identical among every warp, and is identical among threads that controls the same row (i.e. among threads 0~31+64~95; and is identical among threads 32~63+96~127)
// Calc scale factor, and scale li
float new_max, scale_for_old;
if (!should_scale_o) {
// Don't scale O
scale_for_old = 1.0f;
new_max = mi;
} else {
new_max = max(cur_pi_max, mi);
scale_for_old = exp2f(mi - new_max);
}
mi = new_max; // mi is still identical within each row
// Calculate S
nv_bfloat162 s[NUM_ELEMS_PER_THREAD/2];
float cur_sum = get_s_from_p<NUM_ELEMS_PER_THREAD>(s, p, params.sm_scale_div_log2, new_max);
li = fma(li, scale_for_old, cur_sum);
// Wait for last SV gemm, write S
if (k > 0) {
plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);
}
CUTE_UNROLL
for (int i = 0; i < NUM_ELEMS_PER_THREAD/8; i += 1) {
*(uint128_t*)(sS_base + B_H*8*i) = *(uint128_t*)(s + i*4);
}
// Scale O
if (k > 0 && should_scale_o) {
// plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); // NOTE We have waited for last SV gemm before
ku::tcgen05_after_thread_sync();
rescale_O<D_V, 32, tmem_cols::O>(scale_for_old);
ku::tcgen05_before_thread_sync();
}
fence_view_async_shared();
plan.bar_so_ready.arrive();
}
// Epilogue
if (real_mi == -CUDART_INF_F) {
// real_mi == -CUDART_INF_F <=> No valid TopK indices
// We set li to 0 to fit the definition that li := exp(x[i] - mi)
li = 0.0f;
mi = -CUDART_INF_F;
}
// Exchange li
plan.rowwise_li_buf[idx_in_warpgroup] = li;
NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync);
li += plan.rowwise_li_buf[idx_in_warpgroup^64];
// Store mi and li
if (idx_in_warpgroup < 64) {
int global_index = s_q_idx*params.h_q + idx_in_warpgroup;
float cur_lse = fmaf(mi, CUDART_LN2_F, logf(li));
cur_lse = cur_lse == -CUDART_INF_F ? +CUDART_INF_F : cur_lse;
params.max_logits[global_index] = real_mi*CUDART_LN2_F;
params.lse[global_index] = cur_lse;
}
// Wait for the last GEMM
plan.bar_sv_done[(num_k_blocks-1)%NUM_BUFS].wait(((num_k_blocks-1)/NUM_BUFS)&1);
ku::tcgen05_after_thread_sync();
// Fetch dO if necessary
// Store O
float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : __ldg(params.attn_sink + (idx_in_warpgroup%64))*CUDART_L2E_F;
float output_scale = __fdividef(1.0f, li + exp2f(attn_sink - mi));
Tensor sO = make_tensor(make_smem_ptr(plan.u.o.data()), SmemLayoutO{});
constexpr int B_EPI = 64;
Tensor tma_gO = flat_divide(
tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx),
Shape<Int<B_H>, Int<B_EPI>>{}
)(_, _, _0{}, _);
Tensor sO_divided = flat_divide(
sO,
Shape<Int<B_H>, Int<B_EPI>>{}
)(_, _, _0{}, _);
auto thr_tma = tma_params.tma_O.get_slice(_0{});
float2 o[B_EPI/2];
bool have_valid_indices = __any_sync(0xffffffff, li != 0); // Prevent some threads' li == 0 and some threads' li != 0 which lead to deadlock during ku::tmem_ld
if (!have_valid_indices) {
// If there are no valid indices, we set o[i] to 0 and don't load from TMEM
CUTE_UNROLL
for (int i = 0; i < B_EPI/2; ++i)
o[i].x = o[i].y = 0.0f;
output_scale = 1.0f;
}
float2 output_scale_float2 = make_float2(output_scale, output_scale);
bf16* sO_addrs[8];
CUTE_UNROLL
for (int i = 0; i < B_EPI/8; ++i) {
sO_addrs[i] = &sO(idx_in_warpgroup%64, i*8);
}
CUTE_UNROLL
for (int c = 0; c < 2; ++c) {
// Each tile: 64 x 256
CUTE_UNROLL
for (int k = 0; k < (D_V/4)/B_EPI; ++k) {
// Load O from tO
if (have_valid_indices) {
ku::tmem_ld_32dp32bNx<B_EPI>(tmem_cols::O + c*128 + k*B_EPI, o);
cutlass::arch::fence_view_async_tmem_load();
}
// Convert and store
CUTE_UNROLL
for (int i = 0; i < B_EPI/8; ++i) {
nv_bfloat162 o_bf16[4];
CUTE_UNROLL
for (int j = 0; j < 4; ++j) {
o[i*4+j] = ku::float2_mul(o[i*4+j], output_scale_float2);
o_bf16[j] = __float22bfloat162_rn(o[i*4+j]);
}
*(uint128_t*)(sO_addrs[i] + (c*(D_V/2) + (idx_in_warpgroup/64)*(D_V/4) + k*B_EPI)*64) = *(uint128_t*)(o_bf16);
}
// Sync
fence_view_async_shared();
NamedBarrier::arrive_and_wait(128, NamedBarriers::wg0_sync);
if (warp_idx == 0 && elect_one_sync()) {
int epi_chunk_idx = c*(D_V/2/B_EPI) + k;
cute::copy(
tma_params.tma_O,
thr_tma.partition_S(sO_divided(_, _, epi_chunk_idx)),
thr_tma.partition_D(tma_gO(_, _, epi_chunk_idx))
);
}
if (warp_idx == 1 && elect_one_sync()) {
int epi_chunk_idx = c*(D_V/2/B_EPI) + (D_V/B_EPI/4) + k;
cute::copy(
tma_params.tma_O,
thr_tma.partition_S(sO_divided(_, _, epi_chunk_idx)),
thr_tma.partition_D(tma_gO(_, _, epi_chunk_idx))
);
}
}
}
if (warp_idx == 0) {
cute::TMEM::Allocator1Sm().free(0, 512);
}
} else if (warpgroup_idx == 1) {
// Producer warp for KV
int warp_idx = cutlass::canonical_warp_idx_sync() - 4;
constexpr int NUM_WARPS = 4, NUM_LOCAL_ROWS_PER_WARP = (B_TOPK/4)/NUM_WARPS;
if (elect_one_sync()) {
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks; ++k) {
int4 indices[NUM_LOCAL_ROWS_PER_WARP];
int max_indices = -1, min_indices = params.s_kv;
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) {
indices[local_row] = __ldg((int4*)(gIndices + k*B_TOPK) + local_row*NUM_WARPS + warp_idx);
max_indices = max(max_indices, int4_max(indices[local_row]));
min_indices = min(min_indices, int4_min(indices[local_row]));
}
bool is_all_rows_invalid = min_indices == params.s_kv || max_indices == -1;
bool should_skip_tma = is_all_rows_invalid && k >= NUM_BUFS;
if (k == 2) {
plan.bar_prologue_utccp_nope.wait(0); // Since q_nope coincidences with k[2]
}
// Copy NoPE
int cur_buf = k%NUM_BUFS;
plan.bar_sv_done[cur_buf].wait((k/NUM_BUFS)&1^1);
bf16* sK_nope_base = plan.u.k.k_nope[cur_buf].data() + warp_idx*4*64;
auto load_kv_nope_part = [&](int part_idx) {
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) {
CUTE_UNROLL
for (int local_col = part_idx*(D_V/2/64); local_col < (part_idx+1)*(D_V/2/64); ++local_col) {
ku::tma_gather4(
&(tma_params.tensor_map_kv_nope),
plan.bar_kv_nope_ready[cur_buf][part_idx],
sK_nope_base + local_row*(4*NUM_WARPS)*64 + local_col*(B_TOPK*64),
local_col*64,
indices[local_row],
(int64_t)TMA::CacheHintSm90::EVICT_LAST
);
}
}
};
if (!should_skip_tma) {
load_kv_nope_part(0);
load_kv_nope_part(1);
} else {
// NOTE See head128/phase1.cuh for this TMA skipping technique
CUTE_UNROLL
for (int part_idx = 0; part_idx < 2; ++part_idx)
plan.bar_kv_nope_ready[cur_buf][part_idx].complete_transaction(NUM_LOCAL_ROWS_PER_WARP*4*D_V/2*sizeof(bf16));
}
}
}
} else {
// MMA warp
if (warp_idx == 8 && elect_one_sync()) {
// S -> T copy for Q
UMMA::SmemDescriptor sQ_nope_desc = UMMA::make_umma_desc<UMMA::Major::K>(
make_tensor(
make_smem_ptr(plan.u.q_full.q_nope.data()),
tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H*2>, Int<64>>{} // We use this shape for dual gemm (TODO Link)
)
)
);
UMMA::SmemDescriptor sQ_rope_desc = UMMA::make_umma_desc<UMMA::Major::K>(
make_tensor(
make_smem_ptr(plan.s_q_rope.q_rope.data()),
tile_to_shape(
UMMA::Layout_K_SW64_Atom<bf16>{},
Shape<Int<B_H*2>, Int<32>>{}
)
)
);
if constexpr (HAVE_ROPE) {
// Copy the RoPE tile: 128 rows * 32 cols (64B) (in UTCCP's view), or 64 rows * 64 cols (in our view)
plan.bar_prologue_q_rope.arrive_and_expect_tx(B_H*(D_Q-D_V)*sizeof(bf16));
plan.bar_prologue_q_rope.wait(0);
ku::tcgen05_after_thread_sync();
CUTE_UNROLL
for (int subtile_idx = 0; subtile_idx < 2; ++subtile_idx) {
// A subtile is 128 rows * 16 cols (256b, 32B) (in UTCCP's view), or 64 rows * 16 cols * 2 (in our view)
SM100_UTCCP_128dp256bit_1cta::copy(
sQ_rope_desc + (subtile_idx*32) / 16,
tmem_cols::Q_RoPE + subtile_idx*8
);
}
ku::umma_arrive_noelect(plan.bar_prologue_utccp_rope);
}
plan.bar_prologue_q_nope.arrive_and_expect_tx(B_H*D_V*sizeof(bf16));
plan.bar_prologue_q_nope.wait(0);
ku::tcgen05_after_thread_sync();
CUTE_UNROLL
for (int tile_idx = 0; tile_idx < D_V/64/2; ++tile_idx) {
// A tile is 128 rows * 64 cols (128B) (in UTCCP's view), or 64 rows * 128 cols (in our view)
CUTE_UNROLL
for (int subtile_idx = 0; subtile_idx < 4; ++subtile_idx) {
// A subtile is 128 rows * 16 cols (256b, 32B) (in UTCCP's view), or 64 rows * 16 cols * 2 (in our view)
SM100_UTCCP_128dp256bit_1cta::copy(
sQ_nope_desc + (tile_idx*(B_H*128*2) + subtile_idx*32) / 16, // Remember that 4 LSBs are not included
tmem_cols::Q + tile_idx*32 + subtile_idx*8
);
}
}
ku::umma_arrive_noelect(plan.bar_prologue_utccp_nope);
if constexpr (HAVE_ROPE) {
plan.bar_prologue_utccp_rope.wait(0);
}
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks+1; ++k) {
if (k < num_k_blocks) {
// Pi = QKi^T
int cur_buf = k%NUM_BUFS;
Tensor sK_nope = make_tensor(make_smem_ptr(plan.u.k.k_nope[cur_buf].data()), SmemLayoutKNoPE_TiledMMA{});
Tensor sK_rope = make_tensor(make_smem_ptr(plan.u.k.k_rope.data()), SmemLayoutKRoPE_TiledMMA{});
plan.bar_p_free.wait(k&1^1);
ku::tcgen05_after_thread_sync();
// Wait for K (RoPE)
// P = Q(rope) @ K(rope)^T
if constexpr (HAVE_ROPE) {
plan.bar_kv_rope_ready.wait(k&1);
ku::tcgen05_after_thread_sync();
ku::utcmma_ts(tiled_mma_P, tQ_rope, sK_rope, tP, true);
ku::umma_arrive_noelect(plan.bar_qk_rope_done);
}
// Wait for K (NoPE)
if (k == 0) {
plan.bar_prologue_utccp_nope.wait(0);
}
Tensor sK_nope_divided = flat_divide(sK_nope, Tile<Int<B_TOPK*2>, Int<D_V/4>>{})(_, _, _0{}, _);
CUTE_UNROLL
for (int kv_nope_part_idx = 0; kv_nope_part_idx < 2; ++kv_nope_part_idx) {
plan.bar_kv_nope_ready[cur_buf][kv_nope_part_idx].arrive_and_expect_tx(B_TOPK*D_V/2*sizeof(bf16));
plan.bar_kv_nope_ready[cur_buf][kv_nope_part_idx].wait((k/NUM_BUFS)&1);
ku::tcgen05_after_thread_sync();
// P += Q(nope) @ K(nope)^T
bool clear_accum = (!HAVE_ROPE) && kv_nope_part_idx == 0;
ku::utcmma_ts(tiled_mma_P, kv_nope_part_idx ? tQ_nope_part1 : tQ_nope_part0, sK_nope_divided(_, _, kv_nope_part_idx), tP, clear_accum);
}
ku::umma_arrive_noelect(plan.bar_qk_nope_done[cur_buf]);
}
if (k > 0) {
// O += S(i-1)V(i-1)
int cur_buf = (k-1)%NUM_BUFS;
Tensor sS = make_tensor(make_smem_ptr(plan.s_q_rope.s), SmemLayoutS{});
Tensor sV = make_tensor(make_smem_ptr(plan.u.k.k_nope[cur_buf].data()), SmemLayoutV{});
// Wait for S(i-1) and O to be scaled
plan.bar_so_ready.wait((k-1)&1);
ku::tcgen05_after_thread_sync();
// O += sS @ sV
ku::utcmma_ss(tiled_mma_O, sS, sV, tO, k == 1);
ku::umma_arrive_noelect(plan.bar_sv_done[cur_buf]);
}
}
} else if (warp_idx == 9) {
// KV valid loading warp
if (lane_idx < B_TOPK/8) {
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks; ++k) {
char k_validness_mask = load_indices_and_generate_mask(
lane_idx,
gIndices + k*B_TOPK,
params.s_kv,
k*B_TOPK,
topk_length
);
int cur_buf = k%NUM_BUFS;
plan.bar_k_valid_free[cur_buf].wait((k/NUM_BUFS)&1^1);
plan.is_k_valid[cur_buf][lane_idx] = k_validness_mask;
plan.bar_k_valid_ready[cur_buf].arrive();
}
}
} else if (warp_idx == 10 || warp_idx == 11) {
if constexpr (HAVE_ROPE) {
int thread_idx = threadIdx.x - 10*32;
constexpr int GROUP_SIZE = 8, NUM_GROUPS = 64/GROUP_SIZE, ROWS_PER_THREAD = B_TOPK/NUM_GROUPS;
int group_idx = thread_idx / GROUP_SIZE, idx_in_group = thread_idx % GROUP_SIZE;
Tensor sK_rope = make_tensor(make_smem_ptr(plan.u.k.k_rope.data()), SmemLayoutKRoPE{});
bf16* sK_rope_base = &sK_rope(group_idx, idx_in_group*8);
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks; ++k) {
int indices[ROWS_PER_THREAD];
CUTE_UNROLL
for (int local_row = 0; local_row < ROWS_PER_THREAD; ++local_row)
indices[local_row] = __ldg(gIndices + k*B_TOPK + group_idx + local_row*NUM_GROUPS);
plan.bar_qk_rope_done.wait(k&1^1);
CUTE_UNROLL
for (int local_row = 0; local_row < ROWS_PER_THREAD; ++local_row) {
int index = indices[local_row];
ku::cp_async_cacheglobal<ku::PrefetchSize::B128>(
params.kv + (int64_t)index*params.stride_kv_s_kv + 512 + idx_in_group*8,
sK_rope_base + local_row*NUM_GROUPS*32,
index >= 0 && index < params.s_kv
); // NOTE Using cp.async instead of TMA is faster here
// NOTE Here we only consider the range of `index` instead of also checking against topk_length, as it's noted that under this scenario (i.e. there exists a valid index among indices[topk_length: ] that points to a token who has NaN inside)
}
cutlass::arch::cpasync_barrier_arrive_noinc((uint64_t*)&(plan.bar_kv_rope_ready));
}
}
}
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100");
}
#endif
}
template<int D_QK>
void run_fwd_phase1_kernel(const SparseAttnFwdParams& params) {
KU_ASSERT(params.h_kv == 1);
KU_ASSERT(params.topk % B_TOPK == 0); // To save some boundry checkings
KU_ASSERT(params.h_q == B_H); // To save some calculation
KU_ASSERT(params.d_qk == D_QK);
static_assert(D_QK == 576 || D_QK == 512);
auto shape_Q_nope = make_shape(params.h_q, D_V, params.s_q);
auto tma_Q_nope = cute::make_tma_copy(
SM90_TMA_LOAD{},
make_tensor(
make_gmem_ptr((bf16*)params.q),
make_layout(
shape_Q_nope,
make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q)
)
),
SmemLayoutQNoPE{}
);
auto shape_Q_rope = make_shape(params.h_q, D_Q-D_V, params.s_q);
auto tma_Q_rope = cute::make_tma_copy(
SM90_TMA_LOAD{},
make_tensor(
make_gmem_ptr((bf16*)params.q + D_V),
make_layout(
shape_Q_rope,
make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q)
)
),
SmemLayoutQRoPE{}
);
auto shape_O = make_shape(params.h_q, params.d_v, params.s_q);
auto tma_O = cute::make_tma_copy(
SM90_TMA_STORE{},
make_tensor(
make_gmem_ptr((bf16*)params.out),
make_layout(
shape_O,
make_stride(params.d_v, _1{}, params.h_q*params.d_v)
)
),
SmemLayoutOTiles<1>{}
);
CUtensorMap tensor_map_kv_nope;
{
uint64_t size[2] = {D_V, (unsigned long)params.s_kv};
uint64_t stride[1] = {params.stride_kv_s_kv*sizeof(bf16)};
uint32_t box_size[2] = {64, 1};
uint32_t elem_stride[2] = {1, 1};
CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
&tensor_map_kv_nope,
CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
2,
params.kv,
size,
stride,
box_size,
elem_stride,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);
KU_ASSERT(res == CUresult::CUDA_SUCCESS);
}
TmaParams<
decltype(shape_Q_nope), decltype(tma_Q_nope),
decltype(shape_Q_rope), decltype(tma_Q_rope),
decltype(shape_O), decltype(tma_O)
> tma_params = {
shape_Q_nope, tma_Q_nope,
shape_Q_rope, tma_Q_rope,
shape_O, tma_O,
tensor_map_kv_nope
};
auto kernel = &sparse_attn_fwd_kernel<D_QK == 576, decltype(tma_params)>;
constexpr size_t smem_size = sizeof(SharedMemoryPlan);
KU_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
kernel<<<params.s_q, NUM_THREADS, smem_size, params.stream>>>(params, tma_params);
KU_CHECK_KERNEL_LAUNCH();
}
}
#pragma once
#include "params.h"
namespace sm100::fwd::head64 {
template<int D_QK>
void run_fwd_phase1_kernel(const SparseAttnFwdParams& params);
}
#pragma once
#include "phase1.h"
#include <math_constants.h>
#include <cutlass/float8.h>
#include <cute/tensor.hpp>
#include <kerutils/kerutils.cuh>
#include "defines.h"
#include "params.h"
namespace sm100::fwd_for_small_topk::head128 {
using namespace cute;
template<SparseAttnFwdMode FWD_MODE, int D_QK>
struct KernelTemplate {
using ArgT = SparseFwdArgT<FWD_MODE>;
static constexpr bool IS_DECODE = is_decode_v<FWD_MODE>;
static constexpr bool IS_PREFILL = !IS_DECODE;
using fp8_e4m3 = cutlass::float_e4m3_t;
using fp8_e8m0 = __nv_fp8_e8m0;
struct TmaParamsForPrefill {
CUtensorMap tensor_map_q;
CUtensorMap tensor_map_kv;
CUtensorMap tensor_map_o;
};
struct TmaParamsForDecode {
CUtensorMap tensor_map_q;
CUtensorMap tensor_map_o;
CUtensorMap tensor_map_o_accum;
CUtensorMap tensor_map_kv_nope;
CUtensorMap tensor_map_kv_rope;
CUtensorMap tensor_map_extra_kv_nope; // Only available if extra_kv is enabled
CUtensorMap tensor_map_extra_kv_rope;
};
using TmaParams = std::conditional_t<
IS_DECODE,
TmaParamsForDecode,
TmaParamsForPrefill
>;
static_assert(D_QK == 512);
static constexpr int D_Q = D_QK;
static constexpr int D_K = D_QK;
static constexpr int D_V = 512;
static constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) to avoid -inf - (-inf) = nan
static constexpr int H_Q = 128; // For 2 CTAs
static constexpr int B_TOPK = 64; // For 2 CTAs
static constexpr int NUM_THREADS = 128*4;
static constexpr int NUM_WORKER_THREADS = IS_PREFILL ? (128 + 4 + (B_TOPK/8) + 1 + 128)*2 + 1 : (128 + 128 + 1 + 32 + 2 + 128)*2;
// For non-decode mode, we have 4 (half-)KV buffers
// For decode mode, we have 3 (half-)KV buffers with two raw KV buffers
static constexpr int NUM_K_BUFS = IS_DECODE ? 3 : 4;
static constexpr int NUM_RAW_K_BUFS = IS_DECODE ? 2 : 0;
static constexpr int NUM_INDEX_BUFS = IS_DECODE ? 4 : 4;
static constexpr int D_NOPE = 448;
static constexpr int D_ROPE = 64;
static constexpr int TMA_K_STRIDE_FOR_DECODING = D_NOPE + 2*D_ROPE;
static constexpr int NUM_SCALES_EACH_TOKEN = 8; // 7 scales + 1 padding
static constexpr int B_EPI = 64; // Epilogue block size for normal case (i.e. prefill or non-splitkv decoding)
static constexpr int B_EPI_SPLITKV = 32; // Epilogue block size for splitkv decoding
static constexpr int NUM_EPI_SPLITKV_BUFS = 4; // The number of epilogue buffers for splitkv decoding
static_assert((H_Q/2)*D_Q*sizeof(bf16) >= NUM_EPI_SPLITKV_BUFS*(H_Q/2)*(B_EPI_SPLITKV*2)*sizeof(float));
// Tensor memory columns
struct tmem_cols {
// 0 ~ 256: Output accumulator
// 256 ~ 384: Q
// 384 ~ 448: P
static constexpr int O = 0;
static constexpr int Q = 256;
static constexpr int P = 384;
};
struct SharedMemoryPlan {
array_aligned<bf16, (H_Q/2)*D_Q> Q; // Will be output for epilogue
array_aligned<bf16, B_TOPK*(D_K/2)> K[NUM_K_BUFS];
array_aligned<fp8_e4m3, B_TOPK*(D_K/2)> K_raw[NUM_RAW_K_BUFS];
array_aligned<bf16, (H_Q/2)*B_TOPK> S;
float P_exchange[4][(H_Q/2/2)*(B_TOPK/2)];
float rowwise_max_buf[128], rowwise_li_buf[128];
CUTE_ALIGNAS(16) char is_k_valid[NUM_INDEX_BUFS][B_TOPK/8];
CUTE_ALIGNAS(16) int tma_coord[NUM_INDEX_BUFS][B_TOPK];
CUTE_ALIGNAS(16) fp8_e8m0 scales[NUM_INDEX_BUFS][B_TOPK][NUM_SCALES_EACH_TOKEN/2];
transac_bar_t bar_sQ_full, bar_tQ_empty, bar_tQ_full;
transac_bar_t bar_tOut_full, bar_tOut_empty;
transac_bar_t bar_KV_full[NUM_K_BUFS], bar_KV_empty[NUM_K_BUFS];
transac_bar_t bar_P_empty;
transac_bar_t bar_QK_done, bar_SV_done;
transac_bar_t bar_S_O_full;
transac_bar_t bar_li_full, bar_li_empty;
// The following barriers are prefill-only
transac_bar_t bar_clc_full, bar_clc_empty;
// The following barriers are decode-only
transac_bar_t bar_raw_KV_full[NUM_RAW_K_BUFS], bar_raw_KV_empty[NUM_RAW_K_BUFS];
transac_bar_t bar_valid_coord_scales_full[NUM_INDEX_BUFS], bar_valid_coord_scales_empty[NUM_INDEX_BUFS];
ku::CLCResponseObj clc_response_obj;
array_aligned<uint32_t, 1> tmem_start_addr;
};
using TiledMMA_P = decltype(make_tiled_mma(
SM100_MMA_F16BF16_2x1SM_TS_NOELECT<bf16, bf16, float, H_Q, B_TOPK*2, UMMA::Major::K, UMMA::Major::K>{}
)); // *2 for dual gemm
using TiledMMA_O = decltype(make_tiled_mma(
SM100_MMA_F16BF16_2x1SM_SS_NOELECT<bf16, bf16, float, H_Q, 256, UMMA::Major::K, UMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{},
Tile<Int<128>, Layout<Shape<_128, _2, _2>, Stride<_1, _256, _128>>, _16>{} // We use this permutation layout to let CTA0 takes V[:, 0:256] and CTA1 takes V[:, 256:512]
));
struct barrier_ids {
static constexpr int WG0_SYNC = 0;
static constexpr int WG2_SYNC = 1;
static constexpr int WG2_WARP02_SYNC = 2;
static constexpr int WG2_WARP13_SYNC = 3;
};
static __device__ void
sparse_attn_fwd_kernel_devfunc(const ArgT &params, const TmaParams &tma_params);
static void run(const ArgT& params);
};
}
#include "../phase1.h"
#include "../phase1.cuh"
namespace sm100::fwd_for_small_topk::head128 {
template void run_fwd_for_small_topk_phase1_kernel<SparseAttnFwdMode::DecodeWithSplitKV, 512>(const SparseAttnDecodeParams& params);
}
#include "../phase1.h"
#include "../phase1.cuh"
namespace sm100::fwd_for_small_topk::head128 {
template void run_fwd_for_small_topk_phase1_kernel<SparseAttnFwdMode::Prefill, 512>(const SparseAttnFwdParams& params);
}
#pragma once
#include "phase1.h"
#include <math_constants.h>
#include <cute/tensor.hpp>
#include <cutlass/cluster_launch.hpp>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/arch/arch.h>
#include "params.h"
#include "utils.h"
#include "sm100/prefill/sparse/common_subroutine.h"
#include "sm100/helpers.h"
#include "config.h"
namespace sm100::fwd_for_small_topk::head128 {
using namespace cute;
using FwdMode = SparseAttnFwdMode;
template<FwdMode FWD_MODE, int D_QK>
__device__ void
KernelTemplate<FWD_MODE, D_QK>::sparse_attn_fwd_kernel_devfunc(const ArgT &params, const TmaParams &tma_params) {
#ifdef KERUTILS_ENABLE_SM100A
// Grid shape: [2*s_q, 1, 1] for prefilling, [2*s_q, num_sm_parts, 1] for decoding
// Cluster shape: [2, 1, 1]
const int warp_idx = cutlass::canonical_warp_idx_sync();
const int lane_idx = threadIdx.x % 32;
const int warpgroup_idx = cutlass::canonical_warp_group_idx();
const int idx_in_warpgroup = threadIdx.x % 128;
const int cta_idx = block_id_in_cluster().x;
extern __shared__ char wksp_buf[];
SharedMemoryPlan &smem = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);
if (warp_idx == 0 && elect_one_sync()) {
cute::prefetch_tma_descriptor(&tma_params.tensor_map_q);
cute::prefetch_tma_descriptor(&tma_params.tensor_map_o);
if constexpr (IS_DECODE) {
cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv_nope);
cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv_rope);
} else {
cute::prefetch_tma_descriptor(&tma_params.tensor_map_kv);
}
} else if (warp_idx == 1 && elect_one_sync()) {
smem.bar_sQ_full.init(1);
smem.bar_tQ_empty.init(1);
smem.bar_tQ_full.init(1);
smem.bar_tOut_full.init(1);
smem.bar_tOut_empty.init(256);
smem.bar_P_empty.init(256);
smem.bar_QK_done.init(1);
smem.bar_SV_done.init(1);
smem.bar_S_O_full.init(256);
smem.bar_li_full.init(H_Q/2);
smem.bar_li_empty.init(128);
if constexpr (FWD_MODE != FwdMode::DecodeWithSplitKV) {
smem.bar_clc_full.init(1);
smem.bar_clc_empty.init(NUM_WORKER_THREADS);
}
fence_barrier_init();
} else if (warp_idx == 2) {
cute::TMEM::Allocator2Sm().allocate(512, smem.tmem_start_addr.data());
KU_TRAP_ONLY_DEVICE_ASSERT(smem.tmem_start_addr.data()[0] == 0);
cute::TMEM::Allocator2Sm().release_allocation_lock();
} else if (warp_idx == 3 && elect_one_sync()) {
CUTE_UNROLL
for (int i = 0; i < NUM_K_BUFS; ++i) {
smem.bar_KV_full[i].init(IS_PREFILL ? 1 : (128/32)*2+1);
smem.bar_KV_empty[i].init(1);
}
CUTE_UNROLL
for (int i = 0; i < NUM_INDEX_BUFS; ++i) {
smem.bar_valid_coord_scales_full[i].init(IS_PREFILL ? B_TOPK/8 : 32);
smem.bar_valid_coord_scales_empty[i].init(IS_PREFILL ? 128 : (128 + (cta_idx==1) + 2 + 128));
}
if constexpr (IS_DECODE) {
CUTE_UNROLL
for (int i = 0; i < NUM_RAW_K_BUFS; ++i) {
smem.bar_raw_KV_full[i].init(1);
smem.bar_raw_KV_empty[i].init(128);
}
}
fence_barrier_init();
}
ku::barrier_cluster_arrive_relaxed();
ku::barrier_cluster_wait_acquire();
struct OuterloopArgs {
bool outer_loop_phase;
int batch_idx, s_q_idx;
int start_block_idx, end_block_idx;
int topk_length;
int extra_topk_length, num_orig_kv_blocks; // extra-KV related
bool is_no_split; int n_split_idx; // splitkv related
};
auto run_outer_loop = [&](auto loop_body) -> bool {
int outer_loop_phase = false;
if constexpr (FWD_MODE == FwdMode::DecodeWithSplitKV) {
int s_q_idx = blockIdx.x / 2;
DecodingSchedMeta sched_meta;
KU_LDG_256(
params.tile_scheduler_metadata_ptr + blockIdx.y,
&sched_meta,
".nc",
"no_allocate",
"evict_normal",
"256B"
);
if (sched_meta.begin_req_idx >= params.b) {
return 0;
}
#pragma unroll 1
for (int batch_idx = sched_meta.begin_req_idx; batch_idx <= sched_meta.end_req_idx; ++batch_idx) {
int topk_length = params.topk_length ? __ldg(params.topk_length + batch_idx) : params.topk;
int orig_topk_padded = max(ku::ceil(topk_length, (int)B_TOPK), (int)B_TOPK);
int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk;
int total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)B_TOPK); // % B_TOPK == 0
int start_block_idx = batch_idx == sched_meta.begin_req_idx ? sched_meta.begin_block_idx : 0;
int end_block_idx = batch_idx == sched_meta.end_req_idx ? sched_meta.end_block_idx : total_topk_padded / B_TOPK;
bool is_split = batch_idx == sched_meta.begin_req_idx ? sched_meta.is_first_req_splitted : (batch_idx == sched_meta.end_req_idx ? sched_meta.is_last_req_splitted : false);
int n_split_idx = batch_idx == sched_meta.begin_req_idx ? (__ldg(params.num_splits_ptr+batch_idx) + sched_meta.begin_split_idx) : __ldg(params.num_splits_ptr+batch_idx);
// start_block_idx = 0;
// end_block_idx = total_topk_padded / B_TOPK;
// is_split = false;
// n_split_idx = 0;
OuterloopArgs args = {
(bool)outer_loop_phase,
batch_idx, s_q_idx,
start_block_idx, end_block_idx,
topk_length,
extra_topk_length, orig_topk_padded / B_TOPK,
!is_split, n_split_idx
};
loop_body(args);
outer_loop_phase ^= 1;
}
} else {
// Prefill mode. Use CLC to allocate different s_q (for decoding, different batches + s_q) to different workers
ku::CLCResult next_job = {true, (int)blockIdx.x, IS_PREFILL ? 0 : (int)blockIdx.y, 0};
CUTE_NO_UNROLL
while (next_job.is_valid) {
int s_q_idx = next_job.x / 2;
int batch_idx = IS_PREFILL ? 0 : next_job.y;
int topk_length = params.topk_length != nullptr ? __ldg(params.topk_length + (IS_PREFILL?s_q_idx:batch_idx)) : params.topk;
if constexpr (IS_PREFILL) {
int num_k_blocks = max(cute::ceil_div(topk_length, (int)B_TOPK), 1); // num_k_blocks always >= 1
OuterloopArgs args = {
(bool)outer_loop_phase,
0, s_q_idx,
0, num_k_blocks,
topk_length
};
loop_body(args);
} else {
int orig_topk_padded = max(ku::ceil(topk_length, (int)B_TOPK), (int)B_TOPK);
int extra_topk_length = params.extra_topk_length ? __ldg(params.extra_topk_length + batch_idx) : params.extra_topk;
int total_topk_padded = orig_topk_padded + ku::ceil(extra_topk_length, (int)B_TOPK); // % B_TOPK == 0
OuterloopArgs args = {
(bool)outer_loop_phase,
batch_idx, s_q_idx,
0, total_topk_padded / B_TOPK,
topk_length,
extra_topk_length, orig_topk_padded / B_TOPK,
false, 0
};
loop_body(args);
}
smem.bar_clc_full.wait(outer_loop_phase);
next_job = ku::get_clc_query_response<true>(smem.clc_response_obj);
smem.bar_clc_empty.arrive(0u);
outer_loop_phase ^= 1;
}
}
return outer_loop_phase;
};
if (warpgroup_idx == 0) {
// Q fetching and O writing back warpgroup
cutlass::arch::warpgroup_reg_alloc<176>();
bf16* sO_addrs[B_EPI/8];
CUTE_UNROLL
for (int i = 0; i < B_EPI/8; ++i) {
Tensor sO = make_tensor(make_smem_ptr(smem.Q.data()), ku::make_umma_canonical_k_major_layout<H_Q/2, D_V, 128>());
sO_addrs[i] = &sO(idx_in_warpgroup%64, (idx_in_warpgroup/64)*(D_V/2) + i*8);
}
float* sO_accum_addrs[B_EPI_SPLITKV/4];
if constexpr (FWD_MODE == FwdMode::DecodeWithSplitKV) {
// If split-KV is enabled, we need to store back O in float32
// We view Q buffer (with shape 64 x 512, bf16) as 4 buffers with shape (H_Q/2) x (B_EPI_SPLITKV*2), float32
Tensor sO_accum = make_tensor(make_smem_ptr((float*)smem.Q.data()), ku::make_umma_canonical_k_major_layout<H_Q/2, D_V, 128, float>());
CUTE_UNROLL
for (int i = 0; i < B_EPI_SPLITKV/4; ++i) {
sO_accum_addrs[i] = &sO_accum(idx_in_warpgroup%64, i*4) + (idx_in_warpgroup >= 64 ? (H_Q/2)*B_EPI_SPLITKV : 0);
}
}
auto perform_o_copy_out = [&](const OuterloopArgs &args, bool is_last_o) {
// outer_loop_phase is the loop phase corresponding to s_q_idx
// Get li (output_scale actually)
smem.bar_li_full.wait(args.outer_loop_phase);
float output_scale = smem.rowwise_li_buf[idx_in_warpgroup%64];
float2 output_scale_float2 = float2 {output_scale, output_scale};
smem.bar_li_empty.arrive();
// Retrieve and store O, and calculate delta := sum(O*dO, dim=-1) if FWD_MODE is Recompute
smem.bar_tOut_full.wait(args.outer_loop_phase);
if (is_last_o && elect_one_sync()) {
cudaTriggerProgrammaticLaunchCompletion();
}
if (FWD_MODE != FwdMode::DecodeWithSplitKV || args.is_no_split) {
CUTE_UNROLL
for (int k = 0; k < (D_V/2)/B_EPI; ++k) {
float2 o[B_EPI/2];
ku::tmem_ld_32dp32bNx<B_EPI>(tmem_cols::O + k*B_EPI, o);
cutlass::arch::fence_view_async_tmem_load();
if (k == (D_V/2)/B_EPI-1) {
smem.bar_tOut_empty.arrive(0u);
}
CUTE_UNROLL
for (int i = 0; i < B_EPI/8; ++i) {
nv_bfloat162 o_bf16[4];
CUTE_UNROLL
for (int j = 0; j < 4; ++j) {
o[i*4+j] = ku::float2_mul(o[i*4+j], output_scale_float2);
o_bf16[j] = __float22bfloat162_rn(o[i*4+j]);
}
bf16* o_do_addr = sO_addrs[i] + k*B_EPI*(H_Q/2);
if (k == 0 && i == 0) {
smem.bar_tQ_full.wait(args.outer_loop_phase^1^is_last_o); // Wait for sQ's availability
}
ku::st_shared(o_do_addr, *(__int128_t*)o_bf16);
}
}
fence_view_async_shared();
NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC);
if (warp_idx == 0 && elect_one_sync()) {
SM90_TMA_STORE_5D::copy(
&tma_params.tensor_map_o,
smem.Q.data(),
0, cta_idx*(H_Q/2), 0, args.s_q_idx, IS_DECODE ? args.batch_idx : 0
);
cute::tma_store_arrive();
}
} else {
CUTE_UNROLL
for (int k = 0; k < (D_V/2)/B_EPI_SPLITKV; ++k) {
int cur_buf_idx = k % NUM_EPI_SPLITKV_BUFS;
if (k == 0) {
cute::tma_store_wait<0>();
} else {
cute::tma_store_wait<NUM_EPI_SPLITKV_BUFS-1>();
}
NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC);
float o[B_EPI_SPLITKV];
ku::tmem_ld_32dp32bNx<B_EPI_SPLITKV>(tmem_cols::O + k*B_EPI_SPLITKV, o);
cutlass::arch::fence_view_async_tmem_load();
if (k == (D_V/2)/B_EPI_SPLITKV-1) {
smem.bar_tOut_empty.arrive(0u);
}
CUTE_UNROLL
for (int i = 0; i < B_EPI_SPLITKV/4; ++i) {
CUTE_UNROLL
for (int j = 0; j < 4; j += 2) {
*(float2*)(o + i*4 + j) = ku::float2_mul(float2 {o[i*4+j], o[i*4+j+1]}, output_scale_float2);
}
if (k == 0 && i == 0) {
smem.bar_tQ_full.wait(args.outer_loop_phase^1^is_last_o); // Wait for sQ's availability
}
ku::st_shared(
sO_accum_addrs[i] + cur_buf_idx*((H_Q/2)*B_EPI_SPLITKV*2),
*(__int128_t*)(o + i*4)
);
}
fence_view_async_shared();
NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC);
if constexpr (IS_DECODE) { // Otherwise nvcc complains about `tma_params` doesn't have `tensor_map_o_accum`
float* cur_buf_base = (float*)smem.Q.data() + cur_buf_idx*((H_Q/2)*B_EPI_SPLITKV*2);
if (warp_idx == 0 && elect_one_sync()) {
SM90_TMA_STORE_5D::copy(
&tma_params.tensor_map_o_accum,
cur_buf_base,
0, cta_idx*(H_Q/2), k*(B_EPI_SPLITKV/32), args.s_q_idx, args.n_split_idx
);
cute::tma_store_arrive();
} else if (warp_idx == 1 && elect_one_sync()) {
SM90_TMA_STORE_5D::copy(
&tma_params.tensor_map_o_accum,
cur_buf_base + (H_Q/2)*B_EPI_SPLITKV,
0, cta_idx*(H_Q/2), k*(B_EPI_SPLITKV/32) + (D_V/2)/32, args.s_q_idx, args.n_split_idx
);
cute::tma_store_arrive();
}
}
}
}
};
OuterloopArgs last_args;
last_args.batch_idx = -1;
bool final_outer_loop_phase = \
run_outer_loop([&](const OuterloopArgs &args) {
// Copy Q for this round
if constexpr (FWD_MODE == FwdMode::DecodeWithSplitKV) {
cute::tma_store_wait<0>();
NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC); // Since we use two warps to issue TMA during FwdMode::DecodeWithSplitKV
}
if (warp_idx == 0 && elect_one_sync()) {
// Wait for sQ to become empty, and issue G -> S copy for Q
if constexpr (FWD_MODE != FwdMode::DecodeWithSplitKV) {
cute::tma_store_wait<0>(); // This thread must be the same one as o copy out thread (since `elect_one_sync()` always returns the same thread for the same `mask`, according to PTX document)
}
int stride_q_b_div_stride_q_s_q = 0;
if constexpr (IS_DECODE) {
stride_q_b_div_stride_q_s_q = params.stride_q_b / params.stride_q_s_q;
}
SM100_TMA_2SM_LOAD_5D_NOSPLIT::copy(
&tma_params.tensor_map_q,
(uint64_t*)&smem.bar_sQ_full,
(uint64_t)TMA::CacheHintSm90::EVICT_FIRST,
smem.Q.data(),
0, cta_idx*(H_Q/2), 0, 0, (IS_DECODE ? args.batch_idx*stride_q_b_div_stride_q_s_q : 0) + args.s_q_idx
);
// Wait for sQ to be ready, and issue S -> T copy for Q
if (cta_idx == 0) {
smem.bar_sQ_full.arrive_and_expect_tx(H_Q*D_Q*sizeof(bf16));
smem.bar_sQ_full.wait(args.outer_loop_phase);
smem.bar_tQ_empty.wait(args.outer_loop_phase^1);
ku::tcgen05_after_thread_sync();
UMMA::SmemDescriptor sQ_desc = UMMA::make_umma_desc<UMMA::Major::K>(
make_tensor(
make_smem_ptr(smem.Q.data()),
ku::make_umma_canonical_k_major_layout<(H_Q/2)*2, 64, 128>()
)
);
CUTE_UNROLL
for (int tile_idx = 0; tile_idx < D_Q/64/2; ++tile_idx) {
// A tile is 128 rows * 64 cols in UTCCP's view, or 64 rows * 128 cols in our view
CUTE_UNROLL
for (int subtile_idx = 0; subtile_idx < 4; ++subtile_idx) {
// A subtile is 128 rows * 16 cols (256b, 32B) (in UTCCP's view), or 64 rows * 16 cols * 2 (in our view)
// NOTE Using `sQ_desc+((tile_idx*((H_Q/2)*128*2) + subtile_idx*32) >> 4)` leads to IMA, doesn't know why
UMMA::SmemDescriptor cur_sQ_desc = sQ_desc;
cur_sQ_desc.lo += ((tile_idx*((H_Q/2)*128*2) + subtile_idx*32) >> 4);
// uint64_t cur_sQ_desc = sQ_desc;
// cur_sQ_desc += ((tile_idx*((H_Q/2)*128*2) + subtile_idx*32) >> 4);
SM100_UTCCP_128dp256bit_2cta::copy(
cur_sQ_desc,
tmem_cols::Q + tile_idx*32 + subtile_idx*8
);
}
}
ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_tQ_full, 1|2);
}
}
if (last_args.batch_idx != -1) {
perform_o_copy_out(last_args, false);
} else {
smem.bar_tQ_full.wait(args.outer_loop_phase); // To prevent double arrive
}
last_args = args;
});
if (last_args.batch_idx != -1) {
cute::tma_store_wait<0>();
NamedBarrier::arrive_and_wait(128, barrier_ids::WG0_SYNC);
perform_o_copy_out(last_args, true);
}
if (warp_idx == 0) {
cute::TMEM::Allocator2Sm().free(0, 512);
}
} else if (warpgroup_idx == 1) {
// KV fetching threads for prefill, dequant threads for decoding
cutlass::arch::warpgroup_reg_dealloc<80>();
RingBufferState rs;
if constexpr (!IS_DECODE) {
const int warp_idx = cutlass::canonical_warp_idx(); // Using `warp_idx` without `__shfl_sync` is faster
if (elect_one_sync()) {
// KV fetching threads
run_outer_loop([&](const OuterloopArgs &args) {
int* gIndices = params.indices + args.s_q_idx*params.stride_indices_s_q;
int64_t cache_hint = ku::create_simple_cache_policy<ku::CacheHint::EVICT_LAST>();
static constexpr int NUM_ROWS_PER_THREAD = B_TOPK / 4;
CUTE_NO_UNROLL
for (int k = args.start_block_idx; k < args.end_block_idx; ++k) {
auto [k_buf_idx, k_bar_phase] = rs.get<NUM_K_BUFS>();
int cur_indices[NUM_ROWS_PER_THREAD];
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_ROWS_PER_THREAD/8; local_row += 1) {
int row = local_row*(4*8) + (warp_idx-4)*8;
KU_LDG_256(
gIndices + k*B_TOPK + row,
cur_indices + local_row*8,
".nc",
"no_allocate",
"evict_first",
"256B"
);
}
smem.bar_KV_empty[k_buf_idx].wait(k_bar_phase^1);
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_ROWS_PER_THREAD/4; local_row += 1) {
int row = (warp_idx-4)*8 + (local_row/2)*(4*8) + (local_row%2)*4;
int4 indices = *(int4*)(cur_indices+local_row*4);
static_assert(D_K == 512);
CUTE_UNROLL
for (int local_col = 0; local_col < (D_K/64)/2; ++local_col) {
ku::tma_gather4_cta_group_2<true>(
&tma_params.tensor_map_kv,
smem.bar_KV_full[k_buf_idx],
smem.K[k_buf_idx].data() + row*64 + local_col*64*B_TOPK,
local_col*64 + cta_idx*(D_K/2),
indices,
cache_hint
);
}
}
rs.update();
}
});
}
} else {
// 8 threads per token
struct IsCTA0 {};
struct IsCTA1 {};
auto launch_dequant_wg = [&](auto cta_id_t) {
static constexpr bool IS_CTA1 = std::is_same<decltype(cta_id_t), IsCTA1>::value;
constexpr int GROUP_SIZE = 8, NUM_GROUPS = 128/8, ROWS_PER_GROUP = B_TOPK / NUM_GROUPS, COLS_PER_GROUP = (IS_CTA1 ? 256-64 : 256) / (GROUP_SIZE*8);
int group_idx = idx_in_warpgroup/GROUP_SIZE, idx_in_group = idx_in_warpgroup%GROUP_SIZE;
Tensor nope0 = make_tensor(make_smem_ptr(smem.K[0].data()), ku::make_umma_canonical_k_major_layout<B_TOPK, D_K/2, 128>());
bf16* nope0_base = &nope0(group_idx, idx_in_group*8);
fp8_e4m3* raw_nope0_base = smem.K_raw[0].data() + group_idx*(D_K/2) + idx_in_group*8;
run_outer_loop([&](const OuterloopArgs &args) {
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {
auto [k_buf_idx, k_bar_phase] = rs.get<NUM_K_BUFS>();
auto [raw_k_buf_idx, raw_k_bar_phase] = rs.get<NUM_RAW_K_BUFS>();
auto [index_buf_idx, index_bar_phase] = rs.get<NUM_INDEX_BUFS>();
fp8_e4m3* raw_nope_base = raw_nope0_base + raw_k_buf_idx * (B_TOPK*(D_K/2));
auto get_raw_fp8 = [&](int local_row_idx, int local_col_idx) -> uint64_t {
return *(uint64_t*)(raw_nope_base + local_row_idx*NUM_GROUPS*(D_K/2) + local_col_idx*(GROUP_SIZE*8));
};
bf16* nope_base = nope0_base + k_buf_idx * (B_TOPK*(D_K/2));
uint32_t cur_nope_base_uint_addr = cute::cast_smem_ptr_to_uint(nope_base);
auto st_128b = [&](int local_row_idx, int local_col_idx, __int128_t &data) {
asm volatile ("st.weak.shared::cta.b128 [%0], %1;\n"
:
: "r"(cur_nope_base_uint_addr + 2*(local_row_idx*NUM_GROUPS*64 + local_col_idx*B_TOPK*64)), "q"(data) // 2 for sizeof(bf16)
); // We have this `asm volatile` here, otherwise the compiler generates ST.E instead of STS
};
smem.bar_valid_coord_scales_full[index_buf_idx].wait(index_bar_phase);
smem.bar_raw_KV_full[raw_k_buf_idx].wait(raw_k_bar_phase);
CUTE_UNROLL
for (int local_row_idx = 0; local_row_idx < ROWS_PER_GROUP; ++local_row_idx) {
int row_idx = local_row_idx*NUM_GROUPS + group_idx;
bf16 scales[4];
fp8_e8m0 scales_e8m0[4];
*(uint32_t*)scales_e8m0 = *(uint32_t*)(smem.scales[index_buf_idx][row_idx]);
*(__nv_bfloat162_raw*)(scales+0) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+0));
*(__nv_bfloat162_raw*)(scales+2) = __nv_cvt_e8m0x2_to_bf162raw(*(unsigned short*)(scales_e8m0+2));
uint64_t cur_data_fp8x8 = get_raw_fp8(local_row_idx, 0);
CUTE_UNROLL
for (int local_col_idx = 0; local_col_idx < COLS_PER_GROUP; ++local_col_idx) {
ku::nve4m3x2 data_fp8[4];
ku::nvbf16x2 data_bf16[4];
*(uint64_t*)data_fp8 = cur_data_fp8x8;
if (local_col_idx+1 < COLS_PER_GROUP)
cur_data_fp8x8 = get_raw_fp8(local_row_idx, local_col_idx+1);
bf16 scale = scales[local_col_idx];
CUTE_UNROLL
for (int i = 0; i < 4; ++i) {
data_bf16[i] = fp8x2_to_bf16x2_with_scale(data_fp8[i], *(ku::nvbf16*)(&scale));
}
if (local_row_idx == 0 && local_col_idx == 0) {
smem.bar_KV_empty[k_buf_idx].wait(k_bar_phase^1);
}
st_128b(local_row_idx, local_col_idx, *(__int128_t*)data_bf16);
}
}
fence_view_async_shared(); // NOTE Should we use shared::cluster here?
__syncwarp();
smem.bar_valid_coord_scales_empty[index_buf_idx].arrive();
smem.bar_raw_KV_empty[raw_k_buf_idx].arrive();
if (elect_one_sync()) {
smem.bar_KV_full[k_buf_idx].arrive(0u);
}
rs.update();
}
});
};
if (cta_idx == 0) {
launch_dequant_wg(IsCTA0{});
} else {
launch_dequant_wg(IsCTA1{});
}
}
} else if (warpgroup_idx == 2) {
cutlass::arch::warpgroup_reg_dealloc<80>();
RingBufferState rs;
if (warp_idx == 8 && cta_idx == 0 && elect_one_sync()) {
// UMMA thread
TiledMMA tiled_mma_P = TiledMMA_P{};
TiledMMA tiled_mma_O = TiledMMA_O{};
Tensor tP = partition_fragment_C(tiled_mma_P, Shape<Int<H_Q/2>, Int<B_TOPK*2>>{});
Tensor tO = partition_fragment_C(tiled_mma_O, Shape<Int<H_Q/2>, Int<D_V>>{});
Tensor tQ = tiled_mma_P.get_slice(_0{}).make_fragment_A(
partition_shape_A(tiled_mma_P, Shape<Int<H_Q/2>, Int<D_Q/2>>{})
);
tP.data().get() = tmem_cols::P;
tO.data().get() = tmem_cols::O;
tQ.data().get() = tmem_cols::Q;
run_outer_loop([&](const OuterloopArgs &args) {
smem.bar_tQ_full.wait(args.outer_loop_phase);
// Issue P = Q K^T
auto issue_P = [&](int k, int rs_offset) {
auto [k_buf_idx, k_bar_phase] = rs.offset_by(rs_offset).get<NUM_K_BUFS>();
auto [_, bar_phase] = rs.offset_by(rs_offset).get<1>();
smem.bar_P_empty.wait(bar_phase^1);
if constexpr (IS_PREFILL) {
smem.bar_KV_full[k_buf_idx].arrive_and_expect_tx(B_TOPK*D_K*sizeof(bf16));
} else {
// RoPE only
smem.bar_KV_full[k_buf_idx].arrive_and_expect_tx(B_TOPK*D_ROPE*sizeof(bf16));
}
smem.bar_KV_full[k_buf_idx].wait(k_bar_phase);
ku::tcgen05_after_thread_sync();
Tensor sK = make_tensor(
make_smem_ptr(smem.K[k_buf_idx].data()),
ku::make_umma_canonical_k_major_layout<B_TOPK, D_K/2, 128>()
);
ku::utcmma_ts(tiled_mma_P, tQ, sK, tP, true);
ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_QK_done, 1|2);
};
// Issue O += S V
auto issue_O = [&](int k, int rs_offset) {
auto [k_buf_idx, k_bar_phase] = rs.offset_by(rs_offset).get<NUM_K_BUFS>();
auto [_, bar_phase] = rs.offset_by(rs_offset).get<1>();
smem.bar_S_O_full.wait(bar_phase);
if (k == args.start_block_idx) {
smem.bar_tOut_empty.wait(args.outer_loop_phase^1);
}
ku::tcgen05_after_thread_sync();
Tensor sS = make_tensor(
make_smem_ptr(smem.S.data()),
ku::make_umma_canonical_k_major_layout<H_Q/2, B_TOPK, 0>()
);
Tensor sV = make_tensor(
make_smem_ptr(smem.K[k_buf_idx].data()),
ku::make_umma_canonical_mn_major_layout<D_V/2, B_TOPK, 128>()
);
ku::utcmma_ss(tiled_mma_O, sS, sV, tO, k == args.start_block_idx);
ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_SV_done, 1|2);
ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_KV_empty[k_buf_idx], 1|2);
};
CUTE_NO_UNROLL
for (int k = args.start_block_idx; k < args.end_block_idx+1; ++k) {
if (k < args.end_block_idx) {
issue_P(k, 0);
}
if (k == args.end_block_idx-1) {
ku::umma_arrive_2x1SM_noelect(smem.bar_tQ_empty);
}
if (k > args.start_block_idx) {
issue_O(k-1, -1);
}
if (k != args.end_block_idx) {
rs.update();
}
}
ku::tcgen05_before_thread_sync();
ku::umma_arrive_multicast_2x1SM_noelect(smem.bar_tOut_full, 1|2);
});
} else if (warp_idx == 8 && cta_idx == 1 && elect_one_sync()) {
if constexpr (IS_DECODE) {
// KV RoPE fetching warp
run_outer_loop([&](const OuterloopArgs &args) {
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {
auto [index_buf_idx, index_bar_phase] = rs.get<NUM_INDEX_BUFS>();
auto [k_buf_idx, k_bar_phase] = rs.get<NUM_K_BUFS>();
smem.bar_valid_coord_scales_full[index_buf_idx].wait(index_bar_phase);
smem.bar_KV_empty[k_buf_idx].wait(k_bar_phase^1);
CUTE_UNROLL
for (int row = 0; row < B_TOPK; row += 4) {
int4 cur_indices = *(int4*)(smem.tma_coord[index_buf_idx] + row);
ku::tma_gather4_cta_group_2<true>(
block_idx >= args.num_orig_kv_blocks ? &tma_params.tensor_map_extra_kv_rope : &tma_params.tensor_map_kv_rope,
smem.bar_KV_full[k_buf_idx],
smem.K[k_buf_idx].data() + (D_NOPE-D_K/2)*B_TOPK + row*D_ROPE,
0,
cur_indices,
(int64_t)TMA::CacheHintSm90::EVICT_LAST
);
}
smem.bar_valid_coord_scales_empty[index_buf_idx].arrive();
rs.update();
}
});
}
} else if (warp_idx == 9) {
// KV validness loading warp (for prefill), Indices transformation warp (for decode, Responsible for generating: TMA coordinates, scale factors, and valid masks)
if constexpr (IS_PREFILL) {
if (lane_idx < B_TOPK/8) {
run_outer_loop([&](const OuterloopArgs &args) {
int* gIndices = params.indices + args.s_q_idx*params.stride_indices_s_q;
CUTE_NO_UNROLL
for (int k = args.start_block_idx; k < args.end_block_idx; ++k) {
char k_validness_mask = load_indices_and_generate_mask(
lane_idx,
gIndices + k*B_TOPK,
params.s_kv,
k*B_TOPK,
args.topk_length
);
auto [indices_buf_idx, indices_bar_phase] = rs.get<NUM_INDEX_BUFS>();
smem.bar_valid_coord_scales_empty[indices_buf_idx].wait(indices_bar_phase^1);
smem.is_k_valid[indices_buf_idx][lane_idx] = k_validness_mask;
smem.bar_valid_coord_scales_full[indices_buf_idx].arrive();
rs.update();
}
});
}
} else {
static_assert(B_TOPK == 64);
// Each thread is responsible for 2 tokens
static constexpr int tma_coords_step_per_token = 576/TMA_K_STRIDE_FOR_DECODING;
int tma_coords_step_per_block = params.stride_kv_block / TMA_K_STRIDE_FOR_DECODING; // must < 2G since k_batch_stride < 1T and TMA_K_STRIDE_FOR_DECODING > 512
int tma_coords_step_per_extra_block = params.stride_extra_kv_block / TMA_K_STRIDE_FOR_DECODING;
uint8_t* k_scales_ptr = (uint8_t*)params.kv + params.page_block_size*(D_NOPE+2*D_ROPE);
uint8_t* extra_k_scales_ptr = (uint8_t*)params.extra_kv + params.extra_page_block_size*(D_NOPE+2*D_ROPE);
run_outer_loop([&](const OuterloopArgs &args) {
int* indices = (int*)params.indices + params.stride_indices_b*args.batch_idx + params.stride_indices_s_q*args.s_q_idx;
int* extra_indices = (int*)params.extra_indices + params.stride_extra_indices_b*args.batch_idx + params.stride_extra_indices_s_q*args.s_q_idx;
struct IsOrigBlock {};
struct IsExtraBlock {};
auto process_one_block = [&](int block_idx, auto is_extra_block_t) {
auto [index_buf_idx, index_bar_phase] = rs.get<NUM_INDEX_BUFS>();
static constexpr bool IS_EXTRA_BLOCK = std::is_same_v<decltype(is_extra_block_t), IsExtraBlock>;
int cur_block_size = IS_EXTRA_BLOCK ? params.extra_page_block_size : params.page_block_size;
int64_t cur_k_block_stride = IS_EXTRA_BLOCK ? params.stride_extra_kv_block : params.stride_kv_block;
[[maybe_unused]] int cur_k_row_stride = IS_EXTRA_BLOCK ? params.stride_extra_kv_row : params.stride_kv_row;
uint8_t* cur_k_scales_ptr = IS_EXTRA_BLOCK ? extra_k_scales_ptr : k_scales_ptr;
int cur_tma_coords_step_per_block = IS_EXTRA_BLOCK ? tma_coords_step_per_extra_block : tma_coords_step_per_block;
int abs_pos, my_indices[2];
if (!IS_EXTRA_BLOCK) {
abs_pos = block_idx*B_TOPK + lane_idx*2;
*(int2*)my_indices = __ldg((int2*)(indices + abs_pos));
} else {
abs_pos = (block_idx-args.num_orig_kv_blocks)*B_TOPK + lane_idx*2;
*(int2*)my_indices = __ldg((int2*)(extra_indices + abs_pos));
}
smem.bar_valid_coord_scales_empty[index_buf_idx].wait(index_bar_phase^1);
int tma_coords[2];
fp8_e8m0 scales[2*(NUM_SCALES_EACH_TOKEN/2)];
char valid_mask = 0;
CUTE_UNROLL
for (int i = 0; i < 2; ++i) {
int block_idx, idx_in_block;
block_idx = (unsigned int)my_indices[i] / cur_block_size;
idx_in_block = (unsigned int)my_indices[i] % cur_block_size;
bool is_token_valid = my_indices[i] != -1 && (abs_pos+i < (IS_EXTRA_BLOCK?args.extra_topk_length:args.topk_length));
valid_mask |= is_token_valid << i;
tma_coords[i] = is_token_valid ? block_idx*cur_tma_coords_step_per_block + idx_in_block*tma_coords_step_per_token : -1; // If the token is invalid because it topk position exceeds topk_length, we must manually fill tma_coords with -1 to avoid copying-in NaN.
int64_t offset = block_idx*cur_k_block_stride + (idx_in_block*8 + (cta_idx == 1 ? 4 : 0)); // Each token has 7 scale factors with an extra 1B padding
uint32_t scalesx4 = is_token_valid ? __ldg((uint32_t*)(cur_k_scales_ptr + offset)) : 0;
*(uint32_t*)(scales+i*(NUM_SCALES_EACH_TOKEN/2)) = scalesx4;
}
valid_mask <<= lane_idx%4*2;
valid_mask |= __shfl_xor_sync(0xFFFFFFFF, valid_mask, 0x1);
valid_mask |= __shfl_xor_sync(0xFFFFFFFF, valid_mask, 0x2);
*(uint64_t*)(smem.scales[index_buf_idx] + lane_idx*2) = *(uint64_t*)scales;
*(int2*)(smem.tma_coord[index_buf_idx] + lane_idx*2) = *(int2*)tma_coords;
if (lane_idx%4 == 0)
smem.is_k_valid[index_buf_idx][lane_idx/4] = valid_mask;
smem.bar_valid_coord_scales_full[index_buf_idx].arrive();
rs.update();
};
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < min(args.num_orig_kv_blocks, args.end_block_idx); ++block_idx) {
process_one_block(block_idx, IsOrigBlock{});
}
CUTE_NO_UNROLL
for (int block_idx = max(args.start_block_idx, args.num_orig_kv_blocks); block_idx < args.end_block_idx; ++block_idx) {
process_one_block(block_idx, IsExtraBlock{});
}
});
}
} else if (warp_idx >= 10 && elect_one_sync()) {
if constexpr (IS_PREFILL) {
if (warp_idx == 10) {
// CLC Producer thread
run_outer_loop([&](const OuterloopArgs &args) {
if (cta_idx == 0) {
smem.bar_clc_empty.wait(args.outer_loop_phase^1);
ku::issue_clc_query_multicast_cluster_all(smem.bar_clc_full, smem.clc_response_obj);
}
smem.bar_clc_full.arrive_and_expect_tx(sizeof(smem.clc_response_obj));
});
}
} else {
// Raw KV NoPE Producer thread
run_outer_loop([&](const OuterloopArgs &args) {
CUTE_NO_UNROLL
for (int block_idx = args.start_block_idx; block_idx < args.end_block_idx; ++block_idx) {
auto [raw_k_buf_idx, raw_k_bar_phase] = rs.get<NUM_RAW_K_BUFS>();
auto [index_buf_idx, index_bar_phase] = rs.get<NUM_INDEX_BUFS>();
smem.bar_valid_coord_scales_full[index_buf_idx].wait(index_bar_phase);
smem.bar_raw_KV_empty[raw_k_buf_idx].wait(raw_k_bar_phase^1);
int4 nxt_indices = *(int4*)(smem.tma_coord[index_buf_idx] + (warp_idx == 10 ? 0 : 4));
CUTE_UNROLL
for (int row = (warp_idx == 10 ? 0 : 4); row < B_TOPK; row += 8) {
int4 cur_indices = nxt_indices;
if (row+8 < B_TOPK)
nxt_indices = *(int4*)(smem.tma_coord[index_buf_idx] + row + 8);
ku::tma_gather4(
block_idx >= args.num_orig_kv_blocks ? &tma_params.tensor_map_extra_kv_nope : &tma_params.tensor_map_kv_nope,
smem.bar_raw_KV_full[raw_k_buf_idx],
smem.K_raw[raw_k_buf_idx].data() + row*(D_K/2),
cta_idx*(D_K/2),
cur_indices,
(int64_t)TMA::CacheHintSm90::EVICT_LAST
);
}
if (warp_idx == 10) {
smem.bar_raw_KV_full[raw_k_buf_idx].arrive_and_expect_tx(B_TOPK*(D_K/2)*sizeof(fp8_e4m3));
}
smem.bar_valid_coord_scales_empty[index_buf_idx].arrive();
rs.update();
}
});
}
}
} else {
// Scale & Exp threads
cutlass::arch::warpgroup_reg_alloc<176>();
int local_warp_idx = warp_idx - 12;
bf16* sS_base = smem.S.data() + (local_warp_idx >= 2 ? (H_Q/2)*(B_TOPK/2) : 0) + (idx_in_warpgroup%64)*8;
RingBufferState rs;
run_outer_loop([&](const OuterloopArgs &args) {
// For definition and consistency about `mi`, `li`, and `real_mi`, plz refer to head64 prefill
float mi = MAX_INIT_VAL;
float li = 0.0f;
float real_mi = -CUDART_INF_F;
static constexpr int NUM_ELEMS_PER_THREAD = B_TOPK / 2;
CUTE_NO_UNROLL
for (int k = args.start_block_idx; k < args.end_block_idx; ++k) {
auto [k_buf_idx, k_bar_phase] = rs.get<NUM_K_BUFS>();
auto [indices_buf_idx, indices_bar_phase] = rs.get<NUM_INDEX_BUFS>();
auto [_, bar_phase] = rs.get<1>();
// NOTE We don't need to sync for Prefill mode, since we have two synchronizations inside the loop body (one for p_exchange_buf sync, another one for rowwise_max_buf sync). The latter one guarantees the emptyness of p_exchange_buf and the former one guarantees the emptyness of rowwise_max_buf
smem.bar_valid_coord_scales_full[indices_buf_idx].wait(indices_bar_phase);
// Get P from TMEM
float p[NUM_ELEMS_PER_THREAD];
smem.bar_QK_done.wait(bar_phase);
ku::tcgen05_after_thread_sync();
retrieve_mask_and_reduce_p<
NUM_ELEMS_PER_THREAD,
tmem_cols::P,
barrier_ids::WG2_WARP02_SYNC,
barrier_ids::WG2_WARP13_SYNC,
false
>(
smem.is_k_valid[indices_buf_idx],
local_warp_idx,
lane_idx,
[&]() {smem.bar_P_empty.arrive(0u);},
smem.P_exchange,
p
);
// Get rowwise max of P
float cur_pi_max = get_max<NUM_ELEMS_PER_THREAD>(p);
cur_pi_max *= params.sm_scale_div_log2;
smem.rowwise_max_buf[idx_in_warpgroup] = cur_pi_max;
NamedBarrier::arrive_and_wait(64, barrier_ids::WG2_WARP02_SYNC + (local_warp_idx&1));
cur_pi_max = max(cur_pi_max, smem.rowwise_max_buf[idx_in_warpgroup^64]);
real_mi = max(real_mi, cur_pi_max);
bool should_scale_o = __any_sync(0xffffffff, cur_pi_max - mi > 6.0f);
// Calc scale factor, and scale li
float new_max, scale_for_old;
if (!should_scale_o) {
// Don't scale O
scale_for_old = 1.0f;
new_max = mi;
} else {
new_max = max(cur_pi_max, mi);
scale_for_old = exp2f(mi - new_max);
}
mi = new_max; // mi is still identical within each row
// Calculate S
nv_bfloat162 s[NUM_ELEMS_PER_THREAD/2];
float cur_sum = get_s_from_p<NUM_ELEMS_PER_THREAD>(s, p, params.sm_scale_div_log2, new_max);
li = fmaf(li, scale_for_old, cur_sum);
// Store S
smem.bar_SV_done.wait(bar_phase^1);
CUTE_UNROLL
for (int i = 0; i < NUM_ELEMS_PER_THREAD/8; ++i) {
ku::st_shared(sS_base + i*8*(H_Q/2), *(__int128_t*)(s + i*4));
}
// Rescale O
if (k > 0 && should_scale_o) {
ku::tcgen05_after_thread_sync();
rescale_O<D_V, 32, tmem_cols::O>(scale_for_old);
ku::tcgen05_before_thread_sync();
}
fence_view_async_shared();
smem.bar_S_O_full.arrive(0u);
smem.bar_valid_coord_scales_empty[indices_buf_idx].arrive();
rs.update();
}
if (real_mi == -CUDART_INF_F) {
// real_mi == -CUDART_INF_F <=> No valid TopK indices
// We set li to 0 to fit the definition that li := exp(x[i] - mi)
li = 0.0f;
mi = -CUDART_INF_F;
}
// Reduce li
smem.bar_li_empty.wait(args.outer_loop_phase^1);
smem.rowwise_li_buf[idx_in_warpgroup^64] = li;
NamedBarrier::arrive_and_wait(128, barrier_ids::WG2_SYNC);
li += smem.rowwise_li_buf[idx_in_warpgroup];
if (idx_in_warpgroup < H_Q/2) {
// Calculate output_scale and save
int head_idx = cta_idx*(H_Q/2) + idx_in_warpgroup;
float attn_sink = params.attn_sink == nullptr ? -CUDART_INF_F : __ldg(params.attn_sink + head_idx);
float output_scale;
if (FWD_MODE != FwdMode::DecodeWithSplitKV || args.is_no_split) {
output_scale = __fdividef(1.0f, li + exp2f(fmaf(attn_sink, CUDART_L2E_F, -mi)));
} else {
output_scale = __fdividef(1.0f, li);
}
smem.rowwise_li_buf[idx_in_warpgroup] = li == 0.0f ? 0.0f : output_scale;
smem.bar_li_full.arrive();
float cur_lse = fmaf(mi, CUDART_LN2_F, logf(li));
cur_lse = cur_lse == -CUDART_INF_F ? +CUDART_INF_F : cur_lse;
if constexpr (IS_PREFILL) {
int global_index = args.s_q_idx*params.h_q + head_idx;
params.max_logits[global_index] = real_mi*CUDART_LN2_F;
params.lse[global_index] = cur_lse;
} else {
if (FWD_MODE != FwdMode::DecodeWithSplitKV || args.is_no_split) {
params.lse[args.batch_idx*params.stride_lse_b + args.s_q_idx*params.stride_lse_s_q + head_idx] = cur_lse;
} else {
float cur_lse_2base = log2f(li) + mi;
params.lse_accum[args.n_split_idx*params.stride_lse_accum_split + args.s_q_idx*params.stride_lse_accum_s_q + head_idx] = cur_lse_2base;
}
}
}
});
}
ku::barrier_cluster_arrive_relaxed();
ku::barrier_cluster_wait_acquire();
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100");
}
#endif
}
// We have two launchers with different kernel names to distinguish prefill and decode
template<typename Kernel>
static __global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 2)
sparse_attn_fwd_for_small_topk_kernel(__grid_constant__ const typename Kernel::ArgT params, __grid_constant__ const typename Kernel::TmaParams tma_params) {
Kernel::sparse_attn_fwd_kernel_devfunc(params, tma_params);
}
template<typename Kernel>
static __global__ void __launch_bounds__(Kernel::NUM_THREADS, 1, 2)
flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const typename Kernel::ArgT params, __grid_constant__ const typename Kernel::TmaParams tma_params) {
Kernel::sparse_attn_fwd_kernel_devfunc(params, tma_params);
}
template<FwdMode FWD_MODE, int D_QK>
void KernelTemplate<FWD_MODE, D_QK>::run(const ArgT& params) {
static_assert(D_QK == 576 || D_QK == 512);
KU_ASSERT(params.h_kv == 1);
KU_ASSERT(params.topk % B_TOPK == 0); // To save some boundry checkings
KU_ASSERT(params.h_q == H_Q); // To save some calculation
KU_ASSERT(params.d_qk == D_QK);
static_assert(D_Q == 512);
CUtensorMap tensor_map_q;
if constexpr (IS_DECODE) {
KU_ASSERT(params.stride_q_b % params.stride_q_s_q == 0, "In decode mode for MODEL1 sparse fp8 decoding on sm100f, q.stride(0) (on the batch dimension) must be divisible by q.stride(1) (on the sequence dimension).");
tensor_map_q = ku::make_tensor_map(
{64ul, H_Q, 2ul, (D_Q/64ul)/2ul, (unsigned long)params.b * (params.stride_q_b / params.stride_q_s_q)},
ku::make_stride_helper<int>({params.stride_q_h_q, D_Q/2, 64, params.stride_q_s_q}, sizeof(bf16)),
{64, H_Q/2, 2, (D_Q/64)/2, 1},
params.q,
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
CU_TENSOR_MAP_SWIZZLE_128B,
CU_TENSOR_MAP_L2_PROMOTION_L2_256B
);
} else {
tensor_map_q = ku::make_tensor_map(
{64ul, H_Q, 2ul, (D_Q/64ul)/2ul, (unsigned long)params.s_q},
ku::make_stride_helper<int>({params.stride_q_h_q, D_Q/2, 64, params.stride_q_s_q}, sizeof(bf16)),
{64, H_Q/2, 2, (D_Q/64)/2, 1},
params.q,
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
CU_TENSOR_MAP_SWIZZLE_128B,
CU_TENSOR_MAP_L2_PROMOTION_L2_256B
); // We use this layout to group Q[0:64] and Q[256:256+64] together, for UTCCP for dual gemm
}
CUtensorMap tensor_map_kv;
CUtensorMap tensor_map_kv_nope, tensor_map_kv_rope, tensor_map_extra_kv_nope = {}, tensor_map_extra_kv_rope = {};
if constexpr (IS_DECODE) {
auto get_kv_tensormap = [&](bool is_extra, void* k_ptr, int num_blocks, int64_t stride_kv_block, int64_t stride_kv_row) -> std::pair<CUtensorMap, CUtensorMap> {
KU_ASSERT((int64_t)k_ptr % 16 == 0, "The base address of %sk_ptr (%p) must be 16B aligned for sparse fp8 attention on sm100f", is_extra?"extra_":"", k_ptr);
KU_ASSERT(stride_kv_block % TMA_K_STRIDE_FOR_DECODING == 0, "%sk_cache.stride(0) (%ld) must be a multiple of %d. Padding might be necessary", is_extra?"extra_":"", stride_kv_block, TMA_K_STRIDE_FOR_DECODING);
CUtensorMap tensor_map_kv_nope = ku::make_tensor_map(
{D_NOPE + D_ROPE*2, (uint64_t)num_blocks * (stride_kv_block/TMA_K_STRIDE_FOR_DECODING)},
{TMA_K_STRIDE_FOR_DECODING},
{D_K/2, 1},
k_ptr,
CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B
); // NOTE: Here we use `D_NOPE+D_ROPE*2` as the box shape instead of D_NOPE because it's actually faster. I think that's because, if we use `D_NOPE+D_ROPE*2`, we can prefetch part of the RoPE part of the selected tokens.
CUtensorMap tensor_map_kv_rope = ku::make_tensor_map(
{D_ROPE, (uint64_t)num_blocks * (stride_kv_block/TMA_K_STRIDE_FOR_DECODING)},
{TMA_K_STRIDE_FOR_DECODING},
{64, 1},
(uint8_t*)k_ptr + D_NOPE,
CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B
);
return {tensor_map_kv_nope, tensor_map_kv_rope};
};
std::tie(tensor_map_kv_nope, tensor_map_kv_rope) = get_kv_tensormap(false, params.kv, params.num_blocks, params.stride_kv_block, params.stride_kv_row);
if (params.extra_topk > 0)
std::tie(tensor_map_extra_kv_nope, tensor_map_extra_kv_rope) = get_kv_tensormap(true, params.extra_kv, params.extra_num_blocks, params.stride_extra_kv_block, params.stride_extra_kv_row);
} else {
tensor_map_kv = ku::make_tensor_map(
{D_QK, (unsigned long)params.s_kv},
{(unsigned long)params.stride_kv_s_kv*sizeof(bf16)},
{64, 1},
params.kv,
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
CU_TENSOR_MAP_SWIZZLE_128B,
CU_TENSOR_MAP_L2_PROMOTION_L2_256B
);
}
CUtensorMap tensor_map_o;
if constexpr (IS_DECODE) {
tensor_map_o = ku::make_tensor_map(
{64, H_Q, D_V/64, (unsigned long)params.s_q, (unsigned long)params.b},
ku::make_stride_helper<int>({params.stride_o_h_q, 64, params.stride_o_s_q, params.stride_o_b}, sizeof(bf16)),
{64, H_Q/2, D_V/64, 1, 1},
params.out,
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
CU_TENSOR_MAP_SWIZZLE_128B,
CU_TENSOR_MAP_L2_PROMOTION_L2_256B
);
} else {
tensor_map_o = ku::make_tensor_map(
{64, H_Q, D_V/64, (unsigned long)params.s_q, 1ul},
ku::make_stride_helper<int>({D_V, 64, H_Q*D_V, H_Q*D_V}, sizeof(bf16)),
{64, H_Q/2, D_V/64, 1, 1},
params.out,
CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
CU_TENSOR_MAP_SWIZZLE_128B,
CU_TENSOR_MAP_L2_PROMOTION_L2_256B
);
}
CUtensorMap tensor_map_o_accum = {};
if constexpr (FWD_MODE == FwdMode::DecodeWithSplitKV) {
tensor_map_o_accum = ku::make_tensor_map(
{32, H_Q, D_V/32, (unsigned long)params.s_q, (unsigned long)params.num_sm_parts + params.b},
ku::make_stride_helper<int>({params.stride_o_accum_h_q, 32, params.stride_o_accum_s_q, params.stride_o_accum_split}, sizeof(float)),
{32, H_Q/2, B_EPI_SPLITKV/32, 1, 1},
params.o_accum,
CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
CU_TENSOR_MAP_SWIZZLE_128B,
CU_TENSOR_MAP_L2_PROMOTION_L2_256B
);
}
TmaParams tma_params;
if constexpr (IS_DECODE) {
tma_params = {
tensor_map_q,
tensor_map_o,
tensor_map_o_accum,
tensor_map_kv_nope,
tensor_map_kv_rope,
tensor_map_extra_kv_nope,
tensor_map_extra_kv_rope
};
} else {
tma_params = {
tensor_map_q,
tensor_map_kv,
tensor_map_o
};
}
auto kernel = IS_PREFILL ? &sparse_attn_fwd_for_small_topk_kernel<KernelTemplate<FWD_MODE, D_QK>> : &flash_fwd_splitkv_mla_fp8_sparse_kernel<KernelTemplate<FWD_MODE, D_QK>>;
constexpr size_t smem_size = sizeof(SharedMemoryPlan);
KU_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
dim3 grid_shape;
if constexpr (IS_DECODE) {
grid_shape = dim3(2*params.s_q, FWD_MODE == FwdMode::DecodeWithSplitKV ? params.num_sm_parts : params.b, 1);
} else {
grid_shape = dim3(2*params.s_q, 1, 1);
}
cutlass::ClusterLaunchParams launch_params = {
grid_shape,
dim3(NUM_THREADS, 1, 1),
dim3(2, 1, 1),
smem_size,
params.stream
};
KU_CUTLASS_CHECK(cutlass::launch_kernel_on_cluster(
launch_params, (void*)kernel, params, tma_params
));
}
template<FwdMode FWD_MODE, int D_QK>
void run_fwd_for_small_topk_phase1_kernel(const SparseFwdArgT<FWD_MODE>& params) {
using Kernel = KernelTemplate<FWD_MODE, D_QK>;
Kernel::run(params);
}
}
#pragma once
#include "params.h"
namespace sm100::fwd_for_small_topk::head128 {
template<SparseAttnFwdMode FWD_MODE, int D_QK>
void run_fwd_for_small_topk_phase1_kernel(const SparseFwdArgT<FWD_MODE>& params);
}
#pragma once
#include <cute/tensor.hpp>
#include "defines.h"
namespace sm100 {
using namespace cute;
struct int32x8_t {
int a0, a1, a2, a3, a4, a5, a6, a7;
};
struct float8 {
float2 a01, a23, a45, a67;
};
__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) {
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst);
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], %2;\n"
:: "r"(dst_addr),
"l"(src),
"n"(16));
}
template<typename T>
CUTE_DEVICE
static void st_async_128b(void* dst_ptr, const T& data, const transac_bar_t* mbar_ptr) {
static_assert(sizeof(T) == 16, "Data type must be 16 bytes (128 bits) for st_async_128b.");
long2 data_long2 = *reinterpret_cast<const long2*>(&data);
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr);
asm volatile (
"st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n"
:
: "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr)
);
}
CUTE_DEVICE
void umma_arrive_multicast_noelect(uint64_t const* smem_ptr, uint16_t cta_mask) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr);
asm volatile(
"{\n\t"
"tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t"
"}"
:
:"r"(bar_intptr), "h"(cta_mask));
}
CUTE_DEVICE
void umma_arrive_multicast_noelect(transac_bar_t const* smem_ptr, uint16_t cta_mask) {
umma_arrive_multicast_noelect((uint64_t*)smem_ptr, cta_mask);
}
CUTE_DEVICE
void umma_arrive_multicast_2x1SM_noelect(uint64_t const* smem_ptr, uint16_t cta_mask) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr);
asm volatile(
"{\n\t"
"tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t"
"}"
:
:"r"(bar_intptr), "h"(cta_mask));
}
CUTE_DEVICE
void umma_arrive_multicast_2x1SM_noelect(transac_bar_t const* smem_ptr, uint16_t cta_mask) {
umma_arrive_multicast_2x1SM_noelect((uint64_t*)smem_ptr, cta_mask);
}
CUTE_DEVICE
int64_t createpolicy_evict_last() {
int64_t res;
asm volatile(
"createpolicy.fractional.L2::evict_last.b64 %0, 1.0; \n\t"
: "=l"(res)
:
);
return res;
}
CUTE_DEVICE
void atomicadd_f32x4_with_policy(void* global_addr, const float4 &data, int64_t cache_policy) {
asm volatile(
"red.relaxed.gpu.global.add.L2::cache_hint.v4.f32 [%4], {%0, %1, %2, %3}, %5; \n\t"
:
: "f"(data.x), "f"(data.y), "f"(data.z), "f"(data.w),
"l"((int64_t)global_addr), "l"(cache_policy)
);
}
CUTE_DEVICE
void tma_bulk_reduce_add(void const* src_ptr, void* dst_ptr, int32_t store_bytes) {
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(src_ptr);
asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n"
:
: "l"(dst_ptr), "r"(smem_int_ptr), "r"(store_bytes)
: "memory");
}
CUTE_DEVICE
float2 float2_add(const float2 &a, const float2 &b) {
float2 res;
cute::add(res, a, b);
return res;
}
CUTE_DEVICE
float2 float2_mul(const float2 &a, const float2 &b) {
float2 res;
cute::mul(res, a, b);
return res;
}
CUTE_DEVICE
float2 float2_fma(const float2 &a, const float2 &b, const float2 &c) {
// return a*b+c
float2 res;
cute::fma(res, a, b, c);
return res;
}
CUTE_DEVICE
float2 float2_neg(const float2 &a) {
float2 t = {-1.0f, -1.0f};
return float2_mul(a, t);
}
__device__ __forceinline__ void tcgen05_before_thread_sync() {
asm volatile("tcgen05.fence::before_thread_sync;");
}
__device__ __forceinline__ void tcgen05_after_thread_sync() {
asm volatile("tcgen05.fence::after_thread_sync;");
}
template<bool USE_CTA0_MBAR = false>
CUTE_DEVICE void tma_gather4(const void* desc_ptr, transac_bar_t* mbar_ptr, void* smem_ptr, int col_idx, int4 row_idxs, TMA::CacheHintSm90 cache_hint) {
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr);
if constexpr (USE_CTA0_MBAR) {
mbar_addr &= Sm100MmaPeerBitMask;
}
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n"
:
: "r"(smem_addr), "l"(desc_ptr), "r"(col_idx),
"r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w),
"r"(mbar_addr), "l"(uint64_t(cache_hint))
: "memory"
);
}
// 32 data path lanes, 32-bit pattern, repeated N times
template <int N, typename T>
CUTE_DEVICE void tmem_ld_32dp32bNx(uint32_t const &src_addr, T* dst_ptr_) {
static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, "N must be a power of 2 and lies between 1 ~ 128");
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(dst_ptr_);
if constexpr (N == 1) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x1.b32"
"{%0},"
"[%1];\n"
: "=r"(dst_ptr[0])
: "r"(src_addr));
} else if constexpr (N == 2) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x2.b32"
"{%0, %1},"
"[%2];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1])
: "r"(src_addr));
} else if constexpr (N == 4) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x4.b32"
"{%0, %1, %2, %3},"
"[%4];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3])
: "r"(src_addr));
} else if constexpr (N == 8) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7},"
"[%8];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7])
: "r"(src_addr));
} else if constexpr (N == 16) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x16.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, "
"%14, %15},"
"[%16];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15])
: "r"(src_addr));
} else if constexpr (N == 32) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x32.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, "
"%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, "
"%26, %27, %28, %29, %30, %31},"
"[%32];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]),
"=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]),
"=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]),
"=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]),
"=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]),
"=r"(dst_ptr[30]), "=r"(dst_ptr[31])
: "r"(src_addr));
} else if constexpr (N == 64) {
asm volatile(
"tcgen05.ld.sync.aligned.32x32b.x64.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, "
"%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, "
"%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, "
"%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, "
"%57, %58, %59, %60, %61, %62, %63},"
"[%64];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]),
"=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]),
"=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]),
"=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]),
"=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]),
"=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]),
"=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]),
"=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]),
"=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]),
"=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]),
"=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]),
"=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]),
"=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]),
"=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]),
"=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]),
"=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]),
"=r"(dst_ptr[63])
: "r"(src_addr));
} else if constexpr (N == 128) {
asm volatile(
"tcgen05.ld.sync.aligned.32x32b.x128.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, "
"%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, "
"%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, "
"%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, "
"%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, "
"%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, "
"%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, "
"%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, "
"%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, "
"%121, %122, %123, %124, %125, %126, %127},"
"[%128];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]),
"=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]),
"=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]),
"=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]),
"=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]),
"=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]),
"=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]),
"=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]),
"=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]),
"=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]),
"=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]),
"=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]),
"=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]),
"=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]),
"=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]),
"=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]),
"=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]),
"=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]),
"=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]),
"=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]),
"=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]),
"=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]),
"=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]),
"=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]),
"=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]),
"=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]),
"=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]),
"=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]),
"=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]),
"=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]),
"=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]),
"=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]),
"=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]),
"=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]),
"=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]),
"=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]),
"=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]),
"=r"(dst_ptr[126]), "=r"(dst_ptr[127])
: "r"(src_addr));
} else {
asm volatile ("trap");
}
}
// 16 data path lanes, 256-bit pattern, repeated N times
template <int N, typename T>
CUTE_DEVICE void tmem_ld_16dp256bNx(uint32_t const &src_addr, T* dst_ptr_) {
static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 32,
"N must be a power of 2 and lies between 1 ~ 32");
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(dst_ptr_);
if constexpr (N == 1) {
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32"
"{%0, %1, %2, %3},"
"[%4];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3])
: "r"(src_addr));
} else if constexpr (N == 2) {
asm volatile("tcgen05.ld.sync.aligned.16x256b.x2.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7},"
"[%8];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7])
: "r"(src_addr));
} else if constexpr (N == 4) {
asm volatile(
"tcgen05.ld.sync.aligned.16x256b.x4.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, "
"%14, %15},"
"[%16];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15])
: "r"(src_addr));
} else if constexpr (N == 8) {
asm volatile(
"tcgen05.ld.sync.aligned.16x256b.x8.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, "
"%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, "
"%26, %27, %28, %29, %30, %31},"
"[%32];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]),
"=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]),
"=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]),
"=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]),
"=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]),
"=r"(dst_ptr[30]), "=r"(dst_ptr[31])
: "r"(src_addr));
} else if constexpr (N == 16) {
asm volatile(
"tcgen05.ld.sync.aligned.16x256b.x16.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, "
"%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, "
"%28, "
"%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, "
"%42, "
"%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, "
"%56, "
"%57, %58, %59, %60, %61, %62, %63},"
"[%64];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]),
"=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]),
"=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]),
"=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]),
"=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]),
"=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]),
"=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]),
"=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]),
"=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]),
"=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]),
"=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]),
"=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]),
"=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]),
"=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]),
"=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]),
"=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]),
"=r"(dst_ptr[63])
: "r"(src_addr));
} else if constexpr (N == 32) {
asm volatile(
"tcgen05.ld.sync.aligned.16x256b.x32.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, "
"%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, "
"%28, "
"%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, "
"%42, "
"%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, "
"%56, "
"%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, "
"%70, "
"%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, "
"%84, "
"%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, "
"%98, "
"%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, "
"%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, "
"%121, %122, %123, %124, %125, %126, %127},"
"[%128];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]),
"=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]),
"=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]),
"=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]),
"=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]),
"=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]),
"=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]),
"=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]),
"=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]),
"=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]),
"=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]),
"=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]),
"=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]),
"=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]),
"=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]),
"=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]),
"=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]),
"=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]),
"=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]),
"=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]),
"=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]),
"=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]),
"=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]),
"=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]),
"=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]),
"=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]),
"=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]),
"=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]),
"=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]),
"=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]),
"=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]),
"=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]),
"=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]),
"=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]),
"=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]),
"=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]),
"=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]),
"=r"(dst_ptr[126]), "=r"(dst_ptr[127])
: "r"(src_addr));
} else {
asm volatile("trap");
}
}
// 32 data path lanes, 32-bit pattern, repeated N times
template <int N, typename T>
CUTE_DEVICE void tmem_st_32dp32bNx(uint32_t const &dst_addr, T* src_ptr_) {
static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, "N must be a power of 2 and lies between 1 ~ 128");
uint32_t* src_ptr = reinterpret_cast<uint32_t*>(src_ptr_);
if constexpr (N == 1) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x1.b32"
"[%1], {%0};\n"
:
: "r"(src_ptr[0]),
"r"(dst_addr));
} else if constexpr (N == 2) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x2.b32"
"[%2], {%0, %1};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]),
"r"(dst_addr));
} else if constexpr (N == 4) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x4.b32"
"[%4], {%0, %1, %2, %3};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]),
"r"(dst_addr));
} else if constexpr (N == 8) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32"
"[%8], {%0, %1, %2, %3, %4, %5, %6, %7};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]),
"r"(src_ptr[6]), "r"(src_ptr[7]),
"r"(dst_addr));
} else if constexpr (N == 16) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x16.b32"
"[%16], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, "
"%14, %15};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]),
"r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]),
"r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]),
"r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]),
"r"(src_ptr[15]),
"r"(dst_addr));
} else if constexpr (N == 32) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x32.b32"
"[%32], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, "
"%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, "
"%26, %27, %28, %29, %30, %31};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]),
"r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]),
"r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]),
"r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]),
"r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]),
"r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]),
"r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]),
"r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]),
"r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]),
"r"(src_ptr[30]), "r"(src_ptr[31]),
"r"(dst_addr));
} else if constexpr (N == 64) {
asm volatile(
"tcgen05.st.sync.aligned.32x32b.x64.b32"
"[%64], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, "
"%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, "
"%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, "
"%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, "
"%57, %58, %59, %60, %61, %62, %63};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]),
"r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]),
"r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]),
"r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]),
"r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]),
"r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]),
"r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]),
"r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]),
"r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]),
"r"(src_ptr[30]), "r"(src_ptr[31]), "r"(src_ptr[32]),
"r"(src_ptr[33]), "r"(src_ptr[34]), "r"(src_ptr[35]),
"r"(src_ptr[36]), "r"(src_ptr[37]), "r"(src_ptr[38]),
"r"(src_ptr[39]), "r"(src_ptr[40]), "r"(src_ptr[41]),
"r"(src_ptr[42]), "r"(src_ptr[43]), "r"(src_ptr[44]),
"r"(src_ptr[45]), "r"(src_ptr[46]), "r"(src_ptr[47]),
"r"(src_ptr[48]), "r"(src_ptr[49]), "r"(src_ptr[50]),
"r"(src_ptr[51]), "r"(src_ptr[52]), "r"(src_ptr[53]),
"r"(src_ptr[54]), "r"(src_ptr[55]), "r"(src_ptr[56]),
"r"(src_ptr[57]), "r"(src_ptr[58]), "r"(src_ptr[59]),
"r"(src_ptr[60]), "r"(src_ptr[61]), "r"(src_ptr[62]),
"r"(src_ptr[63]),
"r"(dst_addr));
} else if constexpr (N == 128) {
asm volatile(
"tcgen05.st.sync.aligned.32x32b.x128.b32"
"[%128], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, "
"%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, "
"%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, "
"%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, "
"%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, "
"%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, "
"%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, "
"%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, "
"%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, "
"%121, %122, %123, %124, %125, %126, %127};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]),
"r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]),
"r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]),
"r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]),
"r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]),
"r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]),
"r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]),
"r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]),
"r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]),
"r"(src_ptr[30]), "r"(src_ptr[31]), "r"(src_ptr[32]),
"r"(src_ptr[33]), "r"(src_ptr[34]), "r"(src_ptr[35]),
"r"(src_ptr[36]), "r"(src_ptr[37]), "r"(src_ptr[38]),
"r"(src_ptr[39]), "r"(src_ptr[40]), "r"(src_ptr[41]),
"r"(src_ptr[42]), "r"(src_ptr[43]), "r"(src_ptr[44]),
"r"(src_ptr[45]), "r"(src_ptr[46]), "r"(src_ptr[47]),
"r"(src_ptr[48]), "r"(src_ptr[49]), "r"(src_ptr[50]),
"r"(src_ptr[51]), "r"(src_ptr[52]), "r"(src_ptr[53]),
"r"(src_ptr[54]), "r"(src_ptr[55]), "r"(src_ptr[56]),
"r"(src_ptr[57]), "r"(src_ptr[58]), "r"(src_ptr[59]),
"r"(src_ptr[60]), "r"(src_ptr[61]), "r"(src_ptr[62]),
"r"(src_ptr[63]), "r"(src_ptr[64]), "r"(src_ptr[65]),
"r"(src_ptr[66]), "r"(src_ptr[67]), "r"(src_ptr[68]),
"r"(src_ptr[69]), "r"(src_ptr[70]), "r"(src_ptr[71]),
"r"(src_ptr[72]), "r"(src_ptr[73]), "r"(src_ptr[74]),
"r"(src_ptr[75]), "r"(src_ptr[76]), "r"(src_ptr[77]),
"r"(src_ptr[78]), "r"(src_ptr[79]), "r"(src_ptr[80]),
"r"(src_ptr[81]), "r"(src_ptr[82]), "r"(src_ptr[83]),
"r"(src_ptr[84]), "r"(src_ptr[85]), "r"(src_ptr[86]),
"r"(src_ptr[87]), "r"(src_ptr[88]), "r"(src_ptr[89]),
"r"(src_ptr[90]), "r"(src_ptr[91]), "r"(src_ptr[92]),
"r"(src_ptr[93]), "r"(src_ptr[94]), "r"(src_ptr[95]),
"r"(src_ptr[96]), "r"(src_ptr[97]), "r"(src_ptr[98]),
"r"(src_ptr[99]), "r"(src_ptr[100]), "r"(src_ptr[101]),
"r"(src_ptr[102]), "r"(src_ptr[103]), "r"(src_ptr[104]),
"r"(src_ptr[105]), "r"(src_ptr[106]), "r"(src_ptr[107]),
"r"(src_ptr[108]), "r"(src_ptr[109]), "r"(src_ptr[110]),
"r"(src_ptr[111]), "r"(src_ptr[112]), "r"(src_ptr[113]),
"r"(src_ptr[114]), "r"(src_ptr[115]), "r"(src_ptr[116]),
"r"(src_ptr[117]), "r"(src_ptr[118]), "r"(src_ptr[119]),
"r"(src_ptr[120]), "r"(src_ptr[121]), "r"(src_ptr[122]),
"r"(src_ptr[123]), "r"(src_ptr[124]), "r"(src_ptr[125]),
"r"(src_ptr[126]), "r"(src_ptr[127]),
"r"(dst_addr));
} else {
asm volatile ("trap");
}
}
static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. 不确定是不是在所有显卡上都是这个数字
template<typename T>
CUTE_DEVICE
T* get_peer_addr(const T* p) {
return (T*)((int64_t)(p) ^ PEER_ADDR_MASK);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment