Unverified Commit 3e802e87 authored by Flora Feng's avatar Flora Feng Committed by GitHub
Browse files

[Mypy] Fix adjust_request typing (#38264)


Signed-off-by: default avatarsfeng33 <4florafeng@gmail.com>
parent 350af48e
...@@ -505,7 +505,7 @@ Here is a summary of a plugin file: ...@@ -505,7 +505,7 @@ Here is a summary of a plugin file:
# adjust request. e.g.: set skip special tokens # adjust request. e.g.: set skip special tokens
# to False for tool call output. # to False for tool call output.
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: def adjust_request(self, request: ChatCompletionRequest | ResponsesRequest) -> ChatCompletionRequest | ResponsesRequest:
return request return request
# implement the tool call parse for stream call # implement the tool call parse for stream call
......
...@@ -546,7 +546,7 @@ class OpenAIServingRender: ...@@ -546,7 +546,7 @@ class OpenAIServingRender:
raise NotImplementedError(msg) raise NotImplementedError(msg)
tokenizer = renderer.get_tokenizer() tokenizer = renderer.get_tokenizer()
request = tool_parser(tokenizer, request.tools).adjust_request( request = tool_parser(tokenizer, request.tools).adjust_request(
request=request # type: ignore[arg-type] request=request
) )
return conversation, [engine_input] return conversation, [engine_input]
...@@ -32,9 +32,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -32,9 +32,7 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall, FunctionCall,
FunctionDefinition, FunctionDefinition,
) )
from vllm.entrypoints.openai.responses.protocol import ( from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
ResponsesRequest,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
...@@ -229,7 +227,9 @@ class Parser: ...@@ -229,7 +227,9 @@ class Parser:
# ========== Tool Parser Methods ========== # ========== Tool Parser Methods ==========
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
""" """
Adjust the request parameters for tool calling. Adjust the request parameters for tool calling.
......
...@@ -62,7 +62,9 @@ class ToolParser: ...@@ -62,7 +62,9 @@ class ToolParser:
# whereas all tokenizers have .get_vocab() # whereas all tokenizers have .get_vocab()
return self.model_tokenizer.get_vocab() return self.model_tokenizer.get_vocab()
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
""" """
Static method that used to adjust the request parameters. Static method that used to adjust the request parameters.
""" """
......
...@@ -19,6 +19,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -19,6 +19,7 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall, FunctionCall,
ToolCall, ToolCall,
) )
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
...@@ -78,7 +79,9 @@ class DeepSeekV32ToolParser(ToolParser): ...@@ -78,7 +79,9 @@ class DeepSeekV32ToolParser(ToolParser):
"vLLM Successfully import tool parser %s !", self.__class__.__name__ "vLLM Successfully import tool parser %s !", self.__class__.__name__
) )
def adjust_request(self, request): def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request) request = super().adjust_request(request)
if request.tools and request.tool_choice != "none": if request.tools and request.tool_choice != "none":
# Ensure tool call tokens # Ensure tool call tokens
......
...@@ -18,6 +18,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -18,6 +18,7 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall, FunctionCall,
ToolCall, ToolCall,
) )
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
...@@ -86,7 +87,9 @@ class FunctionGemmaToolParser(ToolParser): ...@@ -86,7 +87,9 @@ class FunctionGemmaToolParser(ToolParser):
return arguments return arguments
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request) request = super().adjust_request(request)
if request.tools and request.tool_choice != "none": if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False request.skip_special_tokens = False
......
...@@ -18,6 +18,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -18,6 +18,7 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall, FunctionCall,
ToolCall, ToolCall,
) )
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
...@@ -55,7 +56,9 @@ class GigaChat3ToolParser(ToolParser): ...@@ -55,7 +56,9 @@ class GigaChat3ToolParser(ToolParser):
self.end_content: bool = False self.end_content: bool = False
self.streamed_args_for_tool: list[str] = [] self.streamed_args_for_tool: list[str] = []
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request) request = super().adjust_request(request)
if request.tools and request.tool_choice != "none": if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False request.skip_special_tokens = False
......
...@@ -30,6 +30,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -30,6 +30,7 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall, FunctionCall,
ToolCall, ToolCall,
) )
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
...@@ -151,7 +152,9 @@ class Glm4MoeModelToolParser(ToolParser): ...@@ -151,7 +152,9 @@ class Glm4MoeModelToolParser(ToolParser):
logger.exception("Failed to determine if tools are enabled.") logger.exception("Failed to determine if tools are enabled.")
return False return False
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
"""Adjust request parameters for tool call token handling.""" """Adjust request parameters for tool call token handling."""
request = super().adjust_request(request) request = super().adjust_request(request)
if request.tools and request.tool_choice != "none": if request.tools and request.tool_choice != "none":
......
...@@ -19,6 +19,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -19,6 +19,7 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall, FunctionCall,
ToolCall, ToolCall,
) )
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
...@@ -59,7 +60,9 @@ class Granite4ToolParser(ToolParser): ...@@ -59,7 +60,9 @@ class Granite4ToolParser(ToolParser):
self.start_regex = re.compile(self.tc_start) self.start_regex = re.compile(self.tc_start)
self.end_regex = re.compile(self.tc_end) self.end_regex = re.compile(self.tc_end)
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request) request = super().adjust_request(request)
if request.tools and request.tool_choice != "none": if request.tools and request.tool_choice != "none":
# do not skip special tokens because the tool_call tokens are # do not skip special tokens because the tool_call tokens are
......
...@@ -18,6 +18,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -18,6 +18,7 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall, FunctionCall,
ToolCall, ToolCall,
) )
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
...@@ -77,7 +78,9 @@ class Hermes2ProToolParser(ToolParser): ...@@ -77,7 +78,9 @@ class Hermes2ProToolParser(ToolParser):
# Streaming state: what has been sent to the client. # Streaming state: what has been sent to the client.
self._sent_content_idx: int = 0 self._sent_content_idx: int = 0
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request) request = super().adjust_request(request)
if request.tools and request.tool_choice != "none": if request.tools and request.tool_choice != "none":
# do not skip special tokens because the tool_call tokens are # do not skip special tokens because the tool_call tokens are
......
...@@ -19,6 +19,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -19,6 +19,7 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall, FunctionCall,
ToolCall, ToolCall,
) )
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
...@@ -35,7 +36,9 @@ class Internlm2ToolParser(ToolParser): ...@@ -35,7 +36,9 @@ class Internlm2ToolParser(ToolParser):
super().__init__(tokenizer, tools) super().__init__(tokenizer, tools)
self.position = 0 self.position = 0
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request) request = super().adjust_request(request)
if request.tools and request.tool_choice != "none": if request.tools and request.tool_choice != "none":
# do not skip special tokens because internlm use the special # do not skip special tokens because internlm use the special
......
...@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall, FunctionCall,
ToolCall, ToolCall,
) )
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
...@@ -68,7 +69,9 @@ class JambaToolParser(ToolParser): ...@@ -68,7 +69,9 @@ class JambaToolParser(ToolParser):
"tokens in the tokenizer!" "tokens in the tokenizer!"
) )
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request) request = super().adjust_request(request)
if request.tools and request.tool_choice != "none": if request.tools and request.tool_choice != "none":
# do not skip special tokens because jamba use the special # do not skip special tokens because jamba use the special
......
...@@ -23,6 +23,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -23,6 +23,7 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall, FunctionCall,
ToolCall, ToolCall,
) )
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
...@@ -111,7 +112,9 @@ class MistralToolParser(ToolParser): ...@@ -111,7 +112,9 @@ class MistralToolParser(ToolParser):
"the tokenizer!" "the tokenizer!"
) )
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request) request = super().adjust_request(request)
if ( if (
not is_mistral_tokenizer(self.model_tokenizer) not is_mistral_tokenizer(self.model_tokenizer)
......
...@@ -19,6 +19,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -19,6 +19,7 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall, FunctionCall,
ToolCall, ToolCall,
) )
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
...@@ -51,7 +52,9 @@ class Step3ToolParser(ToolParser): ...@@ -51,7 +52,9 @@ class Step3ToolParser(ToolParser):
self.tool_block_started = False self.tool_block_started = False
self.tool_block_finished = False self.tool_block_finished = False
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request) request = super().adjust_request(request)
if request.tools and request.tool_choice != "none": if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False request.skip_special_tokens = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment