conftest.py 6.74 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
from __future__ import annotations

import json
import logging
from collections.abc import Callable
from typing import Any

10
11
import pytest

12
13
14
15
16
17
18
19
logger = logging.getLogger(__name__)

BASE_TEST_ENV = {
    # The day vLLM said "hello world" on arxiv 🚀
    "VLLM_SYSTEM_START_DATE": "2023-09-12",
}
DEFAULT_MAX_RETRIES = 3

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

@pytest.fixture
def pairs_of_event_types() -> dict[str, str]:
    """Links the 'done' event type with the corresponding 'start' event type.

    This mapping should link all done <-> start events; if tests mean to
    restrict the allowed events, they should filter this fixture to avoid
    copy + paste errors in the mappings or unexpected KeyErrors due to missing
    events.
    """
    # fmt: off
    event_pairs = {
        "response.completed": "response.created",
        "response.output_item.done": "response.output_item.added",
        "response.content_part.done": "response.content_part.added",
        "response.output_text.done": "response.output_text.delta",
        "response.reasoning_text.done": "response.reasoning_text.delta",
        "response.reasoning_part.done": "response.reasoning_part.added",
        "response.mcp_call_arguments.done": "response.mcp_call_arguments.delta",
        "response.mcp_call.completed": "response.mcp_call.in_progress",
        "response.function_call_arguments.done": "response.function_call_arguments.delta", # noqa: E501
        "response.code_interpreter_call_code.done": "response.code_interpreter_call_code.delta", # noqa: E501
        "response.web_search_call.completed": "response.web_search_call.in_progress",
    }
    # fmt: on
    return event_pairs
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201


async def retry_for_tool_call(
    client,
    *,
    model: str,
    expected_tool_type: str,
    max_retries: int = DEFAULT_MAX_RETRIES,
    **create_kwargs: Any,
):
    """Call ``client.responses.create`` up to *max_retries* times, returning
    the first response that contains an output item of *expected_tool_type*.

    Returns the **last** response if none match so the caller's assertions
    fire with a clear diagnostic.
    """
    last_response = None
    for attempt in range(max_retries):
        response = await client.responses.create(model=model, **create_kwargs)
        last_response = response
        if any(
            getattr(item, "type", None) == expected_tool_type
            for item in response.output
        ):
            return response
    assert last_response is not None
    return last_response


async def retry_streaming_for(
    client,
    *,
    model: str,
    validate_events: Callable[[list], bool],
    max_retries: int = DEFAULT_MAX_RETRIES,
    **create_kwargs: Any,
) -> list:
    """Call ``client.responses.create(stream=True)`` up to *max_retries*
    times, returning the first event list where *validate_events* returns
    ``True``.
    """
    last_events: list = []
    for attempt in range(max_retries):
        stream = await client.responses.create(
            model=model, stream=True, **create_kwargs
        )
        events: list = []
        async for event in stream:
            events.append(event)
        last_events = events
        if validate_events(events):
            return events
    return last_events


def has_output_type(response, type_name: str) -> bool:
    """Return True if *response* has at least one output item of *type_name*."""
    return any(getattr(item, "type", None) == type_name for item in response.output)


def events_contain_type(events: list, type_substring: str) -> bool:
    """Return True if any event's type contains *type_substring*."""
    return any(type_substring in getattr(e, "type", "") for e in events)


def validate_streaming_event_stack(
    events: list, pairs_of_event_types: dict[str, str]
) -> None:
    """Validate that streaming events are properly nested/paired."""
    stack: list[str] = []
    for event in events:
        etype = event.type
        if etype == "response.created":
            stack.append(etype)
        elif etype == "response.completed":
            assert stack and stack[-1] == pairs_of_event_types[etype], (
                f"Unexpected stack top for {etype}: "
                f"got {stack[-1] if stack else '<empty>'}"
            )
            stack.pop()
        elif etype.endswith("added") or etype == "response.mcp_call.in_progress":
            stack.append(etype)
        elif etype.endswith("delta"):
            if stack and stack[-1] == etype:
                continue
            stack.append(etype)
        elif etype.endswith("done") or etype == "response.mcp_call.completed":
            assert etype in pairs_of_event_types, f"Unknown done event: {etype}"
            expected_start = pairs_of_event_types[etype]
            assert stack and stack[-1] == expected_start, (
                f"Stack mismatch for {etype}: "
                f"expected {expected_start}, "
                f"got {stack[-1] if stack else '<empty>'}"
            )
            stack.pop()
    assert len(stack) == 0, f"Unclosed events on stack: {stack}"


def log_response_diagnostics(
    response,
    *,
    label: str = "Response Diagnostics",
) -> dict[str, Any]:
    """Extract and log diagnostic info from a Responses API response.

    Logs reasoning, tool-call attempts, MCP items, and output types so
    that CI output (``pytest -s`` or ``--log-cli-level=INFO``) gives
    full visibility into model behaviour even on passing runs.

    Returns the extracted data so callers can make additional assertions
    if needed.
    """
    reasoning_texts = [
        text
        for item in response.output
        if getattr(item, "type", None) == "reasoning"
        for content in getattr(item, "content", [])
        if (text := getattr(content, "text", None))
    ]

    tool_call_attempts = [
        {
            "recipient": msg.get("recipient"),
            "channel": msg.get("channel"),
        }
        for msg in response.output_messages
        if (msg.get("recipient") or "").startswith("python")
    ]

    mcp_items = [
        {
            "name": getattr(item, "name", None),
            "status": getattr(item, "status", None),
        }
        for item in response.output
        if getattr(item, "type", None) == "mcp_call"
    ]

    output_types = [getattr(o, "type", None) for o in response.output]

    diagnostics = {
        "model_attempted_tool_calls": bool(tool_call_attempts),
        "tool_call_attempts": tool_call_attempts,
        "mcp_items": mcp_items,
        "reasoning": reasoning_texts,
        "output_text": response.output_text,
        "output_types": output_types,
    }

    logger.info(
        "\n====== %s ======\n%s\n==============================",
        label,
        json.dumps(diagnostics, indent=2, default=str),
    )

    return diagnostics