Unverified Commit 8d6b3d5d authored by Taneem Ibrahim's avatar Taneem Ibrahim Committed by GitHub
Browse files

[Misc] Refactored 5 duplicate helper functions that were copied-pasted across...


[Misc] Refactored 5 duplicate helper functions that were copied-pasted across multiple parsers (#36436)
Signed-off-by: default avatarTaneem Ibrahim <taneem.ibrahim@gmail.com>
parent 4b87ffbe
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast import ast
import json
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any
import regex as re import regex as re
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
...@@ -13,25 +12,23 @@ from vllm.entrypoints.openai.chat_completion.protocol import ( ...@@ -13,25 +12,23 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
) )
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation, ExtractedToolCallInformation,
FunctionCall,
ToolCall,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.tool_parsers.utils import (
UnexpectedAstError,
compute_tool_delta,
handle_single_tool,
make_valid_python,
)
logger = init_logger(__name__) logger = init_logger(__name__)
class _UnexpectedAstError(Exception):
pass
class Llama4PythonicToolParser(ToolParser): class Llama4PythonicToolParser(ToolParser):
""" """
Toolcall parser for Llama4 that produce tool calls in a pythonic style Toolcall parser for Llama4 that produce tool calls in a pythonic style
...@@ -103,15 +100,13 @@ class Llama4PythonicToolParser(ToolParser): ...@@ -103,15 +100,13 @@ class Llama4PythonicToolParser(ToolParser):
return ExtractedToolCallInformation( return ExtractedToolCallInformation(
tools_called=True, tools_called=True,
tool_calls=[ tool_calls=[
_handle_single_tool(e) # type: ignore handle_single_tool(e) # type: ignore
for e in parsed.elts for e in parsed.elts
], ],
content=None, content=None,
) )
else: else:
raise _UnexpectedAstError( raise UnexpectedAstError("Tool output must be a list of function calls")
"Tool output must be a list of function calls"
)
except Exception: except Exception:
logger.exception("Error in extracting tool call from response.") logger.exception("Error in extracting tool call from response.")
# Treat as regular text # Treat as regular text
...@@ -140,7 +135,7 @@ class Llama4PythonicToolParser(ToolParser): ...@@ -140,7 +135,7 @@ class Llama4PythonicToolParser(ToolParser):
current_text = current_text[len("<|python_start|>") :] current_text = current_text[len("<|python_start|>") :]
if current_text.endswith("<|python_end|>"): if current_text.endswith("<|python_end|>"):
current_text = current_text[: current_text.rfind("<|python_end|>")] current_text = current_text[: current_text.rfind("<|python_end|>")]
valid_and_added_text = _make_valid_python(current_text) valid_and_added_text = make_valid_python(current_text)
if valid_and_added_text is None: if valid_and_added_text is None:
return None return None
valid_text, added_text = valid_and_added_text valid_text, added_text = valid_and_added_text
...@@ -150,11 +145,9 @@ class Llama4PythonicToolParser(ToolParser): ...@@ -150,11 +145,9 @@ class Llama4PythonicToolParser(ToolParser):
if not isinstance(parsed, ast.List) or not all( if not isinstance(parsed, ast.List) or not all(
isinstance(e, ast.Call) for e in parsed.elts isinstance(e, ast.Call) for e in parsed.elts
): ):
raise _UnexpectedAstError( raise UnexpectedAstError("Tool output must be a list of function calls")
"Tool output must be a list of function calls"
)
tool_calls = [ tool_calls = [
_handle_single_tool(e) # type: ignore handle_single_tool(e) # type: ignore
for e in parsed.elts for e in parsed.elts
] ]
...@@ -180,7 +173,7 @@ class Llama4PythonicToolParser(ToolParser): ...@@ -180,7 +173,7 @@ class Llama4PythonicToolParser(ToolParser):
# Strings get single quotes in the model-produced string. # Strings get single quotes in the model-produced string.
# JSON requires double quotes. # JSON requires double quotes.
withheld_suffix = withheld_suffix.replace("'", '"') withheld_suffix = withheld_suffix.replace("'", '"')
delta = _compute_tool_delta( delta = compute_tool_delta(
self.streamed_args_for_tool[index], new_call, index, withheld_suffix self.streamed_args_for_tool[index], new_call, index, withheld_suffix
) )
...@@ -214,130 +207,3 @@ class Llama4PythonicToolParser(ToolParser): ...@@ -214,130 +207,3 @@ class Llama4PythonicToolParser(ToolParser):
"Skipping chunk as a result of tool streaming extraction error" "Skipping chunk as a result of tool streaming extraction error"
) )
return None return None
def _get_parameter_value(val: ast.expr) -> Any:
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
if not all(isinstance(k, ast.Constant) for k in val.keys):
raise _UnexpectedAstError("Dict tool call arguments must have literal keys")
return {
k.value: _get_parameter_value(v) # type: ignore
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [_get_parameter_value(v) for v in val.elts]
else:
raise _UnexpectedAstError("Tool call arguments must be literals")
def _handle_single_tool(call: ast.Call) -> ToolCall:
if not isinstance(call.func, ast.Name):
raise _UnexpectedAstError("Invalid tool call name")
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = _get_parameter_value(keyword.value)
return ToolCall(
type="function",
function=FunctionCall(name=function_name, arguments=json.dumps(arguments)),
)
def _make_valid_python(text: str) -> tuple[str, str] | None:
bracket_stack = []
for index, char in enumerate(text):
if char in {"[", "(", "{"}:
bracket_stack.append(char)
elif char == "]":
if not bracket_stack or bracket_stack.pop() != "[":
raise _UnexpectedAstError("Mismatched square brackets")
elif char == ")":
if not bracket_stack or bracket_stack.pop() != "(":
raise _UnexpectedAstError("Mismatched parentheses")
elif char == "}":
if not bracket_stack or bracket_stack.pop() != "{":
raise _UnexpectedAstError("Mismatched curly braces")
elif char in {"'", '"'}:
if bracket_stack and bracket_stack[-1] == char:
if index > 0 and text[index - 1] == "\\":
# Treat an escaped quote as a regular character
pass
else:
bracket_stack.pop()
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
# Double quote within a single quote string or vice versa.
pass
else:
bracket_stack.append(char)
text = text.rstrip()
if text.endswith("=") or text.endswith(":"):
# Since we have no type information for this property/parameter value,
# we can't fill in a valid value.
return None
if bracket_stack and bracket_stack[-1] == "{":
trailing_dict_text = text[: text.rfind("{")]
num_keys = trailing_dict_text.count(":")
num_values = trailing_dict_text.count(",")
if num_keys <= num_values:
return None # Incomplete property name within parameter value
if bracket_stack and bracket_stack[-1] == "(":
trailing_params_text = text[: text.rfind("(")]
num_full_param_names = trailing_params_text.count("=")
num_full_param_values = trailing_params_text.count(",")
if num_full_param_names <= num_full_param_values:
return None # Incomplete parameter name
if text.endswith(","):
text = text[:-1]
if (
bracket_stack
and bracket_stack[-1] == "["
and not text.endswith("[")
and not text.endswith(")")
):
return None # Incomplete function name
added_text = ""
for char in reversed(bracket_stack):
if char == "[":
added_text += "]"
elif char == "(":
added_text += ")"
elif char == "{":
added_text += "}"
elif char == "'":
added_text += "'"
elif char == '"':
added_text += '"'
return text + added_text, added_text
def _compute_tool_delta(
previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str
) -> DeltaToolCall | None:
new_call_args = new_call.function.arguments
if withheld_suffix:
assert new_call_args.endswith(withheld_suffix)
new_call_args = new_call_args[: -len(withheld_suffix)]
if not previously_sent_args:
return DeltaToolCall(
id=new_call.id,
type="function",
index=index,
function=DeltaFunctionCall(
name=new_call.function.name,
arguments=new_call_args,
),
)
arg_diff = new_call_args[len(previously_sent_args) :]
return (
DeltaToolCall(
id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff)
)
if arg_diff
else None
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast import ast
import json
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any
import regex as re import regex as re
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
...@@ -13,25 +12,23 @@ from vllm.entrypoints.openai.chat_completion.protocol import ( ...@@ -13,25 +12,23 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
) )
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation, ExtractedToolCallInformation,
FunctionCall,
ToolCall,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.tool_parsers.utils import (
UnexpectedAstError,
compute_tool_delta,
handle_single_tool,
make_valid_python,
)
logger = init_logger(__name__) logger = init_logger(__name__)
class _UnexpectedAstError(Exception):
pass
class Olmo3PythonicToolParser(ToolParser): class Olmo3PythonicToolParser(ToolParser):
""" """
Tool call parser for Olmo 3 models that produce tool calls as Tool call parser for Olmo 3 models that produce tool calls as
...@@ -113,15 +110,13 @@ class Olmo3PythonicToolParser(ToolParser): ...@@ -113,15 +110,13 @@ class Olmo3PythonicToolParser(ToolParser):
return ExtractedToolCallInformation( return ExtractedToolCallInformation(
tools_called=True, tools_called=True,
tool_calls=[ tool_calls=[
_handle_single_tool(e) # type: ignore handle_single_tool(e) # type: ignore
for e in parsed.elts for e in parsed.elts
], ],
content=None, content=None,
) )
else: else:
raise _UnexpectedAstError( raise UnexpectedAstError("Tool output must be a list of function calls")
"Tool output must be a list of function calls"
)
except Exception: except Exception:
logger.exception("Error in extracting tool call from response.") logger.exception("Error in extracting tool call from response.")
# Treat as regular text # Treat as regular text
...@@ -151,7 +146,7 @@ class Olmo3PythonicToolParser(ToolParser): ...@@ -151,7 +146,7 @@ class Olmo3PythonicToolParser(ToolParser):
if current_text.endswith("</function_calls>"): if current_text.endswith("</function_calls>"):
current_text = current_text[: -len("</function_calls>")] current_text = current_text[: -len("</function_calls>")]
valid_and_added_text = _make_valid_python(current_text) valid_and_added_text = make_valid_python(current_text)
if valid_and_added_text is None: if valid_and_added_text is None:
return None return None
valid_text, added_text = valid_and_added_text valid_text, added_text = valid_and_added_text
...@@ -166,11 +161,11 @@ class Olmo3PythonicToolParser(ToolParser): ...@@ -166,11 +161,11 @@ class Olmo3PythonicToolParser(ToolParser):
if not isinstance(parsed, ast.List) or not all( if not isinstance(parsed, ast.List) or not all(
isinstance(e, ast.Call) for e in parsed.elts isinstance(e, ast.Call) for e in parsed.elts
): ):
raise _UnexpectedAstError( raise UnexpectedAstError(
"Tool output must be a sequence of newline-separated calls" "Tool output must be a sequence of newline-separated calls"
) )
tool_calls = [ tool_calls = [
_handle_single_tool(e) # type: ignore handle_single_tool(e) # type: ignore
for e in parsed.elts for e in parsed.elts
] ]
...@@ -194,7 +189,7 @@ class Olmo3PythonicToolParser(ToolParser): ...@@ -194,7 +189,7 @@ class Olmo3PythonicToolParser(ToolParser):
# Strings get single quotes in the model-produced string. # Strings get single quotes in the model-produced string.
# JSON requires double quotes. # JSON requires double quotes.
withheld_suffix = withheld_suffix.replace("'", '"') withheld_suffix = withheld_suffix.replace("'", '"')
delta = _compute_tool_delta( delta = compute_tool_delta(
self.streamed_args_for_tool[index], new_call, index, withheld_suffix self.streamed_args_for_tool[index], new_call, index, withheld_suffix
) )
...@@ -228,141 +223,3 @@ class Olmo3PythonicToolParser(ToolParser): ...@@ -228,141 +223,3 @@ class Olmo3PythonicToolParser(ToolParser):
"Skipping chunk as a result of tool streaming extraction error" "Skipping chunk as a result of tool streaming extraction error"
) )
return None return None
def _get_parameter_value(val: ast.expr) -> Any:
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
if not all(isinstance(k, ast.Constant) for k in val.keys):
raise _UnexpectedAstError("Dict tool call arguments must have literal keys")
return {
k.value: _get_parameter_value(v) # type: ignore
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [_get_parameter_value(v) for v in val.elts]
# The model may return function calls where the values are null/true/false
# because the system prompt has API description in json.
elif isinstance(val, ast.Name) and val.id in ["null", "true", "false"]:
if val.id == "null":
return None
elif val.id == "true":
return True
elif val.id == "false":
return False
else:
raise _UnexpectedAstError("Tool call arguments must be literals")
def _handle_single_tool(call: ast.Call) -> ToolCall:
if not isinstance(call.func, ast.Name):
raise _UnexpectedAstError("Invalid tool call name")
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = _get_parameter_value(keyword.value)
return ToolCall(
type="function",
function=FunctionCall(
name=function_name, arguments=json.dumps(arguments, ensure_ascii=False)
),
)
def _make_valid_python(text: str) -> tuple[str, str] | None:
bracket_stack = []
for index, char in enumerate(text):
if char in {"[", "(", "{"}:
bracket_stack.append(char)
elif char == "]":
if not bracket_stack or bracket_stack.pop() != "[":
raise _UnexpectedAstError("Mismatched square brackets")
elif char == ")":
if not bracket_stack or bracket_stack.pop() != "(":
raise _UnexpectedAstError("Mismatched parentheses")
elif char == "}":
if not bracket_stack or bracket_stack.pop() != "{":
raise _UnexpectedAstError("Mismatched curly braces")
elif char in {"'", '"'}:
if bracket_stack and bracket_stack[-1] == char:
if index > 0 and text[index - 1] == "\\":
# Treat an escaped quote as a regular character
pass
else:
bracket_stack.pop()
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
# Double quote within a single quote string or vice versa.
pass
else:
bracket_stack.append(char)
text = text.rstrip()
if text.endswith("=") or text.endswith(":"):
# Since we have no type information for this property/parameter value,
# we can't fill in a valid value.
return None
if bracket_stack and bracket_stack[-1] == "{":
trailing_dict_text = text[: text.rfind("{")]
num_keys = trailing_dict_text.count(":")
num_values = trailing_dict_text.count(",")
if num_keys <= num_values:
return None # Incomplete property name within parameter value
if bracket_stack and bracket_stack[-1] == "(":
trailing_params_text = text[: text.rfind("(")]
num_full_param_names = trailing_params_text.count("=")
num_full_param_values = trailing_params_text.count(",")
if num_full_param_names <= num_full_param_values:
return None # Incomplete parameter name
if text.endswith(","):
text = text[:-1]
if (
bracket_stack
and bracket_stack[-1] == "["
and not text.endswith("[")
and not text.endswith(")")
):
return None # Incomplete function name
added_text = ""
for char in reversed(bracket_stack):
if char == "[":
added_text += "]"
elif char == "(":
added_text += ")"
elif char == "{":
added_text += "}"
elif char == "'":
added_text += "'"
elif char == '"':
added_text += '"'
return text + added_text, added_text
def _compute_tool_delta(
previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str
) -> DeltaToolCall | None:
new_call_args = new_call.function.arguments
if withheld_suffix:
assert new_call_args.endswith(withheld_suffix)
new_call_args = new_call_args[: -len(withheld_suffix)]
if not previously_sent_args:
return DeltaToolCall(
id=new_call.id,
type="function",
index=index,
function=DeltaFunctionCall(
name=new_call.function.name,
arguments=new_call_args,
),
)
arg_diff = new_call_args[len(previously_sent_args) :]
return (
DeltaToolCall(
id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff)
)
if arg_diff
else None
)
...@@ -2,9 +2,7 @@ ...@@ -2,9 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast import ast
import json
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any
import regex as re import regex as re
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
...@@ -14,25 +12,23 @@ from vllm.entrypoints.openai.chat_completion.protocol import ( ...@@ -14,25 +12,23 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
) )
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation, ExtractedToolCallInformation,
FunctionCall,
ToolCall,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import ( from vllm.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.tool_parsers.utils import (
UnexpectedAstError,
compute_tool_delta,
handle_single_tool,
make_valid_python,
)
logger = init_logger(__name__) logger = init_logger(__name__)
class _UnexpectedAstError(Exception):
pass
class PythonicToolParser(ToolParser): class PythonicToolParser(ToolParser):
""" """
Tool call parser for models that produce tool calls in a pythonic style, Tool call parser for models that produce tool calls in a pythonic style,
...@@ -99,15 +95,13 @@ class PythonicToolParser(ToolParser): ...@@ -99,15 +95,13 @@ class PythonicToolParser(ToolParser):
return ExtractedToolCallInformation( return ExtractedToolCallInformation(
tools_called=True, tools_called=True,
tool_calls=[ tool_calls=[
_handle_single_tool(e) # type: ignore handle_single_tool(e) # type: ignore
for e in parsed.elts for e in parsed.elts
], ],
content=None, content=None,
) )
else: else:
raise _UnexpectedAstError( raise UnexpectedAstError("Tool output must be a list of function calls")
"Tool output must be a list of function calls"
)
except Exception: except Exception:
logger.exception("Error in extracting tool call from response.") logger.exception("Error in extracting tool call from response.")
# Treat as regular text # Treat as regular text
...@@ -129,7 +123,7 @@ class PythonicToolParser(ToolParser): ...@@ -129,7 +123,7 @@ class PythonicToolParser(ToolParser):
return DeltaMessage(content=delta_text) return DeltaMessage(content=delta_text)
try: try:
valid_and_added_text = _make_valid_python(current_text) valid_and_added_text = make_valid_python(current_text)
if valid_and_added_text is None: if valid_and_added_text is None:
return None return None
valid_text, added_text = valid_and_added_text valid_text, added_text = valid_and_added_text
...@@ -139,11 +133,9 @@ class PythonicToolParser(ToolParser): ...@@ -139,11 +133,9 @@ class PythonicToolParser(ToolParser):
if not isinstance(parsed, ast.List) or not all( if not isinstance(parsed, ast.List) or not all(
isinstance(e, ast.Call) for e in parsed.elts isinstance(e, ast.Call) for e in parsed.elts
): ):
raise _UnexpectedAstError( raise UnexpectedAstError("Tool output must be a list of function calls")
"Tool output must be a list of function calls"
)
tool_calls = [ tool_calls = [
_handle_single_tool(e) # type: ignore handle_single_tool(e) # type: ignore
for e in parsed.elts for e in parsed.elts
] ]
...@@ -169,7 +161,7 @@ class PythonicToolParser(ToolParser): ...@@ -169,7 +161,7 @@ class PythonicToolParser(ToolParser):
# Strings get single quotes in the model-produced string. # Strings get single quotes in the model-produced string.
# JSON requires double quotes. # JSON requires double quotes.
withheld_suffix = withheld_suffix.replace("'", '"') withheld_suffix = withheld_suffix.replace("'", '"')
delta = _compute_tool_delta( delta = compute_tool_delta(
self.streamed_args_for_tool[index], new_call, index, withheld_suffix self.streamed_args_for_tool[index], new_call, index, withheld_suffix
) )
...@@ -203,132 +195,3 @@ class PythonicToolParser(ToolParser): ...@@ -203,132 +195,3 @@ class PythonicToolParser(ToolParser):
"Skipping chunk as a result of tool streaming extraction error" "Skipping chunk as a result of tool streaming extraction error"
) )
return None return None
def _get_parameter_value(val: ast.expr) -> Any:
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
if not all(isinstance(k, ast.Constant) for k in val.keys):
raise _UnexpectedAstError("Dict tool call arguments must have literal keys")
return {
k.value: _get_parameter_value(v) # type: ignore
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [_get_parameter_value(v) for v in val.elts]
else:
raise _UnexpectedAstError("Tool call arguments must be literals")
def _handle_single_tool(call: ast.Call) -> ToolCall:
if not isinstance(call.func, ast.Name):
raise _UnexpectedAstError("Invalid tool call name")
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = _get_parameter_value(keyword.value)
return ToolCall(
type="function",
function=FunctionCall(
name=function_name, arguments=json.dumps(arguments, ensure_ascii=False)
),
)
def _make_valid_python(text: str) -> tuple[str, str] | None:
bracket_stack = []
for index, char in enumerate(text):
if char in {"[", "(", "{"}:
bracket_stack.append(char)
elif char == "]":
if not bracket_stack or bracket_stack.pop() != "[":
raise _UnexpectedAstError("Mismatched square brackets")
elif char == ")":
if not bracket_stack or bracket_stack.pop() != "(":
raise _UnexpectedAstError("Mismatched parentheses")
elif char == "}":
if not bracket_stack or bracket_stack.pop() != "{":
raise _UnexpectedAstError("Mismatched curly braces")
elif char in {"'", '"'}:
if bracket_stack and bracket_stack[-1] == char:
if index > 0 and text[index - 1] == "\\":
# Treat an escaped quote as a regular character
pass
else:
bracket_stack.pop()
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
# Double quote within a single quote string or vice versa.
pass
else:
bracket_stack.append(char)
text = text.rstrip()
if text.endswith("=") or text.endswith(":"):
# Since we have no type information for this property/parameter value,
# we can't fill in a valid value.
return None
if bracket_stack and bracket_stack[-1] == "{":
trailing_dict_text = text[: text.rfind("{")]
num_keys = trailing_dict_text.count(":")
num_values = trailing_dict_text.count(",")
if num_keys <= num_values:
return None # Incomplete property name within parameter value
if bracket_stack and bracket_stack[-1] == "(":
trailing_params_text = text[: text.rfind("(")]
num_full_param_names = trailing_params_text.count("=")
num_full_param_values = trailing_params_text.count(",")
if num_full_param_names <= num_full_param_values:
return None # Incomplete parameter name
if text.endswith(","):
text = text[:-1]
if (
bracket_stack
and bracket_stack[-1] == "["
and not text.endswith("[")
and not text.endswith(")")
):
return None # Incomplete function name
added_text = ""
for char in reversed(bracket_stack):
if char == "[":
added_text += "]"
elif char == "(":
added_text += ")"
elif char == "{":
added_text += "}"
elif char == "'":
added_text += "'"
elif char == '"':
added_text += '"'
return text + added_text, added_text
def _compute_tool_delta(
previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str
) -> DeltaToolCall | None:
new_call_args = new_call.function.arguments
if withheld_suffix:
assert new_call_args.endswith(withheld_suffix)
new_call_args = new_call_args[: -len(withheld_suffix)]
if not previously_sent_args:
return DeltaToolCall(
id=new_call.id,
type="function",
index=index,
function=DeltaFunctionCall(
name=new_call.function.name,
arguments=new_call_args,
),
)
arg_diff = new_call_args[len(previously_sent_args) :]
return (
DeltaToolCall(
id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff)
)
if arg_diff
else None
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json import json
from json import JSONDecodeError, JSONDecoder from json import JSONDecodeError, JSONDecoder
from typing import Any from typing import Any
...@@ -17,6 +18,15 @@ from vllm.entrypoints.openai.chat_completion.protocol import ( ...@@ -17,6 +18,15 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionNamedToolChoiceParam, ChatCompletionNamedToolChoiceParam,
ChatCompletionToolsParam, ChatCompletionToolsParam,
) )
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaToolCall,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
def find_common_prefix(s1: str, s2: str) -> str: def find_common_prefix(s1: str, s2: str) -> str:
...@@ -212,3 +222,202 @@ def get_json_schema_from_tools( ...@@ -212,3 +222,202 @@ def get_json_schema_from_tools(
return _get_json_schema_from_tools(tools) return _get_json_schema_from_tools(tools)
# tool_choice: "auto" # tool_choice: "auto"
return None return None
# ---------------------------------------------------------------------------
# Shared utilities for pythonic-style tool call parsers
# (PythonicToolParser, Llama4PythonicToolParser, Olmo3PythonicToolParser)
# ---------------------------------------------------------------------------
class UnexpectedAstError(Exception):
"""Raised when the AST structure does not match the expected
pythonic tool call format."""
pass
_JSON_NAME_LITERALS = {
"null": None,
"true": True,
"false": False,
}
def get_parameter_value(val: ast.expr) -> Any:
"""Extract a Python literal value from an AST expression node.
Handles constants, dicts, lists, and JSON-style name literals
(null, true, false) that some models produce instead of Python
literals (None, True, False).
Raises:
UnexpectedAstError: If the AST node is not a supported literal type.
"""
if isinstance(val, ast.Constant):
return val.value
elif isinstance(val, ast.Dict):
if not all(isinstance(k, ast.Constant) for k in val.keys):
logger.warning(
"Dict argument keys are not all literals: %s",
ast.dump(val),
)
raise UnexpectedAstError("Dict tool call arguments must have literal keys")
return {
k.value: get_parameter_value(v) # type: ignore
for k, v in zip(val.keys, val.values)
}
elif isinstance(val, ast.List):
return [get_parameter_value(v) for v in val.elts]
elif isinstance(val, ast.Name) and val.id in _JSON_NAME_LITERALS:
return _JSON_NAME_LITERALS[val.id]
else:
logger.warning(
"Unsupported AST node type in tool call arguments: %s",
ast.dump(val),
)
raise UnexpectedAstError("Tool call arguments must be literals")
def handle_single_tool(call: ast.Call) -> ToolCall:
"""Convert a single AST function call node into a ToolCall object.
Raises:
UnexpectedAstError: If the call node does not have a simple
function name (e.g. it's an attribute access or subscript).
"""
if not isinstance(call.func, ast.Name):
logger.warning(
"Tool call has non-simple function name: %s",
ast.dump(call.func),
)
raise UnexpectedAstError("Invalid tool call name")
function_name = call.func.id
arguments = {}
for keyword in call.keywords:
arguments[keyword.arg] = get_parameter_value(keyword.value)
return ToolCall(
type="function",
function=FunctionCall(
name=function_name,
arguments=json.dumps(arguments, ensure_ascii=False),
),
)
def make_valid_python(text: str) -> tuple[str, str] | None:
"""Attempt to close all open brackets/quotes to make partial Python valid.
Used during streaming to parse incomplete tool call expressions by
appending the necessary closing characters.
Returns:
A tuple of (completed_text, added_suffix) if the text can be
made valid, or None if the text is too incomplete to complete
meaningfully (e.g. mid-parameter-name or mid-dict-key).
Raises:
UnexpectedAstError: If mismatched brackets or parentheses
are detected.
"""
bracket_stack: list[str] = []
for index, char in enumerate(text):
if char in {"[", "(", "{"}:
bracket_stack.append(char)
elif char == "]":
if not bracket_stack or bracket_stack.pop() != "[":
raise UnexpectedAstError("Mismatched square brackets")
elif char == ")":
if not bracket_stack or bracket_stack.pop() != "(":
raise UnexpectedAstError("Mismatched parentheses")
elif char == "}":
if not bracket_stack or bracket_stack.pop() != "{":
raise UnexpectedAstError("Mismatched curly braces")
elif char in {"'", '"'}:
if bracket_stack and bracket_stack[-1] == char:
if index > 0 and text[index - 1] == "\\":
pass
else:
bracket_stack.pop()
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
pass
else:
bracket_stack.append(char)
text = text.rstrip()
if text.endswith("=") or text.endswith(":"):
return None
if bracket_stack and bracket_stack[-1] == "{":
trailing_dict_text = text[: text.rfind("{")]
num_keys = trailing_dict_text.count(":")
num_values = trailing_dict_text.count(",")
if num_keys <= num_values:
return None
if bracket_stack and bracket_stack[-1] == "(":
trailing_params_text = text[: text.rfind("(")]
num_full_param_names = trailing_params_text.count("=")
num_full_param_values = trailing_params_text.count(",")
if num_full_param_names <= num_full_param_values:
return None
if text.endswith(","):
text = text[:-1]
if (
bracket_stack
and bracket_stack[-1] == "["
and not text.endswith("[")
and not text.endswith(")")
):
return None
_CLOSING = {"[": "]", "(": ")", "{": "}", "'": "'", '"': '"'}
added_text = ""
for char in reversed(bracket_stack):
added_text += _CLOSING[char]
return text + added_text, added_text
def compute_tool_delta(
previously_sent_args: str,
new_call: ToolCall,
index: int,
withheld_suffix: str,
) -> DeltaToolCall | None:
"""Compute the incremental delta between previously streamed arguments
and the current tool call state.
Returns:
A DeltaToolCall with only the new argument characters, or None
if there is no difference from what was previously sent.
"""
new_call_args = new_call.function.arguments
if withheld_suffix:
if not new_call_args.endswith(withheld_suffix):
msg = (
f"Tool call arguments '{new_call_args}' do not end with "
f"expected withheld suffix '{withheld_suffix}'"
)
logger.error(msg)
raise ValueError(msg)
new_call_args = new_call_args[: -len(withheld_suffix)]
if not previously_sent_args:
return DeltaToolCall(
id=new_call.id,
type="function",
index=index,
function=DeltaFunctionCall(
name=new_call.function.name,
arguments=new_call_args,
),
)
arg_diff = new_call_args[len(previously_sent_args) :]
return (
DeltaToolCall(
id=None,
index=index,
function=DeltaFunctionCall(arguments=arg_diff),
)
if arg_diff
else None
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment