sampling_params.md 6.4 KB
Newer Older
1
# Sampling Parameters
2

3
This doc describes the sampling parameters of the SGLang Runtime.
4
It is the low-level endpoint of the runtime.
5
If you want a high-level endpoint that can automatically handle chat templates, consider using the [OpenAI Compatible API](https://docs.sglang.ai/backend/openai_api_completions.html).
6

7
## `/generate` Endpoint
8

9
The `/generate` endpoint accepts the following parameters in JSON format. For in detail usage see the [native api doc](https://docs.sglang.ai/backend/native_api.html).
10

11
12
13
14
15
16
17
18
19
* `prompt`: The input prompt. Can be a single prompt or a batch of prompts.
* `input_ids`: Alternative to `text`. Specify the input as token IDs instead of text.
* `sampling_params`: The sampling parameters as described in the sections below.
* `return_logprob`: Whether to return log probabilities for tokens.
* `logprob_start_len`: If returning log probabilities, specifies the start position in the prompt. Default is "-1" which returns logprobs only for output tokens.
* `top_logprobs_num`: If returning log probabilities, specifies the number of top logprobs to return at each position.
* `stream`: Whether to stream the output.
* `lora_path`: Path to LoRA weights.
* `custom_logit_processor`: Custom logit processor for advanced sampling control. For usage see below.
Ying Sheng's avatar
Ying Sheng committed
20

21
## Sampling params
Ying Sheng's avatar
Ying Sheng committed
22

23
### Core Parameters
Ying Sheng's avatar
Ying Sheng committed
24

25
26
27
28
29
30
31
* `max_new_tokens`: The maximum output length measured in tokens.
* `stop`: One or multiple [stop words](https://developer.nvidia.com/blog/how-to-get-better-outputs-from-your-large-language-model/#let_the_model_know_when_to_stop). Generation will stop if one of these words is sampled.
* `stop_token_ids`: Provide stop words in form of token ids. Generation will stop if one of these token ids is sampled.
* `temperature`: [Temperature](https://developer.nvidia.com/blog/how-to-get-better-outputs-from-your-large-language-model/#predictability_vs_creativity) when sampling the next token. `temperature = 0` corresponds to greedy sampling, higher temperature leads to more diversity.
* `top_p`: [Top-p](https://developer.nvidia.com/blog/how-to-get-better-outputs-from-your-large-language-model/#predictability_vs_creativity) selects tokens from the smallest sorted set whose cumulative probability exceeds `top_p`. When `top_p = 1`, this reduces to unrestricted sampling from all tokens.
* `top_k`: [Top-k](https://developer.nvidia.com/blog/how-to-get-better-outputs-from-your-large-language-model/#predictability_vs_creativity) randomly selects from the `k` highest-probability tokens.
* `min_p`: [Min-p](https://github.com/huggingface/transformers/issues/27670) samples from tokens with probability larger than `min_p * highest_token_probability`.
Ying Sheng's avatar
Ying Sheng committed
32

33
### Penalizers
Lianmin Zheng's avatar
Lianmin Zheng committed
34

35
To use penalizers you will need to `--disable-overlap`. Please note that this might degrade performance.
36

37
38
39
40
* `frequency_penalty`: Penalizes tokens based on their frequency in generation so far. Must be between `-2` and `2` where negative numbers encourage repeatment of tokens and positive number encourages sampling of new tokens. The scaling of penalization grows linearly with each appearance of a token.
* `presence_penalty`: Penalizes tokens if they appeared in the generation so far. Must be between `-2` and `2` where negative numbers encourage repeatment of tokens and positive number encourages sampling of new tokens. The scaling of the penalization is constant if a token occured.
* `repetition_penalty`: Penalizes tokens if they appeared in prompt or generation so far. Must be between `0` and `2` where numbers smaller than `1` encourage repeatment of tokens and numbers larger than `2` encourages sampling of new tokens. The penalization scales multiplicatively.
* `min_new_tokens`: Forces the model to generate at least `min_new_tokens` until a stop word or EOS token is sampled. Note that this might lead to unintended behavior for example if the distribution is highly skewed towards these tokens.
41

42
### Constrained decoding
43

44
Please refer to our dedicated guide on [constrained decoding](https://docs.sglang.ai/backend/structured_outputs.html#Native-API-and-SGLang-Runtime-(SRT)) for the following parameters.
Lianmin Zheng's avatar
Lianmin Zheng committed
45

46
47
48
* `json_schema`
* `regex`
* `ebnf`
Lianmin Zheng's avatar
Lianmin Zheng committed
49

50
### Other options
Lianmin Zheng's avatar
Lianmin Zheng committed
51

52
* `n`: Specifies the number of output sequences to generate per request. (Generating multiple outputs in one request (n > 1) is discouraged; separate requests offer better control and efficiency.)
53
54
55
56
57
* `spaces_between_special_tokens`: Whether or not to add spaces between special tokens during detokenization.
* `no_stop_trim`: Don't trim stop words or EOS token from the generated text.
* `ignore_eos`: Don't stop generation when EOS token is sampled.
* `skip_special_tokens`: Remove special tokens during decoding.
* `custom_params`: Used when employing `CustomLogitProcessor`. For usage see below.
58
* `return_hidden_states`: Whether to return hidden states of the model. Note that each time it changes, the cuda graph will be recaptured, which might lead to a performance hit. See the [examples](https://github.com/sgl-project/sglang/blob/main/examples/runtime/engine/hidden_states.py) for more information.
Lianmin Zheng's avatar
Lianmin Zheng committed
59

60

61
62
63
64
65
66
67
### Custom Logit Processor
Launch a server with `--enable-custom-logit-processor` flag on.
```
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --enable-custom-logit-processor
```

Define a custom logit processor that will always sample a specific token id.
68

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
```python
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor

class DeterministicLogitProcessor(CustomLogitProcessor):
    """A dummy logit processor that changes the logits to always
    sample the given token id.
    """

    def __call__(self, logits, custom_param_list):
        # Check that the number of logits matches the number of custom parameters
        assert logits.shape[0] == len(custom_param_list)
        key = "token_id"

        for i, param_dict in enumerate(custom_param_list):
            # Mask all other tokens
            logits[i, :] = -float("inf")
            # Assign highest probability to the specified token
            logits[i, param_dict[key]] = 0.0
        return logits
```

Send a request
91

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
```python
import requests

response = requests.post(
    "http://localhost:30000/generate",
    json={
        "text": "The capital of France is",
        "custom_logit_processor": DeterministicLogitProcessor().to_str(),
        "sampling_params": {
            "temperature": 0.0,
            "max_new_tokens": 32,
            "custom_params": {"token_id": 5},
        },
    },
)
print(response.json())
```