tool.py 4.53 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from abc import ABC, abstractmethod
5
from typing import TYPE_CHECKING, Any
6

7
8
from openai_harmony import Author, Message, Role, TextContent

9
10
11
12
13
14
15
16
from vllm.logger import init_logger

if TYPE_CHECKING:
    # Avoid circular import.
    from vllm.entrypoints.context import ConversationContext

logger = init_logger(__name__)

17
18
MIN_GPT_OSS_VERSION = "0.0.7"

19

20
21
def validate_gpt_oss_install():
    """
22
    Check if the gpt-oss is installed and its version is at least 0.0.7.
23
24
25
26
27
28
29
    If not, raise an ImportError.
    """
    from importlib.metadata import PackageNotFoundError, version

    from packaging.version import InvalidVersion, Version

    try:
30
        pkg_version_str = version("gpt_oss")
31
32
33
34
        pkg_version = Version(pkg_version_str)
    except PackageNotFoundError:
        raise ImportError("Package 'gpt_oss' is not installed.") from None
    except InvalidVersion as e:
35
        raise ImportError(f"Invalid version string for 'gpt_oss': {e}") from None
36

37
    if pkg_version < Version(MIN_GPT_OSS_VERSION):
38
        raise ImportError(
39
40
            f"gpt_oss >= {MIN_GPT_OSS_VERSION} is required, "
            f"but {pkg_version} is installed."
41
42
43
        ) from None


44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class Tool(ABC):
    @abstractmethod
    async def get_result(self, context: "ConversationContext") -> Any:
        pass


class HarmonyBrowserTool(Tool):
    def __init__(self):
        self.enabled = True
        exa_api_key = os.getenv("EXA_API_KEY")
        if not exa_api_key:
            self.enabled = False
            logger.warning_once("EXA_API_KEY is not set, browsing is disabled")
            return

        try:
60
            validate_gpt_oss_install()
61
62
            from gpt_oss.tools.simple_browser import SimpleBrowserTool
            from gpt_oss.tools.simple_browser.backend import ExaBackend
63
        except ImportError as e:
64
65
            self.enabled = False
            logger.warning_once(
66
67
                "gpt_oss is not installed properly (%s), browsing is disabled", e
            )
68
69
70
71
72
73
74
75
            return

        browser_backend = ExaBackend(source="web", api_key=exa_api_key)
        self.browser_tool = SimpleBrowserTool(backend=browser_backend)
        logger.info_once("Browser tool initialized")

    async def get_result(self, context: "ConversationContext") -> Any:
        from vllm.entrypoints.context import HarmonyContext
76

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        assert isinstance(context, HarmonyContext)
        last_msg = context.messages[-1]
        tool_output_msgs = []
        async for msg in self.browser_tool.process(last_msg):
            tool_output_msgs.append(msg)
        return tool_output_msgs

    @property
    def tool_config(self) -> Any:
        return self.browser_tool.tool_config


class HarmonyPythonTool(Tool):
    def __init__(self):
        self.enabled = True

        try:
94
            validate_gpt_oss_install()
95
            from gpt_oss.tools.python_docker.docker_tool import PythonTool
96
        except ImportError as e:
97
98
            self.enabled = False
            logger.warning_once(
99
100
101
                "gpt_oss is not installed properly (%s), code interpreter is disabled",
                e,
            )
102
103
            return

104
        self.python_tool = PythonTool()
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

    async def validate(self):
        if not self.enabled:
            return
        try:
            message = Message(
                author=Author(role=Role.ASSISTANT),
                content=[TextContent(text="print('Hello, world!')")],
                channel="analysis",
                recipient="python",
                content_type="code",
            )
            msgs = []
            async for msg in self.python_tool.process(message):
                msgs.append(msg)
            assert msgs[0].content[0].text == "Hello, world!\n"
        except Exception as e:
            self.enabled = False
            logger.warning_once(
                "Code interpreter tool failed to initialize (%s), code "
125
126
127
                "interpreter is disabled",
                e,
            )
128
            return
129
130
131
132
        logger.info_once("Code interpreter tool initialized")

    async def get_result(self, context: "ConversationContext") -> Any:
        from vllm.entrypoints.context import HarmonyContext
133

134
135
136
137
138
139
140
141
142
143
        assert isinstance(context, HarmonyContext)
        last_msg = context.messages[-1]
        tool_output_msgs = []
        async for msg in self.python_tool.process(last_msg):
            tool_output_msgs.append(msg)
        return tool_output_msgs

    @property
    def tool_config(self) -> Any:
        return self.python_tool.tool_config