tool_server.py 6.62 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from contextlib import AbstractAsyncContextManager, asynccontextmanager
5
from typing import TYPE_CHECKING, Any, Optional
6

7
from openai_harmony import ToolDescription, ToolNamespaceConfig
8
9
10
11
12
13

from vllm.entrypoints.tool import HarmonyBrowserTool, HarmonyPythonTool, Tool
from vllm.logger import init_logger

logger = init_logger(__name__)

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
if TYPE_CHECKING:
    from mcp.types import ListToolsResult


async def list_server_and_tools(server_url: str):
    from mcp import ClientSession
    from mcp.client.sse import sse_client

    async with sse_client(url=server_url) as streams, ClientSession(
            *streams) as session:
        initialize_response = await session.initialize()
        list_tools_response = await session.list_tools()
        return initialize_response, list_tools_response


def trim_schema(schema: dict) -> dict:
    # Turn JSON Schema from MCP generated into Harmony's variant.
    if "title" in schema:
        del schema["title"]
    if "default" in schema and schema["default"] is None:
        del schema["default"]
    if "anyOf" in schema:
        # Turn "anyOf": [{"type": "type-1"}, {"type": "type-2"}]
        # into "type": ["type-1", "type-2"]
        # if there's more than 1 types, also remove "null" type as Harmony will
        # just ignore it
        types = [
            type_dict["type"] for type_dict in schema["anyOf"]
            if type_dict["type"] != 'null'
        ]
        schema["type"] = types
        del schema["anyOf"]
    if "properties" in schema:
        schema["properties"] = {
            k: trim_schema(v)
            for k, v in schema["properties"].items()
        }
    return schema


def post_process_tools_description(
        list_tools_result: "ListToolsResult") -> "ListToolsResult":
    # Adapt the MCP tool result for Harmony
    for tool in list_tools_result.tools:
        tool.inputSchema = trim_schema(tool.inputSchema)

    # Some tools schema don't need to be part of the prompt (e.g. simple text
    # in text out for Python)
    list_tools_result.tools = [
        tool for tool in list_tools_result.tools
        if getattr(tool.annotations, "include_in_prompt", True)
    ]

    return list_tools_result

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

class ToolServer(ABC):

    @abstractmethod
    def has_tool(self, tool_name: str) -> bool:
        """
        Return True if the tool is supported, False otherwise.
        """
        pass

    @abstractmethod
    def get_tool_description(self,
                             tool_name: str) -> Optional[ToolNamespaceConfig]:
        """
        Return the tool description for the given tool name.
        If the tool is not supported, return None.
        """
        pass

    @abstractmethod
89
90
    def new_session(self, tool_name: str,
                    session_id: str) -> AbstractAsyncContextManager[Any]:
91
92
93
94
95
96
        """
        Create a session for the tool.
        """
        ...


97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
class MCPToolServer(ToolServer):

    def __init__(self):
        try:
            import mcp  # noqa: F401
        except ImportError:
            raise ImportError(
                "mcp is not installed. Please run `pip install mcp` to use "
                "MCPToolServer.") from None
        self.harmony_tool_descriptions = {}

    async def add_tool_server(self, server_url: str):
        tool_urls = server_url.split(",")
        self.harmony_tool_descriptions = {}
        self.urls: dict[str, str] = {}
        for url in tool_urls:
            url = f"http://{url}/sse"
            initialize_response, list_tools_response = (
                await list_server_and_tools(url))

            list_tools_response = post_process_tools_description(
                list_tools_response)

            tool_from_mcp = ToolNamespaceConfig(
                name=initialize_response.serverInfo.name,
                description=initialize_response.instructions,
                tools=[
                    ToolDescription.new(name=tool.name,
                                        description=tool.description,
                                        parameters=tool.inputSchema)
                    for tool in list_tools_response.tools
128
129
                ],
            )
130
131
132
133
134
135
136
            self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp
            if tool_from_mcp.name not in self.urls:
                self.urls[tool_from_mcp.name] = url
            else:
                logger.warning(
                    "Tool %s already exists. Ignoring duplicate tool server %s",
                    tool_from_mcp.name, url)
137
138
        logger.info("MCPToolServer initialized with tools: %s",
                    list(self.harmony_tool_descriptions.keys()))
139
140
141
142
143
144
145
146

    def has_tool(self, tool_name: str):
        return tool_name in self.harmony_tool_descriptions

    def get_tool_description(self, tool_name: str):
        return self.harmony_tool_descriptions.get(tool_name)

    @asynccontextmanager
147
    async def new_session(self, tool_name: str, session_id: str):
148
149
150
        from mcp import ClientSession
        from mcp.client.sse import sse_client
        url = self.urls.get(tool_name)
151
        headers = {"x-session-id": session_id}
152
153
        if not url:
            raise KeyError(f"Tool '{tool_name}' is not supported")
154
155
156
        async with sse_client(url=url,
                              headers=headers) as streams, ClientSession(
                                  *streams) as session:
157
158
159
160
            await session.initialize()
            yield session


161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
class DemoToolServer(ToolServer):

    def __init__(self):
        self.tools: dict[str, Tool] = {}
        browser_tool = HarmonyBrowserTool()
        if browser_tool.enabled:
            self.tools["browser"] = browser_tool
        python_tool = HarmonyPythonTool()
        if python_tool.enabled:
            self.tools["python"] = python_tool
        logger.info("DemoToolServer initialized with tools: %s",
                    list(self.tools.keys()))

    def has_tool(self, tool_name: str) -> bool:
        return tool_name in self.tools

    def get_tool_description(self,
                             tool_name: str) -> Optional[ToolNamespaceConfig]:
        if tool_name not in self.tools:
            return None
        if tool_name == "browser":
            return ToolNamespaceConfig.browser()
        elif tool_name == "python":
            return ToolNamespaceConfig.python()
        else:
            raise ValueError(f"Unknown tool {tool_name}")

    @asynccontextmanager
189
    async def new_session(self, tool_name: str, session_id: str):
190
191
        if tool_name not in self.tools:
            raise KeyError(f"Tool '{tool_name}' is not supported")
192
        yield self.tools[tool_name]