Unverified Commit d53cb9cb authored by Flora Feng's avatar Flora Feng Committed by GitHub
Browse files

[Tool Parser][2/3] Use self.tools instead of request.tools in tool parsers (#38189)


Signed-off-by: default avatarsfeng33 <4florafeng@gmail.com>
parent 44eef0ca
...@@ -11,6 +11,10 @@ from unittest.mock import MagicMock ...@@ -11,6 +11,10 @@ from unittest.mock import MagicMock
import pytest import pytest
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionToolsParam,
FunctionDefinition,
)
from vllm.tokenizers import get_tokenizer from vllm.tokenizers import get_tokenizer
from vllm.tool_parsers.deepseekv32_tool_parser import DeepSeekV32ToolParser from vllm.tool_parsers.deepseekv32_tool_parser import DeepSeekV32ToolParser
...@@ -24,8 +28,8 @@ MOCK_TOKENIZER = MagicMock() ...@@ -24,8 +28,8 @@ MOCK_TOKENIZER = MagicMock()
MOCK_TOKENIZER.get_vocab.return_value = {} MOCK_TOKENIZER.get_vocab.return_value = {}
def make_parser() -> DeepSeekV32ToolParser: def make_parser(tools=None) -> DeepSeekV32ToolParser:
return DeepSeekV32ToolParser(MOCK_TOKENIZER) return DeepSeekV32ToolParser(MOCK_TOKENIZER, tools=tools)
def make_tool_param(name: str, params: dict) -> MagicMock: def make_tool_param(name: str, params: dict) -> MagicMock:
...@@ -275,20 +279,22 @@ class TestExtractToolCallsStreaming: ...@@ -275,20 +279,22 @@ class TestExtractToolCallsStreaming:
content = "".join(d.content for d in deltas if d.content is not None) content = "".join(d.content for d in deltas if d.content is not None)
assert "Thinking" in content assert "Thinking" in content
def test_type_conversion_in_streaming(self, parser): def test_type_conversion_in_streaming(self):
tool = make_tool_param( tool = ChatCompletionToolsParam(
"add", function=FunctionDefinition(
{ name="add",
parameters={
"type": "object", "type": "object",
"properties": { "properties": {
"x": {"type": "integer"}, "x": {"type": "integer"},
"y": {"type": "integer"}, "y": {"type": "integer"},
}, },
}, },
),
) )
request = make_request(tools=[tool]) parser = make_parser(tools=[tool])
full_text = build_tool_call("add", {"x": "3", "y": "4"}) full_text = build_tool_call("add", {"x": "3", "y": "4"})
deltas = self._stream(parser, full_text, request=request) deltas = self._stream(parser, full_text)
args_str = self._reconstruct_args(deltas) args_str = self._reconstruct_args(deltas)
assert json.loads(args_str) == {"x": 3, "y": 4} assert json.loads(args_str) == {"x": 3, "y": 4}
......
...@@ -25,14 +25,8 @@ def glm47_tokenizer(): ...@@ -25,14 +25,8 @@ def glm47_tokenizer():
@pytest.fixture @pytest.fixture
def glm47_tool_parser(glm47_tokenizer): def sample_tools():
return Glm47MoeModelToolParser(glm47_tokenizer) return [
@pytest.fixture
def mock_request() -> ChatCompletionRequest:
request = Mock(spec=ChatCompletionRequest)
request.tools = [
ChatCompletionToolsParam( ChatCompletionToolsParam(
function=FunctionDefinition(name="get_current_date", parameters={}), function=FunctionDefinition(name="get_current_date", parameters={}),
), ),
...@@ -49,6 +43,17 @@ def mock_request() -> ChatCompletionRequest: ...@@ -49,6 +43,17 @@ def mock_request() -> ChatCompletionRequest:
), ),
), ),
] ]
@pytest.fixture
def glm47_tool_parser(glm47_tokenizer, sample_tools):
return Glm47MoeModelToolParser(glm47_tokenizer, tools=sample_tools)
@pytest.fixture
def mock_request(sample_tools) -> ChatCompletionRequest:
request = Mock(spec=ChatCompletionRequest)
request.tools = sample_tools
request.tool_choice = "auto" request.tool_choice = "auto"
return request return request
......
...@@ -27,14 +27,8 @@ def glm4_moe_tokenizer(): ...@@ -27,14 +27,8 @@ def glm4_moe_tokenizer():
@pytest.fixture @pytest.fixture
def glm4_moe_tool_parser(glm4_moe_tokenizer): def sample_tools():
return Glm4MoeModelToolParser(glm4_moe_tokenizer) return [
@pytest.fixture
def mock_request() -> ChatCompletionRequest:
request = Mock(spec=ChatCompletionRequest)
request.tools = [ # GLM45 parser needs this attribute to enable tool parsing.
ChatCompletionToolsParam( ChatCompletionToolsParam(
function=FunctionDefinition( function=FunctionDefinition(
name="get_weather", name="get_weather",
...@@ -42,6 +36,17 @@ def mock_request() -> ChatCompletionRequest: ...@@ -42,6 +36,17 @@ def mock_request() -> ChatCompletionRequest:
), ),
), ),
] ]
@pytest.fixture
def glm4_moe_tool_parser(glm4_moe_tokenizer, sample_tools):
return Glm4MoeModelToolParser(glm4_moe_tokenizer, tools=sample_tools)
@pytest.fixture
def mock_request(sample_tools) -> ChatCompletionRequest:
request = Mock(spec=ChatCompletionRequest)
request.tools = sample_tools
return request return request
...@@ -671,14 +676,13 @@ def test_streaming_json_escape_in_string(glm4_moe_tool_parser, mock_request): ...@@ -671,14 +676,13 @@ def test_streaming_json_escape_in_string(glm4_moe_tool_parser, mock_request):
assert '"' in parsed["message"] or "world" in parsed["message"] assert '"' in parsed["message"] or "world" in parsed["message"]
def test_streaming_long_content_incremental(glm4_moe_tool_parser): def test_streaming_long_content_incremental(glm4_moe_tokenizer):
"""Test incremental streaming of long content (Issue #32829). """Test incremental streaming of long content (Issue #32829).
This is the core fix: for long string values like code (4000+ chars), This is the core fix: for long string values like code (4000+ chars),
the parser should stream incrementally rather than buffering until the parser should stream incrementally rather than buffering until
complete. This test verifies we get many fragments, not just 1-3. complete. This test verifies we get many fragments, not just 1-3.
""" """
_reset_streaming_state(glm4_moe_tool_parser)
# Bubble sort example from Issue #32829 - realistic long content # Bubble sort example from Issue #32829 - realistic long content
bubble_sort_code = '''#!/usr/bin/env python3 bubble_sort_code = '''#!/usr/bin/env python3
...@@ -705,27 +709,28 @@ if __name__ == "__main__": ...@@ -705,27 +709,28 @@ if __name__ == "__main__":
sorted_arr = bubble_sort(test_arr.copy()) sorted_arr = bubble_sort(test_arr.copy())
print(f"Sorted: {sorted_arr}")''' print(f"Sorted: {sorted_arr}")'''
# Create a request with tool schema to enable string type detection # Create tools with schema to enable string type detection
# This is required for incremental streaming of string values # This is required for incremental streaming of string values
request = ChatCompletionRequest( tools = [
model=MODEL, ChatCompletionToolsParam(
messages=[], function=FunctionDefinition(
tools=[ name="write_to_file",
{ parameters={
"type": "function",
"function": {
"name": "write_to_file",
"parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"file_path": {"type": "string"}, "file_path": {"type": "string"},
"content": {"type": "string"}, "content": {"type": "string"},
}, },
}, },
}, ),
} ),
], ]
) # type: ignore glm4_moe_tool_parser = Glm4MoeModelToolParser(glm4_moe_tokenizer, tools=tools)
request = ChatCompletionRequest(
model=MODEL,
messages=[],
tools=tools,
)
# Simulate token-based streaming (special tags as single tokens) # Simulate token-based streaming (special tags as single tokens)
chunks = [ chunks = [
......
...@@ -31,13 +31,13 @@ def qwen3_tokenizer(): ...@@ -31,13 +31,13 @@ def qwen3_tokenizer():
@pytest.fixture @pytest.fixture
def qwen3_tool_parser(qwen3_tokenizer): def qwen3_tool_parser(qwen3_tokenizer, sample_tools):
return Qwen3CoderToolParser(qwen3_tokenizer) return Qwen3CoderToolParser(qwen3_tokenizer, tools=sample_tools)
@pytest.fixture @pytest.fixture
def qwen3_xml_tool_parser(qwen3_tokenizer): def qwen3_xml_tool_parser(qwen3_tokenizer, sample_tools):
return Qwen3XMLToolParser(qwen3_tokenizer) return Qwen3XMLToolParser(qwen3_tokenizer, tools=sample_tools)
@pytest.fixture(params=["xml"]) @pytest.fixture(params=["xml"])
...@@ -376,7 +376,7 @@ TX ...@@ -376,7 +376,7 @@ TX
assert extracted_tool_calls.tool_calls[0].function.name == "get_current_weather" assert extracted_tool_calls.tool_calls[0].function.name == "get_current_weather"
def test_extract_tool_calls_type_conversion(qwen3_tool_parser_parametrized): def test_extract_tool_calls_type_conversion(qwen3_tokenizer):
"""Test parameter type conversion based on tool schema""" """Test parameter type conversion based on tool schema"""
tools = [ tools = [
ChatCompletionToolsParam( ChatCompletionToolsParam(
...@@ -417,10 +417,9 @@ hello world ...@@ -417,10 +417,9 @@ hello world
</function> </function>
</tool_call>""" </tool_call>"""
parser = Qwen3XMLToolParser(qwen3_tokenizer, tools=tools)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools) request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools)
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( extracted_tool_calls = parser.extract_tool_calls(model_output, request=request)
model_output, request=request
)
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
assert args["int_param"] == 42 assert args["int_param"] == 42
...@@ -859,7 +858,7 @@ TX ...@@ -859,7 +858,7 @@ TX
def test_extract_tool_calls_complex_type_with_single_quote( def test_extract_tool_calls_complex_type_with_single_quote(
qwen3_tool_parser_parametrized, qwen3_tokenizer,
): ):
"""Test parameter type conversion based on tool schema""" """Test parameter type conversion based on tool schema"""
tools = [ tools = [
...@@ -889,10 +888,9 @@ def test_extract_tool_calls_complex_type_with_single_quote( ...@@ -889,10 +888,9 @@ def test_extract_tool_calls_complex_type_with_single_quote(
</function> </function>
</tool_call>""" </tool_call>"""
parser = Qwen3XMLToolParser(qwen3_tokenizer, tools=tools)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools) request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools)
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls( extracted_tool_calls = parser.extract_tool_calls(model_output, request=request)
model_output, request=request
)
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
assert args["obj_param"] == {"key": "value"} assert args["obj_param"] == {"key": "value"}
......
...@@ -30,8 +30,8 @@ def seed_oss_tokenizer(): ...@@ -30,8 +30,8 @@ def seed_oss_tokenizer():
@pytest.fixture @pytest.fixture
def seed_oss_tool_parser(seed_oss_tokenizer): def seed_oss_tool_parser(seed_oss_tokenizer, sample_tools):
return SeedOssToolParser(seed_oss_tokenizer) return SeedOssToolParser(seed_oss_tokenizer, tools=sample_tools)
@pytest.fixture @pytest.fixture
......
...@@ -28,8 +28,8 @@ def step3p5_tokenizer(): ...@@ -28,8 +28,8 @@ def step3p5_tokenizer():
@pytest.fixture @pytest.fixture
def step3p5_tool_parser(step3p5_tokenizer): def step3p5_tool_parser(step3p5_tokenizer, sample_tools):
return Step3p5ToolParser(step3p5_tokenizer) return Step3p5ToolParser(step3p5_tokenizer, tools=sample_tools)
@pytest.fixture @pytest.fixture
...@@ -386,7 +386,7 @@ TX ...@@ -386,7 +386,7 @@ TX
assert extracted_tool_calls.tool_calls[0].function.name == "get_current_weather" assert extracted_tool_calls.tool_calls[0].function.name == "get_current_weather"
def test_extract_tool_calls_type_conversion(step3p5_tool_parser): def test_extract_tool_calls_type_conversion(step3p5_tokenizer):
"""Test parameter type conversion based on tool schema""" """Test parameter type conversion based on tool schema"""
tools = [ tools = [
ChatCompletionToolsParam( ChatCompletionToolsParam(
...@@ -427,10 +427,9 @@ hello world ...@@ -427,10 +427,9 @@ hello world
</function> </function>
</tool_call>""" </tool_call>"""
parser = Step3p5ToolParser(step3p5_tokenizer, tools=tools)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools) request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools)
extracted_tool_calls = step3p5_tool_parser.extract_tool_calls( extracted_tool_calls = parser.extract_tool_calls(model_output, request=request)
model_output, request=request
)
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
assert args["int_param"] == 42 assert args["int_param"] == 42
...@@ -864,7 +863,7 @@ TX ...@@ -864,7 +863,7 @@ TX
assert parsed_args["state"] == "TX" assert parsed_args["state"] == "TX"
def test_extract_tool_calls_complex_type_with_single_quote(step3p5_tool_parser): def test_extract_tool_calls_complex_type_with_single_quote(step3p5_tokenizer):
"""Test parameter type conversion based on tool schema""" """Test parameter type conversion based on tool schema"""
tools = [ tools = [
ChatCompletionToolsParam( ChatCompletionToolsParam(
...@@ -893,10 +892,9 @@ def test_extract_tool_calls_complex_type_with_single_quote(step3p5_tool_parser): ...@@ -893,10 +892,9 @@ def test_extract_tool_calls_complex_type_with_single_quote(step3p5_tool_parser):
</function> </function>
</tool_call>""" </tool_call>"""
parser = Step3p5ToolParser(step3p5_tokenizer, tools=tools)
request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools) request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools)
extracted_tool_calls = step3p5_tool_parser.extract_tool_calls( extracted_tool_calls = parser.extract_tool_calls(model_output, request=request)
model_output, request=request
)
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
assert args["obj_param"] == {"key": "value"} assert args["obj_param"] == {"key": "value"}
......
...@@ -10,9 +10,11 @@ from openai.types.responses import ( ...@@ -10,9 +10,11 @@ from openai.types.responses import (
ResponseFormatTextJSONSchemaConfig, ResponseFormatTextJSONSchemaConfig,
ResponseTextConfig, ResponseTextConfig,
) )
from openai.types.responses.function_tool import FunctionTool
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionToolsParam,
) )
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage, DeltaMessage,
...@@ -54,7 +56,14 @@ class ToolParser: ...@@ -54,7 +56,14 @@ class ToolParser:
self.streamed_args_for_tool: list[str] = [] self.streamed_args_for_tool: list[str] = []
self.model_tokenizer = tokenizer self.model_tokenizer = tokenizer
self.tools = tools if tools:
self.tools: list[ChatCompletionToolsParam | FunctionTool] = [
tool
for tool in tools
if isinstance(tool, (ChatCompletionToolsParam, FunctionTool))
]
else:
self.tools = []
@cached_property @cached_property
def vocab(self) -> dict[str, int]: def vocab(self) -> dict[str, int]:
......
...@@ -142,12 +142,11 @@ class DeepSeekV32ToolParser(ToolParser): ...@@ -142,12 +142,11 @@ class DeepSeekV32ToolParser(ToolParser):
self, self,
function_name: str, function_name: str,
param_dict: dict[str, str], param_dict: dict[str, str],
request: ChatCompletionRequest | None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Convert raw string param values using the tool schema types.""" """Convert raw string param values using the tool schema types."""
param_config: dict = {} param_config: dict = {}
if request and request.tools: if self.tools:
for tool in request.tools: for tool in self.tools:
if ( if (
hasattr(tool, "function") hasattr(tool, "function")
and tool.function.name == function_name and tool.function.name == function_name
...@@ -241,9 +240,7 @@ class DeepSeekV32ToolParser(ToolParser): ...@@ -241,9 +240,7 @@ class DeepSeekV32ToolParser(ToolParser):
invoke_name, invoke_body = complete_invokes[self.current_tool_index] invoke_name, invoke_body = complete_invokes[self.current_tool_index]
param_dict = self._parse_invoke_params(invoke_body) param_dict = self._parse_invoke_params(invoke_body)
converted = self._convert_params_with_schema( converted = self._convert_params_with_schema(invoke_name, param_dict)
invoke_name, param_dict, request
)
args_json = json.dumps(converted, ensure_ascii=False) args_json = json.dumps(converted, ensure_ascii=False)
idx = self.current_tool_index idx = self.current_tool_index
self.current_tool_index += 1 self.current_tool_index += 1
......
...@@ -189,7 +189,7 @@ class Glm4MoeModelToolParser(ToolParser): ...@@ -189,7 +189,7 @@ class Glm4MoeModelToolParser(ToolParser):
for key, value in pairs: for key, value in pairs:
arg_key = key.strip() arg_key = key.strip()
arg_val = value.strip() arg_val = value.strip()
if not self._is_string_type(tc_name, arg_key, request.tools): if not self._is_string_type(tc_name, arg_key, self.tools):
arg_val = self._deserialize(arg_val) arg_val = self._deserialize(arg_val)
logger.debug("arg_key = %s, arg_val = %s", arg_key, arg_val) logger.debug("arg_key = %s, arg_val = %s", arg_key, arg_val)
arg_dct[arg_key] = arg_val arg_dct[arg_key] = arg_val
...@@ -330,7 +330,7 @@ class Glm4MoeModelToolParser(ToolParser): ...@@ -330,7 +330,7 @@ class Glm4MoeModelToolParser(ToolParser):
key = (self._pending_key or "").strip() key = (self._pending_key or "").strip()
is_string = self._is_string_type( is_string = self._is_string_type(
self._current_tool_name, key, request.tools self._current_tool_name, key, self.tools
) )
if is_string: if is_string:
......
...@@ -200,7 +200,7 @@ class Internlm2ToolParser(ToolParser): ...@@ -200,7 +200,7 @@ class Internlm2ToolParser(ToolParser):
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> ExtractedToolCallInformation: ) -> ExtractedToolCallInformation:
text = model_output text = model_output
tools = request.tools tools = self.tools
if "<|action_start|><|plugin|>" in text: if "<|action_start|><|plugin|>" in text:
text, action = text.split("<|action_start|><|plugin|>") text, action = text.split("<|action_start|><|plugin|>")
action = action.split("<|action_end|>".strip())[0] action = action.split("<|action_end|>".strip())[0]
......
...@@ -308,7 +308,7 @@ class MinimaxM2ToolParser(ToolParser): ...@@ -308,7 +308,7 @@ class MinimaxM2ToolParser(ToolParser):
invoke_str = complete_invokes[self.current_tool_index] invoke_str = complete_invokes[self.current_tool_index]
tool_call = self._parse_single_invoke( tool_call = self._parse_single_invoke(
invoke_str, invoke_str,
request.tools if request else None, self.tools,
) )
if not tool_call: if not tool_call:
self.current_tool_index += 1 self.current_tool_index += 1
...@@ -358,9 +358,7 @@ class MinimaxM2ToolParser(ToolParser): ...@@ -358,9 +358,7 @@ class MinimaxM2ToolParser(ToolParser):
for tool_call_match in self.tool_call_complete_regex.findall(model_output): for tool_call_match in self.tool_call_complete_regex.findall(model_output):
# Find all invokes within this tool_call # Find all invokes within this tool_call
for invoke_match in self.invoke_complete_regex.findall(tool_call_match): for invoke_match in self.invoke_complete_regex.findall(tool_call_match):
tool_call = self._parse_single_invoke( tool_call = self._parse_single_invoke(invoke_match, self.tools)
invoke_match, request.tools if request else None
)
if tool_call: if tool_call:
tool_calls.append(tool_call) tool_calls.append(tool_call)
......
...@@ -314,7 +314,7 @@ class Qwen3CoderToolParser(ToolParser): ...@@ -314,7 +314,7 @@ class Qwen3CoderToolParser(ToolParser):
) )
tool_calls = [ tool_calls = [
self._parse_xml_function_call(function_call_str, request.tools) self._parse_xml_function_call(function_call_str, self.tools)
for function_call_str in function_calls for function_call_str in function_calls
] ]
# Populate prev_tool_call_arr for serving layer to set finish_reason # Populate prev_tool_call_arr for serving layer to set finish_reason
...@@ -607,7 +607,7 @@ class Qwen3CoderToolParser(ToolParser): ...@@ -607,7 +607,7 @@ class Qwen3CoderToolParser(ToolParser):
param_config = self._get_arguments_config( param_config = self._get_arguments_config(
self.current_function_name or "", self.current_function_name or "",
self.streaming_request.tools if self.streaming_request else None, self.tools,
) )
converted_value = self._convert_param_value( converted_value = self._convert_param_value(
...@@ -666,9 +666,7 @@ class Qwen3CoderToolParser(ToolParser): ...@@ -666,9 +666,7 @@ class Qwen3CoderToolParser(ToolParser):
try: try:
parsed_tool = self._parse_xml_function_call( parsed_tool = self._parse_xml_function_call(
func_content, func_content,
self.streaming_request.tools self.tools,
if self.streaming_request
else None,
) )
if parsed_tool and self.current_tool_index < len( if parsed_tool and self.current_tool_index < len(
self.prev_tool_call_arr self.prev_tool_call_arr
......
...@@ -1188,8 +1188,7 @@ class Qwen3XMLToolParser(ToolParser): ...@@ -1188,8 +1188,7 @@ class Qwen3XMLToolParser(ToolParser):
# Reset tool call tracking arrays for new extraction # Reset tool call tracking arrays for new extraction
self.prev_tool_call_arr = [] self.prev_tool_call_arr = []
self.streamed_args_for_tool = [] self.streamed_args_for_tool = []
if request: self.parser.set_tools(self.tools)
self.parser.set_tools(request.tools)
result = self.parser.parse_single_streaming_chunks(model_output) result = self.parser.parse_single_streaming_chunks(model_output)
if not result.tool_calls: if not result.tool_calls:
return ExtractedToolCallInformation( return ExtractedToolCallInformation(
...@@ -1260,8 +1259,7 @@ class Qwen3XMLToolParser(ToolParser): ...@@ -1260,8 +1259,7 @@ class Qwen3XMLToolParser(ToolParser):
# Reset tool call tracking arrays for new streaming session # Reset tool call tracking arrays for new streaming session
self.prev_tool_call_arr = [] self.prev_tool_call_arr = []
self.streamed_args_for_tool = [] self.streamed_args_for_tool = []
if request: self.parser.set_tools(self.tools)
self.parser.set_tools(request.tools)
# Model sometimes outputs separately causing delta_text to be empty. # Model sometimes outputs separately causing delta_text to be empty.
# If there were tool_calls before and all current tool_calls have ended, # If there were tool_calls before and all current tool_calls have ended,
......
...@@ -312,7 +312,7 @@ class SeedOssToolParser(ToolParser): ...@@ -312,7 +312,7 @@ class SeedOssToolParser(ToolParser):
) )
tool_calls = [ tool_calls = [
self._parse_xml_function_call(function_call_str, request.tools) self._parse_xml_function_call(function_call_str, self.tools)
for function_call_str in function_calls for function_call_str in function_calls
] ]
...@@ -566,7 +566,7 @@ class SeedOssToolParser(ToolParser): ...@@ -566,7 +566,7 @@ class SeedOssToolParser(ToolParser):
# Parse to get the complete arguments # Parse to get the complete arguments
try: try:
parsed_tool = self._parse_xml_function_call( parsed_tool = self._parse_xml_function_call(
func_content, request.tools if request else None func_content, self.tools
) )
if parsed_tool: if parsed_tool:
# Update existing entry in prev_tool_call_arr with complete arguments # Update existing entry in prev_tool_call_arr with complete arguments
......
...@@ -82,9 +82,8 @@ class Step3ToolParser(ToolParser): ...@@ -82,9 +82,8 @@ class Step3ToolParser(ToolParser):
self, self,
func_name: str, func_name: str,
params: dict[str, Any], params: dict[str, Any],
request: ChatCompletionRequest,
) -> dict[str, Any]: ) -> dict[str, Any]:
for tool in request.tools or []: for tool in self.tools or []:
if tool.function.name == func_name: if tool.function.name == func_name:
schema = tool.function.parameters or {} schema = tool.function.parameters or {}
properties = schema.get("properties", {}) properties = schema.get("properties", {})
...@@ -234,7 +233,6 @@ class Step3ToolParser(ToolParser): ...@@ -234,7 +233,6 @@ class Step3ToolParser(ToolParser):
final_args = self._cast_arguments( final_args = self._cast_arguments(
function_name, function_name,
tool_call_arr.get("parameters", {}), # type: ignore tool_call_arr.get("parameters", {}), # type: ignore
request,
) )
if final_args: if final_args:
final_args_json = json.dumps(final_args, ensure_ascii=False) final_args_json = json.dumps(final_args, ensure_ascii=False)
...@@ -291,7 +289,7 @@ class Step3ToolParser(ToolParser): ...@@ -291,7 +289,7 @@ class Step3ToolParser(ToolParser):
function_name, params_dict = self._parse_steptml_invoke(invoke_part) function_name, params_dict = self._parse_steptml_invoke(invoke_part)
if function_name and params_dict is not None: if function_name and params_dict is not None:
params_dict = self._cast_arguments(function_name, params_dict, request) params_dict = self._cast_arguments(function_name, params_dict)
params_str = json.dumps(params_dict, ensure_ascii=False) params_str = json.dumps(params_dict, ensure_ascii=False)
tool_calls.append( tool_calls.append(
ToolCall( ToolCall(
......
...@@ -1385,8 +1385,7 @@ class Step3p5ToolParser(ToolParser): ...@@ -1385,8 +1385,7 @@ class Step3p5ToolParser(ToolParser):
# Reset tool call tracking arrays for new extraction # Reset tool call tracking arrays for new extraction
self.prev_tool_call_arr = [] self.prev_tool_call_arr = []
self.streamed_args_for_tool = [] self.streamed_args_for_tool = []
if request: self.parser.set_tools(self.tools)
self.parser.set_tools(request.tools)
result = self.parser.parse_single_streaming_chunks(model_output) result = self.parser.parse_single_streaming_chunks(model_output)
if not result.tool_calls: if not result.tool_calls:
return ExtractedToolCallInformation( return ExtractedToolCallInformation(
...@@ -1457,8 +1456,7 @@ class Step3p5ToolParser(ToolParser): ...@@ -1457,8 +1456,7 @@ class Step3p5ToolParser(ToolParser):
# Reset tool call tracking arrays for new streaming session # Reset tool call tracking arrays for new streaming session
self.prev_tool_call_arr = [] self.prev_tool_call_arr = []
self.streamed_args_for_tool = [] self.streamed_args_for_tool = []
if request: self.parser.set_tools(self.tools)
self.parser.set_tools(request.tools)
# Model sometimes outputs separately causing delta_text to be empty. # Model sometimes outputs separately causing delta_text to be empty.
# If there were tool_calls before and all current tool_calls have ended, # If there were tool_calls before and all current tool_calls have ended,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment