backend_request_func.py 15.9 KB
Newer Older
1
2
import json
import os
3
import sys
4
import time
5
6
import traceback
from dataclasses import dataclass, field
7
from typing import List, Optional, Union
8
9

import aiohttp
10
import huggingface_hub.constants
11
from tqdm.asyncio import tqdm
12
13
from transformers import (AutoTokenizer, PreTrainedTokenizer,
                          PreTrainedTokenizerFast)
14
15
16
17
18
19
20
21
22
23
24
25
26

AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)


@dataclass
class RequestFuncInput:
    prompt: str
    api_url: str
    prompt_len: int
    output_len: int
    model: str
    best_of: int = 1
    use_beam_search: bool = False
27
    logprobs: Optional[int] = None
28
    multi_modal_content: Optional[dict] = None
29
    ignore_eos: bool = False
30
31
32
33
34
35


@dataclass
class RequestFuncOutput:
    generated_text: str = ""
    success: bool = False
36
37
    latency: float = 0.0
    ttft: float = 0.0  # Time to first token
38
39
    itl: List[float] = field(
        default_factory=list)  # List of inter-token latencies
40
    prompt_len: int = 0
41
    error: str = ""
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58


async def async_request_tgi(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        assert not request_func_input.use_beam_search
        params = {
            "best_of": request_func_input.best_of,
            "max_new_tokens": request_func_input.output_len,
            "do_sample": True,
            "temperature": 0.01,  # TGI does not accept 0.0 temperature.
            "top_p": 0.99,  # TGI does not accept 1.0 top_p.
59
            # TGI does not accept ignore_eos flag.
60
61
62
63
64
65
66
67
        }
        payload = {
            "inputs": request_func_input.prompt,
            "parameters": params,
        }
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

68
        ttft = 0.0
69
        st = time.perf_counter()
70
        most_recent_timestamp = st
71
72
73
        try:
            async with session.post(url=api_url, json=payload) as response:
                if response.status == 200:
74
75
76
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
77
                            continue
78
                        chunk_bytes = chunk_bytes.decode("utf-8")
79

80
81
82
83
84
                        #NOTE: Sometimes TGI returns a ping response without
                        # any data, we should skip it.
                        if chunk_bytes.startswith(":"):
                            continue
                        chunk = remove_prefix(chunk_bytes, "data:")
85

86
87
88
                        data = json.loads(chunk)
                        timestamp = time.perf_counter()
                        # First token
89
                        if ttft == 0.0:
90
91
92
                            ttft = time.perf_counter() - st
                            output.ttft = ttft

93
94
95
96
                        # Decoding phase
                        else:
                            output.itl.append(timestamp -
                                              most_recent_timestamp)
97

98
99
100
101
102
                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
                    output.success = True
                    output.generated_text = data["generated_text"]
103
104
105
                else:
                    output.error = response.reason or ""
                    output.success = False
106
        except Exception:
107
            output.success = False
108
109
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

        if pbar:
            pbar.update(1)
        return output


async def async_request_trt_llm(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        assert not request_func_input.use_beam_search
        assert request_func_input.best_of == 1
        payload = {
            "accumulate_tokens": True,
            "text_input": request_func_input.prompt,
            "temperature": 0.0,
            "top_p": 1.0,
            "max_tokens": request_func_input.output_len,
            "stream": True,
        }
134
135
        if request_func_input.ignore_eos:
            payload["min_length"] = request_func_input.output_len
136
137
138
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

139
        ttft = 0.0
140
        st = time.perf_counter()
141
        most_recent_timestamp = st
142
        try:
143
144
            async with session.post(url=api_url, json=payload) as response:
                if response.status == 200:
145
146
147
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
148
149
                            continue

150
151
                        chunk = remove_prefix(chunk_bytes.decode("utf-8"),
                                              "data:")
152
153

                        data = json.loads(chunk)
154
                        output.generated_text += data["text_output"]
155
156
                        timestamp = time.perf_counter()
                        # First token
157
                        if ttft == 0.0:
158
159
160
                            ttft = time.perf_counter() - st
                            output.ttft = ttft

161
162
163
164
165
166
167
168
                        # Decoding phase
                        else:
                            output.itl.append(timestamp -
                                              most_recent_timestamp)

                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
169
170
171
                    output.success = True

                else:
172
                    output.error = response.reason or ""
173
                    output.success = False
174
        except Exception:
175
            output.success = False
176
177
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192

        if pbar:
            pbar.update(1)
        return output


async def async_request_deepspeed_mii(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        assert request_func_input.best_of == 1
        assert not request_func_input.use_beam_search

        payload = {
193
194
195
            "prompt": request_func_input.prompt,
            "max_tokens": request_func_input.output_len,
            "temperature": 0.01,  # deepspeed-mii does not accept 0.0 temp.
196
197
198
199
200
            "top_p": 1.0,
        }
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

201
        # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
202
        # will use 0 as placeholder.
203
        # See https://github.com/microsoft/DeepSpeed-MII/pull/311
204
205
206
207
208
        output.ttft = 0

        st = time.perf_counter()
        try:
            async with session.post(url=request_func_input.api_url,
209
210
211
                                    json=payload) as response:
                if response.status == 200:
                    parsed_resp = await response.json()
212
                    output.latency = time.perf_counter() - st
213
                    output.generated_text = parsed_resp["text"][0]
214
215
                    output.success = True
                else:
216
                    output.error = response.reason or ""
217
                    output.success = False
218
        except Exception:
219
            output.success = False
220
221
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
222
223
224
225
226
227
228
229
230
231
232

        if pbar:
            pbar.update(1)
        return output


async def async_request_openai_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
233
    assert api_url.endswith(
234
235
        ("completions", "profile")
    ), "OpenAI Completions API URL must end with 'completions' or 'profile'."
236
237
238
239
240
241
242
243
244

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        assert not request_func_input.use_beam_search
        payload = {
            "model": request_func_input.model,
            "prompt": request_func_input.prompt,
            "temperature": 0.0,
            "best_of": request_func_input.best_of,
            "max_tokens": request_func_input.output_len,
245
            "logprobs": request_func_input.logprobs,
246
            "stream": True,
247
            "ignore_eos": request_func_input.ignore_eos,
248
249
250
251
252
253
254
255
256
        }
        headers = {
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
        }

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
257
        ttft = 0.0
258
        st = time.perf_counter()
259
        most_recent_timestamp = st
260
261
262
263
        try:
            async with session.post(url=api_url, json=payload,
                                    headers=headers) as response:
                if response.status == 200:
264
265
266
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
267
268
                            continue

269
270
                        chunk = remove_prefix(chunk_bytes.decode("utf-8"),
                                              "data: ")
271
272
273
                        if chunk == "[DONE]":
                            latency = time.perf_counter() - st
                        else:
274
275
                            data = json.loads(chunk)

276
277
278
                            # NOTE: Some completion API might have a last
                            # usage summary response without a token so we
                            # want to check a token was generated
279
280
281
                            if data["choices"][0]["text"]:
                                timestamp = time.perf_counter()
                                # First token
282
                                if ttft == 0.0:
283
284
285
286
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft

                                # Decoding phase
287
288
289
                                else:
                                    output.itl.append(timestamp -
                                                      most_recent_timestamp)
290
291
292

                                most_recent_timestamp = timestamp
                                generated_text += data["choices"][0]["text"]
293
294
295
296

                    output.generated_text = generated_text
                    output.success = True
                    output.latency = latency
297
298
299
                else:
                    output.error = response.reason or ""
                    output.success = False
300
        except Exception:
301
            output.success = False
302
303
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
304
305
306
307
308
309

    if pbar:
        pbar.update(1)
    return output


310
311
312
313
314
315
async def async_request_openai_chat_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith(
316
317
        "chat/completions"
    ), "OpenAI Chat Completions API URL must end with 'chat/completions'."
318
319
320

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        assert not request_func_input.use_beam_search
321
322
323
        content = [{"type": "text", "text": request_func_input.prompt}]
        if request_func_input.multi_modal_content:
            content.append(request_func_input.multi_modal_content)
324
325
326
327
328
        payload = {
            "model": request_func_input.model,
            "messages": [
                {
                    "role": "user",
329
                    "content": content
330
331
332
333
334
                },
            ],
            "temperature": 0.0,
            "max_tokens": request_func_input.output_len,
            "stream": True,
335
            "ignore_eos": request_func_input.ignore_eos,
336
337
338
        }
        headers = {
            "Content-Type": "application/json",
339
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
340
341
342
343
344
345
        }

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
346
        ttft = 0.0
347
        st = time.perf_counter()
348
        most_recent_timestamp = st
349
350
351
352
        try:
            async with session.post(url=api_url, json=payload,
                                    headers=headers) as response:
                if response.status == 200:
353
354
355
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
356
357
                            continue

358
359
                        chunk = remove_prefix(chunk_bytes.decode("utf-8"),
                                              "data: ")
360
361
362
                        if chunk == "[DONE]":
                            latency = time.perf_counter() - st
                        else:
363
364
365
                            timestamp = time.perf_counter()
                            data = json.loads(chunk)

366
367
                            delta = data["choices"][0]["delta"]
                            if delta.get("content", None):
368
                                # First token
369
                                if ttft == 0.0:
370
371
372
373
374
375
376
377
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft

                                # Decoding phase
                                else:
                                    output.itl.append(timestamp -
                                                      most_recent_timestamp)

378
                                generated_text += delta["content"]
379

380
381
                            most_recent_timestamp = timestamp

382
383
384
385
                    output.generated_text = generated_text
                    output.success = True
                    output.latency = latency
                else:
386
                    output.error = response.reason or ""
387
                    output.success = False
388
        except Exception:
389
            output.success = False
390
391
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
392
393
394
395
396
397

    if pbar:
        pbar.update(1)
    return output


398
399
# Since vllm must support Python 3.8, we can't use str.removeprefix(prefix)
# introduced in Python 3.9
400
401
402
403
404
405
def remove_prefix(text: str, prefix: str) -> str:
    if text.startswith(prefix):
        return text[len(prefix):]
    return text


406
def get_model(pretrained_model_name_or_path: str) -> str:
407
408
    if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
        from modelscope import snapshot_download
409
410
411
412
413
414
415
416

        model_path = snapshot_download(
            model_id=pretrained_model_name_or_path,
            local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
            ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])

        return model_path
    return pretrained_model_name_or_path
417
418
419
420
421
422
423
424
425
426
427
428
429


def get_tokenizer(
    pretrained_model_name_or_path: str, trust_remote_code: bool
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
    if pretrained_model_name_or_path is not None and not os.path.exists(
            pretrained_model_name_or_path):
        pretrained_model_name_or_path = get_model(
            pretrained_model_name_or_path)
    return AutoTokenizer.from_pretrained(pretrained_model_name_or_path,
                                         trust_remote_code=trust_remote_code)


430
431
ASYNC_REQUEST_FUNCS = {
    "tgi": async_request_tgi,
432
433
    "vllm": async_request_openai_completions,
    "lmdeploy": async_request_openai_completions,
434
435
    "deepspeed-mii": async_request_deepspeed_mii,
    "openai": async_request_openai_completions,
436
    "openai-chat": async_request_openai_chat_completions,
437
    "tensorrt-llm": async_request_trt_llm,
438
    "scalellm": async_request_openai_completions,
439
    "sglang": async_request_openai_completions,
440
}