"scripts/conversion_glide.py" did not exist on "929e1c032846d40d941f89b7cee999c433cbaf69"
test_api_server.py 2.61 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import subprocess
import sys
import time
from multiprocessing import Pool
from pathlib import Path

import pytest
import requests


def _query_server(prompt: str) -> dict:
    response = requests.post("http://localhost:8000/generate",
                             json={
                                 "prompt": prompt,
                                 "max_tokens": 100,
                                 "temperature": 0,
                                 "ignore_eos": True
                             })
    response.raise_for_status()
    return response.json()


@pytest.fixture
def api_server():
    script_path = Path(__file__).parent.joinpath(
        "api_server_async_engine.py").absolute()
    uvicorn_process = subprocess.Popen([
        sys.executable, "-u",
        str(script_path), "--model", "facebook/opt-125m"
    ])
    yield
    uvicorn_process.terminate()


def test_api_server(api_server):
    """
    Run the API server and test it.

    We run both the server and requests in separate processes.

    We test that the server can handle incoming requests, including
    multiple requests at the same time, and that it can handle requests
    being cancelled without crashing.
    """
    with Pool(32) as pool:
        # Wait until the server is ready
47
        prompts = ["warm up"] * 1
48
49
50
        result = None
        while not result:
            try:
51
52
                for r in pool.map(_query_server, prompts):
                    result = r
53
                    break
54
            except requests.exceptions.ConnectionError:
55
56
57
58
59
60
61
62
63
64
65
66
                time.sleep(1)

        # Actual tests start here
        # Try with 1 prompt
        for result in pool.map(_query_server, prompts):
            assert result

        num_aborted_requests = requests.get(
            "http://localhost:8000/stats").json()["num_aborted_requests"]
        assert num_aborted_requests == 0

        # Try with 100 prompts
67
        prompts = ["test prompt"] * 100
68
69
70
71
        for result in pool.map(_query_server, prompts):
            assert result

        # Cancel requests
72
        prompts = ["canceled requests"] * 100
73
        pool.map_async(_query_server, prompts)
74
        time.sleep(0.001)
75
76
77
78
79
80
81
82
83
84
85
        pool.terminate()
        pool.join()

        # check cancellation stats
        num_aborted_requests = requests.get(
            "http://localhost:8000/stats").json()["num_aborted_requests"]
        assert num_aborted_requests > 0

    # check that server still runs after cancellations
    with Pool(32) as pool:
        # Try with 100 prompts
86
        prompts = ["test prompt after canceled"] * 100
87
88
        for result in pool.map(_query_server, prompts):
            assert result