harmony_utils.py 17.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5

from __future__ import annotations

6
import datetime
7
import json
Woosuk Kwon's avatar
Woosuk Kwon committed
8
from collections.abc import Iterable, Sequence
9
from typing import Literal, Union
10

11
12
13
14
15
16
17
from openai.types.responses import (
    ResponseFunctionToolCall,
    ResponseOutputItem,
    ResponseOutputMessage,
    ResponseOutputText,
    ResponseReasoningItem,
)
18
from openai.types.responses.response_function_web_search import (
19
20
21
22
23
    ActionFind,
    ActionOpenPage,
    ActionSearch,
    ResponseFunctionWebSearch,
)
24
from openai.types.responses.response_reasoning_item import (
25
26
    Content as ResponseReasoningTextContent,
)
27
from openai.types.responses.tool import Tool
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from openai_harmony import (
    Author,
    ChannelConfig,
    Conversation,
    DeveloperContent,
    HarmonyEncodingName,
    Message,
    ReasoningEffort,
    Role,
    StreamableParser,
    SystemContent,
    TextContent,
    ToolDescription,
    load_harmony_encoding,
)
43

44
from vllm import envs
45
46
47
48
from vllm.entrypoints.openai.protocol import (
    ChatCompletionToolsParam,
    ResponseInputOutputItem,
)
49
from vllm.utils import random_uuid
50

51
52
53
54
55
56
57
58
REASONING_EFFORT = {
    "high": ReasoningEffort.HIGH,
    "medium": ReasoningEffort.MEDIUM,
    "low": ReasoningEffort.LOW,
}

_harmony_encoding = None

59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Builtin tools that should be included in the system message when
# they are available and requested by the user.
# Tool args are provided by MCP tool descriptions. Output
# of the tools are stringified.
BUILTIN_TOOLS = {
    "web_search_preview",
    "code_interpreter",
    "container",
}


def has_custom_tools(tool_types: list[str]) -> bool:
    return not set(tool_types).issubset(BUILTIN_TOOLS)

73
74
75
76

def get_encoding():
    global _harmony_encoding
    if _harmony_encoding is None:
77
        _harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
78
79
80
81
    return _harmony_encoding


def get_system_message(
82
83
84
85
86
87
88
    model_identity: str | None = None,
    reasoning_effort: Literal["high", "medium", "low"] | None = None,
    start_date: str | None = None,
    browser_description: str | None = None,
    python_description: str | None = None,
    container_description: str | None = None,
    instructions: str | None = None,
89
    with_custom_tools: bool = False,
90
91
92
93
) -> Message:
    sys_msg_content = SystemContent.new()
    if model_identity is not None:
        sys_msg_content = sys_msg_content.with_model_identity(model_identity)
94
    if instructions is not None and envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS:
95
        current_identity = sys_msg_content.model_identity
96
97
98
        new_identity = (
            f"{current_identity}\n{instructions}" if current_identity else instructions
        )
99
        sys_msg_content = sys_msg_content.with_model_identity(new_identity)
100
101
    if reasoning_effort is not None:
        sys_msg_content = sys_msg_content.with_reasoning_effort(
102
103
            REASONING_EFFORT[reasoning_effort]
        )
104
105
106
107
108
109
110
111
    if start_date is None:
        # NOTE(woosuk): This brings non-determinism in vLLM. Be careful.
        start_date = datetime.datetime.now().strftime("%Y-%m-%d")
    sys_msg_content = sys_msg_content.with_conversation_start_date(start_date)
    if browser_description is not None:
        sys_msg_content = sys_msg_content.with_tools(browser_description)
    if python_description is not None:
        sys_msg_content = sys_msg_content.with_tools(python_description)
112
113
114
115
116
117
    if container_description is not None:
        sys_msg_content = sys_msg_content.with_tools(container_description)
    if not with_custom_tools:
        channel_config = sys_msg_content.channel_config
        invalid_channel = "commentary"
        new_config = ChannelConfig.require_channels(
118
119
            [c for c in channel_config.valid_channels if c != invalid_channel]
        )
120
        sys_msg_content = sys_msg_content.with_channel_config(new_config)
121
122
123
124
    sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
    return sys_msg


125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def create_tool_definition(tool: Union[ChatCompletionToolsParam, Tool]):
    if isinstance(tool, ChatCompletionToolsParam):
        return ToolDescription.new(
            name=tool.function.name,
            description=tool.function.description,
            parameters=tool.function.parameters,
        )
    return ToolDescription.new(
        name=tool.name,
        description=tool.description,
        parameters=tool.parameters,
    )


def get_developer_message(
140
141
    instructions: str | None = None,
    tools: list[Union[Tool, ChatCompletionToolsParam]] | None = None,
142
) -> Message:
143
    dev_msg_content = DeveloperContent.new()
144
    if instructions is not None and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS:
145
146
        dev_msg_content = dev_msg_content.with_instructions(instructions)
    if tools is not None:
147
        function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
148
        for tool in tools:
149
150
151
152
153
154
            if tool.type in (
                "web_search_preview",
                "code_interpreter",
                "container",
                "mcp",
            ):
155
                # These are built-in tools that are added to the system message.
156
157
                # Adding in MCP for now until we support MCP tools executed
                # server side
158
                pass
159

160
161
162
163
164
165
            elif tool.type == "function":
                function_tools.append(tool)
            else:
                raise ValueError(f"tool type {tool.type} not supported")
        if function_tools:
            function_tool_descriptions = [
166
                create_tool_definition(tool) for tool in function_tools
167
168
            ]
            dev_msg_content = dev_msg_content.with_function_tools(
169
170
                function_tool_descriptions
            )
171
172
173
174
175
176
177
178
    dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content)
    return dev_msg


def get_user_message(content: str) -> Message:
    return Message.from_role_and_content(Role.USER, content)


179
180
def parse_response_input(
    response_msg: ResponseInputOutputItem,
181
    prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]],
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
) -> Message:
    if not isinstance(response_msg, dict):
        response_msg = response_msg.model_dump()
    if "type" not in response_msg or response_msg["type"] == "message":
        role = response_msg["role"]
        content = response_msg["content"]
        if role == "system":
            # User is trying to set a system message. Change it to:
            # <|start|>developer<|message|># Instructions
            # {instructions}<|end|>
            role = "developer"
            text_prefix = "Instructions:\n"
        else:
            text_prefix = ""
        if isinstance(content, str):
            msg = Message.from_role_and_content(role, text_prefix + content)
        else:
199
            contents = [TextContent(text=text_prefix + c["text"]) for c in content]
200
            msg = Message.from_role_and_contents(role, contents)
201
202
        if role == "assistant":
            msg = msg.with_channel("final")
203
204
    elif response_msg["type"] == "function_call_output":
        call_id = response_msg["call_id"]
205
        call_response: ResponseFunctionToolCall | None = None
206
        for prev_response in reversed(prev_responses):
207
208
209
210
            if (
                isinstance(prev_response, ResponseFunctionToolCall)
                and prev_response.call_id == call_id
            ):
211
212
213
214
215
216
                call_response = prev_response
                break
        if call_response is None:
            raise ValueError(f"No call message found for {call_id}")
        msg = Message.from_author_and_content(
            Author.new(Role.TOOL, f"functions.{call_response.name}"),
217
218
            response_msg["output"],
        )
219
220
221
222
223
    elif response_msg["type"] == "reasoning":
        content = response_msg["content"]
        assert len(content) == 1
        msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"])
    elif response_msg["type"] == "function_call":
224
        msg = Message.from_role_and_content(Role.ASSISTANT, response_msg["arguments"])
225
226
227
228
229
230
231
232
        msg = msg.with_channel("commentary")
        msg = msg.with_recipient(f"functions.{response_msg['name']}")
        msg = msg.with_content_type("json")
    else:
        raise ValueError(f"Unknown input type: {response_msg['type']}")
    return msg


233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def parse_chat_input(chat_msg) -> list[Message]:
    if not isinstance(chat_msg, dict):
        # Handle Pydantic models
        chat_msg = chat_msg.model_dump(exclude_none=True)

    role = chat_msg.get("role")

    # Assistant message with tool calls
    tool_calls = chat_msg.get("tool_calls")
    if role == "assistant" and tool_calls:
        msgs: list[Message] = []
        for call in tool_calls:
            func = call.get("function", {})
            name = func.get("name", "")
            arguments = func.get("arguments", "") or ""
            msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
            msg = msg.with_channel("commentary")
            msg = msg.with_recipient(f"functions.{name}")
            msg = msg.with_content_type("json")
            msgs.append(msg)
        return msgs

    # Tool role message (tool output)
    if role == "tool":
        name = chat_msg.get("name", "")
        content = chat_msg.get("content", "") or ""
259
260
261
262
263
264
265
        if isinstance(content, list):
            # Handle array format for tool message content
            # by concatenating all text parts.
            content = "".join(
                item.get("text", "") for item in content
                if isinstance(item, dict) and item.get("type") == "text")

266
        msg = Message.from_author_and_content(
267
268
            Author.new(Role.TOOL, f"functions.{name}"), content
        ).with_channel("commentary")
269
270
271
272
        return [msg]

    # Default: user/assistant/system messages with content
    content = chat_msg.get("content", "")
273
274
275
276
    if isinstance(content, str):
        contents = [TextContent(text=content)]
    else:
        # TODO: Support refusal.
277
        contents = [TextContent(text=c.get("text", "")) for c in content]
278
    msg = Message.from_role_and_contents(role, contents)
279
    return [msg]
280
281
282
283
284


def render_for_completion(messages: list[Message]) -> list[int]:
    conversation = Conversation.from_messages(messages)
    token_ids = get_encoding().render_conversation_for_completion(
285
286
        conversation, Role.ASSISTANT
    )
287
288
289
    return token_ids


290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
def parse_output_message(message: Message) -> list[ResponseOutputItem]:
    """
    Parse a Harmony message into a list of output response items.
    """
    if message.author.role != "assistant":
        # This is a message from a tool to the assistant (e.g., search result).
        # Don't include it in the final output for now. This aligns with
        # OpenAI's behavior on models like o4-mini.
        return []

    output_items: list[ResponseOutputItem] = []
    recipient = message.recipient
    if recipient is not None and recipient.startswith("browser."):
        if len(message.content) != 1:
            raise ValueError("Invalid number of contents in browser message")
        content = message.content[0]
        browser_call = json.loads(content.text)
        # TODO: translate to url properly!
        if recipient == "browser.search":
            action = ActionSearch(
310
311
                query=f"cursor:{browser_call.get('query', '')}", type="search"
            )
312
313
        elif recipient == "browser.open":
            action = ActionOpenPage(
314
315
                url=f"cursor:{browser_call.get('url', '')}", type="open_page"
            )
316
        elif recipient == "browser.find":
317
318
319
320
321
            action = ActionFind(
                pattern=browser_call["pattern"],
                url=f"cursor:{browser_call.get('url', '')}",
                type="find",
            )
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
        else:
            raise ValueError(f"Unknown browser action: {recipient}")
        web_search_item = ResponseFunctionWebSearch(
            id=f"ws_{random_uuid()}",
            action=action,
            status="completed",
            type="web_search_call",
        )
        output_items.append(web_search_item)
    elif message.channel == "analysis":
        for content in message.content:
            reasoning_item = ResponseReasoningItem(
                id=f"rs_{random_uuid()}",
                summary=[],
                type="reasoning",
                content=[
338
339
340
                    ResponseReasoningTextContent(
                        text=content.text, type="reasoning_text"
                    )
341
342
343
344
345
                ],
                status=None,
            )
            output_items.append(reasoning_item)
    elif message.channel == "commentary":
346
347
        if recipient is not None and recipient.startswith("functions."):
            function_name = recipient.split(".")[-1]
348
349
350
351
352
353
354
            for content in message.content:
                random_id = random_uuid()
                response_item = ResponseFunctionToolCall(
                    arguments=content.text,
                    call_id=f"call_{random_id}",
                    type="function_call",
                    name=function_name,
355
                    id=f"fc_{random_id}",
356
357
                )
                output_items.append(response_item)
358
359
360
361
362
        elif recipient is not None and (
            recipient.startswith("python")
            or recipient.startswith("browser")
            or recipient.startswith("container")
        ):
363
364
365
366
367
            for content in message.content:
                reasoning_item = ResponseReasoningItem(
                    id=f"rs_{random_uuid()}",
                    summary=[],
                    type="reasoning",
368
                    content=[
369
370
371
                        ResponseReasoningTextContent(
                            text=content.text, type="reasoning_text"
                        )
372
                    ],
373
374
375
376
                    status=None,
                )
                output_items.append(reasoning_item)
        else:
377
            raise ValueError(f"Unknown recipient: {recipient}")
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
    elif message.channel == "final":
        contents = []
        for content in message.content:
            output_text = ResponseOutputText(
                text=content.text,
                annotations=[],  # TODO
                type="output_text",
                logprobs=None,  # TODO
            )
            contents.append(output_text)
        text_item = ResponseOutputMessage(
            id=f"msg_{random_uuid()}",
            content=contents,
            role=message.author.role,
            status="completed",
            type="message",
        )
        output_items.append(text_item)
    else:
        raise ValueError(f"Unknown channel: {message.channel}")
    return output_items


401
def parse_remaining_state(parser: StreamableParser) -> list[ResponseOutputItem]:
402
403
404
405
406
    if not parser.current_content:
        return []
    if parser.current_role != Role.ASSISTANT:
        return []
    current_recipient = parser.current_recipient
407
    if current_recipient is not None and current_recipient.startswith("browser."):
408
409
410
411
412
413
414
415
        return []

    if parser.current_channel == "analysis":
        reasoning_item = ResponseReasoningItem(
            id=f"rs_{random_uuid()}",
            summary=[],
            type="reasoning",
            content=[
416
417
418
                ResponseReasoningTextContent(
                    text=parser.current_content, type="reasoning_text"
                )
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
            ],
            status=None,
        )
        return [reasoning_item]
    elif parser.current_channel == "final":
        output_text = ResponseOutputText(
            text=parser.current_content,
            annotations=[],  # TODO
            type="output_text",
            logprobs=None,  # TODO
        )
        text_item = ResponseOutputMessage(
            id=f"msg_{random_uuid()}",
            content=[output_text],
            role="assistant",
434
435
436
            # if the parser still has messages (ie if the generator got cut
            # abruptly), this should be incomplete
            status="incomplete",
437
438
439
440
441
442
            type="message",
        )
        return [text_item]
    return []


443
444
445
446
447
448
def get_stop_tokens_for_assistant_actions() -> list[int]:
    return get_encoding().stop_tokens_for_assistant_actions()


def get_streamable_parser_for_assistant() -> StreamableParser:
    return StreamableParser(get_encoding(), role=Role.ASSISTANT)
449
450
451
452
453
454
455
456
457
458


def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
    parser = get_streamable_parser_for_assistant()
    for token_id in token_ids:
        parser.process(token_id)
    return parser


def parse_chat_output(
459
    token_ids: Sequence[int],
460
) -> tuple[str | None, str | None, bool]:
461
462
    parser = parse_output_into_messages(token_ids)
    output_msgs = parser.messages
463
    is_tool_call = False  # TODO: update this when tool call is supported
464
465
466
467
468
469
470
471
472
    if len(output_msgs) == 0:
        # The generation has stopped during reasoning.
        reasoning_content = parser.current_content
        final_content = None
    elif len(output_msgs) == 1:
        # The generation has stopped during final message.
        reasoning_content = output_msgs[0].content[0].text
        final_content = parser.current_content
    else:
473
474
        reasoning_msg = output_msgs[:-1]
        final_msg = output_msgs[-1]
475
        reasoning_content = "\n".join([msg.content[0].text for msg in reasoning_msg])
476
477
        final_content = final_msg.content[0].text
    return reasoning_content, final_content, is_tool_call