test_api_server.py 2.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import subprocess
import sys
import time
from multiprocessing import Pool
from pathlib import Path

import pytest
import requests


def _query_server(prompt: str) -> dict:
    response = requests.post("http://localhost:8000/generate",
                             json={
                                 "prompt": prompt,
                                 "max_tokens": 100,
                                 "temperature": 0,
                                 "ignore_eos": True
                             })
    response.raise_for_status()
    return response.json()


@pytest.fixture
def api_server():
    script_path = Path(__file__).parent.joinpath(
        "api_server_async_engine.py").absolute()
27
    # pylint: disable=consider-using-with
28
29
30
31
32
33
34
35
    uvicorn_process = subprocess.Popen([
        sys.executable, "-u",
        str(script_path), "--model", "facebook/opt-125m"
    ])
    yield
    uvicorn_process.terminate()


36
# pylint: disable=redefined-outer-name, unused-argument
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def test_api_server(api_server):
    """
    Run the API server and test it.

    We run both the server and requests in separate processes.

    We test that the server can handle incoming requests, including
    multiple requests at the same time, and that it can handle requests
    being cancelled without crashing.
    """
    with Pool(32) as pool:
        # Wait until the server is ready
        prompts = ["Hello world"] * 1
        result = None
        while not result:
52
            # pylint: disable=bare-except
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
            try:
                for result in pool.map(_query_server, prompts):
                    break
            except:
                time.sleep(1)

        # Actual tests start here
        # Try with 1 prompt
        for result in pool.map(_query_server, prompts):
            assert result

        num_aborted_requests = requests.get(
            "http://localhost:8000/stats").json()["num_aborted_requests"]
        assert num_aborted_requests == 0

        # Try with 100 prompts
        prompts = ["Hello world"] * 100
        for result in pool.map(_query_server, prompts):
            assert result

        # Cancel requests
        pool.map_async(_query_server, prompts)
        time.sleep(0.01)
        pool.terminate()
        pool.join()

        # check cancellation stats
        num_aborted_requests = requests.get(
            "http://localhost:8000/stats").json()["num_aborted_requests"]
        assert num_aborted_requests > 0

    # check that server still runs after cancellations
    with Pool(32) as pool:
        # Try with 100 prompts
        prompts = ["Hello world"] * 100
        for result in pool.map(_query_server, prompts):
            assert result