test_basic.py 2.38 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import openai  # use the official client for correctness check
import pytest


@pytest.mark.asyncio
async def test_simple_input(client: openai.AsyncOpenAI):
    response = await client.responses.create(input="What is 13 * 24?")
    print(response)

    outputs = response.output
    # Whether the output contains the answer.
    assert outputs[-1].type == "message"
    assert "312" in outputs[-1].content[0].text

    # Whether the output contains the reasoning.
    assert outputs[0].type == "reasoning"
20
    assert outputs[0].content[0].text != ""
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75


@pytest.mark.asyncio
async def test_instructions(client: openai.AsyncOpenAI):
    response = await client.responses.create(
        instructions="Finish the answer with QED.",
        input="What is 13 * 24?",
    )
    print(response)

    output_text = response.output[-1].content[0].text
    assert "312" in output_text
    assert "QED" in output_text


@pytest.mark.asyncio
async def test_chat(client: openai.AsyncOpenAI):
    response = await client.responses.create(input=[
        {
            "role": "system",
            "content": "Finish the answer with QED."
        },
        {
            "role": "user",
            "content": "What is 5 * 3?"
        },
        {
            "role": "assistant",
            "content": "15. QED."
        },
        {
            "role": "user",
            "content": "Multiply the result by 2."
        },
    ], )
    print(response)

    output_text = response.output[-1].content[0].text
    assert "30" in output_text
    assert "QED" in output_text


@pytest.mark.asyncio
async def test_chat_with_input_type(client: openai.AsyncOpenAI):
    response = await client.responses.create(input=[
        {
            "role": "user",
            "content": [{
                "type": "input_text",
                "text": "Hello!"
            }],
        },
    ], )
    print(response)
    assert response.status == "completed"
76
77
78
79
80
81
82
83
84
85
86
87
88


@pytest.mark.asyncio
async def test_logprobs(client: openai.AsyncOpenAI):
    response = await client.responses.create(
        include=["message.output_text.logprobs"],
        input="What is 13 * 24?",
        top_logprobs=5,
    )
    print(response)
    outputs = response.output
    assert outputs[-1].content[-1].logprobs
    assert len(outputs[-1].content[-1].logprobs[0].top_logprobs) == 5