Unverified Commit fd249aac authored by Simon Mo's avatar Simon Mo Committed by GitHub
Browse files

Add Sparse Decoding Kernel and Sparse Prefill Kernel for Blackwell


Signed-off-by: default avatarsimon-mo <simon.mo@hey.com>
parent 17944550
......@@ -33,6 +33,8 @@ python tests/test_flash_mla_decoding.py
The dense MLA decoding kernel can achieve up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. For token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16), it can achieve 410 TFLOPS in compute-bound configuration on H800 SXM5, CUDA 12.8.
For Blackwell GPUs, the token-level sparse MLA decoding kernel can achieve up to 350 TFlops (on B200) which is not really optimized yet.
#### Test & benchmark MHA prefill (Dense):
```bash
......@@ -47,7 +49,7 @@ It achieves up to 1460 TFlops in forward and 1000 TFlops in backward computation
python tests/test_flash_mla_prefill.py
```
It achieves up to 640 TFlops in forward computation on H800 SXM5, CUDA 12.8.
It achieves up to 640 TFlops in forward computation on H800 SXM5, CUDA 12.8, and achieves up to 1450 TFlops on B200, CUDA 12.9.
## Requirements
......@@ -60,9 +62,9 @@ Support matrix:
| Kernel | GPU Architecture | MLA Mode [2] | KVCache Format |
| :---: | :---: | :---: | :---: |
| Dense Decoding | Hopper | MQA | BF16 |
| Sparse Decoding | Hopper | MQA | FP8 [1] |
| Sparse Decoding | Hopper & Blackwell | MQA | FP8 [1] |
| Dense Prefill | Blackwell | MHA | |
| Sparse Prefill | Hopper | MQA | |
| Sparse Prefill | Hopper & Blackwell | MQA | |
[1]: For more details on using FP8 KV cache, see documents below.
......
......@@ -16,7 +16,9 @@
#include "sm90/decode/dense/splitkv_mla.h"
#include "sm90/decode/sparse_fp8/splitkv_mla.h"
#include "sm90/prefill/sparse/fwd.h"
#include "sm100/decode/sparse_fp8/splitkv_mla.h"
#include "sm100/prefill/dense/interface.h"
#include "sm100/prefill/sparse/fwd.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
......@@ -31,7 +33,7 @@ struct Arch {
}
bool is_sm100() const {
return major == 10 && minor == 0;
return major == 10;
}
void assert_is_supported() const {
......@@ -86,7 +88,31 @@ DecodingAttnImplMeta get_attn_impl_meta(
}
}
} else if (arch.is_sm100()) {
TORCH_CHECK(false, "Unsupported GPU architecture");
if (is_sparse_attn) {
if (is_fp8_kvcache) {
TORCH_CHECK(h_q_.has_value());
int h_q = h_q_.value();
TORCH_CHECK(h_q % h_k == 0);
int s_q = num_q_tokens_per_head_k * h_k / h_q;
// FP8 + Sparse MLA
return {
std::max(sm_count / h_k / (cutlass::ceil_div(h_q/h_k, 64) * s_q), 1),
5,
64
};
} else {
// Sparse BF16 MLA
TORCH_CHECK(false, "Sparse BF16 MLA is not supported on SM100");
}
} else {
if (is_fp8_kvcache) {
// FP8 MLA
TORCH_CHECK(false, "FP8 Dence MLA is not supported on SM100");
} else {
// Normal BF16 MLA
TORCH_CHECK(false, "BF16 Dence MLA is not supported on SM100");
}
}
} else {
TORCH_CHECK(false, "Unsupported GPU architecture");
}
......@@ -326,7 +352,8 @@ fwd_kvcache_mla(
}
}
} else if (arch.is_sm100()) {
TORCH_CHECK(false, "Unsupported GPU architecture");
TORCH_CHECK(is_fp8 && is_sparse_attn, "Only FP8 + Sparse attention is supported on SM100");
sm100::run_flash_splitkv_mla_fp8_sparse_kernel(params, stream);
} else {
TORCH_CHECK(false, "Unsupported GPU architecture");
}
......@@ -366,7 +393,8 @@ std::vector<at::Tensor> sparse_prefill_fwd(
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm90 = dprops->major == 9;
TORCH_CHECK(is_sm90, "Sparse Attention Forward Kernel (sparse_prefill_fwd) is only supported on SM90 architectures");
bool is_sm100 = dprops->major == 10;
TORCH_CHECK(is_sm90 || is_sm100, "Sparse Attention Forward Kernel (sparse_prefill_fwd) is only supported on SM90 or SM100 architectures");
CHECK_DEVICE(q);
CHECK_DEVICE(kv);
......@@ -423,6 +451,8 @@ std::vector<at::Tensor> sparse_prefill_fwd(
if (is_sm90) {
sm90::run_fwd_kernel(params);
} else if (is_sm100) {
sm100::run_fwd_kernel(params);
} else {
TORCH_CHECK(false, "Unknown architecture");
}
......
#pragma once
#include <cuda_fp8.h>
#include <cuda_bf16.h>
#include "sm100/defines.h"
namespace sm100 {
struct fp8x8 {
__nv_fp8x4_e4m3 lo;
__nv_fp8x4_e4m3 hi;
};
struct fp8x32 {
fp8x8 a0, a1, a2, a3;
};
struct fp8x16 {
fp8x8 a0, a1;
};
__device__ __forceinline__
bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const float &scale) {
__nv_bfloat162 scale_bf162 = __float2bfloat162_rn(scale);
#define DEQUANT_FP8x4(OUTPUT_BF16_LO, OUTPUT_BF16_HI, FP8x4) \
{ \
float4 fp32x4 = (float4)(FP8x4); \
OUTPUT_BF16_LO = __float22bfloat162_rn({fp32x4.x, fp32x4.y})*scale_bf162; \
OUTPUT_BF16_HI = __float22bfloat162_rn({fp32x4.z, fp32x4.w})*scale_bf162; \
}
bf16x8 result;
DEQUANT_FP8x4(result.a01, result.a23, inputs.lo);
DEQUANT_FP8x4(result.a45, result.a67, inputs.hi);
return result;
}
__device__ __forceinline__
fp8x32 ldg_256_fp8x32(void* src_ptr) {
int32x8_t val;
asm volatile("ld.global.nc.L1::evict_normal.L2::evict_normal.L2::256B.v8.s32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];"
: "=r"(val.a0), "=r"(val.a1), "=r"(val.a2), "=r"(val.a3),
"=r"(val.a4), "=r"(val.a5), "=r"(val.a6), "=r"(val.a7)
: "l"(src_ptr)
);
return *reinterpret_cast<fp8x32*>(&val);
}
__device__ __forceinline__
fp8x16 ldg_128_fp8x16(void* src_ptr) {
int4 ret;
asm volatile("ld.global.nc.L1::evict_first.v4.s32 {%0, %1, %2, %3}, [%4];"
: "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w)
: "l"(src_ptr));
return *reinterpret_cast<fp8x16*>(&ret);
}
}
#include "splitkv_mla.h"
#include <cutlass/barrier.h>
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/tensor.hpp>
#include <cute/arch/tmem_allocator_sm100.hpp>
#include "utils.h"
#include "dequant.h"
#include "sm100/defines.h"
#include "sm100/helpers.h"
#include "sm100/intrinsics.h"
#include "sm100/ws_gemm.h"
namespace sm100 {
using cutlass::arch::fence_view_async_shared;
using cutlass::arch::NamedBarrier;
using namespace cute;
constexpr int B_H = 64;
constexpr int B_TOPK = 64;
constexpr int D_K = 576;
constexpr int D_V = 512;
constexpr int NUM_BUFS = 2;
constexpr int NUM_THREADS = 128*3;
constexpr int NUM_WORKING_THREADS = 128 + 128 + 32;
constexpr float MAX_INIT_VAL = -1e30f; // To avoid (-inf) - (-inf) = NaN
template<
typename Shape_Q, typename TMA_Q,
typename Shape_O, typename TMA_O
>
struct TmaParams {
Shape_Q shape_Q; TMA_Q tma_Q;
Shape_O shape_O; TMA_O tma_O;
};
namespace tmem_addr {
constexpr int o = 0; // o: [0, 256]
constexpr int p = 256; // p: [256, 288]
};
using SmemLayoutQ = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H>, Int<D_K>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutOBuf = decltype(tile_to_shape(
UMMA::Layout_K_INTER_Atom<bf16>{}, // TODO This may lead to TMA double traffic
Shape<Int<B_H>, Int<D_V>>{}
));
using SmemLayoutOAccumBuf = Layout<
Shape<Int<B_H>, Int<D_V>>,
Stride<Int<520>, _1> // We use stride = 520 here to avoid bank conflict
>;
using SmemLayoutS = decltype(tile_to_shape(
UMMA::Layout_K_INTER_Atom<bf16>{},
Shape<Int<B_H>, Int<B_TOPK>>{},
Step<_1, _2>{}
));
template<int NUM_TILES>
using SmemLayoutKTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_INTER_Atom<bf16>{},
Shape<Int<B_H>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutKTilesTransposed = decltype(composition(
SmemLayoutKTiles<NUM_TILES>{},
Layout<
Shape<Int<64*NUM_TILES>, Int<B_TOPK>>,
Stride<Int<B_TOPK>, _1>
>{}
));
using SmemLayoutK = SmemLayoutKTiles<9>;
using SmemLayoutV = SmemLayoutKTilesTransposed<8>;
struct SharedMemoryPlan {
array_aligned<bf16, cosize_v<SmemLayoutQ>> q;
union {
array_aligned<bf16, cosize_v<SmemLayoutOBuf>> o_buf;
array_aligned<float, cosize_v<SmemLayoutOAccumBuf>> o_accum_buf;
array_aligned<bf16, cosize_v<SmemLayoutK>> k[NUM_BUFS];
} u;
array_aligned<bf16, cosize_v<SmemLayoutS>> s;
transac_bar_t bar_q;
transac_bar_t bar_k_ready[NUM_BUFS], bar_k_free[NUM_BUFS];
transac_bar_t bar_qk_done[NUM_BUFS], bar_so_ready[NUM_BUFS];
float rowwise_max_buf[128], rowwise_li_buf[128];
bool is_token_valid[NUM_BUFS][B_TOPK];
array_aligned<uint32_t, 1> tmem_start_addr;
};
using TiledMMA_QK = decltype(make_tiled_mma(
SM100_MMA_F16BF16_WS_SS_NOELECT<bf16, bf16, float, B_H, B_TOPK, UMMA::Major::K, UMMA::Major::K>{},
Layout<Shape<_1, _1, _1>>{}
)); // TODO Use TS?
using TiledMMA_SV = decltype(make_tiled_mma(
SM100_MMA_F16BF16_WS_SS_NOELECT<bf16, bf16, float, B_H, 256, UMMA::Major::K, UMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{},
Tile<Int<B_H>, Int<D_V>>{}
));
template<typename T>
CUTE_DEVICE
void store_128b(void* smem_ptr, const T &data) {
static_assert(sizeof(T) == 16);
*(__int128*)smem_ptr = *(__int128*)&data;
}
template<typename TmaParams>
__global__ void __launch_bounds__(NUM_THREADS, 1, 1)
flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams params, __grid_constant__ const TmaParams tma_params) {
#if IS_SM100
const int head_block_idx = blockIdx.x;
const int s_q_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
const int warpgroup_idx = cutlass::canonical_warp_group_idx();
const int idx_in_warpgroup = threadIdx.x % 128;
const int warp_idx = cutlass::canonical_warp_idx_sync();
// Define shared tensors
extern __shared__ char wksp_buf[];
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);
Tensor sQ = make_tensor(make_smem_ptr(plan.q.data()), SmemLayoutQ{});
if (warp_idx == 0 && elect_one_sync()) {
cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor());
}
if (warp_idx == 0) {
if (elect_one_sync()) {
plan.bar_q.init(1);
for (int i = 0; i < NUM_BUFS; ++i) {
plan.bar_k_ready[i].init(128);
plan.bar_k_free[i].init(1);
plan.bar_qk_done[i].init(1);
plan.bar_so_ready[i].init(128);
}
cutlass::arch::fence_barrier_init();
}
cute::TMEM::Allocator1Sm().allocate(512, plan.tmem_start_addr.data());
TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0);
cute::TMEM::Allocator1Sm().release_allocation_lock();
}
__syncthreads();
int bar_phase_k = 0;
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
int4 tile_scheduler_metadata = __ldg(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
int begin_idx = tile_scheduler_metadata.x;
int sched_begin_block_idx = tile_scheduler_metadata.y;
int end_idx = tile_scheduler_metadata.z;
int sched_end_block_idx = tile_scheduler_metadata.w;
if (begin_idx >= params.b) {
if (warp_idx == 0) {
cute::TMEM::Allocator1Sm().free(0, 512);
}
return;
}
auto get_cur_req_info = [&](int batch_idx) -> std::tuple<int, int, bool> {
int start_block_idx = batch_idx == begin_idx ? sched_begin_block_idx : 0;
int end_block_idx = batch_idx == end_idx ? sched_end_block_idx : params.topk / B_TOPK;
bool is_no_split = start_block_idx == 0 && end_block_idx == params.topk / B_TOPK;
return {start_block_idx, end_block_idx, is_no_split};
};
if (warpgroup_idx == 0) {
// Producer warpgroup
#pragma unroll 1
for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) {
auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx);
int* gIndices = params.indices_ptr + batch_idx*params.indices_batch_stride + s_q_idx*params.indices_row_stride; // (topk) : (1)
constexpr int GROUP_SIZE = 4, NUM_GROUPS = 128 / GROUP_SIZE;
constexpr int ROWS_PER_GROUP = B_TOPK / NUM_GROUPS;
int group_idx = idx_in_warpgroup / GROUP_SIZE;
int idx_in_group = idx_in_warpgroup % GROUP_SIZE;
NamedBarrier::arrive_and_wait(NUM_WORKING_THREADS, 1);
CUTE_NO_UNROLL
for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) {
int buf_idx = block_idx % NUM_BUFS;
// Wait for buffer to be available
plan.bar_k_free[buf_idx].wait(bar_phase_k>>buf_idx&1^1);
// Load
Tensor sK = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutK{});
CUTE_UNROLL
for (int local_row = 0; local_row < ROWS_PER_GROUP; ++local_row) {
int smem_row = group_idx + local_row*NUM_GROUPS;
int token_index = __ldg(gIndices + block_idx*B_TOPK + smem_row);
bool is_token_invalid = token_index == -1;
if (idx_in_group == 0)
plan.is_token_valid[buf_idx][smem_row] = !is_token_invalid;
if (is_token_invalid) {
uint128_t zeros = uint128_t{};
CUTE_UNROLL
for (int local_col = 0; local_col < D_V / (GROUP_SIZE*16); ++local_col) {
int col_base = local_col*(GROUP_SIZE*16) + idx_in_group*16;
store_128b(&sK(smem_row, col_base ), zeros);
store_128b(&sK(smem_row, col_base+8), zeros);
}
CUTE_UNROLL
for (int local_col = 0; local_col < (D_K-D_V) / (GROUP_SIZE*8); ++local_col) {
int col_base = local_col*(GROUP_SIZE*8) + idx_in_group*8;
store_128b(&sK(smem_row, D_V+col_base), zeros);
}
} else {
int block_index = token_index/B_TOPK;
int rel_idx_in_block = (token_index+B_TOPK) % B_TOPK; // NOTE When token_index is -1, -1/B_TOPK = 0 and (-1+B_TOPK)%B_TOPK = 63, so there will be no illegal-memory-access error. However, masking is necessary to prevent NaN (TODO Skip some rows instead?) TODO Masking
fp8* gK_base = (fp8*)params.k_ptr + block_index*params.k_batch_stride + rel_idx_in_block*params.k_row_stride;
float4 scales = __ldg((float4*)(gK_base + D_V));
CUTE_UNROLL
for (int local_col = 0; local_col < D_V / (GROUP_SIZE*16); ++local_col) {
int col_base = local_col*(GROUP_SIZE*16) + idx_in_group*16;
fp8x16 cur_fp8s = ldg_128_fp8x16(gK_base + col_base);
float cur_scale = local_col < (256/(GROUP_SIZE*16)) ?
(local_col < (128/(GROUP_SIZE*16)) ? scales.x : scales.y) :
(local_col < (384/(GROUP_SIZE*16)) ? scales.z : scales.w);
store_128b(&sK(smem_row, col_base ), cvt_fp8x8_bf16x8(cur_fp8s.a0, cur_scale));
store_128b(&sK(smem_row, col_base+8), cvt_fp8x8_bf16x8(cur_fp8s.a1, cur_scale));
}
CUTE_UNROLL
for (int local_col = 0; local_col < (D_K-D_V) / (GROUP_SIZE*8); ++local_col) {
int col_base = local_col*(GROUP_SIZE*8) + idx_in_group*8;
fp8x16 cur_k_rope_fp8s = ldg_128_fp8x16(gK_base + D_V + 4*sizeof(float) + col_base*sizeof(bf16));
bf16x8 cur_k_rope = *reinterpret_cast<bf16x8*>(&cur_k_rope_fp8s);
store_128b(&sK(smem_row, D_V+col_base), cur_k_rope);
}
}
}
fence_view_async_shared();
// Signal
plan.bar_k_ready[buf_idx].arrive();
bar_phase_k ^= 1<<buf_idx;
}
}
} else if (warpgroup_idx == 1) {
// Scale & Exp warpgroup
cutlass::arch::warpgroup_reg_alloc<240>();
int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
#pragma unroll 1
for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) {
auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx);
NamedBarrier::arrive_and_wait(NUM_WORKING_THREADS, 1);
float li = 0.0f;
float mi = MAX_INIT_VAL;
CUTE_NO_UNROLL
for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) {
int buf_idx = block_idx % NUM_BUFS;
// Wait for P
plan.bar_qk_done[buf_idx].wait(bar_phase_k>>buf_idx&1);
tcgen05_after_thread_sync();
// Load P from TMEM
float p[B_TOPK/2];
float2* p_float2 = reinterpret_cast<float2*>(p);
tmem_ld_32dp32bNx<B_TOPK/2>(tmem_addr::p, p);
cutlass::arch::fence_view_async_tmem_load();
// Get rowwise max
float cur_max = -INFINITY;
CUTE_UNROLL
for (int i = 0; i < B_TOPK/2; ++i) {
if (!plan.is_token_valid[buf_idx][(idx_in_warpgroup/64)*(B_TOPK/2)+i]) p[i] = -INFINITY;
cur_max = max(cur_max, p[i]);
}
cur_max *= params.scale_softmax_log2;
NamedBarrier::arrive_and_wait(128, 0); // TODO Name these barriers
plan.rowwise_max_buf[idx_in_warpgroup] = cur_max;
NamedBarrier::arrive_and_wait(128, 0);
cur_max = max(cur_max, plan.rowwise_max_buf[idx_in_warpgroup ^ 64]);
float new_max = max(mi, cur_max);
float scale_for_old = exp2f(mi - new_max);
float2 scale_for_old_float2 = {scale_for_old, scale_for_old};
// Get S
float2 scale_softmax_log2_float2 = {params.scale_softmax_log2, params.scale_softmax_log2};
float2 neg_new_max_float2 = {-new_max, -new_max};
bf16 s[B_TOPK/2];
float2 cur_sum = {0.0f, 0.0f};
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2)/2; ++i) {
float2 t = float2_fma(p_float2[i], scale_softmax_log2_float2, neg_new_max_float2);
t.x = exp2(t.x);
t.y = exp2(t.y);
*(__nv_bfloat162*)&s[i*2] = __float22bfloat162_rn(t);
cur_sum = float2_add(cur_sum, t);
}
// Save S
// NOTE We don't need a barrier here, since the current QK^T has finished implies that the previous SV has finished
bf16* sS_base = plan.s.data() + (idx_in_warpgroup/64)*(B_H*B_TOPK/2) + (idx_in_warpgroup%64) * 8;
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2)/8; i += 1) {
store_128b(sS_base + i*8*B_H, *((bf16x8*)s + i));
}
fence_view_async_shared();
// Rescale O
if (block_idx != start_block_idx) {
constexpr int B_SCALE_O = 64;
float2 o[B_SCALE_O/2];
CUTE_UNROLL
for (int b = 0; b < (D_V/2)/B_SCALE_O; ++b) {
tmem_ld_32dp32bNx<B_SCALE_O>(tmem_addr::o + b*B_SCALE_O, o);
cutlass::arch::fence_view_async_tmem_load();
CUTE_UNROLL
for (int i = 0; i < B_SCALE_O/2; ++i)
o[i] = float2_mul(o[i], scale_for_old_float2);
tmem_st_32dp32bNx<B_SCALE_O>(tmem_addr::o + b*B_SCALE_O, o);
cutlass::arch::fence_view_async_tmem_store();
}
}
plan.bar_so_ready[buf_idx].arrive();
// Update mi and li
mi = new_max;
li = li * scale_for_old + cur_sum.x + cur_sum.y;
bar_phase_k ^= 1<<buf_idx;
}
// Epilogue
// Deal with no valid token cases
if (mi == MAX_INIT_VAL) {
mi = -INFINITY;
li = 0.0f;
}
// Reduce li
plan.rowwise_li_buf[idx_in_warpgroup] = li;
NamedBarrier::arrive_and_wait(128, 0);
li += plan.rowwise_li_buf[idx_in_warpgroup ^ 64];
// Save li
int num_valid_heads = min(B_H, params.q_head_per_hk - head_block_idx*B_H);
int start_seq_idx = s_q_idx*params.q_head_per_hk + head_block_idx*B_H;
int n_split_idx = batch_idx == begin_idx ? begin_n_split_idx : 0;
int split_idx = is_no_split ? 0 : (__ldg(params.num_splits_ptr+batch_idx) + n_split_idx);
if (idx_in_warpgroup < num_valid_heads) {
if (is_no_split) {
float* gSoftmaxLse = (float*)params.softmax_lse_ptr + batch_idx*params.q_seq_per_hk + start_seq_idx + idx_in_warpgroup;
*gSoftmaxLse = li == 0.0f ? INFINITY : logf(li) + mi / (float)M_LOG2E; // NOTE Follows Flash MLA's approach, which returns +inf when there are no valid indices
} else {
float* gSoftmaxLseAccum = (float*)params.softmax_lseaccum_ptr + split_idx*params.q_seq_per_hk + start_seq_idx + idx_in_warpgroup;
*gSoftmaxLseAccum = li == 0.0f ? -INFINITY : log2f(li) + mi;
}
}
// Wait for the last SV gemm
plan.bar_k_free[(end_block_idx-1)%NUM_BUFS].wait(bar_phase_k>>((end_block_idx-1)%NUM_BUFS)&1^1);
tcgen05_after_thread_sync();
// Save O
float o_scale = li == 0.0f ? 0.0f : 1.0f / li;
float2 o_scale_float2 = {o_scale, o_scale};
if (is_no_split) {
constexpr int B_EPI = 32;
float2 o[B_EPI/2];
__nv_bfloat162 o_bf16[B_EPI/2];
Tensor sO = make_tensor(make_smem_ptr(plan.u.o_buf.data()), SmemLayoutOBuf{});
bf16* sO_base = plan.u.o_buf.data() + ((idx_in_warpgroup/64)*128)*B_H + (idx_in_warpgroup%64)*8;
CUTE_UNROLL
for (int i = 0; i < (D_V/2) / B_EPI; ++i) {
// Load
tmem_ld_32dp32bNx<B_EPI>(tmem_addr::o + i*B_EPI, o);
cutlass::arch::fence_view_async_tmem_load();
// Scale & Convert
CUTE_UNROLL
for (int j = 0; j < B_EPI/2; ++j) {
o[j] = float2_mul(o[j], o_scale_float2);
o_bf16[j] = __float22bfloat162_rn(o[j]);
}
// Store
int col_base = (i*B_EPI>=D_V/4 ? D_V/2 : 0) + (i*B_EPI%(D_V/4));
CUTE_UNROLL
for (int j = 0; j < B_EPI / 8; ++j)
store_128b(sO_base + (col_base+j*8)*B_H, *reinterpret_cast<bf16x8*>(&o_bf16[j*4]));
}
fence_view_async_shared();
NamedBarrier::arrive_and_wait(128, 0);
if (warp_idx == 4 && elect_one_sync()) {
Tensor tma_gO = tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx, batch_idx);
auto thr_tma = tma_params.tma_O.get_slice(_0{});
Tensor my_tma_gO = flat_divide(tma_gO, Shape<Int<B_H>, Int<D_V>>{})(_, _, head_block_idx, _0{});
cute::copy(
tma_params.tma_O,
thr_tma.partition_S(sO),
thr_tma.partition_D(my_tma_gO)
);
cute::tma_store_arrive();
}
} else {
constexpr int B_EPI = 64;
float2 o[B_EPI/2];
Tensor sO = make_tensor(make_smem_ptr(plan.u.o_accum_buf.data()), SmemLayoutOAccumBuf{});
CUTE_UNROLL
for (int i = 0; i < (D_V/2) / B_EPI; ++i) {
// Load
tmem_ld_32dp32bNx<B_EPI>(tmem_addr::o + i*B_EPI, o);
cutlass::arch::fence_view_async_tmem_load();
// Scale & Convert
CUTE_UNROLL
for (int j = 0; j < B_EPI/2; ++j)
o[j] = float2_mul(o[j], o_scale_float2);
// Store
int col_base = (idx_in_warpgroup/64)*128 + (i*B_EPI >= D_V/4 ? D_V/2 : 0) + (i*B_EPI%(D_V/4));
CUTE_UNROLL
for (int j = 0; j < B_EPI / 4; ++j)
store_128b(&sO(idx_in_warpgroup%64, col_base + j*4), *reinterpret_cast<float4*>(&o[j*2]));
}
fence_view_async_shared();
NamedBarrier::arrive_and_wait(128, 0);
if (elect_one_sync()) {
CUTE_UNROLL
for (int local_row = 0; local_row < B_H/4; ++local_row) {
int smem_row = local_row*4 + (warp_idx-4);
if (smem_row < num_valid_heads) {
SM90_BULK_COPY_S2G::copy(
&sO(smem_row, _0{}),
(float*)params.oaccum_ptr + (split_idx*params.q_seq_per_hk + start_seq_idx + smem_row)*D_V,
D_V*sizeof(float)
);
}
}
cute::tma_store_arrive();
}
}
cute::tma_store_wait<0>();
}
if (warp_idx == 4) {
cute::TMEM::Allocator1Sm().free(0, 512);
}
} else {
cutlass::arch::warpgroup_reg_dealloc<96>();
if (warp_idx == 8) {
// UTCMMA warp
bool bar_phase_q = 0;
TiledMMA tiled_mma_qk = TiledMMA_QK{};
TiledMMA tiled_mma_sv = TiledMMA_SV{};
Tensor tP = partition_fragment_C(tiled_mma_qk, Shape<Int<B_H>, Int<B_TOPK>>{});
Tensor tO = partition_fragment_C(tiled_mma_sv, Shape<Int<B_H>, Int<D_V>>{});
tO.data().get() = tmem_addr::o;
tP.data().get() = tmem_addr::p;
Tensor sS = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutS{});
#pragma unroll 1
for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) {
auto [start_block_idx, end_block_idx, is_no_split] = get_cur_req_info(batch_idx);
if (elect_one_sync()) {
// Copy Q
Tensor gQ = flat_divide(
tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx, batch_idx),
Tile<Int<B_H>, Int<D_K>>{}
)(_, _, head_block_idx, _0{});
launch_tma_copy(tma_params.tma_Q, gQ, sQ, plan.bar_q, TMA::CacheHintSm90::EVICT_FIRST);
plan.bar_q.arrive_and_expect_tx(B_H*D_K*sizeof(bf16));
}
NamedBarrier::arrive_and_wait(NUM_WORKING_THREADS, 1);
if (elect_one_sync()) {
// Wait for Q
plan.bar_q.wait(bar_phase_q);
bar_phase_q ^= 1;
tcgen05_after_thread_sync();
CUTE_NO_UNROLL
for (int block_idx = start_block_idx; block_idx < end_block_idx; block_idx++) {
int buf_idx = block_idx % NUM_BUFS;
// Wait for K
plan.bar_k_ready[buf_idx].wait(bar_phase_k>>buf_idx&1);
tcgen05_after_thread_sync();
Tensor sK = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutK{});
// Issue P = Q @ K^T
utcmma_ss(tiled_mma_qk, sQ, sK, tP, true);
umma_arrive_noelect(plan.bar_qk_done[buf_idx]);
// Wait for S
plan.bar_so_ready[buf_idx].wait(bar_phase_k>>buf_idx&1);
tcgen05_after_thread_sync();
Tensor sV = make_tensor(make_smem_ptr(plan.u.k[buf_idx].data()), SmemLayoutV{});
// Issue O += S @ V
utcmma_ss(tiled_mma_sv, sS, sV, tO, block_idx == start_block_idx);
umma_arrive_noelect(plan.bar_k_free[buf_idx]);
bar_phase_k ^= 1<<buf_idx;
}
}
__syncwarp();
// NOTE If we reach this point, we must have done the QK gemm (since we've waited for bar_so_ready)
// So we can launch the copy of the next Q block immediately
}
}
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100 ~ sm119");
}
#endif
}
void run_flash_splitkv_mla_fp8_sparse_kernel(DecodingParams &params, cudaStream_t stream) {
FLASH_ASSERT(params.h_k == 1);
FLASH_ASSERT(params.topk % B_TOPK == 0);
auto shape_Q = make_shape(params.q_head_per_hk, params.d, params.s_q, params.b);
auto tma_Q = cute::make_tma_copy(
SM90_TMA_LOAD{},
make_tensor(
make_gmem_ptr((bf16*)params.q_ptr),
make_layout(
shape_Q,
make_stride(params.q_row_stride, _1{}, params.q_head_per_hk*params.q_row_stride, params.q_batch_stride)
)
),
SmemLayoutQ{}
);
auto shape_O = make_shape(params.q_head_per_hk, params.d_v, params.s_q, params.b);
auto tma_O = cute::make_tma_copy(
SM90_TMA_STORE{},
make_tensor(
make_gmem_ptr((bf16*)params.o_ptr),
make_layout(
shape_O,
make_stride(params.o_row_stride, _1{}, params.q_head_per_hk*params.o_row_stride, params.o_batch_stride)
)
),
SmemLayoutOBuf{}
);
TmaParams<
decltype(shape_Q), decltype(tma_Q),
decltype(shape_O), decltype(tma_O)
> tma_params = {
shape_Q, tma_Q,
shape_O, tma_O
};
auto mla_kernel = &flash_fwd_splitkv_mla_fp8_sparse_kernel<decltype(tma_params)>;
constexpr size_t smem_size = sizeof(SharedMemoryPlan);
CHECK_CUDA(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
const int num_m_blocks = cute::ceil_div(params.q_head_per_hk, B_H);
// NOTE Don't use PDL because of potential compiler bugs!
mla_kernel<<<dim3(num_m_blocks, params.s_q, params.num_sm_parts), dim3(NUM_THREADS, 1, 1), smem_size, stream>>>(params, tma_params);
CHECK_CUDA_KERNEL_LAUNCH();
}
}
\ No newline at end of file
#pragma once
#include "params.h"
namespace sm100 {
void run_flash_splitkv_mla_fp8_sparse_kernel(DecodingParams &params, cudaStream_t stream);
}
#pragma once
#include <cutlass/bfloat16.h>
#include <cutlass/arch/barrier.h>
namespace sm100 {
using bf16 = cutlass::bfloat16_t;
using fp8 = cutlass::float_e4m3_t;
using transac_bar_t = cutlass::arch::ClusterTransactionBarrier;
using cutlass::arch::fence_view_async_shared;
using cutlass::arch::fence_barrier_init;
using cutlass::arch::NamedBarrier;
struct int32x8_t {
int a0, a1, a2, a3, a4, a5, a6, a7;
};
struct float8 {
float2 a01, a23, a45, a67;
};
struct bf16x8 {
__nv_bfloat162 a01;
__nv_bfloat162 a23;
__nv_bfloat162 a45;
__nv_bfloat162 a67;
};
}
#pragma once
#include <cute/tensor.hpp>
#include "defines.h"
namespace sm100 {
using namespace cute;
using _72 = Int<72>;
using _576 = Int<576>;
template<
typename TMA,
typename Tensor0,
typename Tensor1
>
CUTE_DEVICE
void launch_tma_copy(
const TMA &tma_copy,
Tensor0 src,
Tensor1 dst,
transac_bar_t &bar,
const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL
) {
auto thr_tma = tma_copy.get_slice(_0{});
cute::copy(
tma_copy.with(reinterpret_cast<typename transac_bar_t::ValueType&>(bar), 0, cache_hint),
thr_tma.partition_S(src),
thr_tma.partition_D(dst)
);
}
template<
typename TiledMMA,
typename TensorA,
typename TensorB,
typename TensorFragC
>
CUTE_DEVICE
void utcmma_ss(
TiledMMA &tiled_mma,
TensorA sA,
TensorB sB,
TensorFragC tC_frag,
bool clear_accum
) {
tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One;
ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter
auto sA_frag = thr_mma.partition_fragment_A(sA);
auto sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(sA_frag) == size<2>(sB_frag));
static_assert(size<1>(sA_frag) == size<1>(tC_frag));
static_assert(size<1>(sB_frag) == size<2>(tC_frag));
CUTE_UNROLL
for (int k = 0; k < size<2>(sA_frag); ++k) {
cute::gemm(
tiled_mma,
sA_frag(_, _, k),
sB_frag(_, _, k),
tC_frag
);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
}
template<
typename TiledMMA,
typename TensorA,
typename TensorB,
typename TensorFragC
>
CUTE_DEVICE
void utcmma_ts(
TiledMMA &tiled_mma,
TensorA tA_frag,
TensorB sB,
TensorFragC tC_frag,
bool clear_accum
) {
tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One;
ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter
auto sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(tA_frag) == size<2>(sB_frag));
CUTE_UNROLL
for (int k = 0; k < size<2>(tA_frag); ++k) {
cute::gemm(
tiled_mma,
tA_frag(_, _, k),
sB_frag(_, _, k),
tC_frag
);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
}
}
#pragma once
#include <cute/tensor.hpp>
#include <cute/arch/simd_sm100.hpp>
#include "defines.h"
namespace sm100 {
using namespace cute;
__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) {
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst);
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], %2;\n"
:: "r"(dst_addr),
"l"(src),
"n"(16));
}
CUTE_DEVICE
int64_t createpolicy_evict_last() {
int64_t res;
asm volatile(
"createpolicy.fractional.L2::evict_last.b64 %0, 1.0; \n\t"
: "=l"(res)
:
);
return res;
}
template<typename T>
CUTE_DEVICE
static void st_async_128b(void* dst_ptr, const T& data, const transac_bar_t* mbar_ptr) {
static_assert(sizeof(T) == 16, "Data type must be 16 bytes (128 bits) for st_async_128b.");
long2 data_long2 = *reinterpret_cast<const long2*>(&data);
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr);
asm volatile (
"st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n"
:
: "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr)
);
}
__device__ __forceinline__ void tcgen05_before_thread_sync() {
asm volatile("tcgen05.fence::before_thread_sync;");
}
__device__ __forceinline__ void tcgen05_after_thread_sync() {
asm volatile("tcgen05.fence::after_thread_sync;");
}
CUTE_DEVICE
void umma_arrive_multicast_noelect(transac_bar_t &smem_ptr, uint16_t cta_mask) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&smem_ptr);
asm volatile(
"{\n\t"
"tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t"
"}"
:
:"r"(bar_intptr), "h"(cta_mask));
}
CUTE_DEVICE
void umma_arrive_multicast_2x1SM_noelect(transac_bar_t &smem_ptr, uint16_t cta_mask) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&smem_ptr);
asm volatile(
"{\n\t"
"tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t"
"}"
:
:"r"(bar_intptr), "h"(cta_mask));
}
CUTE_DEVICE
void umma_arrive_noelect(transac_bar_t &smem_ptr) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&smem_ptr);
asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];"
:
:"r"(bar_intptr));
}
CUTE_DEVICE
void umma_arrive_2x1SM_noelect(transac_bar_t &smem_ptr) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(&smem_ptr);
asm volatile("tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.b64 [%0];"
:
:"r"(bar_intptr));
}
CUTE_DEVICE
float2 float2_add(const float2 &a, const float2 &b) {
float2 res;
cute::add(res, a, b);
return res;
}
CUTE_DEVICE
float2 float2_mul(const float2 &a, const float2 &b) {
float2 res;
cute::mul(res, a, b);
return res;
}
CUTE_DEVICE
float2 float2_fma(const float2 &a, const float2 &b, const float2 &c) {
// return a*b+c
float2 res;
cute::fma(res, a, b, c);
return res;
}
CUTE_DEVICE
float2 float2_neg(const float2 &a) {
float2 t = {-1.0f, -1.0f};
return float2_mul(a, t);
}
template<bool USE_CTA0_MBAR = false>
CUTE_DEVICE void tma_gather4(const void* desc_ptr, transac_bar_t* mbar_ptr, void* smem_ptr, int col_idx, int4 row_idxs, TMA::CacheHintSm90 cache_hint) {
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr);
if constexpr (USE_CTA0_MBAR) {
mbar_addr &= Sm100MmaPeerBitMask;
}
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n"
:
: "r"(smem_addr), "l"(desc_ptr), "r"(col_idx),
"r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w),
"r"(mbar_addr), "l"(uint64_t(cache_hint))
: "memory"
);
}
// 32 data path lanes, 32-bit pattern, repeated N times
template <int N, typename T>
CUTE_DEVICE void tmem_ld_32dp32bNx(uint32_t const &src_addr, T* dst_ptr_) {
static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, "N must be a power of 2 and lies between 1 ~ 128");
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(dst_ptr_);
if constexpr (N == 1) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x1.b32"
"{%0},"
"[%1];\n"
: "=r"(dst_ptr[0])
: "r"(src_addr));
} else if constexpr (N == 2) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x2.b32"
"{%0, %1},"
"[%2];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1])
: "r"(src_addr));
} else if constexpr (N == 4) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x4.b32"
"{%0, %1, %2, %3},"
"[%4];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3])
: "r"(src_addr));
} else if constexpr (N == 8) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7},"
"[%8];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7])
: "r"(src_addr));
} else if constexpr (N == 16) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x16.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, "
"%14, %15},"
"[%16];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15])
: "r"(src_addr));
} else if constexpr (N == 32) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x32.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, "
"%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, "
"%26, %27, %28, %29, %30, %31},"
"[%32];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]),
"=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]),
"=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]),
"=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]),
"=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]),
"=r"(dst_ptr[30]), "=r"(dst_ptr[31])
: "r"(src_addr));
} else if constexpr (N == 64) {
asm volatile(
"tcgen05.ld.sync.aligned.32x32b.x64.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, "
"%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, "
"%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, "
"%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, "
"%57, %58, %59, %60, %61, %62, %63},"
"[%64];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]),
"=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]),
"=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]),
"=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]),
"=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]),
"=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]),
"=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]),
"=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]),
"=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]),
"=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]),
"=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]),
"=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]),
"=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]),
"=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]),
"=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]),
"=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]),
"=r"(dst_ptr[63])
: "r"(src_addr));
} else if constexpr (N == 128) {
asm volatile(
"tcgen05.ld.sync.aligned.32x32b.x128.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, "
"%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, "
"%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, "
"%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, "
"%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, "
"%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, "
"%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, "
"%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, "
"%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, "
"%121, %122, %123, %124, %125, %126, %127},"
"[%128];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]),
"=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]),
"=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]),
"=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]),
"=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]),
"=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]),
"=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]),
"=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]),
"=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]),
"=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]),
"=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]),
"=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]),
"=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]),
"=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]),
"=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]),
"=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]),
"=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]),
"=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]),
"=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]),
"=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]),
"=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]),
"=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]),
"=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]),
"=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]),
"=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]),
"=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]),
"=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]),
"=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]),
"=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]),
"=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]),
"=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]),
"=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]),
"=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]),
"=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]),
"=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]),
"=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]),
"=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]),
"=r"(dst_ptr[126]), "=r"(dst_ptr[127])
: "r"(src_addr));
} else {
asm volatile ("trap");
}
}
// 32 data path lanes, 32-bit pattern, repeated N times
template <int N, typename T>
CUTE_DEVICE void tmem_st_32dp32bNx(uint32_t const &dst_addr, T* src_ptr_) {
static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, "N must be a power of 2 and lies between 1 ~ 128");
uint32_t* src_ptr = reinterpret_cast<uint32_t*>(src_ptr_);
if constexpr (N == 1) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x1.b32"
"[%1], {%0};\n"
:
: "r"(src_ptr[0]),
"r"(dst_addr));
} else if constexpr (N == 2) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x2.b32"
"[%2], {%0, %1};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]),
"r"(dst_addr));
} else if constexpr (N == 4) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x4.b32"
"[%4], {%0, %1, %2, %3};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]),
"r"(dst_addr));
} else if constexpr (N == 8) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32"
"[%8], {%0, %1, %2, %3, %4, %5, %6, %7};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]),
"r"(src_ptr[6]), "r"(src_ptr[7]),
"r"(dst_addr));
} else if constexpr (N == 16) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x16.b32"
"[%16], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, "
"%14, %15};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]),
"r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]),
"r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]),
"r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]),
"r"(src_ptr[15]),
"r"(dst_addr));
} else if constexpr (N == 32) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x32.b32"
"[%32], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, "
"%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, "
"%26, %27, %28, %29, %30, %31};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]),
"r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]),
"r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]),
"r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]),
"r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]),
"r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]),
"r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]),
"r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]),
"r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]),
"r"(src_ptr[30]), "r"(src_ptr[31]),
"r"(dst_addr));
} else if constexpr (N == 64) {
asm volatile(
"tcgen05.st.sync.aligned.32x32b.x64.b32"
"[%64], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, "
"%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, "
"%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, "
"%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, "
"%57, %58, %59, %60, %61, %62, %63};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]),
"r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]),
"r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]),
"r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]),
"r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]),
"r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]),
"r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]),
"r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]),
"r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]),
"r"(src_ptr[30]), "r"(src_ptr[31]), "r"(src_ptr[32]),
"r"(src_ptr[33]), "r"(src_ptr[34]), "r"(src_ptr[35]),
"r"(src_ptr[36]), "r"(src_ptr[37]), "r"(src_ptr[38]),
"r"(src_ptr[39]), "r"(src_ptr[40]), "r"(src_ptr[41]),
"r"(src_ptr[42]), "r"(src_ptr[43]), "r"(src_ptr[44]),
"r"(src_ptr[45]), "r"(src_ptr[46]), "r"(src_ptr[47]),
"r"(src_ptr[48]), "r"(src_ptr[49]), "r"(src_ptr[50]),
"r"(src_ptr[51]), "r"(src_ptr[52]), "r"(src_ptr[53]),
"r"(src_ptr[54]), "r"(src_ptr[55]), "r"(src_ptr[56]),
"r"(src_ptr[57]), "r"(src_ptr[58]), "r"(src_ptr[59]),
"r"(src_ptr[60]), "r"(src_ptr[61]), "r"(src_ptr[62]),
"r"(src_ptr[63]),
"r"(dst_addr));
} else if constexpr (N == 128) {
asm volatile(
"tcgen05.st.sync.aligned.32x32b.x128.b32"
"[%128], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, "
"%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, "
"%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, "
"%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, "
"%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, "
"%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, "
"%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, "
"%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, "
"%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, "
"%121, %122, %123, %124, %125, %126, %127};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]),
"r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]),
"r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]),
"r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]),
"r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]),
"r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]),
"r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]),
"r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]),
"r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]),
"r"(src_ptr[30]), "r"(src_ptr[31]), "r"(src_ptr[32]),
"r"(src_ptr[33]), "r"(src_ptr[34]), "r"(src_ptr[35]),
"r"(src_ptr[36]), "r"(src_ptr[37]), "r"(src_ptr[38]),
"r"(src_ptr[39]), "r"(src_ptr[40]), "r"(src_ptr[41]),
"r"(src_ptr[42]), "r"(src_ptr[43]), "r"(src_ptr[44]),
"r"(src_ptr[45]), "r"(src_ptr[46]), "r"(src_ptr[47]),
"r"(src_ptr[48]), "r"(src_ptr[49]), "r"(src_ptr[50]),
"r"(src_ptr[51]), "r"(src_ptr[52]), "r"(src_ptr[53]),
"r"(src_ptr[54]), "r"(src_ptr[55]), "r"(src_ptr[56]),
"r"(src_ptr[57]), "r"(src_ptr[58]), "r"(src_ptr[59]),
"r"(src_ptr[60]), "r"(src_ptr[61]), "r"(src_ptr[62]),
"r"(src_ptr[63]), "r"(src_ptr[64]), "r"(src_ptr[65]),
"r"(src_ptr[66]), "r"(src_ptr[67]), "r"(src_ptr[68]),
"r"(src_ptr[69]), "r"(src_ptr[70]), "r"(src_ptr[71]),
"r"(src_ptr[72]), "r"(src_ptr[73]), "r"(src_ptr[74]),
"r"(src_ptr[75]), "r"(src_ptr[76]), "r"(src_ptr[77]),
"r"(src_ptr[78]), "r"(src_ptr[79]), "r"(src_ptr[80]),
"r"(src_ptr[81]), "r"(src_ptr[82]), "r"(src_ptr[83]),
"r"(src_ptr[84]), "r"(src_ptr[85]), "r"(src_ptr[86]),
"r"(src_ptr[87]), "r"(src_ptr[88]), "r"(src_ptr[89]),
"r"(src_ptr[90]), "r"(src_ptr[91]), "r"(src_ptr[92]),
"r"(src_ptr[93]), "r"(src_ptr[94]), "r"(src_ptr[95]),
"r"(src_ptr[96]), "r"(src_ptr[97]), "r"(src_ptr[98]),
"r"(src_ptr[99]), "r"(src_ptr[100]), "r"(src_ptr[101]),
"r"(src_ptr[102]), "r"(src_ptr[103]), "r"(src_ptr[104]),
"r"(src_ptr[105]), "r"(src_ptr[106]), "r"(src_ptr[107]),
"r"(src_ptr[108]), "r"(src_ptr[109]), "r"(src_ptr[110]),
"r"(src_ptr[111]), "r"(src_ptr[112]), "r"(src_ptr[113]),
"r"(src_ptr[114]), "r"(src_ptr[115]), "r"(src_ptr[116]),
"r"(src_ptr[117]), "r"(src_ptr[118]), "r"(src_ptr[119]),
"r"(src_ptr[120]), "r"(src_ptr[121]), "r"(src_ptr[122]),
"r"(src_ptr[123]), "r"(src_ptr[124]), "r"(src_ptr[125]),
"r"(src_ptr[126]), "r"(src_ptr[127]),
"r"(dst_addr));
} else {
asm volatile ("trap");
}
}
static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. 不确定是不是在所有显卡上都是这个数字
template<typename T>
CUTE_DEVICE
T* get_peer_addr(const T* p) {
return (T*)((int64_t)(p) ^ PEER_ADDR_MASK);
}
}
#include "fwd.h"
#include <math_constants.h>
#include <cute/tensor.hpp>
#include <cutlass/cluster_launch.hpp>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/arch/arch.h>
#include <cutlass/cuda_host_adapter.hpp>
#include "params.h"
#include "utils.h"
#include "sm100/ws_gemm.h"
#include "sm100/helpers.h"
#include "sm100/intrinsics.h"
#include "sm100/tma_cta_group2_nosplit.h"
namespace sm100 {
using namespace cute;
CUTE_DEVICE int32x8_t ldg_256_indices(void* src_ptr) {
int32x8_t val;
asm volatile("ld.global.nc.L1::evict_normal.L2::evict_normal.L2::256B.v8.s32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];"
: "=r"(val.a0), "=r"(val.a1), "=r"(val.a2), "=r"(val.a3),
"=r"(val.a4), "=r"(val.a5), "=r"(val.a6), "=r"(val.a7)
: "l"(src_ptr)
);
return val;
}
template<
typename Shape_Q, typename TMA_Q,
typename Shape_O, typename TMA_O
>
struct TmaParams {
Shape_Q shape_Q; TMA_Q tma_Q;
Shape_O shape_O; TMA_O tma_O;
CUtensorMap tensor_map_kv;
};
struct float2x2 {
float2 lo, hi;
};
constexpr int D_Q = 576;
constexpr int D_K = 576;
constexpr int D_V = 512;
constexpr float MAX_INIT_VAL = -1e30; // We use this number as the initial value for mi (max logits) to avoid -inf - (-inf) = nan
constexpr int B_H = 128; // For 2 CTAs
constexpr int B_TOPK = 128; // For 2 CTAs
constexpr int NUM_BUFS = 2;
constexpr int NUM_THREADS = 256 + 128 + 128; // 128 TMA threads, 128 scale & exp threads, 32 UTCMMA threads
constexpr int D_sQ = 256, NUM_sQ_TILES = D_sQ / 64;
constexpr int D_tQ = D_Q - D_sQ, NUM_tQ_TILES = D_tQ / 64;
static_assert(D_sQ%64 == 0 && D_tQ%64 == 0 && D_sQ + D_tQ == D_Q);
// Tensor memory columns
namespace tmem_cols {
// 0 ~ 256: output
// 256 ~ 320: P
// 320 ~ 512: Q[192:576]
constexpr int o = 0;
constexpr int p = 256;
constexpr int q = 512 - D_tQ/2;
static_assert(p+64 <= q);
}
template<int NUM_TILES>
using SmemLayoutQTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H/2>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutOTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H/2>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutO = SmemLayoutOTiles<8>;
template<int NUM_TILES>
using SmemLayoutKTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_TOPK/2>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
using SmemLayoutV = decltype(coalesce(tile_to_shape(
UMMA::Layout_MN_SW128_Atom<bf16>{},
Shape<Int<256>, Int<B_TOPK>>{},
Step<_2, _1>{}
), Shape<_1, _1>{}));
template<int NUM_TILES>
using SmemLayoutSTiles = decltype(coalesce(tile_to_shape(
UMMA::Layout_K_INTER_Atom<bf16>{},
Shape<Int<B_H/2>, Int<64*NUM_TILES>>{},
Step<_1, _2>{}
), Shape<_1, _1>{}));
struct SharedMemoryPlan {
union {
array_aligned<bf16, cosize_v<SmemLayoutQTiles<9>>> q_full;
struct {
array_aligned<bf16, cosize_v<SmemLayoutQTiles<NUM_sQ_TILES>>> sq;
array_aligned<bf16, cosize_v<SmemLayoutV>> v;
// NOTE K is not overlapped with q_full, so we can do k copy-in while performing S->T copy for q
array_aligned<bf16, cosize_v<SmemLayoutKTiles<9>>> k;
} s;
array_aligned<bf16, cosize_v<SmemLayoutO>> o;
} u;
array_aligned<bf16, cosize_v<SmemLayoutSTiles<2>>> s;
char is_k_valid[NUM_BUFS][B_TOPK/8];
transac_bar_t bar_prologue_q, bar_prologue_utccp;
transac_bar_t bar_qk_part_done[NUM_BUFS], bar_qk_done[NUM_BUFS]; // Pi = QKi^T done (i.e. Ki free)
transac_bar_t bar_sv_part_done[NUM_BUFS], bar_sv_done[NUM_BUFS]; // O += SiVi done (i.e. Vi free)
transac_bar_t bar_k_part0_ready[NUM_BUFS], bar_k_part1_ready[NUM_BUFS];
transac_bar_t bar_v_part0_ready[NUM_BUFS], bar_v_part1_ready[NUM_BUFS]; // Vi is ready
transac_bar_t bar_p_free[NUM_BUFS];
transac_bar_t bar_so_ready[NUM_BUFS]; // S and O are ready
transac_bar_t bar_k_valid_ready[NUM_BUFS], bar_k_valid_free[NUM_BUFS];
array_aligned<uint32_t, 1> tmem_start_addr;
float rowwise_max_buf[128], rowwise_li_buf[128];
};
using TiledMMA_P_tQ = decltype(make_tiled_mma(
SM100_MMA_F16BF16_2x1SM_TS_NOELECT<bf16, bf16, float, B_H, B_TOPK, UMMA::Major::K, UMMA::Major::K>{}
));
using TiledMMA_P_sQ = decltype(make_tiled_mma(
SM100_MMA_F16BF16_2x1SM_SS_NOELECT<bf16, bf16, float, B_H, B_TOPK, UMMA::Major::K, UMMA::Major::K>{}
));
using TiledMMA_O = decltype(make_tiled_mma(
SM100_MMA_F16BF16_2x1SM_SS_NOELECT<bf16, bf16, float, B_H, 256, UMMA::Major::K, UMMA::Major::MN>{},
Layout<Shape<_1, _1, _1>>{},
Tile<Int<128>, Layout<Shape<_128, _2, _2>, Stride<_1, _256, _128>>, _16>{} // We use this permutation layout to let CTA0 takes V[:, 0:256] and CTA1 takes V[:, 256:512]
));
/*
Pipeline Overview:
| Copy | MMA | Scale & Exp |
K0
V0
P0 = QK0^T
K1 S0 = exp(P0)
scale(O) w.r.t P0
P1 = QK1^T
K2 S1 = exp(P1)
O += S0V0
V1 scale(O) w.r.t P1
P2 = QK2^T
K3 S2 = exp(P2)
O += S1V1
V2 scale(O) w.r.t P2
P3 = QK3^T
K4 S3 = exp(P3)
O += S2V2
V3 scale(O) w.r.t P3
...
O += S(n-3)V(n-3)
V(n-2) scale(O) w.r.t P(n-2)
P(n-1) = QK(n-1)^T
S(n-1) = exp(P(n-1))
O += S(n-2)V(n-2)
V(n-1) scale(O) w.r.t P(n-1)
O += S(n-1)V(n-1)
*/
template<typename TmaParams>
__global__ void __launch_bounds__(NUM_THREADS, 1, 2)
sparse_attn_fwd_kernel(__grid_constant__ const SparsePrefillParams params, __grid_constant__ const TmaParams tma_params) {
#if IS_SM100
const int cta_idx = blockIdx.x % 2;
const int s_q_idx = blockIdx.x / 2;
const int warp_idx = cutlass::canonical_warp_idx_sync();
const int lane_idx = threadIdx.x % 32;
const int num_k_blocks = params.topk / B_TOPK;
const int warpgroup_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
const int idx_in_warpgroup = threadIdx.x % 128;
// Prefetch TMA descriptors
if (threadIdx.x == 0) {
cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor());
cute::prefetch_tma_descriptor(&(tma_params.tensor_map_kv));
}
// Define shared tensors
extern __shared__ char wksp_buf[];
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);
Tensor sQ_full = make_tensor(make_smem_ptr(plan.u.q_full.data()), SmemLayoutQTiles<9>{});
int* gIndices = params.indices + s_q_idx*params.stride_indices_s_q; // [topk]
// Allocate tmem tensors
TiledMMA tiled_mma_P_tQ = TiledMMA_P_tQ{};
TiledMMA tiled_mma_P_sQ = TiledMMA_P_sQ{};
TiledMMA tiled_mma_O = TiledMMA_O{};
Tensor tP = partition_fragment_C(tiled_mma_P_tQ, Shape<Int<B_H/2>, Int<B_TOPK>>{});
Tensor tQr = tiled_mma_P_tQ.get_slice(_0{}).make_fragment_A(
partition_shape_A(tiled_mma_P_tQ, Shape<Int<B_H/2>, Int<D_tQ>>{})
);
Tensor tO = partition_fragment_C(tiled_mma_O, Shape<Int<B_H/2>, Int<D_V>>{});
tP.data().get() = tmem_cols::p;
tQr.data().get() = tmem_cols::q;
tO.data().get() = tmem_cols::o;
if (warp_idx == 0) {
if (elect_one_sync()) {
// Initialize barriers
plan.bar_prologue_q.init(1);
plan.bar_prologue_utccp.init(1);
CUTE_UNROLL
for (int i = 0; i < NUM_BUFS; ++i) {
plan.bar_qk_part_done[i].init(1);
plan.bar_qk_done[i].init(1);
plan.bar_sv_part_done[i].init(1);
plan.bar_sv_done[i].init(1);
plan.bar_k_part0_ready[i].init(1);
plan.bar_k_part1_ready[i].init(1);
plan.bar_v_part0_ready[i].init(1);
plan.bar_v_part1_ready[i].init(1);
plan.bar_p_free[i].init(128*2);
plan.bar_so_ready[i].init(128*2);
plan.bar_k_valid_ready[i].init(16);
plan.bar_k_valid_free[i].init(128);
}
fence_barrier_init();
}
}
cute::cluster_sync(); // We must add a cluster_sync() here, or TMA from CTA1 may launch before barrier initialization in CTA0
if (warp_idx == 0) {
if (elect_one_sync()) {
// Copy Q
Tensor gQ = flat_divide(
tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, s_q_idx),
Tile<Int<B_H/2>>{}
)(_, cta_idx, _);
launch_tma_copy(tma_params.tma_Q, gQ, sQ_full, plan.bar_prologue_q, TMA::CacheHintSm90::EVICT_FIRST);
}
// Initialize TMEM
// We put this before cluster_arrive to make sure that the TMEM allocation is done before UTCCP
cute::TMEM::Allocator2Sm().allocate(512, plan.tmem_start_addr.data());
TRAP_ONLY_DEVICE_ASSERT(plan.tmem_start_addr.data()[0] == 0);
cute::TMEM::Allocator2Sm().release_allocation_lock();
__syncwarp();
}
if (warpgroup_idx == 0) {
cutlass::arch::warpgroup_reg_alloc<144>();
// Scale & Exp warps
// The following three numbers are
// - mi: max_logits used to scale Pi (i.e. O := exp2(Pi*scale - mi) @ V)
// - li: sumexp, i.e. li := sum(exp(Pi*scale - mi))
// - real_mi: real max logits, i.e. real_mi := max(Pi*scale)
// where Pi is the i-th row of P, P := QK^T
// mi and real_mi are always consistent within the two threads that
// controls one row (i.e. thread 0+64, 1+65, 2+66, ...) after every update
float mi = MAX_INIT_VAL;
float li = 0.0f;
float real_mi = -CUDART_INF_F;
const float2 scale = float2 {params.sm_scale_div_log2, params.sm_scale_div_log2};
uint128_t* sS_base = (uint128_t*)plan.s.data() + idx_in_warpgroup%64 + 64*((idx_in_warpgroup/64)*8);
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks; ++k) {
// Wait for P
plan.bar_qk_done[k%NUM_BUFS].wait((k/NUM_BUFS)&1);
tcgen05_after_thread_sync();
// Load P
float2 p[(B_TOPK/2)/2];
tmem_ld_32dp32bNx<B_TOPK/2>(tmem_cols::p, p);
cutlass::arch::fence_view_async_tmem_load();
tcgen05_before_thread_sync();
plan.bar_p_free[k%NUM_BUFS].arrive(0u);
// Mask
plan.bar_k_valid_ready[k%NUM_BUFS].wait((k/NUM_BUFS)&1);
// The following code enables NVCC to use R2P instruction
// Although we perform 2x LDS.32 instructions here, don't worry, NVCC will
// convert them to one LDS.64 instruction. However, if we write LDS.64
// here, NVCC won't use R2P.
uint32_t is_k_valid_lo = *(uint32_t*)(plan.is_k_valid[k%NUM_BUFS] + (idx_in_warpgroup>=64?B_TOPK/8/2:0));
uint32_t is_k_valid_hi = *(uint32_t*)(plan.is_k_valid[k%NUM_BUFS] + (idx_in_warpgroup>=64?B_TOPK/8/2:0) + 4);
float* p_float = (float*)p;
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2)/2; i += 1) {
if (!(is_k_valid_lo >> i & 1))
p_float[i] = -CUDART_INF_F;
}
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2)/2; i += 1) {
if (!(is_k_valid_hi >> i & 1))
p_float[i+(B_TOPK/2)/2] = -CUDART_INF_F;
}
// Get rowwise max of Pi
float cur_pi_max = -CUDART_INF_F;
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2); i += 1) {
cur_pi_max = max(cur_pi_max, p_float[i]);
}
cur_pi_max *= params.sm_scale_div_log2;
plan.bar_k_valid_free[k%NUM_BUFS].arrive();
NamedBarrier::arrive_and_wait(128, 0); // Wait for rowwise_max_buf and sP to be ready
plan.rowwise_max_buf[idx_in_warpgroup] = cur_pi_max;
NamedBarrier::arrive_and_wait(128, 0); // TODO Name these barriers
cur_pi_max = max(cur_pi_max, plan.rowwise_max_buf[idx_in_warpgroup^64]);
real_mi = max(real_mi, cur_pi_max);
bool should_scale_o = __any_sync(0xffffffff, cur_pi_max - mi > 6.0f);
// By this point:
// - cur_pi_max, real_mi, and mi is identical within each row (i.e. thread 0+64, 1+65, ...)
// - should_scale_o is identical among threads 0~31+64~95; and is identical among threads 32~63+96~127
// Calc scale factor, and scale li
float new_max, scale_for_old;
if (!should_scale_o) {
// Don't scale O
scale_for_old = 1.0f;
new_max = mi;
} else {
new_max = max(cur_pi_max, mi);
scale_for_old = exp2f(mi - new_max);
}
mi = new_max; // mi is still identical within each row
li *= scale_for_old;
// Calculate S
__nv_bfloat162 s[(B_TOPK/2)/2];
float2 neg_new_max = float2 {-new_max, -new_max};
CUTE_UNROLL
for (int i = 0; i < (B_TOPK/2)/2; i += 1) {
float2 d = float2_fma(p[i], scale, neg_new_max);
d.x = exp2f(d.x);
d.y = exp2f(d.y);
li += d.x + d.y; // NOTE Theorically we can have use FFMA2 here but actually this is faster...
s[i] = __float22bfloat162_rn(d);
}
// Wait for last SV gemm, write S
if (k > 0) {
plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);
}
CUTE_UNROLL
for (int i = 0; i < B_TOPK/2/8; i += 1) {
sS_base[64*i] = *(uint128_t*)(s + i*4);
}
// Scale O
if (k > 0 && should_scale_o) {
float2 scale_for_old_float2 = float2 {scale_for_old, scale_for_old};
// plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1); // NOTE We have waited for last SV gemm before
tcgen05_after_thread_sync();
static constexpr int CHUNK_SIZE = 32;
float2 o[CHUNK_SIZE/2];
CUTE_UNROLL
for (int chunk_idx = 0; chunk_idx < (D_V/2)/CHUNK_SIZE; ++chunk_idx) {
// Load O
tmem_ld_32dp32bNx<CHUNK_SIZE>(tmem_cols::o + chunk_idx*CHUNK_SIZE, o);
cutlass::arch::fence_view_async_tmem_load();
// Mult
for (int i = 0; i < CHUNK_SIZE/2; ++i) {
o[i] = float2_mul(o[i], scale_for_old_float2);
}
// Store O
tmem_st_32dp32bNx<CHUNK_SIZE>(tmem_cols::o + chunk_idx*CHUNK_SIZE, o);
cutlass::arch::fence_view_async_tmem_store();
}
tcgen05_before_thread_sync();
}
fence_view_async_shared();
plan.bar_so_ready[k%NUM_BUFS].arrive(0u);
}
// Epilogue
if (real_mi == -CUDART_INF_F) {
// real_mi == -CUDART_INF_F <=> No valid TopK indices
// We set li to 0 to fit the definition that li := exp(x[i] - mi)
li = 0.0f;
mi = -CUDART_INF_F;
}
// Exchange li
plan.rowwise_li_buf[idx_in_warpgroup] = li;
NamedBarrier::arrive_and_wait(128, 0);
li += plan.rowwise_li_buf[idx_in_warpgroup^64];
// Store mi and li
if (idx_in_warpgroup < 64) {
int global_index = s_q_idx*params.h_q + cta_idx*(B_H/2) + idx_in_warpgroup;
float cur_lse = log2f(li) + mi;
params.max_logits[global_index] = real_mi;
params.lse[global_index] = cur_lse;
}
// Wait for the last GEMM
plan.bar_sv_done[(num_k_blocks-1)%NUM_BUFS].wait(((num_k_blocks-1)/NUM_BUFS)&1);
tcgen05_after_thread_sync();
// Store O
float output_scale = __fdividef(1.0f, li);
Tensor sO = make_tensor(make_smem_ptr(plan.u.o.data()), SmemLayoutO{});
constexpr int B_EPI = 64;
Tensor tma_gO = flat_divide(
tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, s_q_idx),
Shape<Int<B_H/2>, Int<B_EPI>>{}
)(_, _, cta_idx, _);
Tensor sO_divided = flat_divide(
sO,
Shape<Int<B_H/2>, Int<B_EPI>>{}
)(_, _, _0{}, _);
auto thr_tma = tma_params.tma_O.get_slice(_0{});
float2 o[B_EPI/2];
bool have_valid_indices = __any_sync(0xffffffff, li != 0); // Prevent some threads' li == 0 and some threads' li != 0 which lead to deadlock during tmem_ld
if (!have_valid_indices) {
// If there are no valid indices, we set o[i] to 0 and don't load from TMEM
CUTE_UNROLL
for (int i = 0; i < B_EPI/2; ++i)
o[i].x = o[i].y = 0.0f;
output_scale = 1.0f;
}
float2 output_scale_float2 = make_float2(output_scale, output_scale);
CUTE_UNROLL
for (int k = 0; k < (D_V/2)/B_EPI; ++k) {
// Load O from tO
if (have_valid_indices) {
tmem_ld_32dp32bNx<B_EPI>(tmem_cols::o + k*B_EPI, o);
cutlass::arch::fence_view_async_tmem_load();
}
// Convert and store
CUTE_UNROLL
for (int i = 0; i < B_EPI/8; ++i) {
__nv_bfloat162 o_bf16[4];
CUTE_UNROLL
for (int j = 0; j < 4; ++j) {
float2 d = float2_mul(o[i*4+j], output_scale_float2);
o_bf16[j] = __float22bfloat162_rn(d);
}
int smem_row = idx_in_warpgroup % 64;
int smem_col = (idx_in_warpgroup/64)*(D_V/2) + k*B_EPI + i*8;
*(uint128_t*)(&sO(smem_row, smem_col)) = *(uint128_t*)(o_bf16);
}
// Sync
fence_view_async_shared();
NamedBarrier::arrive_and_wait(128, 0);
if (warp_idx == 0 && elect_one_sync()) {
cute::copy(
tma_params.tma_O,
thr_tma.partition_S(sO_divided(_, _, k)),
thr_tma.partition_D(tma_gO(_, _, k))
);
}
if (warp_idx == 1 && elect_one_sync()) {
int k2 = k + (D_V/B_EPI/2);
cute::copy(
tma_params.tma_O,
thr_tma.partition_S(sO_divided(_, _, k2)),
thr_tma.partition_D(tma_gO(_, _, k2))
);
}
}
if (warp_idx == 0) {
cute::TMEM::Allocator2Sm().free(0, 512);
}
} else if (warpgroup_idx == 1) {
// Producer warp for K
cutlass::arch::warpgroup_reg_dealloc<96>();
int warp_idx = cutlass::canonical_warp_idx_sync() - 4;
constexpr int NUM_WARPS = 4, NUM_LOCAL_ROWS_PER_WARP = (B_TOPK/2)/4/NUM_WARPS;
if (elect_one_sync()) {
bf16* sK_base = plan.u.s.k.data() + warp_idx*4*64;
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks; ++k) {
int4 indices[NUM_LOCAL_ROWS_PER_WARP];
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row)
indices[local_row] = __ldg((int4*)(gIndices + k*B_TOPK + cta_idx*(B_TOPK/2)) + local_row*NUM_WARPS + warp_idx);
auto load_part_ki = [&](transac_bar_t* bar, int local_col_start, int local_col_end) {
CUTE_UNROLL
for (int local_row = 0; local_row < NUM_LOCAL_ROWS_PER_WARP; ++local_row) {
CUTE_UNROLL
for (int local_col = local_col_start; local_col < local_col_end; ++local_col)
tma_gather4<true>(
&(tma_params.tensor_map_kv),
bar,
sK_base + local_row*(4*NUM_WARPS)*64 + local_col*((B_TOPK/2)*64),
local_col*64,
indices[local_row],
TMA::CacheHintSm90::EVICT_LAST
);
}
};
int cur_buf = k%NUM_BUFS;
if (k > 0) {
plan.bar_qk_part_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);
}
load_part_ki(plan.bar_k_part0_ready+cur_buf, 0, D_sQ/64);
if (k > 0) {
plan.bar_qk_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);
}
load_part_ki(plan.bar_k_part1_ready+cur_buf, D_sQ/64, D_K/64);
}
}
} else if (warpgroup_idx == 2) {
// Producer warps for V
cutlass::arch::warpgroup_reg_dealloc<96>();
int warp_idx = cutlass::canonical_warp_idx_sync() - 8;
constexpr int NUM_WARPS = 4;
if (elect_one_sync()) {
// Wait for UTCCP
plan.bar_prologue_utccp.wait(0);
bf16* sV_base = plan.u.s.v.data() + warp_idx*4*64;
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks; ++k) {
auto load_part_vi = [&](transac_bar_t* bar, int local_row_start, int local_row_end) {
CUTE_UNROLL
for (int local_row = local_row_start; local_row < local_row_end; ++local_row) {
int4 token_idxs = __ldg((int4*)(gIndices + k*B_TOPK) + local_row*NUM_WARPS + warp_idx);
CUTE_UNROLL
for (int local_col = 0; local_col < (D_V/2)/64; ++local_col)
tma_gather4<true>(
&(tma_params.tensor_map_kv),
bar,
sV_base + local_row*(4*NUM_WARPS)*64 + local_col*(B_TOPK*64),
local_col*64 + (cta_idx?256:0),
token_idxs,
TMA::CacheHintSm90::EVICT_LAST
);
}
};
int cur_buf = k%NUM_BUFS;
if (k > 0) {
plan.bar_sv_part_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);
}
load_part_vi(plan.bar_v_part0_ready+cur_buf, 0, (B_TOPK/2)/4/NUM_WARPS);
if (k > 0) {
plan.bar_sv_done[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);
}
load_part_vi(plan.bar_v_part1_ready+cur_buf, (B_TOPK/2)/4/NUM_WARPS, B_TOPK/4/NUM_WARPS);
}
}
} else {
cutlass::arch::warpgroup_reg_alloc<168>();
// MMA warp
if (cta_idx == 0 && warp_idx == 12 && elect_one_sync()) {
// S -> T copy for Q
UMMA::SmemDescriptor sQ_desc = UMMA::make_umma_desc<UMMA::Major::K>(
make_tensor(
make_smem_ptr(plan.u.q_full.data() + (B_H/2)*D_sQ),
tile_to_shape(
UMMA::Layout_K_SW128_Atom<bf16>{},
Shape<Int<B_H/2>, Int<64>>{}
)
)
);
plan.bar_prologue_q.arrive_and_expect_tx(B_H*D_K*sizeof(bf16));
plan.bar_prologue_q.wait(0);
tcgen05_after_thread_sync();
CUTE_UNROLL
for (int tile_idx = 0; tile_idx < NUM_tQ_TILES; ++tile_idx) {
// A tile is 64 rows * 64 cols (128B)
CUTE_UNROLL
for (int subtile_idx = 0; subtile_idx < 8; ++subtile_idx) {
// A subtile is 64 rows * 8 cols (128b)
SM100_UTCCP_2x64dp128bitlw0213_2cta::copy(
sQ_desc + tile_idx*((B_H/2)*128/16) + subtile_idx*(16/16), // Remember that 4 LSBs are not included
tmem_cols::q + tile_idx*32 + subtile_idx*4
);
}
}
umma_arrive_multicast_2x1SM_noelect(plan.bar_prologue_utccp, 1|2);
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks+1; ++k) {
if (k < num_k_blocks) {
// Pi = QKi^T
int cur_buf = k%NUM_BUFS;
Tensor sQl = make_tensor(make_smem_ptr(plan.u.s.sq.data()), SmemLayoutQTiles<NUM_sQ_TILES>{});
Tensor sKl = make_tensor(make_smem_ptr(plan.u.s.k.data()), SmemLayoutKTiles<NUM_sQ_TILES>{});
Tensor sKr = make_tensor(make_smem_ptr(plan.u.s.k.data()+64*D_sQ), SmemLayoutKTiles<NUM_tQ_TILES>{});
// Wait for K (part0)
plan.bar_k_part0_ready[cur_buf].arrive_and_expect_tx(B_TOPK*D_sQ*sizeof(bf16));
plan.bar_k_part0_ready[cur_buf].wait((k/NUM_BUFS)&1);
if (k > 0) {
plan.bar_p_free[(k-1)%NUM_BUFS].wait(((k-1)/NUM_BUFS)&1);
}
tcgen05_after_thread_sync();
utcmma_ss(tiled_mma_P_sQ, sQl, sKl, tP, true);
umma_arrive_multicast_2x1SM_noelect(plan.bar_qk_part_done[cur_buf], 1|2);
// Wait for K (part1)
plan.bar_k_part1_ready[cur_buf].arrive_and_expect_tx(B_TOPK*(D_K-D_sQ)*sizeof(bf16));
plan.bar_k_part1_ready[cur_buf].wait((k/NUM_BUFS)&1);
tcgen05_after_thread_sync();
utcmma_ts(tiled_mma_P_tQ, tQr, sKr, tP, false);
umma_arrive_multicast_2x1SM_noelect(plan.bar_qk_done[cur_buf], 1|2);
}
if (k > 0) {
// O += S(i-1)V(i-1)
int cur_buf = (k-1)%NUM_BUFS;
Tensor sS = make_tensor(make_smem_ptr(plan.s.data()), SmemLayoutSTiles<2>{});
Tensor sV = make_tensor(make_smem_ptr(plan.u.s.v.data()), SmemLayoutV{});
Tensor sS_divided = flat_divide(sS, Tile<Int<B_H/2>, _64>{})(_, _, _0{}, _); // (B_H/2, 64, 2)
Tensor sV_divided = flat_divide(sV, Tile<Int<D_V/2>, _64>{})(_, _, _0{}, _); // (D_V/2, 64, 2)
// Wait for S(i-1) and O to be scaled
plan.bar_so_ready[cur_buf].wait(((k-1)/NUM_BUFS)&1);
// Wait for V (part0), and issue O += sS @ sV
plan.bar_v_part0_ready[cur_buf].arrive_and_expect_tx((B_TOPK/2)*D_V*sizeof(bf16));
plan.bar_v_part0_ready[cur_buf].wait(((k-1)/NUM_BUFS)&1);
tcgen05_after_thread_sync();
utcmma_ss(tiled_mma_O, sS_divided(_, _, _0{}), sV_divided(_, _, _0{}), tO, k == 1);
umma_arrive_multicast_2x1SM_noelect(plan.bar_sv_part_done[cur_buf], 1|2);
// Wait for V (part1), and issue O += sS @ sV
plan.bar_v_part1_ready[cur_buf].arrive_and_expect_tx((B_TOPK/2)*D_V*sizeof(bf16));
plan.bar_v_part1_ready[cur_buf].wait(((k-1)/NUM_BUFS)&1);
tcgen05_after_thread_sync();
utcmma_ss(tiled_mma_O, sS_divided(_, _, _1{}), sV_divided(_, _, _1{}), tO, false);
umma_arrive_multicast_2x1SM_noelect(plan.bar_sv_done[cur_buf], 1|2);
}
}
} else if (warp_idx == 13) {
// KV valid loading warp
static_assert(B_TOPK == 128);
if (lane_idx < 16) {
CUTE_NO_UNROLL
for (int k = 0; k < num_k_blocks; ++k) {
int cur_buf = k%NUM_BUFS;
int32x8_t indices = ldg_256_indices(gIndices + k*B_TOPK + lane_idx*8);
auto is_valid = [&](int index) -> char {
return index >= 0 && index < params.s_kv;
};
char is_ks_valid_mask = \
is_valid(indices.a7) << 7 |
is_valid(indices.a6) << 6 |
is_valid(indices.a5) << 5 |
is_valid(indices.a4) << 4 |
is_valid(indices.a3) << 3 |
is_valid(indices.a2) << 2 |
is_valid(indices.a1) << 1 |
is_valid(indices.a0) << 0;
plan.bar_k_valid_free[cur_buf].wait((k/NUM_BUFS)&1^1);
plan.is_k_valid[cur_buf][lane_idx] = is_ks_valid_mask;
plan.bar_k_valid_ready[cur_buf].arrive();
}
}
}
}
#else
if (cute::thread0()) {
CUTE_INVALID_CONTROL_PATH("This kernel only supports sm100");
}
#endif
}
void run_fwd_kernel(const SparsePrefillParams& params) {
FLASH_ASSERT(params.h_kv == 1);
FLASH_ASSERT(params.topk % B_TOPK == 0); // To save some boundry checkings
FLASH_ASSERT(params.h_q == B_H); // To save some calculation
auto shape_Q = make_shape(params.h_q, params.d_qk, params.s_q);
auto tma_Q = cute::make_tma_copy(
SM100_TMA_2SM_LOAD_NOSPLIT{},
make_tensor(
make_gmem_ptr((bf16*)params.q),
make_layout(
shape_Q,
make_stride(params.stride_q_h_q, _1{}, params.stride_q_s_q)
)
),
SmemLayoutQTiles<9>{}
);
auto shape_O = make_shape(params.h_q, params.d_v, params.s_q);
auto tma_O = cute::make_tma_copy(
SM90_TMA_STORE{},
make_tensor(
make_gmem_ptr((bf16*)params.out),
make_layout(
shape_O,
make_stride(params.d_v, _1{}, params.h_q*params.d_v)
)
),
SmemLayoutOTiles<1>{}
);
CUtensorMap tensor_map_kv;
{
uint64_t size[2] = {D_K, (unsigned long)params.s_kv};
uint64_t stride[1] = {params.stride_kv_s_kv*sizeof(bf16)};
uint32_t box_size[2] = {64, 1};
uint32_t elem_stride[2] = {1, 1};
CUresult res = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)(
&tensor_map_kv,
CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
2,
params.kv,
size,
stride,
box_size,
elem_stride,
CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,
CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B,
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);
FLASH_ASSERT(res == CUresult::CUDA_SUCCESS);
}
TmaParams<
decltype(shape_Q), decltype(tma_Q),
decltype(shape_O), decltype(tma_O)
> tma_params = {
shape_Q, tma_Q,
shape_O, tma_O,
tensor_map_kv
};
auto kernel = &sparse_attn_fwd_kernel<decltype(tma_params)>;
constexpr size_t smem_size = sizeof(SharedMemoryPlan);
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
cutlass::ClusterLaunchParams launch_params = {
dim3(2*params.s_q, 1, 1),
dim3(NUM_THREADS, 1, 1),
dim3(2, 1, 1),
smem_size,
params.stream
};
cutlass::launch_kernel_on_cluster(
launch_params, (void*)kernel, params, tma_params
);
CHECK_CUDA_KERNEL_LAUNCH();
}
}
#pragma once
#include "params.h"
namespace sm100 {
void run_fwd_kernel(const SparsePrefillParams& params);
}
#pragma once
#include <cute/tensor.hpp>
#include "sm100/defines.h"
namespace sm100 {
using namespace cute;
using _72 = Int<72>;
using _576 = Int<576>;
template<
typename TMA,
typename Tensor0,
typename Tensor1
>
CUTE_DEVICE
void launch_tma_copy(
const TMA &tma_copy,
Tensor0 src,
Tensor1 dst,
transac_bar_t &bar,
const cute::TMA::CacheHintSm90 &cache_hint = cute::TMA::CacheHintSm90::EVICT_NORMAL
) {
auto thr_tma = tma_copy.get_slice(_0{});
cute::copy(
tma_copy.with(reinterpret_cast<typename transac_bar_t::ValueType&>(bar), 0, cache_hint),
thr_tma.partition_S(src),
thr_tma.partition_D(dst)
);
}
template<
typename TiledMMA,
typename TensorA,
typename TensorB,
typename TensorFragC
>
CUTE_DEVICE
void utcmma(
TiledMMA &tiled_mma,
TensorA sA,
TensorB sB,
TensorFragC tC_frag,
bool clear_accum
) {
tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One;
ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter
auto sA_frag = thr_mma.partition_fragment_A(sA);
auto sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(sA_frag) == size<2>(sB_frag));
static_assert(size<1>(sA_frag) == size<1>(tC_frag));
static_assert(size<1>(sB_frag) == size<2>(tC_frag));
CUTE_UNROLL
for (int k = 0; k < size<2>(sA_frag); ++k) {
cute::gemm(
tiled_mma,
sA_frag(_, _, k),
sB_frag(_, _, k),
tC_frag
);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
}
template<
typename TiledMMA,
typename TensorA,
typename TensorB,
typename TensorFragC
>
CUTE_DEVICE
void utcmma_ts(
TiledMMA &tiled_mma,
TensorA tA_frag,
TensorB sB,
TensorFragC tC_frag,
bool clear_accum
) {
tiled_mma.accumulate_ = clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One;
ThrMMA thr_mma = tiled_mma.get_slice(_0{}); // Since A/B/C are already CTA-local tiles, this number does not matter
auto sB_frag = thr_mma.partition_fragment_B(sB);
static_assert(size<2>(tA_frag) == size<2>(sB_frag));
CUTE_UNROLL
for (int k = 0; k < size<2>(tA_frag); ++k) {
cute::gemm(
tiled_mma,
tA_frag(_, _, k),
sB_frag(_, _, k),
tC_frag
);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
}
struct bf16x8 {
__nv_bfloat162 a01;
__nv_bfloat162 a23;
__nv_bfloat162 a45;
__nv_bfloat162 a67;
};
}
#pragma once
#include <cute/tensor.hpp>
#include "defines.h"
namespace sm100 {
using namespace cute;
struct int32x8_t {
int a0, a1, a2, a3, a4, a5, a6, a7;
};
struct float8 {
float2 a01, a23, a45, a67;
};
__forceinline__ __device__ void cp_async_cacheglobal_l2_prefetch_256B(const void* src, void* dst) {
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst);
asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], %2;\n"
:: "r"(dst_addr),
"l"(src),
"n"(16));
}
template<typename T>
CUTE_DEVICE
static void st_async_128b(void* dst_ptr, const T& data, const transac_bar_t* mbar_ptr) {
static_assert(sizeof(T) == 16, "Data type must be 16 bytes (128 bits) for st_async_128b.");
long2 data_long2 = *reinterpret_cast<const long2*>(&data);
uint32_t dst_addr = cute::cast_smem_ptr_to_uint(dst_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr);
asm volatile (
"st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.s64 [%0], {%1, %2}, [%3]; \n"
:
: "r"(dst_addr), "l"(data_long2.x), "l"(data_long2.y), "r"(mbar_addr)
);
}
CUTE_DEVICE
void umma_arrive_multicast_noelect(uint64_t const* smem_ptr, uint16_t cta_mask) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr);
asm volatile(
"{\n\t"
"tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t"
"}"
:
:"r"(bar_intptr), "h"(cta_mask));
}
CUTE_DEVICE
void umma_arrive_multicast_noelect(transac_bar_t const* smem_ptr, uint16_t cta_mask) {
umma_arrive_multicast_noelect((uint64_t*)smem_ptr, cta_mask);
}
CUTE_DEVICE
void umma_arrive_multicast_2x1SM_noelect(uint64_t const* smem_ptr, uint16_t cta_mask) {
uint32_t bar_intptr = cute::cast_smem_ptr_to_uint(smem_ptr);
asm volatile(
"{\n\t"
"tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1; \n\t"
"}"
:
:"r"(bar_intptr), "h"(cta_mask));
}
CUTE_DEVICE
void umma_arrive_multicast_2x1SM_noelect(transac_bar_t const* smem_ptr, uint16_t cta_mask) {
umma_arrive_multicast_2x1SM_noelect((uint64_t*)smem_ptr, cta_mask);
}
CUTE_DEVICE
int64_t createpolicy_evict_last() {
int64_t res;
asm volatile(
"createpolicy.fractional.L2::evict_last.b64 %0, 1.0; \n\t"
: "=l"(res)
:
);
return res;
}
CUTE_DEVICE
void atomicadd_f32x4_with_policy(void* global_addr, const float4 &data, int64_t cache_policy) {
asm volatile(
"red.relaxed.gpu.global.add.L2::cache_hint.v4.f32 [%4], {%0, %1, %2, %3}, %5; \n\t"
:
: "f"(data.x), "f"(data.y), "f"(data.z), "f"(data.w),
"l"((int64_t)global_addr), "l"(cache_policy)
);
}
CUTE_DEVICE
void tma_bulk_reduce_add(void const* src_ptr, void* dst_ptr, int32_t store_bytes) {
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(src_ptr);
asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n"
:
: "l"(dst_ptr), "r"(smem_int_ptr), "r"(store_bytes)
: "memory");
}
CUTE_DEVICE
float2 float2_add(const float2 &a, const float2 &b) {
float2 res;
cute::add(res, a, b);
return res;
}
CUTE_DEVICE
float2 float2_mul(const float2 &a, const float2 &b) {
float2 res;
cute::mul(res, a, b);
return res;
}
CUTE_DEVICE
float2 float2_fma(const float2 &a, const float2 &b, const float2 &c) {
// return a*b+c
float2 res;
cute::fma(res, a, b, c);
return res;
}
CUTE_DEVICE
float2 float2_neg(const float2 &a) {
float2 t = {-1.0f, -1.0f};
return float2_mul(a, t);
}
__device__ __forceinline__ void tcgen05_before_thread_sync() {
asm volatile("tcgen05.fence::before_thread_sync;");
}
__device__ __forceinline__ void tcgen05_after_thread_sync() {
asm volatile("tcgen05.fence::after_thread_sync;");
}
template<bool USE_CTA0_MBAR = false>
CUTE_DEVICE void tma_gather4(const void* desc_ptr, transac_bar_t* mbar_ptr, void* smem_ptr, int col_idx, int4 row_idxs, TMA::CacheHintSm90 cache_hint) {
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
uint32_t mbar_addr = cute::cast_smem_ptr_to_uint(mbar_ptr);
if constexpr (USE_CTA0_MBAR) {
mbar_addr &= Sm100MmaPeerBitMask;
}
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n"
:
: "r"(smem_addr), "l"(desc_ptr), "r"(col_idx),
"r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w),
"r"(mbar_addr), "l"(uint64_t(cache_hint))
: "memory"
);
}
// 32 data path lanes, 32-bit pattern, repeated N times
template <int N, typename T>
CUTE_DEVICE void tmem_ld_32dp32bNx(uint32_t const &src_addr, T* dst_ptr_) {
static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, "N must be a power of 2 and lies between 1 ~ 128");
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(dst_ptr_);
if constexpr (N == 1) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x1.b32"
"{%0},"
"[%1];\n"
: "=r"(dst_ptr[0])
: "r"(src_addr));
} else if constexpr (N == 2) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x2.b32"
"{%0, %1},"
"[%2];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1])
: "r"(src_addr));
} else if constexpr (N == 4) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x4.b32"
"{%0, %1, %2, %3},"
"[%4];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3])
: "r"(src_addr));
} else if constexpr (N == 8) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7},"
"[%8];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7])
: "r"(src_addr));
} else if constexpr (N == 16) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x16.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, "
"%14, %15},"
"[%16];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15])
: "r"(src_addr));
} else if constexpr (N == 32) {
asm volatile("tcgen05.ld.sync.aligned.32x32b.x32.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, "
"%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, "
"%26, %27, %28, %29, %30, %31},"
"[%32];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]),
"=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]),
"=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]),
"=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]),
"=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]),
"=r"(dst_ptr[30]), "=r"(dst_ptr[31])
: "r"(src_addr));
} else if constexpr (N == 64) {
asm volatile(
"tcgen05.ld.sync.aligned.32x32b.x64.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, "
"%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, "
"%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, "
"%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, "
"%57, %58, %59, %60, %61, %62, %63},"
"[%64];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]),
"=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]),
"=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]),
"=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]),
"=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]),
"=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]),
"=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]),
"=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]),
"=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]),
"=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]),
"=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]),
"=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]),
"=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]),
"=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]),
"=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]),
"=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]),
"=r"(dst_ptr[63])
: "r"(src_addr));
} else if constexpr (N == 128) {
asm volatile(
"tcgen05.ld.sync.aligned.32x32b.x128.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, "
"%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, "
"%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, "
"%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, "
"%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, "
"%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, "
"%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, "
"%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, "
"%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, "
"%121, %122, %123, %124, %125, %126, %127},"
"[%128];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]),
"=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]),
"=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]),
"=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]),
"=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]),
"=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]),
"=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]),
"=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]),
"=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]),
"=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]),
"=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]),
"=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]),
"=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]),
"=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]),
"=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]),
"=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]),
"=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]),
"=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]),
"=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]),
"=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]),
"=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]),
"=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]),
"=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]),
"=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]),
"=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]),
"=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]),
"=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]),
"=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]),
"=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]),
"=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]),
"=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]),
"=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]),
"=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]),
"=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]),
"=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]),
"=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]),
"=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]),
"=r"(dst_ptr[126]), "=r"(dst_ptr[127])
: "r"(src_addr));
} else {
asm volatile ("trap");
}
}
// 16 data path lanes, 256-bit pattern, repeated N times
template <int N, typename T>
CUTE_DEVICE void tmem_ld_16dp256bNx(uint32_t const &src_addr, T* dst_ptr_) {
static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 32,
"N must be a power of 2 and lies between 1 ~ 32");
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(dst_ptr_);
if constexpr (N == 1) {
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32"
"{%0, %1, %2, %3},"
"[%4];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3])
: "r"(src_addr));
} else if constexpr (N == 2) {
asm volatile("tcgen05.ld.sync.aligned.16x256b.x2.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7},"
"[%8];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7])
: "r"(src_addr));
} else if constexpr (N == 4) {
asm volatile(
"tcgen05.ld.sync.aligned.16x256b.x4.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, "
"%14, %15},"
"[%16];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15])
: "r"(src_addr));
} else if constexpr (N == 8) {
asm volatile(
"tcgen05.ld.sync.aligned.16x256b.x8.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, "
"%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, "
"%26, %27, %28, %29, %30, %31},"
"[%32];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]),
"=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]),
"=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]),
"=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]),
"=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]),
"=r"(dst_ptr[30]), "=r"(dst_ptr[31])
: "r"(src_addr));
} else if constexpr (N == 16) {
asm volatile(
"tcgen05.ld.sync.aligned.16x256b.x16.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, "
"%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, "
"%28, "
"%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, "
"%42, "
"%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, "
"%56, "
"%57, %58, %59, %60, %61, %62, %63},"
"[%64];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]),
"=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]),
"=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]),
"=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]),
"=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]),
"=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]),
"=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]),
"=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]),
"=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]),
"=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]),
"=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]),
"=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]),
"=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]),
"=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]),
"=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]),
"=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]),
"=r"(dst_ptr[63])
: "r"(src_addr));
} else if constexpr (N == 32) {
asm volatile(
"tcgen05.ld.sync.aligned.16x256b.x32.b32"
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, "
"%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, "
"%28, "
"%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, "
"%42, "
"%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, "
"%56, "
"%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, "
"%70, "
"%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, "
"%84, "
"%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, "
"%98, "
"%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, "
"%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, "
"%121, %122, %123, %124, %125, %126, %127},"
"[%128];\n"
: "=r"(dst_ptr[0]), "=r"(dst_ptr[1]), "=r"(dst_ptr[2]),
"=r"(dst_ptr[3]), "=r"(dst_ptr[4]), "=r"(dst_ptr[5]),
"=r"(dst_ptr[6]), "=r"(dst_ptr[7]), "=r"(dst_ptr[8]),
"=r"(dst_ptr[9]), "=r"(dst_ptr[10]), "=r"(dst_ptr[11]),
"=r"(dst_ptr[12]), "=r"(dst_ptr[13]), "=r"(dst_ptr[14]),
"=r"(dst_ptr[15]), "=r"(dst_ptr[16]), "=r"(dst_ptr[17]),
"=r"(dst_ptr[18]), "=r"(dst_ptr[19]), "=r"(dst_ptr[20]),
"=r"(dst_ptr[21]), "=r"(dst_ptr[22]), "=r"(dst_ptr[23]),
"=r"(dst_ptr[24]), "=r"(dst_ptr[25]), "=r"(dst_ptr[26]),
"=r"(dst_ptr[27]), "=r"(dst_ptr[28]), "=r"(dst_ptr[29]),
"=r"(dst_ptr[30]), "=r"(dst_ptr[31]), "=r"(dst_ptr[32]),
"=r"(dst_ptr[33]), "=r"(dst_ptr[34]), "=r"(dst_ptr[35]),
"=r"(dst_ptr[36]), "=r"(dst_ptr[37]), "=r"(dst_ptr[38]),
"=r"(dst_ptr[39]), "=r"(dst_ptr[40]), "=r"(dst_ptr[41]),
"=r"(dst_ptr[42]), "=r"(dst_ptr[43]), "=r"(dst_ptr[44]),
"=r"(dst_ptr[45]), "=r"(dst_ptr[46]), "=r"(dst_ptr[47]),
"=r"(dst_ptr[48]), "=r"(dst_ptr[49]), "=r"(dst_ptr[50]),
"=r"(dst_ptr[51]), "=r"(dst_ptr[52]), "=r"(dst_ptr[53]),
"=r"(dst_ptr[54]), "=r"(dst_ptr[55]), "=r"(dst_ptr[56]),
"=r"(dst_ptr[57]), "=r"(dst_ptr[58]), "=r"(dst_ptr[59]),
"=r"(dst_ptr[60]), "=r"(dst_ptr[61]), "=r"(dst_ptr[62]),
"=r"(dst_ptr[63]), "=r"(dst_ptr[64]), "=r"(dst_ptr[65]),
"=r"(dst_ptr[66]), "=r"(dst_ptr[67]), "=r"(dst_ptr[68]),
"=r"(dst_ptr[69]), "=r"(dst_ptr[70]), "=r"(dst_ptr[71]),
"=r"(dst_ptr[72]), "=r"(dst_ptr[73]), "=r"(dst_ptr[74]),
"=r"(dst_ptr[75]), "=r"(dst_ptr[76]), "=r"(dst_ptr[77]),
"=r"(dst_ptr[78]), "=r"(dst_ptr[79]), "=r"(dst_ptr[80]),
"=r"(dst_ptr[81]), "=r"(dst_ptr[82]), "=r"(dst_ptr[83]),
"=r"(dst_ptr[84]), "=r"(dst_ptr[85]), "=r"(dst_ptr[86]),
"=r"(dst_ptr[87]), "=r"(dst_ptr[88]), "=r"(dst_ptr[89]),
"=r"(dst_ptr[90]), "=r"(dst_ptr[91]), "=r"(dst_ptr[92]),
"=r"(dst_ptr[93]), "=r"(dst_ptr[94]), "=r"(dst_ptr[95]),
"=r"(dst_ptr[96]), "=r"(dst_ptr[97]), "=r"(dst_ptr[98]),
"=r"(dst_ptr[99]), "=r"(dst_ptr[100]), "=r"(dst_ptr[101]),
"=r"(dst_ptr[102]), "=r"(dst_ptr[103]), "=r"(dst_ptr[104]),
"=r"(dst_ptr[105]), "=r"(dst_ptr[106]), "=r"(dst_ptr[107]),
"=r"(dst_ptr[108]), "=r"(dst_ptr[109]), "=r"(dst_ptr[110]),
"=r"(dst_ptr[111]), "=r"(dst_ptr[112]), "=r"(dst_ptr[113]),
"=r"(dst_ptr[114]), "=r"(dst_ptr[115]), "=r"(dst_ptr[116]),
"=r"(dst_ptr[117]), "=r"(dst_ptr[118]), "=r"(dst_ptr[119]),
"=r"(dst_ptr[120]), "=r"(dst_ptr[121]), "=r"(dst_ptr[122]),
"=r"(dst_ptr[123]), "=r"(dst_ptr[124]), "=r"(dst_ptr[125]),
"=r"(dst_ptr[126]), "=r"(dst_ptr[127])
: "r"(src_addr));
} else {
asm volatile("trap");
}
}
// 32 data path lanes, 32-bit pattern, repeated N times
template <int N, typename T>
CUTE_DEVICE void tmem_st_32dp32bNx(uint32_t const &dst_addr, T* src_ptr_) {
static_assert(N > 0 && (N & (N - 1)) == 0 && N <= 128, "N must be a power of 2 and lies between 1 ~ 128");
uint32_t* src_ptr = reinterpret_cast<uint32_t*>(src_ptr_);
if constexpr (N == 1) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x1.b32"
"[%1], {%0};\n"
:
: "r"(src_ptr[0]),
"r"(dst_addr));
} else if constexpr (N == 2) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x2.b32"
"[%2], {%0, %1};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]),
"r"(dst_addr));
} else if constexpr (N == 4) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x4.b32"
"[%4], {%0, %1, %2, %3};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]),
"r"(dst_addr));
} else if constexpr (N == 8) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32"
"[%8], {%0, %1, %2, %3, %4, %5, %6, %7};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]),
"r"(src_ptr[6]), "r"(src_ptr[7]),
"r"(dst_addr));
} else if constexpr (N == 16) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x16.b32"
"[%16], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, "
"%14, %15};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]),
"r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]),
"r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]),
"r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]),
"r"(src_ptr[15]),
"r"(dst_addr));
} else if constexpr (N == 32) {
asm volatile("tcgen05.st.sync.aligned.32x32b.x32.b32"
"[%32], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, "
"%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, "
"%26, %27, %28, %29, %30, %31};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]),
"r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]),
"r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]),
"r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]),
"r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]),
"r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]),
"r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]),
"r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]),
"r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]),
"r"(src_ptr[30]), "r"(src_ptr[31]),
"r"(dst_addr));
} else if constexpr (N == 64) {
asm volatile(
"tcgen05.st.sync.aligned.32x32b.x64.b32"
"[%64], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, "
"%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, "
"%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, "
"%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, "
"%57, %58, %59, %60, %61, %62, %63};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]),
"r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]),
"r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]),
"r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]),
"r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]),
"r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]),
"r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]),
"r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]),
"r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]),
"r"(src_ptr[30]), "r"(src_ptr[31]), "r"(src_ptr[32]),
"r"(src_ptr[33]), "r"(src_ptr[34]), "r"(src_ptr[35]),
"r"(src_ptr[36]), "r"(src_ptr[37]), "r"(src_ptr[38]),
"r"(src_ptr[39]), "r"(src_ptr[40]), "r"(src_ptr[41]),
"r"(src_ptr[42]), "r"(src_ptr[43]), "r"(src_ptr[44]),
"r"(src_ptr[45]), "r"(src_ptr[46]), "r"(src_ptr[47]),
"r"(src_ptr[48]), "r"(src_ptr[49]), "r"(src_ptr[50]),
"r"(src_ptr[51]), "r"(src_ptr[52]), "r"(src_ptr[53]),
"r"(src_ptr[54]), "r"(src_ptr[55]), "r"(src_ptr[56]),
"r"(src_ptr[57]), "r"(src_ptr[58]), "r"(src_ptr[59]),
"r"(src_ptr[60]), "r"(src_ptr[61]), "r"(src_ptr[62]),
"r"(src_ptr[63]),
"r"(dst_addr));
} else if constexpr (N == 128) {
asm volatile(
"tcgen05.st.sync.aligned.32x32b.x128.b32"
"[%128], {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, "
"%15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, "
"%29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, "
"%43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, "
"%57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, "
"%71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, "
"%85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, "
"%99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, "
"%110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, "
"%121, %122, %123, %124, %125, %126, %127};\n"
:
: "r"(src_ptr[0]), "r"(src_ptr[1]), "r"(src_ptr[2]),
"r"(src_ptr[3]), "r"(src_ptr[4]), "r"(src_ptr[5]),
"r"(src_ptr[6]), "r"(src_ptr[7]), "r"(src_ptr[8]),
"r"(src_ptr[9]), "r"(src_ptr[10]), "r"(src_ptr[11]),
"r"(src_ptr[12]), "r"(src_ptr[13]), "r"(src_ptr[14]),
"r"(src_ptr[15]), "r"(src_ptr[16]), "r"(src_ptr[17]),
"r"(src_ptr[18]), "r"(src_ptr[19]), "r"(src_ptr[20]),
"r"(src_ptr[21]), "r"(src_ptr[22]), "r"(src_ptr[23]),
"r"(src_ptr[24]), "r"(src_ptr[25]), "r"(src_ptr[26]),
"r"(src_ptr[27]), "r"(src_ptr[28]), "r"(src_ptr[29]),
"r"(src_ptr[30]), "r"(src_ptr[31]), "r"(src_ptr[32]),
"r"(src_ptr[33]), "r"(src_ptr[34]), "r"(src_ptr[35]),
"r"(src_ptr[36]), "r"(src_ptr[37]), "r"(src_ptr[38]),
"r"(src_ptr[39]), "r"(src_ptr[40]), "r"(src_ptr[41]),
"r"(src_ptr[42]), "r"(src_ptr[43]), "r"(src_ptr[44]),
"r"(src_ptr[45]), "r"(src_ptr[46]), "r"(src_ptr[47]),
"r"(src_ptr[48]), "r"(src_ptr[49]), "r"(src_ptr[50]),
"r"(src_ptr[51]), "r"(src_ptr[52]), "r"(src_ptr[53]),
"r"(src_ptr[54]), "r"(src_ptr[55]), "r"(src_ptr[56]),
"r"(src_ptr[57]), "r"(src_ptr[58]), "r"(src_ptr[59]),
"r"(src_ptr[60]), "r"(src_ptr[61]), "r"(src_ptr[62]),
"r"(src_ptr[63]), "r"(src_ptr[64]), "r"(src_ptr[65]),
"r"(src_ptr[66]), "r"(src_ptr[67]), "r"(src_ptr[68]),
"r"(src_ptr[69]), "r"(src_ptr[70]), "r"(src_ptr[71]),
"r"(src_ptr[72]), "r"(src_ptr[73]), "r"(src_ptr[74]),
"r"(src_ptr[75]), "r"(src_ptr[76]), "r"(src_ptr[77]),
"r"(src_ptr[78]), "r"(src_ptr[79]), "r"(src_ptr[80]),
"r"(src_ptr[81]), "r"(src_ptr[82]), "r"(src_ptr[83]),
"r"(src_ptr[84]), "r"(src_ptr[85]), "r"(src_ptr[86]),
"r"(src_ptr[87]), "r"(src_ptr[88]), "r"(src_ptr[89]),
"r"(src_ptr[90]), "r"(src_ptr[91]), "r"(src_ptr[92]),
"r"(src_ptr[93]), "r"(src_ptr[94]), "r"(src_ptr[95]),
"r"(src_ptr[96]), "r"(src_ptr[97]), "r"(src_ptr[98]),
"r"(src_ptr[99]), "r"(src_ptr[100]), "r"(src_ptr[101]),
"r"(src_ptr[102]), "r"(src_ptr[103]), "r"(src_ptr[104]),
"r"(src_ptr[105]), "r"(src_ptr[106]), "r"(src_ptr[107]),
"r"(src_ptr[108]), "r"(src_ptr[109]), "r"(src_ptr[110]),
"r"(src_ptr[111]), "r"(src_ptr[112]), "r"(src_ptr[113]),
"r"(src_ptr[114]), "r"(src_ptr[115]), "r"(src_ptr[116]),
"r"(src_ptr[117]), "r"(src_ptr[118]), "r"(src_ptr[119]),
"r"(src_ptr[120]), "r"(src_ptr[121]), "r"(src_ptr[122]),
"r"(src_ptr[123]), "r"(src_ptr[124]), "r"(src_ptr[125]),
"r"(src_ptr[126]), "r"(src_ptr[127]),
"r"(dst_addr));
} else {
asm volatile ("trap");
}
}
static constexpr int PEER_ADDR_MASK = 16777216; // peer_addr = my_addr ^ PEER_ADDR_MASK. 不确定是不是在所有显卡上都是这个数字
template<typename T>
CUTE_DEVICE
T* get_peer_addr(const T* p) {
return (T*)((int64_t)(p) ^ PEER_ADDR_MASK);
}
}
#pragma once
#include <cute/tensor.hpp>
namespace cute {
// Extensions to CuTe
// CuTe don't support UTCMMA with .ws, so we add it here
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
struct SM100_MMA_F16BF16_WS_SS_NOELECT
{
static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_SS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA.");
static_assert(N == 64 || N == 128 || N == 256,
"SM100_MMA_F16BF16_WS_SS_NOELECT N-mode size should be 32, 64 or 128");
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC));
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F16BF16_WS_SS_NOELECT<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_WS_SS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_ws_1sm<c_type>;
// Logical shape-K is always 256bits, transform to units of elements
static constexpr int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_1>;
using ALayout = Layout<Shape <_1,Shape <Int<M>,Int<K>>>,
Stride<_0,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape <_1,Shape <Int<N>,Int<K>>>,
Stride<_0,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape <_1,Shape <Int<M>,Int<N>>>,
Stride<_0,Stride< _1,Int<M>>>>;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_rmem<TA>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t desc_a = A[0];
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_WS_SS_NOELECT<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One,
UMMA::Saturate c_sat = UMMA::Saturate::False>
struct SM100_MMA_F16BF16_2x1SM_TS_NOELECT
{
static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA.");
static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_TS_NOELECT N-mode size should be a multiple of 32 between 32 and 256.");
static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT A from TMEM can't be transposed");
using DRegisters = void;
using ARegisters = uint32_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint32_t const& tmem_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED)
uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0};
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC),
"r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]),
"r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_TS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED");
#endif
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg,
UMMA::Saturate c_sat>
struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_TS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg, c_sat>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::tmem_frg_2sm<a_type, a_type, UMMA::TmemAllocMode::Duplicated>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;
// Size of instructions' K extent is always 256 bits; convert to units of element
constexpr static int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_2>;
using ALayout = Layout<Shape < _2,Shape <Int<M/2>,Int<K>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape < _2,Shape <Int<N/2>,Int<K>>>,
Stride<Int<N/2>,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape < _2,Shape <Int<M/2>,Int<N>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>();
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_tmem<TA>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t tmem_a = raw_pointer_cast(A.data());
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_2x1SM_TS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg, c_sat>::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
// SM100_MMA_F16BF16_2x1SM_SS without elect_one_sync()
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
struct SM100_MMA_F16BF16_2x1SM_SS_NOELECT
{
static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA.");
static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS_NOELECT N-mode size should be a multiple of 32 between 32 and 256.");
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED)
uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0};
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC),
"r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]),
"r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED");
#endif
}
};
// template <class a_type, class b_type, class c_type,
// int M, int N, UMMA::Major a_major, UMMA::Major b_major,
// UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
// struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS_NOELECT<a_type, b_type, c_type,
// M, N, a_major, b_major,
// a_neg, b_neg>> : MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS<a_type, b_type, c_type,
// M, N, a_major, b_major,
// a_neg, b_neg>> {};
template <class a_type, class b_type, class c_type,
int M, int N,
UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS_NOELECT<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;
// Size of instructions's K extent is always 256bits, convert to units of element
constexpr static int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_2>;
using ALayout = Layout<Shape < _2,Shape <Int<M/2>,Int<K>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape < _2,Shape <Int<N/2>,Int<K>>>,
Stride<Int<N/2>,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape < _2,Shape <Int<M/2>,Int<N>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_rmem<TA>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t desc_a = A[0];
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_2x1SM_SS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
}
\ No newline at end of file
#pragma once
#include <cute/tensor.hpp>
namespace cute {
// Extensions to CuTe
// CuTe 自带的 SM100_TMA_2SM_LOAD_1D 系列要求参与的 thread 数量为 2(using ThrID = Layout<_2>;),还会对数据进行切分,用起来太恶心了,所以我们自己改一版。另外,为了和其他使用 SM90 TMA 的部分统一,这里我们让它接受 TMA::CacheHintSm90 而不是 TMA::CacheHintSm100
////////////////////////////////////////////////////////////////////////////////////////////////////
/// TMA_LOAD : Initiates a TMA copy from global memory to shared memory
////////////////////////////////////////////////////////////////////////////////////////////////////
struct SM100_TMA_2SM_LOAD_1D_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint,
[[maybe_unused]] void * smem_ptr,
[[maybe_unused]] int32_t const& crd0)
{
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
// Executed by both CTAs. Set peer bit to 0 so that the
// transaction bytes will update CTA0's barrier.
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.1d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3}], [%2], %4;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "l"(cache_hint)
: "memory");
#else
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
#endif
}
};
struct SM100_TMA_2SM_LOAD_2D_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint,
[[maybe_unused]] void * smem_ptr,
[[maybe_unused]] int32_t const& crd0, int32_t const& crd1)
{
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
// Executed by both CTAs. Set peer bit to 0 so that the
// transaction bytes will update CTA0's barrier.
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.2d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4}], [%2], %5;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "l"(cache_hint)
: "memory");
#else
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
#endif
}
};
struct SM100_TMA_2SM_LOAD_3D_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint,
[[maybe_unused]] void * smem_ptr,
[[maybe_unused]] int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
{
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
// Executed by both CTAs. Set peer bit to 0 so that the
// transaction bytes will update CTA0's barrier.
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.3d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4, %5}], [%2], %6;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint)
: "memory");
#else
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
#endif
}
};
struct SM100_TMA_2SM_LOAD_4D_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
{
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
// Executed by both CTAs. Set peer bit to 0 so that the
// transaction bytes will update CTA0's barrier.
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.4d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4, %5, %6}], [%2], %7;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint)
: "memory");
#else
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
#endif
}
};
struct SM100_TMA_2SM_LOAD_5D_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
{
#if defined(CUTE_ARCH_TMA_SM100_ENABLED)
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr);
// Executed by both CTAs. Set peer bit to 0 so that the
// transaction bytes will update CTA0's barrier.
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask;
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
asm volatile (
"cp.async.bulk.tensor.5d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint"
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;"
:
: "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar),
"r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint)
: "memory");
#else
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED.");
#endif
}
};
struct SM100_TMA_2SM_LOAD_NOSPLIT
{
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0)
{
return SM100_TMA_2SM_LOAD_1D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0);
}
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1)
{
return SM100_TMA_2SM_LOAD_2D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1);
}
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2)
{
return SM100_TMA_2SM_LOAD_3D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2);
}
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3)
{
return SM100_TMA_2SM_LOAD_4D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3);
}
CUTE_HOST_DEVICE static void
copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint,
void * smem_ptr,
int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4)
{
return SM100_TMA_2SM_LOAD_5D_NOSPLIT::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4);
}
using PREFETCH = typename SM90_TMA_LOAD::PREFETCH;
};
struct SM100_TMA_2SM_LOAD_NOSPLIT_OP : SM100_TMA_2SM_LOAD_NOSPLIT {};
// The non-executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_desc and no tma_mbar
// Use .with(tma_mbar) to construct an executable version
template <class NumBitsPerTMA, class AuxParams_>
struct Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT, NumBitsPerTMA, AuxParams_>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM100_TMA_2SM_LOAD_NOSPLIT arguments
TmaDescriptor tma_desc_;
using AuxParams = AuxParams_;
AuxParams aux_params_;
// Return TmaDescriptor/TensorMap
CUTE_HOST_DEVICE constexpr
TmaDescriptor const*
get_tma_descriptor() const {
return &tma_desc_;
}
// Construct an executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_mbar
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>
with(
uint64_t& tma_mbar,
[[maybe_unused]] uint16_t const& multicast_mask = 0,
TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
// We accept multicast_mask here to keep the API for both atoms consistent
return {&tma_desc_, &tma_mbar, static_cast<uint64_t>(cache_hint)};
}
// Construct an executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm)
CUTE_HOST_DEVICE constexpr
Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>
with(
TmaDescriptor const* new_tma_desc,
uint64_t& tma_mbar,
[[maybe_unused]] uint16_t const& multicast_mask = 0,
TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const {
// We accept multicast_mask here to keep the API for both atoms consistent
return {new_tma_desc, &tma_mbar, static_cast<uint64_t>(cache_hint)};
}
// Generate the TMA coord tensor
template <class GShape>
CUTE_HOST_DEVICE constexpr
auto
get_tma_tensor(GShape const& g_shape) const {
static_assert(is_congruent<decltype(g_shape), decltype(aux_params_.g_stride_)>::value);
return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_));
}
// Don't try to execute a copy with SM100_TMA_2SM_LOAD_NOSPLIT before calling .with()
template <class TS, class SLayout,
class TD, class DLayout>
CUTE_HOST_DEVICE friend constexpr void
copy_unpack(Copy_Traits const& traits,
Tensor<TS,SLayout> const& src,
Tensor<TD,DLayout> & dst) = delete;
};
// The executable SM100_TMA_2SM_LOAD_NOSPLIT with tma_desc and tma_mbar
template <class NumBitsPerTMA>
struct Copy_Traits<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>
: TMA_LOAD_Unpack<SM100_TMA_2SM_LOAD_NOSPLIT_OP, NumBitsPerTMA>
{
using ThrID = Layout<_1>;
// Map from (src-thr,src-val) to bit
using SrcLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Map from (dst-thr,dst-val) to bit
using DstLayout = Layout<Shape<_1,NumBitsPerTMA>>;
// Reference map from (thr,val) to bit
using RefLayout = SrcLayout;
// SM100_TMA_2SM_LOAD_NOSPLIT arguments
tuple<
TmaDescriptor const*,
uint64_t*, // smem mbarrier
uint64_t // cache hint
> const opargs_;
CUTE_HOST_DEVICE
Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint64_t cache)
: opargs_(desc, mbar, cache) {}
};
}
#pragma once
#include <cute/tensor.hpp>
namespace cute {
// Extensions to CuTe
// CuTe don't support UTCMMA with .ws, so we add it here
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
struct SM100_MMA_F16BF16_WS_SS_NOELECT
{
static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_SS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA.");
static_assert(N == 64 || N == 128 || N == 256,
"SM100_MMA_F16BF16_WS_SS_NOELECT N-mode size should be 32, 64 or 128");
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC));
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F16BF16_WS_SS_NOELECT<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_WS_SS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_ws_1sm<c_type>;
// Logical shape-K is always 256bits, transform to units of elements
static constexpr int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_1>;
using ALayout = Layout<Shape <_1,Shape <Int<M>,Int<K>>>,
Stride<_0,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape <_1,Shape <Int<N>,Int<K>>>,
Stride<_0,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape <_1,Shape <Int<M>,Int<N>>>,
Stride<_0,Stride< _1,Int<M>>>>;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_rmem<TA>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t desc_a = A[0];
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_WS_SS_NOELECT<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
using namespace cute;
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
struct SM100_MMA_F16BF16_WS_TS_NOELECT
{
static_assert(M == 32 || M == 64 || M == 128, "SM100_MMA_F16BF16_WS_TS_NOELECT M-mode size should be 32, 64 or 128 for 1 CTA cluster MMA.");
static_assert(N == 64 || N == 128 || N == 256,
"SM100_MMA_F16BF16_WS_TS_NOELECT N-mode size should be 32, 64 or 128");
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint32_t const& tmem_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], [%1], %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC));
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F16BF16_WS_TS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_WS_TS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::tmem_frg_1sm<a_type, a_type, UMMA::TmemAllocMode::NonInterleaved>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_1sm<c_type, int32_t, UMMA::TmemAllocMode::NonInterleaved>;
// Logical shape-K is always 256 bits; transform to units of elements
static constexpr int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_1>;
using ALayout = Layout<Shape <_1,Shape <Int<M>,Int<K>>>,
Stride<_0,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape <_1,Shape <Int<N>,Int<K>>>,
Stride<_0,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape <_1,Shape <Int<M>,Int<N>>>,
Stride<_0,Stride< _1,Int<M>>>>;
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_tmem<TA>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint32_t tmem_a = raw_pointer_cast(A.data());
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_WS_TS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg>::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One,
UMMA::Saturate c_sat = UMMA::Saturate::False>
struct SM100_MMA_F16BF16_2x1SM_TS_NOELECT
{
static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA.");
static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_TS_NOELECT N-mode size should be a multiple of 32 between 32 and 256.");
static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT A from TMEM can't be transposed");
using DRegisters = void;
using ARegisters = uint32_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint32_t const& tmem_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED)
uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0};
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC),
"r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]),
"r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_TS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED");
#endif
}
};
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg,
UMMA::Saturate c_sat>
struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_TS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg, c_sat>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_2x1SM_TS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::tmem_frg_2sm<a_type, a_type, UMMA::TmemAllocMode::Duplicated>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;
// Size of instructions' K extent is always 256 bits; convert to units of element
constexpr static int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_2>;
using ALayout = Layout<Shape < _2,Shape <Int<M/2>,Int<K>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape < _2,Shape <Int<N/2>,Int<K>>>,
Stride<Int<N/2>,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape < _2,Shape <Int<M/2>,Int<N>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>();
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_tmem<TA>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t tmem_a = raw_pointer_cast(A.data());
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_2x1SM_TS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg, c_sat>::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
// SM100_MMA_F16BF16_2x1SM_SS without elect_one_sync()
template <class a_type, class b_type, class c_type,
int M, int N, UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One, UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
struct SM100_MMA_F16BF16_2x1SM_SS_NOELECT
{
static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT M-mode size should be 128 or 256 for 2 CTA cluster MMA.");
static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS_NOELECT N-mode size should be a multiple of 32 between 32 and 256.");
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scaleC,
uint64_t const& idescE)
{
#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED)
uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0};
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC),
"r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]),
"r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]));
#else
CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS_NOELECT without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED");
#endif
}
};
// template <class a_type, class b_type, class c_type,
// int M, int N, UMMA::Major a_major, UMMA::Major b_major,
// UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
// struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS_NOELECT<a_type, b_type, c_type,
// M, N, a_major, b_major,
// a_neg, b_neg>> : MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS<a_type, b_type, c_type,
// M, N, a_major, b_major,
// a_neg, b_neg>> {};
template <class a_type, class b_type, class c_type,
int M, int N,
UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg, UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F16BF16_2x1SM_SS_NOELECT<a_type, b_type, c_type,
M, N, a_major, b_major,
a_neg, b_neg>>
{
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> && cute::sizeof_bits_v<b_type> == 16, "SM100_MMA_F16BF16_2x1SM_SS_NOELECT supports 16bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_2sm<c_type>;
// Size of instructions's K extent is always 256bits, convert to units of element
constexpr static int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>,Int<N>,Int<K>>;
using ThrID = Layout<_2>;
using ALayout = Layout<Shape < _2,Shape <Int<M/2>,Int<K>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
using BLayout = Layout<Shape < _2,Shape <Int<N/2>,Int<K>>>,
Stride<Int<N/2>,Stride< _1,Int<N>>>>;
using CLayout = Layout<Shape < _2,Shape <Int<M/2>,Int<N>>>,
Stride<Int<M/2>,Stride< _1,Int<M>>>>;
UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc<
a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>();
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
template <class TD, class DLayout,
class TA, class ALayout,
class TB, class BLayout,
class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend
void
mma_unpack(MMA_Traits const& traits,
Tensor<TD, DLayout> & D,
Tensor<TA, ALayout> const& A,
Tensor<TB, BLayout> const& B,
Tensor<TC, CLayout> const& C)
{
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_rmem<TA>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value, "Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t desc_a = A[0];
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_2x1SM_SS_NOELECT<a_type, b_type, c_type,
M, N,
a_major, b_major,
a_neg, b_neg>::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc);
}
};
} // namespace cute
......@@ -9,6 +9,7 @@ from torch.utils.cpp_extension import (
BuildExtension,
CUDAExtension,
IS_WINDOWS,
CUDA_HOME
)
......@@ -22,8 +23,21 @@ def get_features_args():
return features_args
def get_arch_flags():
# Check NVCC Version
# NOTE The "CUDA_HOME" here is not necessarily from the `CUDA_HOME` environment variable. For more details, see `torch/utils/cpp_extension.py`
assert CUDA_HOME is not None, "PyTorch must be compiled with CUDA support"
nvcc_version = subprocess.check_output(
[os.path.join(CUDA_HOME, "bin", "nvcc"), '--version'], stderr=subprocess.STDOUT
).decode('utf-8')
nvcc_version_number = nvcc_version.split('release ')[1].split(',')[0].strip()
major, minor = map(int, nvcc_version_number.split('.'))
print(f'Compiling using NVCC {major}.{minor}')
DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100")
DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90")
if major < 12 or (major == 12 and minor <= 8):
assert DISABLE_SM100, "sm100 compilation for Flash MLA requires NVCC 12.9 or higher. Please set FLASH_MLA_DISABLE_SM100=1 to disable sm100 compilation, or update your environment."
arch_flags = []
if not DISABLE_SM100:
arch_flags.extend(["-gencode", "arch=compute_100a,code=sm_100a"])
......@@ -55,8 +69,10 @@ ext_modules.append(
"csrc/sm90/decode/dense/splitkv_mla.cu",
"csrc/sm90/decode/sparse_fp8/splitkv_mla.cu",
"csrc/sm90/prefill/sparse/fwd.cu",
"csrc/sm100/decode/sparse_fp8/splitkv_mla.cu",
"csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu",
"csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu",
"csrc/sm100/prefill/sparse/fwd.cu",
],
extra_compile_args={
"cxx": cxx_args + get_features_args(),
......
......@@ -320,6 +320,11 @@ def main(torch_dtype):
testcases = correctness_cases + corner_cases + performance_cases
# Prune out unsupported cases
cc_major, cc_minor = torch.cuda.get_device_capability()
if cc_major == 10:
testcases = [t for t in testcases if (t.is_fp8 and t.topk is not None)]
for testcase in testcases:
test_flash_mla(testcase)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment