"vscode:/vscode.git/clone" did not exist on "04147dcfa70fb7228ce9e2f88fa7dd41631d17f0"
backend_request_func.py 17.6 KB
Newer Older
1
2
import json
import os
3
import sys
4
import time
5
6
import traceback
from dataclasses import dataclass, field
7
from typing import List, Optional, Union
8
9

import aiohttp
10
import huggingface_hub.constants
11
from tqdm.asyncio import tqdm
12
13
from transformers import (AutoTokenizer, PreTrainedTokenizer,
                          PreTrainedTokenizerFast)
14
15
16
17
18
19
20
21
22
23
24

AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)


@dataclass
class RequestFuncInput:
    prompt: str
    api_url: str
    prompt_len: int
    output_len: int
    model: str
25
    model_name: Optional[str] = None
26
    best_of: int = 1
27
    logprobs: Optional[int] = None
28
    extra_body: Optional[dict] = None
29
    multi_modal_content: Optional[dict] = None
30
    ignore_eos: bool = False
31
32
33
34
35
36


@dataclass
class RequestFuncOutput:
    generated_text: str = ""
    success: bool = False
37
    latency: float = 0.0
38
    output_tokens: int = 0
39
    ttft: float = 0.0  # Time to first token
40
41
    itl: List[float] = field(
        default_factory=list)  # List of inter-token latencies
42
    tpot: float = 0.0  # avg next-token latencies
43
    prompt_len: int = 0
44
    error: str = ""
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60


async def async_request_tgi(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        params = {
            "best_of": request_func_input.best_of,
            "max_new_tokens": request_func_input.output_len,
            "do_sample": True,
            "temperature": 0.01,  # TGI does not accept 0.0 temperature.
            "top_p": 0.99,  # TGI does not accept 1.0 top_p.
61
            "truncate": request_func_input.prompt_len,
62
            # TGI does not accept ignore_eos flag.
63
64
65
66
67
68
69
70
        }
        payload = {
            "inputs": request_func_input.prompt,
            "parameters": params,
        }
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

71
        ttft = 0.0
72
        st = time.perf_counter()
73
        most_recent_timestamp = st
74
75
76
        try:
            async with session.post(url=api_url, json=payload) as response:
                if response.status == 200:
77
78
79
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
80
                            continue
81
                        chunk_bytes = chunk_bytes.decode("utf-8")
82

83
                        # NOTE: Sometimes TGI returns a ping response without
84
85
86
                        # any data, we should skip it.
                        if chunk_bytes.startswith(":"):
                            continue
87
                        chunk = chunk_bytes.removeprefix("data:")
88

89
90
91
                        data = json.loads(chunk)
                        timestamp = time.perf_counter()
                        # First token
92
                        if ttft == 0.0:
93
94
95
                            ttft = time.perf_counter() - st
                            output.ttft = ttft

96
97
98
99
                        # Decoding phase
                        else:
                            output.itl.append(timestamp -
                                              most_recent_timestamp)
100

101
102
103
104
105
                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
                    output.success = True
                    output.generated_text = data["generated_text"]
106
107
108
                else:
                    output.error = response.reason or ""
                    output.success = False
109
        except Exception:
110
            output.success = False
111
112
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135

        if pbar:
            pbar.update(1)
        return output


async def async_request_trt_llm(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        assert request_func_input.best_of == 1
        payload = {
            "accumulate_tokens": True,
            "text_input": request_func_input.prompt,
            "temperature": 0.0,
            "top_p": 1.0,
            "max_tokens": request_func_input.output_len,
            "stream": True,
        }
136
137
        if request_func_input.ignore_eos:
            payload["min_length"] = request_func_input.output_len
138
139
140
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

141
        ttft = 0.0
142
        st = time.perf_counter()
143
        most_recent_timestamp = st
144
        try:
145
146
            async with session.post(url=api_url, json=payload) as response:
                if response.status == 200:
147
148
149
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
150
151
                            continue

152
153
                        chunk = chunk_bytes.decode("utf-8").removeprefix(
                            "data:")
154
155

                        data = json.loads(chunk)
156
                        output.generated_text += data["text_output"]
157
158
                        timestamp = time.perf_counter()
                        # First token
159
                        if ttft == 0.0:
160
                            ttft = timestamp - st
161
162
                            output.ttft = ttft

163
164
165
166
167
168
169
170
                        # Decoding phase
                        else:
                            output.itl.append(timestamp -
                                              most_recent_timestamp)

                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
171
172
173
                    output.success = True

                else:
174
                    output.error = response.reason or ""
175
                    output.success = False
176
        except Exception:
177
            output.success = False
178
179
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
180
181
182
183
184
185
186
187
188
189
190
191
192
193

        if pbar:
            pbar.update(1)
        return output


async def async_request_deepspeed_mii(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        assert request_func_input.best_of == 1

        payload = {
194
195
196
            "prompt": request_func_input.prompt,
            "max_tokens": request_func_input.output_len,
            "temperature": 0.01,  # deepspeed-mii does not accept 0.0 temp.
197
198
199
200
201
            "top_p": 1.0,
        }
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

202
        # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
203
        # will use 0 as placeholder.
204
        # See https://github.com/microsoft/DeepSpeed-MII/pull/311
205
206
207
208
209
        output.ttft = 0

        st = time.perf_counter()
        try:
            async with session.post(url=request_func_input.api_url,
210
211
212
                                    json=payload) as response:
                if response.status == 200:
                    parsed_resp = await response.json()
213
                    output.latency = time.perf_counter() - st
214
                    output.generated_text = parsed_resp["text"][0]
215
216
                    output.success = True
                else:
217
                    output.error = response.reason or ""
218
                    output.success = False
219
        except Exception:
220
            output.success = False
221
222
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
223
224
225
226
227
228
229
230
231
232
233

        if pbar:
            pbar.update(1)
        return output


async def async_request_openai_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
234
    assert api_url.endswith(
235
236
        ("completions", "profile")
    ), "OpenAI Completions API URL must end with 'completions' or 'profile'."
237
238
239

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        payload = {
240
241
            "model": request_func_input.model_name \
                if request_func_input.model_name else request_func_input.model,
242
243
244
245
            "prompt": request_func_input.prompt,
            "temperature": 0.0,
            "best_of": request_func_input.best_of,
            "max_tokens": request_func_input.output_len,
246
            "logprobs": request_func_input.logprobs,
247
            "stream": True,
248
            "ignore_eos": request_func_input.ignore_eos,
249
250
251
            "stream_options": {
                "include_usage": True,
            },
252
        }
253
254
        if request_func_input.extra_body:
            payload.update(request_func_input.extra_body)
255
256
257
258
259
260
261
262
263
        headers = {
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
        }

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
        st = time.perf_counter()
264
        most_recent_timestamp = st
265
266
267
268
        try:
            async with session.post(url=api_url, json=payload,
                                    headers=headers) as response:
                if response.status == 200:
269
                    first_chunk_received = False
270
271
272
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
273
274
                            continue

275
276
                        chunk = chunk_bytes.decode("utf-8").removeprefix(
                            "data: ")
277
                        if chunk != "[DONE]":
278
279
                            data = json.loads(chunk)

280
281
282
                            # NOTE: Some completion API might have a last
                            # usage summary response without a token so we
                            # want to check a token was generated
283
284
285
286
                            if choices := data.get("choices"):
                                # Note that text could be empty here
                                # e.g. for special tokens
                                text = choices[0].get("text")
287
288
                                timestamp = time.perf_counter()
                                # First token
289
                                if not first_chunk_received:
290
                                    first_chunk_received = True
291
292
293
294
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft

                                # Decoding phase
295
296
297
                                else:
                                    output.itl.append(timestamp -
                                                      most_recent_timestamp)
298
299

                                most_recent_timestamp = timestamp
300
301
302
303
                                generated_text += text
                            elif usage := data.get("usage"):
                                output.output_tokens = usage.get(
                                    "completion_tokens")
304
305
306
307
308
309
310
                    if first_chunk_received:
                        output.success = True
                    else:
                        output.success = False
                        output.error = (
                            "Never received a valid chunk to calculate TTFT."
                            "This response will be marked as failed!")
311
                    output.generated_text = generated_text
312
                    output.latency = most_recent_timestamp - st
313
314
315
                else:
                    output.error = response.reason or ""
                    output.success = False
316
        except Exception:
317
            output.success = False
318
319
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
320
321
322
323
324
325

    if pbar:
        pbar.update(1)
    return output


326
327
328
329
330
331
async def async_request_openai_chat_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith(
332
333
        "chat/completions"
    ), "OpenAI Chat Completions API URL must end with 'chat/completions'."
334
335

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
336
337
338
        content = [{"type": "text", "text": request_func_input.prompt}]
        if request_func_input.multi_modal_content:
            content.append(request_func_input.multi_modal_content)
339
        payload = {
340
341
            "model": request_func_input.model_name \
                if request_func_input.model_name else request_func_input.model,
342
343
344
            "messages": [
                {
                    "role": "user",
345
                    "content": content
346
347
348
                },
            ],
            "temperature": 0.0,
349
            "max_completion_tokens": request_func_input.output_len,
350
            "stream": True,
351
            "ignore_eos": request_func_input.ignore_eos,
352
353
354
            "stream_options": {
                "include_usage": True,
            },
355
        }
356
357
        if request_func_input.extra_body:
            payload.update(request_func_input.extra_body)
358
359
        headers = {
            "Content-Type": "application/json",
360
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
361
362
363
364
365
366
        }

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
367
        ttft = 0.0
368
        st = time.perf_counter()
369
        most_recent_timestamp = st
370
371
372
373
        try:
            async with session.post(url=api_url, json=payload,
                                    headers=headers) as response:
                if response.status == 200:
374
375
376
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
377
378
                            continue

379
380
                        chunk = chunk_bytes.decode("utf-8").removeprefix(
                            "data: ")
381
                        if chunk != "[DONE]":
382
383
384
                            timestamp = time.perf_counter()
                            data = json.loads(chunk)

385
386
                            if choices := data.get("choices"):
                                content = choices[0]["delta"].get("content")
387
                                # First token
388
                                if ttft == 0.0:
389
                                    ttft = timestamp - st
390
391
392
393
394
395
396
                                    output.ttft = ttft

                                # Decoding phase
                                else:
                                    output.itl.append(timestamp -
                                                      most_recent_timestamp)

397
398
399
400
                                generated_text += content
                            elif usage := data.get("usage"):
                                output.output_tokens = usage.get(
                                    "completion_tokens")
401

402
403
                            most_recent_timestamp = timestamp

404
405
                    output.generated_text = generated_text
                    output.success = True
406
                    output.latency = most_recent_timestamp - st
407
                else:
408
                    output.error = response.reason or ""
409
                    output.success = False
410
        except Exception:
411
            output.success = False
412
413
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
414
415
416
417
418
419

    if pbar:
        pbar.update(1)
    return output


420
def get_model(pretrained_model_name_or_path: str) -> str:
421
422
    if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
        from modelscope import snapshot_download
423
424
425
426
427
428
429
430

        model_path = snapshot_download(
            model_id=pretrained_model_name_or_path,
            local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
            ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])

        return model_path
    return pretrained_model_name_or_path
431
432
433


def get_tokenizer(
434
435
436
437
    pretrained_model_name_or_path: str,
    tokenizer_mode: str = "auto",
    trust_remote_code: bool = False,
    **kwargs,
438
439
440
441
442
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
    if pretrained_model_name_or_path is not None and not os.path.exists(
            pretrained_model_name_or_path):
        pretrained_model_name_or_path = get_model(
            pretrained_model_name_or_path)
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
    if tokenizer_mode == "slow":
        if kwargs.get("use_fast", False):
            raise ValueError(
                "Cannot use the fast tokenizer in slow tokenizer mode.")
        kwargs["use_fast"] = False
    if tokenizer_mode == "mistral":
        try:
            from vllm.transformers_utils.tokenizer import MistralTokenizer
        except ImportError as e:
            raise ImportError("MistralTokenizer requires vllm package.\n"
                              "Please install it with `pip install vllm` "
                              "to use mistral tokenizer mode.") from e
        return MistralTokenizer.from_pretrained(
            str(pretrained_model_name_or_path))
    else:
        return AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path,
            trust_remote_code=trust_remote_code,
            **kwargs,
        )
463
464


465
466
ASYNC_REQUEST_FUNCS = {
    "tgi": async_request_tgi,
467
468
    "vllm": async_request_openai_completions,
    "lmdeploy": async_request_openai_completions,
469
470
    "deepspeed-mii": async_request_deepspeed_mii,
    "openai": async_request_openai_completions,
471
    "openai-chat": async_request_openai_chat_completions,
472
    "tensorrt-llm": async_request_trt_llm,
473
    "scalellm": async_request_openai_completions,
474
    "sglang": async_request_openai_completions,
475
}