test_basic.py 7.33 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from http import HTTPStatus
6
from unittest.mock import AsyncMock, Mock
7

8
import openai
9
import pytest
10
import pytest_asyncio
11
import requests
12
from fastapi import Request
13

14
from vllm.v1.engine.exceptions import EngineDeadError
15
16
17
18
from vllm.version import __version__ as VLLM_VERSION

from ...utils import RemoteOpenAIServer

19
MODEL_NAME = "Qwen/Qwen3-0.6B"
20
21


22
@pytest.fixture(scope="module")
23
def server_args(request: pytest.FixtureRequest) -> list[str]:
24
    """Provide extra arguments to the server via indirect parametrization
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

    Usage:

    >>> @pytest.mark.parametrize(
    >>>     "server_args",
    >>>     [
    >>>         ["--disable-frontend-multiprocessing"],
    >>>         [
    >>>             "--model=NousResearch/Hermes-3-Llama-3.1-70B",
    >>>             "--enable-auto-tool-choice",
    >>>         ],
    >>>     ],
    >>>     indirect=True,
    >>> )
    >>> def test_foo(server, client):
    >>>     ...

    This will run `test_foo` twice with servers with:
    - `--disable-frontend-multiprocessing`
    - `--model=NousResearch/Hermes-3-Llama-3.1-70B --enable-auto-tool-choice`.

    """
    if not hasattr(request, "param"):
        return []

    val = request.param

    if isinstance(val, str):
        return [val]

    return request.param


58
@pytest.fixture(scope="module")
59
def server(server_args):
60
61
62
63
64
65
66
67
68
    args = [
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "bfloat16",
        "--max-model-len",
        "8192",
        "--enforce-eager",
        "--max-num-seqs",
        "128",
69
        *server_args,
70
71
72
73
74
75
    ]

    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
        yield remote_server


76
77
78
79
@pytest_asyncio.fixture
async def client(server):
    async with server.get_async_client() as async_client:
        yield async_client
80
81


82
83
84
85
@pytest.mark.parametrize(
    "server_args",
    [
        pytest.param([], id="default-frontend-multiprocessing"),
86
87
88
89
        pytest.param(
            ["--disable-frontend-multiprocessing"],
            id="disable-frontend-multiprocessing",
        ),
90
91
92
    ],
    indirect=True,
)
93
@pytest.mark.asyncio
94
95
async def test_show_version(server: RemoteOpenAIServer):
    response = requests.get(server.url_for("version"))
96
97
98
99
100
    response.raise_for_status()

    assert response.json() == {"version": VLLM_VERSION}


101
102
103
104
@pytest.mark.parametrize(
    "server_args",
    [
        pytest.param([], id="default-frontend-multiprocessing"),
105
106
107
108
        pytest.param(
            ["--disable-frontend-multiprocessing"],
            id="disable-frontend-multiprocessing",
        ),
109
110
111
    ],
    indirect=True,
)
112
@pytest.mark.asyncio
113
114
async def test_check_health(server: RemoteOpenAIServer):
    response = requests.get(server.url_for("health"))
115
116

    assert response.status_code == HTTPStatus.OK
117
118
119
120
121


@pytest.mark.parametrize(
    "server_args",
    [
122
123
124
        pytest.param(
            ["--max-model-len", "10100"], id="default-frontend-multiprocessing"
        ),
125
126
        pytest.param(
            ["--disable-frontend-multiprocessing", "--max-model-len", "10100"],
127
128
            id="disable-frontend-multiprocessing",
        ),
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    ],
    indirect=True,
)
@pytest.mark.asyncio
async def test_request_cancellation(server: RemoteOpenAIServer):
    # clunky test: send an ungodly amount of load in with short timeouts
    # then ensure that it still responds quickly afterwards

    chat_input = [{"role": "user", "content": "Write a long story"}]
    client = server.get_async_client(timeout=0.5)
    tasks = []
    # Request about 2 million tokens
    for _ in range(200):
        task = asyncio.create_task(
143
144
145
146
147
148
149
            client.chat.completions.create(
                messages=chat_input,
                model=MODEL_NAME,
                max_tokens=10000,
                extra_body={"min_tokens": 10000},
            )
        )
150
151
        tasks.append(task)

152
    done, pending = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
153
154
155
156
157
158
159
160
161
162
163
164

    # Make sure all requests were sent to the server and timed out
    # (We don't want to hide other errors like 400s that would invalidate this
    # test)
    assert len(pending) == 0
    for d in done:
        with pytest.raises(openai.APITimeoutError):
            d.result()

    # If the server had not cancelled all the other requests, then it would not
    # be able to respond to this one within the timeout
    client = server.get_async_client(timeout=5)
165
166
167
    response = await client.chat.completions.create(
        messages=chat_input, model=MODEL_NAME, max_tokens=10
    )
168
169

    assert len(response.choices) == 1
170
171
172
173
174
175
176
177
178
179
180
181


@pytest.mark.asyncio
async def test_request_wrong_content_type(server: RemoteOpenAIServer):
    chat_input = [{"role": "user", "content": "Write a long story"}]
    client = server.get_async_client()

    with pytest.raises(openai.APIStatusError):
        await client.chat.completions.create(
            messages=chat_input,
            model=MODEL_NAME,
            max_tokens=10000,
182
183
            extra_headers={"Content-Type": "application/x-www-form-urlencoded"},
        )
184
185
186
187


@pytest.mark.parametrize(
    "server_args",
188
    [pytest.param(["--enable-server-load-tracking"], id="enable-server-load-tracking")],
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    indirect=True,
)
@pytest.mark.asyncio
async def test_server_load(server: RemoteOpenAIServer):
    # Check initial server load
    response = requests.get(server.url_for("load"))
    assert response.status_code == HTTPStatus.OK
    assert response.json().get("server_load") == 0

    def make_long_completion_request():
        return requests.post(
            server.url_for("v1/completions"),
            headers={"Content-Type": "application/json"},
            json={
                "prompt": "Give me a long story",
                "max_tokens": 1000,
                "temperature": 0,
            },
        )

    # Start the completion request in a background thread.
    completion_future = asyncio.create_task(
211
212
        asyncio.to_thread(make_long_completion_request)
    )
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229

    # Give a short delay to ensure the request has started.
    await asyncio.sleep(0.1)

    # Check server load while the completion request is running.
    response = requests.get(server.url_for("load"))
    assert response.status_code == HTTPStatus.OK
    assert response.json().get("server_load") == 1

    # Wait for the completion request to finish.
    await completion_future
    await asyncio.sleep(0.1)

    # Check server load after the completion request has finished.
    response = requests.get(server.url_for("load"))
    assert response.status_code == HTTPStatus.OK
    assert response.json().get("server_load") == 0
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250


@pytest.mark.asyncio
async def test_health_check_engine_dead_error():
    # Import the health function directly to test it in isolation
    from vllm.entrypoints.openai.api_server import health

    # Create a mock request that simulates what FastAPI would provide
    mock_request = Mock(spec=Request)
    mock_app_state = Mock()
    mock_engine_client = AsyncMock()
    mock_engine_client.check_health.side_effect = EngineDeadError()
    mock_app_state.engine_client = mock_engine_client
    mock_request.app.state = mock_app_state

    # Test the health function directly with our mocked request
    # This simulates what would happen if the engine dies
    response = await health(mock_request)

    # Assert that it returns 503 Service Unavailable
    assert response.status_code == 503