Commit c28eca99 authored by Shengyu Liu's avatar Shengyu Liu
Browse files

Reorganize files and add sparse prefill/decoding kernels on hopper

parent 261330bb
...@@ -8,3 +8,4 @@ dist/ ...@@ -8,3 +8,4 @@ dist/
/.vscode /.vscode
compile_commands.json compile_commands.json
.cache .cache
/dev
# FlashMLA # FlashMLA
## Performance Update (2025.04.22) ## Introduction
We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement on compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Just switch to the new version and enjoy the instant speedup! 🚀🚀🚀 FlashMLA is DeepSeek's library of optimized attention kernels, powering the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) and [DeepSeek-V3.2](TODO) models. This repository contains the following implementations:
Besides, we'd love to share the technical details behind the new kernel! Check out our deep-dive write-up [here](docs/20250422-new-kernel-deep-dive.md). **Sparse Attention Kernels**
The new kernel primarily targets compute-intensive settings (where the number of q heads $\times$ the number of q tokens per request (if MTP is disabled then it's 1) $\ge 64$). For memory-bound cases, we recommend using version [b31bfe7](https://github.com/deepseek-ai/FlashMLA/tree/b31bfe72a83ea205467b3271a5845440a03ed7cb) for optimal performance. *These kernels power DeepSeek Sparse Attention (DSA), as introduced in [this paper](TODO).*
## Introduction - Token-level sparse attention for the prefill stage
- Token-level sparse attention for the decoding stage, with FP8 KV cache
FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving. **Dense Attention Kernels**
Currently released: - Dense attention for the prefill stage
- BF16, FP16 - Dense attention for the decoding stage
- Paged kvcache with block size of 64
## Requirements ## News
- Hopper GPUs - **2025.09.26(TODO) Release of Sparse Attention Kernels**: With the launch of [DeepSeek-V3.2](TODO), we are releasing the corresponding token-level sparse attention kernels. These kernels power the model's DeepSeek Sparse Attention (DSA) and achieve up to 640 TFlops during prefilling and 410 TFlops during decoding.
- CUDA 12.8 and above - **2025.08.01 Kernels for MHA on Blackwell**: Thanks to [NVIDIA's PR](https://github.com/deepseek-ai/FlashMLA/pull/76) for MHA forward / backward kernels on Blackwell!
- PyTorch 2.0 and above - **2025.04.22 Deep-Dive Blog**: We'd love to share the technical details behind the new FlashMLA kernel! Check out our deep-dive write-up [here](docs/20250422-new-kernel-deep-dive.md).
- **2025.04.22 Performance Update**: We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement for compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Simply upgrade to the new version for an immediate performance boost! 🚀🚀🚀
## Quick start ## Performance
### Install #### Test & benchmark MLA decoding (Sparse & Dense):
```bash ```bash
pip install -v . python tests/test_flash_mla_decoding.py
``` ```
### Benchmark The dense MLA decoding kernel can achieve up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. For token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16), it can achieve 410 TFLOPS in compute-bound configuration on H800 SXM5, CUDA 12.8.
#### Testing MLA Decoding #### Test & benchmark MHA prefill (Dense):
```bash ```bash
python tests/test_flash_mla_sm90.py python tests/test_fmha_sm100.py
``` ```
#### Testing MLA Forward/Backward It achieves up to 1460 TFlops in forward and 1000 TFlops in backward computation on B200, as reported by NVIDIA.
#### Test & benchmark MLA prefill (Sparse):
```bash ```bash
python tests/test_fmha_sm100.py python tests/test_flash_mla_prefill.py
``` ```
It is able up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. It achieves up to 640 TFlops in forward computation on H800 SXM5, CUDA 12.8.
## Requirements
- Hopper / Blackwell GPUs (See the support matrix below)
- CUDA 12.8 and above (CUDA 12.9+ is required for Blackwell kernels)
- PyTorch 2.0 and above
Note. For memory-bound cases, we recommend using version [b31bfe7](https://github.com/deepseek-ai/FlashMLA/tree/b31bfe72a83ea205467b3271a5845440a03ed7cb) for optimal performance. Support matrix:
### Usage | Kernel | GPU Architecture | MLA Mode [2] | KVCache Format |
| :---: | :---: | :---: | :---: |
| Dense Decoding | Hopper | MQA | BF16 |
| Sparse Decoding | Hopper | MQA | FP8 [1] |
| Dense Prefill | Blackwell | MHA | |
| Sparse Prefill | Hopper | MQA | |
[1]: For more details on using FP8 KV cache, see documents below.
[2]: Here "MLA Mode" refers to the mode used for MLA calculation. MQA stands for Multi-Query Attention mode (i.e. `head_dim_k` = 576 with `head_dim_v` = 512), while MHA stands for Multi-Head Attention mode (i.e. `head_dim_k` = 192 / 128 with `head_dim_v` = 128). For a detailed explanation of these modes, please refer to the appendix of [DeepSeek V3.2's Paper](TODO).
## Installation
```bash
git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla
cd flash-mla
git submodule update --init --recursive
pip install -v .
```
## Usage
### MLA Decoding
To use the MLA decoding kernels, call get_mla_metadata once before the decoding loop to get the tile scheduler metadata. Then, call flash_mla_with_kvcache in each decoding step. For example:
```python ```python
from flash_mla import get_mla_metadata, flash_mla_with_kvcache from flash_mla import get_mla_metadata, flash_mla_with_kvcache
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens,
s_q * h_q // h_kv,
h_kv,
h_q,
is_fp8,
topk,
)
for i in range(num_layers): for i in range(num_layers):
... ...
o_i, lse_i = flash_mla_with_kvcache( o_i, lse_i = flash_mla_with_kvcache(
q_i, kvcache_i, block_table, cache_seqlens, dv, q_i, kvcache_i, block_table, cache_seqlens, dv,
tile_scheduler_metadata, num_splits, causal=True, tile_scheduler_metadata, num_splits,
is_causal, is_fp8_kvcache, indices,
) )
... ...
``` ```
Where
- `s_q` is the number of q tokens per q sequence. If MTP (speculative decoding) is disabled, it should be 1.
- `h_kv` is the number of key-value heads.
- `h_q` is the number of query heads.
**FP8 KV Cache:**
If `is_fp8_kvcache` is set to `True`, the kernel reads the KV cache in the "FP8 with scale" format (described below). It dequantizes the cache to bfloat16 and performs attention computation in bfloat16. The output is also in bfloat16.
In the "FP8 with scale" format, each token's KV cache is 656 Bytes, structured as:
- **First 512 bytes:** The "quantized NoPE" part, containing 512 `float8_e4m3` values.
- **Next 16 bytes:** Scale factors, containing 4 `float32` values. The first `float32` is the scale for the first 128 `float8_e4m3` values, the second for the next 128, and so on.
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This part is not quantized for accuracy.
See `tests/quant.py` for quantization and dequantization details.
**Sparse Attention (`indices` tensor):**
The `indices` tensor (if provided) enables token-level sparse attention by instructing the kernel to compute attention only for specified tokens.
- **Shape:** `indices` should be a 3D tensor of shape `(batch_size, seq_len_q, topk)`.
- **Format:** `indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * page_block_size + (the offset of token t within the page block)`, where `t` is the k-th token for the j-th query sequence in the i-th batch. Since the index of the page block has already been encoded into `indices_in_kvcache`, the kernel does not require the `block_table` parameter.
- **Invalid entries:** Set invalid indices to `-1`.
**Return Values:**
The kernel returns `(out, lse)`, where:
- `out` is the attention result.
- `lse` is the log-sum-exp value of the attention scores for each query head.
See `tests/test_flash_mla_decoding.py` for a complete example.
### Sparse MLA Prefill
For the sparse MLA prefill kernel, call `flash_mla_sparse_fwd` directly with the following parameters:
- `q`: Query tensor of shape `[s_q, h_q, d_qk]`
- `kv`: Key-Value tensor of shape `[s_kv, h_kv, d_qk]`
- `indices`: Indices tensor of shape `[s_q, h_kv, topk]`
- `sm_scale`: A scalar value
**Note on batching:** This kernel does not support a batch dimension. For multi-batch inference, reshape the input tensors and adjust the `indices` parameter to simulate batch processing.
**Invalid indices:** Set invalid entries in `indices` to `-1` or any number `>= s_kv`.
**Return Values and Equivalent PyTorch Code:**
The kernel returns `(out, max_logits, lse)`. This is equivalent to the following PyTorch operations:
```python
Q: [s_q, h_q, d_qk], bfloat16
kv: [s_kv, h_kv, d_qk], bfloat16
indices: [s_q, h_kv, topk], int32
kv = kv.squeeze(1) # [s_kv, d_qk], h_kv must be 1
indices = indices.squeeze(1) # [s_q, topk]
focused_kv = kv[indices] # For the i-th sequence (s_q), the corresponding KV tokens are selected from the KV cache based on indices[i, :]. This operation results in a tensor of shape [s_q, topk, d_qk].
P = (Q @ focused_kv.transpose(-1, -2)) * sm_scale * math.log2(math.e) # [s_q, h_q, topk]
max_logits = P.max(dim=-1) # [s_q, h_q]
lse = log2sumexp2(P, dim=-1, base=2) # [s_q, h_q],"log2sumexp2" means that the exponentiation and logarithm are base-2
S = exp2(P - lse) # [s_q, h_q, topk]
out = S @ focused_kv # [s_q, h_q, d_qk]
return (out, max_logits, lse)
```
See `tests/test_flash_mla_prefill.py` for a complete example.
### Dense MHA Prefill
This kernel implements the standard dense Multi-Head Attention (MHA) forward and backward operations. It can be called using:
- `flash_attn_varlen_func`
- `flash_attn_varlen_qkvpacked_func`
- `flash_attn_varlen_kvpacked_func`
The usage is similar to the `flash_attn` package. See `tests/test_fmha_sm100.py` for a complete example.
## Acknowledgement ## Acknowledgement
FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) and [cutlass](https://github.com/nvidia/cutlass) projects. FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) and [cutlass](https://github.com/nvidia/cutlass) projects.
...@@ -109,7 +224,7 @@ The corresponding FlashMLA version can be found at: [AITER/MLA](https://github.c ...@@ -109,7 +224,7 @@ The corresponding FlashMLA version can be found at: [AITER/MLA](https://github.c
```bibtex ```bibtex
@misc{flashmla2025, @misc{flashmla2025,
title={FlashMLA: Efficient MLA decoding kernels}, title={FlashMLA: Efficient Multi-head Latent Attention Kernels},
author={Jiashi Li, Shengyu Liu}, author={Jiashi Li, Shengyu Liu},
year={2025}, year={2025},
publisher = {GitHub}, publisher = {GitHub},
......
#pragma once #pragma once
//////////////////////////////////////////////////////////////////////////////////////////////////// #include "cutlass/bfloat16.h"
struct Flash_fwd_mla_params { struct DecodingParams {
using index_t = int64_t; using index_t = int64_t;
int b; // batch size int b; // batch size
...@@ -14,11 +14,13 @@ struct Flash_fwd_mla_params { ...@@ -14,11 +14,13 @@ struct Flash_fwd_mla_params {
int q_head_per_hk; // The number of q_head(s) per KV head, = h_q / h_k int q_head_per_hk; // The number of q_head(s) per KV head, = h_q / h_k
bool is_causal; bool is_causal;
float scale_softmax, scale_softmax_log2; float scale_softmax, scale_softmax_log2;
int topk;
void *__restrict__ q_ptr; void *__restrict__ q_ptr;
void *__restrict__ k_ptr; void *__restrict__ k_ptr;
void *__restrict__ o_ptr; void *__restrict__ o_ptr;
void *__restrict__ softmax_lse_ptr; void *__restrict__ softmax_lse_ptr;
int *__restrict__ indices_ptr;
index_t q_batch_stride; index_t q_batch_stride;
index_t k_batch_stride; index_t k_batch_stride;
...@@ -29,6 +31,8 @@ struct Flash_fwd_mla_params { ...@@ -29,6 +31,8 @@ struct Flash_fwd_mla_params {
index_t q_head_stride; index_t q_head_stride;
index_t k_head_stride; index_t k_head_stride;
index_t o_head_stride; index_t o_head_stride;
index_t indices_batch_stride;
index_t indices_row_stride;
int *__restrict__ block_table; int *__restrict__ block_table;
index_t block_table_batch_stride; index_t block_table_batch_stride;
...@@ -45,9 +49,9 @@ struct Flash_fwd_mla_params { ...@@ -45,9 +49,9 @@ struct Flash_fwd_mla_params {
}; };
static constexpr int TileSchedulerMetaDataSize = 8; static constexpr int TileSchedulerMetaDataSize = 8;
// [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _] // [begin_idx (inclusive), begin_block_idx (inclusive), end_idx (inclusive), end_block_idx (exclusive), begin_n_split_idx, _, _, _]
struct Mla_metadata_params { struct GetDecodingMetadataParams {
int *__restrict__ seqlens_k_ptr; int *__restrict__ seqlens_k_ptr;
int *__restrict__ tile_scheduler_metadata_ptr; int *__restrict__ tile_scheduler_metadata_ptr;
int *__restrict__ num_splits_ptr; int *__restrict__ num_splits_ptr;
...@@ -55,4 +59,26 @@ struct Mla_metadata_params { ...@@ -55,4 +59,26 @@ struct Mla_metadata_params {
int block_size_n; int block_size_n;
int fixed_overhead_num_blocks; int fixed_overhead_num_blocks;
int num_sm_parts; int num_sm_parts;
int topk;
};
struct SparsePrefillParams {
int s_q, s_kv, h_q, h_kv, d_qk, d_v, topk;
float sm_scale, sm_scale_div_log2;
// Input tensors
cutlass::bfloat16_t* __restrict__ q; // [s_q, h_q, d_qk]
cutlass::bfloat16_t* __restrict__ kv; // [s_kv, h_kv, d_qk]
int* __restrict__ indices; // [s_q, h_kv, topk]
int stride_q_s_q; int stride_q_h_q;
int stride_kv_s_kv; int stride_kv_h_kv;
int stride_indices_s_q; int stride_indices_h_kv;
// Output tensors
cutlass::bfloat16_t* __restrict__ out; // [s_q, h_q, d_v]
float* __restrict__ max_logits; // [s_q, h_q]
float* __restrict__ lse; // [s_q, h_q]
cudaStream_t stream;
}; };
...@@ -10,25 +10,104 @@ ...@@ -10,25 +10,104 @@
#include <cutlass/fast_math.h> #include <cutlass/fast_math.h>
#include "kernels/config.h" #include "params.h"
#include "kernels/get_mla_metadata.h" #include "smxx/get_mla_metadata.h"
#include "kernels/mla_combine.h" #include "smxx/mla_combine.h"
#include "kernels/params.h" #include "sm90/decode/dense/splitkv_mla.h"
#include "kernels/splitkv_mla.h" #include "sm90/decode/sparse_fp8/splitkv_mla.h"
#include "sm90/prefill/sparse/fwd.h"
#include "sm100/prefill/dense/interface.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
struct Arch {
int major;
int minor;
bool is_sm90() const {
return major == 9 && minor == 0;
}
bool is_sm100() const {
return major == 10 && minor == 0;
}
void assert_is_supported() const {
TORCH_CHECK(is_sm90() || is_sm100(), "Only SM90 and SM100 are supported");
}
};
// DecodingAttnImplMeta - A struct to hold metadata for Decoding Attention Implementation (i.e. Hopper Dense BF16, Hopper Sparse FP8, etc.)
struct DecodingAttnImplMeta {
int num_sm_parts;
int fixed_overhead_num_blocks;
int k_block_size;
};
DecodingAttnImplMeta get_attn_impl_meta(
Arch arch,
int sm_count,
int num_q_tokens_per_head_k,
int h_k,
std::optional<int> h_q_,
bool is_fp8_kvcache,
bool is_sparse_attn
) {
if (arch.is_sm90()) {
if (is_sparse_attn) {
if (is_fp8_kvcache) {
TORCH_CHECK(h_q_.has_value());
int h_q = h_q_.value();
TORCH_CHECK(h_q % h_k == 0);
int s_q = num_q_tokens_per_head_k * h_k / h_q;
// FP8 + Sparse MLA
return {
std::max((sm_count/2) / h_k / (cutlass::ceil_div(h_q/h_k, 2*64) * s_q), 1),
5,
64
};
} else {
// Sparse BF16 MLA
TORCH_CHECK(false, "Sparse BF16 MLA is not supported on SM90");
}
} else {
if (is_fp8_kvcache) {
// Dense FP8 MLA
TORCH_CHECK(false, "Dense FP8 MLA is not supported on SM90");
} else {
// Dense BF16 MLA
return {
std::max(sm_count / h_k / cutlass::ceil_div(num_q_tokens_per_head_k, 64), 1),
5,
64
};
}
}
} else if (arch.is_sm100()) {
TORCH_CHECK(false, "Unsupported GPU architecture");
} else {
TORCH_CHECK(false, "Unsupported GPU architecture");
}
}
std::vector<at::Tensor> std::vector<at::Tensor>
get_mla_metadata( get_mla_decoding_metadata(
at::Tensor &seqlens_k, at::Tensor &seqlens_k,
const int num_heads_per_head_k, const int num_q_tokens_per_head_k,
const int num_heads_k const int h_k,
const std::optional<int> h_q,
const bool is_fp8_kvcache,
const std::optional<int> topk
) { ) {
bool is_sparse_attn = topk.has_value();
CHECK_DEVICE(seqlens_k); CHECK_DEVICE(seqlens_k);
TORCH_CHECK(seqlens_k.is_contiguous()); TORCH_CHECK(seqlens_k.is_contiguous());
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32); TORCH_CHECK(seqlens_k.dtype() == torch::kInt32);
if (is_sparse_attn)
TORCH_CHECK(h_q.has_value(), "num_heads_q must be provided when topk is provided");
int batch_size = seqlens_k.size(0); int batch_size = seqlens_k.size(0);
int *seqlens_k_ptr = seqlens_k.data_ptr<int>(); int *seqlens_k_ptr = seqlens_k.data_ptr<int>();
...@@ -36,53 +115,67 @@ get_mla_metadata( ...@@ -36,53 +115,67 @@ get_mla_metadata(
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
int sm_count = dprops->multiProcessorCount; int sm_count = dprops->multiProcessorCount;
int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, Config::BLOCK_SIZE_M); Arch arch = {dprops->major, dprops->minor};
arch.assert_is_supported();
DecodingAttnImplMeta attn_impl_meta = get_attn_impl_meta(arch, sm_count, num_q_tokens_per_head_k, h_k, h_q, is_fp8_kvcache, is_sparse_attn);
auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options); auto tile_scheduler_metadata = torch::empty({attn_impl_meta.num_sm_parts, TileSchedulerMetaDataSize}, options);
auto num_splits = torch::empty({batch_size + 1}, options); auto num_splits = torch::empty({batch_size + 1}, options);
int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>(); int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr<int>();
int *num_splits_ptr = num_splits.data_ptr<int>(); int *num_splits_ptr = num_splits.data_ptr<int>();
at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()}; at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()};
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
Mla_metadata_params params = {}; GetDecodingMetadataParams params = {};
params.seqlens_k_ptr = seqlens_k_ptr; params.seqlens_k_ptr = seqlens_k_ptr;
params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr; params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr;
params.num_splits_ptr = num_splits_ptr; params.num_splits_ptr = num_splits_ptr;
params.batch_size = batch_size; params.batch_size = batch_size;
params.block_size_n = Config::PAGE_BLOCK_SIZE; params.block_size_n = attn_impl_meta.k_block_size;
params.fixed_overhead_num_blocks = Config::FIXED_OVERHEAD_NUM_BLOCKS; params.fixed_overhead_num_blocks = attn_impl_meta.fixed_overhead_num_blocks;
params.num_sm_parts = num_sm_parts; params.num_sm_parts = attn_impl_meta.num_sm_parts;
params.topk = is_sparse_attn ? topk.value() : -1;
run_get_mla_metadata_kernel(params, stream); run_get_mla_metadata_kernel(params, stream);
return {tile_scheduler_metadata, num_splits}; return {tile_scheduler_metadata, num_splits};
} }
std::vector<at::Tensor> std::vector<at::Tensor>
mha_fwd_kvcache_mla( fwd_kvcache_mla(
at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True)
const int head_size_v, const int head_size_v,
const at::Tensor &seqlens_k, // batch_size const at::Tensor &seqlens_k, // batch_size
const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq
const float softmax_scale, const float softmax_scale,
bool is_causal, bool is_causal,
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
const at::Tensor &num_splits // batch_size + 1 const at::Tensor &num_splits, // batch_size + 1
const bool &is_fp8,
const std::optional<at::Tensor> &indices // None, or batch_size x seqlen_q x topk
) { ) {
bool is_sparse_attn = indices.has_value();
int topk = is_sparse_attn ? indices->size(-1) : -1;
// Check the architecture // Check the architecture
auto dprops = at::cuda::getCurrentDeviceProperties(); auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm90 = dprops->major == 9 && dprops->minor == 0; Arch arch = {dprops->major, dprops->minor};
TORCH_CHECK(is_sm90); arch.assert_is_supported();
// Check data types // Check data types
auto q_dtype = q.dtype(); auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf); TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf);
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
if (!is_fp8) {
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
} else {
TORCH_CHECK(kcache.dtype() == torch::kFloat8_e4m3fn || kcache.dtype() == torch::kInt8 || kcache.dtype() == torch::kUInt8, "key must have dtype fp8_e4m3fn or int8 or uint8");
}
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32"); TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32");
TORCH_CHECK(!is_sparse_attn || indices->dtype() == torch::kInt32, "indices must have dtype int32");
// Check device // Check device
CHECK_DEVICE(q); CHECK_DEVICE(q);
...@@ -91,14 +184,16 @@ mha_fwd_kvcache_mla( ...@@ -91,14 +184,16 @@ mha_fwd_kvcache_mla(
CHECK_DEVICE(block_table); CHECK_DEVICE(block_table);
CHECK_DEVICE(tile_scheduler_metadata); CHECK_DEVICE(tile_scheduler_metadata);
CHECK_DEVICE(num_splits); CHECK_DEVICE(num_splits);
if (is_sparse_attn) CHECK_DEVICE(indices.value());
// Check layout // Check layout
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(q.stride(-1) == 1, "q must have contiguous last dimension");
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(kcache.stride(-1) == 1, "kcache must have contiguous last dimension");
CHECK_CONTIGUOUS(seqlens_k); CHECK_CONTIGUOUS(seqlens_k);
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
CHECK_CONTIGUOUS(tile_scheduler_metadata); CHECK_CONTIGUOUS(tile_scheduler_metadata);
CHECK_CONTIGUOUS(num_splits); CHECK_CONTIGUOUS(num_splits);
TORCH_CHECK(!is_sparse_attn || indices->stride(-1) == 1, "indices must have contiguous last dimension");
const auto sizes = q.sizes(); const auto sizes = q.sizes();
const int batch_size = sizes[0]; const int batch_size = sizes[0];
...@@ -112,7 +207,8 @@ mha_fwd_kvcache_mla( ...@@ -112,7 +207,8 @@ mha_fwd_kvcache_mla(
const int num_blocks = kcache.size(0); const int num_blocks = kcache.size(0);
const int page_block_size = kcache.size(1); const int page_block_size = kcache.size(1);
const int num_heads_k = kcache.size(2); const int num_heads_k = kcache.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be postive"); TORCH_CHECK(page_block_size == 64, "Currently page_block_size must be 64");
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q_ori == 1) { is_causal = false; } if (seqlen_q_ori == 1) { is_causal = false; }
...@@ -124,11 +220,19 @@ mha_fwd_kvcache_mla( ...@@ -124,11 +220,19 @@ mha_fwd_kvcache_mla(
.reshape({batch_size, q_seq_per_hk, num_heads, head_size_k}); .reshape({batch_size, q_seq_per_hk, num_heads, head_size_k});
CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k); CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k);
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); if (!is_fp8) {
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k);
} else {
int bytes_per_token = 512 + 64*2 + (512/128)*4;
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, bytes_per_token);
TORCH_CHECK(num_heads_k == 1, "Currently the number of k heads must be 1 when is_fp8_kvcache is True");
TORCH_CHECK(kcache.stride(1) == bytes_per_token, "The whole block must be contiguous when is_fp8_cache is True");
}
CHECK_SHAPE(seqlens_k, batch_size); CHECK_SHAPE(seqlens_k, batch_size);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
CHECK_SHAPE(num_splits, batch_size+1); CHECK_SHAPE(num_splits, batch_size+1);
if (is_sparse_attn) CHECK_SHAPE(indices.value(), batch_size, seqlen_q_ori, topk);
at::cuda::CUDAGuard device_guard{(char)q.get_device()}; at::cuda::CUDAGuard device_guard{(char)q.get_device()};
...@@ -137,7 +241,7 @@ mha_fwd_kvcache_mla( ...@@ -137,7 +241,7 @@ mha_fwd_kvcache_mla(
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
CHECK_CONTIGUOUS(softmax_lse); CHECK_CONTIGUOUS(softmax_lse);
Flash_fwd_mla_params params = {}; DecodingParams params = {};
// Set the sizes. // Set the sizes.
params.b = batch_size; params.b = batch_size;
params.s_q = seqlen_q_ori; params.s_q = seqlen_q_ori;
...@@ -152,21 +256,25 @@ mha_fwd_kvcache_mla( ...@@ -152,21 +256,25 @@ mha_fwd_kvcache_mla(
params.d_v = head_size_v; params.d_v = head_size_v;
params.scale_softmax = softmax_scale; params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E); params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
params.topk = topk;
// Set the pointers and strides. // Set the pointers and strides.
params.q_ptr = q.data_ptr(); params.q_ptr = q.data_ptr();
params.k_ptr = kcache.data_ptr(); params.k_ptr = kcache.data_ptr();
params.o_ptr = out.data_ptr(); params.o_ptr = out.data_ptr();
params.indices_ptr = is_sparse_attn ? indices->data_ptr<int>() : nullptr;
params.softmax_lse_ptr = softmax_lse.data_ptr(); params.softmax_lse_ptr = softmax_lse.data_ptr();
// All stride are in elements, not bytes. // All stride are in elements, not bytes.
params.q_batch_stride = q.stride(0); params.q_batch_stride = q.stride(0);
params.k_batch_stride = kcache.stride(0); params.k_batch_stride = kcache.stride(0);
params.o_batch_stride = out.stride(0); params.o_batch_stride = out.stride(0);
params.q_row_stride = q.stride(-3); params.q_row_stride = q.stride(-3);
params.k_row_stride = kcache.stride(-3); params.k_row_stride = kcache.stride(1);
params.o_row_stride = out.stride(-3); params.o_row_stride = out.stride(-3);
params.q_head_stride = q.stride(-2); params.q_head_stride = q.stride(-2);
params.k_head_stride = kcache.stride(-2); params.k_head_stride = kcache.stride(2);
params.o_head_stride = out.stride(-2); params.o_head_stride = out.stride(-2);
params.indices_batch_stride = is_sparse_attn ? indices->stride(0) : 0;
params.indices_row_stride = is_sparse_attn ? indices->stride(1) : 0;
params.block_table = block_table.data_ptr<int>(); params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0); params.block_table_batch_stride = block_table.stride(0);
...@@ -187,14 +295,46 @@ mha_fwd_kvcache_mla( ...@@ -187,14 +295,46 @@ mha_fwd_kvcache_mla(
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(head_size_k == 576); TORCH_CHECK(head_size_k == 576);
if (q_dtype == torch::kHalf) {
#ifdef FLASH_MLA_DISABLE_FP16
TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA.");
#endif
}
if (arch.is_sm90()) {
if (is_sparse_attn) {
if (is_fp8) {
TORCH_CHECK(q_dtype == torch::kBFloat16, "Sparse FP8 MLA only supports BFloat16 on SM90");
sm90::run_flash_splitkv_mla_fp8_sparse_kernel(params, stream);
} else {
TORCH_CHECK(false, "Dense FP8 MLA is not supported on SM90");
}
} else {
if (is_fp8) {
TORCH_CHECK(false, "Dense FP8 MLA is not supported on SM90");
} else {
if (q_dtype == torch::kBFloat16) {
sm90::run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(params, stream);
} else if (q_dtype == torch::kHalf) {
#ifndef FLASH_MLA_DISABLE_FP16
sm90::run_flash_splitkv_mla_kernel<cutlass::half_t>(params, stream);
#endif
} else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
}
}
}
} else if (arch.is_sm100()) {
TORCH_CHECK(false, "Unsupported GPU architecture");
} else {
TORCH_CHECK(false, "Unsupported GPU architecture");
}
if (q_dtype == torch::kBFloat16) { if (q_dtype == torch::kBFloat16) {
run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(params, stream);
run_flash_mla_combine_kernel<cutlass::bfloat16_t>(params, stream); run_flash_mla_combine_kernel<cutlass::bfloat16_t>(params, stream);
} else if (q_dtype == torch::kHalf) { } else if (q_dtype == torch::kHalf) {
#ifdef FLASH_MLA_DISABLE_FP16 #ifndef FLASH_MLA_DISABLE_FP16
TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP16. Please remove this flag from your environment and re-compile FlashMLA.");
#else
run_flash_splitkv_mla_kernel<cutlass::half_t>(params, stream);
run_flash_mla_combine_kernel<cutlass::half_t>(params, stream); run_flash_mla_combine_kernel<cutlass::half_t>(params, stream);
#endif #endif
} else { } else {
...@@ -209,8 +349,94 @@ mha_fwd_kvcache_mla( ...@@ -209,8 +349,94 @@ mha_fwd_kvcache_mla(
return {out, softmax_lse}; return {out, softmax_lse};
} }
inline int int64_stride_to_int(int64_t orig_stride) {
if (orig_stride > std::numeric_limits<int>::max()) {
TORCH_CHECK(false, "[Sparse TopK Attention] Stride exceeds int32 limit: ", orig_stride);
}
return static_cast<int>(orig_stride);
}
std::vector<at::Tensor> sparse_prefill_fwd(
const at::Tensor &q,
const at::Tensor &kv,
const at::Tensor &indices,
float sm_scale,
int d_v
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm90 = dprops->major == 9;
TORCH_CHECK(is_sm90, "Sparse Attention Forward Kernel (sparse_prefill_fwd) is only supported on SM90 architectures");
CHECK_DEVICE(q);
CHECK_DEVICE(kv);
CHECK_DEVICE(indices);
TORCH_CHECK(q.dtype() == torch::kBFloat16);
TORCH_CHECK(kv.dtype() == torch::kBFloat16);
TORCH_CHECK(indices.dtype() == torch::kInt32);
int s_q = q.size(0);
int s_kv = kv.size(0);
int h_q = q.size(1);
int h_kv = kv.size(1);
int d_qk = q.size(2);
int topk = indices.size(2);
CHECK_SHAPE(q, s_q, h_q, d_qk);
CHECK_SHAPE(kv, s_kv, h_kv, d_qk);
CHECK_SHAPE(indices, s_q, h_kv, topk);
TORCH_CHECK(q.stride(-1) == 1);
TORCH_CHECK(kv.stride(-1) == 1);
TORCH_CHECK(indices.stride(-1) == 1);
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
at::Tensor out = torch::empty({s_q, h_q, d_v}, opts);
CHECK_CONTIGUOUS(out);
at::Tensor buf_attn_score, max_logits, lse, p_sum;
max_logits = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat));
lse = torch::empty({s_q, h_q}, opts.dtype(torch::kFloat));
CHECK_CONTIGUOUS(max_logits);
CHECK_CONTIGUOUS(lse);
SparsePrefillParams params = {
s_q, s_kv, h_q, h_kv, d_qk, d_v, topk,
sm_scale, sm_scale * 1.44269504f,
(cutlass::bfloat16_t*)q.data_ptr(),
(cutlass::bfloat16_t*)kv.data_ptr(),
(int*)indices.data_ptr(),
int64_stride_to_int(q.stride(0)), int64_stride_to_int(q.stride(1)),
int64_stride_to_int(kv.stride(0)), int64_stride_to_int(kv.stride(1)),
int64_stride_to_int(indices.stride(0)), int64_stride_to_int(indices.stride(1)),
(cutlass::bfloat16_t*)out.data_ptr(),
(float*)max_logits.data_ptr(),
(float*)lse.data_ptr(),
at::cuda::getCurrentCUDAStream().stream()
};
if (is_sm90) {
sm90::run_fwd_kernel(params);
} else {
TORCH_CHECK(false, "Unknown architecture");
}
return {out, max_logits, lse};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashMLA"; m.doc() = "FlashMLA";
m.def("get_mla_metadata", &get_mla_metadata); m.def("get_mla_decoding_metadata", &get_mla_decoding_metadata);
m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla); m.def("fwd_kvcache_mla", &fwd_kvcache_mla);
m.def("dense_prefill_fwd", &FMHACutlassSM100FwdRun);
m.def("dense_prefill_bwd", &FMHACutlassSM100BwdRun);
m.def("sparse_prefill_fwd", &sparse_prefill_fwd);
} }
...@@ -37,9 +37,9 @@ ...@@ -37,9 +37,9 @@
#include "cute/tensor.hpp" #include "cute/tensor.hpp"
#include "cute/layout.hpp" #include "cute/layout.hpp"
#include "collective/fmha_common.hpp" #include "../collective/fmha_common.hpp"
#include "collective/fmha_fusion.hpp" #include "../collective/fmha_fusion.hpp"
#include "collective/sm100_fmha_load_tma_warpspecialized.hpp" #include "../collective/sm100_fmha_load_tma_warpspecialized.hpp"
namespace cutlass::fmha::collective { namespace cutlass::fmha::collective {
......
...@@ -36,8 +36,8 @@ ...@@ -36,8 +36,8 @@
#include "cute/tensor.hpp" #include "cute/tensor.hpp"
#include "cute/layout.hpp" #include "cute/layout.hpp"
#include "collective/fmha_common.hpp" #include "../collective/fmha_common.hpp"
#include "collective/fmha_fusion.hpp" #include "../collective/fmha_fusion.hpp"
namespace cutlass::fmha::collective { namespace cutlass::fmha::collective {
......
...@@ -37,10 +37,10 @@ ...@@ -37,10 +37,10 @@
#include "cute/tensor.hpp" #include "cute/tensor.hpp"
#include "cute/layout.hpp" #include "cute/layout.hpp"
#include "collective/fmha_common.hpp" #include "../collective/fmha_common.hpp"
#include "collective/fmha_fusion.hpp" #include "../collective/fmha_fusion.hpp"
#include "collective/sm100_fmha_mla_load_tma_warpspecialized.hpp" #include "../collective/sm100_fmha_mla_load_tma_warpspecialized.hpp"
#include "common/pipeline_mla.hpp" #include "../common/pipeline_mla.hpp"
namespace cutlass::fmha::collective { namespace cutlass::fmha::collective {
......
...@@ -36,8 +36,8 @@ ...@@ -36,8 +36,8 @@
#include "cute/tensor.hpp" #include "cute/tensor.hpp"
#include "cute/layout.hpp" #include "cute/layout.hpp"
#include "collective/fmha_common.hpp" #include "../collective/fmha_common.hpp"
#include "collective/fmha_fusion.hpp" #include "../collective/fmha_fusion.hpp"
namespace cutlass::fmha::collective { namespace cutlass::fmha::collective {
......
#include <Python.h> #include "interface.h"
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h> #include <c10/cuda/CUDAStream.h>
#include <torch/library.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include "common/mask.cuh" #include "common/mask.cuh"
#include "common/utils.hpp" #include "common/utils.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment