backend_request_func.py 23.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import io
4
5
import json
import os
6
import sys
7
import time
8
9
import traceback
from dataclasses import dataclass, field
10
from typing import Optional, Union
11
12

import aiohttp
13
import huggingface_hub.constants
14
from tqdm.asyncio import tqdm
15
16
from transformers import (AutoTokenizer, PreTrainedTokenizer,
                          PreTrainedTokenizerFast)
17

18
19
# NOTE(simon): do not import vLLM here so the benchmark script
# can run without vLLM installed.
20

21
22
23
24
25
26
27
28
29
30
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)


@dataclass
class RequestFuncInput:
    prompt: str
    api_url: str
    prompt_len: int
    output_len: int
    model: str
31
    model_name: Optional[str] = None
32
    logprobs: Optional[int] = None
33
    extra_body: Optional[dict] = None
34
    multi_modal_content: Optional[dict] = None
35
    ignore_eos: bool = False
36
    language: Optional[str] = None
37
38
39
40
41
42


@dataclass
class RequestFuncOutput:
    generated_text: str = ""
    success: bool = False
43
    latency: float = 0.0
44
    output_tokens: int = 0
45
    ttft: float = 0.0  # Time to first token
46
47
    itl: list[float] = field(
        default_factory=list)  # list of inter-token latencies
48
    tpot: float = 0.0  # avg next-token latencies
49
    prompt_len: int = 0
50
    error: str = ""
51
52
53
54
55
56
57
58
59


async def async_request_tgi(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

60
61
    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
62
63
64
65
66
        params = {
            "max_new_tokens": request_func_input.output_len,
            "do_sample": True,
            "temperature": 0.01,  # TGI does not accept 0.0 temperature.
            "top_p": 0.99,  # TGI does not accept 1.0 top_p.
67
            "truncate": request_func_input.prompt_len,
68
            "ignore_eos_token": request_func_input.ignore_eos,
69
70
71
72
73
74
75
        }
        payload = {
            "inputs": request_func_input.prompt,
            "parameters": params,
        }
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len
76
77
78
79
        if request_func_input.ignore_eos:
            output.output_tokens = request_func_input.output_len
        else:
            output.output_tokens = None
80

81
        ttft = 0.0
82
        st = time.perf_counter()
83
        most_recent_timestamp = st
84
85
86
        try:
            async with session.post(url=api_url, json=payload) as response:
                if response.status == 200:
87
88
89
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
90
                            continue
91
                        chunk_bytes = chunk_bytes.decode("utf-8")
92

93
                        # NOTE: Sometimes TGI returns a ping response without
94
95
96
                        # any data, we should skip it.
                        if chunk_bytes.startswith(":"):
                            continue
97
                        chunk = chunk_bytes.removeprefix("data:")
98

99
100
101
                        data = json.loads(chunk)
                        timestamp = time.perf_counter()
                        # First token
102
                        if ttft == 0.0:
103
104
105
                            ttft = time.perf_counter() - st
                            output.ttft = ttft

106
107
108
109
                        # Decoding phase
                        else:
                            output.itl.append(timestamp -
                                              most_recent_timestamp)
110

111
112
113
114
115
                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
                    output.success = True
                    output.generated_text = data["generated_text"]
116
117
118
                else:
                    output.error = response.reason or ""
                    output.success = False
119
        except Exception:
120
            output.success = False
121
122
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
123
124
125
126
127
128
129
130
131
132
133
134
135

        if pbar:
            pbar.update(1)
        return output


async def async_request_trt_llm(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

136
137
    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
138
139
140
141
142
143
144
145
        payload = {
            "accumulate_tokens": True,
            "text_input": request_func_input.prompt,
            "temperature": 0.0,
            "top_p": 1.0,
            "max_tokens": request_func_input.output_len,
            "stream": True,
        }
146
147
        if request_func_input.ignore_eos:
            payload["min_length"] = request_func_input.output_len
148
149
150
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

151
        ttft = 0.0
152
        st = time.perf_counter()
153
        most_recent_timestamp = st
154
        try:
155
156
            async with session.post(url=api_url, json=payload) as response:
                if response.status == 200:
157
158
159
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
160
161
                            continue

162
163
                        chunk = chunk_bytes.decode("utf-8").removeprefix(
                            "data:")
164
165

                        data = json.loads(chunk)
166
                        output.generated_text += data["text_output"]
167
168
                        timestamp = time.perf_counter()
                        # First token
169
                        if ttft == 0.0:
170
                            ttft = timestamp - st
171
172
                            output.ttft = ttft

173
174
175
176
177
178
179
180
                        # Decoding phase
                        else:
                            output.itl.append(timestamp -
                                              most_recent_timestamp)

                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
181
182
183
                    output.success = True

                else:
184
                    output.error = response.reason or ""
185
                    output.success = False
186
        except Exception:
187
            output.success = False
188
189
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
190
191
192
193
194
195
196
197
198
199

        if pbar:
            pbar.update(1)
        return output


async def async_request_deepspeed_mii(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
200
201
    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
202
203

        payload = {
204
205
206
            "prompt": request_func_input.prompt,
            "max_tokens": request_func_input.output_len,
            "temperature": 0.01,  # deepspeed-mii does not accept 0.0 temp.
207
208
209
210
211
            "top_p": 1.0,
        }
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

212
        # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
213
        # will use 0 as placeholder.
214
        # See https://github.com/microsoft/DeepSpeed-MII/pull/311
215
216
217
218
219
        output.ttft = 0

        st = time.perf_counter()
        try:
            async with session.post(url=request_func_input.api_url,
220
221
222
                                    json=payload) as response:
                if response.status == 200:
                    parsed_resp = await response.json()
223
                    output.latency = time.perf_counter() - st
224
225
226
227
228
229
230
231
232
                    if "choices" in parsed_resp:
                        output.generated_text = parsed_resp["choices"][0][
                            "text"]
                    elif "text" in parsed_resp:
                        output.generated_text = parsed_resp["text"][0]
                    else:
                        output.error = ("Unexpected response format: "
                                        "neither 'choices' nor 'text' found")
                        output.success = False
233
234
                    output.success = True
                else:
235
                    output.error = response.reason or ""
236
                    output.success = False
237
        except Exception:
238
            output.success = False
239
240
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
241
242
243
244
245
246
247
248
249
250
251

        if pbar:
            pbar.update(1)
        return output


async def async_request_openai_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
252
    assert api_url.endswith(
253
254
        ("completions", "profile")
    ), "OpenAI Completions API URL must end with 'completions' or 'profile'."
255

256
257
    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
258
        payload = {
259
260
            "model": request_func_input.model_name \
                if request_func_input.model_name else request_func_input.model,
261
262
263
            "prompt": request_func_input.prompt,
            "temperature": 0.0,
            "max_tokens": request_func_input.output_len,
264
            "logprobs": request_func_input.logprobs,
265
            "stream": True,
266
267
268
            "stream_options": {
                "include_usage": True,
            },
269
        }
270
271
        if request_func_input.ignore_eos:
            payload["ignore_eos"] = request_func_input.ignore_eos
272
273
        if request_func_input.extra_body:
            payload.update(request_func_input.extra_body)
274
275
276
277
278
279
280
281
282
        headers = {
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
        }

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
        st = time.perf_counter()
283
        most_recent_timestamp = st
284
285
286
287
        try:
            async with session.post(url=api_url, json=payload,
                                    headers=headers) as response:
                if response.status == 200:
288
                    first_chunk_received = False
289
290
291
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
292
293
                            continue

294
295
                        chunk = chunk_bytes.decode("utf-8").removeprefix(
                            "data: ")
296
                        if chunk != "[DONE]":
297
298
                            data = json.loads(chunk)

299
300
301
                            # NOTE: Some completion API might have a last
                            # usage summary response without a token so we
                            # want to check a token was generated
302
303
304
305
                            if choices := data.get("choices"):
                                # Note that text could be empty here
                                # e.g. for special tokens
                                text = choices[0].get("text")
306
307
                                timestamp = time.perf_counter()
                                # First token
308
                                if not first_chunk_received:
309
                                    first_chunk_received = True
310
311
312
313
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft

                                # Decoding phase
314
315
316
                                else:
                                    output.itl.append(timestamp -
                                                      most_recent_timestamp)
317
318

                                most_recent_timestamp = timestamp
319
                                generated_text += text or ""
320
321
322
                            elif usage := data.get("usage"):
                                output.output_tokens = usage.get(
                                    "completion_tokens")
323
324
325
326
327
328
329
                    if first_chunk_received:
                        output.success = True
                    else:
                        output.success = False
                        output.error = (
                            "Never received a valid chunk to calculate TTFT."
                            "This response will be marked as failed!")
330
                    output.generated_text = generated_text
331
                    output.latency = most_recent_timestamp - st
332
333
334
                else:
                    output.error = response.reason or ""
                    output.success = False
335
        except Exception:
336
            output.success = False
337
338
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
339
340
341
342
343
344

    if pbar:
        pbar.update(1)
    return output


345
346
347
348
349
350
async def async_request_openai_chat_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith(
351
        ("chat/completions", "profile")
352
    ), "OpenAI Chat Completions API URL must end with 'chat/completions'."
353

354
355
    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
356
357
358
        content = [{"type": "text", "text": request_func_input.prompt}]
        if request_func_input.multi_modal_content:
            content.append(request_func_input.multi_modal_content)
359
        payload = {
360
361
            "model": request_func_input.model_name \
                if request_func_input.model_name else request_func_input.model,
362
363
364
            "messages": [
                {
                    "role": "user",
365
                    "content": content
366
367
368
                },
            ],
            "temperature": 0.0,
369
            "max_completion_tokens": request_func_input.output_len,
370
            "stream": True,
371
372
373
            "stream_options": {
                "include_usage": True,
            },
374
        }
375
376
        if request_func_input.ignore_eos:
            payload["ignore_eos"] = request_func_input.ignore_eos
377
378
        if request_func_input.extra_body:
            payload.update(request_func_input.extra_body)
379
380
        headers = {
            "Content-Type": "application/json",
381
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
382
383
384
385
386
387
        }

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
388
        ttft = 0.0
389
        st = time.perf_counter()
390
        most_recent_timestamp = st
391
392
393
394
        try:
            async with session.post(url=api_url, json=payload,
                                    headers=headers) as response:
                if response.status == 200:
395
396
397
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
398
399
                            continue

400
401
                        chunk = chunk_bytes.decode("utf-8").removeprefix(
                            "data: ")
402
                        if chunk != "[DONE]":
403
404
405
                            timestamp = time.perf_counter()
                            data = json.loads(chunk)

406
407
                            if choices := data.get("choices"):
                                content = choices[0]["delta"].get("content")
408
                                # First token
409
                                if ttft == 0.0:
410
                                    ttft = timestamp - st
411
412
413
414
415
416
417
                                    output.ttft = ttft

                                # Decoding phase
                                else:
                                    output.itl.append(timestamp -
                                                      most_recent_timestamp)

418
                                generated_text += content or ""
419
420
421
                            elif usage := data.get("usage"):
                                output.output_tokens = usage.get(
                                    "completion_tokens")
422

423
424
                            most_recent_timestamp = timestamp

425
426
                    output.generated_text = generated_text
                    output.success = True
427
                    output.latency = most_recent_timestamp - st
428
                else:
429
                    output.error = response.reason or ""
430
                    output.success = False
431
        except Exception:
432
            output.success = False
433
434
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
435
436
437
438
439
440

    if pbar:
        pbar.update(1)
    return output


441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
async def async_request_openai_audio(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    # Lazy import without PlaceholderModule to avoid vllm dep.
    import soundfile
    api_url = request_func_input.api_url
    assert api_url.endswith(
        ("transcriptions", "translations"
         )), "OpenAI Chat Completions API URL must end with 'transcriptions' "
    "or `translations`."

    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
        content = [{"type": "text", "text": request_func_input.prompt}]
        payload = {
            "model": request_func_input.model_name \
                if request_func_input.model_name else request_func_input.model,
            "temperature": 0.0,
            "max_completion_tokens": request_func_input.output_len,
            "stream": True,
            "language": "en",
            # Flattened due to multipart/form-data
            "stream_include_usage": True,
            "stream_continuous_usage_stats": True
        }
        if request_func_input.extra_body:
            payload.update(request_func_input.extra_body)
        headers = {
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
        }

        # Send audio file
        def to_bytes(y, sr):
            buffer = io.BytesIO()
            soundfile.write(buffer, y, sr, format="WAV")
            buffer.seek(0)
            return buffer

        with to_bytes(*request_func_input.multi_modal_content['audio']) as f:
            form = aiohttp.FormData()
            form.add_field('file', f, content_type='audio/wav')
            for key, value in payload.items():
                form.add_field(key, str(value))

            output = RequestFuncOutput()
            output.prompt_len = request_func_input.prompt_len

            generated_text = ""
            ttft = 0.0
            st = time.perf_counter()
            most_recent_timestamp = st
            try:
                async with session.post(url=api_url,
                                        data=form,
                                        headers=headers) as response:
                    if response.status == 200:
                        async for chunk_bytes in response.content:
                            chunk_bytes = chunk_bytes.strip()
                            if not chunk_bytes:
                                continue

                            chunk = chunk_bytes.decode("utf-8").removeprefix(
                                "data: ")
                            if chunk != "[DONE]":
                                timestamp = time.perf_counter()
                                data = json.loads(chunk)

                                if choices := data.get("choices"):
                                    content = choices[0]["delta"].get(
                                        "content")
                                    # First token
                                    if ttft == 0.0:
                                        ttft = timestamp - st
                                        output.ttft = ttft

                                    # Decoding phase
                                    else:
                                        output.itl.append(
                                            timestamp - most_recent_timestamp)

                                    generated_text += content or ""
                                elif usage := data.get("usage"):
                                    output.output_tokens = usage.get(
                                        "completion_tokens")

                                most_recent_timestamp = timestamp

                        output.generated_text = generated_text
                        output.success = True
                        output.latency = most_recent_timestamp - st
                    else:
                        output.error = response.reason or ""
                        output.success = False
            except Exception:
                output.success = False
                exc_info = sys.exc_info()
                output.error = "".join(traceback.format_exception(*exc_info))

        if pbar:
            pbar.update(1)
        return output


545
def get_model(pretrained_model_name_or_path: str) -> str:
546
547
    if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
        from modelscope import snapshot_download
548

549
550
        from vllm.model_executor.model_loader.weight_utils import get_lock

551
552
553
554
555
556
557
        # Use file lock to prevent multiple processes from
        # downloading the same model weights at the same time.
        with get_lock(pretrained_model_name_or_path):
            model_path = snapshot_download(
                model_id=pretrained_model_name_or_path,
                local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
558

559
            return model_path
560
    return pretrained_model_name_or_path
561
562
563


def get_tokenizer(
564
565
566
567
    pretrained_model_name_or_path: str,
    tokenizer_mode: str = "auto",
    trust_remote_code: bool = False,
    **kwargs,
568
569
570
571
572
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
    if pretrained_model_name_or_path is not None and not os.path.exists(
            pretrained_model_name_or_path):
        pretrained_model_name_or_path = get_model(
            pretrained_model_name_or_path)
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError(
                "Cannot use the fast tokenizer in slow tokenizer mode.")
        kwargs["use_fast"] = False
    if tokenizer_mode == "mistral":
        try:
            from vllm.transformers_utils.tokenizer import MistralTokenizer
        except ImportError as e:
            raise ImportError("MistralTokenizer requires vllm package.\n"
                              "Please install it with `pip install vllm` "
                              "to use mistral tokenizer mode.") from e
        return MistralTokenizer.from_pretrained(
            str(pretrained_model_name_or_path))
    else:
        return AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path,
            trust_remote_code=trust_remote_code,
            **kwargs,
        )
593
594


595
596
ASYNC_REQUEST_FUNCS = {
    "tgi": async_request_tgi,
597
598
    "vllm": async_request_openai_completions,
    "lmdeploy": async_request_openai_completions,
599
600
    "deepspeed-mii": async_request_deepspeed_mii,
    "openai": async_request_openai_completions,
601
    "openai-chat": async_request_openai_chat_completions,
602
    "openai-audio": async_request_openai_audio,
603
    "tensorrt-llm": async_request_trt_llm,
604
    "scalellm": async_request_openai_completions,
605
    "sglang": async_request_openai_completions,
606
}
607
608
609
610
611
612

OPENAI_COMPATIBLE_BACKENDS = [
    k for k, v in ASYNC_REQUEST_FUNCS.items()
    if v in (async_request_openai_completions,
             async_request_openai_chat_completions)
]