README.md 13 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
# SGLang
2
| [**Blog**](https://lmsys.org/blog/2024-01-17-sglang/) | [**Paper**](https://arxiv.org/abs/2312.07104) |
Lianmin Zheng's avatar
Lianmin Zheng committed
3
4
5
6
7
8
9
10

SGLang is a structured generation language designed for large language models (LLMs).
It makes your interaction with LLMs faster and more controllable by co-designing the frontend language and the runtime system.

The core features of SGLang include:
- **A Flexible Front-End Language**: This allows for easy programming of LLM applications with multiple chained generation calls, advanced prompting techniques, control flow, multiple modalities, parallelism, and external interaction.
- **A High-Performance Runtime with RadixAttention**: This feature significantly accelerates the execution of complex LLM programs by automatic KV cache reuse across multiple calls. It also supports other common techniques like continuous batching and tensor parallelism.

Ying Sheng's avatar
Ying Sheng committed
11
12
13
14
## News
- [2024/01] 🔥 SGLang powers the serving of the offical LLaVA v1.6 release demo ([blog](https://llava-vl.github.io/blog/2024-01-30-llava-1-6/)).
- [2024/01] SGLang provides up to 5x faster inference with RadixAttention ([blog](https://lmsys.org/blog/2024-01-17-sglang/)).

Lianmin Zheng's avatar
Lianmin Zheng committed
15
16
17
## Contents
- [Install](#install)
- [Quick Start](#quick-start)
18
- [Frontend: Structured Generation Language (SGLang)](#frontend-structured-generation-language-sglang)
Lianmin Zheng's avatar
Lianmin Zheng committed
19
20
21
22
23
24
25
- [Backend: SGLang Runtime (SRT)](#backend-sglang-runtime-srt)
- [Benchmark And Performance](#benchmark-and-performance)
- [Roadmap](#roadmap)
- [Citation And Acknowledgment](#citation-and-acknowledgment)

## Install

Lianmin Zheng's avatar
Lianmin Zheng committed
26
27
28
29
### Method 1: With pip
```
pip install "sglang[all]"
```
Lianmin Zheng's avatar
Lianmin Zheng committed
30

Lianmin Zheng's avatar
Lianmin Zheng committed
31
### Method 2: From source
Lianmin Zheng's avatar
Lianmin Zheng committed
32
33
34
35
36
37
38
39
```
git clone git@github.com:sgl-project/sglang.git
cd sglang

pip install --upgrade pip
pip install -e "python[all]"
```

Ying Sheng's avatar
Ying Sheng committed
40
### Notes
41
42
43
- If you are using older GPUs (NVIDIA V100, T4), please pick the correct triton compiler version to avoid some known bugs.
  - For NVIDIA T4, please use `pip install "triton>=2.2.0"`.
  - For NVIDIA V100, please install the [nightly](https://triton-lang.org/main/getting-started/installation.html) version.
Lianmin Zheng's avatar
Lianmin Zheng committed
44
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`
Ying Sheng's avatar
Ying Sheng committed
45

46

Lianmin Zheng's avatar
Lianmin Zheng committed
47
48
49
## Quick Start
The example below shows how to use sglang to answer a mulit-turn question.

50
51
### Using Local Models
First, launch a server with
Lianmin Zheng's avatar
Lianmin Zheng committed
52
```
53
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
Lianmin Zheng's avatar
Lianmin Zheng committed
54
55
```

56
57
Then, connect to the server and answer a multi-turn question.

Lianmin Zheng's avatar
Lianmin Zheng committed
58
```python
59
from sglang import function, system, user, assistant, gen, set_default_backend, RuntimeEndpoint
Lianmin Zheng's avatar
Lianmin Zheng committed
60
61
62
63
64
65
66
67
68

@function
def multi_turn_question(s, question_1, question_2):
    s += system("You are a helpful assistant.")
    s += user(question_1)
    s += assistant(gen("answer_1", max_tokens=256))
    s += user(question_2)
    s += assistant(gen("answer_2", max_tokens=256))

69
set_default_backend(RuntimeEndpoint("http://localhost:30000"))
Lianmin Zheng's avatar
Lianmin Zheng committed
70
71
72
73
74
75
76
77

state = multi_turn_question.run(
    question_1="What is the capital of the United States?",
    question_2="List two local attractions.",
)

for m in state.messages():
    print(m["role"], ":", m["content"])
78
79

print(state["answer_1"])
Lianmin Zheng's avatar
Lianmin Zheng committed
80
81
```

82
83
### Using OpenAI Models
Set the OpenAI API Key
Lianmin Zheng's avatar
Lianmin Zheng committed
84
```
85
export OPENAI_API_KEY=sk-******
Lianmin Zheng's avatar
Lianmin Zheng committed
86
87
```

88
Then, answer a multi-turn question.
Lianmin Zheng's avatar
Lianmin Zheng committed
89
```python
90
from sglang import function, system, user, assistant, gen, set_default_backend, OpenAI
Lianmin Zheng's avatar
Lianmin Zheng committed
91
92
93
94
95
96
97
98
99

@function
def multi_turn_question(s, question_1, question_2):
    s += system("You are a helpful assistant.")
    s += user(question_1)
    s += assistant(gen("answer_1", max_tokens=256))
    s += user(question_2)
    s += assistant(gen("answer_2", max_tokens=256))

100
set_default_backend(OpenAI("gpt-3.5-turbo"))
Lianmin Zheng's avatar
Lianmin Zheng committed
101
102
103
104
105
106
107
108

state = multi_turn_question.run(
    question_1="What is the capital of the United States?",
    question_2="List two local attractions.",
)

for m in state.messages():
    print(m["role"], ":", m["content"])
109
110

print(state["answer_1"])
Lianmin Zheng's avatar
Lianmin Zheng committed
111
112
113
114
```

### More Examples

115
Anthropic and VertexAI (Gemini) models are also supported.
Lianmin Zheng's avatar
Lianmin Zheng committed
116
117
You can find more examples at [examples/quick_start](examples/quick_start).

118
## Frontend: Structured Generation Language (SGLang)
Lianmin Zheng's avatar
Lianmin Zheng committed
119

Lianmin Zheng's avatar
Lianmin Zheng committed
120
121
122
123
124
To begin with, import sglang.
```python
import sglang as sgl
```

Lianmin Zheng's avatar
Lianmin Zheng committed
125
`sglang` provides some simple primitives such as `gen`, `select`, `fork`, `image`.
Lianmin Zheng's avatar
Lianmin Zheng committed
126
127
You can implement your prompt flow in a function decorated by `sgl.function`.
You can then invoke the function with `run` or `run_batch`.
128
The system will manage the state, chat template, parallelism and batching for you.
Lianmin Zheng's avatar
Lianmin Zheng committed
129

130
131
The complete code for the examples below can be found at [readme_examples.py](examples/usage/readme_examples.py)

Lianmin Zheng's avatar
Lianmin Zheng committed
132
### Control Flow
Lianmin Zheng's avatar
Lianmin Zheng committed
133
134
You can use any Python code within the function body, including control flow, nested function calls, and external libraries.

Lianmin Zheng's avatar
Lianmin Zheng committed
135
136
```python
@sgl.function
137
138
139
def tool_use(s, question):
    s += "To answer this question: " + question + ". "
    s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"]) + ". "
Lianmin Zheng's avatar
Lianmin Zheng committed
140
141
142

    if s["tool"] == "calculator":
        s += "The math expression is" + sgl.gen("expression")
143
144
    elif s["tool"] == "search engine":
        s += "The key word to search is" + sgl.gen("word")
Lianmin Zheng's avatar
Lianmin Zheng committed
145
```
Lianmin Zheng's avatar
Lianmin Zheng committed
146
147

### Parallelism
Lianmin Zheng's avatar
Lianmin Zheng committed
148
149
150
Use `fork` to launch parallel prompts.
Because `sgl.gen` is non-blocking, the for loop below issues two generation calls in parallel.

Lianmin Zheng's avatar
Lianmin Zheng committed
151
152
153
154
155
156
157
158
```python
@sgl.function
def tip_suggestion(s):
    s += (
        "Here are two tips for staying healthy: "
        "1. Balanced Diet. 2. Regular Exercise.\n\n"
    )

Lianmin Zheng's avatar
Lianmin Zheng committed
159
    forks = s.fork(2)
Lianmin Zheng's avatar
Lianmin Zheng committed
160
161
162
163
164
165
166
167
    for i, f in enumerate(forks):
        f += f"Now, expand tip {i+1} into a paragraph:\n"
        f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n")

    s += "Tip 1:" + forks[0]["detailed_tip"] + "\n"
    s += "Tip 2:" + forks[1]["detailed_tip"] + "\n"
    s += "In summary" + sgl.gen("summary")
```
Lianmin Zheng's avatar
Lianmin Zheng committed
168
169

### Multi Modality
Lianmin Zheng's avatar
Lianmin Zheng committed
170
171
Use `sgl.image` to pass an image as input.

Lianmin Zheng's avatar
Lianmin Zheng committed
172
173
```python
@sgl.function
Lianmin Zheng's avatar
Lianmin Zheng committed
174
def image_qa(s, image_file, question):
Lianmin Zheng's avatar
Lianmin Zheng committed
175
    s += sgl.user(sgl.image(image_file) + question)
Lianmin Zheng's avatar
Lianmin Zheng committed
176
    s += sgl.assistant(sgl.gen("answer", max_tokens=256)
Lianmin Zheng's avatar
Lianmin Zheng committed
177
178
```

179
180
See also [srt_example_llava.py](examples/quick_start/srt_example_llava.py).

Lianmin Zheng's avatar
Lianmin Zheng committed
181
### Constrained Decoding
182
183
Use `regex` to specify a regular expression as a decoding constraint.
This is only supported for local models.
Lianmin Zheng's avatar
Lianmin Zheng committed
184

Lianmin Zheng's avatar
Lianmin Zheng committed
185
```python
Lianmin Zheng's avatar
Lianmin Zheng committed
186
@sgl.function
Lianmin Zheng's avatar
Lianmin Zheng committed
187
188
def regular_expression_gen(s):
    s += "Q: What is the IP address of the Google DNS servers?\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
189
    s += "A: " + sgl.gen(
Lianmin Zheng's avatar
Lianmin Zheng committed
190
191
192
193
194
        "answer",
        temperature=0,
        regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
    )
```
Lianmin Zheng's avatar
Lianmin Zheng committed
195

196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
### JSON Decoding

```python
character_regex = (
    r"""\{\n"""
    + r"""    "name": "[\w\d\s]{1,16}",\n"""
    + r"""    "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
    + r"""    "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
    + r"""    "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
    + r"""    "wand": \{\n"""
    + r"""        "wood": "[\w\d\s]{1,16}",\n"""
    + r"""        "core": "[\w\d\s]{1,16}",\n"""
    + r"""        "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
    + r"""    \},\n"""
    + r"""    "alive": "(Alive|Deceased)",\n"""
    + r"""    "patronus": "[\w\d\s]{1,16}",\n"""
    + r"""    "bogart": "[\w\d\s]{1,16}"\n"""
    + r"""\}"""
)

@sgl.function
def character_gen(s, name):
    s += name + " is a character in Harry Potter. Please fill in the following information about him/her.\n"
    s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
```

See also [json_decode.py](examples/usage/json_decode.py).


Lianmin Zheng's avatar
Lianmin Zheng committed
225
### Batching
Lianmin Zheng's avatar
Lianmin Zheng committed
226
227
Use `run_batch` to run a batch of requests with continuous batching.

Lianmin Zheng's avatar
Lianmin Zheng committed
228
229
230
231
232
233
234
235
236
237
238
239
```python
@sgl.function
def text_qa(s, question):
    s += "Q: " + question + "\n"
    s += "A:" + sgl.gen("answer", stop="\n")

states = text_qa.run_batch(
    [
        {"question": "What is the capital of the United Kingdom?"},
        {"question": "What is the capital of France?"},
        {"question": "What is the capital of Japan?"},
    ],
Lianmin Zheng's avatar
Lianmin Zheng committed
240
    progress_bar=True
Lianmin Zheng's avatar
Lianmin Zheng committed
241
242
)
```
Lianmin Zheng's avatar
Lianmin Zheng committed
243
244

### Streaming
Lianmin Zheng's avatar
Lianmin Zheng committed
245
246
Add `stream=True` to enable streaming.

Lianmin Zheng's avatar
Lianmin Zheng committed
247
248
249
250
251
252
253
254
```python
@sgl.function
def text_qa(s, question):
    s += "Q: " + question + "\n"
    s += "A:" + sgl.gen("answer", stop="\n")

states = text_qa.run(
    question="What is the capital of France?",
Lianmin Zheng's avatar
Lianmin Zheng committed
255
256
257
    temperature=0.1,
    stream=True
)
Lianmin Zheng's avatar
Lianmin Zheng committed
258

Lianmin Zheng's avatar
Lianmin Zheng committed
259
260
261
for out in state.text_iter():
    print(out, end="", flush=True)
```
Lianmin Zheng's avatar
Lianmin Zheng committed
262

Lianmin Zheng's avatar
Lianmin Zheng committed
263
264
265
266
### Tips and Implementation Details
- The `choices` argument in `sgl.gen` is implemented by computing the normalized log probabilities of all choices and selecting the one with the highest probability.
- The `regex` argument in `sgl.gen` is implemented through autoregressive decoding with logit bias masking, according to the constraints set by the regex.

Lianmin Zheng's avatar
Lianmin Zheng committed
267
268
269
## Backend: SGLang Runtime (SRT)
The SGLang Runtime (SRT) is designed to work best with the SGLang frontend.
However, it can also be used as a standalone API server.
Ying Sheng's avatar
Ying Sheng committed
270
In this case, the [RadixAttention](https://arxiv.org/abs/2312.07104) can still greatly accelerate many use cases with automatic KV cache reuse.
Lianmin Zheng's avatar
Lianmin Zheng committed
271
272
273
274
275
276
277
278
279

### Usage
Launch a server
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```

Send a request
```
280
curl http://localhost:30000/generate \
Lianmin Zheng's avatar
Lianmin Zheng committed
281
282
  -H "Content-Type: application/json" \
  -d '{
283
    "text": "Once upon a time,",
284
    "sampling_params": {
285
286
287
      "max_new_tokens": 16,
      "temperature": 0
    }
Lianmin Zheng's avatar
Lianmin Zheng committed
288
289
  }'
```
290
291
Learn more about the argument format [here](docs/sampling_params.md).

292
293
294
295
296
297
298
299
### OpenAI Compatible API

In addition, the server supports an experimental OpenAI-compatible API.

```python
import openai
client = openai.Client(
    base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
Cody Yu's avatar
Cody Yu committed
300
301

# Text completion
302
303
304
305
306
307
308
response = client.completions.create(
	model="default",
	prompt="The capital of France is",
	temperature=0,
	max_tokens=32,
)
print(response)
Cody Yu's avatar
Cody Yu committed
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326

# Chat completion
response = client.chat.completions.create(
    model="default",
    messages=[
        {"role": "system", "content": "You are a helpful AI assistant"},
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)
print(response)
```

In above example, the server uses the chat template specified in the model tokenizer.
You can override the chat template if needed when launching the server:

```
327
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template llama-2
Cody Yu's avatar
Cody Yu committed
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
```

If the chat template you are looking for is missing, you are welcome to contribute it.
Meanwhile, you can also temporary register your chat template as follows:

```json
{
  "name": "my_model",
  "system": "<|im_start|>system",
  "user": "<|im_start|>user",
  "assistant": "<|im_start|>assistant",
  "sep_style": "CHATML",
  "sep": "<|im_end|>",
  "stop_str": ["<|im_end|>", "<|im_start|>"]
}
```

```
346
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.json
347
348
```

Lianmin Zheng's avatar
Lianmin Zheng committed
349
350
351
352
353
### Additional Arguments
- Add `--tp 2` to enable tensor parallelism.
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --tp 2
```
Ying Sheng's avatar
Ying Sheng committed
354
355
356
357
- If you see out-of-memory errors during serving, please try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
```
Lianmin Zheng's avatar
Lianmin Zheng committed
358
- You can turn on [flashinfer](docs/flashinfer.md) to acclerate the inference by using highly optimized CUDA kernels.
Lianmin Zheng's avatar
Lianmin Zheng committed
359
360
361
362
363

### Supported Models
- Llama
- Mistral
- Mixtral
Lianmin Zheng's avatar
Lianmin Zheng committed
364
- Qwen / Qwen 2
Lianmin Zheng's avatar
Lianmin Zheng committed
365
- LLaVA
366
  - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
Lianmin Zheng's avatar
Lianmin Zheng committed
367
368
- Yi-VL
  - see [srt_example_yi_vl.py](examples/quick_start/srt_example_yi_vl.py).
369
- AWQ quantization
Lianmin Zheng's avatar
Lianmin Zheng committed
370
371
372

## Benchmark And Performance

Lianmin Zheng's avatar
Lianmin Zheng committed
373
374
375
376
377
378
- Llama-7B on NVIDIA A10G, FP16, Tensor Parallelism=1
![llama_7b](assets/llama_7b.jpg)

- Mixtral-8x7B on NVIDIA A10G, FP16, Tensor Parallelism=8
![mixtral_8x7b](assets/mixtral_8x7b.jpg)

Lianmin Zheng's avatar
Lianmin Zheng committed
379
Learn more [here](docs/benchmark_results.md).
Lianmin Zheng's avatar
Lianmin Zheng committed
380

Lianmin Zheng's avatar
Lianmin Zheng committed
381
## Roadmap
Lianmin Zheng's avatar
Lianmin Zheng committed
382
- [ ] Function call APIs
Ying Sheng's avatar
Ying Sheng committed
383
- [ ] S-LoRA (expect by Feb. 5)
Lianmin Zheng's avatar
Lianmin Zheng committed
384
385
- [ ] Support more models
- [ ] Support more hardware backends
Lianmin Zheng's avatar
Lianmin Zheng committed
386
387
388
389
390
391
392
393
394
395
396
397
398

## Citation And Acknowledgment
```
@misc{zheng2023efficiently,
      title={Efficiently Programming Large Language Models using SGLang},
      author={Lianmin Zheng and Liangsheng Yin and Zhiqiang Xie and Jeff Huang and Chuyue Sun and Cody Hao Yu and Shiyi Cao and Christos Kozyrakis and Ion Stoica and Joseph E. Gonzalez and Clark Barrett and Ying Sheng},
      year={2023},
      eprint={2312.07104},
      archivePrefix={arXiv},
      primaryClass={cs.AI}
}
```

Lianmin Zheng's avatar
Lianmin Zheng committed
399
400
401
[![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-md.svg)](https://huggingface.co/papers/2312.07104)


402
We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), [LMQL](https://github.com/eth-sri/lmql).