Unverified Commit 29e48707 authored by Flora Feng's avatar Flora Feng Committed by GitHub
Browse files

[Refactor] Consolidate Tool type alias in tool_parsers/utils.py (#38265)


Signed-off-by: default avatarsfeng33 <4florafeng@gmail.com>
parent 4ac22722
...@@ -5,17 +5,14 @@ import importlib ...@@ -5,17 +5,14 @@ import importlib
import os import os
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from functools import cached_property from functools import cached_property
from typing import TypeAlias
from openai.types.responses import ( from openai.types.responses import (
ResponseFormatTextJSONSchemaConfig, ResponseFormatTextJSONSchemaConfig,
ResponseTextConfig, ResponseTextConfig,
) )
from openai.types.responses.tool import Tool as ResponsesTool
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionToolsParam,
) )
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage, DeltaMessage,
...@@ -29,13 +26,13 @@ from vllm.sampling_params import ( ...@@ -29,13 +26,13 @@ from vllm.sampling_params import (
StructuredOutputsParams, StructuredOutputsParams,
) )
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.utils import get_json_schema_from_tools from vllm.tool_parsers.utils import Tool, get_json_schema_from_tools
from vllm.utils.collection_utils import is_list_of from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import import_from_path from vllm.utils.import_utils import import_from_path
logger = init_logger(__name__) __all__ = ["Tool"]
Tool: TypeAlias = ChatCompletionToolsParam | ResponsesTool logger = init_logger(__name__)
class ToolParser: class ToolParser:
......
...@@ -4,14 +4,14 @@ ...@@ -4,14 +4,14 @@
import ast import ast
import json import json
from json import JSONDecodeError, JSONDecoder from json import JSONDecodeError, JSONDecoder
from typing import Any from typing import Any, TypeAlias
import partial_json_parser import partial_json_parser
from openai.types.responses import ( from openai.types.responses import (
FunctionTool, FunctionTool,
ToolChoiceFunction, ToolChoiceFunction,
) )
from openai.types.responses.tool import Tool from openai.types.responses.tool import Tool as ResponsesTool
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.chat_completion.protocol import (
...@@ -26,6 +26,8 @@ from vllm.entrypoints.openai.engine.protocol import ( ...@@ -26,6 +26,8 @@ from vllm.entrypoints.openai.engine.protocol import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
Tool: TypeAlias = ChatCompletionToolsParam | ResponsesTool
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -130,7 +132,7 @@ def consume_space(i: int, s: str) -> int: ...@@ -130,7 +132,7 @@ def consume_space(i: int, s: str) -> int:
def _extract_tool_info( def _extract_tool_info(
tool: Tool | ChatCompletionToolsParam, tool: Tool,
) -> tuple[str, dict[str, Any] | None]: ) -> tuple[str, dict[str, Any] | None]:
if isinstance(tool, FunctionTool): if isinstance(tool, FunctionTool):
return tool.name, tool.parameters return tool.name, tool.parameters
...@@ -140,7 +142,7 @@ def _extract_tool_info( ...@@ -140,7 +142,7 @@ def _extract_tool_info(
raise TypeError(f"Unsupported tool type: {type(tool)}") raise TypeError(f"Unsupported tool type: {type(tool)}")
def _get_tool_schema_from_tool(tool: Tool | ChatCompletionToolsParam) -> dict: def _get_tool_schema_from_tool(tool: Tool) -> dict:
name, params = _extract_tool_info(tool) name, params = _extract_tool_info(tool)
params = params if params else {"type": "object", "properties": {}} params = params if params else {"type": "object", "properties": {}}
return { return {
...@@ -153,7 +155,7 @@ def _get_tool_schema_from_tool(tool: Tool | ChatCompletionToolsParam) -> dict: ...@@ -153,7 +155,7 @@ def _get_tool_schema_from_tool(tool: Tool | ChatCompletionToolsParam) -> dict:
def _get_tool_schema_defs( def _get_tool_schema_defs(
tools: list[Tool | ChatCompletionToolsParam], tools: list[Tool],
) -> dict: ) -> dict:
all_defs: dict[str, dict[str, Any]] = {} all_defs: dict[str, dict[str, Any]] = {}
for tool in tools: for tool in tools:
...@@ -172,7 +174,7 @@ def _get_tool_schema_defs( ...@@ -172,7 +174,7 @@ def _get_tool_schema_defs(
def _get_json_schema_from_tools( def _get_json_schema_from_tools(
tools: list[Tool | ChatCompletionToolsParam], tools: list[Tool],
) -> dict: ) -> dict:
json_schema = { json_schema = {
"type": "array", "type": "array",
...@@ -190,7 +192,7 @@ def _get_json_schema_from_tools( ...@@ -190,7 +192,7 @@ def _get_json_schema_from_tools(
def get_json_schema_from_tools( def get_json_schema_from_tools(
tool_choice: str | ToolChoiceFunction | ChatCompletionNamedToolChoiceParam, tool_choice: str | ToolChoiceFunction | ChatCompletionNamedToolChoiceParam,
tools: list[FunctionTool | ChatCompletionToolsParam] | None, tools: list[Tool] | None,
) -> str | dict | None: ) -> str | dict | None:
# tool_choice: "none" # tool_choice: "none"
if tool_choice in ("none", None) or tools is None: if tool_choice in ("none", None) or tools is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment