Unverified Commit c2067be3 authored by Shengyu Liu's avatar Shengyu Liu Committed by GitHub
Browse files

Performance Update (2025.04.22) (#71)

* Fix benchmark script

* Performance optimization for compute-bound cases

* Add new testcase (s_k = 16384)

* Update README.md

* Update comment

* Update README.md

* Add the deep-dive blog

* Add background color for MLA Kernel Sched.drawio.svg

* Use relative path for the schedule image

* Move flash_mla.h to kernels/params.h
parent b31bfe72
......@@ -5,3 +5,4 @@ __pycache__/
dist/
*perf.csv
*.png
/.vscode
# FlashMLA
## Performance Update (2025.04.22)
We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement on compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Just switch to the new version and enjoy the instant speedup! 🚀🚀🚀
Besides, we'd love to share the technical details behind the new kernel! Check out our deep-dive write-up here: <LINK>
The new kernel primarily targets compute-intensive settings (where the number of q heads $\times$ the number of q tokens per request (if MTP is disabled then it's 1) $\ge 64$). For memory-bound cases, we recommend using version [b31bfe7](https://github.com/deepseek-ai/FlashMLA/tree/b31bfe72a83ea205467b3271a5845440a03ed7cb) for optimal performance.
## Introduction
FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving.
Currently released:
- BF16, FP16
- Paged kvcache with block size of 64
## Requirements
- Hopper GPUs
- CUDA 12.3 and above
- **But we highly recommend 12.8 or above for the best performance**
- PyTorch 2.0 and above
## Quick start
### Install
......@@ -20,7 +37,9 @@ python setup.py install
python tests/test_flash_mla.py
```
Achieving up to 3000 GB/s in memory-bound configuration and 580 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8.
It is able up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8.
Note. For memory-bound cases, we recommend using version [b31bfe7](https://github.com/deepseek-ai/FlashMLA/tree/b31bfe72a83ea205467b3271a5845440a03ed7cb) for optimal performance.
### Usage
......@@ -38,13 +57,6 @@ for i in range(num_layers):
...
```
## Requirements
- Hopper GPUs
- CUDA 12.3 and above
- **But we highly recommend 12.8 or above for the best performance**
- PyTorch 2.0 and above
## Acknowledgement
FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) and [cutlass](https://github.com/nvidia/cutlass) projects.
......@@ -91,7 +103,7 @@ The corresponding FlashMLA version can be found at: [AITER/MLA](https://github.c
```bibtex
@misc{flashmla2025,
title={FlashMLA: Efficient MLA decoding kernels},
author={Jiashi Li},
author={Jiashi Li, Shengyu Liu},
year={2025},
publisher = {GitHub},
howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}},
......
......@@ -435,7 +435,7 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out"
if target not in ["flash_infer", "flash_mla_triton"]:
if target not in ["flash_infer", "flash_mla_triton"] and baseline not in ["flash_infer", "flash_mla_triton"]:
# flash_infer has a different lse return value
# flash_mla_triton doesn't return lse
torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"
......
......@@ -10,8 +10,11 @@
#include <cutlass/fast_math.h>
#include "flash_mla.h"
#include "static_switch.h"
#include "kernels/config.h"
#include "kernels/get_mla_metadata.h"
#include "kernels/mla_combine.h"
#include "kernels/params.h"
#include "kernels/splitkv_mla.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
......@@ -23,11 +26,6 @@ get_mla_metadata(
const int num_heads_per_head_k,
const int num_heads_k
) {
// This should match the logic in the MLA kernel.
static constexpr int block_size_m = 64;
static constexpr int block_size_n = 64;
static constexpr int fixed_overhead_num_blocks = 5;
CHECK_DEVICE(seqlens_k);
TORCH_CHECK(seqlens_k.is_contiguous());
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32);
......@@ -38,7 +36,7 @@ get_mla_metadata(
auto dprops = at::cuda::getCurrentDeviceProperties();
int sm_count = dprops->multiProcessorCount;
int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m);
int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, Config::BLOCK_SIZE_M);
auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options);
auto num_splits = torch::empty({batch_size + 1}, options);
......@@ -52,10 +50,10 @@ get_mla_metadata(
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr;
params.num_splits_ptr = num_splits_ptr;
params.batch_size = batch_size;
params.block_size_n = block_size_n;
params.fixed_overhead_num_blocks = fixed_overhead_num_blocks;
params.block_size_n = Config::PAGE_BLOCK_SIZE;
params.fixed_overhead_num_blocks = Config::FIXED_OVERHEAD_NUM_BLOCKS;
params.num_sm_parts = num_sm_parts;
get_mla_metadata_func(params, stream);
run_get_mla_metadata_kernel(params, stream);
return {tile_scheduler_metadata, num_splits};
}
......@@ -64,7 +62,6 @@ std::vector<at::Tensor>
mha_fwd_kvcache_mla(
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size
std::optional<const at::Tensor> &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v
const int head_size_v,
const at::Tensor &seqlens_k, // batch_size
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
......@@ -73,138 +70,141 @@ mha_fwd_kvcache_mla(
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
const at::Tensor &num_splits // batch_size + 1
) {
// Check the architecture
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
TORCH_CHECK(is_sm90);
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
// Check data types
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf);
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
// Check device
CHECK_DEVICE(q);
CHECK_DEVICE(kcache);
CHECK_DEVICE(seqlens_k);
CHECK_DEVICE(block_table);
CHECK_DEVICE(tile_scheduler_metadata);
CHECK_DEVICE(num_splits);
// Check layout
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
CHECK_CONTIGUOUS(seqlens_k);
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
CHECK_CONTIGUOUS(tile_scheduler_metadata);
CHECK_CONTIGUOUS(num_splits);
const auto sizes = q.sizes();
const int batch_size = sizes[0];
const int seqlen_q_ori = sizes[1];
const int num_heads_ori = sizes[2];
const int head_size = sizes[3];
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32");
const int num_heads_q = sizes[2];
const int head_size_k = sizes[3];
TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 is supported");
TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported");
const int max_num_blocks_per_seq = block_table.size(1);
const int num_blocks = kcache.size(0);
const int page_block_size = kcache.size(1);
const int num_heads_k = kcache.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q_ori == 1) { is_causal = false; }
const int ngroups = num_heads_ori / num_heads_k;
const int seqlen_q = seqlen_q_ori * ngroups;
const int num_q_heads_per_hk = num_heads_q / num_heads_k;
const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk;
const int num_heads = num_heads_k;
q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3)
.reshape({batch_size, seqlen_q, num_heads, head_size});
q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3)
.reshape({batch_size, q_seq_per_hk, num_heads, head_size_k});
int head_size_k = head_size;
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k);
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); }
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
CHECK_DEVICE(seqlens_k);
CHECK_CONTIGUOUS(seqlens_k);
CHECK_SHAPE(seqlens_k, batch_size);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
CHECK_SHAPE(num_splits, batch_size+1);
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts);
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts);
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
CHECK_CONTIGUOUS(softmax_lse);
Flash_fwd_mla_params params = {};
// Set the sizes.
params.b = batch_size;
params.seqlen_q = seqlen_q;
params.cu_seqlens_k = seqlens_k.data_ptr<int>();
params.h = num_heads;
params.h_h_k_ratio = num_heads / num_heads_k;
params.ngroups = ngroups;
params.s_q = seqlen_q_ori;
params.q_seq_per_hk = q_seq_per_hk;
params.seqlens_k_ptr = seqlens_k.data_ptr<int>();
params.h_q = num_heads_q;
params.h_k = num_heads_k;
params.num_blocks = num_blocks;
params.q_head_per_hk = num_q_heads_per_hk;
params.is_causal = is_causal;
params.d = head_size;
params.d = head_size_k;
params.d_v = head_size_v;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = kcache.data_ptr();
params.v_ptr = vcache.data_ptr();
params.o_ptr = out.data_ptr();
params.softmax_lse_ptr = softmax_lse.data_ptr();
// All stride are in elements, not bytes.
params.q_batch_stride = q.stride(0);
params.k_batch_stride = kcache.stride(0);
params.v_batch_stride = vcache.stride(0);
params.o_batch_stride = out.stride(0);
params.q_row_stride = q.stride(-3);
params.k_row_stride = kcache.stride(-3);
params.v_row_stride = vcache.stride(-3);
params.o_row_stride = out.stride(-3);
params.q_head_stride = q.stride(-2);
params.k_head_stride = kcache.stride(-2);
params.v_head_stride = vcache.stride(-2);
params.o_head_stride = out.stride(-2);
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
params.page_block_size = page_block_size;
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
CHECK_DEVICE(tile_scheduler_metadata);
CHECK_CONTIGUOUS(tile_scheduler_metadata);
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
params.num_sm_parts = tile_scheduler_metadata.size(0);
TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
CHECK_DEVICE(num_splits);
CHECK_CONTIGUOUS(num_splits);
params.num_splits_ptr = num_splits.data_ptr<int>();
at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat));
const int total_num_splits = batch_size + params.num_sm_parts;
at::Tensor softmax_lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat));
CHECK_CONTIGUOUS(softmax_lse_accum);
CHECK_CONTIGUOUS(out_accum);
params.total_num_splits = total_num_splits;
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(head_size == 576);
TORCH_CHECK(head_size_k == 576);
if (q_dtype == torch::kBFloat16) {
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, stream);
}
#ifndef FLASH_MLA_DISABLE_FP16
else if (q_dtype == torch::kHalf) {
run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(params, stream);
}
#endif
else {
run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(params, stream);
run_flash_mla_combine_kernel<cutlass::bfloat16_t>(params, stream);
} else if (q_dtype == torch::kHalf) {
#ifdef FLASH_MLA_DISABLE_FP16
TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA.");
#else
run_flash_splitkv_mla_kernel<cutlass::half_t>(params, stream);
run_flash_mla_combine_kernel<cutlass::half_t>(params, stream);
#endif
} else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
.reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3)
.reshape({batch_size, num_heads_ori, seqlen_q_ori});
out = out.view({batch_size, seqlen_q_ori, num_q_heads_per_hk, num_heads_k, head_size_v}).transpose(2, 3)
.reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v});
softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3)
.reshape({batch_size, num_heads_q, seqlen_q_ori});
return {out, softmax_lse};
}
......
#include "flash_fwd_mla_kernel.h"
template void run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(Flash_fwd_mla_params &params, cudaStream_t stream);
#include "flash_fwd_mla_kernel.h"
template void run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(Flash_fwd_mla_params &params, cudaStream_t stream);
#pragma once
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
using namespace cute;
#include "named_barrier.h"
#include "utils.h"
#include "softmax.h"
#include "static_switch.h"
#include "flash_mla.h"
template<typename PrecType, int DIM, int DIM2 = DIM>
constexpr auto getSmemLayoutK() {
constexpr int headSizeBytes = sizeof(PrecType) * DIM;
constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2;
if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) {
return GMMA::Layout_K_SW128_Atom<PrecType>{};
} else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) {
return GMMA::Layout_K_SW64_Atom<PrecType>{};
} else {
return GMMA::Layout_K_SW32_Atom<PrecType>{};
}
}
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t, int kHeadDimV_ = 0>
struct Flash_fwd_kernel_traits_mla {
using Element = elem_type;
using ElementAccum = float;
using index_t = int64_t;
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * 32;
static constexpr int kNWarpsS = 4;
static constexpr int kNThreadsS = kNWarpsS * 32;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 32 == 0);
static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim;
static_assert(kHeadDimV % 32 == 0);
static_assert(kHeadDimV <= kHeadDim);
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
using TiledMma = decltype(make_tiled_mma(
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>,
GMMA::Major::K, GMMA::Major::K>(),
Layout<Shape<Int<kNWarpsS / 4>, _1, _1>>{}));
static constexpr int AtomLayoutNO = kNThreads / kNThreadsS;
using TiledMmaO = decltype(make_tiled_mma(
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kHeadDimV / AtomLayoutNO>, Int<kBlockN>>,
GMMA::Major::K, GMMA::Major::MN>(),
Layout<Shape<Int<kNWarpsS / 4>, Int<AtomLayoutNO>, _1>>{}));
using SmemLayoutQ = decltype(tile_to_shape(
getSmemLayoutK<Element, kHeadDim>(),
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemLayoutK = decltype(tile_to_shape(
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
using SmemLayoutV = decltype(tile_to_shape(
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutP = Layout<Shape<Shape<_2, _2>, Int<kNThreadsS>, _1, Int<kBlockN / 8>>>;
using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
using SmemLayoutAtomO = decltype(composition(
Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<Int<8>, Int<kBlockKSmem>>, Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDimV>>{}));
using SmemCopyAtomO = Copy_Atom<SM90_U32x4_STSM_N, Element>;
using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>;
static constexpr int kNThreadsLoad = kNThreads - kNThreadsS;
static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<
Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
using GmemTiledCopy = decltype(make_tiled_copy(
Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemLayoutAtomO = Layout<
Shape<Int<kNThreadsS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
using GmemTiledCopyO = decltype(make_tiled_copy(
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
GmemLayoutAtomO{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum;
using GmemLayoutAtomOaccum = Layout<
Shape<Int<kNThreadsS / kGmemThreadsPerRowAccum>, Int<kGmemThreadsPerRowAccum>>,
Stride<Int<kGmemThreadsPerRowAccum>, _1>>;
using GmemTiledCopyOaccum = decltype(make_tiled_copy(
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape<_1, _4>>{})); // Val layout, 4 vals per store
};
namespace flash {
using namespace cute;
template<typename Kernel_traits>
struct SharedStorageMLA {
union {
struct {
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutK> * 2> smem_k; // Double buffer
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_scale;
};
struct {
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_max;
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_sum;
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;
};
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Split, typename SharedStorage, typename AccO, typename Softmax>
__forceinline__ __device__ void store(const Flash_fwd_mla_params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx,
SharedStorage &shared_storage, AccO tOrO, Softmax softmax) {
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
const int tidx = threadIdx.x;
typename Kernel_traits::TiledMmaO tiled_mma_o;
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
// Epilogue
const int split_offset = __ldg(params.num_splits_ptr + bidb);
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(tOrO, params.scale_softmax);
using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
using SmemTiledCopyO = std::conditional_t<
!Split,
typename Kernel_traits::SmemCopyAtomO,
typename Kernel_traits::SmemCopyAtomOaccum
>;
auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o);
auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor rO = flash::convert_type<ElementO>(tOrO);
Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
__syncthreads();
cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v;
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)),
Shape<Int<kBlockM>>{}, Stride<_1>{});
using GmemTiledCopyO = std::conditional_t<!Split, typename Kernel_traits::GmemTiledCopyO, typename Kernel_traits::GmemTiledCopyOaccum>;
GmemTiledCopyO gmem_tiled_copy_Oaccum;
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
__syncthreads();
if (tidx >= kNThreadsS) { return; }
Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1)
Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0);
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
if (get<1>(taccOcO_row(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO_row(mi));
if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
}
}
// Construct identity layout for sO
Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM
);
}
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params &params,
const int bidb, const int bidh, const int m_block,
const int n_split_idx, const int seqlen_k,
const int n_block_min, const int n_block_max, const bool NoSplit,
SharedStorage &shared_storage) {
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int kNThreads = Kernel_traits::kNThreads;
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
static_assert(kNThreads == 256 and kNThreadsS == 128);
using Element = typename Kernel_traits::Element;
using index_t = typename Kernel_traits::index_t;
const int tidx = threadIdx.x;
int n_block = n_block_max - 1;
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _);
Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{});
Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS);
Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{});
Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS);
Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{});
Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS);
typename Kernel_traits::TiledMmaO tiled_mma_o;
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N)
Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // ((MMA=4, X), MMA_M, MMA_N=1)
clear(tOrO);
flash::Softmax<2 * size<1>(tOrO)> softmax;
int warp_group_idx = cutlass::canonical_warp_group_idx();
if (warp_group_idx == 0) {
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
if (n_block % 2 == 1) {
// Double buffer for sK
constexpr int sK_offset = size(sK);
tSrK.data() = tSrK.data() + sK_offset / 8;
tOrVt.data() = tOrVt.data() + sK_offset / 8;
}
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
#pragma unroll 1
for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) {
__syncthreads();
Tensor tSrS = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // ((MMA=4, X), MMA_M, MMA_N=1)
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma, tSrQ, tSrK, tSrS);
const bool is_masking_step = masking_step > 0;
const bool is_first_masking_step = masking_step == n_masking_steps;
if (is_masking_step) {
Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tScS = thr_mma.partition_C(cS);
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
if constexpr (!Is_causal) { // Just masking based on col
if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) tSrS(i) = -INFINITY;
} else {
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int row = int(get<0>(tScS(i)));
int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups;
if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY;
}
}
}
// We have key_padding_mask so we'll need to Check_inf
Tensor scale_o = is_first_masking_step
? softmax.template softmax</*Is_first=*/true, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
: is_masking_step ?
softmax.template softmax</*Is_first=*/false, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
: softmax.template softmax</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(tSrS);
cute::copy(rP, tPsP);
cute::copy(scale_o, tScale_osScale_o);
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SReady));
flash::rescale_o(tOrO, scale_o);
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
// Double buffer for sK
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
tSrK.data() = tSrK.data() + sK_offset / 8;
tOrVt.data() = tOrVt.data() + sK_offset / 8;
}
cute::copy(softmax.row_max, tRow_maxsRow_max);
cute::copy(softmax.row_sum, tRow_sumsRow_sum);
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
} else {
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
int cur_block_table = __ldg(&block_table[n_block]);
const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q;
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS);
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
params.seqlen_q - m_block * kBlockM);
const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride;
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K;
auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS);
Tensor tKgK = gmem_thr_copy_K.partition_S(gK);
Tensor tKsK = gmem_thr_copy_K.partition_D(sK);
Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor tKcK = gmem_thr_copy_K.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
Tensor tKpK = make_tensor<bool>(make_shape(size<2>(tKsK)));
if (n_block % 2 == 1) {
// Double buffer for sK
constexpr int sK_offset = size(sK);
tKsK.data() = tKsK.data() + sK_offset;
tOrVt.data() = tOrVt.data() + sK_offset / 8;
}
// We need to clear the sK smem tiles because K is V.
const index_t offset_k = cur_block_table * params.k_batch_stride;
tKgK.data() = tKgK.data() + offset_k;
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK,
seqlen_k - n_block * kBlockN);
tKgK.data() = tKgK.data() + -offset_k;
cute::cp_async_fence();
if (n_block - 1 >= n_block_min) {
cur_block_table = __ldg(&block_table[n_block - 1]);
}
#pragma unroll 1
for (; n_block >= n_block_min; --n_block) {
flash::cp_async_wait<0>();
__syncthreads();
if (n_block - 1 >= n_block_min) {
// Double buffer for sK
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
tKsK.data() = tKsK.data() + sK_offset;
const index_t offset_k = cur_block_table * params.k_batch_stride;
tKgK.data() = tKgK.data() + offset_k;
flash::copy</*Is_even_MN=*/true, /*Is_even_K=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK);
tKgK.data() = tKgK.data() + -offset_k;
cute::cp_async_fence();
}
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SReady));
if (n_block - 2 >= n_block_min) {
cur_block_table = __ldg(&block_table[n_block - 2]);
}
typename Kernel_traits::TiledMma tiled_mma;
auto tSrS_layout = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}).layout();
Tensor rP = make_tensor<Element>(tSrS_layout);
Tensor scale_o = make_tensor<float>(Shape<_2>{});
cute::copy(tScale_osScale_o, scale_o);
cute::copy(tPsP, rP);
flash::rescale_o(tOrO, scale_o);
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
// Double buffer for sK
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
tOrVt.data() = tOrVt.data() + sK_offset / 8;
}
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
cute::copy(tRow_maxsRow_max, softmax.row_max);
cute::copy(tRow_sumsRow_sum, softmax.row_sum);
}
if (NoSplit)
store<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
else
store<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
}
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1)
flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
constexpr int kBlockN = Kernel_traits::kBlockN;
const int m_block = blockIdx.x;
const int bidh = blockIdx.y;
const int partition_idx = blockIdx.z;
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<SharedStorage *>(shared_memory);
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
int4 tile_scheduler_metadata = __ldg(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
int begin_idx = tile_scheduler_metadata.x;
int begin_seqlen = tile_scheduler_metadata.y;
int end_idx = tile_scheduler_metadata.z;
int end_seqlen = tile_scheduler_metadata.w;
if (begin_idx >= params.b) return;
int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
#pragma unroll 1
for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) {
const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0;
const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id);
const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0;
const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN);
const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN);
if (batch_id > begin_idx) {
__syncthreads(); // Barrier between two tiles.
}
flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Element, typename ElementAccum, typename index_t, int kHeadDimV, int kMaxSplits>
__global__ void __launch_bounds__(256, 1, 1)
flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
constexpr int kNThreads = 128;
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
const int hs = params.h * params.seqlen_q;
const int batch_idx = bidx / hs;
const int hs_idx = bidx % hs;
const int split_offset = __ldg(params.num_splits_ptr + batch_idx);
const int actual_num_splits = __ldg(params.num_splits_ptr + batch_idx + 1) - split_offset;
FLASH_DEVICE_ASSERT(actual_num_splits <= kMaxSplits);
if (actual_num_splits == 1) return;
__shared__ ElementAccum sLseScale[kMaxSplits];
const index_t row_offset_lseaccum = split_offset * hs + hs_idx;
const index_t row_offset_lse = bidx;
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lseaccum),
Shape<Int<kMaxSplits>>{}, make_stride(hs));
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<_1>{}, Stride<_1>{});
int warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0) {
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
float local_lse[kNLsePerThread];
for (int i = 0; i < kNLsePerThread; ++i) {
const int split = i * 32 + tidx;
local_lse[i] = split < actual_num_splits ? gLSEaccum(split) : -INFINITY;
}
float max_lse = -INFINITY;
for (int i = 0; i < kNLsePerThread; ++i) max_lse = max(max_lse, local_lse[i]);
for (int offset = 16; offset >= 1; offset /= 2) max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset));
max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf
float sum_lse = 0;
for (int i = 0; i < kNLsePerThread; ++i) sum_lse = sum_lse + expf(local_lse[i] - max_lse);
for (int offset = 16; offset >= 1; offset /= 2) sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset);
float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : logf(sum_lse) + max_lse;
if (tidx == 0) gLSE(0) = global_lse;
for (int i = 0; i < kNLsePerThread; ++i) {
const int split = i * 32 + tidx;
if (split < actual_num_splits) sLseScale[split] = expf(local_lse[i] - global_lse);
}
}
__syncthreads();
static_assert(kHeadDimV % kNThreads == 0);
constexpr int Elements = kHeadDimV / kNThreads;
const index_t row_offset_oaccum = (split_offset * hs + hs_idx) * kHeadDimV;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
Shape<Int<kHeadDimV>>{}, Stride<_1>{});
using GmemTiledCopyOaccum = decltype(make_tiled_copy(
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
Layout<Shape<Int<kNThreads>>>{},
Layout<Shape<Int<Elements>>>{}));
GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
clear(tOrO);
for (int split = 0; split < actual_num_splits; ++split) {
cute::copy(tOgOaccum, tOrOaccum);
ElementAccum lse_scale = sLseScale[split];
for (int i = 0; i < size(tOrO); ++i) {
tOrO(i) += lse_scale * tOrOaccum(i);
}
tOgOaccum.data() = tOgOaccum.data() + hs * kHeadDimV;
}
Tensor rO = flash::convert_type<Element>(tOrO);
const int head_idx = (bidx - batch_idx * hs) / params.seqlen_q;
const int row = bidx - batch_idx * hs - head_idx * params.seqlen_q;
auto o_ptr = reinterpret_cast<Element *>(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride;
Tensor gO = make_tensor(make_gmem_ptr(o_ptr + tidx * Elements), Shape<Int<decltype(size<0>(rO))::value>>{}, Stride<_1>{});
cute::copy(rO, gO);
}
} // namespace flash
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, typename SharedStorage>
void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params &params, cudaStream_t stream) {
FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN);
const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
auto kernel = &flash::flash_fwd_splitkv_mla_kernel<Kernel_traits, Is_causal, SharedStorage>;
constexpr size_t smem_size = sizeof(SharedStorage);
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
kernel<<<dim3(num_m_block, params.h, params.num_sm_parts), Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
CHECK_CUDA_KERNEL_LAUNCH();
dim3 grid_combine(params.b * params.h * params.seqlen_q);
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] {
auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel<
typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>;
combine_kernel<<<grid_combine, 128, 0, stream>>>(params);
});
CHECK_CUDA_KERNEL_LAUNCH();
}
template<typename T, int Headdim>
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params &params, cudaStream_t stream) {
static_assert(Headdim == 576);
FLASH_ASSERT(params.d_v == 512);
FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV
using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>;
run_flash_splitkv_fwd_mla<Kernel_traits, flash::SharedStorageMLA<Kernel_traits>>(params, stream);
}
#pragma once
namespace Config {
static constexpr int BLOCK_SIZE_M = 64;
static constexpr int PAGE_BLOCK_SIZE = 64;
static constexpr int HEAD_DIM_K = 576;
static constexpr int HEAD_DIM_V = 512;
static constexpr int FIXED_OVERHEAD_NUM_BLOCKS = 5;
}
#include "flash_fwd_mla_kernel.h"
#include "get_mla_metadata.h"
static constexpr int MaxBatchSize = 4096;
#include <cuda_runtime_api.h>
#include <cutlass/fast_math.h>
__global__ void __launch_bounds__(256, 1, 1)
#include "utils.h"
__global__ void __launch_bounds__(32, 1, 1)
get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
int *seqlens_k_ptr = params.seqlens_k_ptr;
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
......@@ -12,8 +15,9 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
int num_sm_parts = params.num_sm_parts;
__shared__ int num_blocks_shared[MaxBatchSize];
__shared__ int num_splits_shared[MaxBatchSize];
extern __shared__ int shared_mem[];
int* num_blocks_shared = shared_mem; // [batch_size]
int* num_splits_shared = shared_mem + batch_size; // [batch_size+1]
int total_num_blocks = 0;
for (int i = threadIdx.x; i < batch_size; i += 32) {
......@@ -27,7 +31,7 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
__syncwarp();
if (threadIdx.x == 0) {
int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
int payload = max(cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks, 2*fixed_overhead_num_blocks);
int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
num_splits_shared[0] = 0;
......@@ -70,8 +74,9 @@ get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
}
}
void get_mla_metadata_func(Mla_metadata_params &params, cudaStream_t stream) {
FLASH_ASSERT(params.batch_size < MaxBatchSize);
get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params);
void run_get_mla_metadata_kernel(Mla_metadata_params &params, cudaStream_t stream) {
int smem_size = sizeof(int) * (params.batch_size*2+1);
CHECK_CUDA(cudaFuncSetAttribute(get_mla_metadata_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
get_mla_metadata_kernel<<<1, 32, smem_size, stream>>>(params);
CHECK_CUDA_KERNEL_LAUNCH();
}
\ No newline at end of file
}
#pragma once
#include "params.h"
void run_get_mla_metadata_kernel(Mla_metadata_params &params, cudaStream_t stream);
#include "mla_combine.h"
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include "params.h"
#include "utils.h"
#include "config.h" // for BLOCK_SIZE_M and HEAD_DIM_V
using namespace cute;
template<typename ElementT, int HEAD_DIM_V, int BLOCK_SIZE_M, int MAX_SPLITS, int NUM_THREADS>
__global__ void __launch_bounds__(NUM_THREADS)
flash_fwd_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
// grid_shape: [batch_size, num_q_heads*s_q / BLOCK_SIZE_M]
// Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result
static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m
const int batch_idx = blockIdx.x;
const int m_block_idx = blockIdx.y;
const int warp_idx = threadIdx.x / 32;
const int lane_idx = threadIdx.x % 32;
const int start_split_idx = __ldg(params.num_splits_ptr + batch_idx);
const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1);
const int my_num_splits = end_split_idx - start_split_idx;
FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
if (my_num_splits == 1) {
return;
}
const int num_q_seqs = params.q_seq_per_hk * params.h_k;
const int num_cur_valid_q_seqs = min(BLOCK_SIZE_M, num_q_seqs - m_block_idx*BLOCK_SIZE_M);
Tensor gLseAccum = make_tensor(
make_gmem_ptr((float*)params.softmax_lseaccum_ptr + start_split_idx*num_q_seqs + m_block_idx*BLOCK_SIZE_M),
Shape<Int<MAX_SPLITS>, Int<BLOCK_SIZE_M>>{},
make_stride(num_q_seqs, _1{})
);
Tensor gLse = make_tensor(
make_gmem_ptr((float*)params.softmax_lse_ptr + batch_idx*num_q_seqs + m_block_idx*BLOCK_SIZE_M),
Shape<Int<BLOCK_SIZE_M>>{},
Stride<_1>{}
);
extern __shared__ float smem_buf[];
Tensor sLseScale = make_tensor(
make_smem_ptr(smem_buf),
Shape<Int<BLOCK_SIZE_M>, Int<MAX_SPLITS>>{},
Stride<Int<MAX_SPLITS+1>, _1>{} // +1 to avoid bank conflict
);
// Wait for the previous kernel (the MLA kernel) to finish
cudaGridDependencySynchronize();
// Read gLseAccum into sLseScale
{
#pragma unroll 4
for (int elem_idx = threadIdx.x; elem_idx < my_num_splits*BLOCK_SIZE_M; elem_idx += NUM_THREADS) {
int split_idx = elem_idx / BLOCK_SIZE_M;
int seq_idx = elem_idx % BLOCK_SIZE_M;
sLseScale(seq_idx, split_idx) = seq_idx < num_cur_valid_q_seqs ? gLseAccum(split_idx, seq_idx) : -INFINITY;
}
__syncthreads();
}
if (warp_idx >= num_cur_valid_q_seqs)
return;
// Warp #i gathers LseAccum for seq #i
{
constexpr int NUM_LSE_PER_THREAD = cute::ceil_div(MAX_SPLITS, 32);
float local_lse[NUM_LSE_PER_THREAD];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
const int split_idx = i*32 + lane_idx;
local_lse[i] = split_idx < my_num_splits ? sLseScale(warp_idx, split_idx) : -INFINITY;
}
float max_lse = -INFINITY;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
max_lse = max(max_lse, local_lse[i]);
CUTLASS_PRAGMA_UNROLL
for (int offset = 16; offset >= 1; offset /= 2)
max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset));
max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf
float sum_lse = 0;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i)
sum_lse = sum_lse + exp2f(local_lse[i] - max_lse);
CUTLASS_PRAGMA_UNROLL
for (int offset = 16; offset >= 1; offset /= 2)
sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset);
float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : log2f(sum_lse) + max_lse;
if (lane_idx == 0)
gLse(warp_idx) = global_lse / (float)M_LOG2E;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < NUM_LSE_PER_THREAD; ++i) {
const int split_idx = i*32 + lane_idx;
if (split_idx < my_num_splits) sLseScale(warp_idx, split_idx) = exp2f(local_lse[i] - global_lse);
}
}
__syncwarp();
// Warp #i accumulates activation for seq #i
{
const int64_t row_offset_oaccum = (int64_t)(start_split_idx*num_q_seqs+m_block_idx*BLOCK_SIZE_M+warp_idx) * HEAD_DIM_V;
Tensor gOaccum = make_tensor(
make_gmem_ptr(reinterpret_cast<float *>(params.oaccum_ptr) + row_offset_oaccum),
Shape<Int<MAX_SPLITS>, Int<HEAD_DIM_V>>{},
make_stride(num_q_seqs*HEAD_DIM_V, _1{})
);
static_assert(HEAD_DIM_V % 32 == 0);
constexpr int ELEMS_PER_THREAD = HEAD_DIM_V / 32;
float result[ELEMS_PER_THREAD];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ELEMS_PER_THREAD; ++i)
result[i] = 0.0f;
#pragma unroll 2
for (int split = 0; split < my_num_splits; ++split) {
float lse_scale = sLseScale(warp_idx, split);
if (lse_scale != 0.f) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ELEMS_PER_THREAD; ++i) {
result[i] += lse_scale * gOaccum(split, lane_idx + i*32);
}
}
}
cudaTriggerProgrammaticLaunchCompletion();
const int q_seq_idx = m_block_idx*BLOCK_SIZE_M + warp_idx;
const int k_head_idx = q_seq_idx / params.q_seq_per_hk;
auto o_ptr = reinterpret_cast<ElementT *>(params.o_ptr) + batch_idx*params.o_batch_stride + k_head_idx*params.o_head_stride + (q_seq_idx%params.q_seq_per_hk)*params.o_row_stride;
Tensor gO = make_tensor(
make_gmem_ptr(o_ptr),
Shape<Int<HEAD_DIM_V>>{},
Stride<_1>{}
);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ELEMS_PER_THREAD; ++i)
gO(lane_idx+i*32) = (ElementT)result[i];
}
}
#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \
[&] { \
if (NUM_SPLITS <= 32) { \
constexpr static int NAME = 32; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 64) { \
constexpr static int NAME = 64; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 96) { \
constexpr static int NAME = 96; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 128) { \
constexpr static int NAME = 128; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 160) { \
constexpr static int NAME = 160; \
return __VA_ARGS__(); \
} else { \
FLASH_ASSERT(false); \
} \
}()
template<typename ElementT>
void run_flash_mla_combine_kernel(Flash_fwd_mla_params &params, cudaStream_t stream) {
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, NUM_SPLITS, [&] {
constexpr int BLOCK_SIZE_M = 8;
constexpr int NUM_THREADS = BLOCK_SIZE_M*32;
constexpr size_t smem_size = BLOCK_SIZE_M*(NUM_SPLITS+1)*sizeof(float);
auto combine_kernel = &flash_fwd_mla_combine_kernel<ElementT, Config::HEAD_DIM_V, BLOCK_SIZE_M, NUM_SPLITS, NUM_THREADS>;
CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
cudaLaunchConfig_t combine_kernel_config = {
dim3(params.b, cute::ceil_div(params.h_k*params.q_seq_per_hk, BLOCK_SIZE_M), 1),
dim3(NUM_THREADS, 1, 1),
smem_size,
stream,
attribute,
1
};
cudaLaunchKernelEx(&combine_kernel_config, combine_kernel, params);
});
CHECK_CUDA_KERNEL_LAUNCH();
}
template void run_flash_mla_combine_kernel<cutlass::bfloat16_t>(Flash_fwd_mla_params &params, cudaStream_t stream);
#ifndef FLASH_MLA_DISABLE_FP16
template void run_flash_mla_combine_kernel<cutlass::half_t>(Flash_fwd_mla_params &params, cudaStream_t stream);
#endif
\ No newline at end of file
#pragma once
#include "params.h"
template<typename ElementT>
void run_flash_mla_combine_kernel(Flash_fwd_mla_params &params, cudaStream_t stream);
......@@ -5,39 +5,41 @@
struct Flash_fwd_mla_params {
using index_t = int64_t;
int b, seqlen_q, d, d_v;
int h, h_h_k_ratio, ngroups;
int b; // batch size
int s_q;
int q_seq_per_hk; // The number of q(s) per KV head, = h_q / h_k * s_q
int d, d_v; // K/V dimension
int h_q, h_k; // The number of Q/K heads
int num_blocks; // Number of blocks in total
int q_head_per_hk; // The number of q_head(s) per KV head, = h_q / h_k
bool is_causal;
float scale_softmax, scale_softmax_log2;
int *__restrict__ cu_seqlens_k;
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
void *__restrict__ o_ptr;
void *__restrict__ softmax_lse_ptr;
index_t q_batch_stride;
index_t k_batch_stride;
index_t v_batch_stride;
index_t o_batch_stride;
index_t q_row_stride;
index_t k_row_stride;
index_t v_row_stride;
index_t o_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t v_head_stride;
index_t o_head_stride;
int *__restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;
int *__restrict__ seqlens_k_ptr;
int *__restrict__ tile_scheduler_metadata_ptr;
int num_sm_parts;
int *__restrict__ num_splits_ptr;
int total_num_splits;
void *__restrict__ softmax_lseaccum_ptr;
void *__restrict__ oaccum_ptr;
};
......@@ -45,11 +47,6 @@ struct Flash_fwd_mla_params {
static constexpr int TileSchedulerMetaDataSize = 8;
// [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _]
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim>
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params &params, cudaStream_t stream);
struct Mla_metadata_params {
int *__restrict__ seqlens_k_ptr;
int *__restrict__ tile_scheduler_metadata_ptr;
......@@ -59,5 +56,3 @@ struct Mla_metadata_params {
int fixed_overhead_num_blocks;
int num_sm_parts;
};
void get_mla_metadata_func(Mla_metadata_params &params, cudaStream_t stream);
#include <cutlass/cutlass.h>
#include "params.h"
#include "utils.h"
#include "config.h"
#include "traits.h"
using namespace cute;
using cutlass::arch::NamedBarrier;
// Here we use MAX_INIT_VAL_SM to initialize sM, and MAX_INIT_VAL for masking
// The reason is that, we need to calculate new_max = max(sM(row_idx), cur_max*scale_softmax_log2)
// so we must guarantee that MAX_INIT_VAL*scale_softmax_log2 < MAX_INIT_VAL_SM
static constexpr float MAX_INIT_VAL_SM = -1e30f;
static constexpr float MAX_INIT_VAL = -1e33f;
__forceinline__ __device__ int get_AorC_row_idx(int local_row_idx, int idx_in_warpgroup) {
// In the layout of fragment A and fragment C during WGMMA, data each thread holds resides in two particular rows. This function converts the local_row_idx (0~2) to the actual row_idx
// You may refer to this link for the detailed layout: https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n16-a
int row_idx = (idx_in_warpgroup/32)*16 + local_row_idx*8 + (idx_in_warpgroup%32/4);
return row_idx;
}
// Launch TMA copy for a range of KV tile
// A tile has a shape of PAGE_BLOCK_SIZE (64) x 64
template<
int START_HEAD_DIM_TILE_IDX,
int END_HEAD_DIM_TILE_IDX,
typename TMA_K_OneTile,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1
>
__forceinline__ __device__ void launch_kv_tiles_copy_tma(
Tensor<Engine0, Layout0> const &gKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K)
Tensor<Engine1, Layout1> &sKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K), swizzled
TMA_K_OneTile &tma_K,
TMABarrier* barriers_K,
int idx_in_warpgroup
) {
if (idx_in_warpgroup == 0) {
auto thr_tma = tma_K.get_slice(_0{});
Tensor cur_gKV = thr_tma.partition_S(gKV)(_, _0{}, Int<START_HEAD_DIM_TILE_IDX>{});
Tensor cur_sKV = thr_tma.partition_D(sKV)(_, _0{}, Int<START_HEAD_DIM_TILE_IDX>{});
cute::copy(tma_K.with(reinterpret_cast<typename TMABarrier::ValueType &>(barriers_K[START_HEAD_DIM_TILE_IDX]), 0, cute::TMA::CacheHintSm90::EVICT_FIRST), cur_gKV, cur_sKV);
if constexpr (START_HEAD_DIM_TILE_IDX+1 < END_HEAD_DIM_TILE_IDX) {
launch_kv_tiles_copy_tma<START_HEAD_DIM_TILE_IDX+1, END_HEAD_DIM_TILE_IDX>(gKV, sKV, tma_K, barriers_K, idx_in_warpgroup);
}
}
}
// Prefetch some KV tiles
// Currently this is not used because it leads to performance degradation
template<
int START_HEAD_DIM_TILE_IDX,
int END_HEAD_DIM_TILE_IDX,
typename TMA_K_OneTile,
typename Engine0, typename Layout0
>
__forceinline__ __device__ void prefetch_kv_tiles(
Tensor<Engine0, Layout0> const &gKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K)
TMA_K_OneTile &tma_K,
int idx_in_warpgroup
) {
if (idx_in_warpgroup == 0) {
auto thr_tma = tma_K.get_slice(_0{});
Tensor cur_gKV = thr_tma.partition_S(gKV)(_, _0{}, Int<START_HEAD_DIM_TILE_IDX>{});
cute::prefetch(tma_K, cur_gKV);
if constexpr (START_HEAD_DIM_TILE_IDX+1 < END_HEAD_DIM_TILE_IDX) {
prefetch_kv_tiles<START_HEAD_DIM_TILE_IDX+1, END_HEAD_DIM_TILE_IDX>(gKV, tma_K, idx_in_warpgroup);
}
}
}
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/utils.h
// * Copyright (c) 2024, Tri Dao.
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC);
if constexpr (arrive) {
warpgroup_arrive();
}
if constexpr (zero_init) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
} else {
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
}
if constexpr (commit) {
warpgroup_commit_batch();
}
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}
// Wait for one KV-tile to be ready, and then calculate P += Q K^T for one Q-tile (BLOCK_SIZE_Mx64) and one KV-tile (PAGE_BLOCK_SIZEx64)
// The Q-tile should be in shared memory
template<
typename TiledMMA,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1,
typename Engine2, typename Layout2
>
__forceinline__ __device__ void qkt_gemm_one_tile_sQ(
TiledMMA &tiled_mma,
Tensor<Engine0, Layout0> const &thr_mma_sQ_tile, // (MMA, 1, 4)
Tensor<Engine1, Layout1> const &thr_mma_sKV_tile, // (MMA, 1, 4)
Tensor<Engine2, Layout2> &rP, // ((2, 2, 8), 1, 1)
TMABarrier* barrier,
bool &cur_phase,
int idx_in_warpgroup
) {
if (idx_in_warpgroup == 0) {
barrier->arrive_and_expect_tx(64*64*2);
}
barrier->wait(cur_phase ? 1 : 0);
warpgroup_fence_operand(rP);
warpgroup_arrive();
cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _0{}), thr_mma_sKV_tile(_, _, _0{}), rP);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _1{}), thr_mma_sKV_tile(_, _, _1{}), rP);
cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _2{}), thr_mma_sKV_tile(_, _, _2{}), rP);
cute::gemm(tiled_mma, thr_mma_sQ_tile(_, _, _3{}), thr_mma_sKV_tile(_, _, _3{}), rP);
warpgroup_commit_batch();
warpgroup_fence_operand(rP);
}
template<
typename TiledMMA,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1,
typename Engine2, typename Layout2
>
__forceinline__ __device__ void qkt_gemm_one_tile_rQ(
TiledMMA &tiled_mma,
Tensor<Engine0, Layout0> const &thr_mma_rQ_tile, // (MMA, 1, 4)
Tensor<Engine1, Layout1> const &thr_mma_sKV_tile, // (MMA, 1, 4)
Tensor<Engine2, Layout2> &rP, // ((2, 2, 8), 1, 1)
TMABarrier* barrier,
bool &cur_phase,
int idx_in_warpgroup
) {
if (idx_in_warpgroup == 0) {
barrier->arrive_and_expect_tx(64*64*2);
}
barrier->wait(cur_phase ? 1 : 0);
warpgroup_fence_operand(const_cast<Tensor<Engine0, Layout0> &>(thr_mma_rQ_tile));
warpgroup_fence_operand(rP);
warpgroup_arrive();
cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _0{}), thr_mma_sKV_tile(_, _, _0{}), rP);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _1{}), thr_mma_sKV_tile(_, _, _1{}), rP);
cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _2{}), thr_mma_sKV_tile(_, _, _2{}), rP);
cute::gemm(tiled_mma, thr_mma_rQ_tile(_, _, _3{}), thr_mma_sKV_tile(_, _, _3{}), rP);
warpgroup_commit_batch();
warpgroup_fence_operand(rP);
warpgroup_fence_operand(const_cast<Tensor<Engine0, Layout0> &>(thr_mma_rQ_tile));
}
// Pipelined TMA wait and Q K^T gemm
// In order to overlap memory copy (G->S copy for K) and computation, we divide both Q and K into tiles of shape (BLOCK_SIZE_M, 64), and (PAGE_BLOCK_SIZE, 64) respectively, and then do the computation as follows:
// - Wait for the 0-th tile to be ready using `barrier.wait()`
// - Compute Q K^T for the 0-th tile
// - Wait for the 1-st tile to be ready
// - Compute Q K^T for the 1-st tile
// ...
// This gives latter tiles more time to be ready, and thus can overlap the memory copy and computation
template<
typename T, // Traits
int PHASE_IDX, // See comments in the code
typename Engine0, typename Layout0,
typename Engine1, typename Layout1,
typename Engine2, typename Layout2,
typename Engine3, typename Layout3
>
__forceinline__ __device__ void warpgroup_cooperative_qkt_gemm(
Tensor<Engine0, Layout0> &sQ, // (BLOCK_SIZE_M, HEAD_DIM_K)
Tensor<Engine1, Layout1> &sKV, // (PAGE_BLOCK_SIZE, HEAD_DIM_K)
Tensor<Engine2, Layout2> &rP, // ((2, 2, 8), 1, 1)
Tensor<Engine3, Layout3> &rQ8, // The 8-th tile of Q. We store it separately to leave some room for storing sP1
TMABarrier* barriers,
bool &cur_phase,
int idx_in_warpgroup
) {
Tensor sQ_tiled = flat_divide(sQ, Shape<Int<T::BLOCK_SIZE_M>, _64>{})(_, _, _0{}, _); // (BLOCK_SIZE_M, 64, 9)
Tensor sKV_tiled = flat_divide(sKV, Shape<Int<T::PAGE_BLOCK_SIZE>, _64>{})(_, _, _0{}, _); // (PAGE_BLOCK_SIZE, 64, 9)
TiledMMA tiled_mma_sQ = (typename T::TiledMMA_QK_sQ){};
ThrMMA thr_mma_sQ = tiled_mma_sQ.get_slice(idx_in_warpgroup);
Tensor thr_mma_sQ_tiled = thr_mma_sQ.partition_fragment_A(sQ_tiled); // (MMA, 1, 4, 9)
Tensor thr_mma_sKV_tiled = thr_mma_sQ.partition_fragment_B(sKV_tiled); // (MMA, 1, 4, 9)
TiledMMA tiled_mma_rQ = (typename T::TiledMMA_QK_rQ){};
#define QKT_GEMM_ONE_TILE(TILE_IDX) \
if constexpr(TILE_IDX != 8) { \
qkt_gemm_one_tile_sQ(tiled_mma_sQ, thr_mma_sQ_tiled(_, _, _, Int<TILE_IDX>{}), thr_mma_sKV_tiled(_, _, _, Int<TILE_IDX>{}), rP, barriers + TILE_IDX, cur_phase, idx_in_warpgroup); \
} else { \
qkt_gemm_one_tile_rQ(tiled_mma_rQ, rQ8, thr_mma_sKV_tiled(_, _, _, Int<TILE_IDX>{}), rP, barriers + TILE_IDX, cur_phase, idx_in_warpgroup); \
}
if constexpr (PHASE_IDX == 0) {
// In PHASE-0, warpgroup 0 calculates Q K^T for the first 4 tiles
tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::Zero;
tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One;
QKT_GEMM_ONE_TILE(0);
QKT_GEMM_ONE_TILE(1);
QKT_GEMM_ONE_TILE(2);
QKT_GEMM_ONE_TILE(3);
} else if constexpr (PHASE_IDX == 1) {
// In PHASE-1, warpgroup 1 calculates Q K^T for all the 9 tiles
tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::Zero;
tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One;
QKT_GEMM_ONE_TILE(4);
QKT_GEMM_ONE_TILE(5);
QKT_GEMM_ONE_TILE(6);
QKT_GEMM_ONE_TILE(7);
QKT_GEMM_ONE_TILE(8);
QKT_GEMM_ONE_TILE(0);
QKT_GEMM_ONE_TILE(1);
QKT_GEMM_ONE_TILE(2);
QKT_GEMM_ONE_TILE(3);
cur_phase ^= 1;
} else {
// In PHASE-2, warpgroup 0 calculates Q K^T for the last 5 tiles
static_assert(PHASE_IDX == 2);
tiled_mma_sQ.accumulate_ = GMMA::ScaleOut::One;
tiled_mma_rQ.accumulate_ = GMMA::ScaleOut::One;
QKT_GEMM_ONE_TILE(4);
QKT_GEMM_ONE_TILE(5);
QKT_GEMM_ONE_TILE(6);
QKT_GEMM_ONE_TILE(7);
QKT_GEMM_ONE_TILE(8);
cur_phase ^= 1;
}
}
template<
typename T,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1,
typename Engine2, typename Layout2
>
__forceinline__ __device__ void warpgroup_cooperative_qkt_gemm_no_pipeline(
Tensor<Engine0, Layout0> &sQ, // (BLOCK_SIZE_M, HEAD_DIM_K)
Tensor<Engine1, Layout1> &sKV, // (BLOCK_SIZE_M, HEAD_DIM_K)
Tensor<Engine2, Layout2> &rP, // ((2, 2, 8), 1, 1)
int idx_in_warpgroup
) {
TiledMMA tiled_mma = (typename T::TiledMMA_QK_sQ){};
ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);
Tensor thr_mma_sQ = thr_mma.partition_fragment_A(sQ); // (MMA, 1, 576/16=36)
Tensor thr_mma_sKV = thr_mma.partition_fragment_B(sKV); // (MMA, 1, 576/16=36)
gemm<true, -1>(tiled_mma, thr_mma_sQ, thr_mma_sKV, rP);
}
// Compute O += PV, where P resides in register
template<
typename T,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1,
typename Engine2, typename Layout2
>
__forceinline__ __device__ void warpgroup_cooperative_pv_gemm_localP(
Tensor<Engine0, Layout0> &rP, // ((2, 2, 8), 1, 1), fragment A layout
Tensor<Engine1, Layout1> &sKV_half, // (HEAD_DIM_V/2, PAGE_BLOCK_SIZE)
Tensor<Engine2, Layout2> &rO, // ((2, 2, 32), 1, 1)
int idx_in_warpgroup
) {
TiledMMA tiled_mma = (typename T::TiledMMA_PV_LocalP){};
ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);
Tensor rP_retiled = make_tensor(rP.data(), Layout<
Shape<Shape<_2, _2, _2>, _1, _4>,
Stride<Stride<_1, _2, _4>, _0, _8>
>{});
Tensor thr_mma_sKV_half = thr_mma.partition_fragment_B(sKV_half); // (MMA, 1, 64/16=4)
gemm<false, -1>(tiled_mma, rP_retiled, thr_mma_sKV_half, rO);
}
// Compute O += PV, where P resides in shared memory
template<
typename T,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1,
typename Engine2, typename Layout2
>
__forceinline__ __device__ void warpgroup_cooperative_pv_gemm_remoteP(
Tensor<Engine0, Layout0> &sP,
Tensor<Engine1, Layout1> &sKV_half, // (HEAD_DIM_V/2, PAGE_BLOCK_SIZE)
Tensor<Engine2, Layout2> &rO, // ((2, 2, 32), 1, 1)
int idx_in_warpgroup
) {
TiledMMA tiled_mma = (typename T::TiledMMA_PV_RemoteP){};
ThrMMA thr_mma = tiled_mma.get_slice(idx_in_warpgroup);
Tensor thr_mma_sP = thr_mma.partition_fragment_A(sP);
Tensor thr_mma_sKV_half = thr_mma.partition_fragment_B(sKV_half); // (MMA, 1, 64/16=4)
gemm<false, -1>(tiled_mma, thr_mma_sP, thr_mma_sKV_half, rO);
}
template<
typename T,
bool DO_OOB_FILLING,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1,
typename Engine2, typename Layout2,
typename Engine3, typename Layout3,
typename Engine4, typename Layout4
>
__forceinline__ __device__ void wg0_bunch_0(
Tensor<Engine0, Layout0> &rPb, // ((2, 2, 8), 1, 1)
Tensor<Engine1, Layout1> &rP0, // ((2, 2, 8), 1, 1)
Tensor<Engine2, Layout2> &rO0, // ((2, 2, 32), 1, 1)
Tensor<Engine3, Layout3> &sScale0, // (BLOCK_SIZE_M)
Tensor<Engine4, Layout4> &sM, // (BLOCK_SIZE_M)
float rL[2],
int rRightBorderForQSeq[2],
float scale_softmax_log2,
int start_token_idx,
int idx_in_warpgroup
) {
// This piece of code is tightly coupled [Accumulate's layout](https://docs.nvidia.com/cuda/parallel-thread-execution/_images/wgmma-64N16-D.png)
CUTLASS_PRAGMA_UNROLL
for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {
int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup);
// Mask, and get row-wise max
float cur_max = MAX_INIT_VAL;
CUTLASS_PRAGMA_UNROLL
for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) {
if constexpr (DO_OOB_FILLING) {
int token_idx = start_token_idx + (i/4)*8 + idx_in_warpgroup%4*2;
rP0(i) = token_idx < rRightBorderForQSeq[local_row_idx] ? rP0(i) : MAX_INIT_VAL;
rP0(i+1) = token_idx+1 < rRightBorderForQSeq[local_row_idx] ? rP0(i+1) : MAX_INIT_VAL;
}
cur_max = max(cur_max, max(rP0(i), rP0(i+1)));
}
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1));
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2));
// Update sM and sL
cur_max *= scale_softmax_log2;
float new_max = max(sM(row_idx), cur_max);
float scale_for_old = exp2f(sM(row_idx) - new_max);
__syncwarp(); // Make sure all reads have finished before updating sM
if (idx_in_warpgroup%4 == 0) {
sScale0(row_idx) = scale_for_old;
sM(row_idx) = new_max;
}
// Scale-O
CUTLASS_PRAGMA_UNROLL
for (int i = local_row_idx ? 2 : 0; i < size(rO0); i += 4) {
rO0(i) *= scale_for_old;
rO0(i+1) *= scale_for_old;
}
// Scale, exp, and get row-wise expsum
float cur_sum = 0;
CUTLASS_PRAGMA_UNROLL
for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) {
rP0(i) = exp2f(rP0(i)*scale_softmax_log2 - new_max);
rP0(i+1) = exp2f(rP0(i+1)*scale_softmax_log2 - new_max);
rPb(i) = (typename T::InputT)rP0(i);
rPb(i+1) = (typename T::InputT)rP0(i+1);
cur_sum += rP0(i) + rP0(i+1);
}
rL[local_row_idx] = rL[local_row_idx]*scale_for_old + cur_sum;
}
}
template<
typename T,
bool IS_BLK0_LAST,
bool IS_BLK1_LAST,
bool IS_BLK2_LAST,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1,
typename Engine2, typename Layout2,
typename Engine3, typename Layout3,
typename Engine4, typename Layout4,
typename Engine5, typename Layout5
>
__forceinline__ __device__ void wg1_bunch_0(
Tensor<Engine0, Layout0> &rP1b, // ((2, 2, 8), 1, 1)
Tensor<Engine1, Layout1> &sScale1, // (BLOCK_SIZE_M)
Tensor<Engine2, Layout2> &rO1, // ((2, 2, 32), 1, 1)
Tensor<Engine3, Layout3> &sM, // (BLOCK_SIZE_M)
float rL[2],
int rRightBorderForQSeq[2],
Tensor<Engine4, Layout4> const &sScale0, // (BLOCK_SIZE_M)
Tensor<Engine5, Layout5> &rP1, // ((2, 2, 8), 1, 1)
float scale_softmax_log2,
int start_token_idx,
int idx_in_warpgroup
) {
CUTLASS_PRAGMA_UNROLL
for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {
int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup);
// Mask, and get row-wise max
float cur_max = MAX_INIT_VAL;
CUTLASS_PRAGMA_UNROLL
for (int i = local_row_idx ? 2 : 0; i < size(rP1); i += 4) {
if constexpr (IS_BLK1_LAST || IS_BLK2_LAST) {
// Need to apply the mask when either this block is the last one, or
// the next block is the last one (because of the causal mask)
int token_idx = start_token_idx + (i/4)*8 + idx_in_warpgroup%4*2;
rP1(i) = token_idx < rRightBorderForQSeq[local_row_idx] ? rP1(i) : MAX_INIT_VAL;
rP1(i+1) = token_idx+1 < rRightBorderForQSeq[local_row_idx] ? rP1(i+1) : MAX_INIT_VAL;
} else if constexpr (IS_BLK0_LAST) {
rP1(i) = rP1(i+1) = MAX_INIT_VAL;
}
cur_max = max(cur_max, max(rP1(i), rP1(i+1)));
}
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 1));
cur_max = max(cur_max, __shfl_xor_sync(0xffffffff, cur_max, 2));
cur_max *= scale_softmax_log2;
float old_max = sM(row_idx);
float new_max = max(old_max, cur_max);
float scale_for_old = exp2f(old_max - new_max);
__syncwarp();
if (idx_in_warpgroup%4 == 0) {
sM(row_idx) = new_max;
sScale1(row_idx) = scale_for_old;
}
// Scale, exp, and get row-wise expsum
float cur_sum = 0;
if constexpr (!IS_BLK0_LAST) {
CUTLASS_PRAGMA_UNROLL
for (int i = local_row_idx ? 2 : 0; i < size(rP1); i += 4) {
rP1(i) = exp2f(rP1(i)*scale_softmax_log2 - new_max);
rP1(i+1) = exp2f(rP1(i+1)*scale_softmax_log2 - new_max);
rP1b(i) = (typename T::InputT)rP1(i);
rP1b(i+1) = (typename T::InputT)rP1(i+1);
cur_sum += rP1(i) + rP1(i+1);
}
}
// Scale O
float cur_scale_for_o1 = scale_for_old * sScale0(row_idx);
CUTLASS_PRAGMA_UNROLL
for (int i = local_row_idx ? 2 : 0; i < size(rO1); i += 4) {
rO1(i) *= cur_scale_for_o1;
rO1(i+1) *= cur_scale_for_o1;
}
// Update rL
rL[local_row_idx] = rL[local_row_idx]*cur_scale_for_o1 + cur_sum;
}
}
// Save rPb (64x64, bfloat16/half) to sP using the stmatrix instruction
template<
typename T,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1
>
__forceinline__ __device__ void save_rPb_to_sP(
Tensor<Engine0, Layout0> &rPb,
Tensor<Engine1, Layout1> &sP,
int idx_in_warpgroup
) {
auto r2s_copy = make_tiled_copy_C(
Copy_Atom<SM90_U32x4_STSM_N, typename T::InputT>{},
(typename T::TiledMMA_QK_sQ){}
);
ThrCopy thr_copy = r2s_copy.get_slice(idx_in_warpgroup);
Tensor thr_copy_rPb = thr_copy.retile_S(rPb);
Tensor thr_copy_sP = thr_copy.partition_D(sP);
cute::copy(r2s_copy, thr_copy_rPb, thr_copy_sP);
}
// Retrieve rPb (64x64, bfloat16/half) from sP using the ldmatrix instruction
template<
typename T,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1
>
__forceinline__ __device__ void retrieve_rP_from_sP(
Tensor<Engine0, Layout0> &rPb,
Tensor<Engine1, Layout1> const &sP,
int idx_in_warpgroup
) {
TiledCopy s2r_copy = make_tiled_copy_A(
Copy_Atom<SM75_U32x4_LDSM_N, typename T::InputT>{},
(typename T::TiledMMA_PV_LocalP){}
);
ThrCopy thr_copy = s2r_copy.get_slice(idx_in_warpgroup);
Tensor thr_copy_sP = thr_copy.partition_S(sP);
Tensor thr_copy_rPb = thr_copy.retile_D(rPb);
cute::copy(s2r_copy, thr_copy_sP, thr_copy_rPb);
}
// Rescale rP0 and save the result to rPb
template<
typename T,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1,
typename Engine2, typename Layout2
>
__forceinline__ __device__ void wg0_scale_rP0(
Tensor<Engine0, Layout0> const &sScale1, // (BLOCK_M)
Tensor<Engine1, Layout1> const &rP0, // ((2, 2, 8), 1, 1)
Tensor<Engine2, Layout2> &rPb, // ((2, 2, 8), 1, 1)
int idx_in_warpgroup
) {
CUTLASS_PRAGMA_UNROLL
for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {
int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup);
float scale_factor = sScale1(row_idx);
CUTLASS_PRAGMA_UNROLL
for (int i = local_row_idx ? 2 : 0; i < size(rP0); i += 4) {
rPb(i) = (typename T::InputT)(rP0(i)*scale_factor);
rPb(i+1) = (typename T::InputT)(rP0(i+1)*scale_factor);
}
}
}
// Rescale rO0 according to sScale1
template<
typename Engine0, typename Layout0,
typename Engine1, typename Layout1
>
__forceinline__ __device__ void wg0_rescale_rO0(
Tensor<Engine0, Layout0> &rO0,
Tensor<Engine1, Layout1> &sScale1,
float rL[2],
int idx_in_warpgroup
) {
CUTLASS_PRAGMA_UNROLL
for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {
int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup);
float scale_factor = sScale1(row_idx);
CUTLASS_PRAGMA_UNROLL
for (int i = local_row_idx ? 2 : 0; i < size(rO0); i += 4) {
rO0(i) *= scale_factor;
rO0(i+1) *= scale_factor;
}
rL[local_row_idx] *= scale_factor;
}
}
// Fill out-of-bound V with 0.0
// We must fill it since it may contain NaN, which may propagate to the final result
template<
typename T,
typename Engine0, typename Layout0
>
__forceinline__ __device__ void fill_oob_V(
Tensor<Engine0, Layout0> &sV, // tile_to_shape(GMMA::Layout_MN_SW128_Atom<InputT>{}, Shape<Int<HALF_HEAD_DIM>, Int<T::PAGE_BLOCK_SIZE>>{}, LayoutRight{} );
int valid_window_size,
int idx_in_warpgroup
) {
Tensor sV_int64 = make_tensor(
make_smem_ptr((int64_t*)(sV.data().get().get())),
tile_to_shape(
GMMA::Layout_MN_SW128_Atom<cute::int64_t>{},
Shape<Int<256/(64/16)>, Int<T::PAGE_BLOCK_SIZE>>{},
LayoutRight{}
)
);
valid_window_size = max(valid_window_size, 0);
int head_dim_size = size<0>(sV_int64); // 128%head_dim_size == 0 should holds
for (int token_idx = valid_window_size + (idx_in_warpgroup/head_dim_size); token_idx < size<1>(sV); token_idx += (128/head_dim_size)) {
sV_int64(idx_in_warpgroup%head_dim_size, token_idx) = 0;
}
}
// Store O / OAccum
template<
typename T,
bool IS_NO_SPLIT,
typename TMAParams,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1
>
__forceinline__ __device__ void store_o(
Tensor<Engine0, Layout0> &rO, // ((2, 2, 32), 1, 1)
Tensor<Engine1, Layout1> &gOorAccum, // (BLOCK_SIZE_M, HEAD_DIM_V)
float rL[2],
char* sO_addr,
TMAParams &tma_params,
int batch_idx,
int k_head_idx,
int m_block_idx,
int num_valid_seq_q,
int warpgroup_idx,
int idx_in_warpgroup
) {
using InputT = typename T::InputT;
if constexpr (IS_NO_SPLIT) {
// Should convert the output to bfloat16 / float16, and save it to O
Tensor sOutputBuf = make_tensor(make_smem_ptr((InputT*)sO_addr), tile_to_shape(
GMMA::Layout_K_SW128_Atom<InputT>{},
Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_V>>{}
));
Tensor rOb = make_tensor_like<InputT>(rO);
CUTLASS_PRAGMA_UNROLL
for (int idx = 0; idx < size(rO); ++idx) {
rOb(idx) = (InputT)(rO(idx) / rL[idx%4 >= 2]);
}
Tensor sMyOutputBuf = local_tile(sOutputBuf, Shape<_64, _256>{}, make_coord(_0{}, warpgroup_idx));
TiledCopy r2s_tiled_copy = make_tiled_copy_C(
Copy_Atom<SM90_U32x4_STSM_N, InputT>{},
(typename T::TiledMMA_PV_LocalP){}
);
ThrCopy r2s_thr_copy = r2s_tiled_copy.get_slice(idx_in_warpgroup);
Tensor r2s_thr_copy_rOb = r2s_thr_copy.retile_S(rOb);
Tensor r2s_thr_copy_sMyOutputBuf = r2s_thr_copy.partition_D(sMyOutputBuf);
cute::copy(r2s_tiled_copy, r2s_thr_copy_rOb, r2s_thr_copy_sMyOutputBuf);
cutlass::arch::fence_view_async_shared();
__syncthreads();
if (threadIdx.x == 0) {
Tensor tma_gO = tma_params.tma_O.get_tma_tensor(tma_params.shape_O)(_, _, k_head_idx, batch_idx); // (seqlen_q, HEAD_DIM)
auto thr_tma = tma_params.tma_O.get_slice(_0{});
Tensor my_tma_gO = flat_divide(tma_gO, Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_V>>{})(_, _, m_block_idx, _0{});
cute::copy(
tma_params.tma_O,
thr_tma.partition_S(sOutputBuf),
thr_tma.partition_D(my_tma_gO)
);
cute::tma_store_arrive();
}
} else {
// Should save the result to OAccum
Tensor sOutputBuf = make_tensor(make_smem_ptr((float*)sO_addr), Layout<
Shape<_64, _512>,
Stride<Int<520>, _1> // We use stride = 520 here to avoid bank conflict
>{});
CUTLASS_PRAGMA_UNROLL
for (int idx = 0; idx < size(rO); idx += 2) {
int row = (idx_in_warpgroup/32)*16 + (idx_in_warpgroup%32/4) + (idx%4 >= 2 ? 8 : 0);
int col = warpgroup_idx*256 + (idx_in_warpgroup%4)*2 + idx/4*8;
*(float2*)((float*)sO_addr + sOutputBuf.layout()(row, col)) = float2 {
rO(idx) / rL[idx%4 >= 2],
rO(idx+1) / rL[idx%4 >= 2],
};
}
cutlass::arch::fence_view_async_shared();
__syncthreads();
int row = threadIdx.x;
if (row < num_valid_seq_q) {
SM90_BULK_COPY_S2G::copy(&sOutputBuf(row, _0{}), &gOorAccum(row, _0{}), T::HEAD_DIM_V*sizeof(float));
cute::tma_store_arrive();
}
}
}
template<
typename T,
typename TmaParams, typename Tensor0
>
__forceinline__ __device__ void launch_q_copy(
TmaParams const &tma_params,
int batch_idx,
int m_block_idx,
int k_head_idx,
Tensor0 &sQ,
TMABarrier* barrier_Q
) {
if (threadIdx.x == 0) {
Tensor tma_gQ = tma_params.tma_Q.get_tma_tensor(tma_params.shape_Q)(_, _, k_head_idx, batch_idx); // (seqlen_q, HEAD_DIM)
auto thr_tma = tma_params.tma_Q.get_slice(_0{});
Tensor my_tma_gQ = flat_divide(tma_gQ, Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_K>>{})(_, _, m_block_idx, _0{});
cute::copy(
tma_params.tma_Q.with(reinterpret_cast<typename TMABarrier::ValueType &>(*barrier_Q), 0, cute::TMA::CacheHintSm90::EVICT_FIRST),
thr_tma.partition_S(my_tma_gQ),
thr_tma.partition_D(sQ)
);
barrier_Q->arrive_and_expect_tx(64*576*2);
}
}
template<
typename T,
bool IS_R,
typename Engine0, typename Layout0
>
__forceinline__ __device__ auto get_half_V(
Tensor<Engine0, Layout0> &sK
) {
Tensor sV = make_tensor(sK.data(), (typename T::SmemLayoutV){});
return flat_divide(sV, Shape<Int<T::HEAD_DIM_V/2>, Int<T::PAGE_BLOCK_SIZE>>{})(_, _, Int<(int)IS_R>{}, _0{});
}
template<
typename T,
bool IS_BLK0_LAST, // "BLK0" means block_idx+0, "BLK1" means block_idx+1, ...
bool IS_BLK1_LAST,
typename TMAParams,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1,
typename Engine2, typename Layout2,
typename Engine3, typename Layout3,
typename Engine4, typename Layout4,
typename Engine5, typename Layout5,
typename Engine6, typename Layout6,
typename Engine7, typename Layout7,
typename Engine8, typename Layout8,
typename Engine9, typename Layout9,
typename Engine10, typename Layout10,
typename Engine11, typename Layout11
>
__forceinline__ __device__ void wg0_subroutine(
Tensor<Engine0, Layout0> &tma_gK,
Tensor<Engine1, Layout1> &sQ,
Tensor<Engine2, Layout2> &sK0,
Tensor<Engine3, Layout3> &sK1,
Tensor<Engine4, Layout4> &sP0,
Tensor<Engine5, Layout5> &sP1,
Tensor<Engine6, Layout6> &sM,
Tensor<Engine7, Layout7> &sScale0,
Tensor<Engine8, Layout8> &sScale1,
Tensor<Engine9, Layout9> &rQ8,
Tensor<Engine10, Layout10> &rP0,
Tensor<Engine11, Layout11> &rO0,
float rL[2],
int rRightBorderForQSeq[2],
TMABarrier barriers_K0[9],
TMABarrier barriers_K1[9],
bool &cur_phase_K0,
const TMAParams &tma_params,
const Flash_fwd_mla_params &params,
int* block_table_ptr,
int seqlen_k,
int block_idx,
int end_block_idx,
int idx_in_warpgroup
) {
int start_token_idx = block_idx * T::PAGE_BLOCK_SIZE;
#define GET_BLOCK_INDEX(block_idx) ((block_idx) >= end_block_idx ? 0 : __ldg(block_table_ptr + (block_idx)))
int nxt_block0_index = GET_BLOCK_INDEX(block_idx+2);
int nxt_block1_index = GET_BLOCK_INDEX(block_idx+3);
Tensor sV0L = get_half_V<T, 0>(sK0);
Tensor sV1L = get_half_V<T, 0>(sK1);
Tensor rPb = make_tensor<T::InputT>(Shape<Shape<_2, _2, _2>, _1, _4>{});
// Calc P0 = softmax(P0)
wg0_bunch_0<T, IS_BLK0_LAST||IS_BLK1_LAST>(rPb, rP0, rO0, sScale0, sM, rL, rRightBorderForQSeq, params.scale_softmax_log2, start_token_idx, idx_in_warpgroup);
NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sScale0Ready);
// Issue rO0 += rPb @ sV0L
if constexpr (IS_BLK0_LAST) {
fill_oob_V<T>(sV0L, seqlen_k-start_token_idx, idx_in_warpgroup);
cutlass::arch::fence_view_async_shared();
}
warpgroup_cooperative_pv_gemm_localP<T>(rPb, sV0L, rO0, idx_in_warpgroup);
// Wait for rO0, launch TMA for the next V0L
cute::warpgroup_wait<0>();
// Wait for warpgroup 1, rescale P0, notify warpgroup 1
NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sScale1Ready);
if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) {
// Put it here seems to be faster, don't know why
launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, nxt_block0_index), sK0, tma_params.tma_K, barriers_K0, idx_in_warpgroup);
}
wg0_scale_rP0<T>(sScale1, rP0, rPb, idx_in_warpgroup);
save_rPb_to_sP<T>(rPb, sP0, idx_in_warpgroup);
cutlass::arch::fence_view_async_shared();
NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sP0Ready);
// Wait for warpgroup 1, rescale O0, issue rO0 += rPb @ sV1L
if constexpr (!IS_BLK0_LAST) {
if constexpr (IS_BLK1_LAST) {
fill_oob_V<T>(sV1L, seqlen_k-start_token_idx-T::PAGE_BLOCK_SIZE, idx_in_warpgroup);
cutlass::arch::fence_view_async_shared();
}
NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::rO1sP0sV0RIssued);
wg0_rescale_rO0(rO0, sScale1, rL, idx_in_warpgroup);
warpgroup_cooperative_pv_gemm_remoteP<T>(sP1, sV1L, rO0, idx_in_warpgroup);
}
// Issue P0 = Q @ K0^T
// Since TMAs for these 4 tiles are launched right after rO0 += rPb @ sV0L finishes, they should have already finished. Therefore, we issue the first 4 tiles to fill the pipeline.
if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) {
warpgroup_cooperative_qkt_gemm<T, 0>(sQ, sK0, rP0, rQ8, barriers_K0, cur_phase_K0, idx_in_warpgroup);
}
// Wait for rO0 += rPb @ sV1L, launch TMA
if (!IS_BLK0_LAST && !IS_BLK1_LAST && __builtin_expect(block_idx+3 < end_block_idx, true)) {
cute::warpgroup_wait<4>();
launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, nxt_block1_index), sK1, tma_params.tma_K, barriers_K1, idx_in_warpgroup);
}
// Issue P0 = Q @ K0^T
if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) {
warpgroup_cooperative_qkt_gemm<T, 2>(sQ, sK0, rP0, rQ8, barriers_K0, cur_phase_K0, idx_in_warpgroup);
}
// Wait for P0 = Q @ K0^T
cute::warpgroup_wait<0>();
}
template<
typename T,
bool IS_BLK0_LAST,
bool IS_BLK1_LAST,
bool IS_BLK2_LAST,
typename TMAParams,
typename Engine0, typename Layout0,
typename Engine1, typename Layout1,
typename Engine2, typename Layout2,
typename Engine3, typename Layout3,
typename Engine4, typename Layout4,
typename Engine5, typename Layout5,
typename Engine6, typename Layout6,
typename Engine7, typename Layout7,
typename Engine8, typename Layout8,
typename Engine9, typename Layout9,
typename Engine10, typename Layout10,
typename Engine11, typename Layout11
>
__forceinline__ __device__ void wg1_subroutine(
Tensor<Engine0, Layout0> &tma_gK,
Tensor<Engine1, Layout1> &sQ,
Tensor<Engine2, Layout2> &sK0,
Tensor<Engine3, Layout3> &sK1,
Tensor<Engine4, Layout4> &sP0,
Tensor<Engine5, Layout5> &sP1,
Tensor<Engine6, Layout6> &sM,
Tensor<Engine7, Layout7> &sScale0,
Tensor<Engine8, Layout8> &sScale1,
Tensor<Engine9, Layout9> &rQ8,
Tensor<Engine10, Layout10> &rP1,
Tensor<Engine11, Layout11> &rO1,
float rL[2],
int rRightBorderForQSeq[2],
TMABarrier barriers_K0[9],
TMABarrier barriers_K1[9],
bool &cur_phase_K1,
const TMAParams &tma_params,
const Flash_fwd_mla_params &params,
int* block_table_ptr,
int seqlen_k,
int block_idx,
int end_block_idx,
int idx_in_warpgroup
) {
int start_token_idx = block_idx * T::PAGE_BLOCK_SIZE;
int nxt_block0_index = GET_BLOCK_INDEX(block_idx+2);
int nxt_block1_index = GET_BLOCK_INDEX(block_idx+3);
Tensor rP1b = make_tensor<T::InputT>(Shape<Shape<_2, _2, _2>, _1, _4>{});
Tensor sV0R = get_half_V<T, 1>(sK0);
Tensor sV1R = get_half_V<T, 1>(sK1);
// Wait for rP1 and warpgroup 0, run bunch 1, notify warpgroup 0
NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sScale0Ready);
wg1_bunch_0<T, IS_BLK0_LAST, IS_BLK1_LAST, IS_BLK2_LAST>(rP1b, sScale1, rO1, sM, rL, rRightBorderForQSeq, sScale0, rP1, params.scale_softmax_log2, start_token_idx+T::PAGE_BLOCK_SIZE, idx_in_warpgroup);
NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::sScale1Ready);
// Save rPb to sP, and issue rO1 += rP1b @ sV1R
// We do this after notifying warpgroup 1, since both "saving rPb to sP" and "issuing" WGMMA are high-latency operations
if constexpr (!IS_BLK0_LAST) {
save_rPb_to_sP<T>(rP1b, sP1, idx_in_warpgroup);
}
if constexpr (!IS_BLK0_LAST) {
if constexpr (IS_BLK1_LAST) {
fill_oob_V<T>(sV1R, seqlen_k-start_token_idx-T::PAGE_BLOCK_SIZE, idx_in_warpgroup);
cutlass::arch::fence_view_async_shared();
}
warpgroup_cooperative_pv_gemm_localP<T>(rP1b, sV1R, rO1, idx_in_warpgroup);
if constexpr (!IS_BLK1_LAST) {
// We use this proxy for making sP1 visible to the async proxy
// We skip it if IS_BLK1_LAST, since in that case we have already put a fence
cutlass::arch::fence_view_async_shared();
}
}
// Wait for sP0, issue rO1 += sP0 @ sV0R, notify warpgroup 0
NamedBarrier::arrive_and_wait(T::NUM_THREADS, NamedBarriers::sP0Ready);
if constexpr (IS_BLK0_LAST) {
fill_oob_V<T>(sV0R, seqlen_k-start_token_idx, idx_in_warpgroup);
cutlass::arch::fence_view_async_shared();
}
warpgroup_cooperative_pv_gemm_remoteP<T>(sP0, sV0R, rO1, idx_in_warpgroup);
if constexpr (!IS_BLK0_LAST) {
NamedBarrier::arrive(T::NUM_THREADS, NamedBarriers::rO1sP0sV0RIssued);
}
// Wait for rO1 += rP1b @ sV1R, launch TMA for the next V1R
if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST && !IS_BLK2_LAST) {
cute::warpgroup_wait<1>();
launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, nxt_block1_index), sK1, tma_params.tma_K, barriers_K1, idx_in_warpgroup);
}
// Wait for rO1 += sP0 @ sV0R, launch TMA for the next V0R
if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST) {
cute::warpgroup_wait<0>();
launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, nxt_block0_index), sK0, tma_params.tma_K, barriers_K0, idx_in_warpgroup);
}
if constexpr (!IS_BLK0_LAST && !IS_BLK1_LAST && !IS_BLK2_LAST) {
// Issue rP1 = sQ @ sK1, wait
warpgroup_cooperative_qkt_gemm<T, 1>(sQ, sK1, rP1, rQ8, barriers_K1, cur_phase_K1, idx_in_warpgroup);
}
// We put the `cute::warpgroup_wait<0>()` out of the `if` statement above, otherwise
// nvcc cannot correctly analyse the loop, and will think that we are using accumulator
// registers during the WGMMA pipeline, which results in `WARPGROUP.ARRIVE` and `WARPGROUP.DEPBAR.LE` being inserted in SASS and WGMMA instructions being serialized.
// This is also the reason why we put QK^T here, instead of the first operation in the loop
cute::warpgroup_wait<0>();
}
// A helper function for determining the length of the causal mask for one q token
__forceinline__ __device__ int get_mask_len(const Flash_fwd_mla_params &params, int m_block_idx, int local_seq_q_idx) {
int global_seq_q_idx = m_block_idx*Config::BLOCK_SIZE_M + local_seq_q_idx;
if (global_seq_q_idx < params.q_seq_per_hk) {
int s_q_idx = global_seq_q_idx / params.q_head_per_hk;
return params.s_q - s_q_idx - 1;
} else {
// Out-of-bound request, regard as no masks
return 0;
}
}
template<typename T, typename TmaParams>
__global__ void __launch_bounds__(T::NUM_THREADS, 1, 1)
flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params, __grid_constant__ const TmaParams tma_params) {
// grid shape: [
// num_m_blocks (=ceil_div(seqlen_q_ori*(num_q_heads//num_kv_heads))),
// num_kv_heads,
// num_sm_parts
// ]
// An "sm part" is responsible for all the BLOCK_SIZE_M q_heads in the m_block (as specified by m_block_idx), under one kv head (as specified by k_head_idx), of a segment (as specified by [start_block_idx, end_block_idx]) of one request (as specified by batch_idx).
// If is_no_split is True, then this request is exclusively assigned to this sm_part, so we shall write the result directly into params.o_ptr and params.softmax_lse_ptr. Otherwise, write to oaccum_ptr and softmax_lseaccum_ptr, with the corresponding split idx being (n_split_idx + num_splits_ptr[batch_idx])
// For the complete schedule of the kernel, please read our deep-dive write-up (link can be found in the README.md file).
const int m_block_idx = blockIdx.x;
const int k_head_idx = blockIdx.y;
const int partition_idx = blockIdx.z;
const int warpgroup_idx = threadIdx.x / 128;
const int idx_in_warpgroup = threadIdx.x % 128;
// Define shared tensors
extern __shared__ char wksp_buf[];
using SharedMemoryPlan = typename T::SharedMemoryPlan;
SharedMemoryPlan &plan = *reinterpret_cast<SharedMemoryPlan*>(wksp_buf);
Tensor sQ = make_tensor(make_smem_ptr(plan.smem_sQ.data()), (typename T::SmemLayoutQ){});
Tensor sK0 = make_tensor(make_smem_ptr(plan.smem_sK0.data()), (typename T::SmemLayoutK){});
Tensor sK1 = make_tensor(make_smem_ptr(plan.smem_sK1.data()), (typename T::SmemLayoutK){});
Tensor sP0 = make_tensor(make_smem_ptr(plan.smem_sP0.data()), (typename T::SmemLayoutP0){});
Tensor sP1 = flat_divide(sQ, Shape<Int<T::BLOCK_SIZE_M>, Int<T::PAGE_BLOCK_SIZE>>{})(_, _, _0{}, _8{}); // Overlap with sQ's 8-th tile
Tensor sM = make_tensor(make_smem_ptr(plan.smem_sM.data()), make_shape(Int<T::BLOCK_SIZE_M>{}));
Tensor sL_reduction_wksp = make_tensor(make_smem_ptr(plan.sL_reduction_wksp.data()), make_shape(Int<2*T::BLOCK_SIZE_M>{}));
Tensor sScale0 = make_tensor(make_smem_ptr(plan.smem_sScale0.data()), make_shape(Int<T::BLOCK_SIZE_M>{}));
Tensor sScale1 = make_tensor(make_smem_ptr(plan.smem_sScale1.data()), make_shape(Int<T::BLOCK_SIZE_M>{}));
char* sO_addr = (char*)plan.smem_sK0.data(); // Overlap with sK0 and sK1
// Prefetch TMA descriptors
if (threadIdx.x == 0) {
cute::prefetch_tma_descriptor(tma_params.tma_Q.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_params.tma_K.get_tma_descriptor());
cute::prefetch_tma_descriptor(tma_params.tma_O.get_tma_descriptor());
}
// Define TMA stuffs
Tensor tma_gK = tma_params.tma_K.get_tma_tensor(tma_params.shape_K)(_, _, k_head_idx, _);
TMABarrier* barriers_K0 = plan.barriers_K0;
TMABarrier* barriers_K1 = plan.barriers_K1;
TMABarrier* barrier_Q = &(plan.barrier_Q);
// Initialize TMA barriers
if (threadIdx.x == 0) {
barrier_Q->init(1);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 9; ++i) {
barriers_K0[i].init(1);
barriers_K1[i].init(1);
}
cutlass::arch::fence_view_async_shared();
}
__syncthreads();
bool cur_phase_Q = 0, cur_phase_K0 = 0, cur_phase_K1 = 0;
// Programmatic Dependent Launch: Wait for the previous kernel to finish
cudaGridDependencySynchronize();
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
int4 tile_scheduler_metadata = __ldg(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
int begin_idx = tile_scheduler_metadata.x;
int begin_seqlen = tile_scheduler_metadata.y;
int end_idx = tile_scheduler_metadata.z;
int end_seqlen = tile_scheduler_metadata.w;
if (begin_idx >= params.b) return;
int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
// Copy the first Q
launch_q_copy<T>(tma_params, begin_idx, m_block_idx, k_head_idx, sQ, barrier_Q);
#pragma unroll 1
for (int batch_idx = begin_idx; batch_idx <= end_idx; ++batch_idx) {
constexpr int kBlockN = T::PAGE_BLOCK_SIZE;
const int n_split_idx = batch_idx == begin_idx ? begin_n_split_idx : 0;
int seqlen_k = __ldg(params.seqlens_k_ptr + batch_idx);
const int start_block_idx = batch_idx == begin_idx ? begin_seqlen / kBlockN : 0;
int end_block_idx = batch_idx == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN);
const bool is_no_split = start_block_idx == 0 && end_block_idx == cute::ceil_div(seqlen_k, kBlockN);
int rRightBorderForQSeq[2];
if (params.is_causal) {
// The causal mask looks like:
// XXXX
// XXXX
// ...
// XXXX
// XXX
// XXX
// ...
// XXX
// XX
// XX
// ...
// XX
// Firstly, there is a common_mask_len, which is the minimum length of causal masks among all tokens. Since the length of the causal mask decreases monotonically, the common_mask_len is the length of the causal mask for the last token. We consider the common_mask_len as a "reduction in the length of the k-sequence.", and adjust end_block_idx based on it, to save some calculation.
// Besides, a token may have some extra masks other than the common mask. We use rRightBorderForQSeq to denote it, which means the right border of the k-sequence for the particular q token. In this way, (seqlen_k-common_mask_len) - rRightBorderForQSeq < 64 holds, which means that we only need to apply the causal mask to the last two KV blocks
// NOTE This may lead to start_block_idx >= end_block_idx which needs some special handling
int common_mask_len = get_mask_len(params, m_block_idx, T::BLOCK_SIZE_M-1);
end_block_idx = batch_idx == end_idx ? cute::ceil_div(min(end_seqlen, seqlen_k-common_mask_len), kBlockN) : cute::ceil_div(seqlen_k-common_mask_len, kBlockN);
CUTLASS_PRAGMA_UNROLL
for (int local_row_idx = 0; local_row_idx < 2; ++local_row_idx) {
int row_idx = get_AorC_row_idx(local_row_idx, idx_in_warpgroup);
rRightBorderForQSeq[local_row_idx] = min(seqlen_k-get_mask_len(params, m_block_idx, row_idx), end_block_idx*T::PAGE_BLOCK_SIZE);
}
} else {
rRightBorderForQSeq[0] = rRightBorderForQSeq[1] = seqlen_k;
}
// Define global tensors
using InputT = typename T::InputT;
InputT* o_ptr = (InputT*)params.o_ptr + batch_idx*params.o_batch_stride + m_block_idx*T::BLOCK_SIZE_M*params.o_row_stride + k_head_idx*params.o_head_stride; // (BLOCK_SIZE_M, HEAD_DIM_V) : (params.o_row_stride, 1)
float* softmax_lse_ptr = (float*)params.softmax_lse_ptr + (batch_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M; // (BLOCK_SIZE_M) : (1)
int* block_table_ptr = params.block_table + batch_idx*params.block_table_batch_stride; // (/) : (1)
Tensor gO = make_tensor(make_gmem_ptr(o_ptr), make_layout(
Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_V>>{},
make_stride(params.o_row_stride, _1{})
));
Tensor gSoftmaxLse = make_tensor(make_gmem_ptr(softmax_lse_ptr), Layout<
Shape<Int<T::BLOCK_SIZE_M>>,
Stride<_1>
>{});
// Copy K0 and K1
launch_kv_tiles_copy_tma<0, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx)), sK0, tma_params.tma_K, barriers_K0, threadIdx.x);
if (start_block_idx+1 < end_block_idx) {
launch_kv_tiles_copy_tma<4, 9>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x);
launch_kv_tiles_copy_tma<0, 4>(tma_gK(_, _, __ldg(block_table_ptr + start_block_idx+1)), sK1, tma_params.tma_K, barriers_K1, threadIdx.x);
}
Tensor rO = partition_fragment_C((typename T::TiledMMA_PV_LocalP){}, Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_V / 2>>{}); // ((2, 2, 32), 1, 1)
float rL[2];
rL[0] = rL[1] = 0.0f;
// Clear buffers
cute::fill(rO, 0.);
if (threadIdx.x < size(sM)) {
sM[threadIdx.x] = MAX_INIT_VAL_SM;
}
// Wait for Q
barrier_Q->wait(cur_phase_Q);
cur_phase_Q ^= 1;
Tensor rQ8 = make_tensor<InputT>(Shape<Shape<_2, _2, _2>, _1, _4>{});
retrieve_rP_from_sP<T>(rQ8, local_tile(sQ, Shape<_64, _64>{}, Coord<_0, _8>{}), idx_in_warpgroup);
if (warpgroup_idx == 0) {
// Warpgroup 0
Tensor rP0 = make_tensor<float>((typename T::rP0Layout){});
// NOTE We don't use the pipelined version of Q K^T here since it leads
// to a slow-down (or even register spilling, thanks to the great NVCC)
// Wait for K0
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 9; ++i) {
if (idx_in_warpgroup == 0)
barriers_K0[i].arrive_and_expect_tx(64*64*2);
barriers_K0[i].wait(cur_phase_K0);
}
cur_phase_K0 ^= 1;
// Issue P0 = Q @ K0^T, wait
warpgroup_cooperative_qkt_gemm_no_pipeline<T>(sQ, sK0, rP0, idx_in_warpgroup);
cute::warpgroup_wait<0>();
#define LAUNCH_WG0_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST) \
wg0_subroutine<T, IS_BLK0_LAST, IS_BLK1_LAST>( \
tma_gK, sQ, sK0, sK1, sP0, sP1, sM, sScale0, sScale1, \
rQ8, rP0, rO, rL, rRightBorderForQSeq, \
barriers_K0, barriers_K1, cur_phase_K0, \
tma_params, params, \
block_table_ptr, seqlen_k, block_idx, end_block_idx, idx_in_warpgroup \
);
int block_idx = start_block_idx;
#pragma unroll 1
for (; block_idx < end_block_idx-2; block_idx += 2) {
LAUNCH_WG0_SUBROUTINE(false, false);
}
if (block_idx+1 < end_block_idx) {
LAUNCH_WG0_SUBROUTINE(false, true);
} else if (block_idx < end_block_idx) {
LAUNCH_WG0_SUBROUTINE(true, false);
}
} else {
// Warpgroup 1
Tensor rP1 = make_tensor<float>((typename T::rP0Layout){});
if (start_block_idx+1 < end_block_idx) {
// Issue rP1 = sQ @ sK1, wait
warpgroup_cooperative_qkt_gemm<T, 1>(sQ, sK1, rP1, rQ8, barriers_K1, cur_phase_K1, idx_in_warpgroup);
cute::warpgroup_wait<0>();
}
#define LAUNCH_WG1_SUBROUTINE(IS_BLK0_LAST, IS_BLK1_LAST, IS_BLK2_LAST) \
wg1_subroutine<T, IS_BLK0_LAST, IS_BLK1_LAST, IS_BLK2_LAST>( \
tma_gK, sQ, sK0, sK1, sP0, sP1, sM, sScale0, sScale1, \
rQ8, rP1, rO, rL, rRightBorderForQSeq, \
barriers_K0, barriers_K1, cur_phase_K1, \
tma_params, params, \
block_table_ptr, seqlen_k, block_idx, end_block_idx, idx_in_warpgroup \
);
int block_idx = start_block_idx;
#pragma unroll 1
for (; block_idx < end_block_idx-3; block_idx += 2) {
LAUNCH_WG1_SUBROUTINE(false, false, false);
}
if (block_idx+2 < end_block_idx) {
LAUNCH_WG1_SUBROUTINE(false, false, true);
block_idx += 2;
LAUNCH_WG1_SUBROUTINE(true, false, false);
} else if (block_idx+1 < end_block_idx) {
LAUNCH_WG1_SUBROUTINE(false, true, false);
} else if (block_idx < end_block_idx) {
LAUNCH_WG1_SUBROUTINE(true, false, false);
}
}
// Reduce rL across threads within the same warp
rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 1);
rL[0] += __shfl_xor_sync(0xffffffff, rL[0], 2);
rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 1);
rL[1] += __shfl_xor_sync(0xffffffff, rL[1], 2);
// Reduce rL across warpgroups
int my_row = get_AorC_row_idx(0, idx_in_warpgroup);
if (idx_in_warpgroup%4 == 0) {
sL_reduction_wksp[my_row + warpgroup_idx*64] = rL[0];
sL_reduction_wksp[my_row + 8 + warpgroup_idx*64] = rL[1];
}
__syncthreads();
if (warpgroup_idx == 0) {
rL[0] += sL_reduction_wksp[my_row + 64];
rL[1] += sL_reduction_wksp[my_row + 8 + 64];
} else {
if (idx_in_warpgroup%4 == 0) {
sL_reduction_wksp[my_row] += rL[0];
sL_reduction_wksp[my_row + 8] += rL[1];
}
__syncwarp();
rL[0] = sL_reduction_wksp[my_row];
rL[1] = sL_reduction_wksp[my_row+8];
}
// Prune out when rL is 0.0f or NaN
// rL may be 0.0f if there are large values (~10^12) in QK^T, which leads
// to exp2f(P(i)*scale-max) = 0.0f or +inf due to FMA error.
// When this happens, we set rL to 1.0f. This aligns with the old version
// of the MLA kernel.
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 2; ++i)
rL[i] = (rL[i] == 0.0f || rL[i] != rL[i]) ? 1.0f : rL[i];
// Copy Q for the next batch
if (batch_idx+1 <= end_idx) {
launch_q_copy<T>(tma_params, batch_idx+1, m_block_idx, k_head_idx, sQ, barrier_Q);
} else {
// Allow the next kernel (the combine kernel) to launch
// The next kernel MUST be the combine kernel
cudaTriggerProgrammaticLaunchCompletion();
}
int num_valid_seq_q = min(params.q_seq_per_hk - m_block_idx*T::BLOCK_SIZE_M, T::BLOCK_SIZE_M);
if (is_no_split) {
store_o<T, true>(rO, gO, rL, sO_addr, tma_params, batch_idx, k_head_idx, m_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);
int i = threadIdx.x;
if (i < num_valid_seq_q) {
float cur_L = sL_reduction_wksp[i];
gSoftmaxLse(i) = (cur_L == 0.0f || cur_L != cur_L) ? INFINITY : logf(cur_L) + sM(i) / (float)M_LOG2E;
}
cute::tma_store_wait<0>();
} else {
int split_idx = __ldg(params.num_splits_ptr+batch_idx) + n_split_idx;
float* oaccum_ptr = (float*)params.oaccum_ptr + ((split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M)*T::HEAD_DIM_V; // (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1)
float* softmax_lseaccum_ptr = (float*)params.softmax_lseaccum_ptr + (split_idx*params.h_k + k_head_idx)*params.q_seq_per_hk + m_block_idx*T::BLOCK_SIZE_M; // (BLOCK_SIZE_M) : (1)
Tensor gOAccum = make_tensor(make_gmem_ptr(oaccum_ptr), Layout<
Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_V>>,
Stride<Int<T::HEAD_DIM_V>, _1>
>{});
Tensor gSoftmaxLseAccum = make_tensor(make_gmem_ptr(softmax_lseaccum_ptr), Layout<
Shape<Int<T::BLOCK_SIZE_M>>,
Stride<_1>
>{});
store_o<T, false>(rO, gOAccum, rL, sO_addr, tma_params, batch_idx, k_head_idx, m_block_idx, num_valid_seq_q, warpgroup_idx, idx_in_warpgroup);
int i = threadIdx.x;
if (i < num_valid_seq_q) {
float cur_L = sL_reduction_wksp[i];
gSoftmaxLseAccum(i) = (cur_L == 0.0f || cur_L != cur_L) ? -INFINITY : log2f(cur_L) + sM(i);
}
cute::tma_store_wait<0>();
}
if (batch_idx != end_idx)
__syncthreads();
}
}
template<typename InputT>
void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params &params, cudaStream_t stream) {
using T = Traits<InputT>;
auto shape_Q = make_shape(params.q_seq_per_hk, params.d, params.h_k, params.b);
auto tma_Q = cute::make_tma_copy(
SM90_TMA_LOAD{},
make_tensor(
make_gmem_ptr((InputT*)params.q_ptr),
make_layout(
shape_Q,
make_stride(params.q_row_stride, _1{}, params.q_head_stride, params.q_batch_stride)
)
),
tile_to_shape(
GMMA::Layout_K_SW128_Atom<InputT>{},
Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_K>>{}
)
);
auto shape_K = make_shape(Int<T::PAGE_BLOCK_SIZE>{}, Int<T::HEAD_DIM_K>{}, params.h_k, params.num_blocks);
auto tma_K = cute::make_tma_copy(
SM90_TMA_LOAD{},
make_tensor(
make_gmem_ptr((InputT*)params.k_ptr),
make_layout(
shape_K,
make_stride(params.k_row_stride, _1{}, params.k_head_stride, params.k_batch_stride)
)
),
tile_to_shape(
GMMA::Layout_K_SW128_Atom<InputT>{},
Layout<
Shape<Int<T::PAGE_BLOCK_SIZE>, Int<64>>,
Stride<Int<T::HEAD_DIM_K>, _1>
>{}
)
);
auto shape_O = make_shape(params.q_seq_per_hk, params.d_v, params.h_k, params.b);
auto tma_O = cute::make_tma_copy(
SM90_TMA_STORE{},
make_tensor(
make_gmem_ptr((InputT*)params.o_ptr),
make_layout(
shape_O,
make_stride(params.o_row_stride, _1{}, params.o_head_stride, params.o_batch_stride)
)
),
tile_to_shape(
GMMA::Layout_K_SW128_Atom<InputT>{},
Shape<Int<T::BLOCK_SIZE_M>, Int<T::HEAD_DIM_V>>{}
)
);
TmaParams<decltype(shape_Q), decltype(tma_Q), decltype(shape_K), decltype(tma_K), decltype(shape_O), decltype(tma_O)> tma_params = {
shape_Q, tma_Q,
shape_K, tma_K,
shape_O, tma_O
};
auto mla_kernel = &flash_fwd_splitkv_mla_kernel<T, decltype(tma_params)>;
constexpr size_t smem_size = sizeof(typename T::SharedMemoryPlan);
CHECK_CUDA(cudaFuncSetAttribute(mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
const int num_m_block = cute::ceil_div(params.q_seq_per_hk, T::BLOCK_SIZE_M);
cudaLaunchAttribute mla_kernel_attributes[1];
mla_kernel_attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
mla_kernel_attributes[0].val.programmaticStreamSerializationAllowed = 1;
cudaLaunchConfig_t mla_kernel_config = {
dim3(num_m_block, params.h_k, params.num_sm_parts),
dim3(T::NUM_THREADS, 1, 1),
smem_size,
stream,
mla_kernel_attributes,
1
};
cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params);
CHECK_CUDA_KERNEL_LAUNCH();
}
template void run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(Flash_fwd_mla_params &params, cudaStream_t stream);
#ifndef FLASH_MLA_DISABLE_FP16
template void run_flash_splitkv_mla_kernel<cutlass::half_t>(Flash_fwd_mla_params &params, cudaStream_t stream);
#endif
#pragma once
#include "params.h"
template<typename InputT>
void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params &params, cudaStream_t stream);
#pragma once
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_types.h>
#include <cutlass/barrier.h>
#include "config.h"
using TMABarrier = cutlass::arch::ClusterTransactionBarrier;
using namespace cute;
template<typename InputT_>
struct Traits {
using InputT = InputT_;
static constexpr int BLOCK_SIZE_M = Config::BLOCK_SIZE_M;
static constexpr int PAGE_BLOCK_SIZE = Config::PAGE_BLOCK_SIZE;
static constexpr int HEAD_DIM_K = Config::HEAD_DIM_K;
static constexpr int HEAD_DIM_V = Config::HEAD_DIM_V;
static constexpr int NUM_THREADS = 256;
static_assert(std::is_same_v<InputT, cutlass::bfloat16_t> || std::is_same_v<InputT, cutlass::half_t>);
using TiledMMA_QK_sQ = decltype(make_tiled_mma(
GMMA::ss_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>, GMMA::Major::K, GMMA::Major::K>(),
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_QK_rQ = decltype(make_tiled_mma(
GMMA::rs_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>, GMMA::Major::K, GMMA::Major::K>(),
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_LocalP = decltype(make_tiled_mma(
GMMA::rs_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_V/2>, Int<PAGE_BLOCK_SIZE>>, GMMA::Major::K, GMMA::Major::MN>(),
Layout<Shape<_1, _1, _1>>{}
));
using TiledMMA_PV_RemoteP = decltype(make_tiled_mma(
GMMA::ss_op_selector<InputT, InputT, float, Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_V/2>, Int<PAGE_BLOCK_SIZE>>, GMMA::Major::K, GMMA::Major::MN>(),
Layout<Shape<_1, _1, _1>>{}
));
using SmemLayoutQ = decltype(tile_to_shape(
GMMA::Layout_K_SW128_Atom<InputT>{},
Shape<Int<BLOCK_SIZE_M>, Int<HEAD_DIM_K>>{}
));
using SmemLayoutK = decltype(tile_to_shape(
GMMA::Layout_K_SW128_Atom<InputT>{},
Shape<Int<PAGE_BLOCK_SIZE>, Int<HEAD_DIM_K>>{}
));
using SmemLayoutV = decltype(composition(
SmemLayoutK{},
make_layout(Shape<Int<HEAD_DIM_V>, Int<PAGE_BLOCK_SIZE>>{}, GenRowMajor{})
)); // A transposed version of SmemLayoutK
using SmemLayoutP0 = decltype(tile_to_shape(
GMMA::Layout_K_SW128_Atom<InputT>{},
Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>>{}
));
using rP0Layout = decltype(layout(partition_fragment_C(
TiledMMA_QK_sQ{},
Shape<Int<BLOCK_SIZE_M>, Int<PAGE_BLOCK_SIZE>>{}
)));
struct SharedMemoryPlan {
cute::array_aligned<InputT, cosize_v<SmemLayoutQ>> smem_sQ;
cute::array_aligned<InputT, cosize_v<SmemLayoutK>> smem_sK0;
cute::array_aligned<InputT, cosize_v<SmemLayoutK>> smem_sK1;
cute::array_aligned<InputT, cosize_v<SmemLayoutP0>> smem_sP0;
cute::array_aligned<float, BLOCK_SIZE_M> smem_sM;
cute::array_aligned<float, 2*BLOCK_SIZE_M> sL_reduction_wksp;
cute::array_aligned<float, BLOCK_SIZE_M> smem_sScale0;
cute::array_aligned<float, BLOCK_SIZE_M> smem_sScale1;
TMABarrier barriers_K0[HEAD_DIM_K/64];
TMABarrier barriers_K1[HEAD_DIM_K/64];
TMABarrier barrier_Q;
};
};
template<
typename ShapeQ, typename TMA_Q,
typename ShapeK, typename TMA_K,
typename ShapeO, typename TMA_O
>
struct TmaParams {
ShapeQ shape_Q;
TMA_Q tma_Q;
ShapeK shape_K;
TMA_K tma_K;
ShapeO shape_O;
TMA_O tma_O;
};
enum NamedBarriers : int {
sScale0Ready = 0,
sScale1Ready = 1,
sP0Ready = 2,
rO1sP0sV0RIssued = 3
};
......@@ -5,7 +5,7 @@
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
exit(1); \
exit(1); \
} \
} while(0)
......@@ -29,37 +29,4 @@
} \
} while(0)
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \
[&] { \
if (NUM_SPLITS <= 32) { \
constexpr static int NAME = 32; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 64) { \
constexpr static int NAME = 64; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 96) { \
constexpr static int NAME = 96; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 128) { \
constexpr static int NAME = 128; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 160) { \
constexpr static int NAME = 160; \
return __VA_ARGS__(); \
} else { \
FLASH_ASSERT(false); \
} \
}()
#define println(fmt, ...) { print(fmt, ##__VA_ARGS__); print("\n"); }
#pragma once
#include "cutlass/barrier.h"
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
// Enumerates the reserved named barriers to avoid potential conflicts
enum class NamedBarriers {
SReady = 1,
SoftmaxReady = 2,
};
} // flash
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/numeric_types.h>
#include "utils.h"
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); mi++) {
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
summary(mi) = op(summary(mi), tensor(mi, ni));
}
}
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll
for (int i = 0; i < size(dst); i++){
dst(i) = Allreduce<4>::run(src(i), op);
}
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
thread_reduce_<zero_init>(tensor, summary, op);
quad_allreduce_(summary, summary, op);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
MaxOp<float> max_op;
reduce_<zero_init>(tensor, max, max_op);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
SumOp<float> sum_op;
thread_reduce_<zero_init>(tensor, sum, sum_op);
}
// Apply the exp to all the elements.
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ auto scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// The following macro will disable the use of fma.
// See: https://github.com/pytorch/pytorch/issues/121558 for more details
// This macro is set in PyTorch and not FlashAttention
#ifdef UNFUSE_FMA
tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
#else
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
#endif
}
}
return tensor;
}
// Apply the exp to all the elements.
template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
MaxOp<float> max_op;
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
max(mi) = max_op(max(mi), tensor(mi, ni));
}
max(mi) = Allreduce<4>::run(max(mi), max_op);
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
sum(mi) = 0;
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
sum(mi) += tensor(mi, ni);
}
SumOp<float> sum_op;
sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
}
}
template<typename Tensor0, typename Tensor1>
__forceinline__ __device__ void rescale_o(Tensor0 &acc_o, Tensor1 &scale_o) {
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
#pragma unroll
for (int mi = 0; mi < size(scale_o); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale_o(mi); }
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int kNRows>
struct Softmax {
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
TensorT row_max, row_sum;
__forceinline__ __device__ Softmax() {};
template<bool Is_first, bool Check_inf=false, typename Tensor0>
__forceinline__ __device__ TensorT softmax(Tensor0 &acc_s, float softmax_scale_log2) {
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
TensorT scale_o;
clear(scale_o);
if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
scale_o(mi) = scores_scale;
row_sum(mi) *= scores_scale;
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
return scale_o;
};
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
TensorT lse = make_fragment_like(row_sum);
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
}
return lse;
};
};
} // namespace flash
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cuda_bf16.h>
#include <cute/tensor.hpp>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
// This is slightly faster
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC);
if constexpr (arrive) {
warpgroup_arrive();
}
if constexpr (zero_init) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
} else {
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
}
if constexpr (commit) {
warpgroup_commit_batch();
}
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
template<bool Transposed=false, typename Layout0>
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout0 acc_layout) {
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = acc_layout;
if constexpr (!Transposed) {
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
} else {
return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
}
} else { // SM80
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
if constexpr (!Transposed) {
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
} else {
return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
// For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
// For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N))
template<typename MMA_Traits, typename Layout0>
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout0 acc_layout) {
using X = Underscore;
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) {
auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16))
return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
} else {
static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1);
static_assert(decltype(stride<0, 0>(acc_layout))::value == 1);
static_assert(decltype(stride<0, 1>(acc_layout))::value == 2);
auto l = logical_divide(get<0, 2>(acc_layout), Tile<Layout<Shape<_2, _2>>>{}); // (((2, 2), N / 32))
// This combines the first two modes (<0, 0> and <0, 1>) into one mode.
// Will require register shuffling later to be correct.
return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)),
get<1>(acc_layout),
coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N)
// This combination is right but doesn't work with register shuffling.
// return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)),
// get<1>(acc_layout),
// coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
}
} else { // SM80
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{});
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
if constexpr (mma_shape_K == 8) {
return acc_layout;
} else {
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename To_type, typename Engine, typename Layout>
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
// HACK: this requires tensor to be "contiguous"
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Blocks until all but N previous cp.async.commit_group operations have committed.
// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
// (which is equivalent to commit_group then wait_group 0).
// Instead we just call cp.async.wait_group 0, which is slightly faster.
// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
template <int N>
CUTE_HOST_DEVICE
void cp_async_wait() {
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
// There's no case where !Clear_OOB_K && Clear_OOB_MN
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || predicate_K(k)) {
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k));
}
}
} else if (Clear_OOB_MN) {
cute::clear(D(_, m, _));
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment