attention_backend.md 5.49 KB
Newer Older
1
2
# Attention Backend

Lianmin Zheng's avatar
Lianmin Zheng committed
3
4
5
SGLang supports multiple attention backends. Each of them has different pros and cons.
You can test them according to your needs.

6
## Supporting matrix for different attention backends
7
8

| **Backend**              | **Page Size > 1** | **Spec Decoding** | **MLA** | **Sliding Window** | **MultiModal** |
9
10
11
12
|--------------------------|-------------------|-------------------|---------|--------------------|----------------|
| **FlashInfer**           | ❌                | ✅                 | ✅      | ✅                 | ✅              |
| **FA3**                  | ✅                | ✅                 | ✅      | ✅                 | ✅              |
| **Triton**               | ❌                | ✅                 | ✅      | ✅                 | ❌              |
Lianmin Zheng's avatar
Lianmin Zheng committed
13
| **Torch Native**         | ❌                | ❌                 | ✅      | ❌                 | ❌              |
14
| **FlashMLA**             | ✅                | ✅                 | ✅      | ❌                 | ❌              |
15
| **TRTLLM MLA**           | ✅                | ❌                 | ✅      | ✅                 | ❌              |
Lianmin Zheng's avatar
Lianmin Zheng committed
16
| **Ascend**               | ✅                | ❌                 | ✅      | ❌                 | ❌              |
17
| **Wave**                 | ✅                | ❌                 | ❌      | ❌                 | ❌              |
18

19
20
21
**Notes:**
- TRTLLM MLA only implements decode operations. For prefill operations (including multimodal inputs), it falls back to FlashInfer MLA backend.

22
23
24
Note: Every kernel backend is compatible with a page size > 1 by specifying an argument such as `--page-size 16`.
This is because a page size of 16 can be converted to a page size of 1 in the kernel backend.
The "❌" and "✅" symbols in the table above under "Page Size > 1" indicate whether the kernel actually operates with a page size greater than 1, rather than treating a page size of 16 as a page size of 1.
25
26
27

## User guide

Lianmin Zheng's avatar
Lianmin Zheng committed
28
### Launch command for different attention backends.
29
30
31

- FlashInfer (Default for Non-Hopper Machines, e.g., A100, A40)
```bash
32
33
34
35
36
37
38
39
python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend flashinfer
python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-V3 \
  --attention-backend flashinfer \
  --trust-remote-code
40
41
42
43
```

- FlashAttention 3 (Default for Hopper Machines, e.g., H100, H200, H20)
```bash
44
45
46
47
48
49
50
51
python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend fa3
python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-V3 \
  --trust-remote-code \
  --attention-backend fa3
52
53
54
55
```

- Triton
```bash
56
57
58
59
60
61
62
63
python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend triton
python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-V3 \
  --attention-backend triton \
  --trust-remote-code
64
65
66
67
```

- Torch Native
```bash
68
69
70
python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend torch_native
71
```
72
73
74

- FlashMLA
```bash
75
76
77
78
79
80
81
82
83
84
85
python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-R1 \
  --attention-backend flashmla \
  --trust-remote-code
python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-R1 \
  --attention-backend flashmla \
  --kv-cache-dtype fp8_e4m3 \
  --trust-remote-code
86
```
87

88
89
- TRTLLM MLA (Optimized for Blackwell Architecture, e.g., B200)
```bash
90
91
92
93
94
python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-R1 \
  --attention-backend trtllm_mla \
  --trust-remote-code
95
96
```

Faraz's avatar
Faraz committed
97
98
- TRTLLM MLA with FP8 KV Cache (Higher concurrency, lower memory footprint)
```bash
99
100
101
102
103
104
python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-R1 \
  --attention-backend trtllm_mla \
  --kv-cache-dtype fp8_e4m3 \
  --trust-remote-code
Faraz's avatar
Faraz committed
105
106
```

107
108
- Ascend
```bash
109
110
111
python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend ascend
112
```
113

114
115
- Wave
```bash
116
117
118
python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend wave
119
```
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

## Steps to add a new attention backend
To add a new attention backend, you can learn from the existing backends
(`python/sglang/srt/layers/attention/triton_backend.py`, `python/sglang/srt/layers/attention/flashattention_backend.py`)
and follow the steps below.

1. Run without cuda graph. Support the two forward functions
    - forward_extend
        - Will be used for prefill, prefill with KV cache, and target verification
        - It will be called once per layer
    - forward_decode
        - Will be used for normal decode, and draft decode
        - It will be called once per layer
    - init_forward_metadata
        - Initialize the class and common metadata shared by all layers
        - Call the plan function for optimizations like split_kv
        - It will be called once per forward
2. Run with cuda graph. It has two phases (capture and replay) and you need to implement three functions
    - init_cuda_graph_state
        - It will be called once during life time
        - Create all common shared buffers
    - init_forward_metadata_capture_cuda_graph
        - It will be called before capturing a cuda graph
        - It is similar to init_forward_metadata but write the medatada to some pre-defined buffers
    - init_forward_metadata_replay_cuda_graph
        - It will be called before replaying a cuda graph
        - This function is in the critical path and needs to be fast