test_chunked_prompt.py 3.53 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
import openai  # use the official client for correctness check
import pytest
import pytest_asyncio

from ...utils import RemoteOpenAIServer

# any model with a chat template should work here
11
MODEL_NAME = "Qwen/Qwen3-0.6B"
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41


@pytest.fixture(scope="module")
def server():
    args = [
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "bfloat16",
        "--max-model-len",
        "8192",
        "--enforce-eager",
        "--max-num-seqs",
        "128",
        "--enable-chunked-prefill",
        "--max-num-batched-tokens",
        "1000",
    ]

    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
        yield remote_server


@pytest_asyncio.fixture
async def client(server):
    async with server.get_async_client() as async_client:
        yield async_client


@pytest.mark.asyncio
async def test_completion_stream_options_and_logprobs_with_long_prompts(
42
43
    client: openai.AsyncOpenAI,
):
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    # Test stream with long prompt
    prompt = "What is the capital of France?" * 400

    stream = await client.completions.create(
        model=MODEL_NAME,
        prompt=prompt,
        max_tokens=5,
        temperature=0.0,
        stream=True,
        stream_options={
            "include_usage": True,
            "continuous_usage_stats": True,
        },
        logprobs=5,
    )

    tokens_received = 0
    finished = False
    async for chunk in stream:
        assert chunk.usage.prompt_tokens >= 0
        assert chunk.usage.completion_tokens >= 0
65
66
67
        assert chunk.usage.total_tokens == (
            chunk.usage.prompt_tokens + chunk.usage.completion_tokens
        )
68
69
70
71
72
73
74
75
76
77
78
79
80
        if not finished:
            tokens_received += 1
            assert chunk.choices[0].text

            if chunk.choices[0].finish_reason is not None:
                finished = True

        if finished:
            assert chunk.usage.completion_tokens == tokens_received


@pytest.mark.asyncio
async def test_chat_completion_stream_options_and_logprobs_with_long_prompts(
81
82
    client: openai.AsyncOpenAI,
):
83
    # Test stream with long prompt
84
85
86
87
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "What is the capital of France?" * 400},
    ]
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    stream = await client.chat.completions.create(
        model=MODEL_NAME,
        messages=messages,
        max_tokens=5,
        temperature=0.0,
        stream=True,
        stream_options={
            "include_usage": True,
            "continuous_usage_stats": True,
        },
        logprobs=True,
        top_logprobs=5,
    )

    tokens_received = 0
    empty_chunks_received = 0
    finished = False
    async for chunk in stream:
        assert chunk.usage.prompt_tokens >= 0
        assert chunk.usage.completion_tokens >= 0
108
109
110
        assert chunk.usage.total_tokens == (
            chunk.usage.prompt_tokens + chunk.usage.completion_tokens
        )
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127

        if not finished:
            if chunk.choices[0].delta.content == "":
                # when there is no tokens generated
                assert chunk.usage.completion_tokens == 0
                assert chunk.choices[0].logprobs is None
                empty_chunks_received += 1
            else:
                tokens_received += 1

            if chunk.choices[0].finish_reason is not None:
                finished = True

        if finished:
            assert chunk.usage.completion_tokens == tokens_received

    assert empty_chunks_received <= 1