stream_harmony.py 3.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Harmony-specific streaming delta extraction for chat completions.

This module handles the extraction of DeltaMessage objects from
harmony parser state during streaming chat completions.
"""

from openai_harmony import StreamableParser

from vllm.entrypoints.chat_utils import make_tool_call_id
13
from vllm.entrypoints.openai.engine.protocol import (
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
)


def extract_harmony_streaming_delta(
    harmony_parser: StreamableParser,
    cur_channel: str | None,
    cur_recipient: str | None,
    prev_recipient: str | None,
    delta_text: str,
    include_reasoning: bool,
) -> tuple[DeltaMessage | None, bool]:
    """
    Extract a DeltaMessage from harmony parser state during streaming.

    Args:
        harmony_parser: The StreamableParser instance tracking parse state
        cur_channel: Current channel ("final", "analysis", "commentary", etc.)
        cur_recipient: Current recipient (e.g., "functions.my_func")
        prev_recipient: Previous recipient for detecting tool call transitions
        delta_text: The text delta to include in the message
        include_reasoning: Whether to include reasoning content

    Returns:
        A tuple of (DeltaMessage or None, tools_streamed_flag)
    """
    tools_streamed = False

    if cur_channel == "final":
        delta_message = DeltaMessage(content=delta_text)
    elif (
        (cur_channel == "commentary" or cur_channel == "analysis")
        and cur_recipient
        and cur_recipient.startswith("functions.")
    ):
        # Count completed tool calls to determine index
        base_index = 0
        for msg in harmony_parser.messages:
            if (
                (msg.channel == "commentary" or msg.channel == "analysis")
                and msg.recipient
                and msg.recipient.startswith("functions.")
            ):
                base_index += 1

        if prev_recipient != cur_recipient:
            tool_name = cur_recipient.split("functions.", 1)[1]
            delta_message = DeltaMessage(
                tool_calls=[
                    DeltaToolCall(
                        id=make_tool_call_id(),
                        type="function",
                        function=DeltaFunctionCall(
                            name=tool_name,
                            arguments="",
                        ),
                        index=base_index,
                    )
                ]
            )
        elif delta_text:
            delta_message = DeltaMessage(
                tool_calls=[
                    DeltaToolCall(
                        index=base_index,
                        function=DeltaFunctionCall(arguments=delta_text),
                    )
                ]
            )
        else:
            delta_message = None

        if delta_message is not None:
            tools_streamed = True
    elif cur_channel == "commentary":
        # Tool call preambles meant to be shown to the user
        delta_message = DeltaMessage(content=delta_text)
    elif cur_channel == "analysis":
        if include_reasoning:
            delta_message = DeltaMessage(reasoning=delta_text)
        else:
            delta_message = None
    else:
        delta_message = None

    return delta_message, tools_streamed