Unverified Commit dbb16bed authored by ybyang's avatar ybyang Committed by GitHub
Browse files

Support Thinking Budget (via custom_logit_processor for OpenAI API) [Fix #6572] (#11416)


Signed-off-by: default avatarybyang <ybyang7@iflytek.com>
Co-authored-by: default avatarYorkSu <york_su@qq.com>
parent c1e16003
......@@ -235,6 +235,44 @@ Important Notes:
2. To receive more consistent tool call results, it is recommended to use `--chat-template examples/chat_template/tool_chat_template_deepseekv3.jinja`. It provides an improved unified prompt.
### Thinking Budget for DeepSeek R1
In SGLang, we can implement thinking budget with `CustomLogitProcessor`.
Launch a server with `--enable-custom-logit-processor` flag on.
```
python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-R1 --tp 8 --port 30000 --host 0.0.0.0 --mem-fraction-static 0.9 --disable-cuda-graph --reasoning-parser deepseek-r1 --enable-custom-logit-processor
```
Sample Request:
```python
import openai
from rich.pretty import pprint
from sglang.srt.sampling.custom_logit_processor import DeepSeekR1ThinkingBudgetLogitProcessor
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="*")
response = client.chat.completions.create(
model="deepseek-ai/DeepSeek-R1",
messages=[
{
"role": "user",
"content": "Question: Is Paris the Capital of France?",
}
],
max_tokens=1024,
extra_body={
"custom_logit_processor": DeepSeekR1ThinkingBudgetLogitProcessor().to_str(),
"custom_params": {
"thinking_budget": 512,
},
},
)
pprint(response)
```
## FAQ
**Q: Model loading is taking too long, and I'm encountering an NCCL timeout. What should I do?**
......
......@@ -319,3 +319,27 @@ response = requests.post(
)
print(response.json())
```
Send an OpenAI chat completion request:
```python
import openai
from sglang.utils import print_highlight
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="None")
response = client.chat.completions.create(
model="meta-llama/Meta-Llama-3-8B-Instruct",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0.0,
max_tokens=32,
extra_body={
"custom_logit_processor": DeterministicLogitProcessor().to_str(),
"custom_params": {"token_id": 5},
},
)
print_highlight(f"Response: {response}")
```
......@@ -243,6 +243,8 @@ class CompletionRequest(BaseModel):
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
custom_params: Optional[Dict] = None
custom_logit_processor: Optional[str] = None
# For PD disaggregation
bootstrap_host: Optional[Union[List[str], str]] = None
......@@ -504,6 +506,10 @@ class ChatCompletionRequest(BaseModel):
stream_reasoning: bool = True
chat_template_kwargs: Optional[Dict] = None
# Custom logit processor for advanced sampling control
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
custom_params: Optional[Dict] = None
# For request id
rid: Optional[Union[List[str], str]] = None
# Extra key for classifying the request (e.g. cache_salt)
......@@ -636,6 +642,7 @@ class ChatCompletionRequest(BaseModel):
"ignore_eos": self.ignore_eos,
"skip_special_tokens": self.skip_special_tokens,
"logit_bias": self.logit_bias,
"custom_params": self.custom_params,
}
if self.response_format and self.response_format.type == "json_schema":
......
......@@ -196,6 +196,7 @@ class OpenAIServingChat(OpenAIServingBase):
extra_key=self._compute_extra_key(request),
priority=request.priority,
custom_labels=custom_labels,
custom_logit_processor=request.custom_logit_processor,
)
return adapted_request, request
......
......@@ -121,6 +121,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
extra_key=self._compute_extra_key(request),
priority=request.priority,
custom_labels=custom_labels,
custom_logit_processor=request.custom_logit_processor,
)
return adapted_request, request
......@@ -149,6 +150,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
"ignore_eos": request.ignore_eos,
"skip_special_tokens": request.skip_special_tokens,
"logit_bias": request.logit_bias,
"custom_params": request.custom_params,
}
# Handle response_format constraints
......
import json
from abc import ABC, abstractmethod
from functools import lru_cache
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import dill
import orjson
import torch
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
@lru_cache(maxsize=None)
def _cache_from_str(json_str: str):
......@@ -52,3 +55,74 @@ class DisallowedTokensLogitsProcessor(CustomLogitProcessor):
), f"{custom_param_list=}"
logits[..., disallowed_token_ids] = -float("inf")
return logits
class ThinkingBudgetLogitProcessor(CustomLogitProcessor):
"""A logit processor that controls the length of thinking."""
THINKING_START_TOKEN_ID: int
THINKING_END_TOKEN_ID: int
NEW_LINE_TOKEN_ID: int
def __call__(self, logits, custom_param_list: list[dict[str, Any]]):
if custom_param_list is None or not custom_param_list:
return logits
for i, param_dict in enumerate(custom_param_list):
if param_dict is None:
continue
thinking_budget: int | None = param_dict.get("thinking_budget")
# Skip if thinking_budget is unset, or not an integer, or negative
if (
thinking_budget is None
or not isinstance(thinking_budget, int)
or thinking_budget < 0
):
continue
req: Req = param_dict.get("__req__")
cur_ids: list[int] = [*req.origin_input_ids, *req.output_ids]
# Check if out of thinking stage
if (
self.THINKING_START_TOKEN_ID not in cur_ids
or self.THINKING_END_TOKEN_ID in cur_ids
):
continue
# Find the index of the thinking start token
start_index = cur_ids.index(self.THINKING_START_TOKEN_ID)
# Count the number of tokens after the thinking start token
num_tokens_after_start = len(cur_ids) - start_index - 1
if num_tokens_after_start < thinking_budget:
continue
# Ensure new line token before thinking end token
if not req.output_ids or req.output_ids[-1] != self.NEW_LINE_TOKEN_ID:
logits[i, :] = -float("inf")
logits[i, self.NEW_LINE_TOKEN_ID] = 0.0
continue
# Assign highest probability to the thinking end token
logits[i, :] = -float("inf")
logits[i, self.THINKING_END_TOKEN_ID] = 0.0
return logits
class Qwen3ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor):
"""A logit processor that controls the length of thinking for Qwen3 models."""
THINKING_START_TOKEN_ID: int = 151667
THINKING_END_TOKEN_ID: int = 151668
NEW_LINE_TOKEN_ID: int = 198
class DeepSeekR1ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor):
"""A logit processor that controls the length of thinking for DeepSeek-R1 models."""
THINKING_START_TOKEN_ID: int = 128798
THINKING_END_TOKEN_ID: int = 128799
NEW_LINE_TOKEN_ID: int = 201
......@@ -6,13 +6,17 @@ python3 -m unittest openai_server.basic.test_openai_server.TestOpenAIServer.test
"""
import json
import random
import re
import unittest
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
import numpy as np
import openai
import requests
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.utils import kill_process_tree
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
from sglang.test.runners import TEST_RERANK_QUERY_DOCS
......@@ -848,6 +852,94 @@ class TestOpenAIV1Rerank(CustomTestCase):
self.assertTrue(isinstance(response[1]["index"], int))
class TestOpenAIServerCustomLogitProcessor(CustomTestCase):
@classmethod
def setUpClass(cls) -> None:
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=["--enable-custom-logit-processor"],
)
cls.base_url += "/v1"
cls.tokenizer = get_tokenizer(cls.model)
@classmethod
def tearDownClass(cls) -> None:
kill_process_tree(cls.process.pid)
def run_custom_logit_processor(self, target_token_id: Optional[int] = None) -> None:
"""
Test custom logit processor with custom params.
If target_token_id is None, the custom logit processor won't be passed in.
"""
class DeterministicLogitProcessor(CustomLogitProcessor):
"""A dummy logit processor that changes the logits to always sample the given token id."""
CUSTOM_PARAM_KEY = "token_id"
def __call__(self, logits, custom_param_list):
assert logits.shape[0] == len(custom_param_list)
for i, param_dict in enumerate(custom_param_list):
# Mask all other tokens
logits[i, :] = -float("inf")
# Assign highest probability to the specified token
logits[i, param_dict[self.CUSTOM_PARAM_KEY]] = 0.0
return logits
extra_body = {}
if target_token_id is not None:
extra_body["custom_logit_processor"] = (
DeterministicLogitProcessor().to_str()
)
extra_body["custom_params"] = {
"token_id": target_token_id,
}
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
max_tokens = 200
response = client.chat.completions.create(
model=self.model,
messages=[
{
"role": "user",
"content": "Question: Is Paris the Capital of France?",
},
],
temperature=0.0,
max_tokens=max_tokens,
extra_body=extra_body,
)
if target_token_id is not None:
target_text = self.tokenizer.decode([target_token_id] * max_tokens)
self.assertTrue(
target_text == response.choices[0].message.content,
f"{target_token_id=}\n{target_text=}\n{response.model_dump(mode='json')}",
)
def test_custom_logit_processor(self) -> None:
"""Test custom logit processor with a single request."""
self.run_custom_logit_processor(target_token_id=5)
def test_custom_logit_processor_batch_mixed(self) -> None:
"""Test a batch of requests mixed of requests with and without custom logit processor."""
target_token_ids = list(range(32)) + [None] * 16
random.shuffle(target_token_ids)
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids))
class TestOpenAIV1Score(CustomTestCase):
@classmethod
def setUpClass(cls):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment