Unverified Commit af5e6afa authored by Mario Hong's avatar Mario Hong Committed by GitHub
Browse files

[Bugfix] Fix step3p5 reasoning with interleaved thinking (#34211)


Signed-off-by: default avatarmariohong <mariohong128@gmail.com>
Co-authored-by: default avatarChauncey <chaunceyjiang@gmail.com>
parent ee59a7c6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from transformers import AutoTokenizer
from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
parser_name = "step3p5"
start_token = "<think>"
end_token = "</think>"
REASONING_MODEL_NAME = "stepfun-ai/Step-3.5-Flash"
@pytest.fixture(scope="module")
def step3p5_tokenizer():
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
SIMPLE_REASONING = {
"output": "This is a reasoning section</think>This is the rest",
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
# need to get into parser again to remove newline after </think>
COMPLETE_REASONING = {
"output": "This is a reasoning section</think>",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": False,
}
NO_CONTENT = {
"output": "This is content",
"reasoning_content": "This is content",
"content": None,
"is_reasoning_end": False,
}
NO_REASONING_STREAMING = {
"output": "This is a reasoning section",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": False,
}
MULTIPLE_LINES = {
"output": "This\nThat</think>This is the rest\nThat",
"reasoning_content": "This\nThat",
"content": "This is the rest\nThat",
"is_reasoning_end": True,
}
SHORTEST_REASONING_NO_STREAMING = {
"output": "</think>This is the rest",
"reasoning_content": None,
"content": "This is the rest",
"is_reasoning_end": True,
}
SHORTEST_REASONING = {
"output": "</think>This is the rest",
"reasoning_content": None,
"content": "This is the rest",
"is_reasoning_end": True,
}
REASONING_WITH_THINK = {
"output": "<think>This is a reasoning section</think>This is the rest",
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
COMPLETE_REASONING_WITH_THINK = {
"output": "<think>This is a reasoning section</think>",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": False,
}
MULTIPLE_LINES_WITH_THINK = {
"output": "<think>This\nThat</think>This is the rest\nThat",
"reasoning_content": "This\nThat",
"content": "This is the rest\nThat",
"is_reasoning_end": True,
}
SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
"output": "</think>This is the rest",
"reasoning_content": None,
"content": "This is the rest",
"is_reasoning_end": True,
}
SHORTEST_REASONING_WITH_THINK = {
"output": "</think>This is the rest",
"reasoning_content": None,
"content": "This is the rest",
"is_reasoning_end": True,
}
THINK_NO_END = {
"output": "<think>This is a reasoning section",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": False,
}
EMPTY = {
"output": "",
"reasoning_content": None,
"content": None,
"is_reasoning_end": False,
}
EMPTY_STREAMING = {
"output": "",
"reasoning_content": None,
"content": None,
"is_reasoning_end": False,
}
NEW_LINE = {
"output": "\n<think>This is a reasoning section</think>\nThis is the rest",
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
NEW_LINE_STREAMING = {
"output": "\n<think>This is a reasoning section\n</think>\nThis is the rest",
"reasoning_content": "\nThis is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
NEW_LINE_STREAMING_COMPLEX_CONTENT = {
"output": "\n This is a \n reasoning section\n\n\n</think>\n\nThis is the rest",
"reasoning_content": "\n This is a \n reasoning section\n\n",
"content": "\nThis is the rest",
"is_reasoning_end": True,
}
MULTI_TURN_PROMPT_CONTENT = {
"output": "<think> This is last turn's reasoning section </think> hello <think>",
"reasoning_content": "",
"content": "",
"is_reasoning_end": False,
}
TEST_CASES = [
pytest.param(
False,
SIMPLE_REASONING,
id="simple_reasoning",
),
pytest.param(
True,
SIMPLE_REASONING,
id="simple_reasoning_streaming",
),
pytest.param(
False,
COMPLETE_REASONING,
id="complete_reasoning",
),
pytest.param(
True,
COMPLETE_REASONING,
id="complete_reasoning_streaming",
),
pytest.param(
False,
NO_CONTENT,
id="no_content_token",
),
pytest.param(
True,
NO_REASONING_STREAMING,
id="no_reasoning_token_streaming",
),
pytest.param(
False,
MULTIPLE_LINES,
id="multiple_lines",
),
pytest.param(
True,
MULTIPLE_LINES,
id="multiple_lines_streaming",
),
pytest.param(
True,
SHORTEST_REASONING,
id="shortest",
),
pytest.param(
False,
SHORTEST_REASONING_NO_STREAMING,
id="shortest_streaming",
),
pytest.param(
False,
REASONING_WITH_THINK,
id="reasoning_with_think",
),
pytest.param(
True,
REASONING_WITH_THINK,
id="reasoning_with_think_streaming",
),
pytest.param(
False,
COMPLETE_REASONING_WITH_THINK,
id="complete_reasoning_with_think",
),
pytest.param(
True,
COMPLETE_REASONING_WITH_THINK,
id="complete_reasoning_with_think_streaming",
),
pytest.param(
False,
MULTIPLE_LINES_WITH_THINK,
id="multiple_lines_with_think",
),
pytest.param(
True,
MULTIPLE_LINES_WITH_THINK,
id="multiple_lines_with_think_streaming",
),
pytest.param(
False,
SHORTEST_REASONING_NO_STREAMING_WITH_THINK,
id="shortest_with_think",
),
pytest.param(
True,
SHORTEST_REASONING_WITH_THINK,
id="shortest_with_think_streaming",
),
pytest.param(
False,
THINK_NO_END,
id="think_no_end",
),
pytest.param(
True,
THINK_NO_END,
id="think_no_end_streaming",
),
pytest.param(
False,
EMPTY,
id="empty",
),
pytest.param(
True,
EMPTY_STREAMING,
id="empty_streaming",
),
pytest.param(
False,
NEW_LINE,
id="new_line",
),
pytest.param(
True,
NEW_LINE_STREAMING,
id="new_line_streaming",
),
pytest.param(
True,
NEW_LINE_STREAMING_COMPLEX_CONTENT,
id="new_line_streaming_complex_content",
),
pytest.param(
True,
MULTI_TURN_PROMPT_CONTENT,
id="multi_turn_prompt_content",
),
]
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_reasoning(
streaming: bool,
param_dict: dict,
step3p5_tokenizer,
request,
):
output = step3p5_tokenizer.tokenize(param_dict["output"])
# decode everything to tokens
output_tokens: list[str] = [
step3p5_tokenizer.convert_tokens_to_string([token]) for token in output
]
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
step3p5_tokenizer
)
reasoning, content = run_reasoning_extraction(
parser, output_tokens, streaming=streaming
)
print(f"reasoning: {reasoning}")
print(f"content: {content}")
test_id = request.node.callspec.id if hasattr(request.node, "callspec") else None
if request.node.callspec.id != "multi_turn_prompt_content":
assert reasoning == param_dict["reasoning_content"]
assert content == param_dict["content"]
# Test is_reasoning_end
output_ids = step3p5_tokenizer.convert_tokens_to_ids(output)
if streaming:
is_reasoning_end = parser.is_reasoning_end(output_ids)
assert is_reasoning_end == param_dict["is_reasoning_end"]
# Test extract_content
if param_dict["content"] is not None:
content = parser.extract_content_ids(output_ids)
# Fixed expected token ids for specific test cases
test_id = (
request.node.callspec.id if hasattr(request.node, "callspec") else None
)
# Match most specific first
if test_id not in [
"new_line_streaming_complex_content",
"new_line_streaming",
"new_line",
"multi_turn_prompt_content",
]:
expected_content_ids = step3p5_tokenizer.convert_tokens_to_ids(
step3p5_tokenizer.tokenize(param_dict["content"])
)
assert content == expected_content_ids
else:
content = parser.extract_content_ids(output)
assert content == []
def test_step3p5_streaming_drops_leading_newline(step3p5_tokenizer):
parser_cls = ReasoningParserManager.get_reasoning_parser("step3p5")
parser = parser_cls(step3p5_tokenizer)
output = "<think>calc</think>\nAnswer"
tokens = step3p5_tokenizer.tokenize(output)
output_tokens = [
step3p5_tokenizer.convert_tokens_to_string([token]) for token in tokens
]
_, content = run_reasoning_extraction(parser, output_tokens, streaming=True)
assert content == "Answer"
...@@ -39,24 +39,59 @@ class Step3p5ReasoningParser(BaseThinkingReasoningParser): ...@@ -39,24 +39,59 @@ class Step3p5ReasoningParser(BaseThinkingReasoningParser):
# whether it is immediately before </think>. # whether it is immediately before </think>.
self._pending_reasoning_newline = False self._pending_reasoning_newline = False
# Used to delay the reasoning end detection. # Tracks whether we've seen </think> but are still waiting for one more
# This is necessary to remove the newline appears immediately after </think>, # token to confirm the end.
# which may cause the end detection to be delayed by one round. self._end_token_pending = False
self.end_offset = 1
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
if self.end_token_id in input_ids and self.end_offset > 0: return self._is_reasoning_end_from_ids(input_ids)
self.end_offset -= 1
return False
return self.end_offset < 1
def is_reasoning_end_streaming( def is_reasoning_end_streaming(
self, input_ids: Sequence[int], delta_ids: Iterable[int] self, input_ids: Sequence[int], delta_ids: Iterable[int]
) -> bool: ) -> bool:
if self.end_token_id in input_ids and self.end_offset > 0: # Only examine newly generated tokens; they may contain multiple ids.
self.end_offset -= 1 return self._is_reasoning_end_from_ids(delta_ids)
def _is_reasoning_end_from_ids(self, input_ids: Sequence[int]) -> bool:
# Scan backwards to find the last special token, <think> or </think>.
last_special = None
last_idx = -1
for i in range(len(input_ids) - 1, -1, -1):
token_id = input_ids[i]
if token_id == self.start_token_id:
last_special = "start"
last_idx = i
break
if token_id == self.end_token_id:
last_special = "end"
last_idx = i
break
if last_special == "start":
# If we're already waiting for one token after </think>, do not
# clear the pending state just because the prompt contains <think>.
# Streaming deltas should not include <think> for this model.
if self._end_token_pending:
return False
# A start token after any end token means reasoning is ongoing.
self._end_token_pending = False
return False
if last_special == "end":
# Require at least one token after </think> before ending.
if last_idx < len(input_ids) - 1:
self._end_token_pending = False
return True
self._end_token_pending = True
return False return False
return self.end_offset < 1
# No special tokens in this input. If we were waiting for one token
# after </think>, any new token completes the end.
if self._end_token_pending and input_ids:
self._end_token_pending = False
return True
return False
def extract_reasoning( def extract_reasoning(
self, self,
...@@ -136,9 +171,6 @@ class Step3p5ReasoningParser(BaseThinkingReasoningParser): ...@@ -136,9 +171,6 @@ class Step3p5ReasoningParser(BaseThinkingReasoningParser):
# Content: handle the newline immediately after </think>. # Content: handle the newline immediately after </think>.
if content_to_output is not None: if content_to_output is not None:
# No need to get into parser again to remove newline after </think>.
self.end_offset -= 1
# If we have content, reasoning must have ended. # If we have content, reasoning must have ended.
self._pending_reasoning_newline = False self._pending_reasoning_newline = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment