test_async_tokenization.py 3.69 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import asyncio
import contextlib
import random
import time
from typing import Callable

import openai
import pytest
import pytest_asyncio
import requests

from tests.utils import RemoteOpenAIServer

MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"


@pytest.fixture(scope="module")
def server():  # noqa: F811
    args = [
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "bfloat16",
        "--max-model-len",
        "8192",
        "--enforce-eager",
        "--max-num-seqs",
        "128",
        "--load-format",
        "dummy",
    ]

    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
        yield remote_server


@pytest_asyncio.fixture
async def client(server):
    async with server.get_async_client() as async_client:
        yield async_client


@pytest.mark.asyncio
@pytest.mark.parametrize(
    ids=["completion", "chat"],
    argnames=["create_func_gen", "content_body"],
    argvalues=[
        (lambda x: x.completions.create, {
            "prompt": " ".join(['A'] * 10_000)
        }),
        (lambda x: x.chat.completions.create, {
            "messages": [{
                "role": "user",
                "content": " ".join(['A'] * 10_000)
            }]
        }),
    ],
)
async def test_with_and_without_truncate(
    server: RemoteOpenAIServer,
    client: openai.AsyncOpenAI,
    create_func_gen: Callable,
    content_body: dict,
):
    create_func = create_func_gen(client)
    body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}

    num_requests = 10
    truncate_prompt_tokens = ([1000] * (num_requests // 2) + [None] *
                              (num_requests - num_requests // 2))
    random.shuffle(truncate_prompt_tokens)

    bodies = [{
        **body, "extra_body": {
            'truncate_prompt_tokens': t
        }
    } for t in truncate_prompt_tokens]

    async def get_status_code(**kwargs):
        try:
            await create_func(**kwargs)
            return 200
        except openai.APIStatusError as e:
            return e.status_code

    responses = await asyncio.gather(*[get_status_code(**b) for b in bodies])
    assert 500 not in responses


@pytest.mark.asyncio
@pytest.mark.parametrize(
    ids=["single completion", "multiple completions", "chat"],
    argnames=["create_func_gen", "content_body"],
    argvalues=[
        (lambda x: x.completions.create, {
            "prompt": " ".join(['A'] * 300_000)
        }),
        (lambda x: x.completions.create, {
            "prompt": [" ".join(['A'] * 300_000)] * 2
        }),
        (lambda x: x.chat.completions.create, {
            "messages": [{
                "role": "user",
                "content": " ".join(['A'] * 300_000)
            }]
        }),
    ],
)
async def test_healthcheck_response_time(
    server: RemoteOpenAIServer,
    client: openai.AsyncOpenAI,
    create_func_gen: Callable,
    content_body: dict,
):
    num_requests = 50

    create_func = create_func_gen(client)
    body = {"model": MODEL_NAME, **content_body, "max_tokens": 10}

    def get_response_time(url):
        start_time = time.monotonic()
        res = requests.get(url)
        end_time = time.monotonic()
        assert res.status_code == 200
        return end_time - start_time

    no_load_response_time = get_response_time(server.url_for("health"))
    tasks = [
        asyncio.create_task(create_func(**body)) for _ in range(num_requests)
    ]
    await asyncio.sleep(1)  # give the tasks a chance to start running
    load_response_time = get_response_time(server.url_for("health"))

    with contextlib.suppress(openai.APIStatusError):
        await asyncio.gather(*tasks)

    assert load_response_time < 100 * no_load_response_time
    assert load_response_time < 0.1