endpoint_request_func.py 15.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""The request function for API endpoints."""

5
import io
6
7
8
9
10
11
import json
import os
import sys
import time
import traceback
from dataclasses import dataclass, field
12
from typing import Optional, Union
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

import aiohttp
from tqdm.asyncio import tqdm

AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)


@dataclass
class RequestFuncInput:
    """The input for the request function."""
    prompt: str
    api_url: str
    prompt_len: int
    output_len: int
    model: str
    model_name: Optional[str] = None
    logprobs: Optional[int] = None
    extra_body: Optional[dict] = None
31
    multi_modal_content: Optional[Union[dict, list[dict]]] = None
32
    ignore_eos: bool = False
33
    language: Optional[str] = None
34
    request_id: Optional[str] = None
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53


@dataclass
class RequestFuncOutput:
    """The output of the request function including metrics."""
    generated_text: str = ""
    success: bool = False
    latency: float = 0.0
    output_tokens: int = 0
    ttft: float = 0.0  # Time to first token
    itl: list[float] = field(
        default_factory=list)  # list of inter-token latencies
    tpot: float = 0.0  # avg next-token latencies
    prompt_len: int = 0
    error: str = ""


async def async_request_openai_completions(
    request_func_input: RequestFuncInput,
54
    session: aiohttp.ClientSession,
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    """The async request function for the OpenAI Completions API.

    Args:
        request_func_input: The input for the request function.
        pbar: The progress bar to display the progress.

    Returns:
        The output of the request function.
    """
    api_url = request_func_input.api_url
    assert api_url.endswith(
        ("completions", "profile")
    ), "OpenAI Completions API URL must end with 'completions' or 'profile'."

71
    payload = {
72
73
        "model": request_func_input.model_name
        if request_func_input.model_name else request_func_input.model,
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        "prompt": request_func_input.prompt,
        "temperature": 0.0,
        "repetition_penalty": 1.0,
        "max_tokens": request_func_input.output_len,
        "logprobs": request_func_input.logprobs,
        "stream": True,
        "stream_options": {
            "include_usage": True,
        },
    }
    if request_func_input.ignore_eos:
        payload["ignore_eos"] = request_func_input.ignore_eos
    if request_func_input.extra_body:
        payload.update(request_func_input.extra_body)
    headers = {
        "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
    }
91
92
    if request_func_input.request_id:
        headers["x-request-id"] = request_func_input.request_id
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

    output = RequestFuncOutput()
    output.prompt_len = request_func_input.prompt_len

    generated_text = ""
    st = time.perf_counter()
    most_recent_timestamp = st
    try:
        async with session.post(url=api_url, json=payload,
                                headers=headers) as response:
            if response.status == 200:
                first_chunk_received = False
                async for chunk_bytes in response.content:
                    chunk_bytes = chunk_bytes.strip()
                    if not chunk_bytes:
                        continue
                    chunk_bytes = chunk_bytes.decode("utf-8")
                    # NOTE: SSE comments (often used as pings) start with
                    # a colon. These are not JSON data payload and should
                    # be skipped.
                    if chunk_bytes.startswith(":"):
                        continue

                    chunk = chunk_bytes.removeprefix("data: ")

                    if chunk != "[DONE]":
                        data = json.loads(chunk)

                        # NOTE: Some completion API might have a last
                        # usage summary response without a token so we
                        # want to check a token was generated
                        if choices := data.get("choices"):
                            # Note that text could be empty here
                            # e.g. for special tokens
                            text = choices[0].get("text")
                            timestamp = time.perf_counter()
                            # First token
                            if not first_chunk_received:
                                first_chunk_received = True
                                ttft = time.perf_counter() - st
                                output.ttft = ttft
134

135
136
137
                            # Decoding phase
                            else:
                                output.itl.append(timestamp -
138
                                                  most_recent_timestamp)
139

140
141
142
143
144
145
146
                            most_recent_timestamp = timestamp
                            generated_text += text or ""
                        elif usage := data.get("usage"):
                            output.output_tokens = usage.get(
                                "completion_tokens")
                if first_chunk_received:
                    output.success = True
147
148
                else:
                    output.success = False
149
150
151
152
153
154
155
156
157
158
159
160
                    output.error = (
                        "Never received a valid chunk to calculate TTFT."
                        "This response will be marked as failed!")
                output.generated_text = generated_text
                output.latency = most_recent_timestamp - st
            else:
                output.error = response.reason or ""
                output.success = False
    except Exception:
        output.success = False
        exc_info = sys.exc_info()
        output.error = "".join(traceback.format_exception(*exc_info))
161
162
163
164
165
166

    if pbar:
        pbar.update(1)
    return output


167
168
async def async_request_openai_chat_completions(
    request_func_input: RequestFuncInput,
169
    session: aiohttp.ClientSession,
170
171
172
173
174
175
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith(("chat/completions", "profile")), (
        "OpenAI Chat Completions API URL must end with 'chat/completions'.")

176
177
    content = [{"type": "text", "text": request_func_input.prompt}]
    if request_func_input.multi_modal_content:
178
179
180
181
182
183
184
185
186
187
        mm_content = request_func_input.multi_modal_content
        if isinstance(mm_content, list):
            content.extend(mm_content)
        elif isinstance(mm_content, dict):
            content.append(mm_content)
        else:
            raise TypeError(
                "multi_modal_content must be a dict or list[dict] "
                "for openai-chat"
            )
188
189
190
191
192
193
194
195
    payload = {
        "model":
        request_func_input.model_name
        if request_func_input.model_name else request_func_input.model,
        "messages": [
            {
                "role": "user",
                "content": content
196
            },
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        ],
        "temperature":
        0.0,
        "max_completion_tokens":
        request_func_input.output_len,
        "stream":
        True,
        "stream_options": {
            "include_usage": True,
        },
    }
    if request_func_input.ignore_eos:
        payload["ignore_eos"] = request_func_input.ignore_eos
    if request_func_input.extra_body:
        payload.update(request_func_input.extra_body)
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
    }
216
217
    if request_func_input.request_id:
        headers["x-request-id"] = request_func_input.request_id
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256

    output = RequestFuncOutput()
    output.prompt_len = request_func_input.prompt_len

    generated_text = ""
    ttft = 0.0
    st = time.perf_counter()
    most_recent_timestamp = st
    try:
        async with session.post(url=api_url, json=payload,
                                headers=headers) as response:
            if response.status == 200:
                async for chunk_bytes in response.content:
                    chunk_bytes = chunk_bytes.strip()
                    if not chunk_bytes:
                        continue
                    chunk_bytes = chunk_bytes.decode("utf-8")
                    # NOTE: SSE comments (often used as pings) start with
                    # a colon. These are not JSON data payload and should
                    # be skipped.
                    if chunk_bytes.startswith(":"):
                        continue

                    chunk = chunk_bytes.removeprefix("data: ")

                    if chunk != "[DONE]":
                        timestamp = time.perf_counter()
                        data = json.loads(chunk)

                        if choices := data.get("choices"):
                            content = choices[0]["delta"].get("content")
                            # First token
                            if ttft == 0.0:
                                ttft = timestamp - st
                                output.ttft = ttft

                            # Decoding phase
                            else:
                                output.itl.append(timestamp -
257
                                                  most_recent_timestamp)
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318

                            generated_text += content or ""
                        elif usage := data.get("usage"):
                            output.output_tokens = usage.get(
                                "completion_tokens")

                        most_recent_timestamp = timestamp

                output.generated_text = generated_text
                output.success = True
                output.latency = most_recent_timestamp - st
            else:
                output.error = response.reason or ""
                output.success = False
    except Exception:
        output.success = False
        exc_info = sys.exc_info()
        output.error = "".join(traceback.format_exception(*exc_info))

    if pbar:
        pbar.update(1)
    return output


async def async_request_openai_audio(
    request_func_input: RequestFuncInput,
    session: aiohttp.ClientSession,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    # Lazy import without PlaceholderModule to avoid vllm dep.
    import soundfile

    api_url = request_func_input.api_url
    assert api_url.endswith(("transcriptions", "translations")), (
        "OpenAI Chat Completions API URL must end with 'transcriptions' ")
    "or `translations`."

    content = [{"type": "text", "text": request_func_input.prompt}]
    payload = {
        "model":
        request_func_input.model_name
        if request_func_input.model_name else request_func_input.model,
        "temperature":
        0.0,
        "max_completion_tokens":
        request_func_input.output_len,
        "stream":
        True,
        "language":
        "en",
        # Flattened due to multipart/form-data
        "stream_include_usage":
        True,
        "stream_continuous_usage_stats":
        True,
    }
    if request_func_input.extra_body:
        payload.update(request_func_input.extra_body)
    headers = {
        "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
    }
319
320
    if request_func_input.request_id:
        headers["x-request-id"] = request_func_input.request_id
321
322
323
324
325
326
327
328

    # Send audio file
    def to_bytes(y, sr):
        buffer = io.BytesIO()
        soundfile.write(buffer, y, sr, format="WAV")
        buffer.seek(0)
        return buffer

329
330
331
332
    mm_audio = request_func_input.multi_modal_content
    if not isinstance(mm_audio, dict) or "audio" not in mm_audio:
        raise TypeError("multi_modal_content must be a dict containing 'audio'")
    with to_bytes(*mm_audio["audio"]) as f:
333
334
335
336
        form = aiohttp.FormData()
        form.add_field("file", f, content_type="audio/wav")
        for key, value in payload.items():
            form.add_field(key, str(value))
337
338
339
340
341
342
343
344
345

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
        ttft = 0.0
        st = time.perf_counter()
        most_recent_timestamp = st
        try:
346
347
            async with session.post(url=api_url,
                                    data=form,
348
349
350
351
352
353
354
                                    headers=headers) as response:
                if response.status == 200:
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
                            continue

355
356
                        chunk = chunk_bytes.decode("utf-8").removeprefix(
                            "data: ")
357
358
359
360
361
                        if chunk != "[DONE]":
                            timestamp = time.perf_counter()
                            data = json.loads(chunk)

                            if choices := data.get("choices"):
362
363
                                content = choices[0]["delta"].get(
                                    "content")
364
365
366
367
368
369
370
                                # First token
                                if ttft == 0.0:
                                    ttft = timestamp - st
                                    output.ttft = ttft

                                # Decoding phase
                                else:
371
372
                                    output.itl.append(
                                        timestamp - most_recent_timestamp)
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396

                                generated_text += content or ""
                            elif usage := data.get("usage"):
                                output.output_tokens = usage.get(
                                    "completion_tokens")

                            most_recent_timestamp = timestamp

                    output.generated_text = generated_text
                    output.success = True
                    output.latency = most_recent_timestamp - st
                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

    if pbar:
        pbar.update(1)
    return output


397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
async def async_request_openai_embeddings(
    request_func_input: RequestFuncInput,
    session: aiohttp.ClientSession,
    pbar: Optional[tqdm] = None,
):
    api_url = request_func_input.api_url
    assert api_url.endswith(
        "embeddings"
    ), "OpenAI Embeddings API URL must end with 'embeddings'."

    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
    }

    payload = {
        "model": request_func_input.model,
        "input": request_func_input.prompt,
    }

    output = RequestFuncOutput()
    st = time.perf_counter()
    try:
        async with session.post(
            url=api_url,
            headers=headers,
            json=payload
        ) as response:
            if response.status == 200:
                output.latency = time.perf_counter() - st
                data = await response.json()
                output.success = True
                output.generated_text = ""
                output.prompt_len = data.get(
                    "usage", {}).get(
                    "prompt_tokens", 0)
            else:
                output.success = False
                output.error = response.reason or ""
    except Exception as e:
        output.success = False
        output.error = str(e)

    if pbar:
        pbar.update(1)
    return output


445
446
# TODO: Add more request functions for different API protocols.
ASYNC_REQUEST_FUNCS = {
447
448
449
450
    "vllm": async_request_openai_completions,
    "openai": async_request_openai_completions,
    "openai-chat": async_request_openai_chat_completions,
    "openai-audio": async_request_openai_audio,
451
    "openai-embeddings": async_request_openai_embeddings,
452
}
453
454
455
456
457
458

OPENAI_COMPATIBLE_BACKENDS = [
    k for k, v in ASYNC_REQUEST_FUNCS.items()
    if v in (async_request_openai_completions,
             async_request_openai_chat_completions)
]