backend_request_func.py 23.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import io
4
5
import json
import os
6
import sys
7
import time
8
9
import traceback
from dataclasses import dataclass, field
10
from typing import Optional, Union
11
12

import aiohttp
13
import huggingface_hub.constants
14
from tqdm.asyncio import tqdm
15
16
from transformers import (AutoTokenizer, PreTrainedTokenizer,
                          PreTrainedTokenizerFast)
17

18
19
# NOTE(simon): do not import vLLM here so the benchmark script
# can run without vLLM installed.
20

21
22
23
24
25
26
27
28
29
30
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)


@dataclass
class RequestFuncInput:
    prompt: str
    api_url: str
    prompt_len: int
    output_len: int
    model: str
31
    model_name: Optional[str] = None
32
    logprobs: Optional[int] = None
33
    extra_body: Optional[dict] = None
34
    multi_modal_content: Optional[dict] = None
35
    ignore_eos: bool = False
36
    language: Optional[str] = None
37
38
39
40
41
42


@dataclass
class RequestFuncOutput:
    generated_text: str = ""
    success: bool = False
43
    latency: float = 0.0
44
    output_tokens: int = 0
45
    ttft: float = 0.0  # Time to first token
46
47
    itl: list[float] = field(
        default_factory=list)  # list of inter-token latencies
48
    tpot: float = 0.0  # avg next-token latencies
49
    prompt_len: int = 0
50
    error: str = ""
51
52
53
54
55
56
57
58
59


async def async_request_tgi(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

60
61
    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
62
63
64
65
66
        params = {
            "max_new_tokens": request_func_input.output_len,
            "do_sample": True,
            "temperature": 0.01,  # TGI does not accept 0.0 temperature.
            "top_p": 0.99,  # TGI does not accept 1.0 top_p.
67
            "truncate": request_func_input.prompt_len,
68
            "ignore_eos_token": request_func_input.ignore_eos,
69
70
71
72
73
74
75
        }
        payload = {
            "inputs": request_func_input.prompt,
            "parameters": params,
        }
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len
76
77
78
79
        if request_func_input.ignore_eos:
            output.output_tokens = request_func_input.output_len
        else:
            output.output_tokens = None
80

81
        ttft = 0.0
82
        st = time.perf_counter()
83
        most_recent_timestamp = st
84
85
86
        try:
            async with session.post(url=api_url, json=payload) as response:
                if response.status == 200:
87
88
89
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
90
                            continue
91
                        chunk_bytes = chunk_bytes.decode("utf-8")
92

93
                        # NOTE: Sometimes TGI returns a ping response without
94
95
96
                        # any data, we should skip it.
                        if chunk_bytes.startswith(":"):
                            continue
97
                        chunk = chunk_bytes.removeprefix("data:")
98

99
100
101
                        data = json.loads(chunk)
                        timestamp = time.perf_counter()
                        # First token
102
                        if ttft == 0.0:
103
104
105
                            ttft = time.perf_counter() - st
                            output.ttft = ttft

106
107
108
109
                        # Decoding phase
                        else:
                            output.itl.append(timestamp -
                                              most_recent_timestamp)
110

111
112
113
114
115
                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
                    output.success = True
                    output.generated_text = data["generated_text"]
116
117
118
                else:
                    output.error = response.reason or ""
                    output.success = False
119
        except Exception:
120
            output.success = False
121
122
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
123
124
125
126
127
128
129
130
131
132
133
134
135

        if pbar:
            pbar.update(1)
        return output


async def async_request_trt_llm(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

136
137
    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
138
139
140
141
142
143
144
145
        payload = {
            "accumulate_tokens": True,
            "text_input": request_func_input.prompt,
            "temperature": 0.0,
            "top_p": 1.0,
            "max_tokens": request_func_input.output_len,
            "stream": True,
        }
146
147
        if request_func_input.ignore_eos:
            payload["min_length"] = request_func_input.output_len
148
149
150
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

151
        ttft = 0.0
152
        st = time.perf_counter()
153
        most_recent_timestamp = st
154
        try:
155
156
            async with session.post(url=api_url, json=payload) as response:
                if response.status == 200:
157
158
159
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
160
161
                            continue

162
163
                        chunk = chunk_bytes.decode("utf-8").removeprefix(
                            "data:")
164
165

                        data = json.loads(chunk)
166
                        output.generated_text += data["text_output"]
167
168
                        timestamp = time.perf_counter()
                        # First token
169
                        if ttft == 0.0:
170
                            ttft = timestamp - st
171
172
                            output.ttft = ttft

173
174
175
176
177
178
179
180
                        # Decoding phase
                        else:
                            output.itl.append(timestamp -
                                              most_recent_timestamp)

                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
181
182
183
                    output.success = True

                else:
184
                    output.error = response.reason or ""
185
                    output.success = False
186
        except Exception:
187
            output.success = False
188
189
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
190
191
192
193
194
195
196
197
198
199

        if pbar:
            pbar.update(1)
        return output


async def async_request_deepspeed_mii(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
200
201
    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
202
203

        payload = {
204
205
206
            "prompt": request_func_input.prompt,
            "max_tokens": request_func_input.output_len,
            "temperature": 0.01,  # deepspeed-mii does not accept 0.0 temp.
207
208
209
210
211
            "top_p": 1.0,
        }
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

212
        # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
213
        # will use 0 as placeholder.
214
        # See https://github.com/microsoft/DeepSpeed-MII/pull/311
215
216
217
218
219
        output.ttft = 0

        st = time.perf_counter()
        try:
            async with session.post(url=request_func_input.api_url,
220
221
222
                                    json=payload) as response:
                if response.status == 200:
                    parsed_resp = await response.json()
223
                    output.latency = time.perf_counter() - st
224
225
226
227
228
229
230
231
232
                    if "choices" in parsed_resp:
                        output.generated_text = parsed_resp["choices"][0][
                            "text"]
                    elif "text" in parsed_resp:
                        output.generated_text = parsed_resp["text"][0]
                    else:
                        output.error = ("Unexpected response format: "
                                        "neither 'choices' nor 'text' found")
                        output.success = False
233
234
                    output.success = True
                else:
235
                    output.error = response.reason or ""
236
                    output.success = False
237
        except Exception:
238
            output.success = False
239
240
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
241
242
243
244
245
246
247
248
249
250
251

        if pbar:
            pbar.update(1)
        return output


async def async_request_openai_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
252
    assert api_url.endswith(
253
254
        ("completions", "profile")
    ), "OpenAI Completions API URL must end with 'completions' or 'profile'."
255

256
257
    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
258
        payload = {
259
260
            "model": request_func_input.model_name \
                if request_func_input.model_name else request_func_input.model,
261
262
            "prompt": request_func_input.prompt,
            "temperature": 0.0,
263
            "repetition_penalty": 1.0,
264
            "max_tokens": request_func_input.output_len,
265
            "logprobs": request_func_input.logprobs,
266
            "stream": True,
267
268
269
            "stream_options": {
                "include_usage": True,
            },
270
        }
271
272
        if request_func_input.ignore_eos:
            payload["ignore_eos"] = request_func_input.ignore_eos
273
274
        if request_func_input.extra_body:
            payload.update(request_func_input.extra_body)
275
276
277
278
279
280
281
282
283
        headers = {
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
        }

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
        st = time.perf_counter()
284
        most_recent_timestamp = st
285
286
287
288
        try:
            async with session.post(url=api_url, json=payload,
                                    headers=headers) as response:
                if response.status == 200:
289
                    first_chunk_received = False
290
291
292
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
293
294
                            continue

295
296
                        chunk = chunk_bytes.decode("utf-8").removeprefix(
                            "data: ")
297
                        if chunk != "[DONE]":
298
299
                            data = json.loads(chunk)

300
301
302
                            # NOTE: Some completion API might have a last
                            # usage summary response without a token so we
                            # want to check a token was generated
303
304
305
306
                            if choices := data.get("choices"):
                                # Note that text could be empty here
                                # e.g. for special tokens
                                text = choices[0].get("text")
307
308
                                timestamp = time.perf_counter()
                                # First token
309
                                if not first_chunk_received:
310
                                    first_chunk_received = True
311
312
313
314
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft

                                # Decoding phase
315
316
317
                                else:
                                    output.itl.append(timestamp -
                                                      most_recent_timestamp)
318
319

                                most_recent_timestamp = timestamp
320
                                generated_text += text or ""
321
322
323
                            elif usage := data.get("usage"):
                                output.output_tokens = usage.get(
                                    "completion_tokens")
324
325
326
327
328
329
330
                    if first_chunk_received:
                        output.success = True
                    else:
                        output.success = False
                        output.error = (
                            "Never received a valid chunk to calculate TTFT."
                            "This response will be marked as failed!")
331
                    output.generated_text = generated_text
332
                    output.latency = most_recent_timestamp - st
333
334
335
                else:
                    output.error = response.reason or ""
                    output.success = False
336
        except Exception:
337
            output.success = False
338
339
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
340
341
342
343
344
345

    if pbar:
        pbar.update(1)
    return output


346
347
348
349
350
351
async def async_request_openai_chat_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith(
352
        ("chat/completions", "profile")
353
    ), "OpenAI Chat Completions API URL must end with 'chat/completions'."
354

355
356
    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
357
358
359
        content = [{"type": "text", "text": request_func_input.prompt}]
        if request_func_input.multi_modal_content:
            content.append(request_func_input.multi_modal_content)
360
        payload = {
361
362
            "model": request_func_input.model_name \
                if request_func_input.model_name else request_func_input.model,
363
364
365
            "messages": [
                {
                    "role": "user",
366
                    "content": content
367
368
369
                },
            ],
            "temperature": 0.0,
370
            "max_completion_tokens": request_func_input.output_len,
371
            "stream": True,
372
373
374
            "stream_options": {
                "include_usage": True,
            },
375
        }
376
377
        if request_func_input.ignore_eos:
            payload["ignore_eos"] = request_func_input.ignore_eos
378
379
        if request_func_input.extra_body:
            payload.update(request_func_input.extra_body)
380
381
        headers = {
            "Content-Type": "application/json",
382
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
383
384
385
386
387
388
        }

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
389
        ttft = 0.0
390
        st = time.perf_counter()
391
        most_recent_timestamp = st
392
393
394
395
        try:
            async with session.post(url=api_url, json=payload,
                                    headers=headers) as response:
                if response.status == 200:
396
397
398
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
399
400
                            continue

401
402
                        chunk = chunk_bytes.decode("utf-8").removeprefix(
                            "data: ")
403
                        if chunk != "[DONE]":
404
405
406
                            timestamp = time.perf_counter()
                            data = json.loads(chunk)

407
408
                            if choices := data.get("choices"):
                                content = choices[0]["delta"].get("content")
409
                                # First token
410
                                if ttft == 0.0:
411
                                    ttft = timestamp - st
412
413
414
415
416
417
418
                                    output.ttft = ttft

                                # Decoding phase
                                else:
                                    output.itl.append(timestamp -
                                                      most_recent_timestamp)

419
                                generated_text += content or ""
420
421
422
                            elif usage := data.get("usage"):
                                output.output_tokens = usage.get(
                                    "completion_tokens")
423

424
425
                            most_recent_timestamp = timestamp

426
427
                    output.generated_text = generated_text
                    output.success = True
428
                    output.latency = most_recent_timestamp - st
429
                else:
430
                    output.error = response.reason or ""
431
                    output.success = False
432
        except Exception:
433
            output.success = False
434
435
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
436
437
438
439
440
441

    if pbar:
        pbar.update(1)
    return output


442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
async def async_request_openai_audio(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    # Lazy import without PlaceholderModule to avoid vllm dep.
    import soundfile
    api_url = request_func_input.api_url
    assert api_url.endswith(
        ("transcriptions", "translations"
         )), "OpenAI Chat Completions API URL must end with 'transcriptions' "
    "or `translations`."

    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
        content = [{"type": "text", "text": request_func_input.prompt}]
        payload = {
            "model": request_func_input.model_name \
                if request_func_input.model_name else request_func_input.model,
            "temperature": 0.0,
            "max_completion_tokens": request_func_input.output_len,
            "stream": True,
            "language": "en",
            # Flattened due to multipart/form-data
            "stream_include_usage": True,
            "stream_continuous_usage_stats": True
        }
        if request_func_input.extra_body:
            payload.update(request_func_input.extra_body)
        headers = {
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
        }

        # Send audio file
        def to_bytes(y, sr):
            buffer = io.BytesIO()
            soundfile.write(buffer, y, sr, format="WAV")
            buffer.seek(0)
            return buffer

        with to_bytes(*request_func_input.multi_modal_content['audio']) as f:
            form = aiohttp.FormData()
            form.add_field('file', f, content_type='audio/wav')
            for key, value in payload.items():
                form.add_field(key, str(value))

            output = RequestFuncOutput()
            output.prompt_len = request_func_input.prompt_len

            generated_text = ""
            ttft = 0.0
            st = time.perf_counter()
            most_recent_timestamp = st
            try:
                async with session.post(url=api_url,
                                        data=form,
                                        headers=headers) as response:
                    if response.status == 200:
                        async for chunk_bytes in response.content:
                            chunk_bytes = chunk_bytes.strip()
                            if not chunk_bytes:
                                continue

                            chunk = chunk_bytes.decode("utf-8").removeprefix(
                                "data: ")
                            if chunk != "[DONE]":
                                timestamp = time.perf_counter()
                                data = json.loads(chunk)

                                if choices := data.get("choices"):
                                    content = choices[0]["delta"].get(
                                        "content")
                                    # First token
                                    if ttft == 0.0:
                                        ttft = timestamp - st
                                        output.ttft = ttft

                                    # Decoding phase
                                    else:
                                        output.itl.append(
                                            timestamp - most_recent_timestamp)

                                    generated_text += content or ""
                                elif usage := data.get("usage"):
                                    output.output_tokens = usage.get(
                                        "completion_tokens")

                                most_recent_timestamp = timestamp

                        output.generated_text = generated_text
                        output.success = True
                        output.latency = most_recent_timestamp - st
                    else:
                        output.error = response.reason or ""
                        output.success = False
            except Exception:
                output.success = False
                exc_info = sys.exc_info()
                output.error = "".join(traceback.format_exception(*exc_info))

        if pbar:
            pbar.update(1)
        return output


546
def get_model(pretrained_model_name_or_path: str) -> str:
547
548
    if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
        from modelscope import snapshot_download
549

550
551
        from vllm.model_executor.model_loader.weight_utils import get_lock

552
553
554
555
556
557
558
        # Use file lock to prevent multiple processes from
        # downloading the same model weights at the same time.
        with get_lock(pretrained_model_name_or_path):
            model_path = snapshot_download(
                model_id=pretrained_model_name_or_path,
                local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
559

560
            return model_path
561
    return pretrained_model_name_or_path
562
563
564


def get_tokenizer(
565
566
567
568
    pretrained_model_name_or_path: str,
    tokenizer_mode: str = "auto",
    trust_remote_code: bool = False,
    **kwargs,
569
570
571
572
573
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
    if pretrained_model_name_or_path is not None and not os.path.exists(
            pretrained_model_name_or_path):
        pretrained_model_name_or_path = get_model(
            pretrained_model_name_or_path)
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError(
                "Cannot use the fast tokenizer in slow tokenizer mode.")
        kwargs["use_fast"] = False
    if tokenizer_mode == "mistral":
        try:
            from vllm.transformers_utils.tokenizer import MistralTokenizer
        except ImportError as e:
            raise ImportError("MistralTokenizer requires vllm package.\n"
                              "Please install it with `pip install vllm` "
                              "to use mistral tokenizer mode.") from e
        return MistralTokenizer.from_pretrained(
            str(pretrained_model_name_or_path))
    else:
        return AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path,
            trust_remote_code=trust_remote_code,
            **kwargs,
        )
594
595


596
597
ASYNC_REQUEST_FUNCS = {
    "tgi": async_request_tgi,
598
599
    "vllm": async_request_openai_completions,
    "lmdeploy": async_request_openai_completions,
600
601
    "deepspeed-mii": async_request_deepspeed_mii,
    "openai": async_request_openai_completions,
602
    "openai-chat": async_request_openai_chat_completions,
603
    "openai-audio": async_request_openai_audio,
604
    "tensorrt-llm": async_request_trt_llm,
605
    "scalellm": async_request_openai_completions,
606
    "sglang": async_request_openai_completions,
607
}
608
609
610
611
612
613

OPENAI_COMPATIBLE_BACKENDS = [
    k for k, v in ASYNC_REQUEST_FUNCS.items()
    if v in (async_request_openai_completions,
             async_request_openai_chat_completions)
]