Unverified Commit e2db2b42 authored by Flora Feng's avatar Flora Feng Committed by GitHub
Browse files

[Tool Parser][1/3] Pass tools to ToolParser constructor (#38029)


Signed-off-by: default avatarsfeng33 <4florafeng@gmail.com>
parent 87f05d68
......@@ -22,7 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
from vllm.tool_parsers.utils import extract_intermediate_diff
from vllm.utils.mistral import is_mistral_tokenizer
......@@ -30,8 +30,8 @@ logger = init_logger(__name__)
class JambaToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
if is_mistral_tokenizer(self.model_tokenizer):
raise ValueError(
......
......@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
......@@ -27,8 +28,8 @@ logger = init_logger(__name__)
class KimiK2ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
......
......@@ -17,6 +17,7 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
from vllm.tool_parsers.utils import (
......@@ -47,8 +48,12 @@ class Llama4PythonicToolParser(ToolParser):
re.DOTALL,
)
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
tools: list[Tool] | None = None,
):
super().__init__(tokenizer, tools)
# Rename for readability. This is NOT a tool id.
@property
......
......@@ -24,6 +24,7 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
from vllm.tool_parsers.utils import (
......@@ -44,8 +45,12 @@ class Llama3JsonToolParser(ToolParser):
llama4_json are set.
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
tools: list[Tool] | None = None,
):
super().__init__(tokenizer, tools)
# initialize properties used for state when parsing tool calls in
# streaming mode
......
......@@ -4,12 +4,13 @@
import regex as re
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import Tool
from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
class LongcatFlashToolParser(Hermes2ProToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
self.tool_call_start_token: str = "<longcat_tool_call>"
self.tool_call_end_token: str = "</longcat_tool_call>"
......
......@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
......@@ -29,8 +30,8 @@ logger = init_logger(__name__)
class MinimaxM2ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
self.prev_tool_call_arr: list[dict] = []
......
......@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
from vllm.tool_parsers.utils import extract_intermediate_diff
......@@ -30,8 +31,8 @@ logger = init_logger(__name__)
class MinimaxToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
# Initialize streaming state for tracking tool call progress
self.streaming_state: dict[str, Any] = {
......
......@@ -26,6 +26,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
from vllm.utils.mistral import is_mistral_tokenizer
......@@ -78,8 +79,8 @@ class MistralToolParser(ToolParser):
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
"""
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
if not is_mistral_tokenizer(self.model_tokenizer):
logger.info("Non-Mistral tokenizer detected when using a Mistral model...")
......
......@@ -17,6 +17,7 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
from vllm.tool_parsers.utils import (
......@@ -51,8 +52,12 @@ class Olmo3PythonicToolParser(ToolParser):
re.DOTALL,
)
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
tools: list[Tool] | None = None,
):
super().__init__(tokenizer, tools)
# Rename for readability. This is NOT a tool id.
@property
......
......@@ -16,6 +16,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.entrypoints.openai.parser.harmony_utils import parse_output_into_messages
from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
......@@ -28,8 +29,8 @@ logger = init_logger(__name__)
class OpenAIToolParser(ToolParser):
def __init__(self, tokenizer: "TokenizerLike"):
super().__init__(tokenizer)
def __init__(self, tokenizer: "TokenizerLike", tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
def extract_tool_calls(
self,
......
......@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
......@@ -35,8 +36,12 @@ class Phi4MiniJsonToolParser(ToolParser):
are all set
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None:
super().__init__(tokenizer)
def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
tools: list[Tool] | None = None,
) -> None:
super().__init__(tokenizer, tools)
# initialize properties used for state when parsing tool calls in
# streaming mode
......
......@@ -17,6 +17,7 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
from vllm.tool_parsers.utils import (
......@@ -49,8 +50,12 @@ class PythonicToolParser(ToolParser):
re.DOTALL,
)
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
tools: list[Tool] | None = None,
):
super().__init__(tokenizer, tools)
# Rename for readability. This is NOT a tool id.
@property
......
......@@ -10,7 +10,6 @@ import regex as re
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
......@@ -23,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
......@@ -30,8 +30,8 @@ logger = init_logger(__name__)
class Qwen3CoderToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
......@@ -109,9 +109,7 @@ class Qwen3CoderToolParser(ToolParser):
self.accumulated_params = {}
self.streaming_request = None
def _get_arguments_config(
self, func_name: str, tools: list[ChatCompletionToolsParam] | None
) -> dict:
def _get_arguments_config(self, func_name: str, tools: list[Tool] | None) -> dict:
"""Extract argument configuration for a function."""
if tools is None:
return {}
......@@ -246,7 +244,7 @@ class Qwen3CoderToolParser(ToolParser):
return param_value
def _parse_xml_function_call(
self, function_call_str: str, tools: list[ChatCompletionToolsParam] | None
self, function_call_str: str, tools: list[Tool] | None
) -> ToolCall | None:
# Extract function name
end_index = function_call_str.find(">")
......
......@@ -11,7 +11,6 @@ import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
......@@ -24,6 +23,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
......@@ -40,7 +40,7 @@ class StreamingXMLToolCallParser:
self.reset_streaming_state()
# Tool configuration information
self.tools: list[ChatCompletionToolsParam] | None = None
self.tools: list[Tool] | None = None
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.function_start_token: str = "<function="
......@@ -961,7 +961,7 @@ class StreamingXMLToolCallParser:
self.parser.EndElementHandler = self._end_element
self.parser.CharacterDataHandler = self._char_data
def set_tools(self, tools: list[ChatCompletionToolsParam] | None):
def set_tools(self, tools: list[Tool] | None):
"""Set tool configuration information"""
self.tools = tools
......@@ -1167,8 +1167,8 @@ class StreamingXMLToolCallParser:
class Qwen3XMLToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
self.parser = StreamingXMLToolCallParser()
# Add missing attributes for compatibility with serving_chat.py
......
......@@ -13,7 +13,6 @@ import regex as re
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
......@@ -26,6 +25,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
......@@ -36,8 +36,8 @@ class SeedOssToolParser(ToolParser):
TOOL_CALL_START = "<seed:tool_call>"
TOOL_CALL_END = "</seed:tool_call>"
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
# --- streaming state ---
self._reset_streaming_state()
......@@ -109,7 +109,7 @@ class SeedOssToolParser(ToolParser):
self.json_closed = False
def _parse_xml_function_call(
self, function_call_str: str, tools: list[ChatCompletionToolsParam] | None
self, function_call_str: str, tools: list[Tool] | None
) -> ToolCall | None:
def get_arguments_config(func_name: str) -> dict:
if tools is None:
......
......@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
from vllm.utils import random_uuid
......@@ -43,8 +44,8 @@ class Step3ToolParser(ToolParser):
TOOL_SEP = "<|tool_sep|>"
SPECIAL_TOKENS = [TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END]
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
self.position = 0
# Explicit state flags for robust streaming
self.tool_block_started = False
......
......@@ -11,7 +11,6 @@ import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
......@@ -23,7 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser
from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
logger = init_logger(__name__)
......@@ -38,7 +37,7 @@ class StreamingXMLToolCallParser:
self.reset_streaming_state()
# Tool configuration information
self.tools: list[ChatCompletionToolsParam] | None = None
self.tools: list[Tool] | None = None
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.function_start_token: str = "<function="
......@@ -1161,7 +1160,7 @@ class StreamingXMLToolCallParser:
self.parser.EndElementHandler = self._end_element
self.parser.CharacterDataHandler = self._char_data
def set_tools(self, tools: list[ChatCompletionToolsParam] | None):
def set_tools(self, tools: list[Tool] | None):
"""Set tool configuration information"""
self.tools = tools
......@@ -1365,8 +1364,8 @@ class StreamingXMLToolCallParser:
class Step3p5ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
self.parser = StreamingXMLToolCallParser()
# Add missing attributes for compatibility with serving_chat.py
......
......@@ -6,6 +6,7 @@ from collections.abc import Sequence
from typing import Any, Optional, Union
import regex as re
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
......@@ -19,6 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import (
ToolCall,
)
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
from vllm.logger import init_logger
......@@ -29,8 +31,8 @@ logger = init_logger(__name__)
class xLAMToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
# Initialize state for streaming mode
self.prev_tool_calls: list[dict] = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment