"vscode:/vscode.git/clone" did not exist on "a7ae55b785108d4bfcbf9c95eea5458faf4c0699"
Unverified Commit 6bc503af authored by b8zhong's avatar b8zhong Committed by GitHub
Browse files

[Doc] Update support matrix for attn and hybrid attn (#11293)

parent b2c85669
# Attention Backend
SGLang supports multiple attention backends. Each of them has different pros and cons.
SGLang supports a large variety of attention backends. Each of them has different pros and cons.
You can test them according to your needs.
## Supporting matrix for different attention backends
```{important}
Selecting an optimal attention backend is crucial for maximizing your performance. Different backends excel in various scenarios, so choose based on your model, hardware, and use case. Not all backends are supported on all platforms and model architectures.
```
## Support Matrix
The support matrix is split into two parts: MHA (standard attention) and MLA (multi-head latent attention). For an explanation of the key differences between MHA and MLA, please see the [SGLang documentation on DeepSeek MLA](https://github.com/sgl-project/sglang/blob/main/docs/basic_usage/deepseek.md#multi-head-latent-attention-mla) and the original [DeepSeek MLA paper](https://arxiv.org/pdf/2405.04434).
### MHA Backends
| **Backend** | **Page Size > 1 (native)** | **FP8 KV Cache** | **Spec topk=1** | **Spec topk>1** | **Sliding Window** | **MultiModal** |
|---------------------------------|-----------------------------|------------------|-----------------|-----------------|--------------------|----------------|
| **FlashInfer** | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| **FA3 (FlashAttention 3)** | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| **FA4 (FlashAttention 4)** | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **Triton** | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ |
| **Torch Native (SDPA)** | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **FlexAttention (PyTorch)** | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **TRTLLM MHA** | 16, 32 or 64 | ❌ | ✅ | ❌ | ❌ | ❌ |
| **Dual Chunk FlashAttention** | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **AITER (ROCm)** | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
| **Wave (ROCm)** | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| **Ascend (NPU)** | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
### MLA Backends
| **Backend** | **Native Page Sizes** | **FP8 KV Cache** | **Chunked Prefix Cache** | **Spec topk=1** | **Spec topk>1** |
|----------------------------|---------------------------|------------------|--------------------------|-----------------|-----------------|
| **FlashInfer MLA** | 1 | ❌ | ✅ | ✅ | ❌ |
| **FlashMLA** | 64 | ❌ | ✅ | ✅ | ❌ |
| **Cutlass MLA** | 128 | ✅ | ✅ | ✅ | ❌ |
| **TRTLLM MLA (Blackwell)** | 32 or 64 | ✅ | ✅ | ✅ | ❌ |
| **FA3 (FlashAttention 3)** | n/a | ❌ | ✅ | ✅ | ⚠️ (page_size=1 only) |
| **Triton** | n/a | ❌ | ❌ | ✅ | ⚠️ (page_size=1 only) |
| **FA4** | n/a | ❌ | ❌ | ❌ | ❌ |
| **Ascend MLA (NPU)** | 128 | ❌ | ❌ | ❌ | ❌ |
```{warning}
FlashMLA FP8 KV cache is currently not working. See upstream issue [#8856](https://github.com/sgl-project/sglang/pull/8856). Use non-FP8 KV or another backend when FP8 KV cache is required.
```
```{note}
- FlashAttention 4 is prefill-only for now.
- NSA is specifically designed for [DeepSeek V3.2 DSA](https://lmsys.org/blog/2025-09-29-deepseek-V32/).
```
```{tip}
Speculative decoding topk: `topk` is the number of draft tokens sampled per step from the draft model. `topk = 1` follows classic EAGLE; `topk > 1` explores multiple branches and requires backend support in both draft and verification paths.
```
Note: Many backends that do not natively operate on pages can emulate `page_size > 1` at the wrapper layer by expanding page tables to per-token indices. The "Page Size > 1 (native)" column indicates true in-kernel paging. Some backends require fixed native page sizes and cannot be reduced/emulated differently: TRTLLM MHA (16/32/64), TRTLLM MLA (32/64), FlashMLA (64), Cutlass MLA (128), Ascend (128).
MLA page-size constraints:
- FlashInfer MLA: page_size = 1.
- FlashMLA: page_size = 64.
- Cutlass MLA: page_size = 128.
- TRTLLM MLA: page_size ∈ {32, 64}.
### Hybrid attention (different backends for prefill vs decode) (Experimental)
```{warning}
Hybrid attention is an experimental feature.
```
You can mix-and-match attention backends for prefill and decode. This is useful when one backend excels at prefill and another excels at decode. For the implementation details, please see `python/sglang/srt/layers/attention/hybrid_attn_backend.py`.
```bash
# Example: Prefill with FA4, Decode with TRTLLM MLA (Blackwell)
python3 -m sglang.launch_server \
--model-path nvidia/DeepSeek-R1-FP4 \
--tp 8 \
--attention-backend trtllm_mla \
--moe-runner-backend flashinfer_trtllm \
--quantization modelopt_fp4 \
--prefill-attention-backend fa4
```
#### Speculative decoding with hybrid attention
Hybrid attention also works with speculative decoding. The backend used for draft decoding and target verification depends on `--speculative-attention-mode`:
| **Backend** | **Page Size > 1** | **Spec Decoding** | **MLA** | **Sliding Window** | **MultiModal** |
|--------------------------|-------------------|-------------------|---------|--------------------|----------------|
| **FlashInfer** | ❌ | ✅ | ✅ | ✅ | ✅ |
| **FA3** | ✅ | ✅ | ✅ | ✅ | ✅ |
| **Triton** | ❌ | ✅ | ✅ | ✅ | ❌ |
| **Torch Native** | ❌ | ❌ | ✅ | ❌ | ❌ |
| **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ |
| **TRTLLM MLA** | ✅ | ❌ | ✅ | ✅ | ❌ |
| **Ascend** | ✅ | ❌ | ✅ | ❌ | ❌ |
| **Wave** | ✅ | ❌ | ❌ | ❌ | ❌ |
- `--speculative-attention-mode decode` (recommended): draft/verify use the decode backend.
- `--speculative-attention-mode prefill` (default): draft/verify use the prefill backend.
**Notes:**
- TRTLLM MLA only implements decode operations. For prefill operations (including multimodal inputs), it falls back to FlashInfer MLA backend.
Constraints when combining hybrid attention with speculative decoding:
Note: Every kernel backend is compatible with a page size > 1 by specifying an argument such as `--page-size 16`.
This is because a page size of 16 can be converted to a page size of 1 in the kernel backend.
The "❌" and "✅" symbols in the table above under "Page Size > 1" indicate whether the kernel actually operates with a page size greater than 1, rather than treating a page size of 16 as a page size of 1.
- If any attention backend is `trtllm_mha`, speculative decoding supports only `--speculative-eagle-topk 1`.
- For paged MHA backends with `--page-size > 1` and `--speculative-eagle-topk > 1`, only `flashinfer` is supported.
- `flex_attention` is not supported with speculative decoding.
- For MLA backends, `trtllm_mla` supports `topk > 1`; `flashmla` and `flashinfer_mla` support only `topk = 1`.
- CUDA Graph: the decode backend is always captured; the prefill backend is captured only when `--speculative-attention-mode prefill`.
```{tip}
If you set only one of `--prefill-attention-backend` or `--decode-attention-backend`, the unspecified phase inherits `--attention-backend`.
If both are specified and differ, SGLang automatically enables a hybrid wrapper to dispatch to the chosen backend per phase.
```
## User guide
......@@ -118,6 +196,38 @@ python3 -m sglang.launch_server \
--attention-backend wave
```
- FlexAttention
```bash
python3 -m sglang.launch_server \
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
--attention-backend flex_attention
```
- Dual Chunk FlashAttention (MHA-only)
```bash
python3 -m sglang.launch_server \
--model Qwen/Qwen2.5-14B-Instruct-1M \
--attention-backend dual_chunk_flash_attn
```
- Cutlass MLA
```bash
python3 -m sglang.launch_server \
--tp 8 \
--model deepseek-ai/DeepSeek-R1 \
--attention-backend cutlass_mla \
--trust-remote-code
```
- FlashAttention 4 (MHA & MLA)
```bash
python3 -m sglang.launch_server \
--tp 8 \
--model deepseek-ai/DeepSeek-R1 \
--attention-backend fa4 \
--trust-remote-code
```
## Steps to add a new attention backend
To add a new attention backend, you can learn from the existing backends
(`python/sglang/srt/layers/attention/triton_backend.py`, `python/sglang/srt/layers/attention/flashattention_backend.py`)
......
......@@ -36,6 +36,7 @@ The core features include:
advanced_features/server_arguments.md
advanced_features/hyperparameter_tuning.md
advanced_features/attention_backend.md
advanced_features/speculative_decoding.ipynb
advanced_features/structured_outputs.ipynb
advanced_features/structured_outputs_for_reasoning_models.ipynb
......@@ -44,12 +45,11 @@ The core features include:
advanced_features/quantization.md
advanced_features/lora.ipynb
advanced_features/pd_disaggregation.md
advanced_features/hicache.rst
advanced_features/pd_multiplexing.md
advanced_features/vlm_query.ipynb
advanced_features/router.md
advanced_features/observability.md
advanced_features/attention_backend.md
advanced_features/hicache.rst
.. toctree::
:maxdepth: 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment