backend_request_func.py 18.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import json
import os
5
import sys
6
import time
7
8
import traceback
from dataclasses import dataclass, field
9
from typing import Optional, Union
10
11

import aiohttp
12
import huggingface_hub.constants
13
from tqdm.asyncio import tqdm
14
15
from transformers import (AutoTokenizer, PreTrainedTokenizer,
                          PreTrainedTokenizerFast)
16

17
18
# NOTE(simon): do not import vLLM here so the benchmark script
# can run without vLLM installed.
19

20
21
22
23
24
25
26
27
28
29
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)


@dataclass
class RequestFuncInput:
    prompt: str
    api_url: str
    prompt_len: int
    output_len: int
    model: str
30
    model_name: Optional[str] = None
31
    logprobs: Optional[int] = None
32
    extra_body: Optional[dict] = None
33
    multi_modal_content: Optional[dict] = None
34
    ignore_eos: bool = False
35
36
37
38
39
40


@dataclass
class RequestFuncOutput:
    generated_text: str = ""
    success: bool = False
41
    latency: float = 0.0
42
    output_tokens: int = 0
43
    ttft: float = 0.0  # Time to first token
44
45
    itl: list[float] = field(
        default_factory=list)  # list of inter-token latencies
46
    tpot: float = 0.0  # avg next-token latencies
47
    prompt_len: int = 0
48
    error: str = ""
49
50
51
52
53
54
55
56
57


async def async_request_tgi(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

58
59
    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
60
61
62
63
64
        params = {
            "max_new_tokens": request_func_input.output_len,
            "do_sample": True,
            "temperature": 0.01,  # TGI does not accept 0.0 temperature.
            "top_p": 0.99,  # TGI does not accept 1.0 top_p.
65
            "truncate": request_func_input.prompt_len,
66
            "ignore_eos_token": request_func_input.ignore_eos,
67
68
69
70
71
72
73
        }
        payload = {
            "inputs": request_func_input.prompt,
            "parameters": params,
        }
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len
74
75
76
77
        if request_func_input.ignore_eos:
            output.output_tokens = request_func_input.output_len
        else:
            output.output_tokens = None
78

79
        ttft = 0.0
80
        st = time.perf_counter()
81
        most_recent_timestamp = st
82
83
84
        try:
            async with session.post(url=api_url, json=payload) as response:
                if response.status == 200:
85
86
87
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
88
                            continue
89
                        chunk_bytes = chunk_bytes.decode("utf-8")
90

91
                        # NOTE: Sometimes TGI returns a ping response without
92
93
94
                        # any data, we should skip it.
                        if chunk_bytes.startswith(":"):
                            continue
95
                        chunk = chunk_bytes.removeprefix("data:")
96

97
98
99
                        data = json.loads(chunk)
                        timestamp = time.perf_counter()
                        # First token
100
                        if ttft == 0.0:
101
102
103
                            ttft = time.perf_counter() - st
                            output.ttft = ttft

104
105
106
107
                        # Decoding phase
                        else:
                            output.itl.append(timestamp -
                                              most_recent_timestamp)
108

109
110
111
112
113
                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
                    output.success = True
                    output.generated_text = data["generated_text"]
114
115
116
                else:
                    output.error = response.reason or ""
                    output.success = False
117
        except Exception:
118
            output.success = False
119
120
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
121
122
123
124
125
126
127
128
129
130
131
132
133

        if pbar:
            pbar.update(1)
        return output


async def async_request_trt_llm(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

134
135
    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
136
137
138
139
140
141
142
143
        payload = {
            "accumulate_tokens": True,
            "text_input": request_func_input.prompt,
            "temperature": 0.0,
            "top_p": 1.0,
            "max_tokens": request_func_input.output_len,
            "stream": True,
        }
144
145
        if request_func_input.ignore_eos:
            payload["min_length"] = request_func_input.output_len
146
147
148
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

149
        ttft = 0.0
150
        st = time.perf_counter()
151
        most_recent_timestamp = st
152
        try:
153
154
            async with session.post(url=api_url, json=payload) as response:
                if response.status == 200:
155
156
157
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
158
159
                            continue

160
161
                        chunk = chunk_bytes.decode("utf-8").removeprefix(
                            "data:")
162
163

                        data = json.loads(chunk)
164
                        output.generated_text += data["text_output"]
165
166
                        timestamp = time.perf_counter()
                        # First token
167
                        if ttft == 0.0:
168
                            ttft = timestamp - st
169
170
                            output.ttft = ttft

171
172
173
174
175
176
177
178
                        # Decoding phase
                        else:
                            output.itl.append(timestamp -
                                              most_recent_timestamp)

                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
179
180
181
                    output.success = True

                else:
182
                    output.error = response.reason or ""
183
                    output.success = False
184
        except Exception:
185
            output.success = False
186
187
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
188
189
190
191
192
193
194
195
196
197

        if pbar:
            pbar.update(1)
        return output


async def async_request_deepspeed_mii(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
198
199
    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
200
201

        payload = {
202
203
204
            "prompt": request_func_input.prompt,
            "max_tokens": request_func_input.output_len,
            "temperature": 0.01,  # deepspeed-mii does not accept 0.0 temp.
205
206
207
208
209
            "top_p": 1.0,
        }
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

210
        # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
211
        # will use 0 as placeholder.
212
        # See https://github.com/microsoft/DeepSpeed-MII/pull/311
213
214
215
216
217
        output.ttft = 0

        st = time.perf_counter()
        try:
            async with session.post(url=request_func_input.api_url,
218
219
220
                                    json=payload) as response:
                if response.status == 200:
                    parsed_resp = await response.json()
221
                    output.latency = time.perf_counter() - st
222
223
224
225
226
227
228
229
230
                    if "choices" in parsed_resp:
                        output.generated_text = parsed_resp["choices"][0][
                            "text"]
                    elif "text" in parsed_resp:
                        output.generated_text = parsed_resp["text"][0]
                    else:
                        output.error = ("Unexpected response format: "
                                        "neither 'choices' nor 'text' found")
                        output.success = False
231
232
                    output.success = True
                else:
233
                    output.error = response.reason or ""
234
                    output.success = False
235
        except Exception:
236
            output.success = False
237
238
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
239
240
241
242
243
244
245
246
247
248
249

        if pbar:
            pbar.update(1)
        return output


async def async_request_openai_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
250
    assert api_url.endswith(
251
252
        ("completions", "profile")
    ), "OpenAI Completions API URL must end with 'completions' or 'profile'."
253

254
255
    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
256
        payload = {
257
258
            "model": request_func_input.model_name \
                if request_func_input.model_name else request_func_input.model,
259
260
261
            "prompt": request_func_input.prompt,
            "temperature": 0.0,
            "max_tokens": request_func_input.output_len,
262
            "logprobs": request_func_input.logprobs,
263
            "stream": True,
264
265
266
            "stream_options": {
                "include_usage": True,
            },
267
        }
268
269
        if request_func_input.ignore_eos:
            payload["ignore_eos"] = request_func_input.ignore_eos
270
271
        if request_func_input.extra_body:
            payload.update(request_func_input.extra_body)
272
273
274
275
276
277
278
279
280
        headers = {
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
        }

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
        st = time.perf_counter()
281
        most_recent_timestamp = st
282
283
284
285
        try:
            async with session.post(url=api_url, json=payload,
                                    headers=headers) as response:
                if response.status == 200:
286
                    first_chunk_received = False
287
288
289
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
290
291
                            continue

292
293
                        chunk = chunk_bytes.decode("utf-8").removeprefix(
                            "data: ")
294
                        if chunk != "[DONE]":
295
296
                            data = json.loads(chunk)

297
298
299
                            # NOTE: Some completion API might have a last
                            # usage summary response without a token so we
                            # want to check a token was generated
300
301
302
303
                            if choices := data.get("choices"):
                                # Note that text could be empty here
                                # e.g. for special tokens
                                text = choices[0].get("text")
304
305
                                timestamp = time.perf_counter()
                                # First token
306
                                if not first_chunk_received:
307
                                    first_chunk_received = True
308
309
310
311
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft

                                # Decoding phase
312
313
314
                                else:
                                    output.itl.append(timestamp -
                                                      most_recent_timestamp)
315
316

                                most_recent_timestamp = timestamp
317
                                generated_text += text or ""
318
319
320
                            elif usage := data.get("usage"):
                                output.output_tokens = usage.get(
                                    "completion_tokens")
321
322
323
324
325
326
327
                    if first_chunk_received:
                        output.success = True
                    else:
                        output.success = False
                        output.error = (
                            "Never received a valid chunk to calculate TTFT."
                            "This response will be marked as failed!")
328
                    output.generated_text = generated_text
329
                    output.latency = most_recent_timestamp - st
330
331
332
                else:
                    output.error = response.reason or ""
                    output.success = False
333
        except Exception:
334
            output.success = False
335
336
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
337
338
339
340
341
342

    if pbar:
        pbar.update(1)
    return output


343
344
345
346
347
348
async def async_request_openai_chat_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith(
349
        ("chat/completions", "profile")
350
    ), "OpenAI Chat Completions API URL must end with 'chat/completions'."
351

352
353
    async with aiohttp.ClientSession(trust_env=True,
                                     timeout=AIOHTTP_TIMEOUT) as session:
354
355
356
        content = [{"type": "text", "text": request_func_input.prompt}]
        if request_func_input.multi_modal_content:
            content.append(request_func_input.multi_modal_content)
357
        payload = {
358
359
            "model": request_func_input.model_name \
                if request_func_input.model_name else request_func_input.model,
360
361
362
            "messages": [
                {
                    "role": "user",
363
                    "content": content
364
365
366
                },
            ],
            "temperature": 0.0,
367
            "max_completion_tokens": request_func_input.output_len,
368
            "stream": True,
369
370
371
            "stream_options": {
                "include_usage": True,
            },
372
        }
373
374
        if request_func_input.ignore_eos:
            payload["ignore_eos"] = request_func_input.ignore_eos
375
376
        if request_func_input.extra_body:
            payload.update(request_func_input.extra_body)
377
378
        headers = {
            "Content-Type": "application/json",
379
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
380
381
382
383
384
385
        }

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
386
        ttft = 0.0
387
        st = time.perf_counter()
388
        most_recent_timestamp = st
389
390
391
392
        try:
            async with session.post(url=api_url, json=payload,
                                    headers=headers) as response:
                if response.status == 200:
393
394
395
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
396
397
                            continue

398
399
                        chunk = chunk_bytes.decode("utf-8").removeprefix(
                            "data: ")
400
                        if chunk != "[DONE]":
401
402
403
                            timestamp = time.perf_counter()
                            data = json.loads(chunk)

404
405
                            if choices := data.get("choices"):
                                content = choices[0]["delta"].get("content")
406
                                # First token
407
                                if ttft == 0.0:
408
                                    ttft = timestamp - st
409
410
411
412
413
414
415
                                    output.ttft = ttft

                                # Decoding phase
                                else:
                                    output.itl.append(timestamp -
                                                      most_recent_timestamp)

416
                                generated_text += content or ""
417
418
419
                            elif usage := data.get("usage"):
                                output.output_tokens = usage.get(
                                    "completion_tokens")
420

421
422
                            most_recent_timestamp = timestamp

423
424
                    output.generated_text = generated_text
                    output.success = True
425
                    output.latency = most_recent_timestamp - st
426
                else:
427
                    output.error = response.reason or ""
428
                    output.success = False
429
        except Exception:
430
            output.success = False
431
432
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
433
434
435
436
437
438

    if pbar:
        pbar.update(1)
    return output


439
def get_model(pretrained_model_name_or_path: str) -> str:
440
441
    if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
        from modelscope import snapshot_download
442

443
444
        from vllm.model_executor.model_loader.weight_utils import get_lock

445
446
447
448
449
450
451
        # Use file lock to prevent multiple processes from
        # downloading the same model weights at the same time.
        with get_lock(pretrained_model_name_or_path):
            model_path = snapshot_download(
                model_id=pretrained_model_name_or_path,
                local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
452

453
            return model_path
454
    return pretrained_model_name_or_path
455
456
457


def get_tokenizer(
458
459
460
461
    pretrained_model_name_or_path: str,
    tokenizer_mode: str = "auto",
    trust_remote_code: bool = False,
    **kwargs,
462
463
464
465
466
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
    if pretrained_model_name_or_path is not None and not os.path.exists(
            pretrained_model_name_or_path):
        pretrained_model_name_or_path = get_model(
            pretrained_model_name_or_path)
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError(
                "Cannot use the fast tokenizer in slow tokenizer mode.")
        kwargs["use_fast"] = False
    if tokenizer_mode == "mistral":
        try:
            from vllm.transformers_utils.tokenizer import MistralTokenizer
        except ImportError as e:
            raise ImportError("MistralTokenizer requires vllm package.\n"
                              "Please install it with `pip install vllm` "
                              "to use mistral tokenizer mode.") from e
        return MistralTokenizer.from_pretrained(
            str(pretrained_model_name_or_path))
    else:
        return AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path,
            trust_remote_code=trust_remote_code,
            **kwargs,
        )
487
488


489
490
ASYNC_REQUEST_FUNCS = {
    "tgi": async_request_tgi,
491
492
    "vllm": async_request_openai_completions,
    "lmdeploy": async_request_openai_completions,
493
494
    "deepspeed-mii": async_request_deepspeed_mii,
    "openai": async_request_openai_completions,
495
    "openai-chat": async_request_openai_chat_completions,
496
    "tensorrt-llm": async_request_trt_llm,
497
    "scalellm": async_request_openai_completions,
498
    "sglang": async_request_openai_completions,
499
}