deepseek.md 15.4 KB
Newer Older
Yineng Zhang's avatar
Yineng Zhang committed
1
# DeepSeek Usage
2

3
4
5
SGLang provides many optimizations specifically designed for the DeepSeek models, making it the inference engine recommended by the official [DeepSeek team](https://github.com/deepseek-ai/DeepSeek-V3/tree/main?tab=readme-ov-file#62-inference-with-sglang-recommended) from Day 0.

This document outlines current optimizations for DeepSeek.
6
For an overview of the implemented features see the completed [Roadmap](https://github.com/sgl-project/sglang/issues/2591).
7

8
## Launch DeepSeek V3 with SGLang
9

10
To run DeepSeek V3/R1 models, the requirements are as follows:
Chayenne's avatar
Chayenne committed
11

12
13
14
15
16
| Weight Type | Configuration |
|------------|-------------------|
| **Full precision FP8**<br>*(recommended)* | 8 x H200 |
| | 8 x MI300X |
| | 2 x 8 x H100/800/20 |
17
| | Xeon 6980P CPU |
18
19
20
21
22
23
| **Full precision BF16** | 2 x 8 x H200 |
| | 2 x 8 x MI300X |
| | 4 x 8 x H100/800/20 |
| | 4 x 8 x A100/A800 |
| **Quantized weights (AWQ)** | 8 x H100/800/20 |
| | 8 x A100/A800 |
24
| **Quantized weights (int8)** | 16 x A100/800 |
25
| | 32 x L40S |
26
| | Xeon 6980P CPU |
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

<style>
.md-typeset__table {
  width: 100%;
}

.md-typeset__table table {
  border-collapse: collapse;
  margin: 1em 0;
  border: 2px solid var(--md-typeset-table-color);
  table-layout: fixed;
}

.md-typeset__table th {
  border: 1px solid var(--md-typeset-table-color);
  border-bottom: 2px solid var(--md-typeset-table-color);
  background-color: var(--md-default-bg-color--lighter);
  padding: 12px;
}

.md-typeset__table td {
  border: 1px solid var(--md-typeset-table-color);
  padding: 12px;
}

.md-typeset__table tr:nth-child(2n) {
  background-color: var(--md-default-bg-color--lightest);
}
</style>
Chayenne's avatar
Chayenne committed
56
57
58
59

Detailed commands for reference:

- [8 x H200](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#using-docker-recommended)
Lianmin Zheng's avatar
Lianmin Zheng committed
60
- [8 x MI300X](../platforms/amd_gpu.md#running-deepseek-v3)
Chayenne's avatar
Chayenne committed
61
62
63
- [2 x 8 x H200](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-two-h208-nodes)
- [4 x 8 x A100](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-four-a1008-nodes)
- [8 x A100 (AWQ)](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-8-a100a800-with-awq-quantization)
Yineng Zhang's avatar
Yineng Zhang committed
64
- [16 x A100 (int8)](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-16-a100a800-with-int8-quantization)
65
- [32 x L40S (int8)](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-32-l40s-with-int8-quantization)
Lianmin Zheng's avatar
Lianmin Zheng committed
66
- [Xeon 6980P CPU](../platforms/cpu_server.md#example-running-deepseek-r1)
67
68

### Download Weights
69
If you encounter errors when starting the server, ensure the weights have finished downloading. It's recommended to download them beforehand or restart multiple times until all weights are downloaded. Please refer to [DeepSeek V3](https://huggingface.co/deepseek-ai/DeepSeek-V3-Base#61-inference-with-deepseek-infer-demo-example-only) official guide to download the weights.
70

71
### Launch with one node of 8 x H200
Lianmin Zheng's avatar
Lianmin Zheng committed
72
73
Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#installation--launch).
**Note that Deepseek V3 is already in FP8**, so we should not run it with any quantization arguments like `--quantization fp8 --kv-cache-dtype fp8_e5m2`.
74
75
76
77
78

### Running examples on Multi-node

- [Serving with two H20*8 nodes](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-two-h208-nodes).

79
- [Serving with two H200*8 nodes and docker](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-two-h2008-nodes-and-docker).
80

Jhin's avatar
Jhin committed
81
82
- [Serving with four A100*8 nodes](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-four-a1008-nodes).

Yineng Zhang's avatar
Yineng Zhang committed
83
84
85
86
87
88
89
90
## Optimizations

### Multi-head Latent Attention (MLA) Throughput Optimizations

**Description**: [MLA](https://arxiv.org/pdf/2405.04434) is an innovative attention mechanism introduced by the DeepSeek team, aimed at improving inference efficiency. SGLang has implemented specific optimizations for this, including:

- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.

91
- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including [FlashAttention3](https://github.com/Dao-AILab/flash-attention), [Flashinfer](https://docs.flashinfer.ai/api/mla.html), [FlashMLA](https://github.com/deepseek-ai/FlashMLA), [CutlassMLA](https://github.com/sgl-project/sglang/pull/5390), **TRTLLM MLA** (optimized for Blackwell architecture), and [Triton](https://github.com/triton-lang/triton) backends. The default FA3 provides good performance across wide workloads.
Yineng Zhang's avatar
Yineng Zhang committed
92

93
- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.
Yineng Zhang's avatar
Yineng Zhang committed
94

95
- **CUDA Graph & Torch.compile**: Both MLA and Mixture of Experts (MoE) are compatible with CUDA Graph and Torch.compile, which reduces latency and accelerates decoding speed for small batch sizes.
Yineng Zhang's avatar
Yineng Zhang committed
96

97
- **Chunked Prefix Cache**: Chunked prefix cache optimization can increase throughput by cutting prefix cache into chunks, processing them with multi-head attention and merging their states. Its improvement can be significant when doing chunked prefill on long sequences. Currently this optimization is only available for FlashAttention3 backend.
98

99
Overall, with these optimizations, we have achieved up to **7x** acceleration in output throughput compared to the previous version.
Yineng Zhang's avatar
Yineng Zhang committed
100
101
102
103
104

<p align="center">
  <img src="https://lmsys.org/images/blog/sglang_v0_3/deepseek_mla.svg" alt="Multi-head Latent Attention for DeepSeek Series Models">
</p>

105
**Usage**: MLA optimization is enabled by default. For MLA models on Blackwell architecture (e.g., B200), the default backend is FlashInfer. To use the optimized TRTLLM MLA backend for decode operations, explicitly specify `--attention-backend trtllm_mla`. Note that TRTLLM MLA only optimizes decode operations - prefill operations (including multimodal inputs) will fall back to FlashInfer MLA.
Yineng Zhang's avatar
Yineng Zhang committed
106
107
108
109
110

**Reference**: Check [Blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [Slides](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/lmsys_1st_meetup_deepseek_mla.pdf) for more details.

### Data Parallelism Attention

111
**Description**: This optimization involves data parallelism (DP) for the MLA attention mechanism of DeepSeek Series Models, which allows for a significant reduction in the KV cache size, enabling larger batch sizes. Each DP worker independently handles different types of batches (prefill, decode, idle), which are then synchronized before and after processing through the Mixture-of-Experts (MoE) layer. If you do not use DP attention, KV cache will be duplicated among all TP ranks.
Yineng Zhang's avatar
Yineng Zhang committed
112
113
114
115
116

<p align="center">
  <img src="https://lmsys.org/images/blog/sglang_v0_4/dp_attention.svg" alt="Data Parallelism Attention for DeepSeek Series Models">
</p>

117
With data parallelism attention enabled, we have achieved up to **1.9x** decoding throughput improvement compared to the previous version.
Yineng Zhang's avatar
Yineng Zhang committed
118
119
120
121
122

<p align="center">
  <img src="https://lmsys.org/images/blog/sglang_v0_4/deepseek_coder_v2.svg" alt="Data Parallelism Attention Performance Comparison">
</p>

123
**Usage**:
124
125
- Append `--enable-dp-attention --tp 8 --dp 8` to the server arguments when using 8 H200 GPUs. This optimization improves peak throughput in high batch size scenarios where the server is limited by KV cache capacity. However, it is not recommended for low-latency, small-batch use cases.
- DP and TP attention can be flexibly combined. For example, to deploy DeepSeek-V3/R1 on 2 nodes with 8 H100 GPUs each, you can specify `--enable-dp-attention --tp 16 --dp 2`. This configuration runs attention with 2 DP groups, each containing 8 TP GPUs.
Ke Bao's avatar
Ke Bao committed
126

Yineng Zhang's avatar
Yineng Zhang committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
**Reference**: Check [Blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models).

### Multi Node Tensor Parallelism

**Description**: For users with limited memory on a single node, SGLang supports serving DeepSeek Series Models, including DeepSeek V3, across multiple nodes using tensor parallelism. This approach partitions the model parameters across multiple GPUs or nodes to handle models that are too large for one node's memory.

**Usage**: Check [here](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-2-h208) for usage examples.

### Block-wise FP8

**Description**: SGLang implements block-wise FP8 quantization with two key optimizations:

- **Activation**: E4M3 format using per-token-per-128-channel sub-vector scales with online casting.

- **Weight**: Per-128x128-block quantization for better numerical stability.

143
- **DeepGEMM**: The [DeepGEMM](https://github.com/deepseek-ai/DeepGEMM) kernel library optimized for FP8 matrix multiplications.
144

145
146
147
148
149
150
151
**Usage**: The activation and weight optimization above are turned on by default for DeepSeek V3 models. DeepGEMM is enabled by default on NVIDIA Hopper GPUs and disabled by default on other devices. DeepGEMM can also be manually turned off by setting the environment variable `SGL_ENABLE_JIT_DEEPGEMM=0`.

Before serving the DeepSeek model, precompile the DeepGEMM kernels using:
```bash
python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
```
The precompilation process typically takes around 10 minutes to complete.
Ke Bao's avatar
Ke Bao committed
152
153

### Multi-token Prediction
154
**Description**: SGLang implements DeepSeek V3 Multi-Token Prediction (MTP) based on [EAGLE speculative decoding](https://docs.sglang.ai/backend/speculative_decoding.html#EAGLE-Decoding). With this optimization, the decoding speed can be improved by **1.8x** for batch size 1 and **1.5x** for batch size 32 respectively on H200 TP8 setting.
Ke Bao's avatar
Ke Bao committed
155
156

**Usage**:
157
Add arguments `--speculative-algorithm`, `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` to enable this feature. For example:
Ke Bao's avatar
Ke Bao committed
158
```
159
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --speculative-algorithm EAGLE --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 --trust-remote-code --tp 8
Ke Bao's avatar
Ke Bao committed
160
```
fzyzcjy's avatar
fzyzcjy committed
161
- The best configuration for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes.
162
- FlashAttention3, FlashMLA, and Triton backend fully supports MTP usage. For FlashInfer backend (`--attention-backend flashinfer`) with speculative decoding,`--speculative-eagle-topk` parameter should be set to `1`. MTP support for the CutlassMLA and TRTLLM MLA backends are still under development.
163
164
- To enable DeepSeek MTP for large batch sizes (>32), there are some parameters should be changed (Reference [this discussion](https://github.com/sgl-project/sglang/issues/4543#issuecomment-2737413756)):
  - Adjust `--max-running-requests` to a larger number. The default value is `32` for MTP. For larger batch sizes, you should increase this value beyond the default value.
165
  - Set `--cuda-graph-bs`. It's a list of batch sizes for cuda graph capture. The default captured batch sizes for speculative decoding is set [here](https://github.com/sgl-project/sglang/blob/49420741746c8f3e80e0eb17e7d012bfaf25793a/python/sglang/srt/model_executor/cuda_graph_runner.py#L126). You can include more batch sizes into it.
Ke Bao's avatar
Ke Bao committed
166

Yineng Zhang's avatar
Yineng Zhang committed
167
168
169
170

### Reasoning Content for DeepSeek R1

See [Separate Reasoning](https://docs.sglang.ai/backend/separate_reasoning.html).
Xihuai Wang's avatar
Xihuai Wang committed
171

172
173
174

### Function calling for DeepSeek Models

175
Add arguments `--tool-call-parser deepseekv3` and `--chat-template ./examples/chat_template/tool_chat_template_deepseekv3.jinja`(recommended) to enable this feature. For example (running on 1 * H20 node):
176
177

```
Yineng Zhang's avatar
Yineng Zhang committed
178
python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3-0324 --tp 8 --port 30000 --host 0.0.0.0 --mem-fraction-static 0.9 --tool-call-parser deepseekv3 --chat-template ./examples/chat_template/tool_chat_template_deepseekv3.jinja
179
180
181
182
183
184
185
186
187
188
189
190
191
```

Sample Request:

```
curl "http://127.0.0.1:30000/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{"temperature": 0, "max_tokens": 100, "model": "deepseek-ai/DeepSeek-V3-0324", "tools": [{"type": "function", "function": {"name": "query_weather", "description": "Get weather of an city, the user should supply a city first", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city, e.g. Beijing"}}, "required": ["city"]}}}], "messages": [{"role": "user", "content": "Hows the weather like in Qingdao today"}]}'
```

Expected Response

```
192
{"id":"6501ef8e2d874006bf555bc80cddc7c5","object":"chat.completion","created":1745993638,"model":"deepseek-ai/DeepSeek-V3-0324","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":[{"id":"0","index":null,"type":"function","function":{"name":"query_weather","arguments":"{\"city\": \"Qingdao\"}"}}]},"logprobs":null,"finish_reason":"tool_calls","matched_stop":null}],"usage":{"prompt_tokens":116,"total_tokens":138,"completion_tokens":22,"prompt_tokens_details":null}}
193
194

```
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
Sample Streaming Request:
```
curl "http://127.0.0.1:30000/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{"temperature": 0, "max_tokens": 100, "model": "deepseek-ai/DeepSeek-V3-0324","stream":true,"tools": [{"type": "function", "function": {"name": "query_weather", "description": "Get weather of an city, the user should supply a city first", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city, e.g. Beijing"}}, "required": ["city"]}}}], "messages": [{"role": "user", "content": "Hows the weather like in Qingdao today"}]}'
```
Expected Streamed Chunks (simplified for clarity):
```
data: {"choices":[{"delta":{"tool_calls":[{"function":{"arguments":"{\""}}]}}]}
data: {"choices":[{"delta":{"tool_calls":[{"function":{"arguments":"city"}}]}}]}
data: {"choices":[{"delta":{"tool_calls":[{"function":{"arguments":"\":\""}}]}}]}
data: {"choices":[{"delta":{"tool_calls":[{"function":{"arguments":"Q"}}]}}]}
data: {"choices":[{"delta":{"tool_calls":[{"function":{"arguments":"ing"}}]}}]}
data: {"choices":[{"delta":{"tool_calls":[{"function":{"arguments":"dao"}}]}}]}
data: {"choices":[{"delta":{"tool_calls":[{"function":{"arguments":"\"}"}}]}}]}
data: {"choices":[{"delta":{"tool_calls":null}}], "finish_reason": "tool_calls"}
data: [DONE]
```
213
The client needs to concatenate all arguments fragments to reconstruct the complete tool call:
214
215
216
```
{"city": "Qingdao"}
```
217
218
Important Notes:
1. Use a lower `"temperature"` value for better results.
219
2. To receive more consistent tool call results, it is recommended to use `--chat-template examples/chat_template/tool_chat_template_deepseekv3.jinja`. It provides an improved unified prompt.
220

221

222
223
## FAQ

224
**Q: Model loading is taking too long, and I'm encountering an NCCL timeout. What should I do?**
225

226
A: If you're experiencing extended model loading times and an NCCL timeout, you can try increasing the timeout duration. Add the argument `--dist-timeout 3600` when launching your model. This will set the timeout to one hour, which often resolves the issue.