bench_serving.py 35.4 KB
Newer Older
zhyncs's avatar
zhyncs committed
1
2
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py
3

Ying Sheng's avatar
Ying Sheng committed
4
"""
5
Benchmark online serving with dynamic requests.
Ying Sheng's avatar
Ying Sheng committed
6
7

Usage:
8
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
Ying Sheng's avatar
Ying Sheng committed
9

10
11
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5
python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi
Ying Sheng's avatar
Ying Sheng committed
12
"""
zhyncs's avatar
zhyncs committed
13
14
15
16
17
18
19
20
21
22
23

import argparse
import asyncio
import json
import os
import random
import resource
import sys
import time
import traceback
import warnings
24
from argparse import ArgumentParser
zhyncs's avatar
zhyncs committed
25
from dataclasses import dataclass, field
26
from datetime import datetime
27
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
zhyncs's avatar
zhyncs committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41

import aiohttp
import numpy as np
import requests
from tqdm.asyncio import tqdm
from transformers import (
    AutoTokenizer,
    PreTrainedTokenizer,
    PreTrainedTokenizerBase,
    PreTrainedTokenizerFast,
)

AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)

42
43
global args

zhyncs's avatar
zhyncs committed
44
45
46
47
48
49
50
51

@dataclass
class RequestFuncInput:
    prompt: str
    api_url: str
    prompt_len: int
    output_len: int
    model: str
52
    extra_request_body: Dict[str, Any]
zhyncs's avatar
zhyncs committed
53
54
55
56
57
58
59
60
61
62
63


@dataclass
class RequestFuncOutput:
    generated_text: str = ""
    success: bool = False
    latency: float = 0.0
    ttft: float = 0.0  # Time to first token
    itl: List[float] = field(default_factory=list)  # List of inter-token latencies
    prompt_len: int = 0
    error: str = ""
64
    output_len: int = 0
zhyncs's avatar
zhyncs committed
65
66
67
68
69
70


def remove_prefix(text: str, prefix: str) -> str:
    return text[len(prefix) :] if text.startswith(prefix) else text


71
72
73
74
75
76
77
78
79
80
81
82
83
# trt llm not support ignore_eos
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
async def async_request_trt_llm(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        payload = {
            "accumulate_tokens": True,
            "text_input": request_func_input.prompt,
zhyncs's avatar
zhyncs committed
84
            "temperature": 0.000001,
85
86
87
            "top_p": 1.0,
            "max_tokens": request_func_input.output_len,
            "stream": True,
Ying Sheng's avatar
Ying Sheng committed
88
89
            "min_length": request_func_input.output_len,
            "end_id": 1048576,
90
            **request_func_input.extra_request_body,
91
        }
92
93
94
        if args.disable_ignore_eos:
            del payload["min_length"]
            del payload["end_id"]
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        ttft = 0.0
        st = time.perf_counter()
        most_recent_timestamp = st
        try:
            async with session.post(url=api_url, json=payload) as response:
                if response.status == 200:
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
                            continue

                        chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:")

                        data = json.loads(chunk)
                        output.generated_text += data["text_output"]
                        timestamp = time.perf_counter()
                        # First token
                        if ttft == 0.0:
                            ttft = time.perf_counter() - st
                            output.ttft = ttft

                        # Decoding phase
                        else:
                            output.itl.append(timestamp - most_recent_timestamp)

                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
                    output.success = True
Ying Sheng's avatar
Ying Sheng committed
127
                    output.output_len = request_func_input.output_len
128
129
130
131
132
133
134
135
136
137
138
139
140
141

                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

        if pbar:
            pbar.update(1)
        return output


zhyncs's avatar
zhyncs committed
142
143
144
145
146
147
148
149
150
151
# set ignore_eos True by default
async def async_request_openai_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith(
        "completions"
    ), "OpenAI Completions API URL must end with 'completions'."

Lianmin Zheng's avatar
Lianmin Zheng committed
152
153
    prompt = request_func_input.prompt

zhyncs's avatar
zhyncs committed
154
155
156
    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        payload = {
            "model": request_func_input.model,
Lianmin Zheng's avatar
Lianmin Zheng committed
157
            "prompt": prompt,
zhyncs's avatar
zhyncs committed
158
159
160
            "temperature": 0.0,
            "best_of": 1,
            "max_tokens": request_func_input.output_len,
161
            "stream": not args.disable_stream,
162
            "ignore_eos": not args.disable_ignore_eos,
163
            **request_func_input.extra_request_body,
zhyncs's avatar
zhyncs committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        }
        headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
        ttft = 0.0
        st = time.perf_counter()
        most_recent_timestamp = st
        try:
            async with session.post(
                url=api_url, json=payload, headers=headers
            ) as response:
                if response.status == 200:
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
                            continue

                        chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
185
                        latency = time.perf_counter() - st
zhyncs's avatar
zhyncs committed
186
                        if chunk == "[DONE]":
187
                            pass
zhyncs's avatar
zhyncs committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
                        else:
                            data = json.loads(chunk)

                            # NOTE: Some completion API might have a last
                            # usage summary response without a token so we
                            # want to check a token was generated
                            if data["choices"][0]["text"]:
                                timestamp = time.perf_counter()
                                # First token
                                if ttft == 0.0:
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft

                                # Decoding phase
202
203
                                else:
                                    output.itl.append(timestamp - most_recent_timestamp)
zhyncs's avatar
zhyncs committed
204
205
206
207
208
209
210

                                most_recent_timestamp = timestamp
                                generated_text += data["choices"][0]["text"]

                    output.generated_text = generated_text
                    output.success = True
                    output.latency = latency
211
                    output.output_len = request_func_input.output_len
zhyncs's avatar
zhyncs committed
212
213
214
215
216
217
218
219
220
221
222
223
224
                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

    if pbar:
        pbar.update(1)
    return output


225
async def async_request_gserver(
Lianmin Zheng's avatar
Lianmin Zheng committed
226
227
228
229
230
231
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    raise NotImplementedError()


zhyncs's avatar
zhyncs committed
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
def get_model(pretrained_model_name_or_path: str) -> str:
    if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true":
        import huggingface_hub.constants
        from modelscope import snapshot_download

        model_path = snapshot_download(
            model_id=pretrained_model_name_or_path,
            local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
            ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
        )

        return model_path
    return pretrained_model_name_or_path


def get_tokenizer(
    pretrained_model_name_or_path: str,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
Lianmin Zheng's avatar
Lianmin Zheng committed
250
251
252
253
254
255
256
    if pretrained_model_name_or_path.endswith(
        ".json"
    ) or pretrained_model_name_or_path.endswith(".model"):
        from sglang.srt.hf_transformers_utils import get_tokenizer

        return get_tokenizer(pretrained_model_name_or_path)

zhyncs's avatar
zhyncs committed
257
258
259
260
261
262
263
264
265
266
267
268
269
    if pretrained_model_name_or_path is not None and not os.path.exists(
        pretrained_model_name_or_path
    ):
        pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)
    return AutoTokenizer.from_pretrained(
        pretrained_model_name_or_path, trust_remote_code=True
    )


ASYNC_REQUEST_FUNCS = {
    "sglang": async_request_openai_completions,
    "vllm": async_request_openai_completions,
    "lmdeploy": async_request_openai_completions,
270
    "trt": async_request_trt_llm,
271
    "gserver": async_request_gserver,
zhyncs's avatar
zhyncs committed
272
273
274
275
276
277
278
279
}


@dataclass
class BenchmarkMetrics:
    completed: int
    total_input: int
    total_output: int
Ying Sheng's avatar
Ying Sheng committed
280
    total_output_retokenized: int
zhyncs's avatar
zhyncs committed
281
282
283
    request_throughput: float
    input_throughput: float
    output_throughput: float
Ying Sheng's avatar
Ying Sheng committed
284
    output_throughput_retokenized: float
zhyncs's avatar
zhyncs committed
285
286
287
288
289
290
291
292
293
294
295
296
    mean_ttft_ms: float
    median_ttft_ms: float
    std_ttft_ms: float
    p99_ttft_ms: float
    mean_tpot_ms: float
    median_tpot_ms: float
    std_tpot_ms: float
    p99_tpot_ms: float
    mean_itl_ms: float
    median_itl_ms: float
    std_itl_ms: float
    p99_itl_ms: float
zhyncs's avatar
zhyncs committed
297
298
    mean_e2e_latency_ms: float
    median_e2e_latency_ms: float
zhyncs's avatar
zhyncs committed
299
300


Lianmin Zheng's avatar
Lianmin Zheng committed
301
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
Lianmin Zheng's avatar
Lianmin Zheng committed
302
303


Lianmin Zheng's avatar
Lianmin Zheng committed
304
305
306
307
def download_and_cache_file(url: str, filename: Optional[str] = None):
    """Read and cache a file from a url."""
    if filename is None:
        filename = os.path.join("/tmp", url.split("/")[-1])
Lianmin Zheng's avatar
Lianmin Zheng committed
308

Lianmin Zheng's avatar
Lianmin Zheng committed
309
310
311
    # Check if the cache file already exists
    if os.path.exists(filename):
        return filename
Lianmin Zheng's avatar
Lianmin Zheng committed
312

Lianmin Zheng's avatar
Lianmin Zheng committed
313
    print(f"Downloading from {url} to {filename}")
Lianmin Zheng's avatar
Lianmin Zheng committed
314

Lianmin Zheng's avatar
Lianmin Zheng committed
315
316
317
    # Stream the response to show the progress bar
    response = requests.get(url, stream=True)
    response.raise_for_status()  # Check for request errors
Lianmin Zheng's avatar
Lianmin Zheng committed
318

Lianmin Zheng's avatar
Lianmin Zheng committed
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    # Total size of the file in bytes
    total_size = int(response.headers.get("content-length", 0))
    chunk_size = 1024  # Download in chunks of 1KB

    # Use tqdm to display the progress bar
    with open(filename, "wb") as f, tqdm(
        desc=filename,
        total=total_size,
        unit="B",
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for chunk in response.iter_content(chunk_size=chunk_size):
            f.write(chunk)
            bar.update(len(chunk))

    return filename
Lianmin Zheng's avatar
Lianmin Zheng committed
336
337


zhyncs's avatar
zhyncs committed
338
339
340
341
342
343
344
345
346
def sample_sharegpt_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
    fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, int, int]]:
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")

Lianmin Zheng's avatar
Lianmin Zheng committed
347
    # Download sharegpt if necessary
Lianmin Zheng's avatar
Lianmin Zheng committed
348
349
    if not os.path.isfile(dataset_path):
        dataset_path = download_and_cache_file(SHAREGPT_URL)
zhyncs's avatar
zhyncs committed
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372

    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
    # Only keep the first two turns of each conversation.
    dataset = [
        (data["conversations"][0]["value"], data["conversations"][1]["value"])
        for data in dataset
    ]

    # Shuffle the dataset.
    random.shuffle(dataset)

    # Filter out sequences that are too long or too short
    filtered_dataset: List[Tuple[str, int, int]] = []
    for i in range(len(dataset)):
        if len(filtered_dataset) == num_requests:
            break

        # Tokenize the prompts and completions.
        prompt = dataset[i][0]
Lianmin Zheng's avatar
Lianmin Zheng committed
373
        prompt_token_ids = tokenizer.encode(prompt)
zhyncs's avatar
zhyncs committed
374
        completion = dataset[i][1]
Lianmin Zheng's avatar
Lianmin Zheng committed
375
        completion_token_ids = tokenizer.encode(completion)
zhyncs's avatar
zhyncs committed
376
377
378
379
380
381
382
        prompt_len = len(prompt_token_ids)
        output_len = (
            len(completion_token_ids) if fixed_output_len is None else fixed_output_len
        )
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
Lianmin Zheng's avatar
Lianmin Zheng committed
383
384
385
        if prompt_len > 1024 or (
            prompt_len + output_len > 2048 and fixed_output_len is None
        ):
zhyncs's avatar
zhyncs committed
386
387
388
389
390
391
392
            # Prune too long sequences.
            continue
        filtered_dataset.append((prompt, prompt_len, output_len))

    return filtered_dataset


393
394
395
396
397
398
def sample_random_requests(
    input_len: int,
    output_len: int,
    num_prompts: int,
    range_ratio: float,
    tokenizer: PreTrainedTokenizerBase,
Lianmin Zheng's avatar
Lianmin Zheng committed
399
    dataset_path: str,
400
401
402
) -> List[Tuple[str, int, int]]:

    input_lens = np.random.randint(
Yineng Zhang's avatar
Yineng Zhang committed
403
        max(int(input_len * range_ratio), 1),
404
405
406
407
408
409
410
411
        input_len + 1,
        size=num_prompts,
    )
    output_lens = np.random.randint(
        int(output_len * range_ratio),
        output_len + 1,
        size=num_prompts,
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
412
413
414
415
416

    if True:
        # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens

        # Download sharegpt if necessary
Lianmin Zheng's avatar
Lianmin Zheng committed
417
418
        if not os.path.isfile(dataset_path):
            dataset_path = download_and_cache_file(SHAREGPT_URL)
Lianmin Zheng's avatar
Lianmin Zheng committed
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438

        # Load the dataset.
        with open(dataset_path) as f:
            dataset = json.load(f)
        # Filter out the conversations with less than 2 turns.
        dataset = [data for data in dataset if len(data["conversations"]) >= 2]
        # Only keep the first two turns of each conversation.
        dataset = [
            (data["conversations"][0]["value"], data["conversations"][1]["value"])
            for data in dataset
        ]

        # Shuffle the dataset.
        random.shuffle(dataset)

        # Filter out sequences that are too long or too short
        input_requests: List[Tuple[str, int, int]] = []
        for i in range(num_prompts):
            # Tokenize the prompts and completions.
            prompt = dataset[i][0]
Lianmin Zheng's avatar
Lianmin Zheng committed
439
            prompt_token_ids = tokenizer.encode(prompt)
Lianmin Zheng's avatar
Lianmin Zheng committed
440
441
            prompt_len = len(prompt_token_ids)

Yineng Zhang's avatar
Yineng Zhang committed
442
            if prompt_len > input_lens[i]:
Lianmin Zheng's avatar
Lianmin Zheng committed
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
                input_ids = prompt_token_ids[: input_lens[i]]
            else:
                ratio = (input_lens[i] + prompt_len - 1) // prompt_len
                input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
            prompt = tokenizer.decode(input_ids)
            input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
    else:
        # Sample token ids from random integers. This can cause some NaN issues.
        offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
        input_requests = []
        for i in range(num_prompts):
            prompt = tokenizer.decode(
                [
                    (offsets[i] + i + j) % tokenizer.vocab_size
                    for j in range(input_lens[i])
                ]
            )
            input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
461
462
463
464
465
466

    print(f"#Input tokens: {np.sum(input_lens)}")
    print(f"#Output tokens: {np.sum(output_lens)}")
    return input_requests


zhyncs's avatar
zhyncs committed
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
async def get_request(
    input_requests: List[Tuple[str, int, int]],
    request_rate: float,
) -> AsyncGenerator[Tuple[str, int, int], None]:
    input_requests = iter(input_requests)
    for request in input_requests:
        yield request

        if request_rate == float("inf"):
            # If the request rate is infinity, then we don't need to wait.
            continue

        # Sample the request interval from the exponential distribution.
        interval = np.random.exponential(1.0 / request_rate)
        # The next request will be sent after the interval.
        await asyncio.sleep(interval)


def calculate_metrics(
    input_requests: List[Tuple[str, int, int]],
    outputs: List[RequestFuncOutput],
    dur_s: float,
    tokenizer: PreTrainedTokenizerBase,
490
    backend: str,
zhyncs's avatar
zhyncs committed
491
) -> Tuple[BenchmarkMetrics, List[int]]:
Ying Sheng's avatar
Ying Sheng committed
492
493
    output_lens: List[int] = []
    retokenized_output_lens: List[int] = []
zhyncs's avatar
zhyncs committed
494
495
496
497
498
    total_input = 0
    completed = 0
    itls: List[float] = []
    tpots: List[float] = []
    ttfts: List[float] = []
zhyncs's avatar
zhyncs committed
499
    e2e_latencies: List[float] = []
zhyncs's avatar
zhyncs committed
500
501
    for i in range(len(outputs)):
        if outputs[i].success:
Ying Sheng's avatar
Ying Sheng committed
502
503
504
            output_len = outputs[i].output_len
            output_lens.append(output_len)
            retokenized_output_len = len(
Lianmin Zheng's avatar
Lianmin Zheng committed
505
                tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
Ying Sheng's avatar
Ying Sheng committed
506
507
            )
            retokenized_output_lens.append(retokenized_output_len)
zhyncs's avatar
zhyncs committed
508
509
510
511
512
            total_input += input_requests[i][1]
            if output_len > 1:
                tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
            itls += outputs[i].itl
            ttfts.append(outputs[i].ttft)
zhyncs's avatar
zhyncs committed
513
514
515

            e2e_latencies.append(outputs[i].latency)

zhyncs's avatar
zhyncs committed
516
517
            completed += 1
        else:
Ying Sheng's avatar
Ying Sheng committed
518
519
            output_lens.append(0)
            retokenized_output_lens.append(0)
zhyncs's avatar
zhyncs committed
520
521
522
523
524
525
526
527
528
529

    if completed == 0:
        warnings.warn(
            "All requests failed. This is likely due to a misconfiguration "
            "on the benchmark arguments.",
            stacklevel=2,
        )
    metrics = BenchmarkMetrics(
        completed=completed,
        total_input=total_input,
Ying Sheng's avatar
Ying Sheng committed
530
531
        total_output=sum(output_lens),
        total_output_retokenized=sum(retokenized_output_lens),
zhyncs's avatar
zhyncs committed
532
533
        request_throughput=completed / dur_s,
        input_throughput=total_input / dur_s,
Ying Sheng's avatar
Ying Sheng committed
534
535
        output_throughput=sum(output_lens) / dur_s,
        output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,
zhyncs's avatar
zhyncs committed
536
537
538
539
540
541
542
543
544
545
546
547
548
        mean_ttft_ms=np.mean(ttfts or 0)
        * 1000,  # ttfts is empty if streaming is not supported by backend
        median_ttft_ms=np.median(ttfts or 0) * 1000,
        std_ttft_ms=np.std(ttfts or 0) * 1000,
        p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
        mean_tpot_ms=np.mean(tpots or 0) * 1000,
        median_tpot_ms=np.median(tpots or 0) * 1000,
        std_tpot_ms=np.std(tpots or 0) * 1000,
        p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
        mean_itl_ms=np.mean(itls or 0) * 1000,
        median_itl_ms=np.median(itls or 0) * 1000,
        std_itl_ms=np.std(itls or 0) * 1000,
        p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
zhyncs's avatar
zhyncs committed
549
550
        mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
        median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
zhyncs's avatar
zhyncs committed
551
552
    )

Ying Sheng's avatar
Ying Sheng committed
553
    return metrics, output_lens
zhyncs's avatar
zhyncs committed
554
555
556
557
558
559
560
561
562
563


async def benchmark(
    backend: str,
    api_url: str,
    model_id: str,
    tokenizer: PreTrainedTokenizerBase,
    input_requests: List[Tuple[str, int, int]],
    request_rate: float,
    disable_tqdm: bool,
564
    extra_request_body: Dict[str, Any],
zhyncs's avatar
zhyncs committed
565
566
567
568
569
570
571
572
573
574
575
576
577
578
):
    if backend in ASYNC_REQUEST_FUNCS:
        request_func = ASYNC_REQUEST_FUNCS[backend]
    else:
        raise ValueError(f"Unknown backend: {backend}")

    print("Starting initial single prompt test run...")
    test_prompt, test_prompt_len, test_output_len = input_requests[0]
    test_input = RequestFuncInput(
        model=model_id,
        prompt=test_prompt,
        api_url=api_url,
        prompt_len=test_prompt_len,
        output_len=test_output_len,
579
        extra_request_body=extra_request_body,
zhyncs's avatar
zhyncs committed
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
    )
    test_output = await request_func(request_func_input=test_input)
    if not test_output.success:
        raise ValueError(
            "Initial test run failed - Please make sure benchmark arguments "
            f"are correctly specified. Error: {test_output.error}"
        )
    else:
        print("Initial test run completed. Starting main benchmark run...")

    pbar = None if disable_tqdm else tqdm(total=len(input_requests))

    benchmark_start_time = time.perf_counter()
    tasks: List[asyncio.Task] = []
    async for request in get_request(input_requests, request_rate):
        prompt, prompt_len, output_len = request
        request_func_input = RequestFuncInput(
            model=model_id,
            prompt=prompt,
            api_url=api_url,
            prompt_len=prompt_len,
            output_len=output_len,
602
            extra_request_body=extra_request_body,
zhyncs's avatar
zhyncs committed
603
604
605
606
607
608
609
610
611
612
613
614
615
        )
        tasks.append(
            asyncio.create_task(
                request_func(request_func_input=request_func_input, pbar=pbar)
            )
        )
    outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)

    if pbar is not None:
        pbar.close()

    benchmark_duration = time.perf_counter() - benchmark_start_time

Ying Sheng's avatar
Ying Sheng committed
616
    metrics, output_lens = calculate_metrics(
zhyncs's avatar
zhyncs committed
617
618
619
620
        input_requests=input_requests,
        outputs=outputs,
        dur_s=benchmark_duration,
        tokenizer=tokenizer,
621
        backend=backend,
zhyncs's avatar
zhyncs committed
622
623
624
    )

    print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
625
    print("{:<40} {:<10}".format("Backend:", backend))
zhyncs's avatar
zhyncs committed
626
627
628
629
630
    print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
    print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
    print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
    print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
    print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
Ying Sheng's avatar
Ying Sheng committed
631
632
633
634
635
    print(
        "{:<40} {:<10}".format(
            "Total generated tokens (retokenized):", metrics.total_output_retokenized
        )
    )
zhyncs's avatar
zhyncs committed
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
    print(
        "{:<40} {:<10.2f}".format(
            "Request throughput (req/s):", metrics.request_throughput
        )
    )
    print(
        "{:<40} {:<10.2f}".format(
            "Input token throughput (tok/s):", metrics.input_throughput
        )
    )
    print(
        "{:<40} {:<10.2f}".format(
            "Output token throughput (tok/s):", metrics.output_throughput
        )
    )
zhyncs's avatar
zhyncs committed
651
652
653
654
655
656
657
658
659
    print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
    print(
        "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
    )
    print(
        "{:<40} {:<10.2f}".format(
            "Median E2E Latency (ms):", metrics.median_e2e_latency_ms
        )
    )
zhyncs's avatar
zhyncs committed
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
    print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-"))
    print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
    print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
    print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
    print(
        "{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")
    )
    print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
    print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
    print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
    print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-"))
    print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
    print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
    print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
    print("=" * 50)

zhyncs's avatar
zhyncs committed
676
677
678
679
680
681
682
683
684
    if (
        metrics.median_ttft_ms is not None
        and metrics.mean_itl_ms is not None
        and metrics.output_throughput is not None
    ):
        result = {
            "backend": args.backend,
            "dataset_name": args.dataset_name,
            "request_rate": request_rate,
685
686
687
688
689
690
691
692
            "total_input_tokens": metrics.total_input,
            "total_output_tokens": metrics.total_output,
            "total_output_tokens_retokenized": metrics.total_output_retokenized,
            "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
            "median_e2e_latency_ms": metrics.median_e2e_latency_ms,
            "median_ttft_ms": metrics.median_ttft_ms,
            "median_itl_ms": metrics.median_itl_ms,
            "output_throughput": metrics.output_throughput,
zhyncs's avatar
zhyncs committed
693
694
695
696
            "sharegpt_output_len": args.sharegpt_output_len,
            "random_input_len": args.random_input_len,
            "random_output_len": args.random_output_len,
            "random_range_ratio": args.random_range_ratio,
697
698
            "duration": benchmark_duration,
            "completed": metrics.completed,
zhyncs's avatar
zhyncs committed
699
700
701
702
        }
    else:
        print(f"Error running benchmark for request rate: {request_rate}")
        print("-" * 30)
703

zhyncs's avatar
zhyncs committed
704
705
706
707
708
709
710
    # Determine output file name
    if args.output_file:
        output_file_name = args.output_file
    else:
        now = datetime.now().strftime("%m%d")
        if args.dataset_name == "random":
            output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
711
        else:
zhyncs's avatar
zhyncs committed
712
            output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
713

zhyncs's avatar
zhyncs committed
714
715
716
    # Append results to a JSONL file
    with open(output_file_name, "a") as file:
        file.write(json.dumps(result) + "\n")
717

zhyncs's avatar
zhyncs committed
718
719
720
721
722
    result = {
        "duration": benchmark_duration,
        "completed": metrics.completed,
        "total_input_tokens": metrics.total_input,
        "total_output_tokens": metrics.total_output,
Ying Sheng's avatar
Ying Sheng committed
723
        "total_output_tokens_retokenized": metrics.total_output_retokenized,
zhyncs's avatar
zhyncs committed
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
        "request_throughput": metrics.request_throughput,
        "input_throughput": metrics.input_throughput,
        "output_throughput": metrics.output_throughput,
        "mean_ttft_ms": metrics.mean_ttft_ms,
        "median_ttft_ms": metrics.median_ttft_ms,
        "std_ttft_ms": metrics.std_ttft_ms,
        "p99_ttft_ms": metrics.p99_ttft_ms,
        "mean_tpot_ms": metrics.mean_tpot_ms,
        "median_tpot_ms": metrics.median_tpot_ms,
        "std_tpot_ms": metrics.std_tpot_ms,
        "p99_tpot_ms": metrics.p99_tpot_ms,
        "mean_itl_ms": metrics.mean_itl_ms,
        "median_itl_ms": metrics.median_itl_ms,
        "std_itl_ms": metrics.std_itl_ms,
        "p99_itl_ms": metrics.p99_itl_ms,
        "input_lens": [output.prompt_len for output in outputs],
Ying Sheng's avatar
Ying Sheng committed
740
        "output_lens": output_lens,
zhyncs's avatar
zhyncs committed
741
742
743
744
        "ttfts": [output.ttft for output in outputs],
        "itls": [output.itl for output in outputs],
        "generated_texts": [output.generated_text for output in outputs],
        "errors": [output.error for output in outputs],
zhyncs's avatar
zhyncs committed
745
746
        "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
        "median_e2e_latency_ms": metrics.median_e2e_latency_ms,
zhyncs's avatar
zhyncs committed
747
748
749
750
    }
    return result


751
def parse_request_rate_range(request_rate_range):
zhyncs's avatar
zhyncs committed
752
753
754
755
756
    if len(request_rate_range.split(",")) == 3:
        start, stop, step = map(int, request_rate_range.split(","))
        return list(range(start, stop, step))
    else:
        return list(map(int, request_rate_range.split(",")))
757
758


759
760
761
762
763
764
765
766
767
def check_chat_template(model_path):
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        return "chat_template" in tokenizer.init_kwargs
    except Exception as e:
        print(f"Fail to load tokenizer config with error={e}")
        return False


768
769
770
771
def run_benchmark(args_: argparse.Namespace):
    global args
    args = args_

Lianmin Zheng's avatar
Lianmin Zheng committed
772
    # Set global environments
773
    set_ulimit()
zhyncs's avatar
zhyncs committed
774
775
776
    random.seed(args.seed)
    np.random.seed(args.seed)

777
778
779
780
    extra_request_body = {}
    if args.extra_request_body:
        extra_request_body = json.loads(args.extra_request_body)

Lianmin Zheng's avatar
Lianmin Zheng committed
781
    # Set url
zhyncs's avatar
zhyncs committed
782
783
784
785
786
    if args.port is None:
        args.port = {
            "sglang": 30000,
            "lmdeploy": 23333,
            "vllm": 8000,
787
            "trt": 8000,
788
            "gserver": 9988,
zhyncs's avatar
zhyncs committed
789
790
791
792
793
794
795
796
797
798
799
800
801
        }.get(args.backend, 30000)

    api_url = (
        f"{args.base_url}/v1/completions"
        if args.base_url
        else f"http://{args.host}:{args.port}/v1/completions"
    )
    model_url = (
        f"{args.base_url}/v1/models"
        if args.base_url
        else f"http://{args.host}:{args.port}/v1/models"
    )

802
803
804
805
806
807
808
809
810
    if args.backend == "trt":
        api_url = (
            f"{args.base_url}/v2/models/ensemble/generate_stream"
            if args.base_url
            else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream"
        )
        if args.model is None:
            print("Please provide a model using `--model` when using `trt` backend.")
            sys.exit(1)
811
    elif args.backend == "gserver":
Lianmin Zheng's avatar
Lianmin Zheng committed
812
813
        api_url = args.base_url if args.base_url else f"{args.host}:{args.port}"
        args.model = args.model or "default"
814

Lianmin Zheng's avatar
Lianmin Zheng committed
815
    # Get model name
zhyncs's avatar
zhyncs committed
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
    if args.model is None:
        try:
            response = requests.get(model_url)
            model_list = response.json().get("data", [])
            args.model = model_list[0]["id"] if model_list else None
        except Exception as e:
            print(f"Failed to fetch model from {model_url}. Error: {e}")
            print(
                "Please specify the correct host and port using `--host` and `--port`."
            )
            sys.exit(1)

    if args.model is None:
        print("No model specified or found. Please provide a model using `--model`.")
        sys.exit(1)

832
833
834
835
836
837
    if not check_chat_template(args.model):
        print(
            "\nWARNING It is recommended to use the `Chat` or `Instruct` model for benchmarking.\n"
            "Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\n"
        )

zhyncs's avatar
zhyncs committed
838
839
    print(f"{args}\n")

Lianmin Zheng's avatar
Lianmin Zheng committed
840
    # Read dataset
zhyncs's avatar
zhyncs committed
841
842
843
844
845
846
    backend = args.backend
    model_id = args.model
    tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model

    tokenizer = get_tokenizer(tokenizer_id)

847
848
849
850
851
852
853
854
855
856
857
858
859
860
    if args.dataset_name == "sharegpt":
        input_requests = sample_sharegpt_requests(
            dataset_path=args.dataset_path,
            num_requests=args.num_prompts,
            tokenizer=tokenizer,
            fixed_output_len=args.sharegpt_output_len,
        )
    elif args.dataset_name == "random":
        input_requests = sample_random_requests(
            input_len=args.random_input_len,
            output_len=args.random_output_len,
            num_prompts=args.num_prompts,
            range_ratio=args.random_range_ratio,
            tokenizer=tokenizer,
Lianmin Zheng's avatar
Lianmin Zheng committed
861
            dataset_path=args.dataset_path,
862
863
864
        )
    else:
        raise ValueError(f"Unknown dataset: {args.dataset_name}")
zhyncs's avatar
zhyncs committed
865

Lianmin Zheng's avatar
Lianmin Zheng committed
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
    if not args.multi:
        return asyncio.run(
            benchmark(
                backend=backend,
                api_url=api_url,
                model_id=model_id,
                tokenizer=tokenizer,
                input_requests=input_requests,
                request_rate=args.request_rate,
                disable_tqdm=args.disable_tqdm,
                extra_request_body=extra_request_body,
            )
        )
    else:
        # Benchmark multiple rps. TODO: use a fixed duration to compute num_prompts
881
882
883
884
885
886
887
888
889
890
891
892
        request_rates = parse_request_rate_range(args.request_rate_range)

        for rate in request_rates:
            asyncio.run(
                benchmark(
                    backend=backend,
                    api_url=api_url,
                    model_id=model_id,
                    tokenizer=tokenizer,
                    input_requests=input_requests,
                    request_rate=rate,
                    disable_tqdm=args.disable_tqdm,
893
                    extra_request_body=extra_request_body,
894
895
                )
            )
zhyncs's avatar
zhyncs committed
896
897
898
899
900
901
902
903
904
905
906
907
908
909


def set_ulimit(target_soft_limit=65535):
    resource_type = resource.RLIMIT_NOFILE
    current_soft, current_hard = resource.getrlimit(resource_type)

    if current_soft < target_soft_limit:
        try:
            resource.setrlimit(resource_type, (target_soft_limit, current_hard))
        except ValueError as e:
            print(f"Fail to set RLIMIT_NOFILE: {e}")


if __name__ == "__main__":
910
    parser = ArgumentParser(description="Benchmark the online serving throughput.")
zhyncs's avatar
zhyncs committed
911
912
913
914
    parser.add_argument(
        "--backend",
        type=str,
        choices=list(ASYNC_REQUEST_FUNCS.keys()),
915
        default="sglang",
zhyncs's avatar
zhyncs committed
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
        help="Must specify a backend, depending on the LLM Inference Engine.",
    )
    parser.add_argument(
        "--base-url",
        type=str,
        default=None,
        help="Server or API base url if not using http host and port.",
    )
    parser.add_argument(
        "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
    )
    parser.add_argument(
        "--port",
        type=int,
        help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
    )
    parser.add_argument(
933
934
935
936
937
938
939
940
        "--dataset-name",
        type=str,
        default="sharegpt",
        choices=["sharegpt", "random"],
        help="Name of the dataset to benchmark on.",
    )
    parser.add_argument(
        "--dataset-path", type=str, default="", help="Path to the dataset."
zhyncs's avatar
zhyncs committed
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
    )
    parser.add_argument(
        "--model",
        type=str,
        help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
    )
    parser.add_argument(
        "--tokenizer",
        type=str,
        help="Name or path of the tokenizer. If not set, using the model conf.",
    )
    parser.add_argument(
        "--num-prompts",
        type=int,
        default=1000,
        help="Number of prompts to process. Default is 1000.",
    )
    parser.add_argument(
        "--sharegpt-output-len",
        type=int,
        default=None,
        help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
    )
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
    parser.add_argument(
        "--random-input-len",
        type=int,
        default=1024,
        help="Number of input tokens per request, used only for random dataset.",
    )
    parser.add_argument(
        "--random-output-len",
        type=int,
        default=128,
        help="Number of output tokens per request, used only for random dataset.",
    )
    parser.add_argument(
        "--random-range-ratio",
        type=float,
Yineng Zhang's avatar
Yineng Zhang committed
979
        default=0.0,
980
981
982
        help="Range of sampled ratio of input/output length, "
        "used only for random dataset.",
    )
zhyncs's avatar
zhyncs committed
983
984
985
    parser.add_argument(
        "--request-rate",
        type=float,
986
        default=float("inf"),
zhyncs's avatar
zhyncs committed
987
        help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
min-xu-et's avatar
min-xu-et committed
988
        "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
zhyncs's avatar
zhyncs committed
989
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
990
    parser.add_argument("--seed", type=int, default=1, help="The random seed.")
991
992
993
994
995
996
997
998
999
    parser.add_argument(
        "--multi",
        action="store_true",
        help="Use request rate range rather than single value.",
    )
    parser.add_argument(
        "--request-rate-range",
        type=str,
        default="2,34,2",
zhyncs's avatar
zhyncs committed
1000
        help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.",
1001
1002
    )
    parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
1003
1004
1005
1006
1007
    parser.add_argument(
        "--disable-tqdm",
        action="store_true",
        help="Specify to disable tqdm progress bar.",
    )
1008
1009
1010
1011
1012
    parser.add_argument(
        "--disable-stream",
        action="store_true",
        help="Disable streaming mode.",
    )
1013
1014
1015
1016
1017
    parser.add_argument(
        "--disable-ignore-eos",
        action="store_true",
        help="Disable ignoring EOS.",
    )
1018
1019
1020
1021
1022
1023
1024
    parser.add_argument(
        "--extra-request-body",
        metavar='{"key1": "value1", "key2": "value2"}',
        type=str,
        help="Append given JSON object to the request payload. You can use this to specify"
        "additional generate params like sampling params.",
    )
zhyncs's avatar
zhyncs committed
1025
    args = parser.parse_args()
1026
    run_benchmark(args)