"tests/entrypoints/pooling/embed/test_online.py" did not exist on "f256ebe4df6757d76f1f1642d7e110268a2f8190"
endpoint_request_func.py 18.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""The request function for API endpoints."""

5
import io
6
7
8
9
10
import json
import os
import sys
import time
import traceback
11
from collections.abc import Awaitable
12
from dataclasses import dataclass, field
13
from typing import Optional, Protocol, Union
14
15
16
17
18
19
20

import aiohttp
from tqdm.asyncio import tqdm

AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)


21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class StreamedResponseHandler:
    """Handles streaming HTTP responses by accumulating chunks until complete
    messages are available."""

    def __init__(self):
        self.buffer = ""

    def add_chunk(self, chunk_bytes: bytes) -> list[str]:
        """Add a chunk of bytes to the buffer and return any complete
        messages."""
        chunk_str = chunk_bytes.decode("utf-8")
        self.buffer += chunk_str

        messages = []

        # Split by double newlines (SSE message separator)
        while "\n\n" in self.buffer:
            message, self.buffer = self.buffer.split("\n\n", 1)
            message = message.strip()
            if message:
                messages.append(message)

        # if self.buffer is not empty, check if it is a complete message
        # by removing data: prefix and check if it is a valid JSON
        if self.buffer.startswith("data: "):
            message_content = self.buffer.removeprefix("data: ").strip()
            if message_content == "[DONE]":
                messages.append(self.buffer.strip())
                self.buffer = ""
            elif message_content:
                try:
                    json.loads(message_content)
                    messages.append(self.buffer.strip())
                    self.buffer = ""
                except json.JSONDecodeError:
                    # Incomplete JSON, wait for more chunks.
                    pass

        return messages


62
63
64
@dataclass
class RequestFuncInput:
    """The input for the request function."""
65

66
67
68
69
70
71
72
    prompt: str
    api_url: str
    prompt_len: int
    output_len: int
    model: str
    model_name: Optional[str] = None
    logprobs: Optional[int] = None
73
    extra_headers: Optional[dict] = None
74
    extra_body: Optional[dict] = None
75
    multi_modal_content: Optional[Union[dict, list[dict]]] = None
76
    ignore_eos: bool = False
77
    language: Optional[str] = None
78
    request_id: Optional[str] = None
79
80
81
82
83


@dataclass
class RequestFuncOutput:
    """The output of the request function including metrics."""
84

85
86
87
88
89
    generated_text: str = ""
    success: bool = False
    latency: float = 0.0
    output_tokens: int = 0
    ttft: float = 0.0  # Time to first token
90
    itl: list[float] = field(default_factory=list)  # list of inter-token latencies
91
92
93
    tpot: float = 0.0  # avg next-token latencies
    prompt_len: int = 0
    error: str = ""
94
    start_time: float = 0.0
95
96


97
98
99
100
101
102
class RequestFunc(Protocol):
    def __call__(
        self,
        request_func_input: RequestFuncInput,
        session: aiohttp.ClientSession,
        pbar: Optional[tqdm] = None,
103
    ) -> Awaitable[RequestFuncOutput]: ...
104
105


106
107
async def async_request_openai_completions(
    request_func_input: RequestFuncInput,
108
    session: aiohttp.ClientSession,
109
110
111
112
113
114
115
116
117
118
119
120
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    """The async request function for the OpenAI Completions API.

    Args:
        request_func_input: The input for the request function.
        pbar: The progress bar to display the progress.

    Returns:
        The output of the request function.
    """
    api_url = request_func_input.api_url
121
122
123
    assert api_url.endswith(("completions", "profile")), (
        "OpenAI Completions API URL must end with 'completions' or 'profile'."
    )
124

125
    payload = {
126
        "model": request_func_input.model_name
127
128
        if request_func_input.model_name
        else request_func_input.model,
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        "prompt": request_func_input.prompt,
        "temperature": 0.0,
        "repetition_penalty": 1.0,
        "max_tokens": request_func_input.output_len,
        "logprobs": request_func_input.logprobs,
        "stream": True,
        "stream_options": {
            "include_usage": True,
        },
    }
    if request_func_input.ignore_eos:
        payload["ignore_eos"] = request_func_input.ignore_eos
    if request_func_input.extra_body:
        payload.update(request_func_input.extra_body)
143
    headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
144
145
    if request_func_input.extra_headers:
        headers |= request_func_input.extra_headers
146
147
    if request_func_input.request_id:
        headers["x-request-id"] = request_func_input.request_id
148
149
150
151
152
153

    output = RequestFuncOutput()
    output.prompt_len = request_func_input.prompt_len

    generated_text = ""
    st = time.perf_counter()
154
    output.start_time = st
155
156
    most_recent_timestamp = st
    try:
157
        async with session.post(url=api_url, json=payload, headers=headers) as response:
158
159
            if response.status == 200:
                first_chunk_received = False
160
161
162
                handler = StreamedResponseHandler()

                async for chunk_bytes in response.content.iter_any():
163
164
165
166
                    chunk_bytes = chunk_bytes.strip()
                    if not chunk_bytes:
                        continue

167
168
169
170
171
172
173
                    messages = handler.add_chunk(chunk_bytes)
                    for message in messages:
                        # NOTE: SSE comments (often used as pings) start with
                        # a colon. These are not JSON data payload and should
                        # be skipped.
                        if message.startswith(":"):
                            continue
174

175
                        chunk = message.removeprefix("data: ")
176

177
178
                        if chunk != "[DONE]":
                            data = json.loads(chunk)
179

180
181
182
183
184
185
186
187
188
189
190
191
192
                            # NOTE: Some completion API might have a last
                            # usage summary response without a token so we
                            # want to check a token was generated
                            if choices := data.get("choices"):
                                # Note that text could be empty here
                                # e.g. for special tokens
                                text = choices[0].get("text")
                                timestamp = time.perf_counter()
                                # First token
                                if not first_chunk_received:
                                    first_chunk_received = True
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft
193

194
195
                                # Decoding phase
                                else:
196
                                    output.itl.append(timestamp - most_recent_timestamp)
197
198
199
200

                                most_recent_timestamp = timestamp
                                generated_text += text or ""
                            elif usage := data.get("usage"):
201
                                output.output_tokens = usage.get("completion_tokens")
202
203
                if first_chunk_received:
                    output.success = True
204
205
                else:
                    output.success = False
206
207
                    output.error = (
                        "Never received a valid chunk to calculate TTFT."
208
209
                        "This response will be marked as failed!"
                    )
210
211
212
213
214
215
216
217
218
                output.generated_text = generated_text
                output.latency = most_recent_timestamp - st
            else:
                output.error = response.reason or ""
                output.success = False
    except Exception:
        output.success = False
        exc_info = sys.exc_info()
        output.error = "".join(traceback.format_exception(*exc_info))
219
220
221
222
223
224

    if pbar:
        pbar.update(1)
    return output


225
226
async def async_request_openai_chat_completions(
    request_func_input: RequestFuncInput,
227
    session: aiohttp.ClientSession,
228
229
230
231
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith(("chat/completions", "profile")), (
232
233
        "OpenAI Chat Completions API URL must end with 'chat/completions'."
    )
234

235
236
    content = [{"type": "text", "text": request_func_input.prompt}]
    if request_func_input.multi_modal_content:
237
238
239
240
241
242
243
        mm_content = request_func_input.multi_modal_content
        if isinstance(mm_content, list):
            content.extend(mm_content)
        elif isinstance(mm_content, dict):
            content.append(mm_content)
        else:
            raise TypeError(
244
                "multi_modal_content must be a dict or list[dict] for openai-chat"
245
            )
246
    payload = {
247
248
249
        "model": request_func_input.model_name
        if request_func_input.model_name
        else request_func_input.model,
250
        "messages": [
251
            {"role": "user", "content": content},
252
        ],
253
254
255
        "temperature": 0.0,
        "max_completion_tokens": request_func_input.output_len,
        "stream": True,
256
257
258
259
260
261
262
263
264
265
266
267
        "stream_options": {
            "include_usage": True,
        },
    }
    if request_func_input.ignore_eos:
        payload["ignore_eos"] = request_func_input.ignore_eos
    if request_func_input.extra_body:
        payload.update(request_func_input.extra_body)
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
    }
268
269
    if request_func_input.extra_headers:
        headers |= request_func_input.extra_headers
270
271
    if request_func_input.request_id:
        headers["x-request-id"] = request_func_input.request_id
272
273
274
275
276
277
278

    output = RequestFuncOutput()
    output.prompt_len = request_func_input.prompt_len

    generated_text = ""
    ttft = 0.0
    st = time.perf_counter()
279
    output.start_time = st
280
281
    most_recent_timestamp = st
    try:
282
        async with session.post(url=api_url, json=payload, headers=headers) as response:
283
            if response.status == 200:
284
285
                handler = StreamedResponseHandler()
                async for chunk_bytes in response.content.iter_any():
286
287
288
289
                    chunk_bytes = chunk_bytes.strip()
                    if not chunk_bytes:
                        continue

290
291
292
293
294
295
296
297
298
                    messages = handler.add_chunk(chunk_bytes)
                    for message in messages:
                        # NOTE: SSE comments (often used as pings) start with
                        # a colon. These are not JSON data payload and should
                        # be skipped.
                        if message.startswith(":"):
                            continue

                        chunk = message.removeprefix("data: ")
299

300
301
302
                        if chunk != "[DONE]":
                            timestamp = time.perf_counter()
                            data = json.loads(chunk)
303

304
305
306
307
308
309
                            if choices := data.get("choices"):
                                content = choices[0]["delta"].get("content")
                                # First token
                                if ttft == 0.0:
                                    ttft = timestamp - st
                                    output.ttft = ttft
310

311
312
                                # Decoding phase
                                else:
313
                                    output.itl.append(timestamp - most_recent_timestamp)
314

315
316
                                generated_text += content or ""
                            elif usage := data.get("usage"):
317
                                output.output_tokens = usage.get("completion_tokens")
318

319
                            most_recent_timestamp = timestamp
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346

                output.generated_text = generated_text
                output.success = True
                output.latency = most_recent_timestamp - st
            else:
                output.error = response.reason or ""
                output.success = False
    except Exception:
        output.success = False
        exc_info = sys.exc_info()
        output.error = "".join(traceback.format_exception(*exc_info))

    if pbar:
        pbar.update(1)
    return output


async def async_request_openai_audio(
    request_func_input: RequestFuncInput,
    session: aiohttp.ClientSession,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    # Lazy import without PlaceholderModule to avoid vllm dep.
    import soundfile

    api_url = request_func_input.api_url
    assert api_url.endswith(("transcriptions", "translations")), (
347
348
        "OpenAI Chat Completions API URL must end with 'transcriptions' "
    )
349
350
351
352
    "or `translations`."

    content = [{"type": "text", "text": request_func_input.prompt}]
    payload = {
353
354
355
356
357
358
359
        "model": request_func_input.model_name
        if request_func_input.model_name
        else request_func_input.model,
        "temperature": 0.0,
        "max_completion_tokens": request_func_input.output_len,
        "stream": True,
        "language": "en",
360
        # Flattened due to multipart/form-data
361
362
        "stream_include_usage": True,
        "stream_continuous_usage_stats": True,
363
364
365
366
367
368
    }
    if request_func_input.extra_body:
        payload.update(request_func_input.extra_body)
    headers = {
        "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
    }
369
370
    if request_func_input.extra_headers:
        headers |= request_func_input.extra_headers
371
372
    if request_func_input.request_id:
        headers["x-request-id"] = request_func_input.request_id
373
374
375
376
377
378
379
380

    # Send audio file
    def to_bytes(y, sr):
        buffer = io.BytesIO()
        soundfile.write(buffer, y, sr, format="WAV")
        buffer.seek(0)
        return buffer

381
382
383
384
    mm_audio = request_func_input.multi_modal_content
    if not isinstance(mm_audio, dict) or "audio" not in mm_audio:
        raise TypeError("multi_modal_content must be a dict containing 'audio'")
    with to_bytes(*mm_audio["audio"]) as f:
385
386
387
388
        form = aiohttp.FormData()
        form.add_field("file", f, content_type="audio/wav")
        for key, value in payload.items():
            form.add_field(key, str(value))
389
390
391
392
393
394
395

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
        ttft = 0.0
        st = time.perf_counter()
396
        output.start_time = st
397
398
        most_recent_timestamp = st
        try:
399
400
401
            async with session.post(
                url=api_url, data=form, headers=headers
            ) as response:
402
                if response.status == 200:
403
404
405
                    handler = StreamedResponseHandler()

                    async for chunk_bytes in response.content.iter_any():
406
407
408
409
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
                            continue

410
411
                        messages = handler.add_chunk(chunk_bytes)
                        for message in messages:
412
                            chunk = message.decode("utf-8").removeprefix("data: ")
413
414
415
416
417
                            if chunk != "[DONE]":
                                timestamp = time.perf_counter()
                                data = json.loads(chunk)

                                if choices := data.get("choices"):
418
                                    content = choices[0]["delta"].get("content")
419
420
421
422
423
424
425
426
                                    # First token
                                    if ttft == 0.0:
                                        ttft = timestamp - st
                                        output.ttft = ttft

                                    # Decoding phase
                                    else:
                                        output.itl.append(
427
428
                                            timestamp - most_recent_timestamp
                                        )
429
430
431
432

                                    generated_text += content or ""
                                elif usage := data.get("usage"):
                                    output.output_tokens = usage.get(
433
434
                                        "completion_tokens"
                                    )
435
436

                                most_recent_timestamp = timestamp
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453

                    output.generated_text = generated_text
                    output.success = True
                    output.latency = most_recent_timestamp - st
                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

    if pbar:
        pbar.update(1)
    return output


454
455
456
457
458
459
async def async_request_openai_embeddings(
    request_func_input: RequestFuncInput,
    session: aiohttp.ClientSession,
    pbar: Optional[tqdm] = None,
):
    api_url = request_func_input.api_url
460
461
462
    assert api_url.endswith("embeddings"), (
        "OpenAI Embeddings API URL must end with 'embeddings'."
    )
463
464
465
466
467
468
469
470
471
472
473
474
475

    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
    }

    payload = {
        "model": request_func_input.model,
        "input": request_func_input.prompt,
    }

    output = RequestFuncOutput()
    st = time.perf_counter()
476
    output.start_time = st
477
    try:
478
        async with session.post(url=api_url, headers=headers, json=payload) as response:
479
480
481
482
483
            if response.status == 200:
                output.latency = time.perf_counter() - st
                data = await response.json()
                output.success = True
                output.generated_text = ""
484
                output.prompt_len = data.get("usage", {}).get("prompt_tokens", 0)
485
486
487
488
489
490
491
492
493
494
495
496
            else:
                output.success = False
                output.error = response.reason or ""
    except Exception as e:
        output.success = False
        output.error = str(e)

    if pbar:
        pbar.update(1)
    return output


497
# TODO: Add more request functions for different API protocols.
498
ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = {
499
500
501
502
    "vllm": async_request_openai_completions,
    "openai": async_request_openai_completions,
    "openai-chat": async_request_openai_chat_completions,
    "openai-audio": async_request_openai_audio,
503
    "openai-embeddings": async_request_openai_embeddings,
504
}
505
506

OPENAI_COMPATIBLE_BACKENDS = [
507
508
509
    k
    for k, v in ASYNC_REQUEST_FUNCS.items()
    if v in (async_request_openai_completions, async_request_openai_chat_completions)
510
]