test_harmony.py 42 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""Integration tests for the Harmony-based Responses API."""

from __future__ import annotations

7
import importlib.util
8
import json
9
import logging
10
import time
11
from typing import Any
12
13
14

import pytest
import pytest_asyncio
15
import requests
16
from openai import BadRequestError, NotFoundError, OpenAI
17
from openai_harmony import Message
18

19
from ....utils import RemoteOpenAIServer
20
21
22
23
24
25
26
27
28
29
from .conftest import (
    BASE_TEST_ENV,
    events_contain_type,
    has_output_type,
    retry_for_tool_call,
    retry_streaming_for,
    validate_streaming_event_stack,
)

logger = logging.getLogger(__name__)
30
31
32

MODEL_NAME = "openai/gpt-oss-20b"

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
GET_WEATHER_SCHEMA = {
    "type": "function",
    "name": "get_weather",
    "description": "Get current temperature for provided coordinates in celsius.",  # noqa
    "parameters": {
        "type": "object",
        "properties": {
            "latitude": {"type": "number"},
            "longitude": {"type": "number"},
        },
        "required": ["latitude", "longitude"],
        "additionalProperties": False,
    },
    "strict": True,
}

49

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def get_weather(latitude, longitude):
    try:
        response = requests.get(
            f"https://api.open-meteo.com/v1/forecast?"
            f"latitude={latitude}&longitude={longitude}"
            f"&current=temperature_2m,wind_speed_10m"
            f"&hourly=temperature_2m,relative_humidity_2m,"
            f"wind_speed_10m",
            timeout=10,
        )
        data = response.json()
        return data["current"]["temperature_2m"]
    except (requests.RequestException, KeyError) as e:
        logger.warning(
            "External weather API call failed (%s), "
            "returning fake value. This does not affect "
            "test correctness — only the tool-calling "
            "protocol is under test.",
            e,
        )
        return 15.0


def get_place_to_travel():
    return "Paris"


def get_horoscope(sign):
    return f"{sign}: Next Tuesday you will befriend a baby otter."


def call_function(name, args):
    logger.info("Calling function %s with args %s", name, args)
    dispatch = {
        "get_weather": lambda: get_weather(**args),
        "get_place_to_travel": lambda: get_place_to_travel(),
        "get_horoscope": lambda: get_horoscope(**args),
    }
    if name not in dispatch:
        raise ValueError(f"Unknown function: {name}")
    result = dispatch[name]()
    logger.info("Function %s returned: %s", name, result)
    return result


95
@pytest.fixture(scope="module")
96
def server():
97
98
99
    assert importlib.util.find_spec("gpt_oss") is not None, (
        "Harmony tests require gpt_oss package to be installed"
    )
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    args = [
        "--enforce-eager",
        "--tool-server",
        "demo",
        "--max_model_len",
        "5000",
    ]
    env_dict = {
        **BASE_TEST_ENV,
        "VLLM_ENABLE_RESPONSES_API_STORE": "1",
        "PYTHON_EXECUTION_BACKEND": "dangerously_use_uv",
        "VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS": (
            "code_interpreter,container,web_search_preview"
        ),
        "VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS": "1",
    }
116
117
    with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
        yield remote_server
118
119
120
121
122
123
124
125
126
127
128
129
130


@pytest_asyncio.fixture
async def client(server):
    async with server.get_async_client() as async_client:
        yield async_client


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_basic(client: OpenAI, model_name: str):
    response = await client.responses.create(
        model=model_name,
131
        input="What is 123 * 456?",
132
133
134
135
136
137
138
139
140
141
142
    )
    assert response is not None
    print("response: ", response)
    assert response.status == "completed"


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_basic_with_instructions(client: OpenAI, model_name: str):
    response = await client.responses.create(
        model=model_name,
143
        input="What is 123 * 456?",
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        instructions="Respond in Korean.",
    )
    assert response is not None
    assert response.status == "completed"


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_basic_with_reasoning_effort(client: OpenAI, model_name: str):
    response = await client.responses.create(
        model=model_name,
        input="What is the capital of South Korea?",
        reasoning={"effort": "low"},
    )
    assert response is not None
    assert response.status == "completed"


162
163
164
165
166
167
168
169
170
171
172
173
174
175
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_max_tokens(client: OpenAI, model_name: str):
    response = await client.responses.create(
        model=model_name,
        input="What is the first paragraph of Moby Dick?",
        reasoning={"effort": "low"},
        max_output_tokens=30,
    )
    assert response is not None
    assert response.status == "incomplete"
    assert response.incomplete_details.reason == "max_output_tokens"


176
177
178
179
180
181
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_chat(client: OpenAI, model_name: str):
    response = await client.responses.create(
        model=model_name,
        input=[
182
183
184
            {"role": "system", "content": "Respond in Korean."},
            {"role": "user", "content": "Hello!"},
            {"role": "assistant", "content": "Hello! How can I help you today?"},
185
            {"role": "user", "content": "What is 123 * 456? Explain your answer."},
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        ],
    )
    assert response is not None
    assert response.status == "completed"


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_chat_with_input_type(client: OpenAI, model_name: str):
    response = await client.responses.create(
        model=model_name,
        input=[
            {
                "role": "user",
200
                "content": [{"type": "input_text", "text": "What is 123 * 456?"}],
201
202
203
204
205
206
207
208
209
210
211
212
213
            },
        ],
    )
    assert response is not None
    assert response.status == "completed"


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_structured_output(client: OpenAI, model_name: str):
    response = await client.responses.create(
        model=model_name,
        input=[
214
            {"role": "system", "content": "Extract the event information."},
215
216
            {
                "role": "user",
217
                "content": "Alice and Bob are going to a science fair on Friday.",
218
219
220
221
222
223
224
225
226
            },
        ],
        text={
            "format": {
                "type": "json_schema",
                "name": "calendar_event",
                "schema": {
                    "type": "object",
                    "properties": {
227
228
                        "name": {"type": "string"},
                        "date": {"type": "string"},
229
230
231
232
                        "participants": {
                            "type": "array",
                            "items": {"type": "string"},
                        },
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
                    },
                    "required": ["name", "date", "participants"],
                    "additionalProperties": False,
                },
                "description": "A calendar event.",
                "strict": True,
            }
        },
    )
    assert response is not None
    assert response.status == "completed"


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_structured_output_with_parse(client: OpenAI, model_name: str):
    from pydantic import BaseModel

    class CalendarEvent(BaseModel):
        name: str
        date: str
        participants: list[str]

    response = await client.responses.parse(
        model=model_name,
        input="Alice and Bob are going to a science fair on Friday",
        instructions="Extract the event information",
        text_format=CalendarEvent,
    )
    assert response is not None
    assert response.status == "completed"


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_store(client: OpenAI, model_name: str):
    for store in [True, False]:
        response = await client.responses.create(
            model=model_name,
272
            input="What is 123 * 456?",
273
274
275
276
277
278
279
280
281
282
            store=store,
        )
        assert response is not None

        try:
            _retrieved_response = await client.responses.retrieve(response.id)
            is_not_found = False
        except NotFoundError:
            is_not_found = True

283
284
285
        assert is_not_found == (not store), (
            f"store={store}: expected not_found={not store}, got {is_not_found}"
        )
286
287
288
289
290
291
292


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_background(client: OpenAI, model_name: str):
    response = await client.responses.create(
        model=model_name,
293
        input="What is 123 * 456?",
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
        background=True,
    )
    assert response is not None

    retries = 0
    max_retries = 30
    while retries < max_retries:
        response = await client.responses.retrieve(response.id)
        if response.status == "completed":
            break
        time.sleep(1)
        retries += 1

    assert response.status == "completed"


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_background_cancel(client: OpenAI, model_name: str):
    response = await client.responses.create(
        model=model_name,
        input="Write a long story about a cat.",
        background=True,
    )
    assert response is not None
    time.sleep(1)

    cancelled_response = await client.responses.cancel(response.id)
    assert cancelled_response is not None


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_stateful_multi_turn(client: OpenAI, model_name: str):
    response1 = await client.responses.create(
329
        model=model_name, input="What is 123 * 456?"
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
    )
    assert response1.status == "completed"

    response2 = await client.responses.create(
        model=model_name,
        input="What if I increase both numbers by 1?",
        previous_response_id=response1.id,
    )
    assert response2.status == "completed"

    response3 = await client.responses.create(
        model=model_name,
        input="Divide the result by 2.",
        previous_response_id=response2.id,
    )
    assert response3.status == "completed"


348
349
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
350
351
352
async def test_streaming_types(
    pairs_of_event_types: dict[str, str], client: OpenAI, model_name: str
):
353
354
355
356
357
358
359
360
361
362
363
    stream = await client.responses.create(
        model=model_name,
        input="tell me a story about a cat in 20 words",
        reasoning={"effort": "low"},
        tools=[],
        stream=True,
        background=False,
    )
    events = []
    async for event in stream:
        events.append(event)
364

365
    validate_streaming_event_stack(events, pairs_of_event_types)
366
367


368
369
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
370
371
372
async def test_function_calling_with_streaming_types(
    pairs_of_event_types: dict[str, str], client: OpenAI, model_name: str
):
373
374
375
376
377
378
379
    """Streaming event nesting for function-calling responses."""

    def _has_function_events(evts: list) -> bool:
        return events_contain_type(evts, "function_call_arguments")

    events = await retry_streaming_for(
        client,
380
        model=model_name,
381
382
383
384
        validate_events=_has_function_events,
        input=[{"role": "user", "content": "What's the weather like in Paris today?"}],
        tools=[GET_WEATHER_SCHEMA],
        temperature=0.0,
385
386
    )

387
    validate_streaming_event_stack(events, pairs_of_event_types)
388
389


390
391
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
392
393
@pytest.mark.parametrize("background", [True, False])
async def test_streaming(client: OpenAI, model_name: str, background: bool):
394
    # TODO: Add back when web search and code interpreter are available in CI
395
396
    prompts = [
        "tell me a story about a cat in 20 words",
397
        "What is 123 * 456? Use python to calculate the result.",
398
        # "When did Jensen found NVIDIA? Search it and answer the year only.",
399
400
401
    ]

    for prompt in prompts:
402
        stream = await client.responses.create(
403
404
405
406
            model=model_name,
            input=prompt,
            reasoning={"effort": "low"},
            tools=[
407
408
409
                # {
                #     "type": "web_search_preview"
                # },
410
                {"type": "code_interpreter", "container": {"type": "auto"}},
411
412
            ],
            stream=True,
413
            background=background,
414
            extra_body={"enable_response_messages": True},
415
416
        )

417
418
419
        current_item_id = ""
        current_content_index = -1

420
421
        events = []
        current_event_mode = None
422
        resp_id = None
423
        checked_response_completed = False
424
425

        async for event in stream:
426
427
428
            if event.type == "response.created":
                resp_id = event.response.id

429
            # Validate custom fields on response-level events
430
            if event.type in [
431
432
433
                "response.completed",
                "response.in_progress",
                "response.created",
434
            ]:
435
436
                assert "input_messages" in event.response.model_extra
                assert "output_messages" in event.response.model_extra
437
438
439
440
441
442
443
444
445
446
                if event.type == "response.completed":
                    # make sure the serialization of content works
                    for msg in event.response.model_extra["output_messages"]:
                        # make sure we can convert the messages back into harmony
                        Message.from_dict(msg)

                    for msg in event.response.model_extra["input_messages"]:
                        # make sure we can convert the messages back into harmony
                        Message.from_dict(msg)
                    checked_response_completed = True
447

448
449
            if current_event_mode != event.type:
                current_event_mode = event.type
450
                logger.debug("[%s] ", event.type)
451

452
            # Verify item IDs
453
454
455
456
            if event.type == "response.output_item.added":
                assert event.item.id != current_item_id
                current_item_id = event.item.id
            elif event.type in [
457
458
                "response.output_text.delta",
                "response.reasoning_text.delta",
459
460
461
            ]:
                assert event.item_id == current_item_id

462
            # Verify content indices
463
            if event.type in [
464
465
                "response.content_part.added",
                "response.reasoning_part.added",
466
            ]:
467
468
469
                assert event.content_index != current_content_index
                current_content_index = event.content_index
            elif event.type in [
470
471
                "response.output_text.delta",
                "response.reasoning_text.delta",
472
473
474
            ]:
                assert event.content_index == current_content_index

475
476
477
            events.append(event)

        assert len(events) > 0
478
        assert events[-1].response.output, "Final response should have output"
479
        assert checked_response_completed
480

481
482
483
        if background:
            starting_after = 5
            async with await client.responses.retrieve(
484
                response_id=resp_id, stream=True, starting_after=starting_after
485
            ) as replay_stream:
486
                counter = starting_after
487
                async for event in replay_stream:
488
489
                    counter += 1
                    assert event == events[counter]
490
            assert counter == len(events) - 1
491

492
493
494

@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
495
@pytest.mark.skip(reason="Web search tool is not available in CI yet.")
496
497
498
499
async def test_web_search(client: OpenAI, model_name: str):
    response = await client.responses.create(
        model=model_name,
        input="Who is the president of South Korea as of now?",
500
        tools=[{"type": "web_search_preview"}],
501
502
503
504
505
506
507
508
    )
    assert response is not None
    assert response.status == "completed"


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_code_interpreter(client: OpenAI, model_name: str):
509
510
511
512
    timeout_value = client.timeout * 3
    client_with_timeout = client.with_options(timeout=timeout_value)

    response = await client_with_timeout.responses.create(
513
        model=model_name,
514
515
516
517
518
519
520
        input=(
            "What's the first 4 digits after the decimal point of "
            "cube root of `19910212 * 20250910`? "
            "Show only the digits. The python interpreter is not stateful "
            "and you must print to see the output."
        ),
        tools=[{"type": "code_interpreter", "container": {"type": "auto"}}],
521
        temperature=0.0,
522
523
524
    )
    assert response is not None
    assert response.status == "completed"
525
    assert response.usage.output_tokens_details.tool_output_tokens > 0
526

527
528
529
    for item in response.output:
        if item.type == "message":
            output_string = item.content[0].text
530
531
532
            assert "5846" in output_string, (
                f"Expected '5846' in output, got: {output_string}"
            )
533
534


535
536
537
538
539
540
541
542
543
544
545
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_reasoning_item(client: OpenAI, model_name: str):
    response = await client.responses.create(
        model=model_name,
        input=[
            {"type": "message", "content": "Hello.", "role": "user"},
            {
                "type": "reasoning",
                "id": "lol",
                "content": [
546
                    {"type": "reasoning_text", "text": "We need to respond: greeting."}
547
548
549
550
551
552
553
554
555
556
                ],
                "summary": [],
            },
        ],
        temperature=0.0,
    )
    assert response is not None
    assert response.status == "completed"


557
558
559
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_function_calling(client: OpenAI, model_name: str):
560
    tools = [GET_WEATHER_SCHEMA]
561

562
563
    response = await retry_for_tool_call(
        client,
564
        model=model_name,
565
        expected_tool_type="function_call",
566
567
        input="What's the weather like in Paris today?",
        tools=tools,
568
        temperature=0.0,
569
        extra_body={"request_id": "test_function_calling_non_resp"},
570
571
    )
    assert response.status == "completed"
572
573
574
575
    assert has_output_type(response, "function_call"), (
        f"Expected function_call in output, got: "
        f"{[getattr(o, 'type', None) for o in response.output]}"
    )
576

577
    tool_call = next(o for o in response.output if o.type == "function_call")
578
    args = json.loads(tool_call.arguments)
579
    result = call_function(tool_call.name, args)
580
581
582

    response_2 = await client.responses.create(
        model=model_name,
583
584
585
586
587
588
589
        input=[
            {
                "type": "function_call_output",
                "call_id": tool_call.call_id,
                "output": str(result),
            }
        ],
590
591
        tools=tools,
        previous_response_id=response.id,
592
        temperature=0.0,
593
594
595
596
597
598
599
600
601
602
    )
    assert response_2.status == "completed"
    assert response_2.output_text is not None

    # NOTE: chain-of-thought should be removed.
    response_3 = await client.responses.create(
        model=model_name,
        input="What's the weather like in Paris today?",
        tools=tools,
        previous_response_id=response_2.id,
603
        temperature=0.0,
604
605
606
607
608
609
610
611
    )
    assert response_3.status == "completed"
    assert response_3.output_text is not None


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_function_calling_multi_turn(client: OpenAI, model_name: str):
612
    """Multi-tool, multi-turn function calling with retry at API level."""
613
614
615
616
617
618
619
620
621
622
623
624
625
    tools = [
        {
            "type": "function",
            "name": "get_place_to_travel",
            "description": "Get a random place to travel",
            "parameters": {
                "type": "object",
                "properties": {},
                "required": [],
                "additionalProperties": False,
            },
            "strict": True,
        },
626
        GET_WEATHER_SCHEMA,
627
628
    ]

629
630
631
    # Turn 1: model should call one of the tools
    response = await retry_for_tool_call(
        client,
632
        model=model_name,
633
        expected_tool_type="function_call",
634
        input="Help me plan a trip to a random place. And tell me the weather there.",
635
        tools=tools,
636
        temperature=0.0,
637
638
    )
    assert response.status == "completed"
639
640
641
642
    assert has_output_type(response, "function_call"), (
        f"Turn 1: expected function_call, got: "
        f"{[getattr(o, 'type', None) for o in response.output]}"
    )
643

644
645
    tool_call = next(o for o in response.output if o.type == "function_call")
    result = call_function(tool_call.name, json.loads(tool_call.arguments))
646

647
648
649
    # Turn 2
    response_2 = await retry_for_tool_call(
        client,
650
        model=model_name,
651
        expected_tool_type="function_call",
652
653
654
655
656
657
658
        input=[
            {
                "type": "function_call_output",
                "call_id": tool_call.call_id,
                "output": str(result),
            }
        ],
659
660
        tools=tools,
        previous_response_id=response.id,
661
        temperature=0.0,
662
663
664
    )
    assert response_2.status == "completed"

665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
    # If model produced another tool call, execute it
    if has_output_type(response_2, "function_call"):
        tool_call_2 = next(o for o in response_2.output if o.type == "function_call")
        result_2 = call_function(tool_call_2.name, json.loads(tool_call_2.arguments))
        response_3 = await client.responses.create(
            model=model_name,
            input=[
                {
                    "type": "function_call_output",
                    "call_id": tool_call_2.call_id,
                    "output": str(result_2),
                }
            ],
            tools=tools,
            previous_response_id=response_2.id,
            temperature=0.0,
        )
        assert response_3.status == "completed"
        assert response_3.output_text is not None
    else:
        # Model went straight to answering - acceptable but unexpected.
        # Log as warning so it shows up in CI without failing the test.
        assert response_2.output_text is not None
        pytest.xfail(
            "Model went straight to answering instead of calling a "
            "second tool. Valid behaviour but not the expected path."
            "If this happens consistently, the prompt or model may have "
            "changed behaviour."
        )
694
695
696
697
698


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_function_calling_required(client: OpenAI, model_name: str):
699
    tools = [GET_WEATHER_SCHEMA]
700
701
702
703
704
705
706
707
708
709

    with pytest.raises(BadRequestError):
        await client.responses.create(
            model=model_name,
            input="What's the weather like in Paris today?",
            tools=tools,
            tool_choice="required",
        )


710
711
712
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_system_message_with_tools(client: OpenAI, model_name: str):
713
    from vllm.entrypoints.openai.parser.harmony_utils import get_system_message
714
715
716
717
718
719
720
721
722
723
724
725

    # Test with custom tools enabled - commentary channel should be available
    sys_msg = get_system_message(with_custom_tools=True)
    valid_channels = sys_msg.content[0].channel_config.valid_channels
    assert "commentary" in valid_channels

    # Test with custom tools disabled - commentary channel should be removed
    sys_msg = get_system_message(with_custom_tools=False)
    valid_channels = sys_msg.content[0].channel_config.valid_channels
    assert "commentary" not in valid_channels


726
727
728
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_function_calling_full_history(client: OpenAI, model_name: str):
729
    tools = [GET_WEATHER_SCHEMA]
730

731
732
733
    input_messages = [
        {"role": "user", "content": "What's the weather like in Paris today?"}
    ]
734

735
736
    response = await retry_for_tool_call(
        client,
737
        model=model_name,
738
        expected_tool_type="function_call",
739
740
        input=input_messages,
        tools=tools,
741
        temperature=0.0,
742
743
744
    )
    assert response.status == "completed"

745
746
747
748
749
    tool_call = next((o for o in response.output if o.type == "function_call"), None)
    assert tool_call is not None, (
        f"Expected function_call in output, got: "
        f"{[getattr(o, 'type', None) for o in response.output]}"
    )
750

751
    result = call_function(tool_call.name, json.loads(tool_call.arguments))
752

753
    input_messages.extend(response.output)
754
755
756
757
758
759
760
761
762
763
764
765
    input_messages.append(
        {  # append result message
            "type": "function_call_output",
            "call_id": tool_call.call_id,
            "output": str(result),
        }
    )

    response_2 = await client.responses.create(
        model=model_name,
        input=input_messages,
        tools=tools,
766
        temperature=0.0,
767
768
769
    )
    assert response_2.status == "completed"
    assert response_2.output_text is not None
770
771


772
773
774
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_function_calling_with_stream(client: OpenAI, model_name: str):
775
    """Function calling via streaming, with retry for non-determinism."""
776
777
    tools = [GET_WEATHER_SCHEMA]
    input_list = [
778
        {"role": "user", "content": "What's the weather like in Paris today?"},
779
    ]
780
781
782
783
784
785
786
787
788
789

    def _has_function_call(evts: list) -> bool:
        return any(
            getattr(e, "type", "") == "response.output_item.added"
            and getattr(getattr(e, "item", None), "type", None) == "function_call"
            for e in evts
        )

    events = await retry_streaming_for(
        client,
790
        model=model_name,
791
        validate_events=_has_function_call,
792
793
        input=input_list,
        tools=tools,
794
        temperature=0.0,
795
    )
796
797
798
799

    # Parse tool calls from events
    final_tool_calls: dict[int, Any] = {}
    for event in events:
800
        if event.type == "response.output_item.added":
801
802
            if getattr(event.item, "type", None) == "function_call":
                final_tool_calls[event.output_index] = event.item
803
        elif event.type == "response.function_call_arguments.delta":
804
805
806
            tc = final_tool_calls.get(event.output_index)
            if tc:
                tc.arguments += event.delta
807
        elif event.type == "response.function_call_arguments.done":
808
809
810
811
812
            tc = final_tool_calls.get(event.output_index)
            if tc:
                assert event.arguments == tc.arguments

    # Find get_weather call
813
    tool_call = None
814
    result = None
815
    for tc in final_tool_calls.values():
816
        if getattr(tc, "type", None) == "function_call" and tc.name == "get_weather":
817
818
819
            args = json.loads(tc.arguments)
            result = call_function(tc.name, args)
            tool_call = tc
820
            input_list.append(tc)
821
            break
822
823

    assert tool_call is not None, (
824
825
        "Expected model to call 'get_weather', "
        f"but got: {[getattr(tc, 'name', None) for tc in final_tool_calls.values()]}"
826
    )
827
828

    # Second turn with the tool result
829
830
831
832
833
834
835
836
837
838
839
840
    response = await client.responses.create(
        model=model_name,
        input=input_list
        + [
            {
                "type": "function_call_output",
                "call_id": tool_call.call_id,
                "output": str(result),
            }
        ],
        tools=tools,
        stream=True,
841
        temperature=0.0,
842
843
844
845
846
847
848
849
850
851
852
    )
    async for event in response:
        # check that no function call events in the stream
        assert event.type != "response.function_call_arguments.delta"
        assert event.type != "response.function_call_arguments.done"
        # check that the response contains output text
        if event.type == "response.completed":
            assert len(event.response.output) > 0
            assert event.response.output_text is not None


853
854
855
856
857
858
859
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_function_calling_no_code_interpreter_events(
    client: OpenAI, model_name: str
):
    """Verify that function calls don't trigger code_interpreter events.

860
861
862
    Uses retry_streaming_for to handle non-determinism: the model might not
    always produce a function_call, but if it does, code_interpreter events
    should NEVER appear.
863
864
865
    """
    tools = [GET_WEATHER_SCHEMA]
    input_list = [
866
        {"role": "user", "content": "What's the weather like in Paris today?"},
867
    ]
868
869
870
871
872
873
874
875
876
877

    def _has_function_call(evts: list) -> bool:
        return any(
            getattr(e, "type", "") == "response.output_item.added"
            and getattr(getattr(e, "item", None), "type", None) == "function_call"
            for e in evts
        )

    events = await retry_streaming_for(
        client,
878
        model=model_name,
879
        validate_events=_has_function_call,
880
881
        input=input_list,
        tools=tools,
882
        temperature=0.0,
883
884
    )

885
886
    event_types_seen = {e.type for e in events}
    function_call_found = _has_function_call(events)
887

888
889
890
891
    assert function_call_found, (
        f"Expected to see a function_call after retries. "
        f"Event types: {sorted(event_types_seen)}"
    )
892

893
894
    # The actual invariant under test
    for event in events:
895
        assert "code_interpreter" not in event.type, (
896
897
            f"Found code_interpreter event '{event.type}' during function call. "
            "Function calls should only emit function_call events."
898
899
900
901
902
903
904
905
906
907
908
        )

    # Verify we saw the correct function call event types
    assert (
        "response.function_call_arguments.delta" in event_types_seen
        or "response.function_call_arguments.done" in event_types_seen
    ), "Expected to see function_call_arguments events"


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
Robert Shaw's avatar
Robert Shaw committed
909
910
911
912
@pytest.mark.skip(
    reason="This test is flaky in CI, needs investigation and "
    "potential fixes in the code interpreter MCP implementation."
)
913
async def test_mcp_code_interpreter_streaming(client: OpenAI, model_name: str, server):
914
    tools = [{"type": "mcp", "server_label": "code_interpreter"}]
915
    input_text = (
916
        "Calculate 123 * 456 using python. "
917
918
        "The python interpreter is not stateful and you must "
        "print to see the output."
919
920
    )

921
922
923
924
925
    def _has_mcp_call(evts: list) -> bool:
        return events_contain_type(evts, "mcp_call")

    events = await retry_streaming_for(
        client,
926
        model=model_name,
927
        validate_events=_has_mcp_call,
928
929
930
931
932
933
934
935
        input=input_text,
        tools=tools,
        temperature=0.0,
        instructions=(
            "You must use the Python tool to execute code. Never simulate execution."
        ),
    )

936
937
938
939
940
941
942
943
944
945
946
947
    event_types = [e.type for e in events]
    event_types_set = set(event_types)
    logger.info(
        "\n====== MCP Streaming Diagnostics ======\n"
        "Event count: %d\n"
        "Event types (in order): %s\n"
        "Unique event types: %s\n"
        "=======================================",
        len(events),
        event_types,
        sorted(event_types_set),
    )
948

949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
    # Verify the full MCP streaming lifecycle
    assert "response.output_item.added" in event_types_set, (
        f"MCP call was not added. Events: {sorted(event_types_set)}"
    )
    assert "response.mcp_call.in_progress" in event_types_set, (
        f"MCP call in_progress not seen. Events: {sorted(event_types_set)}"
    )
    assert "response.mcp_call_arguments.delta" in event_types_set, (
        f"MCP arguments delta not seen. Events: {sorted(event_types_set)}"
    )
    assert "response.mcp_call_arguments.done" in event_types_set, (
        f"MCP arguments done not seen. Events: {sorted(event_types_set)}"
    )
    assert "response.mcp_call.completed" in event_types_set, (
        f"MCP call completed not seen. Events: {sorted(event_types_set)}"
    )
    assert "response.output_item.done" in event_types_set, (
        f"MCP item done not seen. Events: {sorted(event_types_set)}"
    )
968

969
970
    # Validate specific MCP event details
    for event in events:
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
        if event.type == "response.output_item.added":
            if hasattr(event.item, "type") and event.item.type == "mcp_call":
                assert event.item.name == "python"
                assert event.item.server_label == "code_interpreter"
        elif event.type == "response.mcp_call_arguments.done":
            assert event.name == "python"
            assert event.arguments is not None
        elif (
            event.type == "response.output_item.done"
            and hasattr(event.item, "type")
            and event.item.type == "mcp_call"
        ):
            assert event.item.name == "python"
            assert event.item.status == "completed"

986
987
988
989
990
    # code_interpreter events should NOT appear when using MCP type
    code_interp_events = [e.type for e in events if "code_interpreter" in e.type]
    assert not code_interp_events, (
        "Should not see code_interpreter events when using MCP type, "
        f"but got: {code_interp_events}"
991
992
993
994
995
996
    )


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_multi_turn(client: OpenAI, model_name: str, server):
997
998
999
1000
1001
    """MCP tools work across multiple turns via previous_response_id."""
    tools = [{"type": "mcp", "server_label": "code_interpreter"}]
    instructions = (
        "You must use the Python tool to execute code. Never simulate execution."
    )
1002

1003
1004
1005
    # First turn
    response1 = await retry_for_tool_call(
        client,
1006
        model=model_name,
1007
        expected_tool_type="mcp_call",
1008
        input="Calculate 1234 * 4567 using python tool and print the result.",
1009
1010
        tools=tools,
        temperature=0.0,
1011
        instructions=instructions,
1012
1013
1014
1015
        extra_body={"enable_response_messages": True},
    )
    assert response1.status == "completed"

1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
    # Verify MCP call in output_messages
    tool_call_found = any(
        (msg.get("recipient") or "").startswith("python")
        for msg in response1.output_messages
    )
    tool_response_found = any(
        msg.get("author", {}).get("role") == "tool"
        and (msg.get("author", {}).get("name") or "").startswith("python")
        for msg in response1.output_messages
    )
1026
1027
1028
    assert tool_call_found, "MCP tool call not found in output_messages"
    assert tool_response_found, "MCP tool response not found in output_messages"

1029
1030
    # No developer messages expected for elevated tools
    developer_msgs = [
1031
1032
        msg for msg in response1.input_messages if msg["author"]["role"] == "developer"
    ]
1033
    assert len(developer_msgs) == 0, "No developer message expected for elevated tools"
1034

1035
    # Second turn
1036
1037
1038
1039
1040
    response2 = await client.responses.create(
        model=model_name,
        input="Now divide that result by 2.",
        tools=tools,
        temperature=0.0,
1041
        instructions=instructions,
1042
1043
1044
1045
1046
1047
        previous_response_id=response1.id,
        extra_body={"enable_response_messages": True},
    )
    assert response2.status == "completed"


1048
1049
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
1050
async def test_output_messages_enabled(client: OpenAI, model_name: str, server):
1051
1052
1053
    response = await client.responses.create(
        model=model_name,
        input="What is the capital of South Korea?",
1054
1055
        extra_body={"enable_response_messages": True},
    )
1056
1057
1058
1059
1060

    assert response is not None
    assert response.status == "completed"
    assert len(response.input_messages) > 0
    assert len(response.output_messages) > 0
1061
1062
1063
1064
1065
1066
1067


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_function_call_with_previous_input_messages(
    client: OpenAI, model_name: str
):
1068
    """Multi-turn function calling using previous_input_messages."""
1069
1070
1071
1072
1073
1074
1075
    tools = [
        {
            "type": "function",
            "name": "get_horoscope",
            "description": "Get today's horoscope for an astrological sign.",
            "parameters": {
                "type": "object",
1076
                "properties": {"sign": {"type": "string"}},
1077
1078
1079
1080
1081
1082
1083
                "required": ["sign"],
                "additionalProperties": False,
            },
            "strict": True,
        }
    ]

1084
1085
1086
    # Step 1: Get a function call from the model
    response = await retry_for_tool_call(
        client,
1087
        model=model_name,
1088
        expected_tool_type="function_call",
1089
1090
        input="What is the horoscope for Aquarius today?",
        tools=tools,
1091
        temperature=0.0,
1092
        extra_body={"enable_response_messages": True},
1093
        max_output_tokens=1000,
1094
1095
1096
    )
    assert response.status == "completed"

1097
1098
1099
1100
1101
1102
1103
1104
    function_call = next(
        (item for item in response.output if item.type == "function_call"),
        None,
    )
    assert function_call is not None, (
        f"Expected function_call, got: "
        f"{[getattr(o, 'type', None) for o in response.output]}"
    )
1105
1106
1107
1108
1109
    assert function_call.name == "get_horoscope"

    args = json.loads(function_call.arguments)
    result = call_function(function_call.name, args)

1110
    # Step 2: Build full conversation history
1111
    previous_messages = (
1112
1113
        response.input_messages
        + response.output_messages
1114
1115
1116
1117
1118
1119
1120
1121
1122
        + [
            {
                "role": "tool",
                "name": "functions.get_horoscope",
                "content": [{"type": "text", "text": str(result)}],
            }
        ]
    )

1123
1124
    # Step 3: Second call with previous_input_messages
    response_2 = await client.responses.create(
1125
1126
        model=model_name,
        tools=tools,
1127
        temperature=0.0,
1128
        input="Now tell me the horoscope based on the tool result.",
1129
1130
1131
1132
1133
1134
1135
1136
        extra_body={
            "previous_input_messages": previous_messages,
            "enable_response_messages": True,
        },
    )
    assert response_2.status == "completed"
    assert response_2.output_text is not None

1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
    # Verify exactly 1 system, 1 developer, 1 tool message
    num_system = 0
    num_developer = 0
    num_tool = 0
    for msg_dict in response_2.input_messages:
        # input_messages use {"author": {"role": "..."}} format,
        # not the top-level {"role": "..."} that Message.from_dict
        # expects.
        author = msg_dict.get("author", {})
        role = author.get("role") if isinstance(author, dict) else None
        if role == "system":
            num_system += 1
        elif role == "developer":
            num_developer += 1
        elif role == "tool":
            num_tool += 1
    assert num_system == 1, f"Expected 1 system message, got {num_system}"
    assert num_developer == 1, f"Expected 1 developer message, got {num_developer}"
    assert num_tool == 1, f"Expected 1 tool message, got {num_tool}"

1157
    output_text = response_2.output_text.lower()
1158
1159
    assert any(kw in output_text for kw in ["aquarius", "otter", "tuesday"]), (
        f"Expected horoscope-related content, got: {response_2.output_text}"
1160
    )
1161
1162
1163
1164
1165
1166
1167


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_chat_truncation_content_not_null(client: OpenAI, model_name: str):
    response = await client.chat.completions.create(
        model=model_name,
1168
1169
1170
        messages=[
            {
                "role": "user",
1171
1172
1173
1174
                "content": (
                    "What is the role of AI in medicine? "
                    "The response must exceed 350 words."
                ),
1175
1176
            }
        ],
1177
        temperature=0.0,
1178
        max_tokens=350,
1179
1180
1181
1182
1183
    )
    choice = response.choices[0]
    assert choice.finish_reason == "length", (
        f"Expected finish_reason='length', got {choice.finish_reason}"
    )
1184
    assert choice.message.content is not None, "Content should not be None"
1185
    assert len(choice.message.content) > 0, "Content should not be empty"
1186
1187
1188
1189


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
1190
1191
async def test_system_prompt_override_no_duplication(client: OpenAI, model_name: str):
    """Hard check: custom system message must not be duplicated."""
1192
1193
1194
    response = await client.responses.create(
        model=model_name,
        input=[
1195
1196
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": "Hello"},
1197
1198
        ],
        extra_body={"enable_response_messages": True},
1199
        temperature=0.0,
1200
1201
1202
1203
    )
    assert response.status == "completed"
    assert response.output_text is not None

1204
1205
1206
1207
1208
1209
1210
1211
1212
    num_system = 0
    for msg in response.input_messages:
        # input_messages use {"author": {"role": "system"}} format,
        # not the top-level {"role": "system"} that Message.from_dict expects.
        author = msg.get("author", {})
        role = author.get("role") if isinstance(author, dict) else None
        if role == "system":
            num_system += 1
    assert num_system == 1, f"Expected 1 system message, got {num_system}"
1213

1214

1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.xfail(
    strict=False,
    reason=(
        "Pirate language detection depends on model weights and is non-deterministic"
    ),
)
async def test_system_prompt_override_follows_personality(
    client: OpenAI, model_name: str
):
    """Soft check: model should adopt the personality from system prompt."""
    response = await client.responses.create(
1228
1229
1230
1231
        model=model_name,
        input=[
            {
                "role": "system",
1232
1233
1234
1235
                "content": (
                    "You are a pirate. Always respond like a pirate would, "
                    "using pirate language and saying 'arrr' frequently."
                ),
1236
            },
1237
            {"role": "user", "content": "Hello, how are you?"},
1238
1239
1240
        ],
        temperature=0.0,
    )
1241
1242
1243
1244
1245
    assert response.status == "completed"
    output_text = response.output_text.lower()
    pirate_indicators = ["arrr", "matey", "ahoy", "ye", "sea", "aye", "sail"]
    assert any(kw in output_text for kw in pirate_indicators), (
        f"Expected pirate language, got: {response.output_text}"
1246
    )
1247

1248
1249
1250
1251
1252
1253

@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_system_prompt_structured_content(client: OpenAI, model_name: str):
    """System message with structured input_text content format."""
    response = await client.responses.create(
1254
1255
1256
1257
        model=model_name,
        input=[
            {
                "role": "system",
1258
1259
1260
                "content": [
                    {"type": "input_text", "text": "You are a helpful assistant."}
                ],
1261
            },
1262
            {"role": "user", "content": "What is 2 + 2?"},
1263
1264
1265
        ],
        temperature=0.0,
    )
1266
1267
1268
    assert response is not None
    assert response.status == "completed"
    assert response.output_text is not None