test_basic.py 6.45 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
6
from http import HTTPStatus

7
import openai
8
import pytest
9
import pytest_asyncio
10
11
12
13
14
15
16
17
18
import requests

from vllm.version import __version__ as VLLM_VERSION

from ...utils import RemoteOpenAIServer

MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"


19
@pytest.fixture(scope="module")
20
def server_args(request: pytest.FixtureRequest) -> list[str]:
21
    """Provide extra arguments to the server via indirect parametrization
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

    Usage:

    >>> @pytest.mark.parametrize(
    >>>     "server_args",
    >>>     [
    >>>         ["--disable-frontend-multiprocessing"],
    >>>         [
    >>>             "--model=NousResearch/Hermes-3-Llama-3.1-70B",
    >>>             "--enable-auto-tool-choice",
    >>>         ],
    >>>     ],
    >>>     indirect=True,
    >>> )
    >>> def test_foo(server, client):
    >>>     ...

    This will run `test_foo` twice with servers with:
    - `--disable-frontend-multiprocessing`
    - `--model=NousResearch/Hermes-3-Llama-3.1-70B --enable-auto-tool-choice`.

    """
    if not hasattr(request, "param"):
        return []

    val = request.param

    if isinstance(val, str):
        return [val]

    return request.param


55
@pytest.fixture(scope="module")
56
def server(server_args):
57
58
59
60
61
62
63
64
65
    args = [
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "bfloat16",
        "--max-model-len",
        "8192",
        "--enforce-eager",
        "--max-num-seqs",
        "128",
66
        *server_args,
67
68
69
70
71
72
    ]

    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
        yield remote_server


73
74
75
76
@pytest_asyncio.fixture
async def client(server):
    async with server.get_async_client() as async_client:
        yield async_client
77
78


79
80
81
82
@pytest.mark.parametrize(
    "server_args",
    [
        pytest.param([], id="default-frontend-multiprocessing"),
83
84
85
86
        pytest.param(
            ["--disable-frontend-multiprocessing"],
            id="disable-frontend-multiprocessing",
        ),
87
88
89
    ],
    indirect=True,
)
90
@pytest.mark.asyncio
91
92
async def test_show_version(server: RemoteOpenAIServer):
    response = requests.get(server.url_for("version"))
93
94
95
96
97
    response.raise_for_status()

    assert response.json() == {"version": VLLM_VERSION}


98
99
100
101
@pytest.mark.parametrize(
    "server_args",
    [
        pytest.param([], id="default-frontend-multiprocessing"),
102
103
104
105
        pytest.param(
            ["--disable-frontend-multiprocessing"],
            id="disable-frontend-multiprocessing",
        ),
106
107
108
    ],
    indirect=True,
)
109
@pytest.mark.asyncio
110
111
async def test_check_health(server: RemoteOpenAIServer):
    response = requests.get(server.url_for("health"))
112
113

    assert response.status_code == HTTPStatus.OK
114
115
116
117
118


@pytest.mark.parametrize(
    "server_args",
    [
119
120
121
        pytest.param(
            ["--max-model-len", "10100"], id="default-frontend-multiprocessing"
        ),
122
123
        pytest.param(
            ["--disable-frontend-multiprocessing", "--max-model-len", "10100"],
124
125
            id="disable-frontend-multiprocessing",
        ),
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    ],
    indirect=True,
)
@pytest.mark.asyncio
async def test_request_cancellation(server: RemoteOpenAIServer):
    # clunky test: send an ungodly amount of load in with short timeouts
    # then ensure that it still responds quickly afterwards

    chat_input = [{"role": "user", "content": "Write a long story"}]
    client = server.get_async_client(timeout=0.5)
    tasks = []
    # Request about 2 million tokens
    for _ in range(200):
        task = asyncio.create_task(
140
141
142
143
144
145
146
            client.chat.completions.create(
                messages=chat_input,
                model=MODEL_NAME,
                max_tokens=10000,
                extra_body={"min_tokens": 10000},
            )
        )
147
148
        tasks.append(task)

149
    done, pending = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
150
151
152
153
154
155
156
157
158
159
160
161

    # Make sure all requests were sent to the server and timed out
    # (We don't want to hide other errors like 400s that would invalidate this
    # test)
    assert len(pending) == 0
    for d in done:
        with pytest.raises(openai.APITimeoutError):
            d.result()

    # If the server had not cancelled all the other requests, then it would not
    # be able to respond to this one within the timeout
    client = server.get_async_client(timeout=5)
162
163
164
    response = await client.chat.completions.create(
        messages=chat_input, model=MODEL_NAME, max_tokens=10
    )
165
166

    assert len(response.choices) == 1
167
168
169
170
171
172
173
174
175
176
177
178


@pytest.mark.asyncio
async def test_request_wrong_content_type(server: RemoteOpenAIServer):
    chat_input = [{"role": "user", "content": "Write a long story"}]
    client = server.get_async_client()

    with pytest.raises(openai.APIStatusError):
        await client.chat.completions.create(
            messages=chat_input,
            model=MODEL_NAME,
            max_tokens=10000,
179
180
            extra_headers={"Content-Type": "application/x-www-form-urlencoded"},
        )
181
182
183
184


@pytest.mark.parametrize(
    "server_args",
185
    [pytest.param(["--enable-server-load-tracking"], id="enable-server-load-tracking")],
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    indirect=True,
)
@pytest.mark.asyncio
async def test_server_load(server: RemoteOpenAIServer):
    # Check initial server load
    response = requests.get(server.url_for("load"))
    assert response.status_code == HTTPStatus.OK
    assert response.json().get("server_load") == 0

    def make_long_completion_request():
        return requests.post(
            server.url_for("v1/completions"),
            headers={"Content-Type": "application/json"},
            json={
                "prompt": "Give me a long story",
                "max_tokens": 1000,
                "temperature": 0,
            },
        )

    # Start the completion request in a background thread.
    completion_future = asyncio.create_task(
208
209
        asyncio.to_thread(make_long_completion_request)
    )
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226

    # Give a short delay to ensure the request has started.
    await asyncio.sleep(0.1)

    # Check server load while the completion request is running.
    response = requests.get(server.url_for("load"))
    assert response.status_code == HTTPStatus.OK
    assert response.json().get("server_load") == 1

    # Wait for the completion request to finish.
    await completion_future
    await asyncio.sleep(0.1)

    # Check server load after the completion request has finished.
    response = requests.get(server.url_for("load"))
    assert response.status_code == HTTPStatus.OK
    assert response.json().get("server_load") == 0