Unverified Commit e2db2b42 authored by Flora Feng's avatar Flora Feng Committed by GitHub
Browse files

[Tool Parser][1/3] Pass tools to ToolParser constructor (#38029)


Signed-off-by: default avatarsfeng33 <4florafeng@gmail.com>
parent 87f05d68
...@@ -565,7 +565,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -565,7 +565,7 @@ class OpenAIServingChat(OpenAIServing):
) )
tool_parsers: list[ToolParser | None] = [ tool_parsers: list[ToolParser | None] = [
self.tool_parser(tokenizer) self.tool_parser(tokenizer, request.tools)
] * num_choices ] * num_choices
else: else:
tool_parsers = [None] * num_choices tool_parsers = [None] * num_choices
...@@ -1331,7 +1331,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1331,7 +1331,7 @@ class OpenAIServingChat(OpenAIServing):
"Tokenizer not available when `skip_tokenizer_init=True`" "Tokenizer not available when `skip_tokenizer_init=True`"
) )
tool_parser = self.tool_parser(tokenizer) tool_parser = self.tool_parser(tokenizer, request.tools)
# NOTE: We use token_ids for openai tool parser # NOTE: We use token_ids for openai tool parser
tool_call_info = tool_parser.extract_tool_calls( tool_call_info = tool_parser.extract_tool_calls(
"", "",
......
...@@ -925,7 +925,7 @@ class OpenAIServing: ...@@ -925,7 +925,7 @@ class OpenAIServing:
# Automatic Tool Call Parsing # Automatic Tool Call Parsing
try: try:
tool_parser = tool_parser_cls(tokenizer) tool_parser = tool_parser_cls(tokenizer, request.tools)
except RuntimeError as e: except RuntimeError as e:
logger.exception("Error in tool parser creation.") logger.exception("Error in tool parser creation.")
raise e raise e
......
...@@ -52,7 +52,7 @@ class ResponsesParser: ...@@ -52,7 +52,7 @@ class ResponsesParser:
self.reasoning_parser_instance = reasoning_parser_cls(tokenizer) self.reasoning_parser_instance = reasoning_parser_cls(tokenizer)
self.tool_parser_instance = None self.tool_parser_instance = None
if tool_parser_cls is not None: if tool_parser_cls is not None:
self.tool_parser_instance = tool_parser_cls(tokenizer) self.tool_parser_instance = tool_parser_cls(tokenizer, request.tools)
# Store the last finish_reason to determine response status # Store the last finish_reason to determine response status
self.finish_reason: str | None = None self.finish_reason: str | None = None
......
...@@ -1344,7 +1344,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1344,7 +1344,7 @@ class OpenAIServingResponses(OpenAIServing):
reasoning_parser = self.parser.reasoning_parser_cls(tokenizer) reasoning_parser = self.parser.reasoning_parser_cls(tokenizer)
tool_parser = None tool_parser = None
if self.parser and self.parser.tool_parser_cls: if self.parser and self.parser.tool_parser_cls:
tool_parser = self.parser.tool_parser_cls(tokenizer) tool_parser = self.parser.tool_parser_cls(tokenizer, request.tools)
reasoning_ended = False reasoning_ended = False
tool_call_text_started = False tool_call_text_started = False
previous_text = "" previous_text = ""
......
...@@ -545,6 +545,8 @@ class OpenAIServingRender: ...@@ -545,6 +545,8 @@ class OpenAIServingRender:
) )
raise NotImplementedError(msg) raise NotImplementedError(msg)
tokenizer = renderer.get_tokenizer() tokenizer = renderer.get_tokenizer()
request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore[arg-type] request = tool_parser(tokenizer, request.tools).adjust_request(
request=request # type: ignore[arg-type]
)
return conversation, [engine_input] return conversation, [engine_input]
...@@ -5,13 +5,18 @@ import importlib ...@@ -5,13 +5,18 @@ import importlib
import os import os
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from functools import cached_property from functools import cached_property
from typing import TypeAlias
from openai.types.responses import ( from openai.types.responses import (
ResponseFormatTextJSONSchemaConfig, ResponseFormatTextJSONSchemaConfig,
ResponseTextConfig, ResponseTextConfig,
) )
from openai.types.responses.tool import Tool as ResponsesTool
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
)
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage, DeltaMessage,
ExtractedToolCallInformation, ExtractedToolCallInformation,
...@@ -30,6 +35,8 @@ from vllm.utils.import_utils import import_from_path ...@@ -30,6 +35,8 @@ from vllm.utils.import_utils import import_from_path
logger = init_logger(__name__) logger = init_logger(__name__)
Tool: TypeAlias = ChatCompletionToolsParam | ResponsesTool
class ToolParser: class ToolParser:
""" """
...@@ -38,7 +45,11 @@ class ToolParser: ...@@ -38,7 +45,11 @@ class ToolParser:
derived classes. derived classes.
""" """
def __init__(self, tokenizer: TokenizerLike): def __init__(
self,
tokenizer: TokenizerLike,
tools: list[Tool] | None = None,
):
self.prev_tool_call_arr: list[dict] = [] self.prev_tool_call_arr: list[dict] = []
# the index of the tool call that is currently being parsed # the index of the tool call that is currently being parsed
self.current_tool_id: int = -1 self.current_tool_id: int = -1
...@@ -46,6 +57,7 @@ class ToolParser: ...@@ -46,6 +57,7 @@ class ToolParser:
self.streamed_args_for_tool: list[str] = [] self.streamed_args_for_tool: list[str] = []
self.model_tokenizer = tokenizer self.model_tokenizer = tokenizer
self.tools = tools
@cached_property @cached_property
def vocab(self) -> dict[str, int]: def vocab(self) -> dict[str, int]:
......
...@@ -19,14 +19,14 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -19,14 +19,14 @@ from vllm.entrypoints.openai.engine.protocol import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
logger = init_logger(__name__) logger = init_logger(__name__)
class DeepSeekV31ToolParser(ToolParser): class DeepSeekV31ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = [] self.prev_tool_call_arr: list[dict] = []
......
...@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
...@@ -43,8 +44,8 @@ class DeepSeekV32ToolParser(ToolParser): ...@@ -43,8 +44,8 @@ class DeepSeekV32ToolParser(ToolParser):
</|DSML|function_calls> </|DSML|function_calls>
""" """
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
self.prev_tool_call_arr: list[dict] = [] self.prev_tool_call_arr: list[dict] = []
......
...@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
...@@ -27,8 +28,8 @@ logger = init_logger(__name__) ...@@ -27,8 +28,8 @@ logger = init_logger(__name__)
class DeepSeekV3ToolParser(ToolParser): class DeepSeekV3ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = [] self.prev_tool_call_arr: list[dict] = []
......
...@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
...@@ -27,12 +28,12 @@ logger = init_logger(__name__) ...@@ -27,12 +28,12 @@ logger = init_logger(__name__)
class Ernie45ToolParser(ToolParser): class Ernie45ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
""" """
Ernie thinking model format: Ernie thinking model format:
abc\n</think>\n\n\n<tool_call>\ndef\n</tool_call>\n abc\n</think>\n\n\n<tool_call>\ndef\n</tool_call>\n
""" """
super().__init__(tokenizer) super().__init__(tokenizer, tools)
self.current_tool_name_sent = False self.current_tool_name_sent = False
self.prev_tool_call_arr: list[dict] = [] self.prev_tool_call_arr: list[dict] = []
self.current_tool_id = -1 self.current_tool_id = -1
......
...@@ -20,7 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -20,7 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -33,8 +33,8 @@ class FunctionGemmaToolParser(ToolParser): ...@@ -33,8 +33,8 @@ class FunctionGemmaToolParser(ToolParser):
<start_function_call>call:func_name{param:<escape>value<escape>}<end_function_call> <start_function_call>call:func_name{param:<escape>value<escape>}<end_function_call>
""" """
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
# Streaming state # Streaming state
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False
......
...@@ -20,7 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -20,7 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -46,8 +46,8 @@ ARGS_REGEX = re.compile( ...@@ -46,8 +46,8 @@ ARGS_REGEX = re.compile(
class GigaChat3ToolParser(ToolParser): class GigaChat3ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
self.tool_started: bool = False self.tool_started: bool = False
self.tool_name_sent: bool = False self.tool_name_sent: bool = False
self.tool_id: str | None = None self.tool_id: str | None = None
......
...@@ -16,14 +16,15 @@ import regex as re ...@@ -16,14 +16,15 @@ import regex as re
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import Tool
from vllm.tool_parsers.glm4_moe_tool_parser import Glm4MoeModelToolParser from vllm.tool_parsers.glm4_moe_tool_parser import Glm4MoeModelToolParser
logger = init_logger(__name__) logger = init_logger(__name__)
class Glm47MoeModelToolParser(Glm4MoeModelToolParser): class Glm47MoeModelToolParser(Glm4MoeModelToolParser):
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
# GLM-4.7 format: <tool_call>func_name[<arg_key>...]*</tool_call> # GLM-4.7 format: <tool_call>func_name[<arg_key>...]*</tool_call>
# The function name can be followed by a newline, whitespace, or # The function name can be followed by a newline, whitespace, or
# directly by <arg_key> tags (no separator). The arg section is # directly by <arg_key> tags (no separator). The arg section is
......
...@@ -21,7 +21,6 @@ import regex as re ...@@ -21,7 +21,6 @@ import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionToolsParam,
) )
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall, DeltaFunctionCall,
...@@ -34,6 +33,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -34,6 +33,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
...@@ -48,8 +48,8 @@ class Glm4MoeModelToolParser(ToolParser): ...@@ -48,8 +48,8 @@ class Glm4MoeModelToolParser(ToolParser):
rather than waiting for the complete </arg_value> tag. rather than waiting for the complete </arg_value> tag.
""" """
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
# Stateful streaming fields # Stateful streaming fields
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict[str, Any]] = [] self.prev_tool_call_arr: list[dict[str, Any]] = []
...@@ -122,7 +122,7 @@ class Glm4MoeModelToolParser(ToolParser): ...@@ -122,7 +122,7 @@ class Glm4MoeModelToolParser(ToolParser):
def _is_string_type( def _is_string_type(
tool_name: str, tool_name: str,
arg_name: str, arg_name: str,
tools: list[ChatCompletionToolsParam] | None, tools: list[Tool] | None,
) -> bool: ) -> bool:
if tools is None: if tools is None:
return False return False
......
...@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
...@@ -43,8 +44,8 @@ FuncT = TypeVar("FuncT", bound=_FunctionCallCtor) ...@@ -43,8 +44,8 @@ FuncT = TypeVar("FuncT", bound=_FunctionCallCtor)
class Granite4ToolParser(ToolParser): class Granite4ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
self.prev_tool_call_arr: list[dict] = [] self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1 self.current_tool_id: int = -1
......
...@@ -24,6 +24,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -24,6 +24,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
from vllm.tool_parsers.utils import ( from vllm.tool_parsers.utils import (
...@@ -46,8 +47,8 @@ class Granite20bFCToolParser(ToolParser): ...@@ -46,8 +47,8 @@ class Granite20bFCToolParser(ToolParser):
are all set are all set
""" """
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
self.bot_token = "<function_call>" self.bot_token = "<function_call>"
self.tool_start_token = self.bot_token self.tool_start_token = self.bot_token
......
...@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
from vllm.tool_parsers.utils import ( from vllm.tool_parsers.utils import (
...@@ -44,8 +45,8 @@ class GraniteToolParser(ToolParser): ...@@ -44,8 +45,8 @@ class GraniteToolParser(ToolParser):
are all set are all set
""" """
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
# for granite 3.0, the token `<|tool_call|>` # for granite 3.0, the token `<|tool_call|>`
self.bot_token = "<|tool_call|>" self.bot_token = "<|tool_call|>"
# for granite 3.1, the string `<tool_call>` # for granite 3.1, the string `<tool_call>`
......
...@@ -23,6 +23,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -23,6 +23,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
from vllm.utils.mistral import is_mistral_tokenizer from vllm.utils.mistral import is_mistral_tokenizer
...@@ -31,8 +32,8 @@ logger = init_logger(__name__) ...@@ -31,8 +32,8 @@ logger = init_logger(__name__)
class Hermes2ProToolParser(ToolParser): class Hermes2ProToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
if is_mistral_tokenizer(tokenizer): if is_mistral_tokenizer(tokenizer):
logger.error("Detected Mistral tokenizer when using a Hermes model") logger.error("Detected Mistral tokenizer when using a Hermes model")
......
...@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
from vllm.tool_parsers.utils import consume_space from vllm.tool_parsers.utils import consume_space
...@@ -31,8 +32,8 @@ logger = init_logger(__name__) ...@@ -31,8 +32,8 @@ logger = init_logger(__name__)
class HunyuanA13BToolParser(ToolParser): class HunyuanA13BToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
# Initialize state for streaming mode # Initialize state for streaming mode
self.prev_tool_calls: list[dict] = [] self.prev_tool_calls: list[dict] = []
......
...@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
from vllm.tool_parsers.utils import extract_intermediate_diff from vllm.tool_parsers.utils import extract_intermediate_diff
...@@ -30,8 +31,8 @@ logger = init_logger(__name__) ...@@ -30,8 +31,8 @@ logger = init_logger(__name__)
class Internlm2ToolParser(ToolParser): class Internlm2ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
self.position = 0 self.position = 0
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment