README.md 10.6 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
# SGLang
2
| [**Blog**](https://lmsys.org/blog/2024-01-17-sglang/) | [**Paper**](https://arxiv.org/abs/2312.07104) |
Lianmin Zheng's avatar
Lianmin Zheng committed
3
4
5
6
7
8
9
10
11
12
13

SGLang is a structured generation language designed for large language models (LLMs).
It makes your interaction with LLMs faster and more controllable by co-designing the frontend language and the runtime system.

The core features of SGLang include:
- **A Flexible Front-End Language**: This allows for easy programming of LLM applications with multiple chained generation calls, advanced prompting techniques, control flow, multiple modalities, parallelism, and external interaction.
- **A High-Performance Runtime with RadixAttention**: This feature significantly accelerates the execution of complex LLM programs by automatic KV cache reuse across multiple calls. It also supports other common techniques like continuous batching and tensor parallelism.

## Contents
- [Install](#install)
- [Quick Start](#quick-start)
14
- [Frontend: Structured Generation Language (SGLang)](#frontend-structured-generation-language-sglang)
Lianmin Zheng's avatar
Lianmin Zheng committed
15
16
17
18
19
20
21
- [Backend: SGLang Runtime (SRT)](#backend-sglang-runtime-srt)
- [Benchmark And Performance](#benchmark-and-performance)
- [Roadmap](#roadmap)
- [Citation And Acknowledgment](#citation-and-acknowledgment)

## Install

Lianmin Zheng's avatar
Lianmin Zheng committed
22
23
24
25
### Method 1: With pip
```
pip install "sglang[all]"
```
Lianmin Zheng's avatar
Lianmin Zheng committed
26

Lianmin Zheng's avatar
Lianmin Zheng committed
27
### Method 2: From source
Lianmin Zheng's avatar
Lianmin Zheng committed
28
29
30
31
32
33
34
35
```
git clone git@github.com:sgl-project/sglang.git
cd sglang

pip install --upgrade pip
pip install -e "python[all]"
```

Ying Sheng's avatar
Ying Sheng committed
36
37
38
39
### Notes
- If you are using older GPUs (NVIDIA T4, V100), please use `pip install "triton>=2.2.0"` to avoid some bugs in the triton compiler
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install sglang[openai]`

Lianmin Zheng's avatar
Lianmin Zheng committed
40
41
42
43
## Quick Start
The example below shows how to use sglang to answer a mulit-turn question.

### Using OpenAI Models
Lianmin Zheng's avatar
Lianmin Zheng committed
44
45
Set the OpenAI API Key
```
46
export OPENAI_API_KEY=sk-******
Lianmin Zheng's avatar
Lianmin Zheng committed
47
48
49
```

Then, answer a multi-turn question.
Lianmin Zheng's avatar
Lianmin Zheng committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
```python
from sglang import function, system, user, assistant, gen, set_default_backend, OpenAI

@function
def multi_turn_question(s, question_1, question_2):
    s += system("You are a helpful assistant.")
    s += user(question_1)
    s += assistant(gen("answer_1", max_tokens=256))
    s += user(question_2)
    s += assistant(gen("answer_2", max_tokens=256))

set_default_backend(OpenAI("gpt-3.5-turbo"))

state = multi_turn_question.run(
    question_1="What is the capital of the United States?",
    question_2="List two local attractions.",
)

for m in state.messages():
    print(m["role"], ":", m["content"])
```

### Using Local Models
First, launch a server with
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```

Then, connect to the server and answer a multi-turn question.

```python
from sglang import function, system, user, assistant, gen, set_default_backend, RuntimeEndpoint

@function
def multi_turn_question(s, question_1, question_2):
    s += system("You are a helpful assistant.")
    s += user(question_1)
    s += assistant(gen("answer_1", max_tokens=256))
    s += user(question_2)
    s += assistant(gen("answer_2", max_tokens=256))

set_default_backend(RuntimeEndpoint("http://localhost:30000"))

state = multi_turn_question.run(
    question_1="What is the capital of the United States?",
    question_2="List two local attractions.",
)

for m in state.messages():
    print(m["role"], ":", m["content"])
```

### More Examples

104
Anthropic and VertexAI (Gemini) models are also supported.
Lianmin Zheng's avatar
Lianmin Zheng committed
105
106
You can find more examples at [examples/quick_start](examples/quick_start).

107
## Frontend: Structured Generation Language (SGLang)
Lianmin Zheng's avatar
Lianmin Zheng committed
108

Lianmin Zheng's avatar
Lianmin Zheng committed
109
110
111
112
113
To begin with, import sglang.
```python
import sglang as sgl
```

Lianmin Zheng's avatar
Lianmin Zheng committed
114
`sglang` provides some simple primitives such as `gen`, `select`, `fork`, `image`.
Lianmin Zheng's avatar
Lianmin Zheng committed
115
116
117
118
You can implement your prompt flow in a function decorated by `sgl.function`.
You can then invoke the function with `run` or `run_batch`.
The system will manage the state, chat template, and parallelism for you.

Lianmin Zheng's avatar
Lianmin Zheng committed
119
### Control Flow
Lianmin Zheng's avatar
Lianmin Zheng committed
120
121
You can use any Python code within the function body, including control flow, nested function calls, and external libraries.

Lianmin Zheng's avatar
Lianmin Zheng committed
122
123
124
125
126
127
128
129
130
131
132
```python
@sgl.function
def control_flow(s, question):
    s += "To answer this question: " + question + ", "
    s += "I need to use a " + sgl.gen("tool", choices=["calculator", "web browser"]) + ". "

    if s["tool"] == "calculator":
        s += "The math expression is" + sgl.gen("expression")
    elif s["tool"] == "web browser":
        s += "The website url is" + sgl.gen("url")
```
Lianmin Zheng's avatar
Lianmin Zheng committed
133
134

### Parallelism
Lianmin Zheng's avatar
Lianmin Zheng committed
135
136
137
Use `fork` to launch parallel prompts.
Because `sgl.gen` is non-blocking, the for loop below issues two generation calls in parallel.

Lianmin Zheng's avatar
Lianmin Zheng committed
138
139
140
141
142
143
144
145
```python
@sgl.function
def tip_suggestion(s):
    s += (
        "Here are two tips for staying healthy: "
        "1. Balanced Diet. 2. Regular Exercise.\n\n"
    )

Lianmin Zheng's avatar
Lianmin Zheng committed
146
    forks = s.fork(2)
Lianmin Zheng's avatar
Lianmin Zheng committed
147
148
149
150
151
152
153
154
    for i, f in enumerate(forks):
        f += f"Now, expand tip {i+1} into a paragraph:\n"
        f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n")

    s += "Tip 1:" + forks[0]["detailed_tip"] + "\n"
    s += "Tip 2:" + forks[1]["detailed_tip"] + "\n"
    s += "In summary" + sgl.gen("summary")
```
Lianmin Zheng's avatar
Lianmin Zheng committed
155
156

### Multi Modality
Lianmin Zheng's avatar
Lianmin Zheng committed
157
158
Use `sgl.image` to pass an image as input.

Lianmin Zheng's avatar
Lianmin Zheng committed
159
160
```python
@sgl.function
Lianmin Zheng's avatar
Lianmin Zheng committed
161
def image_qa(s, image_file, question):
Lianmin Zheng's avatar
Lianmin Zheng committed
162
    s += sgl.user(sgl.image(image_file) + question)
Lianmin Zheng's avatar
Lianmin Zheng committed
163
    s += sgl.assistant(sgl.gen("answer", max_tokens=256)
Lianmin Zheng's avatar
Lianmin Zheng committed
164
165
```

Lianmin Zheng's avatar
Lianmin Zheng committed
166
### Constrained Decoding
167
168
Use `regex` to specify a regular expression as a decoding constraint.
This is only supported for local models.
Lianmin Zheng's avatar
Lianmin Zheng committed
169

Lianmin Zheng's avatar
Lianmin Zheng committed
170
```python
Lianmin Zheng's avatar
Lianmin Zheng committed
171
@sgl.function
Lianmin Zheng's avatar
Lianmin Zheng committed
172
173
def regular_expression_gen(s):
    s += "Q: What is the IP address of the Google DNS servers?\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
174
    s += "A: " + sgl.gen(
Lianmin Zheng's avatar
Lianmin Zheng committed
175
176
177
178
179
        "answer",
        temperature=0,
        regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
    )
```
Lianmin Zheng's avatar
Lianmin Zheng committed
180

Lianmin Zheng's avatar
Lianmin Zheng committed
181
### Batching
Lianmin Zheng's avatar
Lianmin Zheng committed
182
183
Use `run_batch` to run a batch of requests with continuous batching.

Lianmin Zheng's avatar
Lianmin Zheng committed
184
185
186
187
188
189
190
191
192
193
194
195
```python
@sgl.function
def text_qa(s, question):
    s += "Q: " + question + "\n"
    s += "A:" + sgl.gen("answer", stop="\n")

states = text_qa.run_batch(
    [
        {"question": "What is the capital of the United Kingdom?"},
        {"question": "What is the capital of France?"},
        {"question": "What is the capital of Japan?"},
    ],
Lianmin Zheng's avatar
Lianmin Zheng committed
196
    progress_bar=True
Lianmin Zheng's avatar
Lianmin Zheng committed
197
198
)
```
Lianmin Zheng's avatar
Lianmin Zheng committed
199
200

### Streaming
Lianmin Zheng's avatar
Lianmin Zheng committed
201
202
Add `stream=True` to enable streaming.

Lianmin Zheng's avatar
Lianmin Zheng committed
203
204
205
206
207
208
209
210
```python
@sgl.function
def text_qa(s, question):
    s += "Q: " + question + "\n"
    s += "A:" + sgl.gen("answer", stop="\n")

states = text_qa.run(
    question="What is the capital of France?",
Lianmin Zheng's avatar
Lianmin Zheng committed
211
212
213
    temperature=0.1,
    stream=True
)
Lianmin Zheng's avatar
Lianmin Zheng committed
214

Lianmin Zheng's avatar
Lianmin Zheng committed
215
216
217
for out in state.text_iter():
    print(out, end="", flush=True)
```
Lianmin Zheng's avatar
Lianmin Zheng committed
218
219
220
221

## Backend: SGLang Runtime (SRT)
The SGLang Runtime (SRT) is designed to work best with the SGLang frontend.
However, it can also be used as a standalone API server.
Ying Sheng's avatar
Ying Sheng committed
222
In this case, the [RadixAttention](https://arxiv.org/abs/2312.07104) can still greatly accelerate many use cases with automatic KV cache reuse.
Lianmin Zheng's avatar
Lianmin Zheng committed
223
224
225
226
227
228
229
230
231

### Usage
Launch a server
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```

Send a request
```
232
curl http://localhost:30000/generate \
Lianmin Zheng's avatar
Lianmin Zheng committed
233
234
  -H "Content-Type: application/json" \
  -d '{
235
236
237
238
239
    "text": "Once upon a time,",
    "parameters": {
      "max_new_tokens": 16,
      "temperature": 0
    }
Lianmin Zheng's avatar
Lianmin Zheng committed
240
241
  }'
```
242
243
Learn more about the argument format [here](docs/sampling_params.md).

244
245
246
247
248
249
250
251
### OpenAI Compatible API

In addition, the server supports an experimental OpenAI-compatible API.

```python
import openai
client = openai.Client(
    base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
Cody Yu's avatar
Cody Yu committed
252
253

# Text completion
254
255
256
257
258
259
260
response = client.completions.create(
	model="default",
	prompt="The capital of France is",
	temperature=0,
	max_tokens=32,
)
print(response)
Cody Yu's avatar
Cody Yu committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278

# Chat completion
response = client.chat.completions.create(
    model="default",
    messages=[
        {"role": "system", "content": "You are a helpful AI assistant"},
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)
print(response)
```

In above example, the server uses the chat template specified in the model tokenizer.
You can override the chat template if needed when launching the server:

```
279
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template llama-2
Cody Yu's avatar
Cody Yu committed
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
```

If the chat template you are looking for is missing, you are welcome to contribute it.
Meanwhile, you can also temporary register your chat template as follows:

```json
{
  "name": "my_model",
  "system": "<|im_start|>system",
  "user": "<|im_start|>user",
  "assistant": "<|im_start|>assistant",
  "sep_style": "CHATML",
  "sep": "<|im_end|>",
  "stop_str": ["<|im_end|>", "<|im_start|>"]
}
```

```
298
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.json
299
300
```

Lianmin Zheng's avatar
Lianmin Zheng committed
301
302
303
304
305
### Additional Arguments
- Add `--tp 2` to enable tensor parallelism.
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --tp 2
```
Ying Sheng's avatar
Ying Sheng committed
306
307
308
309
- If you see out-of-memory errors during serving, please try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
```
Lianmin Zheng's avatar
Lianmin Zheng committed
310
311
312
313
314
315

### Supported Models
- Llama
- Mistral
- Mixtral
- LLaVA
Lianmin Zheng's avatar
Lianmin Zheng committed
316
  - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000`
317
- AWQ quantization
Lianmin Zheng's avatar
Lianmin Zheng committed
318
319
320

## Benchmark And Performance

Lianmin Zheng's avatar
Lianmin Zheng committed
321
322
323
324
325
326
- Llama-7B on NVIDIA A10G, FP16, Tensor Parallelism=1
![llama_7b](assets/llama_7b.jpg)

- Mixtral-8x7B on NVIDIA A10G, FP16, Tensor Parallelism=8
![mixtral_8x7b](assets/mixtral_8x7b.jpg)

Lianmin Zheng's avatar
Lianmin Zheng committed
327
Learn more [here](docs/benchmark_results.md).
Lianmin Zheng's avatar
Lianmin Zheng committed
328

Lianmin Zheng's avatar
Lianmin Zheng committed
329
## Roadmap
Lianmin Zheng's avatar
Lianmin Zheng committed
330
- [ ] Function call APIs
Ying Sheng's avatar
Ying Sheng committed
331
- [ ] S-LoRA (expect by Feb. 5)
Lianmin Zheng's avatar
Lianmin Zheng committed
332
333
- [ ] Support more models
- [ ] Support more hardware backends
Lianmin Zheng's avatar
Lianmin Zheng committed
334
335
336
337
338
339
340
341
342
343
344
345
346

## Citation And Acknowledgment
```
@misc{zheng2023efficiently,
      title={Efficiently Programming Large Language Models using SGLang},
      author={Lianmin Zheng and Liangsheng Yin and Zhiqiang Xie and Jeff Huang and Chuyue Sun and Cody Hao Yu and Shiyi Cao and Christos Kozyrakis and Ion Stoica and Joseph E. Gonzalez and Clark Barrett and Ying Sheng},
      year={2023},
      eprint={2312.07104},
      archivePrefix={arXiv},
      primaryClass={cs.AI}
}
```

347
We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), [LMQL](https://github.com/eth-sri/lmql).