tool.py 6.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import json
4
5
import os
from abc import ABC, abstractmethod
6
from typing import TYPE_CHECKING, Any
7

8
9
10
from openai.types.responses.response_function_tool_call_output_item import (
    ResponseFunctionToolCallOutputItem,
)
11
12
from openai_harmony import Author, Message, Role, TextContent

13
from vllm.logger import init_logger
14
from vllm.utils import random_uuid
15
16
17
18
19
20
21

if TYPE_CHECKING:
    # Avoid circular import.
    from vllm.entrypoints.context import ConversationContext

logger = init_logger(__name__)

22
23
MIN_GPT_OSS_VERSION = "0.0.7"

24

25
26
def validate_gpt_oss_install():
    """
27
    Check if the gpt-oss is installed and its version is at least 0.0.7.
28
29
30
31
32
33
34
    If not, raise an ImportError.
    """
    from importlib.metadata import PackageNotFoundError, version

    from packaging.version import InvalidVersion, Version

    try:
35
        pkg_version_str = version("gpt_oss")
36
37
38
39
        pkg_version = Version(pkg_version_str)
    except PackageNotFoundError:
        raise ImportError("Package 'gpt_oss' is not installed.") from None
    except InvalidVersion as e:
40
        raise ImportError(f"Invalid version string for 'gpt_oss': {e}") from None
41

42
    if pkg_version < Version(MIN_GPT_OSS_VERSION):
43
        raise ImportError(
44
45
            f"gpt_oss >= {MIN_GPT_OSS_VERSION} is required, "
            f"but {pkg_version} is installed."
46
47
48
        ) from None


49
50
51
52
53
class Tool(ABC):
    @abstractmethod
    async def get_result(self, context: "ConversationContext") -> Any:
        pass

54
55
56
57
    @abstractmethod
    async def get_result_parsable_context(self, context: "ConversationContext") -> Any:
        pass

58
59
60
61
62
63
64
65
66
67
68

class HarmonyBrowserTool(Tool):
    def __init__(self):
        self.enabled = True
        exa_api_key = os.getenv("EXA_API_KEY")
        if not exa_api_key:
            self.enabled = False
            logger.warning_once("EXA_API_KEY is not set, browsing is disabled")
            return

        try:
69
            validate_gpt_oss_install()
70
71
            from gpt_oss.tools.simple_browser import SimpleBrowserTool
            from gpt_oss.tools.simple_browser.backend import ExaBackend
72
        except ImportError as e:
73
74
            self.enabled = False
            logger.warning_once(
75
76
                "gpt_oss is not installed properly (%s), browsing is disabled", e
            )
77
78
79
80
81
82
83
84
            return

        browser_backend = ExaBackend(source="web", api_key=exa_api_key)
        self.browser_tool = SimpleBrowserTool(backend=browser_backend)
        logger.info_once("Browser tool initialized")

    async def get_result(self, context: "ConversationContext") -> Any:
        from vllm.entrypoints.context import HarmonyContext
85

86
87
88
89
90
91
92
        assert isinstance(context, HarmonyContext)
        last_msg = context.messages[-1]
        tool_output_msgs = []
        async for msg in self.browser_tool.process(last_msg):
            tool_output_msgs.append(msg)
        return tool_output_msgs

93
94
95
    async def get_result_parsable_context(self, context: "ConversationContext") -> Any:
        raise NotImplementedError("Not implemented yet")

96
97
98
99
100
101
102
103
104
105
    @property
    def tool_config(self) -> Any:
        return self.browser_tool.tool_config


class HarmonyPythonTool(Tool):
    def __init__(self):
        self.enabled = True

        try:
106
            validate_gpt_oss_install()
107
            from gpt_oss.tools.python_docker.docker_tool import PythonTool
108
        except ImportError as e:
109
110
            self.enabled = False
            logger.warning_once(
111
112
113
                "gpt_oss is not installed properly (%s), code interpreter is disabled",
                e,
            )
114
115
            return

116
        self.python_tool = PythonTool()
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

    async def validate(self):
        if not self.enabled:
            return
        try:
            message = Message(
                author=Author(role=Role.ASSISTANT),
                content=[TextContent(text="print('Hello, world!')")],
                channel="analysis",
                recipient="python",
                content_type="code",
            )
            msgs = []
            async for msg in self.python_tool.process(message):
                msgs.append(msg)
            assert msgs[0].content[0].text == "Hello, world!\n"
        except Exception as e:
            self.enabled = False
            logger.warning_once(
                "Code interpreter tool failed to initialize (%s), code "
137
138
139
                "interpreter is disabled",
                e,
            )
140
            return
141
142
143
144
        logger.info_once("Code interpreter tool initialized")

    async def get_result(self, context: "ConversationContext") -> Any:
        from vllm.entrypoints.context import HarmonyContext
145

146
147
148
149
150
151
152
        assert isinstance(context, HarmonyContext)
        last_msg = context.messages[-1]
        tool_output_msgs = []
        async for msg in self.python_tool.process(last_msg):
            tool_output_msgs.append(msg)
        return tool_output_msgs

153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    async def get_result_parsable_context(self, context: "ConversationContext") -> Any:
        """
        This function converts parsable context types to harmony and
        back so we can use GPTOSS demo python tool
        """
        from vllm.entrypoints.context import ParsableContext

        assert isinstance(context, ParsableContext)

        last_msg = context.parser.response_messages[-1]
        args = json.loads(last_msg.arguments)

        last_msg_harmony = Message(
            author=Author(role="assistant", name=None),
            content=[TextContent(text=args["code"])],
            channel="analysis",
            recipient="python",
            content_type="code",
        )

        tool_output_msgs = []
        async for msg in self.python_tool.process(last_msg_harmony):
            processed = ResponseFunctionToolCallOutputItem(
                id=f"fco_{random_uuid()}",
                type="function_call_output",
                call_id=f"call_{random_uuid()}",
                output=msg.content[0].text,
                status="completed",
            )
            tool_output_msgs.append(processed)
        return tool_output_msgs

185
186
187
    @property
    def tool_config(self) -> Any:
        return self.python_tool.tool_config