backend_request_func.py 15.2 KB
Newer Older
1
2
import json
import os
3
import sys
4
import time
5
6
import traceback
from dataclasses import dataclass, field
7
from typing import List, Optional, Union
8
9

import aiohttp
10
import huggingface_hub.constants
11
from tqdm.asyncio import tqdm
12
13
from transformers import (AutoTokenizer, PreTrainedTokenizer,
                          PreTrainedTokenizerFast)
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)


@dataclass
class RequestFuncInput:
    prompt: str
    api_url: str
    prompt_len: int
    output_len: int
    model: str
    best_of: int = 1
    use_beam_search: bool = False


@dataclass
class RequestFuncOutput:
    generated_text: str = ""
    success: bool = False
33
34
    latency: float = 0.0
    ttft: float = 0.0  # Time to first token
35
36
    itl: List[float] = field(
        default_factory=list)  # List of inter-token latencies
37
    prompt_len: int = 0
38
    error: str = ""
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63


async def async_request_tgi(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        assert not request_func_input.use_beam_search
        params = {
            "best_of": request_func_input.best_of,
            "max_new_tokens": request_func_input.output_len,
            "do_sample": True,
            "temperature": 0.01,  # TGI does not accept 0.0 temperature.
            "top_p": 0.99,  # TGI does not accept 1.0 top_p.
        }
        payload = {
            "inputs": request_func_input.prompt,
            "parameters": params,
        }
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

64
        ttft = 0.0
65
        st = time.perf_counter()
66
        most_recent_timestamp = st
67
68
69
        try:
            async with session.post(url=api_url, json=payload) as response:
                if response.status == 200:
70
71
72
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
73
                            continue
74
                        chunk_bytes = chunk_bytes.decode("utf-8")
75

76
77
78
79
80
                        #NOTE: Sometimes TGI returns a ping response without
                        # any data, we should skip it.
                        if chunk_bytes.startswith(":"):
                            continue
                        chunk = remove_prefix(chunk_bytes, "data:")
81

82
83
84
                        data = json.loads(chunk)
                        timestamp = time.perf_counter()
                        # First token
85
                        if ttft == 0.0:
86
87
88
                            ttft = time.perf_counter() - st
                            output.ttft = ttft

89
90
91
92
                        # Decoding phase
                        else:
                            output.itl.append(timestamp -
                                              most_recent_timestamp)
93

94
95
96
97
98
                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
                    output.success = True
                    output.generated_text = data["generated_text"]
99
100
101
                else:
                    output.error = response.reason or ""
                    output.success = False
102
        except Exception:
103
            output.success = False
104
105
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

        if pbar:
            pbar.update(1)
        return output


async def async_request_trt_llm(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        assert not request_func_input.use_beam_search
        assert request_func_input.best_of == 1
        payload = {
            "accumulate_tokens": True,
            "text_input": request_func_input.prompt,
            "temperature": 0.0,
            "top_p": 1.0,
            "max_tokens": request_func_input.output_len,
            "stream": True,
        }
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

133
        ttft = 0.0
134
        st = time.perf_counter()
135
        most_recent_timestamp = st
136
        try:
137
138
            async with session.post(url=api_url, json=payload) as response:
                if response.status == 200:
139
140
141
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
142
143
                            continue

144
145
                        chunk = remove_prefix(chunk_bytes.decode("utf-8"),
                                              "data:")
146
147

                        data = json.loads(chunk)
148
                        output.generated_text += data["text_output"]
149
150
                        timestamp = time.perf_counter()
                        # First token
151
                        if ttft == 0.0:
152
153
154
                            ttft = time.perf_counter() - st
                            output.ttft = ttft

155
156
157
158
159
160
161
162
                        # Decoding phase
                        else:
                            output.itl.append(timestamp -
                                              most_recent_timestamp)

                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
163
164
165
                    output.success = True

                else:
166
                    output.error = response.reason or ""
167
                    output.success = False
168
        except Exception:
169
            output.success = False
170
171
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186

        if pbar:
            pbar.update(1)
        return output


async def async_request_deepspeed_mii(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        assert request_func_input.best_of == 1
        assert not request_func_input.use_beam_search

        payload = {
187
188
189
            "prompt": request_func_input.prompt,
            "max_tokens": request_func_input.output_len,
            "temperature": 0.01,  # deepspeed-mii does not accept 0.0 temp.
190
191
192
193
194
            "top_p": 1.0,
        }
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

195
        # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
196
        # will use 0 as placeholder.
197
        # See https://github.com/microsoft/DeepSpeed-MII/pull/311
198
199
200
201
202
        output.ttft = 0

        st = time.perf_counter()
        try:
            async with session.post(url=request_func_input.api_url,
203
204
205
                                    json=payload) as response:
                if response.status == 200:
                    parsed_resp = await response.json()
206
                    output.latency = time.perf_counter() - st
207
                    output.generated_text = parsed_resp["text"][0]
208
209
                    output.success = True
                else:
210
                    output.error = response.reason or ""
211
                    output.success = False
212
        except Exception:
213
            output.success = False
214
215
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
216
217
218
219
220
221
222
223
224
225
226

        if pbar:
            pbar.update(1)
        return output


async def async_request_openai_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
227
    assert api_url.endswith(
228
229
        "completions"
    ), "OpenAI Completions API URL must end with 'completions'."
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        assert not request_func_input.use_beam_search
        payload = {
            "model": request_func_input.model,
            "prompt": request_func_input.prompt,
            "temperature": 0.0,
            "best_of": request_func_input.best_of,
            "max_tokens": request_func_input.output_len,
            "stream": True,
        }
        headers = {
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
        }

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
249
        ttft = 0.0
250
        st = time.perf_counter()
251
        most_recent_timestamp = st
252
253
254
255
        try:
            async with session.post(url=api_url, json=payload,
                                    headers=headers) as response:
                if response.status == 200:
256
257
258
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
259
260
                            continue

261
262
                        chunk = remove_prefix(chunk_bytes.decode("utf-8"),
                                              "data: ")
263
264
265
                        if chunk == "[DONE]":
                            latency = time.perf_counter() - st
                        else:
266
267
                            data = json.loads(chunk)

268
269
270
                            # NOTE: Some completion API might have a last
                            # usage summary response without a token so we
                            # want to check a token was generated
271
272
273
                            if data["choices"][0]["text"]:
                                timestamp = time.perf_counter()
                                # First token
274
                                if ttft == 0.0:
275
276
277
278
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft

                                # Decoding phase
279
280
                                output.itl.append(timestamp -
                                                  most_recent_timestamp)
281
282
283

                                most_recent_timestamp = timestamp
                                generated_text += data["choices"][0]["text"]
284
285
286
287

                    output.generated_text = generated_text
                    output.success = True
                    output.latency = latency
288
289
290
                else:
                    output.error = response.reason or ""
                    output.success = False
291
        except Exception:
292
            output.success = False
293
294
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
295
296
297
298
299
300

    if pbar:
        pbar.update(1)
    return output


301
302
303
304
305
306
async def async_request_openai_chat_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith(
307
308
        "chat/completions"
    ), "OpenAI Chat Completions API URL must end with 'chat/completions'."
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        assert not request_func_input.use_beam_search
        payload = {
            "model": request_func_input.model,
            "messages": [
                {
                    "role": "user",
                    "content": request_func_input.prompt,
                },
            ],
            "temperature": 0.0,
            "max_tokens": request_func_input.output_len,
            "stream": True,
        }
        headers = {
            "Content-Type": "application/json",
326
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
327
328
329
330
331
332
        }

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
333
        ttft = 0.0
334
        st = time.perf_counter()
335
        most_recent_timestamp = st
336
337
338
339
        try:
            async with session.post(url=api_url, json=payload,
                                    headers=headers) as response:
                if response.status == 200:
340
341
342
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
343
344
                            continue

345
346
                        chunk = remove_prefix(chunk_bytes.decode("utf-8"),
                                              "data: ")
347
348
349
                        if chunk == "[DONE]":
                            latency = time.perf_counter() - st
                        else:
350
351
352
                            timestamp = time.perf_counter()
                            data = json.loads(chunk)

353
354
                            delta = data["choices"][0]["delta"]
                            if delta.get("content", None):
355
                                # First token
356
                                if ttft == 0.0:
357
358
359
360
361
362
363
364
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft

                                # Decoding phase
                                else:
                                    output.itl.append(timestamp -
                                                      most_recent_timestamp)

365
                                generated_text += delta["content"]
366

367
368
                            most_recent_timestamp = timestamp

369
370
371
372
                    output.generated_text = generated_text
                    output.success = True
                    output.latency = latency
                else:
373
                    output.error = response.reason or ""
374
                    output.success = False
375
        except Exception:
376
            output.success = False
377
378
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
379
380
381
382
383
384

    if pbar:
        pbar.update(1)
    return output


385
386
# Since vllm must support Python 3.8, we can't use str.removeprefix(prefix)
# introduced in Python 3.9
387
388
389
390
391
392
def remove_prefix(text: str, prefix: str) -> str:
    if text.startswith(prefix):
        return text[len(prefix):]
    return text


393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
def get_model(pretrained_model_name_or_path: str):
    if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true':
        from modelscope import snapshot_download
    else:
        from huggingface_hub import snapshot_download

    model_path = snapshot_download(
        model_id=pretrained_model_name_or_path,
        local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
        ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"])
    return model_path


def get_tokenizer(
    pretrained_model_name_or_path: str, trust_remote_code: bool
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
    if pretrained_model_name_or_path is not None and not os.path.exists(
            pretrained_model_name_or_path):
        pretrained_model_name_or_path = get_model(
            pretrained_model_name_or_path)
    return AutoTokenizer.from_pretrained(pretrained_model_name_or_path,
                                         trust_remote_code=trust_remote_code)


417
418
ASYNC_REQUEST_FUNCS = {
    "tgi": async_request_tgi,
419
420
    "vllm": async_request_openai_completions,
    "lmdeploy": async_request_openai_completions,
421
422
    "deepspeed-mii": async_request_deepspeed_mii,
    "openai": async_request_openai_completions,
423
    "openai-chat": async_request_openai_chat_completions,
424
    "tensorrt-llm": async_request_trt_llm,
425
    "scalellm": async_request_openai_completions,
426
}