tool_server.py 7.51 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from contextlib import AbstractAsyncContextManager, asynccontextmanager
5
from typing import TYPE_CHECKING, Any
6

7
from openai_harmony import ToolDescription, ToolNamespaceConfig
8
9
10
11
12
13

from vllm.entrypoints.tool import HarmonyBrowserTool, HarmonyPythonTool, Tool
from vllm.logger import init_logger

logger = init_logger(__name__)

14
15
16
17
18
19
20
if TYPE_CHECKING:
    from mcp.types import ListToolsResult


async def list_server_and_tools(server_url: str):
    from mcp import ClientSession
    from mcp.client.sse import sse_client
21
22
23
24
25

    async with (
        sse_client(url=server_url) as streams,
        ClientSession(*streams) as session,
    ):
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
        initialize_response = await session.initialize()
        list_tools_response = await session.list_tools()
        return initialize_response, list_tools_response


def trim_schema(schema: dict) -> dict:
    # Turn JSON Schema from MCP generated into Harmony's variant.
    if "title" in schema:
        del schema["title"]
    if "default" in schema and schema["default"] is None:
        del schema["default"]
    if "anyOf" in schema:
        # Turn "anyOf": [{"type": "type-1"}, {"type": "type-2"}]
        # into "type": ["type-1", "type-2"]
        # if there's more than 1 types, also remove "null" type as Harmony will
        # just ignore it
        types = [
43
44
45
            type_dict["type"]
            for type_dict in schema["anyOf"]
            if type_dict["type"] != "null"
46
47
48
49
50
        ]
        schema["type"] = types
        del schema["anyOf"]
    if "properties" in schema:
        schema["properties"] = {
51
            k: trim_schema(v) for k, v in schema["properties"].items()
52
53
54
55
56
        }
    return schema


def post_process_tools_description(
57
58
    list_tools_result: "ListToolsResult",
) -> "ListToolsResult":
59
60
61
62
63
64
65
    # Adapt the MCP tool result for Harmony
    for tool in list_tools_result.tools:
        tool.inputSchema = trim_schema(tool.inputSchema)

    # Some tools schema don't need to be part of the prompt (e.g. simple text
    # in text out for Python)
    list_tools_result.tools = [
66
67
        tool
        for tool in list_tools_result.tools
68
69
70
71
72
        if getattr(tool.annotations, "include_in_prompt", True)
    ]

    return list_tools_result

73
74
75
76
77
78
79
80
81
82

class ToolServer(ABC):
    @abstractmethod
    def has_tool(self, tool_name: str) -> bool:
        """
        Return True if the tool is supported, False otherwise.
        """
        pass

    @abstractmethod
83
84
85
    def get_tool_description(
        self, tool_name: str, allowed_tools: list[str] | None = None
    ) -> ToolNamespaceConfig | None:
86
87
88
89
90
91
92
        """
        Return the tool description for the given tool name.
        If the tool is not supported, return None.
        """
        pass

    @abstractmethod
93
    def new_session(
94
        self, tool_name: str, session_id: str, headers: dict[str, str] | None = None
95
    ) -> AbstractAsyncContextManager[Any]:
96
97
98
99
100
101
        """
        Create a session for the tool.
        """
        ...


102
103
104
105
106
107
108
class MCPToolServer(ToolServer):
    def __init__(self):
        try:
            import mcp  # noqa: F401
        except ImportError:
            raise ImportError(
                "mcp is not installed. Please run `pip install mcp` to use "
109
110
                "MCPToolServer."
            ) from None
111
112
113
114
115
116
117
118
        self.harmony_tool_descriptions = {}

    async def add_tool_server(self, server_url: str):
        tool_urls = server_url.split(",")
        self.harmony_tool_descriptions = {}
        self.urls: dict[str, str] = {}
        for url in tool_urls:
            url = f"http://{url}/sse"
119
            initialize_response, list_tools_response = await list_server_and_tools(url)
120

121
            list_tools_response = post_process_tools_description(list_tools_response)
122
123
124
125
126

            tool_from_mcp = ToolNamespaceConfig(
                name=initialize_response.serverInfo.name,
                description=initialize_response.instructions,
                tools=[
127
128
129
130
131
                    ToolDescription.new(
                        name=tool.name,
                        description=tool.description,
                        parameters=tool.inputSchema,
                    )
132
                    for tool in list_tools_response.tools
133
134
                ],
            )
135
136
137
138
139
140
            self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
            if tool_from_mcp.name not in self.urls:
                self.urls[tool_from_mcp.name] = url
            else:
                logger.warning(
                    "Tool %s already exists. Ignoring duplicate tool server %s",
141
142
143
144
145
146
147
                    tool_from_mcp.name,
                    url,
                )
        logger.info(
            "MCPToolServer initialized with tools: %s",
            list(self.harmony_tool_descriptions.keys()),
        )
148
149
150
151

    def has_tool(self, tool_name: str):
        return tool_name in self.harmony_tool_descriptions

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    def get_tool_description(
        self,
        server_label: str,
        allowed_tools: list[str] | None = None,
    ) -> ToolNamespaceConfig | None:
        cfg = self.harmony_tool_descriptions.get(server_label)
        if cfg is None:
            return None

        # No restrictions: all tools from this MCP server
        if allowed_tools is None:
            return cfg

        filtered = [t for t in cfg.tools if t.name in allowed_tools]

        if not filtered:
            return None

        return ToolNamespaceConfig(
            name=cfg.name,
            description=cfg.description,
            tools=filtered,
        )
175
176

    @asynccontextmanager
177
    async def new_session(
178
        self, tool_name: str, session_id: str, headers: dict[str, str] | None = None
179
    ):
180
181
        from mcp import ClientSession
        from mcp.client.sse import sse_client
182

183
        url = self.urls.get(tool_name)
184
185
186
        request_headers = {"x-session-id": session_id}
        if headers is not None:
            request_headers.update(headers)
187
188
        if not url:
            raise KeyError(f"Tool '{tool_name}' is not supported")
189
190
191
192
        async with (
            sse_client(url=url, headers=request_headers) as streams,
            ClientSession(*streams) as session,
        ):
193
194
195
196
            await session.initialize()
            yield session


197
198
199
class DemoToolServer(ToolServer):
    def __init__(self):
        self.tools: dict[str, Tool] = {}
200
201

    async def init_and_validate(self):
202
        browser_tool = HarmonyBrowserTool()
203
204
        python_tool = HarmonyPythonTool()
        await python_tool.validate()
205
206
207
208
        if browser_tool.enabled:
            self.tools["browser"] = browser_tool
        if python_tool.enabled:
            self.tools["python"] = python_tool
209
210
211
        logger.info(
            "DemoToolServer initialized with tools: %s", list(self.tools.keys())
        )
212
213
214
215

    def has_tool(self, tool_name: str) -> bool:
        return tool_name in self.tools

216
217
218
    def get_tool_description(
        self, tool_name: str, allowed_tools: list[str] | None = None
    ) -> ToolNamespaceConfig | None:
219
220
221
222
223
224
225
226
227
228
        if tool_name not in self.tools:
            return None
        if tool_name == "browser":
            return ToolNamespaceConfig.browser()
        elif tool_name == "python":
            return ToolNamespaceConfig.python()
        else:
            raise ValueError(f"Unknown tool {tool_name}")

    @asynccontextmanager
229
    async def new_session(
230
        self, tool_name: str, session_id: str, headers: dict[str, str] | None = None
231
    ):
232
233
        if tool_name not in self.tools:
            raise KeyError(f"Tool '{tool_name}' is not supported")
234
        yield self.tools[tool_name]