README.md 11.7 KB
Newer Older
Jiashi Li's avatar
Jiashi Li committed
1
2
# FlashMLA

3
## Introduction
4

Jiashi Li's avatar
Jiashi Li committed
5
FlashMLA is DeepSeek's library of optimized attention kernels, powering the [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3) and [DeepSeek-V3.2-Exp](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp) models. This repository contains the following implementations:
6

7
**Sparse Attention Kernels**
8

Shengyu Liu's avatar
Shengyu Liu committed
9
*These kernels power DeepSeek Sparse Attention (DSA), as introduced in [this paper](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp).*
10

11
12
- Token-level sparse attention for the prefill stage
- Token-level sparse attention for the decoding stage, with FP8 KV cache
shenzhe's avatar
shenzhe committed
13
- DSA MLS sparse prefill and DSA BF16 sparse decoding on Hygon DCU gfx93
14

15
**Dense Attention Kernels**
Jiashi Li's avatar
Jiashi Li committed
16

17
18
- Dense attention for the prefill stage
- Dense attention for the decoding stage
Jiashi Li's avatar
Jiashi Li committed
19

20
## News
21

Shengyu Liu's avatar
Shengyu Liu committed
22
- **2025.09.29 Release of Sparse Attention Kernels**: With the launch of [DeepSeek-V3.2](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp), we are releasing the corresponding token-level sparse attention kernels. These kernels power the model's DeepSeek Sparse Attention (DSA) and achieve up to 640 TFlops during prefilling and 410 TFlops during decoding. We also release a deep-dive blog for our new FP8 sparse decoding kernel. Check it out [here](docs/20250929-hopper-fp8-sparse-deep-dive.md).
Jiashi Li's avatar
Jiashi Li committed
23
- **2025.08.01 Kernels for MHA on SM100**: Thanks to [NVIDIA's PR](https://github.com/deepseek-ai/FlashMLA/pull/76) for MHA forward / backward kernels on SM100!
24
25
- **2025.04.22 Deep-Dive Blog**: We'd love to share the technical details behind the new FlashMLA kernel! Check out our deep-dive write-up [here](docs/20250422-new-kernel-deep-dive.md).
- **2025.04.22 Performance Update**: We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement for compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Simply upgrade to the new version for an immediate performance boost! 🚀🚀🚀
26

27
## Performance
Jiashi Li's avatar
Jiashi Li committed
28

29
#### Test & benchmark MLA decoding (Sparse & Dense):
Jiashi Li's avatar
Jiashi Li committed
30
31

```bash
32
33
python tests/test_flash_mla_dense_decoding.py
python tests/test_flash_mla_sparse_decoding.py
shenzhe's avatar
shenzhe committed
34
FLASH_MLA_DECODE_BF16=1 python tests/test_flash_mla_sparse_decoding.py
Jiashi Li's avatar
Jiashi Li committed
35
36
```

Jiashi Li's avatar
Jiashi Li committed
37
The dense MLA decoding kernel achieves up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5 with CUDA 12.8. The token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16) achieves 410 TFLOPS in compute-bound configuration on H800 SXM5 with CUDA 12.8, and achieves up to 350 TFlops on B200 (which is not really optimized yet).
38

39
#### Test & benchmark MHA prefill (Dense):
40
41

```bash
42
python tests/test_fmha_sm100.py
43
44
```

45
46
47
It achieves up to 1460 TFlops in forward and 1000 TFlops in backward computation on B200, as reported by NVIDIA.

#### Test & benchmark MLA prefill (Sparse):
48

Jiashi Li's avatar
Jiashi Li committed
49
```bash
50
python tests/test_flash_mla_sparse_prefill.py
Jiashi Li's avatar
Jiashi Li committed
51
52
```

Jiashi Li's avatar
Jiashi Li committed
53
It achieves up to 640 TFlops in forward computation on H800 SXM5 with CUDA 12.8, and achieves up to 1450 TFlops on B200, CUDA 12.9.
54
55
56

## Requirements

Jiashi Li's avatar
Jiashi Li committed
57
58
- SM90 / SM100 (See the support matrix below)
- CUDA 12.8 and above (CUDA 12.9+ is required for SM100 kernels)
59
- PyTorch 2.0 and above
60

shenzhe's avatar
shenzhe committed
61
62
63
For the Hygon DCU gfx93 path in this branch, build with the local AICC/ROCm
toolchain and the pinned CUTLASS submodule described below.

64
Support matrix:
Jiashi Li's avatar
Jiashi Li committed
65

66
67
| Kernel | GPU Architecture | MLA Mode [2] | KVCache Format |
| :---: | :---: | :---: | :---: |
Jiashi Li's avatar
Jiashi Li committed
68
69
| Dense Decoding | SM90 | MQA | BF16 |
| Sparse Decoding | SM90 & SM100 | MQA | FP8 [1] |
shenzhe's avatar
shenzhe committed
70
| DSA BF16 Sparse Decoding | gfx93 | MQA | BF16 |
Jiashi Li's avatar
Jiashi Li committed
71
| Dense Prefill | SM100 | MHA |  |
shenzhe's avatar
shenzhe committed
72
| Sparse Prefill | SM90 & SM100 / gfx93 | MQA |  |
73
74
75

[1]: For more details on using FP8 KV cache, see documents below.

Shengyu Liu's avatar
Shengyu Liu committed
76
[2]: Here "MLA Mode" refers to the mode used for MLA calculation. MQA stands for Multi-Query Attention mode (i.e. `head_dim_k` =  576 with `head_dim_v` = 512), while MHA stands for Multi-Head Attention mode (i.e. `head_dim_k` = 192 / 128 with `head_dim_v` = 128). For a detailed explanation of these modes, please refer to the appendix of [DeepSeek V3.2's Paper](https://github.com/deepseek-ai/DeepSeek-V3.2-Exp).
77
78
79
80

## Installation

```bash
shenzhe's avatar
shenzhe committed
81
82
git clone -b master-aicc ssh://git@10.16.1.204:10022/dcutoolkit/deeplearing/flashmla.git flashmla
cd flashmla
83
84
85
86
git submodule update --init --recursive
pip install -v .
```

shenzhe's avatar
shenzhe committed
87
88
89
90
The CUTLASS dependency is pinned as `csrc/cutlass/cutlass_3.2.1` on branch
`feature/16x64-mmac`. If the submodule is missing, `setup.py` will try to
initialize it before compiling.

shenzhe's avatar
shenzhe committed
91
92
93
94
95
96
97
98
99
100
101
102
103
When running tests directly from the source tree, prefer an in-place build so
the local `flash_mla` package can find the compiled `flash_mla.cuda` extension:

```bash
python setup.py build_ext --inplace
python tests/test_flash_mla_sparse_prefill.py
FLASH_MLA_DECODE_BF16=1 python tests/test_flash_mla_sparse_decoding.py
```

If you only run `pip install -v .` or `python setup.py install`, avoid launching
tests from a different unbuilt source checkout, otherwise Python may import that
checkout source `flash_mla/` directory and fail with `No module named flash_mla.cuda`.

104
105
106
107
108
## Usage

### MLA Decoding

To use the MLA decoding kernels, call get_mla_metadata once before the decoding loop to get the tile scheduler metadata. Then, call flash_mla_with_kvcache in each decoding step. For example:
Jiashi Li's avatar
Jiashi Li committed
109
110
111
112

```python
from flash_mla import get_mla_metadata, flash_mla_with_kvcache

113
114
115
116
117
118
119
120
tile_scheduler_metadata, num_splits = get_mla_metadata(
    cache_seqlens,
    s_q * h_q // h_kv,
    h_kv,
    h_q,
    is_fp8,
    topk,
)
Jiashi Li's avatar
Jiashi Li committed
121
122
123
124
125

for i in range(num_layers):
    ...
    o_i, lse_i = flash_mla_with_kvcache(
        q_i, kvcache_i, block_table, cache_seqlens, dv,
126
127
        tile_scheduler_metadata, num_splits,
        is_causal, is_fp8_kvcache, indices,
Jiashi Li's avatar
Jiashi Li committed
128
129
130
131
    )
    ...
```

132
133
134
135
136
137
138
139
140
Where

- `s_q` is the number of q tokens per q sequence. If MTP (speculative decoding) is disabled, it should be 1.
- `h_kv` is the number of key-value heads.
- `h_q` is the number of query heads.

**FP8 KV Cache:**
If `is_fp8_kvcache` is set to `True`, the kernel reads the KV cache in the "FP8 with scale" format (described below). It dequantizes the cache to bfloat16 and performs attention computation in bfloat16. The output is also in bfloat16.

shenzhe's avatar
shenzhe committed
141
142
143
144
145
**BF16 Sparse Decode on gfx93:**
This branch also supports the DSA BF16 sparse decode path when
`is_fp8_kvcache=False` and `k_cache` / `extra_k_cache` are `torch.bfloat16`.
The test suite enables this path with `FLASH_MLA_DECODE_BF16=1`.

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
In the "FP8 with scale" format, each token's KV cache is 656 Bytes, structured as:
-   **First 512 bytes:** The "quantized NoPE" part, containing 512 `float8_e4m3` values.
-   **Next 16 bytes:** Scale factors, containing 4 `float32` values. The first `float32` is the scale for the first 128 `float8_e4m3` values, the second for the next 128, and so on.
-   **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This part is not quantized for accuracy.

See `tests/quant.py` for quantization and dequantization details.

**Sparse Attention (`indices` tensor):**
The `indices` tensor (if provided) enables token-level sparse attention by instructing the kernel to compute attention only for specified tokens.

-   **Shape:** `indices` should be a 3D tensor of shape `(batch_size, seq_len_q, topk)`.
-   **Format:** `indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * page_block_size + (the offset of token t within the page block)`, where `t` is the k-th token for the j-th query sequence in the i-th batch. Since the index of the page block has already been encoded into `indices_in_kvcache`, the kernel does not require the `block_table` parameter.
-   **Invalid entries:** Set invalid indices to `-1`.

**Return Values:**
The kernel returns `(out, lse)`, where:
-   `out` is the attention result.
-   `lse` is the log-sum-exp value of the attention scores for each query head.

See `tests/test_flash_mla_decoding.py` for a complete example.

### Sparse MLA Prefill

For the sparse MLA prefill kernel, call `flash_mla_sparse_fwd` directly with the following parameters:
-   `q`: Query tensor of shape `[s_q, h_q, d_qk]`
-   `kv`: Key-Value tensor of shape `[s_kv, h_kv, d_qk]`
-   `indices`: Indices tensor of shape `[s_q, h_kv, topk]`
-   `sm_scale`: A scalar value

**Note on batching:** This kernel does not support a batch dimension. For multi-batch inference, reshape the input tensors and adjust the `indices` parameter to simulate batch processing.

**Invalid indices:** Set invalid entries in `indices` to `-1` or any number `>= s_kv`.

**Return Values and Equivalent PyTorch Code:**
The kernel returns `(out, max_logits, lse)`. This is equivalent to the following PyTorch operations:

```python
Q: [s_q, h_q, d_qk], bfloat16
kv: [s_kv, h_kv, d_qk], bfloat16
indices: [s_q, h_kv, topk], int32

kv = kv.squeeze(1)  # [s_kv, d_qk], h_kv must be 1
indices = indices.squeeze(1)    # [s_q, topk]
focused_kv = kv[indices]    # For the i-th sequence (s_q), the corresponding KV tokens are selected from the KV cache based on indices[i, :]. This operation results in a tensor of shape [s_q, topk, d_qk].

P = (Q @ focused_kv.transpose(-1, -2)) * sm_scale * math.log2(math.e)    # [s_q, h_q, topk]
max_logits = P.max(dim=-1) # [s_q, h_q]
lse = log2sumexp2(P, dim=-1, base=2)   # [s_q, h_q],"log2sumexp2" means that the exponentiation and logarithm are base-2
S = exp2(P - lse)      # [s_q, h_q, topk]
out = S @ focused_kv  # [s_q, h_q, d_qk]

return (out, max_logits, lse)
```

See `tests/test_flash_mla_prefill.py` for a complete example.

### Dense MHA Prefill

This kernel implements the standard dense Multi-Head Attention (MHA) forward and backward operations. It can be called using:
-   `flash_attn_varlen_func`
-   `flash_attn_varlen_qkvpacked_func`
-   `flash_attn_varlen_kvpacked_func`

The usage is similar to the `flash_attn` package. See `tests/test_fmha_sm100.py` for a complete example.

Jiashi Li's avatar
Jiashi Li committed
211
212
213
214
## Acknowledgement

FlashMLA is inspired by [FlashAttention 2&3](https://github.com/dao-AILab/flash-attention/) and [cutlass](https://github.com/nvidia/cutlass) projects.

215
216
## Community Support

hpp's avatar
hpp committed
217
### MetaX
218
219
For MetaX GPUs, visit the official website: [MetaX](https://www.metax-tech.com).

hpp's avatar
hpp committed
220
The corresponding FlashMLA version can be found at: [MetaX-MACA/FlashMLA](https://github.com/MetaX-MACA/FlashMLA)
221

222
223
224
225

### Moore Threads
For the Moore Threads GPU, visit the official website: [Moore Threads](https://www.mthreads.com/).

hpp's avatar
hpp committed
226
The corresponding FlashMLA version is available on GitHub: [MooreThreads/MT-flashMLA](https://github.com/MooreThreads/MT-flashMLA).
227
228
229
230
231


### Hygon DCU
For the Hygon DCU, visit the official website: [Hygon Developer](https://developer.sourcefind.cn/).

hpp's avatar
hpp committed
232
The corresponding FlashMLA version is available here: [OpenDAS/MLAttention](https://developer.sourcefind.cn/codes/OpenDAS/MLAttention).
233
234
235
236
237


### Intellifusion
For the Intellifusion NNP, visit the official website: [Intellifusion](https://www.intellif.com).

hpp's avatar
hpp committed
238
The corresponding FlashMLA version is available on Gitee: [Intellifusion/tyllm](https://gitee.com/Intellifusion_2025/tyllm/blob/master/python/tylang/flash_mla.py).
239
240
241
242
243


### Iluvatar Corex
For Iluvatar Corex GPUs, visit the official website: [Iluvatar Corex](https://www.iluvatar.com).

hpp's avatar
hpp committed
244
The corresponding FlashMLA version is available on GitHub: [Deep-Spark/FlashMLA](https://github.com/Deep-Spark/FlashMLA/tree/iluvatar_flashmla)
245

Jiashi Li's avatar
Jiashi Li committed
246
247
248
249
250
251

### AMD Instinct
For AMD Instinct GPUs, visit the official website: [AMD Instinct](https://www.amd.com/en/products/accelerators/instinct.html).

The corresponding FlashMLA version can be found at: [AITER/MLA](https://github.com/ROCm/aiter/blob/main/aiter/mla.py)

Jiashi Li's avatar
Jiashi Li committed
252
253
254
255
## Citation

```bibtex
@misc{flashmla2025,
256
      title={FlashMLA: Efficient Multi-head Latent Attention Kernels},
257
      author={Jiashi Li, Shengyu Liu},
Jiashi Li's avatar
Jiashi Li committed
258
259
260
261
262
      year={2025},
      publisher = {GitHub},
      howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}},
}
```