Unverified Commit dbb16bed authored by ybyang's avatar ybyang Committed by GitHub
Browse files

Support Thinking Budget (via custom_logit_processor for OpenAI API) [Fix #6572] (#11416)


Signed-off-by: default avatarybyang <ybyang7@iflytek.com>
Co-authored-by: default avatarYorkSu <york_su@qq.com>
parent c1e16003
...@@ -235,6 +235,44 @@ Important Notes: ...@@ -235,6 +235,44 @@ Important Notes:
2. To receive more consistent tool call results, it is recommended to use `--chat-template examples/chat_template/tool_chat_template_deepseekv3.jinja`. It provides an improved unified prompt. 2. To receive more consistent tool call results, it is recommended to use `--chat-template examples/chat_template/tool_chat_template_deepseekv3.jinja`. It provides an improved unified prompt.
### Thinking Budget for DeepSeek R1
In SGLang, we can implement thinking budget with `CustomLogitProcessor`.
Launch a server with `--enable-custom-logit-processor` flag on.
```
python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-R1 --tp 8 --port 30000 --host 0.0.0.0 --mem-fraction-static 0.9 --disable-cuda-graph --reasoning-parser deepseek-r1 --enable-custom-logit-processor
```
Sample Request:
```python
import openai
from rich.pretty import pprint
from sglang.srt.sampling.custom_logit_processor import DeepSeekR1ThinkingBudgetLogitProcessor
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="*")
response = client.chat.completions.create(
model="deepseek-ai/DeepSeek-R1",
messages=[
{
"role": "user",
"content": "Question: Is Paris the Capital of France?",
}
],
max_tokens=1024,
extra_body={
"custom_logit_processor": DeepSeekR1ThinkingBudgetLogitProcessor().to_str(),
"custom_params": {
"thinking_budget": 512,
},
},
)
pprint(response)
```
## FAQ ## FAQ
**Q: Model loading is taking too long, and I'm encountering an NCCL timeout. What should I do?** **Q: Model loading is taking too long, and I'm encountering an NCCL timeout. What should I do?**
......
...@@ -319,3 +319,27 @@ response = requests.post( ...@@ -319,3 +319,27 @@ response = requests.post(
) )
print(response.json()) print(response.json())
``` ```
Send an OpenAI chat completion request:
```python
import openai
from sglang.utils import print_highlight
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="None")
response = client.chat.completions.create(
model="meta-llama/Meta-Llama-3-8B-Instruct",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0.0,
max_tokens=32,
extra_body={
"custom_logit_processor": DeterministicLogitProcessor().to_str(),
"custom_params": {"token_id": 5},
},
)
print_highlight(f"Response: {response}")
```
...@@ -243,6 +243,8 @@ class CompletionRequest(BaseModel): ...@@ -243,6 +243,8 @@ class CompletionRequest(BaseModel):
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None session_params: Optional[Dict] = None
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
custom_params: Optional[Dict] = None
custom_logit_processor: Optional[str] = None
# For PD disaggregation # For PD disaggregation
bootstrap_host: Optional[Union[List[str], str]] = None bootstrap_host: Optional[Union[List[str], str]] = None
...@@ -504,6 +506,10 @@ class ChatCompletionRequest(BaseModel): ...@@ -504,6 +506,10 @@ class ChatCompletionRequest(BaseModel):
stream_reasoning: bool = True stream_reasoning: bool = True
chat_template_kwargs: Optional[Dict] = None chat_template_kwargs: Optional[Dict] = None
# Custom logit processor for advanced sampling control
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
custom_params: Optional[Dict] = None
# For request id # For request id
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
# Extra key for classifying the request (e.g. cache_salt) # Extra key for classifying the request (e.g. cache_salt)
...@@ -636,6 +642,7 @@ class ChatCompletionRequest(BaseModel): ...@@ -636,6 +642,7 @@ class ChatCompletionRequest(BaseModel):
"ignore_eos": self.ignore_eos, "ignore_eos": self.ignore_eos,
"skip_special_tokens": self.skip_special_tokens, "skip_special_tokens": self.skip_special_tokens,
"logit_bias": self.logit_bias, "logit_bias": self.logit_bias,
"custom_params": self.custom_params,
} }
if self.response_format and self.response_format.type == "json_schema": if self.response_format and self.response_format.type == "json_schema":
......
...@@ -196,6 +196,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -196,6 +196,7 @@ class OpenAIServingChat(OpenAIServingBase):
extra_key=self._compute_extra_key(request), extra_key=self._compute_extra_key(request),
priority=request.priority, priority=request.priority,
custom_labels=custom_labels, custom_labels=custom_labels,
custom_logit_processor=request.custom_logit_processor,
) )
return adapted_request, request return adapted_request, request
......
...@@ -121,6 +121,7 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -121,6 +121,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
extra_key=self._compute_extra_key(request), extra_key=self._compute_extra_key(request),
priority=request.priority, priority=request.priority,
custom_labels=custom_labels, custom_labels=custom_labels,
custom_logit_processor=request.custom_logit_processor,
) )
return adapted_request, request return adapted_request, request
...@@ -149,6 +150,7 @@ class OpenAIServingCompletion(OpenAIServingBase): ...@@ -149,6 +150,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
"ignore_eos": request.ignore_eos, "ignore_eos": request.ignore_eos,
"skip_special_tokens": request.skip_special_tokens, "skip_special_tokens": request.skip_special_tokens,
"logit_bias": request.logit_bias, "logit_bias": request.logit_bias,
"custom_params": request.custom_params,
} }
# Handle response_format constraints # Handle response_format constraints
......
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import lru_cache from functools import lru_cache
from typing import Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
import dill import dill
import orjson import orjson
import torch import torch
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def _cache_from_str(json_str: str): def _cache_from_str(json_str: str):
...@@ -52,3 +55,74 @@ class DisallowedTokensLogitsProcessor(CustomLogitProcessor): ...@@ -52,3 +55,74 @@ class DisallowedTokensLogitsProcessor(CustomLogitProcessor):
), f"{custom_param_list=}" ), f"{custom_param_list=}"
logits[..., disallowed_token_ids] = -float("inf") logits[..., disallowed_token_ids] = -float("inf")
return logits return logits
class ThinkingBudgetLogitProcessor(CustomLogitProcessor):
"""A logit processor that controls the length of thinking."""
THINKING_START_TOKEN_ID: int
THINKING_END_TOKEN_ID: int
NEW_LINE_TOKEN_ID: int
def __call__(self, logits, custom_param_list: list[dict[str, Any]]):
if custom_param_list is None or not custom_param_list:
return logits
for i, param_dict in enumerate(custom_param_list):
if param_dict is None:
continue
thinking_budget: int | None = param_dict.get("thinking_budget")
# Skip if thinking_budget is unset, or not an integer, or negative
if (
thinking_budget is None
or not isinstance(thinking_budget, int)
or thinking_budget < 0
):
continue
req: Req = param_dict.get("__req__")
cur_ids: list[int] = [*req.origin_input_ids, *req.output_ids]
# Check if out of thinking stage
if (
self.THINKING_START_TOKEN_ID not in cur_ids
or self.THINKING_END_TOKEN_ID in cur_ids
):
continue
# Find the index of the thinking start token
start_index = cur_ids.index(self.THINKING_START_TOKEN_ID)
# Count the number of tokens after the thinking start token
num_tokens_after_start = len(cur_ids) - start_index - 1
if num_tokens_after_start < thinking_budget:
continue
# Ensure new line token before thinking end token
if not req.output_ids or req.output_ids[-1] != self.NEW_LINE_TOKEN_ID:
logits[i, :] = -float("inf")
logits[i, self.NEW_LINE_TOKEN_ID] = 0.0
continue
# Assign highest probability to the thinking end token
logits[i, :] = -float("inf")
logits[i, self.THINKING_END_TOKEN_ID] = 0.0
return logits
class Qwen3ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor):
"""A logit processor that controls the length of thinking for Qwen3 models."""
THINKING_START_TOKEN_ID: int = 151667
THINKING_END_TOKEN_ID: int = 151668
NEW_LINE_TOKEN_ID: int = 198
class DeepSeekR1ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor):
"""A logit processor that controls the length of thinking for DeepSeek-R1 models."""
THINKING_START_TOKEN_ID: int = 128798
THINKING_END_TOKEN_ID: int = 128799
NEW_LINE_TOKEN_ID: int = 201
...@@ -6,13 +6,17 @@ python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test ...@@ -6,13 +6,17 @@ python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test
""" """
import json import json
import random
import re import re
import unittest import unittest
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
import numpy as np import numpy as np
import openai import openai
import requests import requests
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.srt.utils.hf_transformers_utils import get_tokenizer from sglang.srt.utils.hf_transformers_utils import get_tokenizer
from sglang.test.runners import TEST_RERANK_QUERY_DOCS from sglang.test.runners import TEST_RERANK_QUERY_DOCS
...@@ -848,6 +852,94 @@ class TestOpenAIV1Rerank(CustomTestCase): ...@@ -848,6 +852,94 @@ class TestOpenAIV1Rerank(CustomTestCase):
self.assertTrue(isinstance(response[1]["index"], int)) self.assertTrue(isinstance(response[1]["index"], int))
class TestOpenAIServerCustomLogitProcessor(CustomTestCase):
@classmethod
def setUpClass(cls) -> None:
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=["--enable-custom-logit-processor"],
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(cls.model)
@classmethod
def tearDownClass(cls) -> None:
kill_process_tree(cls.process.pid)
def run_custom_logit_processor(self, target_token_id: Optional[int] = None) -> None:
"""
Test custom logit processor with custom params.
If target_token_id is None, the custom logit processor won't be passed in.
"""
class DeterministicLogitProcessor(CustomLogitProcessor):
"""A dummy logit processor that changes the logits to always sample the given token id."""
CUSTOM_PARAM_KEY = "token_id"
def __call__(self, logits, custom_param_list):
assert logits.shape[0] == len(custom_param_list)
for i, param_dict in enumerate(custom_param_list):
# Mask all other tokens
logits[i, :] = -float("inf")
# Assign highest probability to the specified token
logits[i, param_dict[self.CUSTOM_PARAM_KEY]] = 0.0
return logits
extra_body = {}
if target_token_id is not None:
extra_body["custom_logit_processor"] = (
DeterministicLogitProcessor().to_str()
)
extra_body["custom_params"] = {
"token_id": target_token_id,
}
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
max_tokens = 200
response = client.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": "Question: Is Paris the Capital of France?",
},
],
temperature=0.0,
max_tokens=max_tokens,
extra_body=extra_body,
)
if target_token_id is not None:
target_text = self.tokenizer.decode([target_token_id] * max_tokens)
self.assertTrue(
target_text == response.choices[0].message.content,
f"{target_token_id=}\n{target_text=}\n{response.model_dump(mode='json')}",
)
def test_custom_logit_processor(self) -> None:
"""Test custom logit processor with a single request."""
self.run_custom_logit_processor(target_token_id=5)
def test_custom_logit_processor_batch_mixed(self) -> None:
"""Test a batch of requests mixed of requests with and without custom logit processor."""
target_token_ids = list(range(32)) + [None] * 16
random.shuffle(target_token_ids)
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids))
class TestOpenAIV1Score(CustomTestCase): class TestOpenAIV1Score(CustomTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment