Unverified Commit ecbfbb8d authored by Chauncey's avatar Chauncey Committed by GitHub
Browse files

[Feature] Add auto-detection for reasoning_config when only reasoning_parser is set (#38214)


Signed-off-by: default avatarchaunceyjiang <chaunceyjiang@gmail.com>
parent e0613702
......@@ -249,7 +249,7 @@ Token counting starts from `reasoning_start_str`. Once the reasoning token count
To use this feature:
- `--reasoning-parser` enables reasoning extraction.
- `--reasoning-config` defines the reasoning boundary tokens (e.g., `reasoning_start_str`, `reasoning_end_str`).
- `--reasoning-config` defines the reasoning boundary tokens (e.g., `reasoning_start_str`, `reasoning_end_str`). If not set, vLLM will attempt to automatically initialize these tokens from the reasoning parser.
- `thinking_token_budget` (a sampling parameter) sets the per-request reasoning token limit.
If `thinking_token_budget` is not specified, no explicit reasoning limit is applied beyond normal generation constraints such as `max_tokens`.
......
......@@ -24,6 +24,24 @@ def server():
"--max-model-len",
"2048",
"--enforce-eager",
"--gpu-memory-utilization",
"0.4",
"--no-async-scheduling",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def server_with_auto_reasoning_config():
args = [
"--reasoning-parser",
"qwen3",
"--max-model-len",
"2048",
"--enforce-eager",
"--gpu-memory-utilization",
"0.4",
"--no-async-scheduling",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
......@@ -31,12 +49,18 @@ def server():
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
async def client(request, server, server_with_auto_reasoning_config):
server_map = {
"default": server,
"auto_config": server_with_auto_reasoning_config,
}
target_server = server_map[request.param]
async with target_server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("client", ["default", "auto_config"], indirect=True)
async def test_thinking_token_budget_mixed_requests(client: openai.AsyncOpenAI):
"""Test that mixed requests (some with thinking_token_budget, some without)
complete successfully without errors."""
......@@ -61,6 +85,7 @@ async def test_thinking_token_budget_mixed_requests(client: openai.AsyncOpenAI):
@pytest.mark.asyncio
@pytest.mark.parametrize("client", ["default", "auto_config"], indirect=True)
async def test_thinking_token_budget_limits_reasoning(client: openai.AsyncOpenAI):
"""Test that thinking_token_budget limits the number of reasoning tokens.
......@@ -82,6 +107,6 @@ async def test_thinking_token_budget_limits_reasoning(client: openai.AsyncOpenAI
reasoning_token_count += 1
assert reasoning_token_count == THINK_BUDGET, (
f"reasoning tokens ({reasoning_token_count}) != "
f"reasoning tokens ({reasoning_token_count}) exceeded "
f"thinking_token_budget ({THINK_BUDGET})"
)
......@@ -106,6 +106,7 @@ class MockReasoningConfig:
reasoning_start_token_ids = [THINK_START_TOKEN_ID]
reasoning_end_token_ids = [THINK_END_TOKEN_ID]
enabled = True
def _generate_fake_sampling_metadata(
......
......@@ -5,6 +5,7 @@ from dataclasses import field
from vllm.config.model import ModelConfig
from vllm.config.utils import config
from vllm.reasoning import ReasoningParserManager
from vllm.tokenizers import cached_tokenizer_from_config
......@@ -18,11 +19,11 @@ class ReasoningConfig:
`initialize_token_ids` and are not intended to be set directly.
"""
# NOTE: These parameters are temporary, the intent is to derive them
# automatically from the reasoning parser in a future version.
reasoning_start_str: str = "<think>"
reasoning_parser: str = ""
"""The name of the ReasoningParser to use for this model."""
reasoning_start_str: str = ""
"""String that indicates the start of reasoning."""
reasoning_end_str: str = "</think>"
reasoning_end_str: str = ""
"""String that indicates the end of reasoning content."""
_reasoning_start_token_ids: list[int] | None = field(
......@@ -36,6 +37,16 @@ class ReasoningConfig:
"""Private backing field for `reasoning_end_token_ids`. Set by
`initialize_token_ids`. Not intended to be configured directly."""
_enabled: bool = field(default=False, init=False, repr=False)
"""Private field indicating whether reasoning token IDs have been initialized.
Set to True by `initialize_token_ids` once token IDs are initialized."""
@property
def enabled(self) -> bool:
"""Returns True if reasoning is enabled (i.e. if token IDs have been
initialized), False otherwise."""
return self._enabled
@property
def reasoning_start_token_ids(self) -> list[int] | None:
"""Token IDs derived from `reasoning_start_str`. Set automatically by
......@@ -54,15 +65,36 @@ class ReasoningConfig:
self._reasoning_start_token_ids is not None
and self._reasoning_end_token_ids is not None
):
return
self._enabled = True
return # Already initialized
tokenizer = cached_tokenizer_from_config(model_config=model_config)
reasoning_start_str = self.reasoning_start_str
reasoning_end_str = self.reasoning_end_str
if self.reasoning_parser is not None and (
not reasoning_start_str or not reasoning_end_str
):
parser_cls = ReasoningParserManager.get_reasoning_parser(
self.reasoning_parser
)
reasoning_parser = parser_cls(tokenizer)
start_token = reasoning_parser.reasoning_start_str
if start_token and not reasoning_start_str:
reasoning_start_str = start_token
end_token = reasoning_parser.reasoning_end_str
if end_token and not reasoning_end_str:
reasoning_end_str = end_token
if not reasoning_start_str or not reasoning_end_str:
# If we don't have valid strings to tokenize,
# we can't initialize the token IDs.
return
self._reasoning_start_token_ids = tokenizer.encode(
self.reasoning_start_str, add_special_tokens=False
reasoning_start_str, add_special_tokens=False
)
self._reasoning_end_token_ids = tokenizer.encode(
self.reasoning_end_str, add_special_tokens=False
reasoning_end_str, add_special_tokens=False
)
if not self._reasoning_start_token_ids or not self._reasoning_end_token_ids:
......@@ -72,3 +104,4 @@ class ReasoningConfig:
f"reasoning_end_str='{self.reasoning_end_str}'. "
"Ensure the strings are valid tokens in the model's vocabulary."
)
self._enabled = True
......@@ -1210,6 +1210,12 @@ class VllmConfig:
if self.reasoning_config is not None and self.model_config is not None:
self.reasoning_config.initialize_token_ids(self.model_config)
if not self.reasoning_config.enabled:
logger.warning_once(
"Auto-initialization of reasoning token IDs failed. "
"Please check whether your reasoning parser has implemented "
"the `reasoning_start_str` and `reasoning_end_str`."
)
# Hybrid KV cache manager (HMA) runtime rules:
# - Explicit enable (--no-disable-kv-cache-manager): error if runtime
......
......@@ -1591,7 +1591,7 @@ class EngineArgs:
self._set_default_max_num_seqs_and_batched_tokens_args(
usage_context, model_config
)
self._set_default_reasoning_config_args()
sliding_window: int | None = None
if not is_interleaved(model_config.hf_text_config):
# Only set CacheConfig.sliding_window if the model is all sliding
......@@ -2233,6 +2233,13 @@ class EngineArgs:
)
self.enable_prefix_caching = False
def _set_default_reasoning_config_args(self):
if not self.reasoning_parser:
return
if self.reasoning_config is None:
self.reasoning_config = ReasoningConfig()
self.reasoning_config.reasoning_parser = self.reasoning_parser
def _set_default_max_num_seqs_and_batched_tokens_args(
self,
usage_context: UsageContext | None,
......
......@@ -39,6 +39,20 @@ class ReasoningParser:
# whereas all tokenizers have .get_vocab()
return self.model_tokenizer.get_vocab()
@property
def reasoning_start_str(self) -> str | None:
"""Set `reasoning_start_str` to the strings that delimit
the reasoning block (e.g. `""<seed:think>""` and `"<think>"`).
"""
return None
@property
def reasoning_end_str(self) -> str | None:
"""Set `reasoning_end_str` to the strings that delimit
the reasoning block (e.g. `""</seed:think>""` and `"</think>"`).
"""
return None
@abstractmethod
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
"""
......
......@@ -39,6 +39,14 @@ class BaseThinkingReasoningParser(ReasoningParser):
"""The token that ends reasoning content."""
raise NotImplementedError
@property
def reasoning_start_str(self) -> str:
return self.start_token
@property
def reasoning_end_str(self) -> str:
return self.end_token
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
......
......@@ -98,9 +98,9 @@ class InputProcessor:
self.tokenizer,
)
if (
params.thinking_token_budget is not None
and self.vllm_config.reasoning_config is None
if params.thinking_token_budget is not None and (
self.vllm_config.reasoning_config is None
or not self.vllm_config.reasoning_config.enabled
):
raise ValueError(
"thinking_token_budget is set but reasoning_config is "
......
......@@ -301,7 +301,7 @@ class ThinkingTokenBudgetLogitsProcessor(LogitsProcessor):
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
# Check if thinking is enabled
self.is_enabled = reasoning_config is not None
self.is_enabled = reasoning_config is not None and reasoning_config.enabled
self.reasoning_start_token_ids = getattr(
reasoning_config, "reasoning_start_token_ids", []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment