"vllm/vscode:/vscode.git/clone" did not exist on "62fe9a486e23d625a0a36cdbffcfd83510e14d42"
stream_harmony.py 5.99 KB
Newer Older
1
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Harmony-specific streaming delta extraction for chat completions.

This module handles the extraction of DeltaMessage objects from
harmony parser state during streaming chat completions.
"""

10
11
from typing import NamedTuple

12
13
14
from openai_harmony import StreamableParser

from vllm.entrypoints.chat_utils import make_tool_call_id
15
from vllm.entrypoints.openai.engine.protocol import (
16
17
18
19
20
21
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
)


22
23
24
25
26
27
class TokenState(NamedTuple):
    channel: str | None
    recipient: str | None
    text: str


28
29
def extract_harmony_streaming_delta(
    harmony_parser: StreamableParser,
30
    token_states: list[TokenState],
31
32
33
34
35
36
37
38
    prev_recipient: str | None,
    include_reasoning: bool,
) -> tuple[DeltaMessage | None, bool]:
    """
    Extract a DeltaMessage from harmony parser state during streaming.

    Args:
        harmony_parser: The StreamableParser instance tracking parse state
39
        token_states: List of TokenState tuples for each token
40
41
42
43
44
45
        prev_recipient: Previous recipient for detecting tool call transitions
        include_reasoning: Whether to include reasoning content

    Returns:
        A tuple of (DeltaMessage or None, tools_streamed_flag)
    """
46
47
48
49

    if not token_states:
        return None, False

50
51
    tools_streamed = False

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    # Group consecutive tokens with same channel/recipient
    groups: list[TokenState] = []

    current_channel = token_states[0].channel
    current_recipient = token_states[0].recipient
    current_text = token_states[0].text

    for i in range(1, len(token_states)):
        state = token_states[i]
        if state.channel == current_channel and state.recipient == current_recipient:
            current_text += state.text
        else:
            groups.append(TokenState(current_channel, current_recipient, current_text))
            current_channel = state.channel
            current_recipient = state.recipient
            current_text = state.text

    groups.append(TokenState(current_channel, current_recipient, current_text))

    # Process each group and create delta messages
    delta_message = None
    combined_content = ""
    combined_reasoning = ""
    tool_messages = []
    content_encountered = False

    # Calculate base_index once before the loop
    # This counts completed tool calls in messages
    base_index = 0
    for msg in harmony_parser.messages:
        if (
            (msg.channel == "commentary" or msg.channel == "analysis")
            and msg.recipient
            and msg.recipient.startswith("functions.")
        ):
            base_index += 1

    # If there's an ongoing tool call from previous chunk,
    # the next new tool call starts at base_index + 1
    if prev_recipient and prev_recipient.startswith("functions."):
        next_tool_index = base_index + 1
        # Ongoing call is at base_index
        ongoing_tool_index = base_index
    else:
        # No ongoing call, next new call is at base_index
        next_tool_index = base_index
        ongoing_tool_index = None

    for group in groups:
        if group.channel == "final":
            combined_content += group.text
            content_encountered = True
        elif (
            (group.channel == "commentary" or group.channel == "analysis")
            and group.recipient
            and group.recipient.startswith("functions.")
        ):
            opened_new_call = False
            if prev_recipient != group.recipient:
                # New tool call - emit the opening message
                tool_name = group.recipient.split("functions.", 1)[1]
                tool_messages.append(
114
115
116
117
118
119
120
                    DeltaToolCall(
                        id=make_tool_call_id(),
                        type="function",
                        function=DeltaFunctionCall(
                            name=tool_name,
                            arguments="",
                        ),
121
                        index=next_tool_index,
122
                    )
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
                )
                opened_new_call = True
                prev_recipient = group.recipient
                # Increment for subsequent new tool calls
                next_tool_index += 1

            if group.text:
                # Stream arguments for the ongoing tool call
                if opened_new_call:
                    # Just opened in this group
                    tool_call_index = next_tool_index - 1
                else:
                    # Continuing from previous chunk
                    # If ongoing_tool_index is None here, it means
                    # we're continuing a call but prev_recipient
                    # wasn't a function. Use base_index.
                    tool_call_index = (
                        ongoing_tool_index
                        if ongoing_tool_index is not None
                        else base_index
                    )
                tool_messages.append(
145
                    DeltaToolCall(
146
147
                        index=tool_call_index,
                        function=DeltaFunctionCall(arguments=group.text),
148
                    )
149
150
151
152
153
154
155
                )
        elif group.channel == "commentary":
            # Tool call preambles meant to be shown to the user
            combined_content += group.text
            content_encountered = True
        elif group.channel == "analysis" and include_reasoning:
            combined_reasoning += group.text
156

157
158
159
160
161
162
163
164
165
    # Combine all non-empty fields into a single message
    if content_encountered or combined_reasoning or tool_messages:
        delta_kwargs: dict[str, str | list[DeltaToolCall]] = {}
        if content_encountered:
            delta_kwargs["content"] = combined_content
        if combined_reasoning:
            delta_kwargs["reasoning"] = combined_reasoning
        if tool_messages:
            delta_kwargs["tool_calls"] = tool_messages
166
            tools_streamed = True
167
        delta_message = DeltaMessage(**delta_kwargs)
168
169
170
171
    else:
        delta_message = None

    return delta_message, tools_streamed