"docs/usage/troubleshooting.md" did not exist on "3610fb49302867af5b2598b218b3011bc9ed52aa"
backend_request_func.py 23.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import io
4
5
import json
import os
6
import sys
7
import time
8
9
import traceback
from dataclasses import dataclass, field
10
from typing import Optional, Union
11
12

import aiohttp
13
import huggingface_hub.constants
14
from tqdm.asyncio import tqdm
15
16
from transformers import (AutoTokenizer, PreTrainedTokenizer,
                          PreTrainedTokenizerFast)
17

18
19
# NOTE(simon): do not import vLLM here so the benchmark script
# can run without vLLM installed.
20

21
22
23
24
25
26
27
28
29
30
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)


@dataclass
class RequestFuncInput:
    prompt: str
    api_url: str
    prompt_len: int
    output_len: int
    model: str
31
    model_name: Optional[str] = None
32
    logprobs: Optional[int] = None
33
    extra_body: Optional[dict] = None
34
    multi_modal_content: Optional[dict] = None
35
    ignore_eos: bool = False
36
    language: Optional[str] = None
37
38
39
40
41
42


@dataclass
class RequestFuncOutput:
    generated_text: str = ""
    success: bool = False
43
    latency: float = 0.0
44
    output_tokens: int = 0
45
    ttft: float = 0.0  # Time to first token
46
47
    itl: list[float] = field(
        default_factory=list)  # list of inter-token latencies
48
    tpot: float = 0.0  # avg next-token latencies
49
    prompt_len: int = 0
50
    error: str = ""
51
52
53
54
55
56
57
58
59


async def async_request_tgi(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

60
61
    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
62
63
64
65
66
        params = {
            "max_new_tokens": request_func_input.output_len,
            "do_sample": True,
            "temperature": 0.01,  # TGI does not accept 0.0 temperature.
            "top_p": 0.99,  # TGI does not accept 1.0 top_p.
67
            "truncate": request_func_input.prompt_len,
68
            "ignore_eos_token": request_func_input.ignore_eos,
69
70
71
72
73
74
75
        }
        payload = {
            "inputs": request_func_input.prompt,
            "parameters": params,
        }
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len
76
77
78
79
        if request_func_input.ignore_eos:
            output.output_tokens = request_func_input.output_len
        else:
            output.output_tokens = None
80

81
        ttft = 0.0
82
        st = time.perf_counter()
83
        most_recent_timestamp = st
84
85
86
        try:
            async with session.post(url=api_url, json=payload) as response:
                if response.status == 200:
87
88
89
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
90
                            continue
91
                        chunk_bytes = chunk_bytes.decode("utf-8")
92

93
                        # NOTE: Sometimes TGI returns a ping response without
94
95
96
                        # any data, we should skip it.
                        if chunk_bytes.startswith(":"):
                            continue
97
                        chunk = chunk_bytes.removeprefix("data:")
98

99
100
101
                        data = json.loads(chunk)
                        timestamp = time.perf_counter()
                        # First token
102
                        if ttft == 0.0:
103
104
105
                            ttft = time.perf_counter() - st
                            output.ttft = ttft

106
107
108
109
                        # Decoding phase
                        else:
                            output.itl.append(timestamp -
                                              most_recent_timestamp)
110

111
112
113
114
115
                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
                    output.success = True
                    output.generated_text = data["generated_text"]
116
117
118
                else:
                    output.error = response.reason or ""
                    output.success = False
119
        except Exception:
120
            output.success = False
121
122
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
123
124
125
126
127
128
129
130
131
132
133
134
135

        if pbar:
            pbar.update(1)
        return output


async def async_request_trt_llm(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

136
137
    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
138
139
140
141
142
143
144
145
        payload = {
            "accumulate_tokens": True,
            "text_input": request_func_input.prompt,
            "temperature": 0.0,
            "top_p": 1.0,
            "max_tokens": request_func_input.output_len,
            "stream": True,
        }
146
147
        if request_func_input.ignore_eos:
            payload["min_length"] = request_func_input.output_len
148
149
150
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

151
        ttft = 0.0
152
        st = time.perf_counter()
153
        most_recent_timestamp = st
154
        try:
155
156
            async with session.post(url=api_url, json=payload) as response:
                if response.status == 200:
157
158
159
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
160
161
                            continue

162
163
                        chunk = chunk_bytes.decode("utf-8").removeprefix(
                            "data:")
164
165

                        data = json.loads(chunk)
166
                        output.generated_text += data["text_output"]
167
168
                        timestamp = time.perf_counter()
                        # First token
169
                        if ttft == 0.0:
170
                            ttft = timestamp - st
171
172
                            output.ttft = ttft

173
174
175
176
177
178
179
180
                        # Decoding phase
                        else:
                            output.itl.append(timestamp -
                                              most_recent_timestamp)

                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
181
182
183
                    output.success = True

                else:
184
                    output.error = response.reason or ""
185
                    output.success = False
186
        except Exception:
187
            output.success = False
188
189
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
190
191
192
193
194
195
196
197
198
199

        if pbar:
            pbar.update(1)
        return output


async def async_request_deepspeed_mii(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
200
201
    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
202
203

        payload = {
204
            "model": request_func_input.model,
205
206
207
            "prompt": request_func_input.prompt,
            "max_tokens": request_func_input.output_len,
            "temperature": 0.01,  # deepspeed-mii does not accept 0.0 temp.
208
209
210
211
212
            "top_p": 1.0,
        }
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

213
        # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
214
        # will use 0 as placeholder.
215
        # See https://github.com/microsoft/DeepSpeed-MII/pull/311
216
217
218
219
220
        output.ttft = 0

        st = time.perf_counter()
        try:
            async with session.post(url=request_func_input.api_url,
221
222
223
                                    json=payload) as response:
                if response.status == 200:
                    parsed_resp = await response.json()
224
                    output.latency = time.perf_counter() - st
225
226
227
228
229
230
231
232
233
                    if "choices" in parsed_resp:
                        output.generated_text = parsed_resp["choices"][0][
                            "text"]
                    elif "text" in parsed_resp:
                        output.generated_text = parsed_resp["text"][0]
                    else:
                        output.error = ("Unexpected response format: "
                                        "neither 'choices' nor 'text' found")
                        output.success = False
234
235
                    output.success = True
                else:
236
                    output.error = response.reason or ""
237
                    output.success = False
238
        except Exception:
239
            output.success = False
240
241
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
242
243
244
245
246
247
248
249
250
251
252

        if pbar:
            pbar.update(1)
        return output


async def async_request_openai_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
253
    assert api_url.endswith(
254
255
        ("completions", "profile")
    ), "OpenAI Completions API URL must end with 'completions' or 'profile'."
256

257
258
    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
259
        payload = {
260
261
            "model": request_func_input.model_name \
                if request_func_input.model_name else request_func_input.model,
262
263
            "prompt": request_func_input.prompt,
            "temperature": 0.0,
264
            "repetition_penalty": 1.0,
265
            "max_tokens": request_func_input.output_len,
266
            "logprobs": request_func_input.logprobs,
267
            "stream": True,
268
269
270
            "stream_options": {
                "include_usage": True,
            },
271
        }
272
273
        if request_func_input.ignore_eos:
            payload["ignore_eos"] = request_func_input.ignore_eos
274
275
        if request_func_input.extra_body:
            payload.update(request_func_input.extra_body)
276
277
278
279
280
281
282
283
284
        headers = {
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
        }

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
        st = time.perf_counter()
285
        most_recent_timestamp = st
286
287
288
289
        try:
            async with session.post(url=api_url, json=payload,
                                    headers=headers) as response:
                if response.status == 200:
290
                    first_chunk_received = False
291
292
293
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
294
295
                            continue

296
297
                        chunk = chunk_bytes.decode("utf-8").removeprefix(
                            "data: ")
298
                        if chunk != "[DONE]":
299
300
                            data = json.loads(chunk)

301
302
303
                            # NOTE: Some completion API might have a last
                            # usage summary response without a token so we
                            # want to check a token was generated
304
305
306
307
                            if choices := data.get("choices"):
                                # Note that text could be empty here
                                # e.g. for special tokens
                                text = choices[0].get("text")
308
309
                                timestamp = time.perf_counter()
                                # First token
310
                                if not first_chunk_received:
311
                                    first_chunk_received = True
312
313
314
315
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft

                                # Decoding phase
316
317
318
                                else:
                                    output.itl.append(timestamp -
                                                      most_recent_timestamp)
319
320

                                most_recent_timestamp = timestamp
321
                                generated_text += text or ""
322
323
324
                            elif usage := data.get("usage"):
                                output.output_tokens = usage.get(
                                    "completion_tokens")
325
326
327
328
329
330
331
                    if first_chunk_received:
                        output.success = True
                    else:
                        output.success = False
                        output.error = (
                            "Never received a valid chunk to calculate TTFT."
                            "This response will be marked as failed!")
332
                    output.generated_text = generated_text
333
                    output.latency = most_recent_timestamp - st
334
335
336
                else:
                    output.error = response.reason or ""
                    output.success = False
337
        except Exception:
338
            output.success = False
339
340
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
341
342
343
344
345
346

    if pbar:
        pbar.update(1)
    return output


347
348
349
350
351
352
async def async_request_openai_chat_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith(
353
        ("chat/completions", "profile")
354
    ), "OpenAI Chat Completions API URL must end with 'chat/completions'."
355

356
357
    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
358
359
360
        content = [{"type": "text", "text": request_func_input.prompt}]
        if request_func_input.multi_modal_content:
            content.append(request_func_input.multi_modal_content)
361
        payload = {
362
363
            "model": request_func_input.model_name \
                if request_func_input.model_name else request_func_input.model,
364
365
366
            "messages": [
                {
                    "role": "user",
367
                    "content": content
368
369
370
                },
            ],
            "temperature": 0.0,
371
            "max_completion_tokens": request_func_input.output_len,
372
            "stream": True,
373
374
375
            "stream_options": {
                "include_usage": True,
            },
376
        }
377
378
        if request_func_input.ignore_eos:
            payload["ignore_eos"] = request_func_input.ignore_eos
379
380
        if request_func_input.extra_body:
            payload.update(request_func_input.extra_body)
381
382
        headers = {
            "Content-Type": "application/json",
383
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
384
385
386
387
388
389
        }

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
390
        ttft = 0.0
391
        st = time.perf_counter()
392
        most_recent_timestamp = st
393
394
395
396
        try:
            async with session.post(url=api_url, json=payload,
                                    headers=headers) as response:
                if response.status == 200:
397
398
399
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
400
401
                            continue

402
403
                        chunk = chunk_bytes.decode("utf-8").removeprefix(
                            "data: ")
404
                        if chunk != "[DONE]":
405
406
407
                            timestamp = time.perf_counter()
                            data = json.loads(chunk)

408
409
                            if choices := data.get("choices"):
                                content = choices[0]["delta"].get("content")
410
                                # First token
411
                                if ttft == 0.0:
412
                                    ttft = timestamp - st
413
414
415
416
417
418
419
                                    output.ttft = ttft

                                # Decoding phase
                                else:
                                    output.itl.append(timestamp -
                                                      most_recent_timestamp)

420
                                generated_text += content or ""
421
422
423
                            elif usage := data.get("usage"):
                                output.output_tokens = usage.get(
                                    "completion_tokens")
424

425
426
                            most_recent_timestamp = timestamp

427
428
                    output.generated_text = generated_text
                    output.success = True
429
                    output.latency = most_recent_timestamp - st
430
                else:
431
                    output.error = response.reason or ""
432
                    output.success = False
433
        except Exception:
434
            output.success = False
435
436
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
437
438
439
440
441
442

    if pbar:
        pbar.update(1)
    return output


443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
async def async_request_openai_audio(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    # Lazy import without PlaceholderModule to avoid vllm dep.
    import soundfile
    api_url = request_func_input.api_url
    assert api_url.endswith(
        ("transcriptions", "translations"
         )), "OpenAI Chat Completions API URL must end with 'transcriptions' "
    "or `translations`."

    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
        content = [{"type": "text", "text": request_func_input.prompt}]
        payload = {
            "model": request_func_input.model_name \
                if request_func_input.model_name else request_func_input.model,
            "temperature": 0.0,
            "max_completion_tokens": request_func_input.output_len,
            "stream": True,
            "language": "en",
            # Flattened due to multipart/form-data
            "stream_include_usage": True,
            "stream_continuous_usage_stats": True
        }
        if request_func_input.extra_body:
            payload.update(request_func_input.extra_body)
        headers = {
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
        }

        # Send audio file
        def to_bytes(y, sr):
            buffer = io.BytesIO()
            soundfile.write(buffer, y, sr, format="WAV")
            buffer.seek(0)
            return buffer

        with to_bytes(*request_func_input.multi_modal_content['audio']) as f:
            form = aiohttp.FormData()
            form.add_field('file', f, content_type='audio/wav')
            for key, value in payload.items():
                form.add_field(key, str(value))

            output = RequestFuncOutput()
            output.prompt_len = request_func_input.prompt_len

            generated_text = ""
            ttft = 0.0
            st = time.perf_counter()
            most_recent_timestamp = st
            try:
                async with session.post(url=api_url,
                                        data=form,
                                        headers=headers) as response:
                    if response.status == 200:
                        async for chunk_bytes in response.content:
                            chunk_bytes = chunk_bytes.strip()
                            if not chunk_bytes:
                                continue

                            chunk = chunk_bytes.decode("utf-8").removeprefix(
                                "data: ")
                            if chunk != "[DONE]":
                                timestamp = time.perf_counter()
                                data = json.loads(chunk)

                                if choices := data.get("choices"):
                                    content = choices[0]["delta"].get(
                                        "content")
                                    # First token
                                    if ttft == 0.0:
                                        ttft = timestamp - st
                                        output.ttft = ttft

                                    # Decoding phase
                                    else:
                                        output.itl.append(
                                            timestamp - most_recent_timestamp)

                                    generated_text += content or ""
                                elif usage := data.get("usage"):
                                    output.output_tokens = usage.get(
                                        "completion_tokens")

                                most_recent_timestamp = timestamp

                        output.generated_text = generated_text
                        output.success = True
                        output.latency = most_recent_timestamp - st
                    else:
                        output.error = response.reason or ""
                        output.success = False
            except Exception:
                output.success = False
                exc_info = sys.exc_info()
                output.error = "".join(traceback.format_exception(*exc_info))

        if pbar:
            pbar.update(1)
        return output


547
def get_model(pretrained_model_name_or_path: str) -> str:
548
549
    if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
        from modelscope import snapshot_download
550

551
552
        from vllm.model_executor.model_loader.weight_utils import get_lock

553
554
555
556
557
558
559
        # Use file lock to prevent multiple processes from
        # downloading the same model weights at the same time.
        with get_lock(pretrained_model_name_or_path):
            model_path = snapshot_download(
                model_id=pretrained_model_name_or_path,
                local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
560

561
            return model_path
562
    return pretrained_model_name_or_path
563
564
565


def get_tokenizer(
566
567
568
569
    pretrained_model_name_or_path: str,
    tokenizer_mode: str = "auto",
    trust_remote_code: bool = False,
    **kwargs,
570
571
572
573
574
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
    if pretrained_model_name_or_path is not None and not os.path.exists(
            pretrained_model_name_or_path):
        pretrained_model_name_or_path = get_model(
            pretrained_model_name_or_path)
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError(
                "Cannot use the fast tokenizer in slow tokenizer mode.")
        kwargs["use_fast"] = False
    if tokenizer_mode == "mistral":
        try:
            from vllm.transformers_utils.tokenizer import MistralTokenizer
        except ImportError as e:
            raise ImportError("MistralTokenizer requires vllm package.\n"
                              "Please install it with `pip install vllm` "
                              "to use mistral tokenizer mode.") from e
        return MistralTokenizer.from_pretrained(
            str(pretrained_model_name_or_path))
    else:
        return AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path,
            trust_remote_code=trust_remote_code,
            **kwargs,
        )
595
596


597
598
ASYNC_REQUEST_FUNCS = {
    "tgi": async_request_tgi,
599
600
    "vllm": async_request_openai_completions,
    "lmdeploy": async_request_openai_completions,
601
602
    "deepspeed-mii": async_request_deepspeed_mii,
    "openai": async_request_openai_completions,
603
    "openai-chat": async_request_openai_chat_completions,
604
    "openai-audio": async_request_openai_audio,
605
    "tensorrt-llm": async_request_trt_llm,
606
    "scalellm": async_request_openai_completions,
607
    "sglang": async_request_openai_completions,
608
}
609
610
611
612
613
614

OPENAI_COMPATIBLE_BACKENDS = [
    k for k, v in ASYNC_REQUEST_FUNCS.items()
    if v in (async_request_openai_completions,
             async_request_openai_chat_completions)
]