Unverified Commit 6796ce8b authored by Chauncey's avatar Chauncey Committed by GitHub
Browse files

[Bugfix] Fix the issue with interleaved thinking when using streaming (#30033)


Signed-off-by: default avatarchaunceyjiang <chaunceyjiang@gmail.com>
Signed-off-by: default avatarChauncey <chaunceyjiang@gmail.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent e96a6a6d
...@@ -112,7 +112,7 @@ class TestBaseThinkingReasoningParserMethods: ...@@ -112,7 +112,7 @@ class TestBaseThinkingReasoningParserMethods:
"""Test the is_reasoning_end method.""" """Test the is_reasoning_end method."""
parser = TestThinkingReasoningParser(test_tokenizer) parser = TestThinkingReasoningParser(test_tokenizer)
end_token_id = parser.end_token_id end_token_id = parser.end_token_id
start_token_id = parser.start_token_id
# Test with end token present # Test with end token present
assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True
...@@ -122,6 +122,16 @@ class TestBaseThinkingReasoningParserMethods: ...@@ -122,6 +122,16 @@ class TestBaseThinkingReasoningParserMethods:
# Test with empty list # Test with empty list
assert parser.is_reasoning_end([]) is False assert parser.is_reasoning_end([]) is False
# Test with interleaved thinking
assert parser.is_reasoning_end([1, start_token_id, 2, end_token_id]) is True
assert parser.is_reasoning_end([1, start_token_id, 2, 3]) is False
assert (
parser.is_reasoning_end(
[1, start_token_id, 2, end_token_id, 2, 2, start_token_id]
)
is False
)
def test_extract_content_ids(self, test_tokenizer): def test_extract_content_ids(self, test_tokenizer):
"""Test the extract_content_ids method.""" """Test the extract_content_ids method."""
parser = TestThinkingReasoningParser(test_tokenizer) parser = TestThinkingReasoningParser(test_tokenizer)
......
...@@ -64,8 +64,15 @@ class BaseThinkingReasoningParser(ReasoningParser): ...@@ -64,8 +64,15 @@ class BaseThinkingReasoningParser(ReasoningParser):
) )
def is_reasoning_end(self, input_ids: list[int]) -> bool: def is_reasoning_end(self, input_ids: list[int]) -> bool:
start_token_id = self.start_token_id
end_token_id = self.end_token_id end_token_id = self.end_token_id
return any(input_id == end_token_id for input_id in reversed(input_ids))
for i in range(len(input_ids) - 1, -1, -1):
if input_ids[i] == start_token_id:
return False
if input_ids[i] == end_token_id:
return True
return False
def extract_content_ids(self, input_ids: list[int]) -> list[int]: def extract_content_ids(self, input_ids: list[int]) -> list[int]:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment