test_api_server.py 4.29 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import copyreg
5
import os
6
7
8
9
10
11
12
13
import subprocess
import sys
import time
from multiprocessing import Pool
from pathlib import Path

import pytest
import requests
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import urllib3.exceptions


def _pickle_new_connection_error(obj):
    """Custom pickler for NewConnectionError to fix tblib compatibility."""
    # Extract the original message by removing the "conn: " prefix
    full_message = obj.args[0] if obj.args else ""
    if ': ' in full_message:
        # Split off the connection part and keep the actual message
        _, actual_message = full_message.split(': ', 1)
    else:
        actual_message = full_message
    return _unpickle_new_connection_error, (actual_message, )


def _unpickle_new_connection_error(message):
    """Custom unpickler for NewConnectionError."""
    # Create with None as conn and the actual message
    return urllib3.exceptions.NewConnectionError(None, message)


# Register the custom pickle/unpickle functions for tblib compatibility
copyreg.pickle(urllib3.exceptions.NewConnectionError,
               _pickle_new_connection_error)
38
39


40
def _query_server(prompt: str, max_tokens: int = 5) -> dict:
41
42
43
    response = requests.post("http://localhost:8000/generate",
                             json={
                                 "prompt": prompt,
44
                                 "max_tokens": max_tokens,
45
46
47
48
49
50
51
                                 "temperature": 0,
                                 "ignore_eos": True
                             })
    response.raise_for_status()
    return response.json()


52
53
54
55
def _query_server_long(prompt: str) -> dict:
    return _query_server(prompt, max_tokens=500)


56
@pytest.fixture
57
def api_server(distributed_executor_backend: str):
58
59
    script_path = Path(__file__).parent.joinpath(
        "api_server_async_engine.py").absolute()
60
    commands = [
61
62
63
64
65
66
67
68
69
        sys.executable,
        "-u",
        str(script_path),
        "--model",
        "facebook/opt-125m",
        "--host",
        "127.0.0.1",
        "--distributed-executor-backend",
        distributed_executor_backend,
70
    ]
71

72
73
74
75
    # API Server Test Requires V0.
    my_env = os.environ.copy()
    my_env["VLLM_USE_V1"] = "0"
    uvicorn_process = subprocess.Popen(commands, env=my_env)
76
77
78
79
    yield
    uvicorn_process.terminate()


80
@pytest.mark.timeout(300)
81
@pytest.mark.parametrize("distributed_executor_backend", ["mp", "ray"])
82
def test_api_server(api_server, distributed_executor_backend: str):
83
84
85
86
87
88
89
90
91
92
93
    """
    Run the API server and test it.

    We run both the server and requests in separate processes.

    We test that the server can handle incoming requests, including
    multiple requests at the same time, and that it can handle requests
    being cancelled without crashing.
    """
    with Pool(32) as pool:
        # Wait until the server is ready
94
        prompts = ["warm up"] * 1
95
96
97
        result = None
        while not result:
            try:
98
99
                for r in pool.map(_query_server, prompts):
                    result = r
100
                    break
101
            except requests.exceptions.ConnectionError:
102
103
104
105
106
107
108
109
110
111
112
113
                time.sleep(1)

        # Actual tests start here
        # Try with 1 prompt
        for result in pool.map(_query_server, prompts):
            assert result

        num_aborted_requests = requests.get(
            "http://localhost:8000/stats").json()["num_aborted_requests"]
        assert num_aborted_requests == 0

        # Try with 100 prompts
114
        prompts = ["test prompt"] * 100
115
116
117
        for result in pool.map(_query_server, prompts):
            assert result

118
    with Pool(32) as pool:
119
        # Cancel requests
120
        prompts = ["canceled requests"] * 100
121
122
        pool.map_async(_query_server_long, prompts)
        time.sleep(0.01)
123
124
125
126
        pool.terminate()
        pool.join()

        # check cancellation stats
127
        # give it some time to update the stats
Simon Mo's avatar
Simon Mo committed
128
129
        time.sleep(1)

130
131
132
133
134
135
136
        num_aborted_requests = requests.get(
            "http://localhost:8000/stats").json()["num_aborted_requests"]
        assert num_aborted_requests > 0

    # check that server still runs after cancellations
    with Pool(32) as pool:
        # Try with 100 prompts
137
        prompts = ["test prompt after canceled"] * 100
138
139
        for result in pool.map(_query_server, prompts):
            assert result