Unverified Commit e2db2b42 authored by Flora Feng's avatar Flora Feng Committed by GitHub
Browse files

[Tool Parser][1/3] Pass tools to ToolParser constructor (#38029)


Signed-off-by: default avatarsfeng33 <4florafeng@gmail.com>
parent 87f05d68
...@@ -22,7 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -22,7 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
from vllm.tool_parsers.utils import extract_intermediate_diff from vllm.tool_parsers.utils import extract_intermediate_diff
from vllm.utils.mistral import is_mistral_tokenizer from vllm.utils.mistral import is_mistral_tokenizer
...@@ -30,8 +30,8 @@ logger = init_logger(__name__) ...@@ -30,8 +30,8 @@ logger = init_logger(__name__)
class JambaToolParser(ToolParser): class JambaToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
if is_mistral_tokenizer(self.model_tokenizer): if is_mistral_tokenizer(self.model_tokenizer):
raise ValueError( raise ValueError(
......
...@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
...@@ -27,8 +28,8 @@ logger = init_logger(__name__) ...@@ -27,8 +28,8 @@ logger = init_logger(__name__)
class KimiK2ToolParser(ToolParser): class KimiK2ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = [] self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1 self.current_tool_id: int = -1
......
...@@ -17,6 +17,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -17,6 +17,7 @@ from vllm.entrypoints.openai.engine.protocol import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
from vllm.tool_parsers.utils import ( from vllm.tool_parsers.utils import (
...@@ -47,8 +48,12 @@ class Llama4PythonicToolParser(ToolParser): ...@@ -47,8 +48,12 @@ class Llama4PythonicToolParser(ToolParser):
re.DOTALL, re.DOTALL,
) )
def __init__(self, tokenizer: PreTrainedTokenizerBase): def __init__(
super().__init__(tokenizer) self,
tokenizer: PreTrainedTokenizerBase,
tools: list[Tool] | None = None,
):
super().__init__(tokenizer, tools)
# Rename for readability. This is NOT a tool id. # Rename for readability. This is NOT a tool id.
@property @property
......
...@@ -24,6 +24,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -24,6 +24,7 @@ from vllm.entrypoints.openai.engine.protocol import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
from vllm.tool_parsers.utils import ( from vllm.tool_parsers.utils import (
...@@ -44,8 +45,12 @@ class Llama3JsonToolParser(ToolParser): ...@@ -44,8 +45,12 @@ class Llama3JsonToolParser(ToolParser):
llama4_json are set. llama4_json are set.
""" """
def __init__(self, tokenizer: PreTrainedTokenizerBase): def __init__(
super().__init__(tokenizer) self,
tokenizer: PreTrainedTokenizerBase,
tools: list[Tool] | None = None,
):
super().__init__(tokenizer, tools)
# initialize properties used for state when parsing tool calls in # initialize properties used for state when parsing tool calls in
# streaming mode # streaming mode
......
...@@ -4,12 +4,13 @@ ...@@ -4,12 +4,13 @@
import regex as re import regex as re
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import Tool
from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
class LongcatFlashToolParser(Hermes2ProToolParser): class LongcatFlashToolParser(Hermes2ProToolParser):
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
self.tool_call_start_token: str = "<longcat_tool_call>" self.tool_call_start_token: str = "<longcat_tool_call>"
self.tool_call_end_token: str = "</longcat_tool_call>" self.tool_call_end_token: str = "</longcat_tool_call>"
......
...@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
...@@ -29,8 +30,8 @@ logger = init_logger(__name__) ...@@ -29,8 +30,8 @@ logger = init_logger(__name__)
class MinimaxM2ToolParser(ToolParser): class MinimaxM2ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
self.prev_tool_call_arr: list[dict] = [] self.prev_tool_call_arr: list[dict] = []
......
...@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
from vllm.tool_parsers.utils import extract_intermediate_diff from vllm.tool_parsers.utils import extract_intermediate_diff
...@@ -30,8 +31,8 @@ logger = init_logger(__name__) ...@@ -30,8 +31,8 @@ logger = init_logger(__name__)
class MinimaxToolParser(ToolParser): class MinimaxToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
# Initialize streaming state for tracking tool call progress # Initialize streaming state for tracking tool call progress
self.streaming_state: dict[str, Any] = { self.streaming_state: dict[str, Any] = {
......
...@@ -26,6 +26,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -26,6 +26,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
from vllm.utils.mistral import is_mistral_tokenizer from vllm.utils.mistral import is_mistral_tokenizer
...@@ -78,8 +79,8 @@ class MistralToolParser(ToolParser): ...@@ -78,8 +79,8 @@ class MistralToolParser(ToolParser):
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
""" """
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
if not is_mistral_tokenizer(self.model_tokenizer): if not is_mistral_tokenizer(self.model_tokenizer):
logger.info("Non-Mistral tokenizer detected when using a Mistral model...") logger.info("Non-Mistral tokenizer detected when using a Mistral model...")
......
...@@ -17,6 +17,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -17,6 +17,7 @@ from vllm.entrypoints.openai.engine.protocol import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
from vllm.tool_parsers.utils import ( from vllm.tool_parsers.utils import (
...@@ -51,8 +52,12 @@ class Olmo3PythonicToolParser(ToolParser): ...@@ -51,8 +52,12 @@ class Olmo3PythonicToolParser(ToolParser):
re.DOTALL, re.DOTALL,
) )
def __init__(self, tokenizer: PreTrainedTokenizerBase): def __init__(
super().__init__(tokenizer) self,
tokenizer: PreTrainedTokenizerBase,
tools: list[Tool] | None = None,
):
super().__init__(tokenizer, tools)
# Rename for readability. This is NOT a tool id. # Rename for readability. This is NOT a tool id.
@property @property
......
...@@ -16,6 +16,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -16,6 +16,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.entrypoints.openai.parser.harmony_utils import parse_output_into_messages from vllm.entrypoints.openai.parser.harmony_utils import parse_output_into_messages
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
...@@ -28,8 +29,8 @@ logger = init_logger(__name__) ...@@ -28,8 +29,8 @@ logger = init_logger(__name__)
class OpenAIToolParser(ToolParser): class OpenAIToolParser(ToolParser):
def __init__(self, tokenizer: "TokenizerLike"): def __init__(self, tokenizer: "TokenizerLike", tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
def extract_tool_calls( def extract_tool_calls(
self, self,
......
...@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -20,6 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
...@@ -35,8 +36,12 @@ class Phi4MiniJsonToolParser(ToolParser): ...@@ -35,8 +36,12 @@ class Phi4MiniJsonToolParser(ToolParser):
are all set are all set
""" """
def __init__(self, tokenizer: PreTrainedTokenizerBase) -> None: def __init__(
super().__init__(tokenizer) self,
tokenizer: PreTrainedTokenizerBase,
tools: list[Tool] | None = None,
) -> None:
super().__init__(tokenizer, tools)
# initialize properties used for state when parsing tool calls in # initialize properties used for state when parsing tool calls in
# streaming mode # streaming mode
......
...@@ -17,6 +17,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -17,6 +17,7 @@ from vllm.entrypoints.openai.engine.protocol import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
from vllm.tool_parsers.utils import ( from vllm.tool_parsers.utils import (
...@@ -49,8 +50,12 @@ class PythonicToolParser(ToolParser): ...@@ -49,8 +50,12 @@ class PythonicToolParser(ToolParser):
re.DOTALL, re.DOTALL,
) )
def __init__(self, tokenizer: PreTrainedTokenizerBase): def __init__(
super().__init__(tokenizer) self,
tokenizer: PreTrainedTokenizerBase,
tools: list[Tool] | None = None,
):
super().__init__(tokenizer, tools)
# Rename for readability. This is NOT a tool id. # Rename for readability. This is NOT a tool id.
@property @property
......
...@@ -10,7 +10,6 @@ import regex as re ...@@ -10,7 +10,6 @@ import regex as re
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionToolsParam,
) )
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall, DeltaFunctionCall,
...@@ -23,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -23,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
...@@ -30,8 +30,8 @@ logger = init_logger(__name__) ...@@ -30,8 +30,8 @@ logger = init_logger(__name__)
class Qwen3CoderToolParser(ToolParser): class Qwen3CoderToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = [] self.prev_tool_call_arr: list[dict] = []
...@@ -109,9 +109,7 @@ class Qwen3CoderToolParser(ToolParser): ...@@ -109,9 +109,7 @@ class Qwen3CoderToolParser(ToolParser):
self.accumulated_params = {} self.accumulated_params = {}
self.streaming_request = None self.streaming_request = None
def _get_arguments_config( def _get_arguments_config(self, func_name: str, tools: list[Tool] | None) -> dict:
self, func_name: str, tools: list[ChatCompletionToolsParam] | None
) -> dict:
"""Extract argument configuration for a function.""" """Extract argument configuration for a function."""
if tools is None: if tools is None:
return {} return {}
...@@ -246,7 +244,7 @@ class Qwen3CoderToolParser(ToolParser): ...@@ -246,7 +244,7 @@ class Qwen3CoderToolParser(ToolParser):
return param_value return param_value
def _parse_xml_function_call( def _parse_xml_function_call(
self, function_call_str: str, tools: list[ChatCompletionToolsParam] | None self, function_call_str: str, tools: list[Tool] | None
) -> ToolCall | None: ) -> ToolCall | None:
# Extract function name # Extract function name
end_index = function_call_str.find(">") end_index = function_call_str.find(">")
......
...@@ -11,7 +11,6 @@ import regex as re ...@@ -11,7 +11,6 @@ import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionToolsParam,
) )
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall, DeltaFunctionCall,
...@@ -24,6 +23,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -24,6 +23,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
...@@ -40,7 +40,7 @@ class StreamingXMLToolCallParser: ...@@ -40,7 +40,7 @@ class StreamingXMLToolCallParser:
self.reset_streaming_state() self.reset_streaming_state()
# Tool configuration information # Tool configuration information
self.tools: list[ChatCompletionToolsParam] | None = None self.tools: list[Tool] | None = None
self.tool_call_start_token: str = "<tool_call>" self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>" self.tool_call_end_token: str = "</tool_call>"
self.function_start_token: str = "<function=" self.function_start_token: str = "<function="
...@@ -961,7 +961,7 @@ class StreamingXMLToolCallParser: ...@@ -961,7 +961,7 @@ class StreamingXMLToolCallParser:
self.parser.EndElementHandler = self._end_element self.parser.EndElementHandler = self._end_element
self.parser.CharacterDataHandler = self._char_data self.parser.CharacterDataHandler = self._char_data
def set_tools(self, tools: list[ChatCompletionToolsParam] | None): def set_tools(self, tools: list[Tool] | None):
"""Set tool configuration information""" """Set tool configuration information"""
self.tools = tools self.tools = tools
...@@ -1167,8 +1167,8 @@ class StreamingXMLToolCallParser: ...@@ -1167,8 +1167,8 @@ class StreamingXMLToolCallParser:
class Qwen3XMLToolParser(ToolParser): class Qwen3XMLToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
self.parser = StreamingXMLToolCallParser() self.parser = StreamingXMLToolCallParser()
# Add missing attributes for compatibility with serving_chat.py # Add missing attributes for compatibility with serving_chat.py
......
...@@ -13,7 +13,6 @@ import regex as re ...@@ -13,7 +13,6 @@ import regex as re
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionToolsParam,
) )
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall, DeltaFunctionCall,
...@@ -26,6 +25,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -26,6 +25,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
...@@ -36,8 +36,8 @@ class SeedOssToolParser(ToolParser): ...@@ -36,8 +36,8 @@ class SeedOssToolParser(ToolParser):
TOOL_CALL_START = "<seed:tool_call>" TOOL_CALL_START = "<seed:tool_call>"
TOOL_CALL_END = "</seed:tool_call>" TOOL_CALL_END = "</seed:tool_call>"
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
# --- streaming state --- # --- streaming state ---
self._reset_streaming_state() self._reset_streaming_state()
...@@ -109,7 +109,7 @@ class SeedOssToolParser(ToolParser): ...@@ -109,7 +109,7 @@ class SeedOssToolParser(ToolParser):
self.json_closed = False self.json_closed = False
def _parse_xml_function_call( def _parse_xml_function_call(
self, function_call_str: str, tools: list[ChatCompletionToolsParam] | None self, function_call_str: str, tools: list[Tool] | None
) -> ToolCall | None: ) -> ToolCall | None:
def get_arguments_config(func_name: str) -> dict: def get_arguments_config(func_name: str) -> dict:
if tools is None: if tools is None:
......
...@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -22,6 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
from vllm.utils import random_uuid from vllm.utils import random_uuid
...@@ -43,8 +44,8 @@ class Step3ToolParser(ToolParser): ...@@ -43,8 +44,8 @@ class Step3ToolParser(ToolParser):
TOOL_SEP = "<|tool_sep|>" TOOL_SEP = "<|tool_sep|>"
SPECIAL_TOKENS = [TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END] SPECIAL_TOKENS = [TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END]
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
self.position = 0 self.position = 0
# Explicit state flags for robust streaming # Explicit state flags for robust streaming
self.tool_block_started = False self.tool_block_started = False
......
...@@ -11,7 +11,6 @@ import regex as re ...@@ -11,7 +11,6 @@ import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionToolsParam,
) )
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall, DeltaFunctionCall,
...@@ -23,7 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -23,7 +22,7 @@ from vllm.entrypoints.openai.engine.protocol import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser from vllm.tool_parsers.abstract_tool_parser import Tool, ToolParser
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -38,7 +37,7 @@ class StreamingXMLToolCallParser: ...@@ -38,7 +37,7 @@ class StreamingXMLToolCallParser:
self.reset_streaming_state() self.reset_streaming_state()
# Tool configuration information # Tool configuration information
self.tools: list[ChatCompletionToolsParam] | None = None self.tools: list[Tool] | None = None
self.tool_call_start_token: str = "<tool_call>" self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>" self.tool_call_end_token: str = "</tool_call>"
self.function_start_token: str = "<function=" self.function_start_token: str = "<function="
...@@ -1161,7 +1160,7 @@ class StreamingXMLToolCallParser: ...@@ -1161,7 +1160,7 @@ class StreamingXMLToolCallParser:
self.parser.EndElementHandler = self._end_element self.parser.EndElementHandler = self._end_element
self.parser.CharacterDataHandler = self._char_data self.parser.CharacterDataHandler = self._char_data
def set_tools(self, tools: list[ChatCompletionToolsParam] | None): def set_tools(self, tools: list[Tool] | None):
"""Set tool configuration information""" """Set tool configuration information"""
self.tools = tools self.tools = tools
...@@ -1365,8 +1364,8 @@ class StreamingXMLToolCallParser: ...@@ -1365,8 +1364,8 @@ class StreamingXMLToolCallParser:
class Step3p5ToolParser(ToolParser): class Step3p5ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
self.parser = StreamingXMLToolCallParser() self.parser = StreamingXMLToolCallParser()
# Add missing attributes for compatibility with serving_chat.py # Add missing attributes for compatibility with serving_chat.py
......
...@@ -6,6 +6,7 @@ from collections.abc import Sequence ...@@ -6,6 +6,7 @@ from collections.abc import Sequence
from typing import Any, Optional, Union from typing import Any, Optional, Union
import regex as re import regex as re
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
) )
...@@ -19,6 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -19,6 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import (
ToolCall, ToolCall,
) )
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -29,8 +31,8 @@ logger = init_logger(__name__) ...@@ -29,8 +31,8 @@ logger = init_logger(__name__)
class xLAMToolParser(ToolParser): class xLAMToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer) super().__init__(tokenizer, tools)
# Initialize state for streaming mode # Initialize state for streaming mode
self.prev_tool_calls: list[dict] = [] self.prev_tool_calls: list[dict] = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment