Unverified Commit e2db2b42 authored by Flora Feng's avatar Flora Feng Committed by GitHub
Browse files

[Tool Parser][1/3] Pass tools to ToolParser constructor (#38029)


Signed-off-by: default avatarsfeng33 <4florafeng@gmail.com>
parent 87f05d68
......@@ -565,7 +565,7 @@ class OpenAIServingChat(OpenAIServing):
)
tool_parsers: list[ToolParser | None] = [
self.tool_parser(tokenizer)
self.tool_parser(tokenizer, request.tools)
] * num_choices
else:
tool_parsers = [None] * num_choices
......@@ -1331,7 +1331,7 @@ class OpenAIServingChat(OpenAIServing):
"Tokenizer not available when `skip_tokenizer_init=True`"
)
tool_parser = self.tool_parser(tokenizer)
tool_parser = self.tool_parser(tokenizer, request.tools)
# NOTE: We use token_ids for openai tool parser
tool_call_info = tool_parser.extract_tool_calls(
"",
......
......@@ -925,7 +925,7 @@ class OpenAIServing:
# Automatic Tool Call Parsing
try:
tool_parser = tool_parser_cls(tokenizer)
tool_parser = tool_parser_cls(tokenizer, request.tools)
except RuntimeError as e:
logger.exception("Error in tool parser creation.")
raise e
......
......@@ -52,7 +52,7 @@ class ResponsesParser:
self.reasoning_parser_instance = reasoning_parser_cls(tokenizer)
self.tool_parser_instance = None
if tool_parser_cls is not None:
self.tool_parser_instance = tool_parser_cls(tokenizer)
self.tool_parser_instance = tool_parser_cls(tokenizer, request.tools)
# Store the last finish_reason to determine response status
self.finish_reason: str | None = None
......
......@@ -1344,7 +1344,7 @@ class OpenAIServingResponses(OpenAIServing):
reasoning_parser = self.parser.reasoning_parser_cls(tokenizer)
tool_parser = None
if self.parser and self.parser.tool_parser_cls:
tool_parser = self.parser.tool_parser_cls(tokenizer)
tool_parser = self.parser.tool_parser_cls(tokenizer, request.tools)
reasoning_ended = False
tool_call_text_started = False
previous_text = ""
......
......@@ -545,6 +545,8 @@ class OpenAIServingRender:
)
raise NotImplementedError(msg)
tokenizer = renderer.get_tokenizer()
request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore[arg-type]
request = tool_parser(tokenizer, request.tools).adjust_request(
request=request # type: ignore[arg-type]
)
return conversation, [engine_input]
......@@ -5,13 +5,18 @@ import importlib
import os
from collections.abc import Callable, Sequence
from functools import cached_property
from typing import TypeAlias
from openai.types.responses import (
ResponseFormatTextJSONSchemaConfig,
ResponseTextConfig,
)
from openai.types.responses.tool import Tool as ResponsesTool
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage,
ExtractedToolCallInformation,
......@@ -30,6 +35,8 @@ from vllm.utils.import_utils import import_from_path
logger = init_logger(__name__)
Tool: TypeAlias = ChatCompletionToolsParam | ResponsesTool
class ToolParser:
"""
......@@ -38,7 +45,11 @@ class ToolParser:
derived classes.
"""
def __init__(self, tokenizer: TokenizerLike):
def __init__(
self,
tokenizer: TokenizerLike,
tools: list[Tool] | None = None,
):
self.prev_tool_call_arr: list[dict] = []
# the index of the tool call that is currently being parsed
self.current_tool_id: int = -1
......@@ -46,6 +57,7 @@ class ToolParser:
self.streamed_args_for_tool: list[str] = []
self.model_tokenizer = tokenizer
self.tools = tools
@cached_property
def vocab(self) -> dict[str, int]:
......
......@@ -19,14 +19,14 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser
from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
logger = init_logger(__name__)
class DeepSeekV31ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
......
......@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
......@@ -43,8 +44,8 @@ class DeepSeekV32ToolParser(ToolParser):
</|DSML|function_calls>
"""
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
self.prev_tool_call_arr: list[dict] = []
......
......@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
......@@ -27,8 +28,8 @@ logger = init_logger(__name__)
class DeepSeekV3ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
......
......@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
......@@ -27,12 +28,12 @@ logger = init_logger(__name__)
class Ernie45ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
"""
Ernie thinking model format:
abc\n</think>\n\n\n<tool_call>\ndef\n</tool_call>\n
"""
super().__init__(tokenizer)
super().__init__(tokenizer, tools)
self.current_tool_name_sent = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id = -1
......
......@@ -20,7 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser
from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
logger = init_logger(__name__)
......@@ -33,8 +33,8 @@ class FunctionGemmaToolParser(ToolParser):
<start_function_call>call:func_name{param:<escape>value<escape>}<end_function_call>
"""
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
# Streaming state
self.current_tool_name_sent: bool = False
......
......@@ -20,7 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser
from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
logger = init_logger(__name__)
......@@ -46,8 +46,8 @@ ARGS_REGEX = re.compile(
class GigaChat3ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
self.tool_started: bool = False
self.tool_name_sent: bool = False
self.tool_id: str | None = None
......
......@@ -16,14 +16,15 @@ import regex as re
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import Tool
from vllm.tool_parsers.glm4_moe_tool_parser import Glm4MoeModelToolParser
logger = init_logger(__name__)
class Glm47MoeModelToolParser(Glm4MoeModelToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
# GLM-4.7 format: <tool_call>func_name[<arg_key>...]*</tool_call>
# The function name can be followed by a newline, whitespace, or
# directly by <arg_key> tags (no separator). The arg section is
......
......@@ -21,7 +21,6 @@ import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
......@@ -34,6 +33,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
......@@ -48,8 +48,8 @@ class Glm4MoeModelToolParser(ToolParser):
rather than waiting for the complete </arg_value> tag.
"""
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
# Stateful streaming fields
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict[str, Any]] = []
......@@ -122,7 +122,7 @@ class Glm4MoeModelToolParser(ToolParser):
def _is_string_type(
tool_name: str,
arg_name: str,
tools: list[ChatCompletionToolsParam] | None,
tools: list[Tool] | None,
) -> bool:
if tools is None:
return False
......
......@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
......@@ -43,8 +44,8 @@ FuncT = TypeVar("FuncT", bound=_FunctionCallCtor)
class Granite4ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
......
......@@ -24,6 +24,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
from vllm.tool_parsers.utils import (
......@@ -46,8 +47,8 @@ class Granite20bFCToolParser(ToolParser):
are all set
"""
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
self.bot_token = "<function_call>"
self.tool_start_token = self.bot_token
......
......@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
from vllm.tool_parsers.utils import (
......@@ -44,8 +45,8 @@ class GraniteToolParser(ToolParser):
are all set
"""
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
# for granite 3.0, the token `<|tool_call|>`
self.bot_token = "<|tool_call|>"
# for granite 3.1, the string `<tool_call>`
......
......@@ -23,6 +23,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
from vllm.utils.mistral import is_mistral_tokenizer
......@@ -31,8 +32,8 @@ logger = init_logger(__name__)
class Hermes2ProToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
if is_mistral_tokenizer(tokenizer):
logger.error("Detected Mistral tokenizer when using a Hermes model")
......
......@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
from vllm.tool_parsers.utils import consume_space
......@@ -31,8 +32,8 @@ logger = init_logger(__name__)
class HunyuanA13BToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
# Initialize state for streaming mode
self.prev_tool_calls: list[dict] = []
......
......@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
from vllm.tool_parsers.utils import extract_intermediate_diff
......@@ -30,8 +31,8 @@ logger = init_logger(__name__)
class Internlm2ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
self.position = 0
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment