Unverified Commit 7f1f36bf authored by Martin Hickey's avatar Martin Hickey Committed by GitHub
Browse files

[CI] Fix mypy for vllm/reasoning (#35742)


Signed-off-by: default avatarMartin Hickey <martin.hickey@ie.ibm.com>
Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 5282c7d4
...@@ -23,7 +23,7 @@ class TestGptOssStructuralTagsIntegration: ...@@ -23,7 +23,7 @@ class TestGptOssStructuralTagsIntegration:
"""Create a mock tokenizer.""" """Create a mock tokenizer."""
tokenizer = Mock() tokenizer = Mock()
tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5]) tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
tokenizer.vocab = {"<|end|>": 6} tokenizer.get_vocab = Mock(return_value={"<|end|>": 6})
return tokenizer return tokenizer
@pytest.fixture @pytest.fixture
......
...@@ -25,7 +25,7 @@ class TestGptOssReasoningParser: ...@@ -25,7 +25,7 @@ class TestGptOssReasoningParser:
"""Create a mock tokenizer for testing.""" """Create a mock tokenizer for testing."""
tokenizer = Mock() tokenizer = Mock()
tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5]) tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
tokenizer.vocab = {"<|end|>": 6} tokenizer.get_vocab = Mock(return_value={"<|end|>": 6})
return tokenizer return tokenizer
@pytest.fixture @pytest.fixture
......
...@@ -41,7 +41,6 @@ EXCLUDE = [ ...@@ -41,7 +41,6 @@ EXCLUDE = [
# TODO: Remove these entries after fixing mypy errors. # TODO: Remove these entries after fixing mypy errors.
"vllm/benchmarks", "vllm/benchmarks",
"vllm/config", "vllm/config",
"vllm/reasoning",
] ]
......
...@@ -6,7 +6,7 @@ import os ...@@ -6,7 +6,7 @@ import os
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Callable, Iterable, Sequence from collections.abc import Callable, Iterable, Sequence
from functools import cached_property from functools import cached_property
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING
from vllm.entrypoints.mcp.tool_server import ToolServer from vllm.entrypoints.mcp.tool_server import ToolServer
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -14,21 +14,10 @@ from vllm.utils.collection_utils import is_list_of ...@@ -14,21 +14,10 @@ from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import import_from_path from vllm.utils.import_utils import import_from_path
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
ChatCompletionRequest, from vllm.entrypoints.openai.engine.protocol import DeltaMessage
) from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage,
)
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
else:
ChatCompletionRequest = Any
DeltaMessage = Any
ResponsesRequest = Any
TokenizerLike = Any
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -41,7 +30,7 @@ class ReasoningParser: ...@@ -41,7 +30,7 @@ class ReasoningParser:
It is used to extract reasoning content from the model output. It is used to extract reasoning content from the model output.
""" """
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs): def __init__(self, tokenizer: "TokenizerLike", *args, **kwargs):
self.model_tokenizer = tokenizer self.model_tokenizer = tokenizer
@cached_property @cached_property
...@@ -127,7 +116,7 @@ class ReasoningParser: ...@@ -127,7 +116,7 @@ class ReasoningParser:
def extract_reasoning( def extract_reasoning(
self, self,
model_output: str, model_output: str,
request: ChatCompletionRequest | ResponsesRequest, request: "ChatCompletionRequest | ResponsesRequest",
) -> tuple[str | None, str | None]: ) -> tuple[str | None, str | None]:
""" """
Extract reasoning content from a complete model-generated string. Extract reasoning content from a complete model-generated string.
...@@ -136,14 +125,10 @@ class ReasoningParser: ...@@ -136,14 +125,10 @@ class ReasoningParser:
available before sending to the client. available before sending to the client.
Parameters: Parameters:
model_output: str model_output: The model-generated string to extract reasoning content from.
The model-generated string to extract reasoning content from. request: The request object that was used to generate the model_output.
request: ChatCompletionRequest
The request object that was used to generate the model_output.
Returns: Returns:
tuple[Optional[str], Optional[str]]
A tuple containing the reasoning content and the content. A tuple containing the reasoning content and the content.
""" """
...@@ -156,7 +141,7 @@ class ReasoningParser: ...@@ -156,7 +141,7 @@ class ReasoningParser:
previous_token_ids: Sequence[int], previous_token_ids: Sequence[int],
current_token_ids: Sequence[int], current_token_ids: Sequence[int],
delta_token_ids: Sequence[int], delta_token_ids: Sequence[int],
) -> DeltaMessage | None: ) -> "DeltaMessage | None":
""" """
Instance method that should be implemented for extracting reasoning Instance method that should be implemented for extracting reasoning
from an incomplete response; for use when handling reasoning calls and from an incomplete response; for use when handling reasoning calls and
......
...@@ -4,22 +4,15 @@ ...@@ -4,22 +4,15 @@
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from itertools import islice from itertools import islice
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING
from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
ChatCompletionRequest, from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
)
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
else:
ChatCompletionRequest = Any
ResponsesRequest = Any
class BaseThinkingReasoningParser(ReasoningParser): class BaseThinkingReasoningParser(ReasoningParser):
...@@ -58,13 +51,15 @@ class BaseThinkingReasoningParser(ReasoningParser): ...@@ -58,13 +51,15 @@ class BaseThinkingReasoningParser(ReasoningParser):
if not self.start_token or not self.end_token: if not self.start_token or not self.end_token:
raise ValueError("start_token and end_token must be defined in subclasses") raise ValueError("start_token and end_token must be defined in subclasses")
self.start_token_id = self.vocab.get(self.start_token) start_token_id = self.vocab.get(self.start_token)
self.end_token_id = self.vocab.get(self.end_token) end_token_id = self.vocab.get(self.end_token)
if self.start_token_id is None or self.end_token_id is None: if start_token_id is None or end_token_id is None:
raise RuntimeError( raise RuntimeError(
f"{self.__class__.__name__} reasoning parser could not locate " f"{self.__class__.__name__} reasoning parser could not locate "
"think start/end tokens in the tokenizer!" "think start/end tokens in the tokenizer!"
) )
self.start_token_id: int = start_token_id
self.end_token_id: int = end_token_id
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
start_token_id = self.start_token_id start_token_id = self.start_token_id
...@@ -152,7 +147,7 @@ class BaseThinkingReasoningParser(ReasoningParser): ...@@ -152,7 +147,7 @@ class BaseThinkingReasoningParser(ReasoningParser):
return DeltaMessage(content=delta_text) return DeltaMessage(content=delta_text)
def extract_reasoning( def extract_reasoning(
self, model_output: str, request: ChatCompletionRequest | ResponsesRequest self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
) -> tuple[str | None, str | None]: ) -> tuple[str | None, str | None]:
""" """
Extract reasoning content from the model output. Extract reasoning content from the model output.
......
...@@ -2,19 +2,21 @@ ...@@ -2,19 +2,21 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser from vllm.reasoning import ReasoningParser
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from .identity_reasoning_parser import IdentityReasoningParser from .identity_reasoning_parser import IdentityReasoningParser
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -32,6 +34,7 @@ class DeepSeekV3ReasoningParser(ReasoningParser): ...@@ -32,6 +34,7 @@ class DeepSeekV3ReasoningParser(ReasoningParser):
enable_thinking = bool(chat_kwargs.get("enable_thinking", False)) enable_thinking = bool(chat_kwargs.get("enable_thinking", False))
thinking = thinking or enable_thinking thinking = thinking or enable_thinking
self._parser: ReasoningParser
if thinking: if thinking:
self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs) self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs)
else: else:
...@@ -49,7 +52,7 @@ class DeepSeekV3ReasoningParser(ReasoningParser): ...@@ -49,7 +52,7 @@ class DeepSeekV3ReasoningParser(ReasoningParser):
return self._parser.extract_content_ids(input_ids) return self._parser.extract_content_ids(input_ids)
def extract_reasoning( def extract_reasoning(
self, model_output: str, request: ChatCompletionRequest self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
) -> tuple[str | None, str | None]: ) -> tuple[str | None, str | None]:
return self._parser.extract_reasoning(model_output, request) return self._parser.extract_reasoning(model_output, request)
...@@ -61,7 +64,7 @@ class DeepSeekV3ReasoningParser(ReasoningParser): ...@@ -61,7 +64,7 @@ class DeepSeekV3ReasoningParser(ReasoningParser):
previous_token_ids: Sequence[int], previous_token_ids: Sequence[int],
current_token_ids: Sequence[int], current_token_ids: Sequence[int],
delta_token_ids: Sequence[int], delta_token_ids: Sequence[int],
) -> DeltaMessage | None: ) -> "DeltaMessage | None":
return self._parser.extract_reasoning_streaming( return self._parser.extract_reasoning_streaming(
previous_text, previous_text,
current_text, current_text,
......
...@@ -2,16 +2,18 @@ ...@@ -2,16 +2,18 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence from collections.abc import Sequence
from typing import TYPE_CHECKING
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -46,20 +48,12 @@ class Ernie45ReasoningParser(BaseThinkingReasoningParser): ...@@ -46,20 +48,12 @@ class Ernie45ReasoningParser(BaseThinkingReasoningParser):
"constructor during construction." "constructor during construction."
) )
self.start_token_id = self.vocab.get(self.start_token)
self.end_token_id = self.vocab.get(self.end_token)
self.response_start_token_id = self.vocab.get(self.response_start_token) self.response_start_token_id = self.vocab.get(self.response_start_token)
self.response_end_token_id = self.vocab.get(self.response_end_token) self.response_end_token_id = self.vocab.get(self.response_end_token)
self.newline_token_id = self.vocab.get(self.newline_token) self.newline_token_id = self.vocab.get(self.newline_token)
self.parser_token_ids = [self.end_token_id, self.response_end_token_id] self.parser_token_ids = [self.end_token_id, self.response_end_token_id]
if self.start_token_id is None or self.end_token_id is None:
raise RuntimeError(
"Ernie45 reasoning parser could not locate think start/end "
"tokens in the tokenizer!"
)
def extract_reasoning_streaming( def extract_reasoning_streaming(
self, self,
previous_text: str, previous_text: str,
...@@ -144,7 +138,7 @@ class Ernie45ReasoningParser(BaseThinkingReasoningParser): ...@@ -144,7 +138,7 @@ class Ernie45ReasoningParser(BaseThinkingReasoningParser):
return DeltaMessage(reasoning=delta_text) return DeltaMessage(reasoning=delta_text)
def extract_reasoning( def extract_reasoning(
self, model_output: str, request: ChatCompletionRequest self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
) -> tuple[str | None, str | None]: ) -> tuple[str | None, str | None]:
""" """
Extract reasoning content from the model output. Extract reasoning content from the model output.
......
...@@ -2,18 +2,20 @@ ...@@ -2,18 +2,20 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json import json
from collections.abc import Sequence from collections.abc import Sequence
from typing import TYPE_CHECKING
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.mcp.tool_server import ToolServer from vllm.entrypoints.mcp.tool_server import ToolServer
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.entrypoints.openai.parser.harmony_utils import parse_chat_output from vllm.entrypoints.openai.parser.harmony_utils import parse_chat_output
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser from vllm.reasoning import ReasoningParser
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
logger = init_logger(__name__) logger = init_logger(__name__)
no_func_reaonsing_tag = { no_func_reaonsing_tag = {
...@@ -78,7 +80,7 @@ class GptOssReasoningParser(ReasoningParser): ...@@ -78,7 +80,7 @@ class GptOssReasoningParser(ReasoningParser):
self.reasoning_end_token_ids_suffix = self.model_tokenizer.encode("<|message|>") self.reasoning_end_token_ids_suffix = self.model_tokenizer.encode("<|message|>")
# We also need to check for the <|end|> token to avoid false positives from # We also need to check for the <|end|> token to avoid false positives from
# previous messages in multi-turn conversations. # previous messages in multi-turn conversations.
self.eom_token_id = self.model_tokenizer.vocab["<|end|>"] self.eom_token_id = self.vocab["<|end|>"]
self.reasoning_max_num_between_tokens = 20 self.reasoning_max_num_between_tokens = 20
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
...@@ -148,7 +150,7 @@ class GptOssReasoningParser(ReasoningParser): ...@@ -148,7 +150,7 @@ class GptOssReasoningParser(ReasoningParser):
def extract_reasoning( def extract_reasoning(
self, self,
model_output: str, model_output: str,
request: ChatCompletionRequest, request: "ChatCompletionRequest | ResponsesRequest",
) -> tuple[str | None, str | None]: ) -> tuple[str | None, str | None]:
raise NotImplementedError( raise NotImplementedError(
"gpt-oss has a special branch for parsing reasoning in non-streaming mode. This method shouldn't be used." # noqa: E501 "gpt-oss has a special branch for parsing reasoning in non-streaming mode. This method shouldn't be used." # noqa: E501
......
...@@ -2,17 +2,19 @@ ...@@ -2,17 +2,19 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence from collections.abc import Sequence
from typing import TYPE_CHECKING
import regex as re import regex as re
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser from vllm.reasoning import ReasoningParser
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -53,7 +55,7 @@ class GraniteReasoningParser(ReasoningParser): ...@@ -53,7 +55,7 @@ class GraniteReasoningParser(ReasoningParser):
) )
def extract_reasoning( def extract_reasoning(
self, model_output: str, request: ChatCompletionRequest self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
) -> tuple[str | None, str | None]: ) -> tuple[str | None, str | None]:
"""Extract the reasoning content & content sections, respectively. """Extract the reasoning content & content sections, respectively.
If the sequence doesn't match what we expect, i.e., the model generates If the sequence doesn't match what we expect, i.e., the model generates
......
...@@ -2,17 +2,19 @@ ...@@ -2,17 +2,19 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence from collections.abc import Sequence
from typing import TYPE_CHECKING
import regex as re import regex as re
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser from vllm.reasoning import ReasoningParser
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -65,8 +67,8 @@ class HunyuanA13BReasoningParser(ReasoningParser): ...@@ -65,8 +67,8 @@ class HunyuanA13BReasoningParser(ReasoningParser):
self.fast_think_ids = [14023, 771, 1363, 524, 27963, 397, 27, 9399, 397] self.fast_think_ids = [14023, 771, 1363, 524, 27963, 397, 27, 9399, 397]
# when state change, send out all the buffered text in last state # when state change, send out all the buffered text in last state
self.buffered_text = [] self.buffered_text: list[str] = []
self.buffered_ids = [] self.buffered_ids: list[int] = []
self.current_state = "reasoning" self.current_state = "reasoning"
self.all_states = ["reasoning", "response"] self.all_states = ["reasoning", "response"]
...@@ -76,7 +78,7 @@ class HunyuanA13BReasoningParser(ReasoningParser): ...@@ -76,7 +78,7 @@ class HunyuanA13BReasoningParser(ReasoningParser):
# this sequence only for the think start, it has two way to start. # this sequence only for the think start, it has two way to start.
self.expected_sequence_side = self.think_start_ids_fast self.expected_sequence_side = self.think_start_ids_fast
self.sequence_index = 0 self.sequence_index = 0
self.token_buffer = [] self.token_buffer: list[int] = []
self.text_buffer = "" self.text_buffer = ""
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
...@@ -90,7 +92,7 @@ class HunyuanA13BReasoningParser(ReasoningParser): ...@@ -90,7 +92,7 @@ class HunyuanA13BReasoningParser(ReasoningParser):
return [] return []
def extract_reasoning( def extract_reasoning(
self, model_output: str, request: ChatCompletionRequest self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
) -> tuple[str | None, str | None]: ) -> tuple[str | None, str | None]:
"""Extract the reasoning content & content sections, respectively. """Extract the reasoning content & content sections, respectively.
If the sequence doesn't match what we expect, i.e., the model generates If the sequence doesn't match what we expect, i.e., the model generates
......
...@@ -2,16 +2,18 @@ ...@@ -2,16 +2,18 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser from vllm.reasoning import ReasoningParser
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -59,7 +61,7 @@ class IdentityReasoningParser(ReasoningParser): ...@@ -59,7 +61,7 @@ class IdentityReasoningParser(ReasoningParser):
return None return None
def extract_reasoning( def extract_reasoning(
self, model_output: str, request: ChatCompletionRequest self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
) -> tuple[str | None, str | None]: ) -> tuple[str | None, str | None]:
# No reasoning separation: return None for reasoning, # No reasoning separation: return None for reasoning,
# and full model_output as content # and full model_output as content
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
class KimiK2ReasoningParser(ReasoningParser): class KimiK2ReasoningParser(ReasoningParser):
""" """
...@@ -39,6 +41,7 @@ class KimiK2ReasoningParser(ReasoningParser): ...@@ -39,6 +41,7 @@ class KimiK2ReasoningParser(ReasoningParser):
thinking = bool(chat_kwargs.get("thinking", True)) thinking = bool(chat_kwargs.get("thinking", True))
# If thinking is not enabled, use identity parser to fall through # If thinking is not enabled, use identity parser to fall through
self._identity_parser: IdentityReasoningParser | None
if not thinking: if not thinking:
self._identity_parser = IdentityReasoningParser(tokenizer, *args, **kwargs) self._identity_parser = IdentityReasoningParser(tokenizer, *args, **kwargs)
else: else:
...@@ -62,10 +65,6 @@ class KimiK2ReasoningParser(ReasoningParser): ...@@ -62,10 +65,6 @@ class KimiK2ReasoningParser(ReasoningParser):
"tokens in the tokenizer!" "tokens in the tokenizer!"
) )
def _is_identity_mode(self) -> bool:
"""Check if parser is in identity mode (no reasoning extraction)."""
return self._identity_parser is not None
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
""" """
Check if the reasoning content ends in the input_ids. Check if the reasoning content ends in the input_ids.
...@@ -74,7 +73,7 @@ class KimiK2ReasoningParser(ReasoningParser): ...@@ -74,7 +73,7 @@ class KimiK2ReasoningParser(ReasoningParser):
1. The end token (</think>) 1. The end token (</think>)
2. The tool section start token (<|tool_calls_section_begin|>) 2. The tool section start token (<|tool_calls_section_begin|>)
""" """
if self._is_identity_mode(): if self._identity_parser is not None:
return self._identity_parser.is_reasoning_end(input_ids) return self._identity_parser.is_reasoning_end(input_ids)
start_token_id = self._start_token_id start_token_id = self._start_token_id
...@@ -95,29 +94,32 @@ class KimiK2ReasoningParser(ReasoningParser): ...@@ -95,29 +94,32 @@ class KimiK2ReasoningParser(ReasoningParser):
return False return False
def is_reasoning_end_streaming( def is_reasoning_end_streaming(
self, input_ids: Sequence[int], delta_ids: Sequence[int] self, input_ids: Sequence[int], delta_ids: Iterable[int]
) -> bool: ) -> bool:
""" """
Check if the reasoning content ends in the input_ids on a decode step. Check if the reasoning content ends in the input_ids on a decode step.
""" """
if self._is_identity_mode(): if self._identity_parser is not None:
return self._identity_parser.is_reasoning_end_streaming( return self._identity_parser.is_reasoning_end_streaming(
input_ids, delta_ids input_ids, delta_ids
) )
# Materialize iterable for membership checks
delta_ids_set = set(delta_ids)
# Check for explicit end token or implicit tool section start in delta # Check for explicit end token or implicit tool section start in delta
if self._end_token_id in delta_ids: if self._end_token_id in delta_ids_set:
return True return True
return ( return (
self._tool_section_start_token_id is not None self._tool_section_start_token_id is not None
and self._tool_section_start_token_id in delta_ids and self._tool_section_start_token_id in delta_ids_set
) )
def extract_content_ids(self, input_ids: list[int]) -> list[int]: def extract_content_ids(self, input_ids: list[int]) -> list[int]:
""" """
Extract content token ids from the input_ids. Extract content token ids from the input_ids.
""" """
if self._is_identity_mode(): if self._identity_parser is not None:
return self._identity_parser.extract_content_ids(input_ids) return self._identity_parser.extract_content_ids(input_ids)
if self._end_token_id in input_ids: if self._end_token_id in input_ids:
...@@ -145,12 +147,12 @@ class KimiK2ReasoningParser(ReasoningParser): ...@@ -145,12 +147,12 @@ class KimiK2ReasoningParser(ReasoningParser):
return [] return []
def extract_reasoning( def extract_reasoning(
self, model_output: str, request: ChatCompletionRequest self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
) -> tuple[str | None, str | None]: ) -> tuple[str | None, str | None]:
""" """
Extract reasoning content from the model output. Extract reasoning content from the model output.
""" """
if self._is_identity_mode(): if self._identity_parser is not None:
return self._identity_parser.extract_reasoning(model_output, request) return self._identity_parser.extract_reasoning(model_output, request)
# thinking does not require a think start token but consume it if present # thinking does not require a think start token but consume it if present
...@@ -189,7 +191,7 @@ class KimiK2ReasoningParser(ReasoningParser): ...@@ -189,7 +191,7 @@ class KimiK2ReasoningParser(ReasoningParser):
""" """
Extract reasoning content from a delta message during streaming. Extract reasoning content from a delta message during streaming.
""" """
if self._is_identity_mode(): if self._identity_parser is not None:
return self._identity_parser.extract_reasoning_streaming( return self._identity_parser.extract_reasoning_streaming(
previous_text, previous_text,
current_text, current_text,
......
...@@ -2,21 +2,20 @@ ...@@ -2,21 +2,20 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence from collections.abc import Sequence
from typing import TYPE_CHECKING
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage, DeltaMessage,
) )
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -114,6 +113,6 @@ class MiniMaxM2AppendThinkReasoningParser(ReasoningParser): ...@@ -114,6 +113,6 @@ class MiniMaxM2AppendThinkReasoningParser(ReasoningParser):
return DeltaMessage(content=delta_text) return DeltaMessage(content=delta_text)
def extract_reasoning( def extract_reasoning(
self, model_output: str, request: ChatCompletionRequest | ResponsesRequest self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
) -> tuple[str | None, str | None]: ) -> tuple[str | None, str | None]:
return None, "<think>" + model_output return None, "<think>" + model_output
...@@ -3,18 +3,17 @@ ...@@ -3,18 +3,17 @@
from collections.abc import Sequence from collections.abc import Sequence
from functools import cached_property from functools import cached_property
from typing import TYPE_CHECKING
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser from vllm.reasoning import ReasoningParser
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -113,7 +112,7 @@ class MistralReasoningParser(BaseThinkingReasoningParser): ...@@ -113,7 +112,7 @@ class MistralReasoningParser(BaseThinkingReasoningParser):
return input_ids[:eot_token_index] + input_ids[eot_token_index + 1 :] return input_ids[:eot_token_index] + input_ids[eot_token_index + 1 :]
def extract_reasoning( def extract_reasoning(
self, model_output: str, request: ChatCompletionRequest | ResponsesRequest self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
) -> tuple[str | None, str | None]: ) -> tuple[str | None, str | None]:
""" """
Extract reasoning content from the model output. Extract reasoning content from the model output.
......
...@@ -8,20 +8,15 @@ from typing import TYPE_CHECKING ...@@ -8,20 +8,15 @@ from typing import TYPE_CHECKING
import regex as re import regex as re
if TYPE_CHECKING: from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.tokenizers import TokenizerLike
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage,
)
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser from vllm.reasoning import ReasoningParser
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -256,15 +251,15 @@ class Olmo3ReasoningParser(ReasoningParser): ...@@ -256,15 +251,15 @@ class Olmo3ReasoningParser(ReasoningParser):
def extract_reasoning( def extract_reasoning(
self, self,
model_output: str, model_output: str,
request: ChatCompletionRequest | ResponsesRequest, request: "ChatCompletionRequest | ResponsesRequest",
) -> tuple[str | None, str | None]: ) -> tuple[str | None, str | None]:
"""Extract the reasoning content & content sections, respectively. """Extract the reasoning content & content sections, respectively.
If the sequence doesn't match what we expect, i.e., the model generates If the sequence doesn't match what we expect, i.e., the model generates
something else, all content is considered non-reasoning content. something else, all content is considered non-reasoning content.
Args: Args:
model_output (str): Output of the model to be parsed. model_output: Output of the model to be parsed.
request (ChatCompletionRequest | ResponsesRequest): Request being request: Request being
processed. processed.
Returns: Returns:
......
...@@ -2,16 +2,15 @@ ...@@ -2,16 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence from collections.abc import Sequence
from typing import TYPE_CHECKING
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
from vllm.tokenizers import TokenizerLike
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.tokenizers import TokenizerLike
class Qwen3ReasoningParser(BaseThinkingReasoningParser): class Qwen3ReasoningParser(BaseThinkingReasoningParser):
...@@ -34,7 +33,7 @@ class Qwen3ReasoningParser(BaseThinkingReasoningParser): ...@@ -34,7 +33,7 @@ class Qwen3ReasoningParser(BaseThinkingReasoningParser):
it is stripped before extraction (non-streaming) or skipped (streaming). it is stripped before extraction (non-streaming) or skipped (streaming).
""" """
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs): def __init__(self, tokenizer: "TokenizerLike", *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs) super().__init__(tokenizer, *args, **kwargs)
chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {}
...@@ -53,7 +52,7 @@ class Qwen3ReasoningParser(BaseThinkingReasoningParser): ...@@ -53,7 +52,7 @@ class Qwen3ReasoningParser(BaseThinkingReasoningParser):
return "</think>" return "</think>"
def extract_reasoning( def extract_reasoning(
self, model_output: str, request: ChatCompletionRequest | ResponsesRequest self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
) -> tuple[str | None, str | None]: ) -> tuple[str | None, str | None]:
""" """
Extract reasoning content from the model output. Extract reasoning content from the model output.
......
...@@ -3,17 +3,19 @@ ...@@ -3,17 +3,19 @@
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from itertools import islice from itertools import islice
from typing import TYPE_CHECKING
import regex as re import regex as re
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser from vllm.reasoning import ReasoningParser
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -37,12 +39,13 @@ class Step3ReasoningParser(ReasoningParser): ...@@ -37,12 +39,13 @@ class Step3ReasoningParser(ReasoningParser):
"constructor during construction." "constructor during construction."
) )
self.think_end_token_id = self.vocab.get(self.think_end_token) think_end_token_id = self.vocab.get(self.think_end_token)
if self.think_end_token_id is None: if think_end_token_id is None:
raise RuntimeError( raise RuntimeError(
"Step3 reasoning parser could not locate think end " "Step3 reasoning parser could not locate think end "
"token in the tokenizer!" "token in the tokenizer!"
) )
self.think_end_token_id: int = think_end_token_id
def extract_reasoning_streaming( def extract_reasoning_streaming(
self, self,
...@@ -82,7 +85,7 @@ class Step3ReasoningParser(ReasoningParser): ...@@ -82,7 +85,7 @@ class Step3ReasoningParser(ReasoningParser):
return DeltaMessage(reasoning=delta_text) return DeltaMessage(reasoning=delta_text)
def extract_reasoning( def extract_reasoning(
self, model_output: str, request: ChatCompletionRequest self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
) -> tuple[str | None, str | None]: ) -> tuple[str | None, str | None]:
# Check if the model output contains the </think> token # Check if the model output contains the </think> token
if self.think_end_token not in model_output: if self.think_end_token not in model_output:
...@@ -94,10 +97,7 @@ class Step3ReasoningParser(ReasoningParser): ...@@ -94,10 +97,7 @@ class Step3ReasoningParser(ReasoningParser):
reasoning = model_output[:end_index] reasoning = model_output[:end_index]
# Content after </think> token # Content after </think> token
content = model_output[end_index + len(self.think_end_token) :] content = model_output[end_index + len(self.think_end_token) :] or None
if len(content) == 0:
content = None
return reasoning, content return reasoning, content
......
...@@ -2,17 +2,16 @@ ...@@ -2,17 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
class Step3p5ReasoningParser(BaseThinkingReasoningParser): class Step3p5ReasoningParser(BaseThinkingReasoningParser):
""" """
...@@ -50,7 +49,7 @@ class Step3p5ReasoningParser(BaseThinkingReasoningParser): ...@@ -50,7 +49,7 @@ class Step3p5ReasoningParser(BaseThinkingReasoningParser):
self, input_ids: Sequence[int], delta_ids: Iterable[int] self, input_ids: Sequence[int], delta_ids: Iterable[int]
) -> bool: ) -> bool:
# Only examine newly generated tokens; they may contain multiple ids. # Only examine newly generated tokens; they may contain multiple ids.
return self._is_reasoning_end_from_ids(delta_ids) return self._is_reasoning_end_from_ids(tuple(delta_ids))
def _is_reasoning_end_from_ids(self, input_ids: Sequence[int]) -> bool: def _is_reasoning_end_from_ids(self, input_ids: Sequence[int]) -> bool:
# Scan backwards to find the last special token, <think> or </think>. # Scan backwards to find the last special token, <think> or </think>.
...@@ -96,7 +95,7 @@ class Step3p5ReasoningParser(BaseThinkingReasoningParser): ...@@ -96,7 +95,7 @@ class Step3p5ReasoningParser(BaseThinkingReasoningParser):
def extract_reasoning( def extract_reasoning(
self, self,
model_output: str, model_output: str,
request: ChatCompletionRequest | ResponsesRequest, request: "ChatCompletionRequest | ResponsesRequest",
) -> tuple[str | None, str | None]: ) -> tuple[str | None, str | None]:
reasoning, content = super().extract_reasoning(model_output, request) reasoning, content = super().extract_reasoning(model_output, request)
if reasoning is not None: if reasoning is not None:
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import functools import functools
import json import json
from collections.abc import Collection, Set from collections.abc import Collection, Sequence, Set
from pathlib import Path from pathlib import Path
from typing import Any, Literal, overload from typing import Any, Literal, overload
...@@ -348,7 +348,9 @@ class Grok2Tokenizer(TokenizerLike): ...@@ -348,7 +348,9 @@ class Grok2Tokenizer(TokenizerLike):
tokens = self._maybe_truncate(tokens, max_length) tokens = self._maybe_truncate(tokens, max_length)
return tokens return tokens
def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str: def decode(
self, ids: Sequence[int] | int, skip_special_tokens: bool = False
) -> str:
if isinstance(ids, int): if isinstance(ids, int):
ids = [ids] ids = [ids]
if skip_special_tokens: if skip_special_tokens:
...@@ -371,7 +373,7 @@ class Grok2Tokenizer(TokenizerLike): ...@@ -371,7 +373,7 @@ class Grok2Tokenizer(TokenizerLike):
return [self._token_to_id.get(token, self._unk_token_id) for token in tokens] return [self._token_to_id.get(token, self._unk_token_id) for token in tokens]
def convert_ids_to_tokens( def convert_ids_to_tokens(
self, ids: list[int], skip_special_tokens: bool = False self, ids: Sequence[int], skip_special_tokens: bool = False
) -> list[str]: ) -> list[str]:
tokens = [] tokens = []
for token_id in ids: for token_id in ids:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, cast, overload from typing import TYPE_CHECKING, Any, cast, overload
...@@ -434,7 +435,9 @@ class MistralTokenizer(TokenizerLike): ...@@ -434,7 +435,9 @@ class MistralTokenizer(TokenizerLike):
return_dict=False, return_dict=False,
) )
def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str: def decode(
self, ids: Sequence[int] | int, skip_special_tokens: bool = False
) -> str:
# TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962 # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962
# is in, directly call self.transformers_tokenizer.decode(...). # is in, directly call self.transformers_tokenizer.decode(...).
if isinstance(ids, int): if isinstance(ids, int):
...@@ -512,7 +515,7 @@ class MistralTokenizer(TokenizerLike): ...@@ -512,7 +515,7 @@ class MistralTokenizer(TokenizerLike):
def convert_ids_to_tokens( def convert_ids_to_tokens(
self, self,
ids: list[int], ids: Sequence[int],
skip_special_tokens: bool = False, skip_special_tokens: bool = False,
) -> list[str]: ) -> list[str]:
if not skip_special_tokens: if not skip_special_tokens:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment