utils.py 6.96 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import AsyncGenerator
from typing import Any

7
from vllm.entrypoints.openai.chat_completion.protocol import (
8
9
10
11
    ChatCompletionResponse,
    ChatCompletionResponseChoice,
    ChatCompletionStreamResponse,
    ChatMessage,
12
13
)
from vllm.entrypoints.openai.engine.protocol import (
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    UsageInfo,
)


async def accumulate_streaming_response(
    stream_generator: AsyncGenerator[str, None],
) -> ChatCompletionResponse:
    """
    Accumulate streaming SSE chunks into a complete ChatCompletionResponse.

    This helper parses the SSE format and builds up the complete response
    by combining all the delta chunks.
    """
    accumulated_content = ""
    accumulated_reasoning = None
    accumulated_tool_calls: list[dict[str, Any]] = []
    role = None
    finish_reason = None
    response_id = None
    created = None
    model = None
    index = 0

    async for chunk_str in stream_generator:
        # Skip empty lines and [DONE] marker
        if not chunk_str.strip() or chunk_str.strip() == "data: [DONE]":
            continue

        # Parse SSE format: "data: {json}\n\n"
        if chunk_str.startswith("data: "):
            json_str = chunk_str[6:].strip()
            try:
                chunk_data = json.loads(json_str)
                # print(f"DEBUG: Parsed chunk_data: {chunk_data}")
                chunk = ChatCompletionStreamResponse(**chunk_data)

                # Store metadata from first chunk
                if response_id is None:
                    response_id = chunk.id
                    created = chunk.created
                    model = chunk.model

                # Process each choice in the chunk
                for choice in chunk.choices:
                    if choice.delta.role:
                        role = choice.delta.role
                    if choice.delta.content:
                        accumulated_content += choice.delta.content
                    if choice.delta.reasoning:
                        if accumulated_reasoning is None:
                            accumulated_reasoning = ""
                        accumulated_reasoning += choice.delta.reasoning
                    if choice.delta.tool_calls:
                        # Accumulate tool calls
                        for tool_call_delta in choice.delta.tool_calls:
                            # Find or create the tool call at this index
                            while len(accumulated_tool_calls) <= tool_call_delta.index:
                                accumulated_tool_calls.append(
                                    {
                                        "id": None,
                                        "type": "function",
                                        "function": {"name": "", "arguments": ""},
                                    }
                                )

                            if tool_call_delta.id:
                                accumulated_tool_calls[tool_call_delta.index]["id"] = (
                                    tool_call_delta.id
                                )
                            if tool_call_delta.function:
                                if tool_call_delta.function.name:
                                    accumulated_tool_calls[tool_call_delta.index][
                                        "function"
                                    ]["name"] += tool_call_delta.function.name
                                if tool_call_delta.function.arguments:
                                    accumulated_tool_calls[tool_call_delta.index][
                                        "function"
                                    ]["arguments"] += tool_call_delta.function.arguments

                    if choice.finish_reason:
                        finish_reason = choice.finish_reason
                    if choice.index is not None:
                        index = choice.index

            except json.JSONDecodeError:
                continue

    # Build the final message
    message_kwargs = {
        "role": role or "assistant",
        "content": accumulated_content if accumulated_content else None,
        "reasoning": accumulated_reasoning,
    }

    # Only include tool_calls if there are any
    if accumulated_tool_calls:
        message_kwargs["tool_calls"] = [
            {"id": tc["id"], "type": tc["type"], "function": tc["function"]}
            for tc in accumulated_tool_calls
        ]

    message = ChatMessage(**message_kwargs)

    # Build the final response
    choice = ChatCompletionResponseChoice(
        index=index,
        message=message,
        finish_reason=finish_reason or "stop",
    )

    # Create usage info (with dummy values for tests)
    usage = UsageInfo(
        prompt_tokens=0,
        completion_tokens=0,
        total_tokens=0,
    )

    response = ChatCompletionResponse(
        id=response_id or "chatcmpl-test",
        object="chat.completion",
        created=created or 0,
        model=model or "test-model",
        choices=[choice],
        usage=usage,
    )

    return response


def verify_harmony_messages(
    messages: list[Any], expected_messages: list[dict[str, Any]]
):
    assert len(messages) == len(expected_messages)
    for msg, expected in zip(messages, expected_messages):
        if "role" in expected:
            assert msg.author.role == expected["role"]
        if "author_name" in expected:
            assert msg.author.name == expected["author_name"]
        if "channel" in expected:
            assert msg.channel == expected["channel"]
        if "recipient" in expected:
            assert msg.recipient == expected["recipient"]
        if "content" in expected:
            assert msg.content[0].text == expected["content"]
        if "content_type" in expected:
            assert msg.content_type == expected["content_type"]
        if "tool_definitions" in expected:
            # Check that the tool definitions match the expected list of tool names
            actual_tools = [t.name for t in msg.content[0].tools["functions"].tools]
            assert actual_tools == expected["tool_definitions"]


def verify_chat_response(
    response: ChatCompletionResponse,
    content: str | None = None,
    reasoning: str | None = None,
    tool_calls: list[tuple[str, str]] | None = None,
):
    assert len(response.choices) == 1
    message = response.choices[0].message

    if content is not None:
        assert message.content == content
    else:
        assert not message.content

    if reasoning is not None:
        assert message.reasoning == reasoning
    else:
        assert not message.reasoning

    if tool_calls:
        assert message.tool_calls is not None
        assert len(message.tool_calls) == len(tool_calls)
        for tc, (expected_name, expected_args) in zip(message.tool_calls, tool_calls):
            assert tc.function.name == expected_name
            assert tc.function.arguments == expected_args
    else:
        assert not message.tool_calls