backend_request_func.py 24.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import io
5
6
import json
import os
7
import sys
8
import time
9
10
import traceback
from dataclasses import dataclass, field
11
from typing import Optional, Union
12
13

import aiohttp
14
import huggingface_hub.constants
15
from tqdm.asyncio import tqdm
16
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
17

18
19
# NOTE(simon): do not import vLLM here so the benchmark script
# can run without vLLM installed.
20

21
22
23
24
25
26
27
28
29
30
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)


@dataclass
class RequestFuncInput:
    prompt: str
    api_url: str
    prompt_len: int
    output_len: int
    model: str
31
    model_name: Optional[str] = None
32
    logprobs: Optional[int] = None
33
    extra_body: Optional[dict] = None
34
    multi_modal_content: Optional[dict | list[dict]] = None
35
    ignore_eos: bool = False
36
    language: Optional[str] = None
37
    request_id: Optional[str] = None
38
39
40
41
42
43


@dataclass
class RequestFuncOutput:
    generated_text: str = ""
    success: bool = False
44
    latency: float = 0.0
45
    output_tokens: int = 0
46
    ttft: float = 0.0  # Time to first token
47
    itl: list[float] = field(default_factory=list)  # list of inter-token latencies
48
    tpot: float = 0.0  # avg next-token latencies
49
    prompt_len: int = 0
50
    error: str = ""
51
52
53
54
55
56
57
58
59


async def async_request_tgi(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

60
61
62
    async with aiohttp.ClientSession(
        trust_env=True, timeout=AIOHTTP_TIMEOUT
    ) as session:
63
64
65
66
67
        params = {
            "max_new_tokens": request_func_input.output_len,
            "do_sample": True,
            "temperature": 0.01,  # TGI does not accept 0.0 temperature.
            "top_p": 0.99,  # TGI does not accept 1.0 top_p.
68
            "truncate": request_func_input.prompt_len,
69
            "ignore_eos_token": request_func_input.ignore_eos,
70
71
72
73
74
        }
        payload = {
            "inputs": request_func_input.prompt,
            "parameters": params,
        }
75
76
77
        headers = None
        if request_func_input.request_id:
            headers = {"x-request-id": request_func_input.request_id}
78
79
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len
80
81
82
83
        if request_func_input.ignore_eos:
            output.output_tokens = request_func_input.output_len
        else:
            output.output_tokens = None
84

85
        ttft = 0.0
86
        st = time.perf_counter()
87
        most_recent_timestamp = st
88
        try:
89
90
91
            async with session.post(
                url=api_url, json=payload, headers=headers
            ) as response:
92
                if response.status == 200:
93
94
95
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
96
                            continue
97
                        chunk_bytes = chunk_bytes.decode("utf-8")
98

99
                        # NOTE: Sometimes TGI returns a ping response without
100
101
102
                        # any data, we should skip it.
                        if chunk_bytes.startswith(":"):
                            continue
103
                        chunk = chunk_bytes.removeprefix("data:")
104

105
106
107
                        data = json.loads(chunk)
                        timestamp = time.perf_counter()
                        # First token
108
                        if ttft == 0.0:
109
110
111
                            ttft = time.perf_counter() - st
                            output.ttft = ttft

112
113
                        # Decoding phase
                        else:
114
                            output.itl.append(timestamp - most_recent_timestamp)
115

116
117
118
119
120
                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
                    output.success = True
                    output.generated_text = data["generated_text"]
121
122
123
                else:
                    output.error = response.reason or ""
                    output.success = False
124
        except Exception:
125
            output.success = False
126
127
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
128
129
130
131
132
133
134
135
136
137
138
139
140

        if pbar:
            pbar.update(1)
        return output


async def async_request_trt_llm(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

141
142
143
    async with aiohttp.ClientSession(
        trust_env=True, timeout=AIOHTTP_TIMEOUT
    ) as session:
144
145
146
147
148
149
150
151
        payload = {
            "accumulate_tokens": True,
            "text_input": request_func_input.prompt,
            "temperature": 0.0,
            "top_p": 1.0,
            "max_tokens": request_func_input.output_len,
            "stream": True,
        }
152
153
        if request_func_input.ignore_eos:
            payload["min_length"] = request_func_input.output_len
154
155
156
        headers = None
        if request_func_input.request_id:
            headers = {"x-request-id": request_func_input.request_id}
157
158
159
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

160
        ttft = 0.0
161
        st = time.perf_counter()
162
        most_recent_timestamp = st
163
        try:
164
165
166
            async with session.post(
                url=api_url, json=payload, headers=headers
            ) as response:
167
                if response.status == 200:
168
169
170
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
171
172
                            continue

173
                        chunk = chunk_bytes.decode("utf-8").removeprefix("data:")
174
175

                        data = json.loads(chunk)
176
                        output.generated_text += data["text_output"]
177
178
                        timestamp = time.perf_counter()
                        # First token
179
                        if ttft == 0.0:
180
                            ttft = timestamp - st
181
182
                            output.ttft = ttft

183
184
                        # Decoding phase
                        else:
185
                            output.itl.append(timestamp - most_recent_timestamp)
186
187
188
189

                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
190
191
192
                    output.success = True

                else:
193
                    output.error = response.reason or ""
194
                    output.success = False
195
        except Exception:
196
            output.success = False
197
198
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
199
200
201
202
203
204
205
206
207
208

        if pbar:
            pbar.update(1)
        return output


async def async_request_deepspeed_mii(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
209
210
211
212
213
    api_url = request_func_input.api_url
    assert api_url.endswith(("completions", "profile")), (
        "OpenAI Completions API URL must end with 'completions' or 'profile'."
    )

214
215
216
    async with aiohttp.ClientSession(
        trust_env=True, timeout=AIOHTTP_TIMEOUT
    ) as session:
217
        payload = {
218
            "model": request_func_input.model,
219
220
221
            "prompt": request_func_input.prompt,
            "max_tokens": request_func_input.output_len,
            "temperature": 0.01,  # deepspeed-mii does not accept 0.0 temp.
222
223
            "top_p": 1.0,
        }
224
        headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
225
226
        if request_func_input.request_id:
            headers["x-request-id"] = request_func_input.request_id
227

228
229
230
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

231
        # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
232
        # will use 0 as placeholder.
233
        # See https://github.com/microsoft/DeepSpeed-MII/pull/311
234
235
236
237
        output.ttft = 0

        st = time.perf_counter()
        try:
238
            async with session.post(
239
                url=api_url, json=payload, headers=headers
240
            ) as response:
241
242
                if response.status == 200:
                    parsed_resp = await response.json()
243
                    output.latency = time.perf_counter() - st
244
                    if "choices" in parsed_resp:
245
                        output.generated_text = parsed_resp["choices"][0]["text"]
246
247
248
                    elif "text" in parsed_resp:
                        output.generated_text = parsed_resp["text"][0]
                    else:
249
250
251
252
                        output.error = (
                            "Unexpected response format: "
                            "neither 'choices' nor 'text' found"
                        )
253
                        output.success = False
254
255
                    output.success = True
                else:
256
                    output.error = response.reason or ""
257
                    output.success = False
258
        except Exception:
259
            output.success = False
260
261
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
262
263
264
265
266
267
268
269
270
271
272

        if pbar:
            pbar.update(1)
        return output


async def async_request_openai_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
273
274
275
    assert api_url.endswith(("completions", "profile")), (
        "OpenAI Completions API URL must end with 'completions' or 'profile'."
    )
276

277
278
279
    async with aiohttp.ClientSession(
        trust_env=True, timeout=AIOHTTP_TIMEOUT
    ) as session:
280
        payload = {
281
282
283
            "model": request_func_input.model_name
            if request_func_input.model_name
            else request_func_input.model,
284
285
            "prompt": request_func_input.prompt,
            "temperature": 0.0,
286
            "repetition_penalty": 1.0,
287
            "max_tokens": request_func_input.output_len,
288
            "logprobs": request_func_input.logprobs,
289
            "stream": True,
290
291
292
            "stream_options": {
                "include_usage": True,
            },
293
        }
294
295
        if request_func_input.ignore_eos:
            payload["ignore_eos"] = request_func_input.ignore_eos
296
297
        if request_func_input.extra_body:
            payload.update(request_func_input.extra_body)
298
        headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
299
300
        if request_func_input.request_id:
            headers["x-request-id"] = request_func_input.request_id
301
302
303
304
305
306

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
        st = time.perf_counter()
307
        most_recent_timestamp = st
308
        try:
309
310
311
            async with session.post(
                url=api_url, json=payload, headers=headers
            ) as response:
312
                if response.status == 200:
313
                    first_chunk_received = False
314
315
316
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
317
318
                            continue

319
                        chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
320
                        if chunk != "[DONE]":
321
322
                            data = json.loads(chunk)

323
324
325
                            # NOTE: Some completion API might have a last
                            # usage summary response without a token so we
                            # want to check a token was generated
326
327
328
329
                            if choices := data.get("choices"):
                                # Note that text could be empty here
                                # e.g. for special tokens
                                text = choices[0].get("text")
330
331
                                timestamp = time.perf_counter()
                                # First token
332
                                if not first_chunk_received:
333
                                    first_chunk_received = True
334
335
336
337
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft

                                # Decoding phase
338
                                else:
339
                                    output.itl.append(timestamp - most_recent_timestamp)
340
341

                                most_recent_timestamp = timestamp
342
                                generated_text += text or ""
343
                            if usage := data.get("usage"):
344
                                output.output_tokens = usage.get("completion_tokens")
345
346
347
348
349
350
                    if first_chunk_received:
                        output.success = True
                    else:
                        output.success = False
                        output.error = (
                            "Never received a valid chunk to calculate TTFT."
351
352
                            "This response will be marked as failed!"
                        )
353
                    output.generated_text = generated_text
354
                    output.latency = most_recent_timestamp - st
355
356
357
                else:
                    output.error = response.reason or ""
                    output.success = False
358
        except Exception:
359
            output.success = False
360
361
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
362
363
364
365
366
367

    if pbar:
        pbar.update(1)
    return output


368
369
370
371
372
async def async_request_openai_chat_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
373
374
375
    assert api_url.endswith(("chat/completions", "profile")), (
        "OpenAI Chat Completions API URL must end with 'chat/completions'."
    )
376

377
378
379
    async with aiohttp.ClientSession(
        trust_env=True, timeout=AIOHTTP_TIMEOUT
    ) as session:
380
381
        content = [{"type": "text", "text": request_func_input.prompt}]
        if request_func_input.multi_modal_content:
382
383
384
385
386
387
388
389
390
            mm_content = request_func_input.multi_modal_content
            if isinstance(mm_content, list):
                content.extend(mm_content)
            elif isinstance(mm_content, dict):
                content.append(mm_content)
            else:
                raise TypeError(
                    "multi_modal_content must be a dict or list[dict] for openai-chat"
                )
391
        payload = {
392
393
394
            "model": request_func_input.model_name
            if request_func_input.model_name
            else request_func_input.model,
395
            "messages": [
396
                {"role": "user", "content": content},
397
398
            ],
            "temperature": 0.0,
399
            "max_completion_tokens": request_func_input.output_len,
400
            "stream": True,
401
402
403
            "stream_options": {
                "include_usage": True,
            },
404
        }
405
406
        if request_func_input.ignore_eos:
            payload["ignore_eos"] = request_func_input.ignore_eos
407
408
        if request_func_input.extra_body:
            payload.update(request_func_input.extra_body)
409
410
        headers = {
            "Content-Type": "application/json",
411
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
412
        }
413
414
        if request_func_input.request_id:
            headers["x-request-id"] = request_func_input.request_id
415
416
417
418
419

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
420
        ttft = 0.0
421
        st = time.perf_counter()
422
        most_recent_timestamp = st
423
        try:
424
425
426
            async with session.post(
                url=api_url, json=payload, headers=headers
            ) as response:
427
                if response.status == 200:
428
429
430
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
431
                            continue
432
433
434
435
436
437
438
                        chunk_bytes = chunk_bytes.decode("utf-8")
                        # NOTE: SSE comments (often used as pings) start with a colon.
                        # These are not JSON data payload and should be skipped.
                        if chunk_bytes.startswith(":"):
                            continue

                        chunk = chunk_bytes.removeprefix("data: ")
439

440
                        if chunk != "[DONE]":
441
442
443
                            timestamp = time.perf_counter()
                            data = json.loads(chunk)

444
445
                            if choices := data.get("choices"):
                                content = choices[0]["delta"].get("content")
446
                                # First token
447
                                if ttft == 0.0:
448
                                    ttft = timestamp - st
449
450
451
452
                                    output.ttft = ttft

                                # Decoding phase
                                else:
453
                                    output.itl.append(timestamp - most_recent_timestamp)
454

455
                                generated_text += content or ""
456
                            elif usage := data.get("usage"):
457
                                output.output_tokens = usage.get("completion_tokens")
458

459
460
                            most_recent_timestamp = timestamp

461
462
                    output.generated_text = generated_text
                    output.success = True
463
                    output.latency = most_recent_timestamp - st
464
                else:
465
                    output.error = response.reason or ""
466
                    output.success = False
467
        except Exception:
468
            output.success = False
469
470
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
471
472
473
474
475
476

    if pbar:
        pbar.update(1)
    return output


477
478
479
480
481
482
async def async_request_openai_audio(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    # Lazy import without PlaceholderModule to avoid vllm dep.
    import soundfile
483

484
    api_url = request_func_input.api_url
485
486
487
    assert api_url.endswith(("transcriptions", "translations")), (
        "OpenAI Chat Completions API URL must end with 'transcriptions' "
    )
488
489
    "or `translations`."

490
491
492
    async with aiohttp.ClientSession(
        trust_env=True, timeout=AIOHTTP_TIMEOUT
    ) as session:
493
494
        content = [{"type": "text", "text": request_func_input.prompt}]
        payload = {
495
496
497
            "model": request_func_input.model_name
            if request_func_input.model_name
            else request_func_input.model,
498
499
500
501
502
503
            "temperature": 0.0,
            "max_completion_tokens": request_func_input.output_len,
            "stream": True,
            "language": "en",
            # Flattened due to multipart/form-data
            "stream_include_usage": True,
504
            "stream_continuous_usage_stats": True,
505
506
507
508
509
510
        }
        if request_func_input.extra_body:
            payload.update(request_func_input.extra_body)
        headers = {
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
        }
511
512
        if request_func_input.request_id:
            headers["x-request-id"] = request_func_input.request_id
513
514
515
516
517
518
519
520

        # Send audio file
        def to_bytes(y, sr):
            buffer = io.BytesIO()
            soundfile.write(buffer, y, sr, format="WAV")
            buffer.seek(0)
            return buffer

521
522
523
524
        mm_audio = request_func_input.multi_modal_content
        if not isinstance(mm_audio, dict) or "audio" not in mm_audio:
            raise TypeError("multi_modal_content must be a dict containing 'audio'")
        with to_bytes(*mm_audio["audio"]) as f:
525
            form = aiohttp.FormData()
526
            form.add_field("file", f, content_type="audio/wav")
527
528
529
530
531
532
533
534
535
536
537
            for key, value in payload.items():
                form.add_field(key, str(value))

            output = RequestFuncOutput()
            output.prompt_len = request_func_input.prompt_len

            generated_text = ""
            ttft = 0.0
            st = time.perf_counter()
            most_recent_timestamp = st
            try:
538
539
540
                async with session.post(
                    url=api_url, data=form, headers=headers
                ) as response:
541
542
543
544
545
546
                    if response.status == 200:
                        async for chunk_bytes in response.content:
                            chunk_bytes = chunk_bytes.strip()
                            if not chunk_bytes:
                                continue

547
                            chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
548
549
550
551
552
                            if chunk != "[DONE]":
                                timestamp = time.perf_counter()
                                data = json.loads(chunk)

                                if choices := data.get("choices"):
553
                                    content = choices[0]["delta"].get("content")
554
555
556
557
558
559
560
561
                                    # First token
                                    if ttft == 0.0:
                                        ttft = timestamp - st
                                        output.ttft = ttft

                                    # Decoding phase
                                    else:
                                        output.itl.append(
562
563
                                            timestamp - most_recent_timestamp
                                        )
564
565
566
567

                                    generated_text += content or ""
                                elif usage := data.get("usage"):
                                    output.output_tokens = usage.get(
568
569
                                        "completion_tokens"
                                    )
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588

                                most_recent_timestamp = timestamp

                        output.generated_text = generated_text
                        output.success = True
                        output.latency = most_recent_timestamp - st
                    else:
                        output.error = response.reason or ""
                        output.success = False
            except Exception:
                output.success = False
                exc_info = sys.exc_info()
                output.error = "".join(traceback.format_exception(*exc_info))

        if pbar:
            pbar.update(1)
        return output


589
def get_model(pretrained_model_name_or_path: str) -> str:
590
    if os.getenv("VLLM_USE_MODELSCOPE", "False").lower() == "true":
591
        from modelscope import snapshot_download
592

593
594
        from vllm.model_executor.model_loader.weight_utils import get_lock

595
596
597
598
599
600
        # Use file lock to prevent multiple processes from
        # downloading the same model weights at the same time.
        with get_lock(pretrained_model_name_or_path):
            model_path = snapshot_download(
                model_id=pretrained_model_name_or_path,
                local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
601
602
                ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
            )
603

604
            return model_path
605
    return pretrained_model_name_or_path
606
607
608


def get_tokenizer(
609
610
611
612
    pretrained_model_name_or_path: str,
    tokenizer_mode: str = "auto",
    trust_remote_code: bool = False,
    **kwargs,
613
614
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
    if pretrained_model_name_or_path is not None and not os.path.exists(
615
616
617
        pretrained_model_name_or_path
    ):
        pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)
618
619
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
620
            raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
621
622
623
624
625
        kwargs["use_fast"] = False
    if tokenizer_mode == "mistral":
        try:
            from vllm.transformers_utils.tokenizer import MistralTokenizer
        except ImportError as e:
626
627
628
629
630
631
            raise ImportError(
                "MistralTokenizer requires vllm package.\n"
                "Please install it with `pip install vllm` "
                "to use mistral tokenizer mode."
            ) from e
        return MistralTokenizer.from_pretrained(str(pretrained_model_name_or_path))
632
633
634
635
636
637
    else:
        return AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path,
            trust_remote_code=trust_remote_code,
            **kwargs,
        )
638
639


640
641
ASYNC_REQUEST_FUNCS = {
    "tgi": async_request_tgi,
642
643
    "vllm": async_request_openai_completions,
    "lmdeploy": async_request_openai_completions,
644
645
    "deepspeed-mii": async_request_deepspeed_mii,
    "openai": async_request_openai_completions,
646
    "openai-chat": async_request_openai_chat_completions,
647
    "openai-audio": async_request_openai_audio,
648
    "tensorrt-llm": async_request_trt_llm,
649
    "scalellm": async_request_openai_completions,
650
    "sglang": async_request_openai_completions,
651
    "llama.cpp": async_request_openai_completions,
652
}
653
654

OPENAI_COMPATIBLE_BACKENDS = [
655
656
657
    k
    for k, v in ASYNC_REQUEST_FUNCS.items()
    if v in (async_request_openai_completions, async_request_openai_chat_completions)
658
]