README.md 9.94 KB
Newer Older
Jiashi Li's avatar
Jiashi Li committed
1
2
# FlashMLA

3
## Introduction
4

5
FlashMLA is DeepSeek's library of optimized attention kernels, powering the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) and [DeepSeek-V3.2](TODO) models. This repository contains the following implementations:
6

7
**Sparse Attention Kernels**
8

9
*These kernels power DeepSeek Sparse Attention (DSA), as introduced in [this paper](TODO).*
10

11
12
- Token-level sparse attention for the prefill stage
- Token-level sparse attention for the decoding stage, with FP8 KV cache
13

14
**Dense Attention Kernels**
Jiashi Li's avatar
Jiashi Li committed
15

16
17
- Dense attention for the prefill stage
- Dense attention for the decoding stage
Jiashi Li's avatar
Jiashi Li committed
18

19
## News
20

21
22
23
24
- **2025.09.26(TODO) Release of Sparse Attention Kernels**: With the launch of [DeepSeek-V3.2](TODO), we are releasing the corresponding token-level sparse attention kernels. These kernels power the model's DeepSeek Sparse Attention (DSA) and achieve up to 640 TFlops during prefilling and 410 TFlops during decoding.
- **2025.08.01 Kernels for MHA on Blackwell**: Thanks to [NVIDIA's PR](https://github.com/deepseek-ai/FlashMLA/pull/76) for MHA forward / backward kernels on Blackwell!
- **2025.04.22 Deep-Dive Blog**: We'd love to share the technical details behind the new FlashMLA kernel! Check out our deep-dive write-up [here](docs/20250422-new-kernel-deep-dive.md).
- **2025.04.22 Performance Update**: We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement for compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Simply upgrade to the new version for an immediate performance boost! 🚀🚀🚀
25

26
## Performance
Jiashi Li's avatar
Jiashi Li committed
27

28
#### Test & benchmark MLA decoding (Sparse & Dense):
Jiashi Li's avatar
Jiashi Li committed
29
30

```bash
31
python tests/test_flash_mla_decoding.py
Jiashi Li's avatar
Jiashi Li committed
32
33
```

34
The dense MLA decoding kernel can achieve up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5, using CUDA 12.8. For token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16), it can achieve 410 TFLOPS in compute-bound configuration on H800 SXM5, CUDA 12.8.
Jiashi Li's avatar
Jiashi Li committed
35

36
#### Test & benchmark MHA prefill (Dense):
37
38

```bash
39
python tests/test_fmha_sm100.py
40
41
```

42
43
44
It achieves up to 1460 TFlops in forward and 1000 TFlops in backward computation on B200, as reported by NVIDIA.

#### Test & benchmark MLA prefill (Sparse):
45

Jiashi Li's avatar
Jiashi Li committed
46
```bash
47
python tests/test_flash_mla_prefill.py
Jiashi Li's avatar
Jiashi Li committed
48
49
```

50
51
52
53
54
55
56
It achieves up to 640 TFlops in forward computation on H800 SXM5, CUDA 12.8.

## Requirements

- Hopper / Blackwell GPUs (See the support matrix below)
- CUDA 12.8 and above (CUDA 12.9+ is required for Blackwell kernels)
- PyTorch 2.0 and above
57

58
Support matrix:
Jiashi Li's avatar
Jiashi Li committed
59

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
| Kernel | GPU Architecture | MLA Mode [2] | KVCache Format |
| :---: | :---: | :---: | :---: |
| Dense Decoding | Hopper | MQA | BF16 |
| Sparse Decoding | Hopper | MQA | FP8 [1] |
| Dense Prefill | Blackwell | MHA |  |
| Sparse Prefill | Hopper | MQA |  |

[1]: For more details on using FP8 KV cache, see documents below.

[2]: Here "MLA Mode" refers to the mode used for MLA calculation. MQA stands for Multi-Query Attention mode (i.e. `head_dim_k` =  576 with `head_dim_v` = 512), while MHA stands for Multi-Head Attention mode (i.e. `head_dim_k` = 192 / 128 with `head_dim_v` = 128). For a detailed explanation of these modes, please refer to the appendix of [DeepSeek V3.2's Paper](TODO).

## Installation

```bash
git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla
cd flash-mla
git submodule update --init --recursive
pip install -v .
```

## Usage

### MLA Decoding

To use the MLA decoding kernels, call get_mla_metadata once before the decoding loop to get the tile scheduler metadata. Then, call flash_mla_with_kvcache in each decoding step. For example:
Jiashi Li's avatar
Jiashi Li committed
85
86
87
88

```python
from flash_mla import get_mla_metadata, flash_mla_with_kvcache

89
90
91
92
93
94
95
96
tile_scheduler_metadata, num_splits = get_mla_metadata(
    cache_seqlens,
    s_q * h_q // h_kv,
    h_kv,
    h_q,
    is_fp8,
    topk,
)
Jiashi Li's avatar
Jiashi Li committed
97
98
99
100
101

for i in range(num_layers):
    ...
    o_i, lse_i = flash_mla_with_kvcache(
        q_i, kvcache_i, block_table, cache_seqlens, dv,
102
103
        tile_scheduler_metadata, num_splits,
        is_causal, is_fp8_kvcache, indices,
Jiashi Li's avatar
Jiashi Li committed
104
105
106
107
    )
    ...
```

108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
Where

- `s_q` is the number of q tokens per q sequence. If MTP (speculative decoding) is disabled, it should be 1.
- `h_kv` is the number of key-value heads.
- `h_q` is the number of query heads.

**FP8 KV Cache:**
If `is_fp8_kvcache` is set to `True`, the kernel reads the KV cache in the "FP8 with scale" format (described below). It dequantizes the cache to bfloat16 and performs attention computation in bfloat16. The output is also in bfloat16.

In the "FP8 with scale" format, each token's KV cache is 656 Bytes, structured as:
-   **First 512 bytes:** The "quantized NoPE" part, containing 512 `float8_e4m3` values.
-   **Next 16 bytes:** Scale factors, containing 4 `float32` values. The first `float32` is the scale for the first 128 `float8_e4m3` values, the second for the next 128, and so on.
-   **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This part is not quantized for accuracy.

See `tests/quant.py` for quantization and dequantization details.

**Sparse Attention (`indices` tensor):**
The `indices` tensor (if provided) enables token-level sparse attention by instructing the kernel to compute attention only for specified tokens.

-   **Shape:** `indices` should be a 3D tensor of shape `(batch_size, seq_len_q, topk)`.
-   **Format:** `indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * page_block_size + (the offset of token t within the page block)`, where `t` is the k-th token for the j-th query sequence in the i-th batch. Since the index of the page block has already been encoded into `indices_in_kvcache`, the kernel does not require the `block_table` parameter.
-   **Invalid entries:** Set invalid indices to `-1`.

**Return Values:**
The kernel returns `(out, lse)`, where:
-   `out` is the attention result.
-   `lse` is the log-sum-exp value of the attention scores for each query head.

See `tests/test_flash_mla_decoding.py` for a complete example.

### Sparse MLA Prefill

For the sparse MLA prefill kernel, call `flash_mla_sparse_fwd` directly with the following parameters:
-   `q`: Query tensor of shape `[s_q, h_q, d_qk]`
-   `kv`: Key-Value tensor of shape `[s_kv, h_kv, d_qk]`
-   `indices`: Indices tensor of shape `[s_q, h_kv, topk]`
-   `sm_scale`: A scalar value

**Note on batching:** This kernel does not support a batch dimension. For multi-batch inference, reshape the input tensors and adjust the `indices` parameter to simulate batch processing.

**Invalid indices:** Set invalid entries in `indices` to `-1` or any number `>= s_kv`.

**Return Values and Equivalent PyTorch Code:**
The kernel returns `(out, max_logits, lse)`. This is equivalent to the following PyTorch operations:

```python
Q: [s_q, h_q, d_qk], bfloat16
kv: [s_kv, h_kv, d_qk], bfloat16
indices: [s_q, h_kv, topk], int32

kv = kv.squeeze(1)  # [s_kv, d_qk], h_kv must be 1
indices = indices.squeeze(1)    # [s_q, topk]
focused_kv = kv[indices]    # For the i-th sequence (s_q), the corresponding KV tokens are selected from the KV cache based on indices[i, :]. This operation results in a tensor of shape [s_q, topk, d_qk].

P = (Q @ focused_kv.transpose(-1, -2)) * sm_scale * math.log2(math.e)    # [s_q, h_q, topk]
max_logits = P.max(dim=-1) # [s_q, h_q]
lse = log2sumexp2(P, dim=-1, base=2)   # [s_q, h_q],"log2sumexp2" means that the exponentiation and logarithm are base-2
S = exp2(P - lse)      # [s_q, h_q, topk]
out = S @ focused_kv  # [s_q, h_q, d_qk]

return (out, max_logits, lse)
```

See `tests/test_flash_mla_prefill.py` for a complete example.

### Dense MHA Prefill

This kernel implements the standard dense Multi-Head Attention (MHA) forward and backward operations. It can be called using:
-   `flash_attn_varlen_func`
-   `flash_attn_varlen_qkvpacked_func`
-   `flash_attn_varlen_kvpacked_func`

The usage is similar to the `flash_attn` package. See `tests/test_fmha_sm100.py` for a complete example.

Jiashi Li's avatar
Jiashi Li committed
182
183
184
185
## Acknowledgement

FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) and [cutlass](https://github.com/nvidia/cutlass) projects.

186
187
## Community Support

hpp's avatar
hpp committed
188
### MetaX
189
190
For MetaX GPUs, visit the official website: [MetaX](https://www.metax-tech.com).

hpp's avatar
hpp committed
191
The corresponding FlashMLA version can be found at: [MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA)
192

193
194
195
196

### Moore Threads
For the Moore Threads GPU, visit the official website: [Moore Threads](https://www.mthreads.com/).

hpp's avatar
hpp committed
197
The corresponding FlashMLA version is available on GitHub: [MooreThreads/MT-flashMLA](https://github.com/MooreThreads/MT-flashMLA).
198
199
200
201
202


### Hygon DCU
For the Hygon DCU, visit the official website: [Hygon Developer](https://developer.sourcefind.cn/).

hpp's avatar
hpp committed
203
The corresponding FlashMLA version is available here: [OpenDAS/MLAttention](https://developer.sourcefind.cn/codes/OpenDAS/MLAttention).
204
205
206
207
208


### Intellifusion
For the Intellifusion NNP, visit the official website: [Intellifusion](https://www.intellif.com).

hpp's avatar
hpp committed
209
The corresponding FlashMLA version is available on Gitee: [Intellifusion/tyllm](https://gitee.com/Intellifusion_2025/tyllm/blob/master/python/tylang/flash_mla.py).
210
211
212
213
214


### Iluvatar Corex
For Iluvatar Corex GPUs, visit the official website: [Iluvatar Corex](https://www.iluvatar.com).

hpp's avatar
hpp committed
215
The corresponding FlashMLA version is available on GitHub: [Deep-Spark/FlashMLA](https://github.com/Deep-Spark/FlashMLA/tree/iluvatar_flashmla)
216

Jiashi Li's avatar
Jiashi Li committed
217
218
219
220
221
222

### AMD Instinct
For AMD Instinct GPUs, visit the official website: [AMD Instinct](https://www.amd.com/en/products/accelerators/instinct.html).

The corresponding FlashMLA version can be found at: [AITER/MLA](https://github.com/ROCm/aiter/blob/main/aiter/mla.py)

Jiashi Li's avatar
Jiashi Li committed
223
224
225
226
## Citation

```bibtex
@misc{flashmla2025,
227
      title={FlashMLA: Efficient Multi-head Latent Attention Kernels},
228
      author={Jiashi Li, Shengyu Liu},
Jiashi Li's avatar
Jiashi Li committed
229
230
231
232
233
      year={2025},
      publisher = {GitHub},
      howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}},
}
```