sampling_params.md 18.9 KB
Newer Older
1
2
# Sampling Parameters

3
This doc describes the sampling parameters of the SGLang Runtime. It is the low-level endpoint of the runtime.
4
If you want a high-level endpoint that can automatically handle chat templates, consider using the [OpenAI Compatible API](openai_api_completions.ipynb).
5

6
## `/generate` Endpoint
7

8
The `/generate` endpoint accepts the following parameters in JSON format. For detailed usage, see the [native API doc](native_api.ipynb). The object is defined at `io_struct.py::GenerateReqInput`. You can also read the source code to find more arguments and docs.
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

| Argument                   | Type/Default                                                                 | Description                                                                                                                                                     |
|----------------------------|------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------|
| text                       | `Optional[Union[List[str], str]] = None`                                     | The input prompt. Can be a single prompt or a batch of prompts.                                                                                                 |
| input_ids                  | `Optional[Union[List[List[int]], List[int]]] = None`                         | The token IDs for text; one can specify either text or input_ids.                                                                                               |
| input_embeds               | `Optional[Union[List[List[List[float]]], List[List[float]]]] = None`         | The embeddings for input_ids; one can specify either text, input_ids, or input_embeds.                                                                          |
| image_data                 | `Optional[Union[List[List[ImageDataItem]], List[ImageDataItem], ImageDataItem]] = None` | The image input. Can be an image instance, file name, URL, or base64 encoded string. Can be a single image, list of images, or list of lists of images. |
| audio_data                 | `Optional[Union[List[AudioDataItem], AudioDataItem]] = None`                 | The audio input. Can be a file name, URL, or base64 encoded string.                                                                                             |
| sampling_params            | `Optional[Union[List[Dict], Dict]] = None`                                   | The sampling parameters as described in the sections below.                                                                                                     |
| rid                        | `Optional[Union[List[str], str]] = None`                                     | The request ID.                                                                                                                                                 |
| return_logprob             | `Optional[Union[List[bool], bool]] = None`                                   | Whether to return log probabilities for tokens.                                                                                                                 |
| logprob_start_len          | `Optional[Union[List[int], int]] = None`                                     | If return_logprob, the start location in the prompt for returning logprobs. Default is "-1", which returns logprobs for output tokens only.                     |
| top_logprobs_num           | `Optional[Union[List[int], int]] = None`                                     | If return_logprob, the number of top logprobs to return at each position.                                                                                       |
| token_ids_logprob          | `Optional[Union[List[List[int]], List[int]]] = None`                         | If return_logprob, the token IDs to return logprob for.                                                                                                         |
| return_text_in_logprobs    | `bool = False`                                                               | Whether to detokenize tokens in text in the returned logprobs.                                                                                                  |
| stream                     | `bool = False`                                                               | Whether to stream output.                                                                                                                                       |
| lora_path                  | `Optional[Union[List[Optional[str]], Optional[str]]] = None`                 | The path to the LoRA.                                                                                                                                           |
| custom_logit_processor     | `Optional[Union[List[Optional[str]], str]] = None`                           | Custom logit processor for advanced sampling control. Must be a serialized instance of `CustomLogitProcessor` using its `to_str()` method. For usage see below. |
| return_hidden_states       | `Union[List[bool], bool] = False`                                            | Whether to return hidden states.                                                                                                                                |
28

29
## Sampling parameters
30

31
32
The object is defined at `sampling_params.py::SamplingParams`. You can also read the source code to find more arguments and docs.

33
34
35
36
37
38
39
40
41
42
43
44
### Note on defaults

By default, SGLang initializes several sampling parameters from the model's `generation_config.json` (when the server is launched with `--sampling-defaults model`, which is the default). To use SGLang/OpenAI constant defaults instead, start the server with `--sampling-defaults openai`. You can always override any parameter per request via `sampling_params`.

```bash
# Use model-provided defaults from generation_config.json (default behavior)
python -m sglang.launch_server --model-path <MODEL> --sampling-defaults model

# Use SGLang/OpenAI constant defaults instead
python -m sglang.launch_server --model-path <MODEL> --sampling-defaults openai
```

45
### Core parameters
46

47
48
49
50
51
| Argument        | Type/Default                                 | Description                                                                                                                                    |
|-----------------|----------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------|
| max_new_tokens  | `int = 128`                                  | The maximum output length measured in tokens.                                                                                                  |
| stop            | `Optional[Union[str, List[str]]] = None`     | One or multiple [stop words](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop). Generation will stop if one of these words is sampled. |
| stop_token_ids  | `Optional[List[int]] = None`                 | Provide stop words in the form of token IDs. Generation will stop if one of these token IDs is sampled.                                        |
52
53
54
55
| temperature     | `float (model default; fallback 1.0)`        | [Temperature](https://platform.openai.com/docs/api-reference/chat/create#chat-create-temperature) when sampling the next token. `temperature = 0` corresponds to greedy sampling, a higher temperature leads to more diversity. |
| top_p           | `float (model default; fallback 1.0)`        | [Top-p](https://platform.openai.com/docs/api-reference/chat/create#chat-create-top_p) selects tokens from the smallest sorted set whose cumulative probability exceeds `top_p`. When `top_p = 1`, this reduces to unrestricted sampling from all tokens. |
| top_k           | `int (model default; fallback -1)`           | [Top-k](https://developer.nvidia.com/blog/how-to-get-better-outputs-from-your-large-language-model/#predictability_vs_creativity) randomly selects from the `k` highest-probability tokens. |
| min_p           | `float (model default; fallback 0.0)`        | [Min-p](https://github.com/huggingface/transformers/issues/27670) samples from tokens with probability larger than `min_p * highest_token_probability`. |
56

57
### Penalizers
58

59
60
61
| Argument           | Type/Default           | Description                                                                                                                                    |
|--------------------|------------------------|------------------------------------------------------------------------------------------------------------------------------------------------|
| frequency_penalty  | `float = 0.0`          | Penalizes tokens based on their frequency in generation so far. Must be between `-2` and `2` where negative numbers encourage repeatment of tokens and positive number encourages sampling of new tokens. The scaling of penalization grows linearly with each appearance of a token. |
62
| presence_penalty   | `float = 0.0`          | Penalizes tokens if they appeared in the generation so far. Must be between `-2` and `2` where negative numbers encourage repeatment of tokens and positive number encourages sampling of new tokens. The scaling of the penalization is constant if a token occurred. |
63
| min_new_tokens     | `int = 0`              | Forces the model to generate at least `min_new_tokens` until a stop word or EOS token is sampled. Note that this might lead to unintended behavior, for example, if the distribution is highly skewed towards these tokens. |
64

65
### Constrained decoding
66

Lianmin Zheng's avatar
Lianmin Zheng committed
67
Please refer to our dedicated guide on [constrained decoding](../advanced_features/structured_outputs.ipynb) for the following parameters.
68

69
70
71
72
73
74
| Argument        | Type/Default                    | Description                                                                                                                                    |
|-----------------|---------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------|
| json_schema     | `Optional[str] = None`          | JSON schema for structured outputs.                                                                                                            |
| regex           | `Optional[str] = None`          | Regex for structured outputs.                                                                                                                  |
| ebnf            | `Optional[str] = None`          | EBNF for structured outputs.                                                                                                                   |
| structural_tag  | `Optional[str] = None`          | The structal tag for structured outputs.                                                                                                       |
75

76
### Other options
77

78
79
80
81
82
| Argument                      | Type/Default                    | Description                                                                                                                                    |
|-------------------------------|---------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------|
| n                             | `int = 1`                       | Specifies the number of output sequences to generate per request. (Generating multiple outputs in one request (n > 1) is discouraged; repeating the same prompts several times offers better control and efficiency.) |
| ignore_eos                    | `bool = False`                  | Don't stop generation when EOS token is sampled.                                                                                               |
| skip_special_tokens           | `bool = True`                   | Remove special tokens during decoding.                                                                                                         |
83
84
| spaces_between_special_tokens | `bool = True`                   | Whether or not to add spaces between special tokens during detokenization.                                                                     |
| no_stop_trim                  | `bool = False`                  | Don't trim stop words or EOS token from the generated text.                                                                                    |
85
| custom_params                 | `Optional[List[Optional[Dict[str, Any]]]] = None` | Used when employing `CustomLogitProcessor`. For usage, see below.                                                                              |
86

87
## Examples
88

89
90
91
92
### Normal

Launch a server:

93
```bash
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000
```

Send a request:

```python
import requests

response = requests.post(
    "http://localhost:30000/generate",
    json={
        "text": "The capital of France is",
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": 32,
        },
    },
)
print(response.json())
```

Detailed example in [send request](./send_request.ipynb).

### Streaming

Send a request and stream the output:

```python
import requests, json
123

124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
response = requests.post(
    "http://localhost:30000/generate",
    json={
        "text": "The capital of France is",
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": 32,
        },
        "stream": True,
    },
    stream=True,
)

prev = 0
for chunk in response.iter_lines(decode_unicode=False):
    chunk = chunk.decode("utf-8")
    if chunk and chunk.startswith("data:"):
        if chunk == "data: [DONE]":
            break
        data = json.loads(chunk[5:].strip("\n"))
        output = data["text"].strip()
        print(output[prev:], end="", flush=True)
        prev = len(output)
print("")
```

150
Detailed example in [openai compatible api](openai_api_completions.ipynb).
151

152
### Multimodal
153
154
155

Launch a server:

156
```bash
157
python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-7b-ov
158
159
160
161
```

Download an image:

162
```bash
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
curl -o example_image.png -L https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true
```

Send a request:

```python
import requests

response = requests.post(
    "http://localhost:30000/generate",
    json={
        "text": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
                "<|im_start|>user\n<image>\nDescribe this image in a very short sentence.<|im_end|>\n"
                "<|im_start|>assistant\n",
        "image_data": "example_image.png",
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": 32,
        },
    },
)
print(response.json())
```

The `image_data` can be a file name, a URL, or a base64 encoded string. See also `python/sglang/srt/utils.py:load_image`.

Streaming is supported in a similar manner as [above](#streaming).

191
Detailed example in [OpenAI API Vision](openai_api_vision.ipynb).
192
193
194
195
196
197
198

### Structured Outputs (JSON, Regex, EBNF)

You can specify a JSON schema, regular expression or [EBNF](https://en.wikipedia.org/wiki/Extended_Backus%E2%80%93Naur_form) to constrain the model output. The model output will be guaranteed to follow the given constraints. Only one constraint parameter (`json_schema`, `regex`, or `ebnf`) can be specified for a request.

SGLang supports two grammar backends:

199
- [XGrammar](https://github.com/mlc-ai/xgrammar) (default): Supports JSON schema, regular expression, and EBNF constraints.
200
  - XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md).
Lianmin Zheng's avatar
Lianmin Zheng committed
201
- [Outlines](https://github.com/dottxt-ai/outlines): Supports JSON schema and regular expression constraints.
202

203
If instead you want to initialize the Outlines backend, you can use `--grammar-backend outlines` flag:
204
205
206

```bash
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \
207
--port 30000 --host 0.0.0.0 --grammar-backend [xgrammar|outlines] # xgrammar or outlines (default: xgrammar)
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
```

```python
import json
import requests

json_schema = json.dumps({
    "type": "object",
    "properties": {
        "name": {"type": "string", "pattern": "^[\\w]+$"},
        "population": {"type": "integer"},
    },
    "required": ["name", "population"],
})

# JSON (works with both Outlines and XGrammar)
response = requests.post(
    "http://localhost:30000/generate",
    json={
        "text": "Here is the information of the capital of France in the JSON format.\n",
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": 64,
            "json_schema": json_schema,
        },
    },
)
print(response.json())

# Regular expression (Outlines backend only)
response = requests.post(
    "http://localhost:30000/generate",
    json={
        "text": "Paris is the capital of",
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": 64,
            "regex": "(France|England)",
        },
    },
)
print(response.json())

# EBNF (XGrammar backend only)
response = requests.post(
    "http://localhost:30000/generate",
    json={
        "text": "Write a greeting.",
        "sampling_params": {
            "temperature": 0,
            "max_new_tokens": 64,
            "ebnf": 'root ::= "Hello" | "Hi" | "Hey"',
        },
    },
)
print(response.json())
```

Lianmin Zheng's avatar
Lianmin Zheng committed
266
Detailed example in [structured outputs](../advanced_features/structured_outputs.ipynb).
267
268
269

### Custom logit processor

270
Launch a server with `--enable-custom-logit-processor` flag on.
271
272

```bash
273
274
275
276
python -m sglang.launch_server \
  --model-path meta-llama/Meta-Llama-3-8B-Instruct \
  --port 30000 \
  --enable-custom-logit-processor
277
278
279
```

Define a custom logit processor that will always sample a specific token id.
280

281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
```python
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor

class DeterministicLogitProcessor(CustomLogitProcessor):
    """A dummy logit processor that changes the logits to always
    sample the given token id.
    """

    def __call__(self, logits, custom_param_list):
        # Check that the number of logits matches the number of custom parameters
        assert logits.shape[0] == len(custom_param_list)
        key = "token_id"

        for i, param_dict in enumerate(custom_param_list):
            # Mask all other tokens
            logits[i, :] = -float("inf")
            # Assign highest probability to the specified token
            logits[i, param_dict[key]] = 0.0
        return logits
```

302
303
Send a request:

304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
```python
import requests

response = requests.post(
    "http://localhost:30000/generate",
    json={
        "text": "The capital of France is",
        "custom_logit_processor": DeterministicLogitProcessor().to_str(),
        "sampling_params": {
            "temperature": 0.0,
            "max_new_tokens": 32,
            "custom_params": {"token_id": 5},
        },
    },
)
print(response.json())
```