endpoint_request_func.py 14.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""The request function for API endpoints."""

5
import io
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import json
import os
import sys
import time
import traceback
from dataclasses import dataclass, field
from typing import Optional

import aiohttp
from tqdm.asyncio import tqdm

AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)


@dataclass
class RequestFuncInput:
    """The input for the request function."""
    prompt: str
    api_url: str
    prompt_len: int
    output_len: int
    model: str
    model_name: Optional[str] = None
    logprobs: Optional[int] = None
    extra_body: Optional[dict] = None
31
    multi_modal_content: Optional[dict | list[dict]] = None
32
    ignore_eos: bool = False
33
    language: Optional[str] = None
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52


@dataclass
class RequestFuncOutput:
    """The output of the request function including metrics."""
    generated_text: str = ""
    success: bool = False
    latency: float = 0.0
    output_tokens: int = 0
    ttft: float = 0.0  # Time to first token
    itl: list[float] = field(
        default_factory=list)  # list of inter-token latencies
    tpot: float = 0.0  # avg next-token latencies
    prompt_len: int = 0
    error: str = ""


async def async_request_openai_completions(
    request_func_input: RequestFuncInput,
53
    session: aiohttp.ClientSession,
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    """The async request function for the OpenAI Completions API.

    Args:
        request_func_input: The input for the request function.
        pbar: The progress bar to display the progress.

    Returns:
        The output of the request function.
    """
    api_url = request_func_input.api_url
    assert api_url.endswith(
        ("completions", "profile")
    ), "OpenAI Completions API URL must end with 'completions' or 'profile'."

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    payload = {
        "model": request_func_input.model_name \
            if request_func_input.model_name else request_func_input.model,
        "prompt": request_func_input.prompt,
        "temperature": 0.0,
        "repetition_penalty": 1.0,
        "max_tokens": request_func_input.output_len,
        "logprobs": request_func_input.logprobs,
        "stream": True,
        "stream_options": {
            "include_usage": True,
        },
    }
    if request_func_input.ignore_eos:
        payload["ignore_eos"] = request_func_input.ignore_eos
    if request_func_input.extra_body:
        payload.update(request_func_input.extra_body)
    headers = {
        "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
    }

    output = RequestFuncOutput()
    output.prompt_len = request_func_input.prompt_len

    generated_text = ""
    st = time.perf_counter()
    most_recent_timestamp = st
    try:
        async with session.post(url=api_url, json=payload,
                                headers=headers) as response:
            if response.status == 200:
                first_chunk_received = False
                async for chunk_bytes in response.content:
                    chunk_bytes = chunk_bytes.strip()
                    if not chunk_bytes:
                        continue
                    chunk_bytes = chunk_bytes.decode("utf-8")
                    # NOTE: SSE comments (often used as pings) start with
                    # a colon. These are not JSON data payload and should
                    # be skipped.
                    if chunk_bytes.startswith(":"):
                        continue

                    chunk = chunk_bytes.removeprefix("data: ")

                    if chunk != "[DONE]":
                        data = json.loads(chunk)

                        # NOTE: Some completion API might have a last
                        # usage summary response without a token so we
                        # want to check a token was generated
                        if choices := data.get("choices"):
                            # Note that text could be empty here
                            # e.g. for special tokens
                            text = choices[0].get("text")
                            timestamp = time.perf_counter()
                            # First token
                            if not first_chunk_received:
                                first_chunk_received = True
                                ttft = time.perf_counter() - st
                                output.ttft = ttft
131

132
133
134
135
                            # Decoding phase
                            else:
                                output.itl.append(timestamp -
                                                    most_recent_timestamp)
136

137
138
139
140
141
142
143
                            most_recent_timestamp = timestamp
                            generated_text += text or ""
                        elif usage := data.get("usage"):
                            output.output_tokens = usage.get(
                                "completion_tokens")
                if first_chunk_received:
                    output.success = True
144
145
                else:
                    output.success = False
146
147
148
149
150
151
152
153
154
155
156
157
                    output.error = (
                        "Never received a valid chunk to calculate TTFT."
                        "This response will be marked as failed!")
                output.generated_text = generated_text
                output.latency = most_recent_timestamp - st
            else:
                output.error = response.reason or ""
                output.success = False
    except Exception:
        output.success = False
        exc_info = sys.exc_info()
        output.error = "".join(traceback.format_exception(*exc_info))
158
159
160
161
162
163

    if pbar:
        pbar.update(1)
    return output


164
165
async def async_request_openai_chat_completions(
    request_func_input: RequestFuncInput,
166
    session: aiohttp.ClientSession,
167
168
169
170
171
172
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith(("chat/completions", "profile")), (
        "OpenAI Chat Completions API URL must end with 'chat/completions'.")

173
174
    content = [{"type": "text", "text": request_func_input.prompt}]
    if request_func_input.multi_modal_content:
175
176
177
178
179
180
181
182
183
184
        mm_content = request_func_input.multi_modal_content
        if isinstance(mm_content, list):
            content.extend(mm_content)
        elif isinstance(mm_content, dict):
            content.append(mm_content)
        else:
            raise TypeError(
                "multi_modal_content must be a dict or list[dict] "
                "for openai-chat"
            )
185
186
187
188
189
190
191
192
    payload = {
        "model":
        request_func_input.model_name
        if request_func_input.model_name else request_func_input.model,
        "messages": [
            {
                "role": "user",
                "content": content
193
            },
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        ],
        "temperature":
        0.0,
        "max_completion_tokens":
        request_func_input.output_len,
        "stream":
        True,
        "stream_options": {
            "include_usage": True,
        },
    }
    if request_func_input.ignore_eos:
        payload["ignore_eos"] = request_func_input.ignore_eos
    if request_func_input.extra_body:
        payload.update(request_func_input.extra_body)
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
    }

    output = RequestFuncOutput()
    output.prompt_len = request_func_input.prompt_len

    generated_text = ""
    ttft = 0.0
    st = time.perf_counter()
    most_recent_timestamp = st
    try:
        async with session.post(url=api_url, json=payload,
                                headers=headers) as response:
            if response.status == 200:
                async for chunk_bytes in response.content:
                    chunk_bytes = chunk_bytes.strip()
                    if not chunk_bytes:
                        continue
                    chunk_bytes = chunk_bytes.decode("utf-8")
                    # NOTE: SSE comments (often used as pings) start with
                    # a colon. These are not JSON data payload and should
                    # be skipped.
                    if chunk_bytes.startswith(":"):
                        continue

                    chunk = chunk_bytes.removeprefix("data: ")

                    if chunk != "[DONE]":
                        timestamp = time.perf_counter()
                        data = json.loads(chunk)

                        if choices := data.get("choices"):
                            content = choices[0]["delta"].get("content")
                            # First token
                            if ttft == 0.0:
                                ttft = timestamp - st
                                output.ttft = ttft

                            # Decoding phase
                            else:
                                output.itl.append(timestamp -
                                                    most_recent_timestamp)

                            generated_text += content or ""
                        elif usage := data.get("usage"):
                            output.output_tokens = usage.get(
                                "completion_tokens")

                        most_recent_timestamp = timestamp

                output.generated_text = generated_text
                output.success = True
                output.latency = most_recent_timestamp - st
            else:
                output.error = response.reason or ""
                output.success = False
    except Exception:
        output.success = False
        exc_info = sys.exc_info()
        output.error = "".join(traceback.format_exception(*exc_info))

    if pbar:
        pbar.update(1)
    return output


async def async_request_openai_audio(
    request_func_input: RequestFuncInput,
    session: aiohttp.ClientSession,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    # Lazy import without PlaceholderModule to avoid vllm dep.
    import soundfile

    api_url = request_func_input.api_url
    assert api_url.endswith(("transcriptions", "translations")), (
        "OpenAI Chat Completions API URL must end with 'transcriptions' ")
    "or `translations`."

    content = [{"type": "text", "text": request_func_input.prompt}]
    payload = {
        "model":
        request_func_input.model_name
        if request_func_input.model_name else request_func_input.model,
        "temperature":
        0.0,
        "max_completion_tokens":
        request_func_input.output_len,
        "stream":
        True,
        "language":
        "en",
        # Flattened due to multipart/form-data
        "stream_include_usage":
        True,
        "stream_continuous_usage_stats":
        True,
    }
    if request_func_input.extra_body:
        payload.update(request_func_input.extra_body)
    headers = {
        "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
    }

    # Send audio file
    def to_bytes(y, sr):
        buffer = io.BytesIO()
        soundfile.write(buffer, y, sr, format="WAV")
        buffer.seek(0)
        return buffer

322
323
324
325
    mm_audio = request_func_input.multi_modal_content
    if not isinstance(mm_audio, dict) or "audio" not in mm_audio:
        raise TypeError("multi_modal_content must be a dict containing 'audio'")
    with to_bytes(*mm_audio["audio"]) as f:
326
327
328
329
        form = aiohttp.FormData()
        form.add_field("file", f, content_type="audio/wav")
        for key, value in payload.items():
            form.add_field(key, str(value))
330
331
332
333
334
335
336
337
338

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
        ttft = 0.0
        st = time.perf_counter()
        most_recent_timestamp = st
        try:
339
340
            async with session.post(url=api_url,
                                    data=form,
341
342
343
344
345
346
347
                                    headers=headers) as response:
                if response.status == 200:
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
                            continue

348
349
                        chunk = chunk_bytes.decode("utf-8").removeprefix(
                            "data: ")
350
351
352
353
354
                        if chunk != "[DONE]":
                            timestamp = time.perf_counter()
                            data = json.loads(chunk)

                            if choices := data.get("choices"):
355
356
                                content = choices[0]["delta"].get(
                                    "content")
357
358
359
360
361
362
363
                                # First token
                                if ttft == 0.0:
                                    ttft = timestamp - st
                                    output.ttft = ttft

                                # Decoding phase
                                else:
364
365
                                    output.itl.append(
                                        timestamp - most_recent_timestamp)
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389

                                generated_text += content or ""
                            elif usage := data.get("usage"):
                                output.output_tokens = usage.get(
                                    "completion_tokens")

                            most_recent_timestamp = timestamp

                    output.generated_text = generated_text
                    output.success = True
                    output.latency = most_recent_timestamp - st
                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

    if pbar:
        pbar.update(1)
    return output


390
391
# TODO: Add more request functions for different API protocols.
ASYNC_REQUEST_FUNCS = {
392
393
394
395
    "vllm": async_request_openai_completions,
    "openai": async_request_openai_completions,
    "openai-chat": async_request_openai_chat_completions,
    "openai-audio": async_request_openai_audio,
396
}
397
398
399
400
401
402

OPENAI_COMPATIBLE_BACKENDS = [
    k for k, v in ASYNC_REQUEST_FUNCS.items()
    if v in (async_request_openai_completions,
             async_request_openai_chat_completions)
]