Unverified Commit 9ce8fad2 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Perf] Optimize Python Slice for Structured Output using `islice` instead of [:] (#33593)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
Signed-off-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent c38b8d5a
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import importlib import importlib
import os import os
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Callable, Sequence from collections.abc import Callable, Iterable, Sequence
from functools import cached_property from functools import cached_property
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
...@@ -68,7 +68,7 @@ class ReasoningParser: ...@@ -68,7 +68,7 @@ class ReasoningParser:
""" """
def is_reasoning_end_streaming( def is_reasoning_end_streaming(
self, input_ids: Sequence[int], delta_ids: Sequence[int] self, input_ids: Sequence[int], delta_ids: Iterable[int]
) -> bool: ) -> bool:
""" """
Check if the reasoning content ends in the input_ids on a Check if the reasoning content ends in the input_ids on a
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Sequence from collections.abc import Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.entrypoints.openai.engine.protocol import DeltaMessage
...@@ -77,7 +78,7 @@ class BaseThinkingReasoningParser(ReasoningParser): ...@@ -77,7 +78,7 @@ class BaseThinkingReasoningParser(ReasoningParser):
return False return False
def is_reasoning_end_streaming( def is_reasoning_end_streaming(
self, input_ids: Sequence[int], delta_ids: Sequence[int] self, input_ids: Sequence[int], delta_ids: Iterable[int]
) -> bool: ) -> bool:
end_token_id = self.end_token_id end_token_id = self.end_token_id
return end_token_id in delta_ids return end_token_id in delta_ids
...@@ -86,7 +87,7 @@ class BaseThinkingReasoningParser(ReasoningParser): ...@@ -86,7 +87,7 @@ class BaseThinkingReasoningParser(ReasoningParser):
""" """
Extract the content after the end tokens Extract the content after the end tokens
""" """
if self.end_token_id not in input_ids[:-1]: if self.end_token_id not in islice(input_ids, 0, max(0, len(input_ids) - 1)):
return [] return []
else: else:
return input_ids[input_ids.index(self.end_token_id) + 1 :] return input_ids[input_ids.index(self.end_token_id) + 1 :]
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence from collections.abc import Iterable, Sequence
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
...@@ -41,7 +41,7 @@ class DeepSeekV3ReasoningParser(ReasoningParser): ...@@ -41,7 +41,7 @@ class DeepSeekV3ReasoningParser(ReasoningParser):
return self._parser.is_reasoning_end(input_ids) return self._parser.is_reasoning_end(input_ids)
def is_reasoning_end_streaming( def is_reasoning_end_streaming(
self, input_ids: Sequence[int], delta_ids: Sequence[int] self, input_ids: Sequence[int], delta_ids: Iterable[int]
) -> bool: ) -> bool:
return self._parser.is_reasoning_end_streaming(input_ids, delta_ids) return self._parser.is_reasoning_end_streaming(input_ids, delta_ids)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence from collections.abc import Iterable, Sequence
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
...@@ -36,7 +36,7 @@ class IdentityReasoningParser(ReasoningParser): ...@@ -36,7 +36,7 @@ class IdentityReasoningParser(ReasoningParser):
return True return True
def is_reasoning_end_streaming( def is_reasoning_end_streaming(
self, input_ids: Sequence[int], delta_ids: Sequence[int] self, input_ids: Sequence[int], delta_ids: Iterable[int]
) -> bool: ) -> bool:
return True return True
......
...@@ -69,7 +69,7 @@ class MistralReasoningParser(BaseThinkingReasoningParser): ...@@ -69,7 +69,7 @@ class MistralReasoningParser(BaseThinkingReasoningParser):
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
has_eot_token = False has_eot_token = False
for id in input_ids[::-1]: for id in reversed(input_ids):
if id == self.start_token_id: if id == self.start_token_id:
# Reasoning ends only if a BOT token is found before a EOT token. # Reasoning ends only if a BOT token is found before a EOT token.
return has_eot_token return has_eot_token
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence from collections.abc import Iterable, Sequence
from itertools import islice
import regex as re import regex as re
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
...@@ -104,13 +105,15 @@ class Step3ReasoningParser(ReasoningParser): ...@@ -104,13 +105,15 @@ class Step3ReasoningParser(ReasoningParser):
return self.think_end_token_id in input_ids return self.think_end_token_id in input_ids
def is_reasoning_end_streaming( def is_reasoning_end_streaming(
self, input_ids: Sequence[int], delta_ids: Sequence[int] self, input_ids: Sequence[int], delta_ids: Iterable[int]
) -> bool: ) -> bool:
end_token_id = self.think_end_token_id end_token_id = self.think_end_token_id
return end_token_id in delta_ids return end_token_id in delta_ids
def extract_content_ids(self, input_ids: list[int]) -> list[int]: def extract_content_ids(self, input_ids: list[int]) -> list[int]:
if self.think_end_token_id not in input_ids[:-1]: if self.think_end_token_id not in islice(
input_ids, 0, max(0, len(input_ids) - 1)
):
return [] return []
else: else:
return input_ids[input_ids.index(self.think_end_token_id) + 1 :] return input_ids[input_ids.index(self.think_end_token_id) + 1 :]
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence from collections.abc import Iterable, Sequence
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
...@@ -51,7 +51,7 @@ class Step3p5ReasoningParser(BaseThinkingReasoningParser): ...@@ -51,7 +51,7 @@ class Step3p5ReasoningParser(BaseThinkingReasoningParser):
return self.end_offset < 1 return self.end_offset < 1
def is_reasoning_end_streaming( def is_reasoning_end_streaming(
self, input_ids: Sequence[int], delta_ids: Sequence[int] self, input_ids: Sequence[int], delta_ids: Iterable[int]
) -> bool: ) -> bool:
if self.end_token_id in input_ids and self.end_offset > 0: if self.end_token_id in input_ids and self.end_offset > 0:
self.end_offset -= 1 self.end_offset -= 1
......
...@@ -324,8 +324,11 @@ class StructuredOutputManager: ...@@ -324,8 +324,11 @@ class StructuredOutputManager:
# Check if reasoning ends in *this* step # Check if reasoning ends in *this* step
delta_from = request.num_computed_tokens - request.num_output_placeholders delta_from = request.num_computed_tokens - request.num_output_placeholders
all_token_ids = request.all_token_ids all_token_ids = request.all_token_ids
start = (
delta_from if delta_from >= 0 else max(len(all_token_ids) + delta_from, 0)
)
if self.reasoner.is_reasoning_end_streaming( if self.reasoner.is_reasoning_end_streaming(
all_token_ids, all_token_ids[delta_from:] all_token_ids, itertools.islice(all_token_ids, start, None)
): ):
# Reasoning just ended, so we shouldn't advance til # Reasoning just ended, so we shouldn't advance til
# next pass # next pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment