"components/src/vscode:/vscode.git/clone" did not exist on "7d78fdad8dc7249b4940098ec77d0e4fbfeab1c2"
Unverified Commit 5544f8c1 authored by Fergus's avatar Fergus Committed by GitHub
Browse files

[Performance] Add is_reasoning_end_streaming() override to GptOssReasoningParser (#35745)


Signed-off-by: default avatarFergus <fergus.barratt00@gmail.com>
Signed-off-by: default avatarfergus barratt <fergus.barratt00@gmail.com>
Co-authored-by: default avatarChauncey <chaunceyjiang@gmail.com>
parent 9f39b380
...@@ -280,3 +280,72 @@ class TestGptOssStructuralTags: ...@@ -280,3 +280,72 @@ class TestGptOssStructuralTags:
assert tag["content"]["type"] == "any_text" assert tag["content"]["type"] == "any_text"
assert tag["end"] == "<|end|>" assert tag["end"] == "<|end|>"
assert tag["begin"].startswith("<|channel|>") assert tag["begin"].startswith("<|channel|>")
@pytest.mark.parametrize(
"output, is_reasoning_end",
[(t["output"], t["is_reasoning_end"]) for t in TEST_CASES],
)
def test_gptoss_is_reasoning_end_streaming(
output,
is_reasoning_end,
gpt_oss_tokenizer,
):
"""Streaming override must agree with is_reasoning_end for all cases."""
tokens = gpt_oss_tokenizer.tokenize(output)
parser: ReasoningParser = GptOssReasoningParser(gpt_oss_tokenizer)
output_ids = gpt_oss_tokenizer.convert_tokens_to_ids(tokens)
delta_ids = output_ids[-1:] if output_ids else []
actual = parser.is_reasoning_end_streaming(output_ids, delta_ids)
assert is_reasoning_end == actual
@pytest.mark.parametrize(
"output, is_reasoning_end",
[(t["output"], t["is_reasoning_end"]) for t in TEST_CASES],
)
def test_gptoss_is_reasoning_end_streaming_long_prefix(
output,
is_reasoning_end,
gpt_oss_tokenizer,
):
"""Windowing must produce correct results even with a long prefix."""
tokens = gpt_oss_tokenizer.tokenize(output)
parser: ReasoningParser = GptOssReasoningParser(gpt_oss_tokenizer)
output_ids = gpt_oss_tokenizer.convert_tokens_to_ids(tokens)
# Prepend 10k dummy reasoning tokens to simulate a long generation
long_prefix = [1] * 10_000
padded_ids = long_prefix + list(output_ids)
delta_ids = output_ids[-1:] if output_ids else []
actual = parser.is_reasoning_end_streaming(padded_ids, delta_ids)
assert is_reasoning_end == actual
@pytest.mark.parametrize(
"output, is_reasoning_end",
[(t["output"], t["is_reasoning_end"]) for t in TEST_CASES],
)
def test_gptoss_is_reasoning_end_streaming_large_delta(
output,
is_reasoning_end,
gpt_oss_tokenizer,
):
"""Simulate speculative decoding where the entire test sequence arrives
as a single large delta appended after a long prefix. The window must
expand to cover delta_ids so the end pattern is never missed."""
tokens = gpt_oss_tokenizer.tokenize(output)
parser: ReasoningParser = GptOssReasoningParser(gpt_oss_tokenizer)
output_ids = gpt_oss_tokenizer.convert_tokens_to_ids(tokens)
long_prefix = [1] * 10_000
padded_ids = long_prefix + list(output_ids)
# delta_ids = the entire test sequence (as if accepted in one spec step)
delta_ids = list(output_ids)
actual = parser.is_reasoning_end_streaming(padded_ids, delta_ids)
assert is_reasoning_end == actual
def test_gptoss_is_reasoning_end_streaming_signature(gpt_oss_tokenizer):
"""Verify the method is callable with the expected signature."""
parser = GptOssReasoningParser(gpt_oss_tokenizer)
result = parser.is_reasoning_end_streaming([], [])
assert result is False
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json import json
from collections.abc import Sequence from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
...@@ -112,6 +112,25 @@ class GptOssReasoningParser(ReasoningParser): ...@@ -112,6 +112,25 @@ class GptOssReasoningParser(ReasoningParser):
return True return True
return False return False
def is_reasoning_end_streaming(
self, input_ids: Sequence[int], delta_ids: Iterable[int]
) -> bool:
# The pattern window covers the end-of-reasoning marker itself.
# We add len(delta_ids) so that under speculative decoding (where
# a single step can accept many tokens) the entire accepted chunk
# is always inside the scan region.
delta_ids = tuple(delta_ids)
pattern_len = (
len(self.reasoning_end_token_ids_prefix)
+ self.reasoning_max_num_between_tokens
+ len(self.reasoning_end_token_ids_suffix)
)
window = pattern_len + len(delta_ids)
n = len(input_ids)
if n <= window:
return self.is_reasoning_end(input_ids)
return self.is_reasoning_end(input_ids[n - window :])
def extract_content_ids(self, input_ids: list[int]) -> list[int]: def extract_content_ids(self, input_ids: list[int]) -> list[int]:
_, content, _ = parse_chat_output(input_ids) _, content, _ = parse_chat_output(input_ids)
if content is None: if content is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment