"vscode:/vscode.git/clone" did not exist on "3cadecf0c4a1e8dbce63700ad7a1ba3716494e95"
README.md 12.7 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
# SGLang
2
| [**Blog**](https://lmsys.org/blog/2024-01-17-sglang/) | [**Paper**](https://arxiv.org/abs/2312.07104) |
Lianmin Zheng's avatar
Lianmin Zheng committed
3
4
5
6
7
8
9
10
11
12
13

SGLang is a structured generation language designed for large language models (LLMs).
It makes your interaction with LLMs faster and more controllable by co-designing the frontend language and the runtime system.

The core features of SGLang include:
- **A Flexible Front-End Language**: This allows for easy programming of LLM applications with multiple chained generation calls, advanced prompting techniques, control flow, multiple modalities, parallelism, and external interaction.
- **A High-Performance Runtime with RadixAttention**: This feature significantly accelerates the execution of complex LLM programs by automatic KV cache reuse across multiple calls. It also supports other common techniques like continuous batching and tensor parallelism.

## Contents
- [Install](#install)
- [Quick Start](#quick-start)
14
- [Frontend: Structured Generation Language (SGLang)](#frontend-structured-generation-language-sglang)
Lianmin Zheng's avatar
Lianmin Zheng committed
15
16
17
18
19
20
21
- [Backend: SGLang Runtime (SRT)](#backend-sglang-runtime-srt)
- [Benchmark And Performance](#benchmark-and-performance)
- [Roadmap](#roadmap)
- [Citation And Acknowledgment](#citation-and-acknowledgment)

## Install

Lianmin Zheng's avatar
Lianmin Zheng committed
22
23
24
25
### Method 1: With pip
```
pip install "sglang[all]"
```
Lianmin Zheng's avatar
Lianmin Zheng committed
26

Lianmin Zheng's avatar
Lianmin Zheng committed
27
### Method 2: From source
Lianmin Zheng's avatar
Lianmin Zheng committed
28
29
30
31
32
33
34
35
```
git clone git@github.com:sgl-project/sglang.git
cd sglang

pip install --upgrade pip
pip install -e "python[all]"
```

Ying Sheng's avatar
Ying Sheng committed
36
### Notes
37
38
39
- If you are using older GPUs (NVIDIA V100, T4), please pick the correct triton compiler version to avoid some known bugs.
  - For NVIDIA T4, please use `pip install "triton>=2.2.0"`.
  - For NVIDIA V100, please install the [nightly](https://triton-lang.org/main/getting-started/installation.html) version.
Lianmin Zheng's avatar
Lianmin Zheng committed
40
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`
Ying Sheng's avatar
Ying Sheng committed
41

42

Lianmin Zheng's avatar
Lianmin Zheng committed
43
44
45
## Quick Start
The example below shows how to use sglang to answer a mulit-turn question.

46
47
### Using Local Models
First, launch a server with
Lianmin Zheng's avatar
Lianmin Zheng committed
48
```
49
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
Lianmin Zheng's avatar
Lianmin Zheng committed
50
51
```

52
53
Then, connect to the server and answer a multi-turn question.

Lianmin Zheng's avatar
Lianmin Zheng committed
54
```python
55
from sglang import function, system, user, assistant, gen, set_default_backend, RuntimeEndpoint
Lianmin Zheng's avatar
Lianmin Zheng committed
56
57
58
59
60
61
62
63
64

@function
def multi_turn_question(s, question_1, question_2):
    s += system("You are a helpful assistant.")
    s += user(question_1)
    s += assistant(gen("answer_1", max_tokens=256))
    s += user(question_2)
    s += assistant(gen("answer_2", max_tokens=256))

65
set_default_backend(RuntimeEndpoint("http://localhost:30000"))
Lianmin Zheng's avatar
Lianmin Zheng committed
66
67
68
69
70
71
72
73

state = multi_turn_question.run(
    question_1="What is the capital of the United States?",
    question_2="List two local attractions.",
)

for m in state.messages():
    print(m["role"], ":", m["content"])
74
75

print(state["answer_1"])
Lianmin Zheng's avatar
Lianmin Zheng committed
76
77
```

78
79
### Using OpenAI Models
Set the OpenAI API Key
Lianmin Zheng's avatar
Lianmin Zheng committed
80
```
81
export OPENAI_API_KEY=sk-******
Lianmin Zheng's avatar
Lianmin Zheng committed
82
83
```

84
Then, answer a multi-turn question.
Lianmin Zheng's avatar
Lianmin Zheng committed
85
```python
86
from sglang import function, system, user, assistant, gen, set_default_backend, OpenAI
Lianmin Zheng's avatar
Lianmin Zheng committed
87
88
89
90
91
92
93
94
95

@function
def multi_turn_question(s, question_1, question_2):
    s += system("You are a helpful assistant.")
    s += user(question_1)
    s += assistant(gen("answer_1", max_tokens=256))
    s += user(question_2)
    s += assistant(gen("answer_2", max_tokens=256))

96
set_default_backend(OpenAI("gpt-3.5-turbo"))
Lianmin Zheng's avatar
Lianmin Zheng committed
97
98
99
100
101
102
103
104

state = multi_turn_question.run(
    question_1="What is the capital of the United States?",
    question_2="List two local attractions.",
)

for m in state.messages():
    print(m["role"], ":", m["content"])
105
106

print(state["answer_1"])
Lianmin Zheng's avatar
Lianmin Zheng committed
107
108
109
110
```

### More Examples

111
Anthropic and VertexAI (Gemini) models are also supported.
Lianmin Zheng's avatar
Lianmin Zheng committed
112
113
You can find more examples at [examples/quick_start](examples/quick_start).

114
## Frontend: Structured Generation Language (SGLang)
Lianmin Zheng's avatar
Lianmin Zheng committed
115

Lianmin Zheng's avatar
Lianmin Zheng committed
116
117
118
119
120
To begin with, import sglang.
```python
import sglang as sgl
```

Lianmin Zheng's avatar
Lianmin Zheng committed
121
`sglang` provides some simple primitives such as `gen`, `select`, `fork`, `image`.
Lianmin Zheng's avatar
Lianmin Zheng committed
122
123
You can implement your prompt flow in a function decorated by `sgl.function`.
You can then invoke the function with `run` or `run_batch`.
124
The system will manage the state, chat template, parallelism and batching for you.
Lianmin Zheng's avatar
Lianmin Zheng committed
125

126
127
The complete code for the examples below can be found at [readme_examples.py](examples/usage/readme_examples.py)

Lianmin Zheng's avatar
Lianmin Zheng committed
128
### Control Flow
Lianmin Zheng's avatar
Lianmin Zheng committed
129
130
You can use any Python code within the function body, including control flow, nested function calls, and external libraries.

Lianmin Zheng's avatar
Lianmin Zheng committed
131
132
```python
@sgl.function
133
134
135
def tool_use(s, question):
    s += "To answer this question: " + question + ". "
    s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"]) + ". "
Lianmin Zheng's avatar
Lianmin Zheng committed
136
137
138

    if s["tool"] == "calculator":
        s += "The math expression is" + sgl.gen("expression")
139
140
    elif s["tool"] == "search engine":
        s += "The key word to search is" + sgl.gen("word")
Lianmin Zheng's avatar
Lianmin Zheng committed
141
```
Lianmin Zheng's avatar
Lianmin Zheng committed
142
143

### Parallelism
Lianmin Zheng's avatar
Lianmin Zheng committed
144
145
146
Use `fork` to launch parallel prompts.
Because `sgl.gen` is non-blocking, the for loop below issues two generation calls in parallel.

Lianmin Zheng's avatar
Lianmin Zheng committed
147
148
149
150
151
152
153
154
```python
@sgl.function
def tip_suggestion(s):
    s += (
        "Here are two tips for staying healthy: "
        "1. Balanced Diet. 2. Regular Exercise.\n\n"
    )

Lianmin Zheng's avatar
Lianmin Zheng committed
155
    forks = s.fork(2)
Lianmin Zheng's avatar
Lianmin Zheng committed
156
157
158
159
160
161
162
163
    for i, f in enumerate(forks):
        f += f"Now, expand tip {i+1} into a paragraph:\n"
        f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n")

    s += "Tip 1:" + forks[0]["detailed_tip"] + "\n"
    s += "Tip 2:" + forks[1]["detailed_tip"] + "\n"
    s += "In summary" + sgl.gen("summary")
```
Lianmin Zheng's avatar
Lianmin Zheng committed
164
165

### Multi Modality
Lianmin Zheng's avatar
Lianmin Zheng committed
166
167
Use `sgl.image` to pass an image as input.

Lianmin Zheng's avatar
Lianmin Zheng committed
168
169
```python
@sgl.function
Lianmin Zheng's avatar
Lianmin Zheng committed
170
def image_qa(s, image_file, question):
Lianmin Zheng's avatar
Lianmin Zheng committed
171
    s += sgl.user(sgl.image(image_file) + question)
Lianmin Zheng's avatar
Lianmin Zheng committed
172
    s += sgl.assistant(sgl.gen("answer", max_tokens=256)
Lianmin Zheng's avatar
Lianmin Zheng committed
173
174
```

175
176
See also [srt_example_llava.py](examples/quick_start/srt_example_llava.py).

Lianmin Zheng's avatar
Lianmin Zheng committed
177
### Constrained Decoding
178
179
Use `regex` to specify a regular expression as a decoding constraint.
This is only supported for local models.
Lianmin Zheng's avatar
Lianmin Zheng committed
180

Lianmin Zheng's avatar
Lianmin Zheng committed
181
```python
Lianmin Zheng's avatar
Lianmin Zheng committed
182
@sgl.function
Lianmin Zheng's avatar
Lianmin Zheng committed
183
184
def regular_expression_gen(s):
    s += "Q: What is the IP address of the Google DNS servers?\n"
Lianmin Zheng's avatar
Lianmin Zheng committed
185
    s += "A: " + sgl.gen(
Lianmin Zheng's avatar
Lianmin Zheng committed
186
187
188
189
190
        "answer",
        temperature=0,
        regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
    )
```
Lianmin Zheng's avatar
Lianmin Zheng committed
191

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
### JSON Decoding

```python
character_regex = (
    r"""\{\n"""
    + r"""    "name": "[\w\d\s]{1,16}",\n"""
    + r"""    "house": "(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)",\n"""
    + r"""    "blood status": "(Pure-blood|Half-blood|Muggle-born)",\n"""
    + r"""    "occupation": "(student|teacher|auror|ministry of magic|death eater|order of the phoenix)",\n"""
    + r"""    "wand": \{\n"""
    + r"""        "wood": "[\w\d\s]{1,16}",\n"""
    + r"""        "core": "[\w\d\s]{1,16}",\n"""
    + r"""        "length": [0-9]{1,2}\.[0-9]{0,2}\n"""
    + r"""    \},\n"""
    + r"""    "alive": "(Alive|Deceased)",\n"""
    + r"""    "patronus": "[\w\d\s]{1,16}",\n"""
    + r"""    "bogart": "[\w\d\s]{1,16}"\n"""
    + r"""\}"""
)

@sgl.function
def character_gen(s, name):
    s += name + " is a character in Harry Potter. Please fill in the following information about him/her.\n"
    s += sgl.gen("json_output", max_tokens=256, regex=character_regex)
```

See also [json_decode.py](examples/usage/json_decode.py).


Lianmin Zheng's avatar
Lianmin Zheng committed
221
### Batching
Lianmin Zheng's avatar
Lianmin Zheng committed
222
223
Use `run_batch` to run a batch of requests with continuous batching.

Lianmin Zheng's avatar
Lianmin Zheng committed
224
225
226
227
228
229
230
231
232
233
234
235
```python
@sgl.function
def text_qa(s, question):
    s += "Q: " + question + "\n"
    s += "A:" + sgl.gen("answer", stop="\n")

states = text_qa.run_batch(
    [
        {"question": "What is the capital of the United Kingdom?"},
        {"question": "What is the capital of France?"},
        {"question": "What is the capital of Japan?"},
    ],
Lianmin Zheng's avatar
Lianmin Zheng committed
236
    progress_bar=True
Lianmin Zheng's avatar
Lianmin Zheng committed
237
238
)
```
Lianmin Zheng's avatar
Lianmin Zheng committed
239
240

### Streaming
Lianmin Zheng's avatar
Lianmin Zheng committed
241
242
Add `stream=True` to enable streaming.

Lianmin Zheng's avatar
Lianmin Zheng committed
243
244
245
246
247
248
249
250
```python
@sgl.function
def text_qa(s, question):
    s += "Q: " + question + "\n"
    s += "A:" + sgl.gen("answer", stop="\n")

states = text_qa.run(
    question="What is the capital of France?",
Lianmin Zheng's avatar
Lianmin Zheng committed
251
252
253
    temperature=0.1,
    stream=True
)
Lianmin Zheng's avatar
Lianmin Zheng committed
254

Lianmin Zheng's avatar
Lianmin Zheng committed
255
256
257
for out in state.text_iter():
    print(out, end="", flush=True)
```
Lianmin Zheng's avatar
Lianmin Zheng committed
258

Lianmin Zheng's avatar
Lianmin Zheng committed
259
260
261
262
### Tips and Implementation Details
- The `choices` argument in `sgl.gen` is implemented by computing the normalized log probabilities of all choices and selecting the one with the highest probability.
- The `regex` argument in `sgl.gen` is implemented through autoregressive decoding with logit bias masking, according to the constraints set by the regex.

Lianmin Zheng's avatar
Lianmin Zheng committed
263
264
265
## Backend: SGLang Runtime (SRT)
The SGLang Runtime (SRT) is designed to work best with the SGLang frontend.
However, it can also be used as a standalone API server.
Ying Sheng's avatar
Ying Sheng committed
266
In this case, the [RadixAttention](https://arxiv.org/abs/2312.07104) can still greatly accelerate many use cases with automatic KV cache reuse.
Lianmin Zheng's avatar
Lianmin Zheng committed
267
268
269
270
271
272
273
274
275

### Usage
Launch a server
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```

Send a request
```
276
curl http://localhost:30000/generate \
Lianmin Zheng's avatar
Lianmin Zheng committed
277
278
  -H "Content-Type: application/json" \
  -d '{
279
    "text": "Once upon a time,",
280
    "sampling_params": {
281
282
283
      "max_new_tokens": 16,
      "temperature": 0
    }
Lianmin Zheng's avatar
Lianmin Zheng committed
284
285
  }'
```
286
287
Learn more about the argument format [here](docs/sampling_params.md).

288
289
290
291
292
293
294
295
### OpenAI Compatible API

In addition, the server supports an experimental OpenAI-compatible API.

```python
import openai
client = openai.Client(
    base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
Cody Yu's avatar
Cody Yu committed
296
297

# Text completion
298
299
300
301
302
303
304
response = client.completions.create(
	model="default",
	prompt="The capital of France is",
	temperature=0,
	max_tokens=32,
)
print(response)
Cody Yu's avatar
Cody Yu committed
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322

# Chat completion
response = client.chat.completions.create(
    model="default",
    messages=[
        {"role": "system", "content": "You are a helpful AI assistant"},
        {"role": "user", "content": "List 3 countries and their capitals."},
    ],
    temperature=0,
    max_tokens=64,
)
print(response)
```

In above example, the server uses the chat template specified in the model tokenizer.
You can override the chat template if needed when launching the server:

```
323
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template llama-2
Cody Yu's avatar
Cody Yu committed
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
```

If the chat template you are looking for is missing, you are welcome to contribute it.
Meanwhile, you can also temporary register your chat template as follows:

```json
{
  "name": "my_model",
  "system": "<|im_start|>system",
  "user": "<|im_start|>user",
  "assistant": "<|im_start|>assistant",
  "sep_style": "CHATML",
  "sep": "<|im_end|>",
  "stop_str": ["<|im_end|>", "<|im_start|>"]
}
```

```
342
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --chat-template ./my_model_template.json
343
344
```

Lianmin Zheng's avatar
Lianmin Zheng committed
345
346
347
348
349
### Additional Arguments
- Add `--tp 2` to enable tensor parallelism.
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --tp 2
```
Ying Sheng's avatar
Ying Sheng committed
350
351
352
353
- If you see out-of-memory errors during serving, please try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`
```
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
```
Lianmin Zheng's avatar
Lianmin Zheng committed
354
- You can turn on [flashinfer](docs/flashinfer.md) to acclerate the inference by using highly optimized CUDA kernels.
Lianmin Zheng's avatar
Lianmin Zheng committed
355
356
357
358
359
360

### Supported Models
- Llama
- Mistral
- Mixtral
- LLaVA
361
  - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
Lianmin Zheng's avatar
Lianmin Zheng committed
362
- Qwen / Qwen 2
363
- AWQ quantization
Lianmin Zheng's avatar
Lianmin Zheng committed
364
365
366

## Benchmark And Performance

Lianmin Zheng's avatar
Lianmin Zheng committed
367
368
369
370
371
372
- Llama-7B on NVIDIA A10G, FP16, Tensor Parallelism=1
![llama_7b](assets/llama_7b.jpg)

- Mixtral-8x7B on NVIDIA A10G, FP16, Tensor Parallelism=8
![mixtral_8x7b](assets/mixtral_8x7b.jpg)

Lianmin Zheng's avatar
Lianmin Zheng committed
373
Learn more [here](docs/benchmark_results.md).
Lianmin Zheng's avatar
Lianmin Zheng committed
374

Lianmin Zheng's avatar
Lianmin Zheng committed
375
## Roadmap
Lianmin Zheng's avatar
Lianmin Zheng committed
376
- [ ] Function call APIs
Ying Sheng's avatar
Ying Sheng committed
377
- [ ] S-LoRA (expect by Feb. 5)
Lianmin Zheng's avatar
Lianmin Zheng committed
378
379
- [ ] Support more models
- [ ] Support more hardware backends
Lianmin Zheng's avatar
Lianmin Zheng committed
380
381
382
383
384
385
386
387
388
389
390
391
392

## Citation And Acknowledgment
```
@misc{zheng2023efficiently,
      title={Efficiently Programming Large Language Models using SGLang},
      author={Lianmin Zheng and Liangsheng Yin and Zhiqiang Xie and Jeff Huang and Chuyue Sun and Cody Hao Yu and Shiyi Cao and Christos Kozyrakis and Ion Stoica and Joseph E. Gonzalez and Clark Barrett and Ying Sheng},
      year={2023},
      eprint={2312.07104},
      archivePrefix={arXiv},
      primaryClass={cs.AI}
}
```

Lianmin Zheng's avatar
Lianmin Zheng committed
393
394
395
[![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-md.svg)](https://huggingface.co/papers/2312.07104)


396
We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), [LMQL](https://github.com/eth-sri/lmql).