"vllm/vscode:/vscode.git/clone" did not exist on "258a2c58d08fc7a242556120877a89404861fbce"
Unverified Commit 47318847 authored by Sungjae Lee's avatar Sungjae Lee Committed by GitHub
Browse files

[Feature] limit thinking tokens (hard limit) (#20859)


Signed-off-by: default avatarSungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: default avatarSungjae Lee <sung-jae.lee@navercorp.com>
Signed-off-by: default avatarChauncey <chaunceyjiang@gmail.com>
Co-authored-by: default avatarChauncey <chaunceyjiang@gmail.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 8de5261e
...@@ -240,6 +240,81 @@ response = client.chat.completions.create( ...@@ -240,6 +240,81 @@ response = client.chat.completions.create(
) )
``` ```
## Thinking Budget Control
Some models, such as [Qwen3](https://qwen.readthedocs.io/en/latest/getting_started/quickstart.html#thinking-budget), [DeepSeek](https://www.alibabacloud.com/help/en/model-studio/deep-thinking), and [Nemotron3](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16), support a thinking budget that limits the maximum number of tokens used for reasoning.
Token counting starts from `think_start_str`. Once the reasoning token count reaches the configured `thinking_token_budget`, vLLM forces the model to produce `think_end_str`, effectively terminating the reasoning block.
To use this feature:
- `--reasoning-parser` enables reasoning extraction.
- `--reasoning-config` defines the reasoning boundary tokens (e.g., `think_start_str`, `think_end_str`).
- `thinking_token_budget` (a sampling parameter) sets the per-request reasoning token limit.
If `thinking_token_budget` is not specified, no explicit reasoning limit is applied beyond normal generation constraints such as `max_tokens`.
`--reasoning-config` accepts a JSON object corresponding to
[ReasoningConfig][vllm.config.ReasoningConfig] with the following fields:
| Field | Type | Description |
|-------------------|----------------|--------------------------------------------------|
| `think_start_str` | `str \| null` | String that marks the start of reasoning content |
| `think_end_str` | `str \| null` | String that marks the end of reasoning content |
!!! note
`think_end_str` can include a transition phrase before the think end token. For example, setting `think_end_str` to `"I have to give the solution based on the thinking directly now.</think>"` instructs the model to emit that phrase when the budget is exhausted, making the reasoning termination more natural.
### Online Serving
```bash
vllm serve Qwen/Qwen3-0.6B \
--reasoning-parser qwen3 \
--reasoning-config '{"think_start_str": "<think>", "think_end_str": "I have to give the solution based on the thinking directly now.</think>"}'
```
Then make a request with `thinking_token_budget` to limit the reasoning tokens:
```bash
curl http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen3-0.6B",
"messages": [
{ "role": "user", "content": "9.11 and 9.8, which is greater?" }
],
"extra_body": {
"thinking_token_budget": 10
}
}'
```
### Offline Inference
```python
from vllm import LLM, SamplingParams
from vllm.config import ReasoningConfig
llm = LLM(
model="Qwen/Qwen3-0.6B",
reasoning_config=ReasoningConfig(
think_start_str="<think>",
think_end_str="I have to give the solution based on the thinking directly now.</think>",
),
)
sampling_params = SamplingParams(thinking_token_budget=10)
messages = [
{"role": "user", "content": "9.11 and 9.8, which is greater?"},
]
outputs = llm.chat(messages, sampling_params=sampling_params)
for output in outputs:
print("text:", output.outputs[0].text)
```
## Limitations ## Limitations
- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`). - The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""E2E tests for thinking_token_budget with reasoning models."""
import openai
import pytest
import pytest_asyncio
from tests.utils import RemoteOpenAIServer
MODEL_NAME = "Qwen/Qwen3-0.6B"
MESSAGES = [{"role": "user", "content": "What is 1+1? Be concise."}]
THINK_BUDGET = 5
@pytest.fixture(scope="module")
def server():
args = [
"--reasoning-parser",
"qwen3",
"--reasoning-config",
'{"think_start_str": "<think>", "think_end_str": "</think>"}',
"--max-model-len",
"2048",
"--enforce-eager",
"--no-async-scheduling",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
async def test_thinking_token_budget_mixed_requests(client: openai.AsyncOpenAI):
"""Test that mixed requests (some with thinking_token_budget, some without)
complete successfully without errors."""
response_with_budget = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES,
max_tokens=100,
extra_body={"thinking_token_budget": THINK_BUDGET},
)
response_without_budget = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES,
max_tokens=100,
)
msg_with = response_with_budget.choices[0].message
msg_without = response_without_budget.choices[0].message
assert msg_with.content or getattr(msg_with, "reasoning", None)
assert msg_without.content or getattr(msg_without, "reasoning", None)
@pytest.mark.asyncio
async def test_thinking_token_budget_limits_reasoning(client: openai.AsyncOpenAI):
"""Test that thinking_token_budget limits the number of reasoning tokens.
In streaming mode each reasoning delta corresponds to one token, so
counting non-empty reasoning_content chunks gives the exact token count.
"""
reasoning_token_count = 0
stream = await client.chat.completions.create(
model=MODEL_NAME,
messages=MESSAGES,
max_tokens=100,
stream=True,
extra_body={"thinking_token_budget": THINK_BUDGET},
)
async for chunk in stream:
delta = chunk.choices[0].delta
if getattr(delta, "reasoning", None):
reasoning_token_count += 1
assert reasoning_token_count == THINK_BUDGET, (
f"reasoning tokens ({reasoning_token_count}) != "
f"thinking_token_budget ({THINK_BUDGET})"
)
...@@ -30,6 +30,7 @@ from vllm.v1.sample.logits_processor import ( ...@@ -30,6 +30,7 @@ from vllm.v1.sample.logits_processor import (
MinPLogitsProcessor, MinPLogitsProcessor,
MinTokensLogitsProcessor, MinTokensLogitsProcessor,
MoveDirectionality, MoveDirectionality,
ThinkingTokenBudgetLogitsProcessor,
build_logitsprocs, build_logitsprocs,
) )
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
...@@ -47,6 +48,11 @@ MIN_TOKENS_LEN_THRESHOLD = 5 ...@@ -47,6 +48,11 @@ MIN_TOKENS_LEN_THRESHOLD = 5
REQS_PER_LOGITPROC = 50 REQS_PER_LOGITPROC = 50
STR_NO_LOGITPROC = "none" STR_NO_LOGITPROC = "none"
# ThinkingTokenBudgetLogitsProcessor testing constants
THINKING_TOKEN_BUDGET = 5
THINK_START_TOKEN_ID = 999
THINK_END_TOKEN_ID = 998
# LogitsProcessor subclass or "none" # LogitsProcessor subclass or "none"
LogitprocType: TypeAlias = type[LogitsProcessor] | str LogitprocType: TypeAlias = type[LogitsProcessor] | str
...@@ -67,9 +73,24 @@ class LogitsProcsRequestParams: ...@@ -67,9 +73,24 @@ class LogitsProcsRequestParams:
self.workload_index = workload_index self.workload_index = workload_index
self.logitproc_type = logitproc_type self.logitproc_type = logitproc_type
# Number of output tokens is randomly 0 or twice the min-tokens # Number of output tokens is randomly 0 or twice the min-tokens
# threshold which will be used in testing. Output token values # threshold which will be used in testing.
# don't matter *for these tests* so use 0 as a dummy value # Generate diverse random tokens for all processors (more realistic)
self.out_tokens = [0] * (MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2)) num_tokens = MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2)
if num_tokens > 0:
# Use diverse random tokens
self.out_tokens = [random.randint(1, 950) for _ in range(num_tokens)]
# Set first token for ThinkingTokenBudget testing
is_thinking_processor = (
logitproc_type is ThinkingTokenBudgetLogitsProcessor
or (
hasattr(logitproc_type, "__name__")
and logitproc_type.__name__ == "ThinkingTokenBudgetLogitsProcessor"
)
)
if is_thinking_processor:
self.out_tokens[0] = THINK_START_TOKEN_ID
else:
self.out_tokens = []
self.prompt_tokens = [] self.prompt_tokens = []
self.params = _sampling_params_from_logitproc(logitproc_type) self.params = _sampling_params_from_logitproc(logitproc_type)
...@@ -79,6 +100,13 @@ class LogitsProcsRequestParams: ...@@ -79,6 +100,13 @@ class LogitsProcsRequestParams:
return f"MyClass({summ})" return f"MyClass({summ})"
class MockReasoningConfig:
"""Mock reasoning config for testing ThinkingTokenBudgetLogitsProcessor."""
think_start_token_ids = [THINK_START_TOKEN_ID]
think_end_token_ids = [THINK_END_TOKEN_ID]
def _generate_fake_sampling_metadata( def _generate_fake_sampling_metadata(
num_output_tokens: int, num_output_tokens: int,
batch_size: int, batch_size: int,
...@@ -97,8 +125,12 @@ def _generate_fake_sampling_metadata( ...@@ -97,8 +125,12 @@ def _generate_fake_sampling_metadata(
0, vocab_size, size=np.random.randint(1, MAX_NUM_PROMPT_TOKENS) 0, vocab_size, size=np.random.randint(1, MAX_NUM_PROMPT_TOKENS)
).tolist() ).tolist()
) )
vllm_config = VllmConfig()
vllm_config.reasoning_config = MockReasoningConfig()
logitsprocs = build_logitsprocs( logitsprocs = build_logitsprocs(
vllm_config=VllmConfig(), vllm_config=vllm_config,
device=device, device=device,
is_pin_memory=PIN_MEMORY_AVAILABLE, is_pin_memory=PIN_MEMORY_AVAILABLE,
is_pooling_model=False, is_pooling_model=False,
...@@ -403,6 +435,127 @@ def _min_tokens_validate( ...@@ -403,6 +435,127 @@ def _min_tokens_validate(
) )
def _thinking_budget_params(kwargs: dict) -> None:
"""Set SamplingParams kwargs for thinking token budget tests"""
kwargs["thinking_token_budget"] = THINKING_TOKEN_BUDGET
def _thinking_budget_validate(
test_fakes: LogitsprocsTestFakes,
persistent_batch: list[LogitsProcsRequestParams],
logits_new: torch.Tensor,
batch_index: int,
request_params: LogitsProcsRequestParams,
step_idx: int,
) -> None:
"""Validate thinking token budget processor behavior"""
# Get the ThinkingTokenBudgetLogitsProcessor instance
tb_processor: ThinkingTokenBudgetLogitsProcessor = next(
test_fakes.get_logitsprocs_by_cls(ThinkingTokenBudgetLogitsProcessor)
)
# Get current request state
state = tb_processor._state.get(batch_index)
params = request_params.params
# Validate thinking token budget configuration
if hasattr(params, "thinking_token_budget") and params.thinking_token_budget:
# State should exist for requests with thinking_token_budget
if state is None:
_raise_error_invalid(
msg_suffix=(
f"Expected state for batch {batch_index} "
f"with thinking_token_budget={params.thinking_token_budget}"
),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx,
)
# Validate budget matches what was set
expected_budget = params.thinking_token_budget
actual_budget = state["thinking_token_budget"]
if actual_budget != expected_budget:
_raise_error_invalid(
msg_suffix=(
f"Budget mismatch: expected {expected_budget}, got {actual_budget}"
),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx,
)
# Check if we're in thinking mode and validate token counting
output_tokens = request_params.out_tokens
# Find if thinking has started in output tokens
thinking_started = False
start_tokens = tb_processor.think_start_token_ids
if len(start_tokens) > 0:
for i in range(len(output_tokens) - len(start_tokens) + 1):
if output_tokens[i : i + len(start_tokens)] == start_tokens:
thinking_started = True
break
if thinking_started:
# If budget is exceeded, validate end token forcing
think_count = state["think_count"]
budget = state["thinking_token_budget"]
if think_count >= budget:
if not state["in_end"]:
_raise_error_invalid(
msg_suffix=(
f"Budget exceeded ({think_count} >= "
f"{budget}) but not "
"forcing end tokens"
),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx,
)
# Validate that only end tokens are allowed
end_tokens = tb_processor.think_end_token_ids
if len(end_tokens) > 0:
expected_end_token_id = end_tokens[
min(state["end_count"], len(end_tokens) - 1)
]
# Check logits masking
batch_logits = logits_new[batch_index]
for token_id in range(len(batch_logits)):
logit_value = batch_logits[token_id]
if token_id == expected_end_token_id:
# End token should not be masked
if logit_value == -float("inf"):
_raise_error_invalid(
msg_suffix=(
f"End token {token_id} should not be "
"masked but is"
),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx,
)
else:
# All other tokens should be masked when forcing end
if logit_value != -float("inf"):
_raise_error_invalid(
msg_suffix=(
f"Token {token_id} should be masked "
f"when forcing end tokens, but "
f"logit={logit_value}"
),
batch_index=batch_index,
request_params=request_params,
step_idx=step_idx,
)
def _none_validate( def _none_validate(
test_fakes: LogitsprocsTestFakes, test_fakes: LogitsprocsTestFakes,
persistent_batch: list[LogitsProcsRequestParams], persistent_batch: list[LogitsProcsRequestParams],
...@@ -449,20 +602,30 @@ logitsprocs_test_mapping = { ...@@ -449,20 +602,30 @@ logitsprocs_test_mapping = {
MinTokensLogitsProcessor: LogitsprocTestHelpers( MinTokensLogitsProcessor: LogitsprocTestHelpers(
gen_request_fxn=_min_tokens_params, eval_fxn=_min_tokens_validate gen_request_fxn=_min_tokens_params, eval_fxn=_min_tokens_validate
), ),
ThinkingTokenBudgetLogitsProcessor: LogitsprocTestHelpers(
gen_request_fxn=_thinking_budget_params, eval_fxn=_thinking_budget_validate
),
} }
def _get_test_cases() -> list[list[str]]: def _get_test_cases() -> list[list[str]]:
"""Each test case is a set of logitsprocs""" """Each test case is a set of logitsprocs"""
logitsprocs_types = list(logitsprocs_test_mapping.keys()) logitsprocs_types = list(logitsprocs_test_mapping.keys())
# Isolate ThinkingTokenBudgetLogitsProcessor from all other processors
# to avoid unexpected modification of logits interference
thinking_processor = ThinkingTokenBudgetLogitsProcessor
other_processors = [
p
for p in logitsprocs_types
if p != STR_NO_LOGITPROC and p != thinking_processor
]
return ( return (
[[STR_NO_LOGITPROC]] [[STR_NO_LOGITPROC]]
+ [ + [[logitproc_type, STR_NO_LOGITPROC] for logitproc_type in other_processors]
[logitproc_type, STR_NO_LOGITPROC] + [other_processors]
for logitproc_type in logitsprocs_types + [[thinking_processor]]
if logitproc_type != STR_NO_LOGITPROC
]
+ [logitsprocs_types]
) )
......
...@@ -33,6 +33,7 @@ from vllm.config.offload import ( ...@@ -33,6 +33,7 @@ from vllm.config.offload import (
from vllm.config.parallel import EPLBConfig, ParallelConfig from vllm.config.parallel import EPLBConfig, ParallelConfig
from vllm.config.pooler import PoolerConfig from vllm.config.pooler import PoolerConfig
from vllm.config.profiler import ProfilerConfig from vllm.config.profiler import ProfilerConfig
from vllm.config.reasoning import ReasoningConfig
from vllm.config.scheduler import SchedulerConfig from vllm.config.scheduler import SchedulerConfig
from vllm.config.speculative import SpeculativeConfig from vllm.config.speculative import SpeculativeConfig
from vllm.config.speech_to_text import SpeechToTextConfig from vllm.config.speech_to_text import SpeechToTextConfig
...@@ -101,6 +102,8 @@ __all__ = [ ...@@ -101,6 +102,8 @@ __all__ = [
"ParallelConfig", "ParallelConfig",
# From vllm.config.pooler # From vllm.config.pooler
"PoolerConfig", "PoolerConfig",
# From vllm.config.reasoning
"ReasoningConfig",
# From vllm.config.scheduler # From vllm.config.scheduler
"SchedulerConfig", "SchedulerConfig",
# From vllm.config.speculative # From vllm.config.speculative
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import field
from vllm.config.model import ModelConfig
from vllm.config.utils import config
from vllm.tokenizers import cached_tokenizer_from_config
@config
class ReasoningConfig:
"""Configuration for reasoning models.
Set `think_start_str` and `think_end_str` to the strings that delimit
the reasoning block (e.g. `"<think>"` and `"</think>"`). The
corresponding token IDs are derived automatically via
`initialize_token_ids` and are not intended to be set directly.
"""
# NOTE: These parameters are temporary, the intent is to derive them
# automatically from the reasoning parser in a future version.
think_start_str: str = "<think>"
"""String that indicates the start of reasoning."""
think_end_str: str = "</think>"
"""String that indicates the end of reasoning content."""
_think_start_token_ids: list[int] | None = field(
default=None, init=False, repr=False
)
"""Private backing field for `think_start_token_ids`. Set by
`initialize_token_ids`. Not intended to be configured directly."""
_think_end_token_ids: list[int] | None = field(default=None, init=False, repr=False)
"""Private backing field for `think_end_token_ids`. Set by
`initialize_token_ids`. Not intended to be configured directly."""
@property
def think_start_token_ids(self) -> list[int] | None:
"""Token IDs derived from `think_start_str`. Set automatically by
`initialize_token_ids`. Not intended to be configured directly."""
return self._think_start_token_ids
@property
def think_end_token_ids(self) -> list[int] | None:
"""Token IDs derived from `think_end_str`. Set automatically by
`initialize_token_ids`. Not intended to be configured directly."""
return self._think_end_token_ids
def initialize_token_ids(self, model_config: ModelConfig) -> None:
"""Initialize reasoning token IDs from strings using the tokenizer."""
if (
self._think_start_token_ids is not None
and self._think_end_token_ids is not None
):
return
tokenizer = cached_tokenizer_from_config(model_config=model_config)
self._think_start_token_ids = tokenizer.encode(
self.think_start_str, add_special_tokens=False
)
self._think_end_token_ids = tokenizer.encode(
self.think_end_str, add_special_tokens=False
)
if not self._think_start_token_ids or not self._think_end_token_ids:
raise ValueError(
f"ReasoningConfig: failed to tokenize reasoning strings: "
f"think_start_str='{self.think_start_str}', "
f"think_end_str='{self.think_end_str}'. "
"Ensure the strings are valid tokens in the model's vocabulary."
)
...@@ -40,6 +40,7 @@ from .observability import ObservabilityConfig ...@@ -40,6 +40,7 @@ from .observability import ObservabilityConfig
from .offload import OffloadConfig from .offload import OffloadConfig
from .parallel import ParallelConfig from .parallel import ParallelConfig
from .profiler import ProfilerConfig from .profiler import ProfilerConfig
from .reasoning import ReasoningConfig
from .scheduler import SchedulerConfig from .scheduler import SchedulerConfig
from .speculative import EagleModelTypes, NgramGPUTypes, SpeculativeConfig from .speculative import EagleModelTypes, NgramGPUTypes, SpeculativeConfig
from .structured_outputs import StructuredOutputsConfig from .structured_outputs import StructuredOutputsConfig
...@@ -302,6 +303,8 @@ class VllmConfig: # type: ignore[misc] ...@@ -302,6 +303,8 @@ class VllmConfig: # type: ignore[misc]
"""The configurations for event publishing.""" """The configurations for event publishing."""
ec_transfer_config: ECTransferConfig | None = None ec_transfer_config: ECTransferConfig | None = None
"""The configurations for distributed EC cache transfer.""" """The configurations for distributed EC cache transfer."""
reasoning_config: ReasoningConfig | None = None
"""The configurations for reasoning model."""
# some opaque config, only used to provide additional information # some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing, debugging or out of # for the hash computation, mainly used for testing, debugging or out of
# tree config registration. # tree config registration.
...@@ -1143,6 +1146,9 @@ class VllmConfig: # type: ignore[misc] ...@@ -1143,6 +1146,9 @@ class VllmConfig: # type: ignore[misc]
if not self.instance_id: if not self.instance_id:
self.instance_id = random_uuid()[:5] self.instance_id = random_uuid()[:5]
if self.reasoning_config is not None and self.model_config is not None:
self.reasoning_config.initialize_token_ids(self.model_config)
# Hybrid KV cache manager (HMA) runtime rules: # Hybrid KV cache manager (HMA) runtime rules:
# - Explicit enable (--no-disable-kv-cache-manager): error if runtime # - Explicit enable (--no-disable-kv-cache-manager): error if runtime
# disables it # disables it
......
...@@ -53,6 +53,7 @@ from vllm.config import ( ...@@ -53,6 +53,7 @@ from vllm.config import (
PoolerConfig, PoolerConfig,
PrefetchOffloadConfig, PrefetchOffloadConfig,
ProfilerConfig, ProfilerConfig,
ReasoningConfig,
SchedulerConfig, SchedulerConfig,
SpeculativeConfig, SpeculativeConfig,
StructuredOutputsConfig, StructuredOutputsConfig,
...@@ -581,6 +582,7 @@ class EngineArgs: ...@@ -581,6 +582,7 @@ class EngineArgs:
kv_events_config: KVEventsConfig | None = None kv_events_config: KVEventsConfig | None = None
ec_transfer_config: ECTransferConfig | None = None ec_transfer_config: ECTransferConfig | None = None
reasoning_config: ReasoningConfig = get_field(VllmConfig, "reasoning_config")
generation_config: str = ModelConfig.generation_config generation_config: str = ModelConfig.generation_config
enable_sleep_mode: bool = ModelConfig.enable_sleep_mode enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
...@@ -1297,6 +1299,7 @@ class EngineArgs: ...@@ -1297,6 +1299,7 @@ class EngineArgs:
vllm_group.add_argument( vllm_group.add_argument(
"--attention-config", "-ac", **vllm_kwargs["attention_config"] "--attention-config", "-ac", **vllm_kwargs["attention_config"]
) )
vllm_group.add_argument("--reasoning-config", **vllm_kwargs["reasoning_config"])
vllm_group.add_argument("--kernel-config", **vllm_kwargs["kernel_config"]) vllm_group.add_argument("--kernel-config", **vllm_kwargs["kernel_config"])
vllm_group.add_argument( vllm_group.add_argument(
"--additional-config", **vllm_kwargs["additional_config"] "--additional-config", **vllm_kwargs["additional_config"]
...@@ -1958,6 +1961,7 @@ class EngineArgs: ...@@ -1958,6 +1961,7 @@ class EngineArgs:
kv_transfer_config=self.kv_transfer_config, kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config, kv_events_config=self.kv_events_config,
ec_transfer_config=self.ec_transfer_config, ec_transfer_config=self.ec_transfer_config,
reasoning_config=self.reasoning_config,
profiler_config=self.profiler_config, profiler_config=self.profiler_config,
additional_config=self.additional_config, additional_config=self.additional_config,
optimization_level=self.optimization_level, optimization_level=self.optimization_level,
......
...@@ -180,6 +180,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -180,6 +180,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
| None | None
) = "none" ) = "none"
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None reasoning_effort: Literal["none", "low", "medium", "high"] | None = None
thinking_token_budget: int | None = None
include_reasoning: bool = True include_reasoning: bool = True
parallel_tool_calls: bool | None = True parallel_tool_calls: bool | None = True
...@@ -515,6 +516,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -515,6 +516,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
structured_outputs=self.structured_outputs, structured_outputs=self.structured_outputs,
logit_bias=self.logit_bias, logit_bias=self.logit_bias,
bad_words=self.bad_words, bad_words=self.bad_words,
thinking_token_budget=self.thinking_token_budget,
allowed_token_ids=self.allowed_token_ids, allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None, extra_args=extra_args or None,
skip_clone=True, # Created fresh per request, safe to skip clone skip_clone=True, # Created fresh per request, safe to skip clone
......
...@@ -281,6 +281,8 @@ class SamplingParams( ...@@ -281,6 +281,8 @@ class SamplingParams(
_bad_words_token_ids: list[list[int]] | None = None _bad_words_token_ids: list[list[int]] | None = None
skip_reading_prefix_cache: bool | None = None skip_reading_prefix_cache: bool | None = None
thinking_token_budget: int | None = None
"""Maximum number of tokens allowed for thinking operations."""
repetition_detection: RepetitionDetectionParams | None = None repetition_detection: RepetitionDetectionParams | None = None
"""Parameters for detecting repetitive N-gram patterns in output tokens. """Parameters for detecting repetitive N-gram patterns in output tokens.
...@@ -304,6 +306,7 @@ class SamplingParams( ...@@ -304,6 +306,7 @@ class SamplingParams(
stop: str | list[str] | None = None, stop: str | list[str] | None = None,
stop_token_ids: list[int] | None = None, stop_token_ids: list[int] | None = None,
bad_words: list[str] | None = None, bad_words: list[str] | None = None,
thinking_token_budget: int | None = None,
include_stop_str_in_output: bool = False, include_stop_str_in_output: bool = False,
ignore_eos: bool = False, ignore_eos: bool = False,
max_tokens: int | None = 16, max_tokens: int | None = 16,
...@@ -344,6 +347,7 @@ class SamplingParams( ...@@ -344,6 +347,7 @@ class SamplingParams(
stop=stop, stop=stop,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
bad_words=bad_words, bad_words=bad_words,
thinking_token_budget=thinking_token_budget,
include_stop_str_in_output=include_stop_str_in_output, include_stop_str_in_output=include_stop_str_in_output,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
max_tokens=max_tokens, max_tokens=max_tokens,
...@@ -858,6 +862,7 @@ class SamplingParams( ...@@ -858,6 +862,7 @@ class SamplingParams(
f"stop={self.stop}, " f"stop={self.stop}, "
f"stop_token_ids={self.stop_token_ids}, " f"stop_token_ids={self.stop_token_ids}, "
f"bad_words={self.bad_words}, " f"bad_words={self.bad_words}, "
f"thinking_token_budget={self.thinking_token_budget}, "
f"include_stop_str_in_output={self.include_stop_str_in_output}, " f"include_stop_str_in_output={self.include_stop_str_in_output}, "
f"ignore_eos={self.ignore_eos}, " f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, " f"max_tokens={self.max_tokens}, "
......
...@@ -99,6 +99,16 @@ class InputProcessor: ...@@ -99,6 +99,16 @@ class InputProcessor:
self.structured_outputs_config, self.structured_outputs_config,
self.tokenizer, self.tokenizer,
) )
if (
params.thinking_token_budget is not None
and self.vllm_config.reasoning_config is None
):
raise ValueError(
"thinking_token_budget is set but reasoning_config is "
"not configured. Please set --reasoning-config to use "
"thinking_token_budget."
)
elif isinstance(params, PoolingParams): elif isinstance(params, PoolingParams):
supported_pooling_tasks = [ supported_pooling_tasks = [
task for task in supported_tasks if task in POOLING_TASKS task for task in supported_tasks if task in POOLING_TASKS
......
...@@ -18,6 +18,7 @@ from vllm.v1.sample.logits_processor.builtin import ( ...@@ -18,6 +18,7 @@ from vllm.v1.sample.logits_processor.builtin import (
LogitBiasLogitsProcessor, LogitBiasLogitsProcessor,
MinPLogitsProcessor, MinPLogitsProcessor,
MinTokensLogitsProcessor, MinTokensLogitsProcessor,
ThinkingTokenBudgetLogitsProcessor,
process_dict_updates, process_dict_updates,
) )
from vllm.v1.sample.logits_processor.interface import ( from vllm.v1.sample.logits_processor.interface import (
...@@ -50,6 +51,7 @@ BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [ ...@@ -50,6 +51,7 @@ BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [
MinTokensLogitsProcessor, MinTokensLogitsProcessor,
LogitBiasLogitsProcessor, LogitBiasLogitsProcessor,
MinPLogitsProcessor, MinPLogitsProcessor,
ThinkingTokenBudgetLogitsProcessor,
] ]
...@@ -354,4 +356,5 @@ __all__ = [ ...@@ -354,4 +356,5 @@ __all__ = [
"STR_POOLING_REJECTS_LOGITSPROCS", "STR_POOLING_REJECTS_LOGITSPROCS",
"LOGITSPROCS_GROUP", "LOGITSPROCS_GROUP",
"AdapterLogitsProcessor", "AdapterLogitsProcessor",
"ThinkingTokenBudgetLogitsProcessor",
] ]
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, TypeVar from typing import TYPE_CHECKING, Any, TypeVar
import numpy as np import numpy as np
import torch import torch
...@@ -291,6 +291,263 @@ class MinTokensLogitsProcessor(LogitsProcessor): ...@@ -291,6 +291,263 @@ class MinTokensLogitsProcessor(LogitsProcessor):
return logits return logits
class ThinkingTokenBudgetLogitsProcessor(LogitsProcessor):
"""Limits the number of tokens allowed inside a 'thinking' section."""
def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
):
reasoning_config = vllm_config.reasoning_config
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
# Check if thinking is enabled
self.is_enabled = reasoning_config is not None
self.think_start_token_ids = getattr(
reasoning_config, "think_start_token_ids", []
)
self.think_end_token_ids = getattr(reasoning_config, "think_end_token_ids", [])
self.pin_memory = is_pin_memory
self.device = device
# Per-request state tracking for thinking token management
# Key: request_index, Value: state dict containing:
# "in_think": bool - currently in thinking mode
# "in_end": bool - currently forcing end tokens output
# "check_count_down": int - steps remaining until next think
# start/end token parsing
# "think_count": int - number of thinking tokens generated
# "end_count": int - number of end tokens forced so far
# "thinking_token_budget": int - max allowed thinking tokens
# "output_tok_ids": list[int] - generated output tokens
# "prev_output_length": int - previous output length for
# incremental processing
self._state: dict[int, dict[str, Any]] = {}
# Preallocate reusable tensors
self.mask = torch.zeros(max_num_reqs, dtype=torch.bool, device=device)
self.force_token_ids = torch.full(
(max_num_reqs,), -1, dtype=torch.long, device=device
)
@staticmethod
def _find_last_sequence_index(target_list: list[int], token_ids: list[int]) -> int:
"""
Returns the index of the last occurrence of token_ids in target_list.
Args:
target_list (list[int]): The list of token IDs.
token_ids (list[int]): The sequence of token IDs to find.
"""
if not token_ids:
return -1
for i in range(len(target_list) - len(token_ids), -1, -1):
if target_list[i : i + len(token_ids)] == token_ids:
return i
return -1
def _init_state_entry(
self, prompt_tok_ids: list[int] | None, thinking_token_budget: int
) -> dict[str, Any]:
"""Initializes the tracking state for a given sequence index."""
if prompt_tok_ids is None:
last_start = -1
last_end = -1
in_think = False
think_count = 0
else:
last_start = self._find_last_sequence_index(
prompt_tok_ids, self.think_start_token_ids
)
last_end = self._find_last_sequence_index(
prompt_tok_ids, self.think_end_token_ids
)
in_think = last_start > last_end
if in_think:
think_count = len(prompt_tok_ids) - (
last_start + len(self.think_start_token_ids)
)
else:
think_count = 0
return {
"in_think": in_think, # Currently in thinking mode
"in_end": in_think and thinking_token_budget == 0,
"check_count_down": thinking_token_budget,
"think_count": think_count, # Number of tokens in thinking section
"end_count": 0, # Number of end tokens forced so far
"prompt_tok_ids": prompt_tok_ids,
"output_tok_ids": [],
"thinking_token_budget": thinking_token_budget,
"prev_output_length": 0,
# Track previous output length for incremental updates
}
def _update_think_state(self, state: dict[str, Any]):
"""Updates the state based on newly generated output tokens."""
if not state.get("in_end", False) and state.get("check_count_down", 0) > 0:
state["check_count_down"] -= 1
return
output = state.get("output_tok_ids", [])
if not output:
return
# Track previous output length for incremental processing
prev_length = state.get("prev_output_length", 0)
current_length = len(output)
if current_length <= prev_length:
return
# Process only newly added tokens
new_tokens = output[prev_length:]
state["prev_output_length"] = current_length
# Check if new tokens contain think start or end sequences
start_len = len(self.think_start_token_ids)
end_len = len(self.think_end_token_ids)
# Look for think sequences in recent tokens (including boundary)
# Check overlapping regions where sequences might span boundaries
check_start_idx = max(0, prev_length - max(start_len, end_len) + 1)
recent_tokens = output[check_start_idx:]
# Find any think start/end sequences in recent tokens
recent_start_pos = self._find_last_sequence_index(
recent_tokens, self.think_start_token_ids
)
recent_end_pos = self._find_last_sequence_index(
recent_tokens, self.think_end_token_ids
)
# Update state based on recent sequences
if not state["in_end"]:
if recent_start_pos >= 0 and recent_end_pos >= 0:
if recent_start_pos > recent_end_pos:
# Case: ...<end>...<start>... - entering think mode
absolute_start_pos = check_start_idx + recent_start_pos
new_think_count = current_length - (absolute_start_pos + start_len)
state["in_think"] = True
state["think_count"] = new_think_count
else:
# Case: ...<start>...<end>... - exiting think mode
state["in_think"] = False
state["think_count"] = 0
elif recent_start_pos >= 0:
# Found think start - entering think mode
absolute_start_pos = check_start_idx + recent_start_pos
new_think_count = current_length - (absolute_start_pos + start_len)
state["in_think"] = True
state["think_count"] = new_think_count
elif recent_end_pos >= 0:
# Found think end - exiting think mode
state["in_think"] = False
state["think_count"] = 0
elif state["in_think"]:
# Continue thinking mode, increment count by new tokens
state["think_count"] += len(new_tokens)
# Set countdown based on current state
if state["in_think"]:
remaining_budget = max(
0, state["thinking_token_budget"] - state["think_count"]
)
state["check_count_down"] = max(0, remaining_budget - 1)
else:
state["check_count_down"] = state["thinking_token_budget"]
# Check if need to transition to end mode
if (
state["in_think"]
and state["think_count"] >= state["thinking_token_budget"]
):
state["in_think"] = False
state["in_end"] = True
state["end_count"] = 0
state["check_count_down"] = state["thinking_token_budget"]
else:
# In end mode
state["end_count"] += 1
if state["end_count"] >= len(self.think_end_token_ids):
state.update(
{
"in_end": False,
"end_count": 0,
"check_count_down": state["thinking_token_budget"],
}
)
def is_argmax_invariant(self) -> bool:
"""This logits processor can change the outcome of
greedy sampling by forcing that the thinking section
ends after a certain number of tokens."""
return False
def update_state(self, batch_update: BatchUpdate | None):
if not self.is_enabled:
return
if batch_update:
for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
thinking_token_budget = params.thinking_token_budget
if thinking_token_budget is not None:
self._state[index] = self._init_state_entry(
prompt_tok_ids, thinking_token_budget
)
self._state[index]["output_tok_ids"] = output_tok_ids
else:
# Remove state if no thinking budget
self._state.pop(index, None)
for index in batch_update.removed:
self._state.pop(index, {})
for i1, i2, direction in batch_update.moved:
if direction == MoveDirectionality.SWAP:
state1 = self._state.pop(i1, None)
state2 = self._state.pop(i2, None)
if state1 is not None:
self._state[i2] = state1
if state2 is not None:
self._state[i1] = state2
else:
state = self._state.pop(i1, None)
if state is not None:
self._state[i2] = state
for state in self._state.values():
self._update_think_state(state)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not self.is_enabled or not self._state:
return logits
batch_size = logits.size(0)
self.mask[:batch_size] = False
for i in range(batch_size):
state = self._state.get(i)
if state and state["in_end"]:
self.mask[i] = True
self.force_token_ids[i] = self.think_end_token_ids[state["end_count"]]
# Check in CPU first not to sync with GPU
has_active_thinking = any(
state.get("in_end", False) for state in self._state.values()
)
if has_active_thinking:
current_mask = self.mask[:batch_size]
active_indices = current_mask.nonzero(as_tuple=False).view(-1)
if len(active_indices) > 0:
force_tokens = self.force_token_ids[active_indices]
# Apply a large value for the end thinking token id index
logits[active_indices, force_tokens] = 1e9
return logits
def process_dict_updates( def process_dict_updates(
req_entries: dict[int, T], req_entries: dict[int, T],
batch_update: BatchUpdate | None, batch_update: BatchUpdate | None,
......
...@@ -629,7 +629,10 @@ class GPUModelRunner( ...@@ -629,7 +629,10 @@ class GPUModelRunner(
), ),
# We currently don't know whether a particular custom logits processor # We currently don't know whether a particular custom logits processor
# uses output token ids so we set this conservatively. # uses output token ids so we set this conservatively.
logitsprocs_need_output_token_ids=bool(custom_logitsprocs), # ThinkingTokenBudgetLogitsProcessor also needs output token ids to
# correctly track think start/end token sequences in async scheduling.
logitsprocs_need_output_token_ids=bool(custom_logitsprocs)
or self.vllm_config.reasoning_config is not None,
is_pooling_model=self.is_pooling_model, is_pooling_model=self.is_pooling_model,
cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment