test_api_server.py 3.63 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import os
5
6
7
8
9
10
11
12
import subprocess
import sys
import time
from multiprocessing import Pool
from pathlib import Path

import pytest
import requests
13
14
import os
from ..utils import models_path_prefix
15
16


17
def _query_server(prompt: str, max_tokens: int = 5) -> dict:
18
19
20
    response = requests.post("http://localhost:8000/generate",
                             json={
                                 "prompt": prompt,
21
                                 "max_tokens": max_tokens,
22
23
24
25
26
27
28
                                 "temperature": 0,
                                 "ignore_eos": True
                             })
    response.raise_for_status()
    return response.json()


29
30
31
32
def _query_server_long(prompt: str) -> dict:
    return _query_server(prompt, max_tokens=500)


33
@pytest.fixture
34
def api_server(tokenizer_pool_size: int, distributed_executor_backend: str):
35
36
    script_path = Path(__file__).parent.joinpath(
        "api_server_async_engine.py").absolute()
37
    commands = [
38
39
40
41
        sys.executable,
        "-u",
        str(script_path),
        "--model",
zhuwenwen's avatar
zhuwenwen committed
42
        os.path.join(models_path_prefix, "facebook/opt-125m"),
43
44
45
46
47
48
        "--host",
        "127.0.0.1",
        "--tokenizer-pool-size",
        str(tokenizer_pool_size),
        "--distributed-executor-backend",
        distributed_executor_backend,
49
    ]
50

51
52
53
54
    # API Server Test Requires V0.
    my_env = os.environ.copy()
    my_env["VLLM_USE_V1"] = "0"
    uvicorn_process = subprocess.Popen(commands, env=my_env)
55
56
57
58
    yield
    uvicorn_process.terminate()


59
@pytest.mark.parametrize("tokenizer_pool_size", [0, 2])
60
@pytest.mark.parametrize("distributed_executor_backend", ["mp", "ray"])
61
def test_api_server(api_server, tokenizer_pool_size: int,
62
                    distributed_executor_backend: str):
63
64
65
66
67
68
69
70
71
72
73
    """
    Run the API server and test it.

    We run both the server and requests in separate processes.

    We test that the server can handle incoming requests, including
    multiple requests at the same time, and that it can handle requests
    being cancelled without crashing.
    """
    with Pool(32) as pool:
        # Wait until the server is ready
74
        prompts = ["warm up"] * 1
75
76
77
        result = None
        while not result:
            try:
78
79
                for r in pool.map(_query_server, prompts):
                    result = r
80
                    break
81
            except requests.exceptions.ConnectionError:
82
83
84
85
86
87
88
89
90
                time.sleep(1)

        # Actual tests start here
        # Try with 1 prompt
        for result in pool.map(_query_server, prompts):
            assert result

        num_aborted_requests = requests.get(
            "http://localhost:8000/stats").json()["num_aborted_requests"]
91
        # assert num_aborted_requests == 0
92
93

        # Try with 100 prompts
94
        prompts = ["test prompt"] * 100
95
96
97
        for result in pool.map(_query_server, prompts):
            assert result

98
    with Pool(32) as pool:
99
        # Cancel requests
100
        prompts = ["canceled requests"] * 100
101
102
        pool.map_async(_query_server_long, prompts)
        time.sleep(0.01)
103
104
105
106
        pool.terminate()
        pool.join()

        # check cancellation stats
Simon Mo's avatar
Simon Mo committed
107
108
109
        # give it some times to update the stats
        time.sleep(1)

110
111
112
113
114
115
116
        num_aborted_requests = requests.get(
            "http://localhost:8000/stats").json()["num_aborted_requests"]
        assert num_aborted_requests > 0

    # check that server still runs after cancellations
    with Pool(32) as pool:
        # Try with 100 prompts
117
        prompts = ["test prompt after canceled"] * 100
118
119
        for result in pool.map(_query_server, prompts):
            assert result