backend_request_func.py 14.2 KB
Newer Older
1
2
import json
import os
3
import sys
4
import time
5
6
7
import traceback
from dataclasses import dataclass, field
from typing import List, Optional
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

import aiohttp
from tqdm.asyncio import tqdm

AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)


@dataclass
class RequestFuncInput:
    prompt: str
    api_url: str
    prompt_len: int
    output_len: int
    model: str
    best_of: int = 1
    use_beam_search: bool = False


@dataclass
class RequestFuncOutput:
    generated_text: str = ""
    success: bool = False
30
31
    latency: float = 0.0
    ttft: float = 0.0  # Time to first token
32
33
    itl: List[float] = field(
        default_factory=list)  # List of inter-token latencies
34
    prompt_len: int = 0
35
    error: str = ""
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60


async def async_request_tgi(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        assert not request_func_input.use_beam_search
        params = {
            "best_of": request_func_input.best_of,
            "max_new_tokens": request_func_input.output_len,
            "do_sample": True,
            "temperature": 0.01,  # TGI does not accept 0.0 temperature.
            "top_p": 0.99,  # TGI does not accept 1.0 top_p.
        }
        payload = {
            "inputs": request_func_input.prompt,
            "parameters": params,
        }
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

61
        ttft = 0.0
62
        st = time.perf_counter()
63
        most_recent_timestamp = st
64
65
66
        try:
            async with session.post(url=api_url, json=payload) as response:
                if response.status == 200:
67
68
69
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
70
                            continue
71
                        chunk_bytes = chunk_bytes.decode("utf-8")
72

73
74
75
76
77
                        #NOTE: Sometimes TGI returns a ping response without
                        # any data, we should skip it.
                        if chunk_bytes.startswith(":"):
                            continue
                        chunk = remove_prefix(chunk_bytes, "data:")
78

79
80
81
                        data = json.loads(chunk)
                        timestamp = time.perf_counter()
                        # First token
82
                        if ttft == 0.0:
83
84
85
                            ttft = time.perf_counter() - st
                            output.ttft = ttft

86
87
88
89
                        # Decoding phase
                        else:
                            output.itl.append(timestamp -
                                              most_recent_timestamp)
90

91
92
93
94
95
                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
                    output.success = True
                    output.generated_text = data["generated_text"]
96
97
98
                else:
                    output.error = response.reason or ""
                    output.success = False
99
        except Exception:
100
            output.success = False
101
102
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

        if pbar:
            pbar.update(1)
        return output


async def async_request_trt_llm(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        assert not request_func_input.use_beam_search
        assert request_func_input.best_of == 1
        payload = {
            "accumulate_tokens": True,
            "text_input": request_func_input.prompt,
            "temperature": 0.0,
            "top_p": 1.0,
            "max_tokens": request_func_input.output_len,
            "stream": True,
        }
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

130
        ttft = 0.0
131
        st = time.perf_counter()
132
        most_recent_timestamp = st
133
        try:
134
135
            async with session.post(url=api_url, json=payload) as response:
                if response.status == 200:
136
137
138
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
139
140
                            continue

141
142
                        chunk = remove_prefix(chunk_bytes.decode("utf-8"),
                                              "data:")
143
144

                        data = json.loads(chunk)
145
                        output.generated_text += data["text_output"]
146
147
                        timestamp = time.perf_counter()
                        # First token
148
                        if ttft == 0.0:
149
150
151
                            ttft = time.perf_counter() - st
                            output.ttft = ttft

152
153
154
155
156
157
158
159
                        # Decoding phase
                        else:
                            output.itl.append(timestamp -
                                              most_recent_timestamp)

                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
160
161
162
                    output.success = True

                else:
163
                    output.error = response.reason or ""
164
                    output.success = False
165
        except Exception:
166
            output.success = False
167
168
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183

        if pbar:
            pbar.update(1)
        return output


async def async_request_deepspeed_mii(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        assert request_func_input.best_of == 1
        assert not request_func_input.use_beam_search

        payload = {
184
185
186
            "prompt": request_func_input.prompt,
            "max_tokens": request_func_input.output_len,
            "temperature": 0.01,  # deepspeed-mii does not accept 0.0 temp.
187
188
189
190
191
            "top_p": 1.0,
        }
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

192
        # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
193
        # will use 0 as placeholder.
194
        # See https://github.com/microsoft/DeepSpeed-MII/pull/311
195
196
197
198
199
        output.ttft = 0

        st = time.perf_counter()
        try:
            async with session.post(url=request_func_input.api_url,
200
201
202
                                    json=payload) as response:
                if response.status == 200:
                    parsed_resp = await response.json()
203
                    output.latency = time.perf_counter() - st
204
                    output.generated_text = parsed_resp["text"][0]
205
206
                    output.success = True
                else:
207
                    output.error = response.reason or ""
208
                    output.success = False
209
        except Exception:
210
            output.success = False
211
212
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
213
214
215
216
217
218
219
220
221
222
223

        if pbar:
            pbar.update(1)
        return output


async def async_request_openai_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
224
225
226
    assert api_url.endswith(
        "v1/completions"
    ), "OpenAI Completions API URL must end with 'v1/completions'."
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        assert not request_func_input.use_beam_search
        payload = {
            "model": request_func_input.model,
            "prompt": request_func_input.prompt,
            "temperature": 0.0,
            "best_of": request_func_input.best_of,
            "max_tokens": request_func_input.output_len,
            "stream": True,
        }
        headers = {
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
        }

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
246
        ttft = 0.0
247
        st = time.perf_counter()
248
        most_recent_timestamp = st
249
250
251
252
        try:
            async with session.post(url=api_url, json=payload,
                                    headers=headers) as response:
                if response.status == 200:
253
254
255
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
256
257
                            continue

258
259
                        chunk = remove_prefix(chunk_bytes.decode("utf-8"),
                                              "data: ")
260
261
262
                        if chunk == "[DONE]":
                            latency = time.perf_counter() - st
                        else:
263
264
265
266
267
                            data = json.loads(chunk)

                            if data["choices"][0]["text"]:
                                timestamp = time.perf_counter()
                                # First token
268
                                if ttft == 0.0:
269
270
271
272
273
274
275
276
277
278
279
280
281
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft

                                # Decoding phase
                                # NOTE: Some completion API might have a last
                                # usage summary response without a token so we
                                # do not want to include as inter-token-latency
                                elif data.get("usage", None) is None:
                                    output.itl.append(timestamp -
                                                      most_recent_timestamp)

                                most_recent_timestamp = timestamp
                                generated_text += data["choices"][0]["text"]
282
283
284
285

                    output.generated_text = generated_text
                    output.success = True
                    output.latency = latency
286
287
288
                else:
                    output.error = response.reason or ""
                    output.success = False
289
        except Exception:
290
            output.success = False
291
292
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
293
294
295
296
297
298

    if pbar:
        pbar.update(1)
    return output


299
300
301
302
303
304
305
async def async_request_openai_chat_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith(
        "v1/chat/completions"
306
    ), "OpenAI Chat Completions API URL must end with 'v1/chat/completions'."
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        assert not request_func_input.use_beam_search
        payload = {
            "model": request_func_input.model,
            "messages": [
                {
                    "role": "user",
                    "content": request_func_input.prompt,
                },
            ],
            "temperature": 0.0,
            "max_tokens": request_func_input.output_len,
            "stream": True,
        }
        headers = {
            "Content-Type": "application/json",
324
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
325
326
327
328
329
330
        }

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
331
        ttft = 0.0
332
        st = time.perf_counter()
333
        most_recent_timestamp = st
334
335
336
337
        try:
            async with session.post(url=api_url, json=payload,
                                    headers=headers) as response:
                if response.status == 200:
338
339
340
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
341
342
                            continue

343
344
                        chunk = remove_prefix(chunk_bytes.decode("utf-8"),
                                              "data: ")
345
346
347
                        if chunk == "[DONE]":
                            latency = time.perf_counter() - st
                        else:
348
349
350
                            timestamp = time.perf_counter()
                            data = json.loads(chunk)

351
352
                            delta = data["choices"][0]["delta"]
                            if delta.get("content", None):
353
                                # First token
354
                                if ttft == 0.0:
355
356
357
358
359
360
361
362
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft

                                # Decoding phase
                                else:
                                    output.itl.append(timestamp -
                                                      most_recent_timestamp)

363
                                generated_text += delta["content"]
364

365
366
                            most_recent_timestamp = timestamp

367
368
369
370
                    output.generated_text = generated_text
                    output.success = True
                    output.latency = latency
                else:
371
                    output.error = response.reason or ""
372
                    output.success = False
373
        except Exception:
374
            output.success = False
375
376
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
377
378
379
380
381
382

    if pbar:
        pbar.update(1)
    return output


383
384
# Since vllm must support Python 3.8, we can't use str.removeprefix(prefix)
# introduced in Python 3.9
385
386
387
388
389
390
def remove_prefix(text: str, prefix: str) -> str:
    if text.startswith(prefix):
        return text[len(prefix):]
    return text


391
392
ASYNC_REQUEST_FUNCS = {
    "tgi": async_request_tgi,
393
394
    "vllm": async_request_openai_completions,
    "lmdeploy": async_request_openai_completions,
395
396
    "deepspeed-mii": async_request_deepspeed_mii,
    "openai": async_request_openai_completions,
397
    "openai-chat": async_request_openai_chat_completions,
398
399
    "tensorrt-llm": async_request_trt_llm,
}