test_chunked_prompt.py 4.07 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

zhuwenwen's avatar
zhuwenwen committed
4
import os
5
6
7
8
import openai  # use the official client for correctness check
import pytest
import pytest_asyncio

zhuwenwen's avatar
zhuwenwen committed
9
from ...utils import RemoteOpenAIServer, models_path_prefix
10
11

# any model with a chat template should work here
12
MODEL_NAME = os.path.join(models_path_prefix, "Qwen/Qwen3-0.6B")
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42


@pytest.fixture(scope="module")
def server():
    args = [
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "bfloat16",
        "--max-model-len",
        "8192",
        "--enforce-eager",
        "--max-num-seqs",
        "128",
        "--enable-chunked-prefill",
        "--max-num-batched-tokens",
        "1000",
    ]

    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
        yield remote_server


@pytest_asyncio.fixture
async def client(server):
    async with server.get_async_client() as async_client:
        yield async_client


@pytest.mark.asyncio
async def test_completion_stream_options_and_logprobs_with_long_prompts(
43
44
    client: openai.AsyncOpenAI,
):
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    # Test stream with long prompt
    prompt = "What is the capital of France?" * 400

    stream = await client.completions.create(
        model=MODEL_NAME,
        prompt=prompt,
        max_tokens=5,
        temperature=0.0,
        stream=True,
        stream_options={
            "include_usage": True,
            "continuous_usage_stats": True,
        },
        logprobs=5,
    )

    tokens_received = 0
    finished = False
    async for chunk in stream:
        assert chunk.usage.prompt_tokens >= 0
        assert chunk.usage.completion_tokens >= 0
66
67
68
        assert chunk.usage.total_tokens == (
            chunk.usage.prompt_tokens + chunk.usage.completion_tokens
        )
69
70
        if not finished:
            assert chunk.choices[0].text
71
72
73
74
            # Count actual tokens from logprobs since multiple tokens
            # can be batched into a single chunk
            assert chunk.choices[0].logprobs and chunk.choices[0].logprobs.tokens
            tokens_received += len(chunk.choices[0].logprobs.tokens)
75
76
77
78
79
80
81
82
83
84

            if chunk.choices[0].finish_reason is not None:
                finished = True

        if finished:
            assert chunk.usage.completion_tokens == tokens_received


@pytest.mark.asyncio
async def test_chat_completion_stream_options_and_logprobs_with_long_prompts(
85
86
    client: openai.AsyncOpenAI,
):
87
    # Test stream with long prompt
88
89
90
91
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "What is the capital of France?" * 400},
    ]
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
    stream = await client.chat.completions.create(
        model=MODEL_NAME,
        messages=messages,
        max_tokens=5,
        temperature=0.0,
        stream=True,
        stream_options={
            "include_usage": True,
            "continuous_usage_stats": True,
        },
        logprobs=True,
        top_logprobs=5,
    )

    tokens_received = 0
    empty_chunks_received = 0
    finished = False
    async for chunk in stream:
        assert chunk.usage.prompt_tokens >= 0
        assert chunk.usage.completion_tokens >= 0
112
113
114
        assert chunk.usage.total_tokens == (
            chunk.usage.prompt_tokens + chunk.usage.completion_tokens
        )
115
116
117
118
119
120
121
122

        if not finished:
            if chunk.choices[0].delta.content == "":
                # when there is no tokens generated
                assert chunk.usage.completion_tokens == 0
                assert chunk.choices[0].logprobs is None
                empty_chunks_received += 1
            else:
123
124
125
126
                # Count actual tokens from logprobs since multiple tokens
                # can be batched into a single chunk
                assert chunk.choices[0].logprobs and chunk.choices[0].logprobs.content
                tokens_received += len(chunk.choices[0].logprobs.content)
127
128
129
130
131
132
133
134

            if chunk.choices[0].finish_reason is not None:
                finished = True

        if finished:
            assert chunk.usage.completion_tokens == tokens_received

    assert empty_chunks_received <= 1