Unverified Commit 3e802e87 authored by Flora Feng's avatar Flora Feng Committed by GitHub
Browse files

[Mypy] Fix adjust_request typing (#38264)


Signed-off-by: default avatarsfeng33 <4florafeng@gmail.com>
parent 350af48e
......@@ -505,7 +505,7 @@ Here is a summary of a plugin file:
# adjust request. e.g.: set skip special tokens
# to False for tool call output.
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
def adjust_request(self, request: ChatCompletionRequest | ResponsesRequest) -> ChatCompletionRequest | ResponsesRequest:
return request
# implement the tool call parse for stream call
......
......@@ -546,7 +546,7 @@ class OpenAIServingRender:
raise NotImplementedError(msg)
tokenizer = renderer.get_tokenizer()
request = tool_parser(tokenizer, request.tools).adjust_request(
request=request # type: ignore[arg-type]
request=request
)
return conversation, [engine_input]
......@@ -32,9 +32,7 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall,
FunctionDefinition,
)
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.tokenizers import TokenizerLike
......@@ -229,7 +227,9 @@ class Parser:
# ========== Tool Parser Methods ==========
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
"""
Adjust the request parameters for tool calling.
......
......@@ -62,7 +62,9 @@ class ToolParser:
# whereas all tokenizers have .get_vocab()
return self.model_tokenizer.get_vocab()
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
"""
Static method that used to adjust the request parameters.
"""
......
......@@ -19,6 +19,7 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
......@@ -78,7 +79,9 @@ class DeepSeekV32ToolParser(ToolParser):
"vLLM Successfully import tool parser %s !", self.__class__.__name__
)
def adjust_request(self, request):
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
# Ensure tool call tokens
......
......@@ -18,6 +18,7 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
......@@ -86,7 +87,9 @@ class FunctionGemmaToolParser(ToolParser):
return arguments
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False
......
......@@ -18,6 +18,7 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
......@@ -55,7 +56,9 @@ class GigaChat3ToolParser(ToolParser):
self.end_content: bool = False
self.streamed_args_for_tool: list[str] = []
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False
......
......@@ -30,6 +30,7 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
......@@ -151,7 +152,9 @@ class Glm4MoeModelToolParser(ToolParser):
logger.exception("Failed to determine if tools are enabled.")
return False
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
"""Adjust request parameters for tool call token handling."""
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
......
......@@ -19,6 +19,7 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
......@@ -59,7 +60,9 @@ class Granite4ToolParser(ToolParser):
self.start_regex = re.compile(self.tc_start)
self.end_regex = re.compile(self.tc_end)
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
# do not skip special tokens because the tool_call tokens are
......
......@@ -18,6 +18,7 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
......@@ -77,7 +78,9 @@ class Hermes2ProToolParser(ToolParser):
# Streaming state: what has been sent to the client.
self._sent_content_idx: int = 0
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
# do not skip special tokens because the tool_call tokens are
......
......@@ -19,6 +19,7 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
......@@ -35,7 +36,9 @@ class Internlm2ToolParser(ToolParser):
super().__init__(tokenizer, tools)
self.position = 0
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
# do not skip special tokens because internlm use the special
......
......@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
......@@ -68,7 +69,9 @@ class JambaToolParser(ToolParser):
"tokens in the tokenizer!"
)
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
# do not skip special tokens because jamba use the special
......
......@@ -23,6 +23,7 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
......@@ -111,7 +112,9 @@ class MistralToolParser(ToolParser):
"the tokenizer!"
)
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request)
if (
not is_mistral_tokenizer(self.model_tokenizer)
......
......@@ -19,6 +19,7 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
......@@ -51,7 +52,9 @@ class Step3ToolParser(ToolParser):
self.tool_block_started = False
self.tool_block_finished = False
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment