tool_server.py 6.43 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from contextlib import AbstractAsyncContextManager, asynccontextmanager
5
from typing import TYPE_CHECKING, Any, Optional
6

7
from openai_harmony import ToolDescription, ToolNamespaceConfig
8
9
10
11
12
13

from vllm.entrypoints.tool import HarmonyBrowserTool, HarmonyPythonTool, Tool
from vllm.logger import init_logger

logger = init_logger(__name__)

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
if TYPE_CHECKING:
    from mcp.types import ListToolsResult


async def list_server_and_tools(server_url: str):
    from mcp import ClientSession
    from mcp.client.sse import sse_client

    async with sse_client(url=server_url) as streams, ClientSession(
            *streams) as session:
        initialize_response = await session.initialize()
        list_tools_response = await session.list_tools()
        return initialize_response, list_tools_response


def trim_schema(schema: dict) -> dict:
    # Turn JSON Schema from MCP generated into Harmony's variant.
    if "title" in schema:
        del schema["title"]
    if "default" in schema and schema["default"] is None:
        del schema["default"]
    if "anyOf" in schema:
        # Turn "anyOf": [{"type": "type-1"}, {"type": "type-2"}]
        # into "type": ["type-1", "type-2"]
        # if there's more than 1 types, also remove "null" type as Harmony will
        # just ignore it
        types = [
            type_dict["type"] for type_dict in schema["anyOf"]
            if type_dict["type"] != 'null'
        ]
        schema["type"] = types
        del schema["anyOf"]
    if "properties" in schema:
        schema["properties"] = {
            k: trim_schema(v)
            for k, v in schema["properties"].items()
        }
    return schema


def post_process_tools_description(
        list_tools_result: "ListToolsResult") -> "ListToolsResult":
    # Adapt the MCP tool result for Harmony
    for tool in list_tools_result.tools:
        tool.inputSchema = trim_schema(tool.inputSchema)

    # Some tools schema don't need to be part of the prompt (e.g. simple text
    # in text out for Python)
    list_tools_result.tools = [
        tool for tool in list_tools_result.tools
        if getattr(tool.annotations, "include_in_prompt", True)
    ]

    return list_tools_result

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

class ToolServer(ABC):

    @abstractmethod
    def has_tool(self, tool_name: str) -> bool:
        """
        Return True if the tool is supported, False otherwise.
        """
        pass

    @abstractmethod
    def get_tool_description(self,
                             tool_name: str) -> Optional[ToolNamespaceConfig]:
        """
        Return the tool description for the given tool name.
        If the tool is not supported, return None.
        """
        pass

    @abstractmethod
    def new_session(self, tool_name: str) -> AbstractAsyncContextManager[Any]:
        """
        Create a session for the tool.
        """
        ...


96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
class MCPToolServer(ToolServer):

    def __init__(self):
        try:
            import mcp  # noqa: F401
        except ImportError:
            raise ImportError(
                "mcp is not installed. Please run `pip install mcp` to use "
                "MCPToolServer.") from None
        self.harmony_tool_descriptions = {}

    async def add_tool_server(self, server_url: str):
        tool_urls = server_url.split(",")
        self.harmony_tool_descriptions = {}
        self.urls: dict[str, str] = {}
        for url in tool_urls:
            url = f"http://{url}/sse"
            initialize_response, list_tools_response = (
                await list_server_and_tools(url))

            list_tools_response = post_process_tools_description(
                list_tools_response)

            tool_from_mcp = ToolNamespaceConfig(
                name=initialize_response.serverInfo.name,
                description=initialize_response.instructions,
                tools=[
                    ToolDescription.new(name=tool.name,
                                        description=tool.description,
                                        parameters=tool.inputSchema)
                    for tool in list_tools_response.tools
                ])
            self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
            if tool_from_mcp.name not in self.urls:
                self.urls[tool_from_mcp.name] = url
            else:
                logger.warning(
                    "Tool %s already exists. Ignoring duplicate tool server %s",
                    tool_from_mcp.name, url)
135
136
        logger.info("MCPToolServer initialized with tools: %s",
                    list(self.harmony_tool_descriptions.keys()))
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

    def has_tool(self, tool_name: str):
        return tool_name in self.harmony_tool_descriptions

    def get_tool_description(self, tool_name: str):
        return self.harmony_tool_descriptions.get(tool_name)

    @asynccontextmanager
    async def new_session(self, tool_name: str):
        from mcp import ClientSession
        from mcp.client.sse import sse_client
        url = self.urls.get(tool_name)
        if not url:
            raise KeyError(f"Tool '{tool_name}' is not supported")
        async with sse_client(url=url) as streams, ClientSession(
                *streams) as session:
            await session.initialize()
            yield session


157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
class DemoToolServer(ToolServer):

    def __init__(self):
        self.tools: dict[str, Tool] = {}
        browser_tool = HarmonyBrowserTool()
        if browser_tool.enabled:
            self.tools["browser"] = browser_tool
        python_tool = HarmonyPythonTool()
        if python_tool.enabled:
            self.tools["python"] = python_tool
        logger.info("DemoToolServer initialized with tools: %s",
                    list(self.tools.keys()))

    def has_tool(self, tool_name: str) -> bool:
        return tool_name in self.tools

    def get_tool_description(self,
                             tool_name: str) -> Optional[ToolNamespaceConfig]:
        if tool_name not in self.tools:
            return None
        if tool_name == "browser":
            return ToolNamespaceConfig.browser()
        elif tool_name == "python":
            return ToolNamespaceConfig.python()
        else:
            raise ValueError(f"Unknown tool {tool_name}")

    @asynccontextmanager
    async def new_session(self, tool_name: str):
186
187
        if tool_name not in self.tools:
            raise KeyError(f"Tool '{tool_name}' is not supported")
188
        yield self.tools[tool_name]