tool.py 4.47 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from abc import ABC, abstractmethod
5
from typing import TYPE_CHECKING, Any
6

7
8
from openai_harmony import Author, Message, Role, TextContent

9
10
11
12
13
14
15
16
17
from vllm.logger import init_logger

if TYPE_CHECKING:
    # Avoid circular import.
    from vllm.entrypoints.context import ConversationContext

logger = init_logger(__name__)


18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def validate_gpt_oss_install():
    """
    Check if the gpt-oss is installed and its version is at least 0.0.3.
    If not, raise an ImportError.
    """
    from importlib.metadata import PackageNotFoundError, version

    from packaging.version import InvalidVersion, Version

    try:
        pkg_version_str = version("gpt_oss")  # e.g., "0.0.5"
        pkg_version = Version(pkg_version_str)
    except PackageNotFoundError:
        raise ImportError("Package 'gpt_oss' is not installed.") from None
    except InvalidVersion as e:
33
        raise ImportError(f"Invalid version string for 'gpt_oss': {e}") from None
34
35
36
37
38
39
40

    if pkg_version < Version("0.0.3"):
        raise ImportError(
            f"gpt_oss >= 0.0.3 is required, but {pkg_version} is installed."
        ) from None


41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class Tool(ABC):
    @abstractmethod
    async def get_result(self, context: "ConversationContext") -> Any:
        pass


class HarmonyBrowserTool(Tool):
    def __init__(self):
        self.enabled = True
        exa_api_key = os.getenv("EXA_API_KEY")
        if not exa_api_key:
            self.enabled = False
            logger.warning_once("EXA_API_KEY is not set, browsing is disabled")
            return

        try:
57
            validate_gpt_oss_install()
58
59
            from gpt_oss.tools.simple_browser import SimpleBrowserTool
            from gpt_oss.tools.simple_browser.backend import ExaBackend
60
        except ImportError as e:
61
62
            self.enabled = False
            logger.warning_once(
63
64
                "gpt_oss is not installed properly (%s), browsing is disabled", e
            )
65
66
67
68
69
70
71
72
            return

        browser_backend = ExaBackend(source="web", api_key=exa_api_key)
        self.browser_tool = SimpleBrowserTool(backend=browser_backend)
        logger.info_once("Browser tool initialized")

    async def get_result(self, context: "ConversationContext") -> Any:
        from vllm.entrypoints.context import HarmonyContext
73

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        assert isinstance(context, HarmonyContext)
        last_msg = context.messages[-1]
        tool_output_msgs = []
        async for msg in self.browser_tool.process(last_msg):
            tool_output_msgs.append(msg)
        return tool_output_msgs

    @property
    def tool_config(self) -> Any:
        return self.browser_tool.tool_config


class HarmonyPythonTool(Tool):
    def __init__(self):
        self.enabled = True

        try:
91
            validate_gpt_oss_install()
92
            from gpt_oss.tools.python_docker.docker_tool import PythonTool
93
        except ImportError as e:
94
95
            self.enabled = False
            logger.warning_once(
96
97
98
                "gpt_oss is not installed properly (%s), code interpreter is disabled",
                e,
            )
99
100
            return

101
        self.python_tool = PythonTool()
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

    async def validate(self):
        if not self.enabled:
            return
        try:
            message = Message(
                author=Author(role=Role.ASSISTANT),
                content=[TextContent(text="print('Hello, world!')")],
                channel="analysis",
                recipient="python",
                content_type="code",
            )
            msgs = []
            async for msg in self.python_tool.process(message):
                msgs.append(msg)
            assert msgs[0].content[0].text == "Hello, world!\n"
        except Exception as e:
            self.enabled = False
            logger.warning_once(
                "Code interpreter tool failed to initialize (%s), code "
122
123
124
                "interpreter is disabled",
                e,
            )
125
            return
126
127
128
129
        logger.info_once("Code interpreter tool initialized")

    async def get_result(self, context: "ConversationContext") -> Any:
        from vllm.entrypoints.context import HarmonyContext
130

131
132
133
134
135
136
137
138
139
140
        assert isinstance(context, HarmonyContext)
        last_msg = context.messages[-1]
        tool_output_msgs = []
        async for msg in self.python_tool.process(last_msg):
            tool_output_msgs.append(msg)
        return tool_output_msgs

    @property
    def tool_config(self) -> Any:
        return self.python_tool.tool_config