test_chat_echo.py 3.74 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from typing import NamedTuple

zhuwenwen's avatar
zhuwenwen committed
6
import os
7
8
9
10
import openai  # use the official client for correctness check
import pytest
import pytest_asyncio

11
12
from vllm.config import ModelConfig

zhuwenwen's avatar
zhuwenwen committed
13
from ...utils import RemoteOpenAIServer, models_path_prefix
14
15

# # any model with a chat template should work here
zhuwenwen's avatar
zhuwenwen committed
16
MODEL_NAME = os.path.join(models_path_prefix, "Qwen/Qwen2-1.5B-Instruct")
17
18


19
20
21
22
23
24
25
26
27
def get_vocab_size(model_name):
    config = ModelConfig(
        model=model_name,
        seed=0,
        dtype="float16",
    )
    return config.get_vocab_size()


28
29
30
31
32
33
34
35
36
@pytest.fixture(scope="module")
def server():
    args = [
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "float16",
        "--enforce-eager",
        "--max-model-len",
        "4080",
37
        "--max-logprobs",  # test prompt_logprobs equal to -1
38
        "151936",
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    ]

    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
        yield remote_server


@pytest_asyncio.fixture
async def client(server):
    async with server.get_async_client() as async_client:
        yield async_client


class TestCase(NamedTuple):
    model_name: str
    echo: bool


@pytest.mark.asyncio
@pytest.mark.parametrize(
    "test_case",
    [
        TestCase(model_name=MODEL_NAME, echo=True),
61
        TestCase(model_name=MODEL_NAME, echo=False),
62
63
64
    ],
)
async def test_chat_session_with_echo_and_continue_final_message(
65
66
    client: openai.AsyncOpenAI, test_case: TestCase
):
67
68
69
70
    saying: str = "Here is a common saying about apple. An apple a day, keeps"
    # test echo with continue_final_message parameter
    chat_completion = await client.chat.completions.create(
        model=test_case.model_name,
71
72
73
74
        messages=[
            {"role": "user", "content": "tell me a common saying"},
            {"role": "assistant", "content": saying},
        ],
75
76
77
        extra_body={
            "echo": test_case.echo,
            "continue_final_message": True,
78
79
80
            "add_generation_prompt": False,
        },
    )
81
82
83
84
85
86
87
88
89
90
91
92
    assert chat_completion.id is not None
    assert len(chat_completion.choices) == 1

    choice = chat_completion.choices[0]
    assert choice.finish_reason == "stop"

    message = choice.message
    if test_case.echo:
        assert message.content is not None and saying in message.content
    else:
        assert message.content is not None and saying not in message.content
    assert message.role == "assistant"
93
94
95
96


@pytest.mark.asyncio
async def test_prompt_logprobs(client: openai.AsyncOpenAI):
97
98
99
100
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Beijing is the capital of which country?"},
    ]
101
102
103
104
105
106
107
108
109

    completion = await client.chat.completions.create(
        model=MODEL_NAME,
        messages=messages,
        extra_body={"prompt_logprobs": -1},
    )

    assert completion.prompt_logprobs is not None
    assert len(completion.prompt_logprobs) > 0
110
111
112
113


@pytest.mark.asyncio
async def test_top_logprobs(client: openai.AsyncOpenAI):
114
115
116
117
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Beijing is the capital of which country?"},
    ]
118
119
120
121

    completion = await client.chat.completions.create(
        model=MODEL_NAME,
        messages=messages,
122
        max_tokens=1,
123
124
125
126
127
128
129
130
        extra_body={
            "top_logprobs": -1,
            "logprobs": "true",
        },
    )
    assert completion.choices[0].logprobs is not None
    assert completion.choices[0].logprobs.content is not None
    assert len(completion.choices[0].logprobs.content) > 0
131
132
133
    assert len(
        completion.choices[0].logprobs.content[0].top_logprobs
    ) == get_vocab_size(MODEL_NAME)