backend_request_func.py 15.3 KB
Newer Older
1
2
import json
import os
3
import sys
4
import time
5
6
import traceback
from dataclasses import dataclass, field
7
from typing import List, Optional, Union
8
9

import aiohttp
10
import huggingface_hub.constants
11
from tqdm.asyncio import tqdm
12
13
from transformers import (AutoTokenizer, PreTrainedTokenizer,
                          PreTrainedTokenizerFast)
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)


@dataclass
class RequestFuncInput:
    prompt: str
    api_url: str
    prompt_len: int
    output_len: int
    model: str
    best_of: int = 1
    use_beam_search: bool = False


@dataclass
class RequestFuncOutput:
    generated_text: str = ""
    success: bool = False
33
34
    latency: float = 0.0
    ttft: float = 0.0  # Time to first token
35
36
    itl: List[float] = field(
        default_factory=list)  # List of inter-token latencies
37
    prompt_len: int = 0
38
    error: str = ""
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63


async def async_request_tgi(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        assert not request_func_input.use_beam_search
        params = {
            "best_of": request_func_input.best_of,
            "max_new_tokens": request_func_input.output_len,
            "do_sample": True,
            "temperature": 0.01,  # TGI does not accept 0.0 temperature.
            "top_p": 0.99,  # TGI does not accept 1.0 top_p.
        }
        payload = {
            "inputs": request_func_input.prompt,
            "parameters": params,
        }
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

64
        ttft = 0.0
65
        st = time.perf_counter()
66
        most_recent_timestamp = st
67
68
69
        try:
            async with session.post(url=api_url, json=payload) as response:
                if response.status == 200:
70
71
72
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
73
                            continue
74
                        chunk_bytes = chunk_bytes.decode("utf-8")
75

76
77
78
79
80
                        #NOTE: Sometimes TGI returns a ping response without
                        # any data, we should skip it.
                        if chunk_bytes.startswith(":"):
                            continue
                        chunk = remove_prefix(chunk_bytes, "data:")
81

82
83
84
                        data = json.loads(chunk)
                        timestamp = time.perf_counter()
                        # First token
85
                        if ttft == 0.0:
86
87
88
                            ttft = time.perf_counter() - st
                            output.ttft = ttft

89
90
91
92
                        # Decoding phase
                        else:
                            output.itl.append(timestamp -
                                              most_recent_timestamp)
93

94
95
96
97
98
                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
                    output.success = True
                    output.generated_text = data["generated_text"]
99
100
101
                else:
                    output.error = response.reason or ""
                    output.success = False
102
        except Exception:
103
            output.success = False
104
105
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

        if pbar:
            pbar.update(1)
        return output


async def async_request_trt_llm(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        assert not request_func_input.use_beam_search
        assert request_func_input.best_of == 1
        payload = {
            "accumulate_tokens": True,
            "text_input": request_func_input.prompt,
            "temperature": 0.0,
            "top_p": 1.0,
            "max_tokens": request_func_input.output_len,
            "stream": True,
        }
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

133
        ttft = 0.0
134
        st = time.perf_counter()
135
        most_recent_timestamp = st
136
        try:
137
138
            async with session.post(url=api_url, json=payload) as response:
                if response.status == 200:
139
140
141
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
142
143
                            continue

144
145
                        chunk = remove_prefix(chunk_bytes.decode("utf-8"),
                                              "data:")
146
147

                        data = json.loads(chunk)
148
                        output.generated_text += data["text_output"]
149
150
                        timestamp = time.perf_counter()
                        # First token
151
                        if ttft == 0.0:
152
153
154
                            ttft = time.perf_counter() - st
                            output.ttft = ttft

155
156
157
158
159
160
161
162
                        # Decoding phase
                        else:
                            output.itl.append(timestamp -
                                              most_recent_timestamp)

                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
163
164
165
                    output.success = True

                else:
166
                    output.error = response.reason or ""
167
                    output.success = False
168
        except Exception:
169
            output.success = False
170
171
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186

        if pbar:
            pbar.update(1)
        return output


async def async_request_deepspeed_mii(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        assert request_func_input.best_of == 1
        assert not request_func_input.use_beam_search

        payload = {
187
188
189
            "prompt": request_func_input.prompt,
            "max_tokens": request_func_input.output_len,
            "temperature": 0.01,  # deepspeed-mii does not accept 0.0 temp.
190
191
192
193
194
            "top_p": 1.0,
        }
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

195
        # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
196
        # will use 0 as placeholder.
197
        # See https://github.com/microsoft/DeepSpeed-MII/pull/311
198
199
200
201
202
        output.ttft = 0

        st = time.perf_counter()
        try:
            async with session.post(url=request_func_input.api_url,
203
204
205
                                    json=payload) as response:
                if response.status == 200:
                    parsed_resp = await response.json()
206
                    output.latency = time.perf_counter() - st
207
                    output.generated_text = parsed_resp["text"][0]
208
209
                    output.success = True
                else:
210
                    output.error = response.reason or ""
211
                    output.success = False
212
        except Exception:
213
            output.success = False
214
215
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
216
217
218
219
220
221
222
223
224
225
226

        if pbar:
            pbar.update(1)
        return output


async def async_request_openai_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
227
228
229
    assert api_url.endswith(
        "v1/completions"
    ), "OpenAI Completions API URL must end with 'v1/completions'."
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        assert not request_func_input.use_beam_search
        payload = {
            "model": request_func_input.model,
            "prompt": request_func_input.prompt,
            "temperature": 0.0,
            "best_of": request_func_input.best_of,
            "max_tokens": request_func_input.output_len,
            "stream": True,
        }
        headers = {
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
        }

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
249
        ttft = 0.0
250
        st = time.perf_counter()
251
        most_recent_timestamp = st
252
253
254
255
        try:
            async with session.post(url=api_url, json=payload,
                                    headers=headers) as response:
                if response.status == 200:
256
257
258
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
259
260
                            continue

261
262
                        chunk = remove_prefix(chunk_bytes.decode("utf-8"),
                                              "data: ")
263
264
265
                        if chunk == "[DONE]":
                            latency = time.perf_counter() - st
                        else:
266
267
268
269
270
                            data = json.loads(chunk)

                            if data["choices"][0]["text"]:
                                timestamp = time.perf_counter()
                                # First token
271
                                if ttft == 0.0:
272
273
274
275
276
277
278
279
280
281
282
283
284
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft

                                # Decoding phase
                                # NOTE: Some completion API might have a last
                                # usage summary response without a token so we
                                # do not want to include as inter-token-latency
                                elif data.get("usage", None) is None:
                                    output.itl.append(timestamp -
                                                      most_recent_timestamp)

                                most_recent_timestamp = timestamp
                                generated_text += data["choices"][0]["text"]
285
286
287
288

                    output.generated_text = generated_text
                    output.success = True
                    output.latency = latency
289
290
291
                else:
                    output.error = response.reason or ""
                    output.success = False
292
        except Exception:
293
            output.success = False
294
295
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
296
297
298
299
300
301

    if pbar:
        pbar.update(1)
    return output


302
303
304
305
306
307
308
async def async_request_openai_chat_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith(
        "v1/chat/completions"
309
    ), "OpenAI Chat Completions API URL must end with 'v1/chat/completions'."
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        assert not request_func_input.use_beam_search
        payload = {
            "model": request_func_input.model,
            "messages": [
                {
                    "role": "user",
                    "content": request_func_input.prompt,
                },
            ],
            "temperature": 0.0,
            "max_tokens": request_func_input.output_len,
            "stream": True,
        }
        headers = {
            "Content-Type": "application/json",
327
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
328
329
330
331
332
333
        }

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
334
        ttft = 0.0
335
        st = time.perf_counter()
336
        most_recent_timestamp = st
337
338
339
340
        try:
            async with session.post(url=api_url, json=payload,
                                    headers=headers) as response:
                if response.status == 200:
341
342
343
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
344
345
                            continue

346
347
                        chunk = remove_prefix(chunk_bytes.decode("utf-8"),
                                              "data: ")
348
349
350
                        if chunk == "[DONE]":
                            latency = time.perf_counter() - st
                        else:
351
352
353
                            timestamp = time.perf_counter()
                            data = json.loads(chunk)

354
355
                            delta = data["choices"][0]["delta"]
                            if delta.get("content", None):
356
                                # First token
357
                                if ttft == 0.0:
358
359
360
361
362
363
364
365
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft

                                # Decoding phase
                                else:
                                    output.itl.append(timestamp -
                                                      most_recent_timestamp)

366
                                generated_text += delta["content"]
367

368
369
                            most_recent_timestamp = timestamp

370
371
372
373
                    output.generated_text = generated_text
                    output.success = True
                    output.latency = latency
                else:
374
                    output.error = response.reason or ""
375
                    output.success = False
376
        except Exception:
377
            output.success = False
378
379
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
380
381
382
383
384
385

    if pbar:
        pbar.update(1)
    return output


386
387
# Since vllm must support Python 3.8, we can't use str.removeprefix(prefix)
# introduced in Python 3.9
388
389
390
391
392
393
def remove_prefix(text: str, prefix: str) -> str:
    if text.startswith(prefix):
        return text[len(prefix):]
    return text


394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
def get_model(pretrained_model_name_or_path: str):
    if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
        from modelscope import snapshot_download
    else:
        from huggingface_hub import snapshot_download

    model_path = snapshot_download(
        model_id=pretrained_model_name_or_path,
        local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
        ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
    return model_path


def get_tokenizer(
    pretrained_model_name_or_path: str, trust_remote_code: bool
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
    if pretrained_model_name_or_path is not None and not os.path.exists(
            pretrained_model_name_or_path):
        pretrained_model_name_or_path = get_model(
            pretrained_model_name_or_path)
    return AutoTokenizer.from_pretrained(pretrained_model_name_or_path,
                                         trust_remote_code=trust_remote_code)


418
419
ASYNC_REQUEST_FUNCS = {
    "tgi": async_request_tgi,
420
421
    "vllm": async_request_openai_completions,
    "lmdeploy": async_request_openai_completions,
422
423
    "deepspeed-mii": async_request_deepspeed_mii,
    "openai": async_request_openai_completions,
424
    "openai-chat": async_request_openai_chat_completions,
425
426
    "tensorrt-llm": async_request_trt_llm,
}