Unverified Commit 92762edc authored by Doug Campos's avatar Doug Campos Committed by GitHub
Browse files

[Bugfix] Treat <tool_call> as implicit reasoning end in Qwen3 parser (#35687)


Signed-off-by: default avatarDoug Campos <qmx@qmx.me>
parent 626daa20
......@@ -78,6 +78,25 @@ WITHOUT_THINK_STREAM = {
"content": None,
}
# --- <tool_call> without </think> (implicit reasoning end) ---
TOOL_CALL_BODY = (
"<tool_call>\n<function=bash>\n<parameter=command>"
"\ncat /etc/hosts\n</parameter>\n</function>\n</tool_call>"
)
TOOL_CALL_NO_THINK_END = {
"output": "I need to read the file.\n\n" + TOOL_CALL_BODY,
"reasoning": "I need to read the file.\n\n",
"content": TOOL_CALL_BODY,
}
TOOL_CALL_WITH_THINK_NO_END = {
"output": "<think>I need to read the file.\n\n" + TOOL_CALL_BODY,
"reasoning": "I need to read the file.\n\n",
"content": TOOL_CALL_BODY,
}
# --- Edge cases ---
COMPLETE_REASONING = {
......@@ -199,6 +218,26 @@ TEST_CASES = [
TRUNCATED_NO_START_TOKEN_STREAM,
id="truncated_no_start_token_stream",
),
pytest.param(
False,
TOOL_CALL_NO_THINK_END,
id="tool_call_no_think_end",
),
pytest.param(
True,
TOOL_CALL_NO_THINK_END,
id="tool_call_no_think_end_stream",
),
pytest.param(
False,
TOOL_CALL_WITH_THINK_NO_END,
id="tool_call_with_think_no_end",
),
pytest.param(
True,
TOOL_CALL_WITH_THINK_NO_END,
id="tool_call_with_think_no_end_stream",
),
]
......@@ -255,6 +294,13 @@ MULTI_TOKEN_DELTA_CASES = [
"content",
id="no_start_end_grouped_with_content",
),
pytest.param(
# <tool_call> arrives in a separate delta after reasoning text
["I need to read the file.\n\n", "<tool_call>\n<function=bash>"],
"I need to read the file.\n\n",
"<tool_call>\n<function=bash>",
id="tool_call_implicit_reasoning_end",
),
]
......@@ -296,6 +342,12 @@ THINKING_DISABLED_CASES = [
"Some output without think tokens",
id="thinking_disabled_no_think_tokens",
),
pytest.param(
"I need to read the file.\n\n" + TOOL_CALL_BODY,
None,
"I need to read the file.\n\n" + TOOL_CALL_BODY,
id="thinking_disabled_with_tool_call",
),
]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
......@@ -31,6 +31,10 @@ class Qwen3ReasoningParser(BaseThinkingReasoningParser):
use an older chat template where the model generates <think> itself.
This parser handles both styles: if <think> appears in the generated output
it is stripped before extraction (non-streaming) or skipped (streaming).
NOTE: Qwen3.5 models may emit <tool_call> inside the thinking block
without closing </think> first. <tool_call> is treated as an implicit
end of reasoning, matching the approach in KimiK2ReasoningParser.
"""
def __init__(self, tokenizer: "TokenizerLike", *args, **kwargs):
......@@ -41,6 +45,11 @@ class Qwen3ReasoningParser(BaseThinkingReasoningParser):
# pure content when the user explicitly disables it.
self.thinking_enabled = chat_kwargs.get("enable_thinking", True)
self._tool_call_tag = "<tool_call>"
self._tool_call_token_id = self.vocab.get(self._tool_call_tag)
self._tool_call_end_tag = "</tool_call>"
self._tool_call_end_token_id = self.vocab.get(self._tool_call_end_tag)
@property
def start_token(self) -> str:
"""The token that starts reasoning content."""
......@@ -51,6 +60,58 @@ class Qwen3ReasoningParser(BaseThinkingReasoningParser):
"""The token that ends reasoning content."""
return "</think>"
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
start_token_id = self.start_token_id
end_token_id = self.end_token_id
tool_call_token_id = self._tool_call_token_id
tool_call_end_token_id = self._tool_call_end_token_id
for i in range(len(input_ids) - 1, -1, -1):
token_id = input_ids[i]
if token_id == start_token_id:
# Found <think> before </think> or <tool_call>
return False
if token_id == end_token_id:
return True
if tool_call_token_id is not None and token_id == tool_call_token_id:
# Only treat as implicit reasoning end if this <tool_call>
# is NOT followed by </tool_call>. Paired occurrences are
# template examples in the prompt, not model output.
if tool_call_end_token_id is not None and any(
input_ids[j] == tool_call_end_token_id
for j in range(i + 1, len(input_ids))
):
continue
return True
return False
def is_reasoning_end_streaming(
self, input_ids: Sequence[int], delta_ids: Iterable[int]
) -> bool:
if super().is_reasoning_end_streaming(input_ids, delta_ids):
return True
if self._tool_call_token_id is not None:
return self._tool_call_token_id in delta_ids
return False
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract content token ids from the input_ids.
"""
result = super().extract_content_ids(input_ids)
if result:
return result
# Fall back: content starts at <tool_call> (implicit reasoning end).
if (
self._tool_call_token_id is not None
and self._tool_call_token_id in input_ids
):
tool_call_index = (
len(input_ids) - 1 - input_ids[::-1].index(self._tool_call_token_id)
)
return input_ids[tool_call_index:]
return []
def extract_reasoning(
self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
) -> tuple[str | None, str | None]:
......@@ -78,20 +139,24 @@ class Qwen3ReasoningParser(BaseThinkingReasoningParser):
model_output_parts[2] if model_output_parts[1] else model_output_parts[0]
)
if self.end_token not in model_output:
if self.end_token in model_output:
reasoning, _, content = model_output.partition(self.end_token)
return reasoning, content or None
if not self.thinking_enabled:
# Thinking explicitly disabled — treat everything as content.
return None, model_output
# No </think> — check for implicit reasoning end via <tool_call>.
tool_call_index = model_output.find(self._tool_call_tag)
if tool_call_index != -1:
reasoning = model_output[:tool_call_index]
content = model_output[tool_call_index:]
return reasoning or None, content or None
# Thinking enabled but no </think>: output was truncated.
# Everything generated so far is reasoning.
return model_output, None
# Extract reasoning content from the model output.
reasoning, _, content = model_output.partition(self.end_token)
final_content = content or None
return reasoning, final_content
def extract_reasoning_streaming(
self,
previous_text: str,
......@@ -135,6 +200,20 @@ class Qwen3ReasoningParser(BaseThinkingReasoningParser):
# end_token_id in IDs but not in text (already stripped)
return None
# Implicit reasoning end via <tool_call>.
if (
self._tool_call_token_id is not None
and self._tool_call_token_id in delta_token_ids
):
tool_index = delta_text.find(self._tool_call_tag)
if tool_index >= 0:
reasoning = delta_text[:tool_index]
content = delta_text[tool_index:]
return DeltaMessage(
reasoning=reasoning if reasoning else None,
content=content if content else None,
)
# No end token in this delta.
if not delta_text:
# Nothing left after stripping start token.
......@@ -142,6 +221,11 @@ class Qwen3ReasoningParser(BaseThinkingReasoningParser):
elif self.end_token_id in previous_token_ids:
# End token already passed: everything is content now.
return DeltaMessage(content=delta_text)
elif (
self._tool_call_token_id is not None
and self._tool_call_token_id in previous_token_ids
):
return DeltaMessage(content=delta_text)
else:
# No end token yet: still in reasoning phase.
return DeltaMessage(reasoning=delta_text)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment