bench_serving.py 33.8 KB
Newer Older
zhyncs's avatar
zhyncs committed
1
2
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py
Ying Sheng's avatar
Ying Sheng committed
3
4
5
6
"""
Benchmark online serving.

Usage:
7
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
Ying Sheng's avatar
Ying Sheng committed
8

9
10
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5
python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi
Ying Sheng's avatar
Ying Sheng committed
11
"""
zhyncs's avatar
zhyncs committed
12
13
14
15
16
17
18
19
20
21
22
23
24

import argparse
import asyncio
import json
import os
import random
import resource
import sys
import time
import traceback
import warnings
from argparse import ArgumentParser as FlexibleArgumentParser
from dataclasses import dataclass, field
25
from datetime import datetime
zhyncs's avatar
zhyncs committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from typing import AsyncGenerator, List, Optional, Tuple, Union

import aiohttp
import numpy as np
import requests
from tqdm.asyncio import tqdm
from transformers import (
    AutoTokenizer,
    PreTrainedTokenizer,
    PreTrainedTokenizerBase,
    PreTrainedTokenizerFast,
)

AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)


@dataclass
class RequestFuncInput:
    prompt: str
    api_url: str
    prompt_len: int
    output_len: int
    model: str


@dataclass
class RequestFuncOutput:
    generated_text: str = ""
    success: bool = False
    latency: float = 0.0
    ttft: float = 0.0  # Time to first token
    itl: List[float] = field(default_factory=list)  # List of inter-token latencies
    prompt_len: int = 0
    error: str = ""
60
    output_len: int = 0
zhyncs's avatar
zhyncs committed
61
62
63
64
65
66


def remove_prefix(text: str, prefix: str) -> str:
    return text[len(prefix) :] if text.startswith(prefix) else text


67
68
69
70
71
72
73
74
75
76
77
78
79
# trt llm not support ignore_eos
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
async def async_request_trt_llm(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        payload = {
            "accumulate_tokens": True,
            "text_input": request_func_input.prompt,
zhyncs's avatar
zhyncs committed
80
            "temperature": 0.000001,
81
82
83
            "top_p": 1.0,
            "max_tokens": request_func_input.output_len,
            "stream": True,
Ying Sheng's avatar
Ying Sheng committed
84
85
            "min_length": request_func_input.output_len,
            "end_id": 1048576,
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        }
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        ttft = 0.0
        st = time.perf_counter()
        most_recent_timestamp = st
        try:
            async with session.post(url=api_url, json=payload) as response:
                if response.status == 200:
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
                            continue

                        chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:")

                        data = json.loads(chunk)
                        output.generated_text += data["text_output"]
                        timestamp = time.perf_counter()
                        # First token
                        if ttft == 0.0:
                            ttft = time.perf_counter() - st
                            output.ttft = ttft

                        # Decoding phase
                        else:
                            output.itl.append(timestamp - most_recent_timestamp)

                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
                    output.success = True
Ying Sheng's avatar
Ying Sheng committed
119
                    output.output_len = request_func_input.output_len
120
121
122
123
124
125
126
127
128
129
130
131
132
133

                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

        if pbar:
            pbar.update(1)
        return output


zhyncs's avatar
zhyncs committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# set ignore_eos True by default
async def async_request_openai_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith(
        "completions"
    ), "OpenAI Completions API URL must end with 'completions'."

    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        payload = {
            "model": request_func_input.model,
            "prompt": request_func_input.prompt,
            "temperature": 0.0,
            "best_of": 1,
            "max_tokens": request_func_input.output_len,
151
            "stream": not args.disable_stream,
zhyncs's avatar
zhyncs committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
            "ignore_eos": True,
        }
        headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}

        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len

        generated_text = ""
        ttft = 0.0
        st = time.perf_counter()
        most_recent_timestamp = st
        try:
            async with session.post(
                url=api_url, json=payload, headers=headers
            ) as response:
                if response.status == 200:
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
                            continue

                        chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
174
                        latency = time.perf_counter() - st
zhyncs's avatar
zhyncs committed
175
                        if chunk == "[DONE]":
176
                            pass
zhyncs's avatar
zhyncs committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
                        else:
                            data = json.loads(chunk)

                            # NOTE: Some completion API might have a last
                            # usage summary response without a token so we
                            # want to check a token was generated
                            if data["choices"][0]["text"]:
                                timestamp = time.perf_counter()
                                # First token
                                if ttft == 0.0:
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft

                                # Decoding phase
                                output.itl.append(timestamp - most_recent_timestamp)

                                most_recent_timestamp = timestamp
                                generated_text += data["choices"][0]["text"]

                    output.generated_text = generated_text
                    output.success = True
                    output.latency = latency
199
                    output.output_len = request_func_input.output_len
zhyncs's avatar
zhyncs committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

    if pbar:
        pbar.update(1)
    return output


def get_model(pretrained_model_name_or_path: str) -> str:
    if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true":
        import huggingface_hub.constants
        from modelscope import snapshot_download

        model_path = snapshot_download(
            model_id=pretrained_model_name_or_path,
            local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
            ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
        )

        return model_path
    return pretrained_model_name_or_path


def get_tokenizer(
    pretrained_model_name_or_path: str,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
    if pretrained_model_name_or_path is not None and not os.path.exists(
        pretrained_model_name_or_path
    ):
        pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)
    return AutoTokenizer.from_pretrained(
        pretrained_model_name_or_path, trust_remote_code=True
    )


ASYNC_REQUEST_FUNCS = {
    "sglang": async_request_openai_completions,
    "vllm": async_request_openai_completions,
    "lmdeploy": async_request_openai_completions,
244
    "trt": async_request_trt_llm,
zhyncs's avatar
zhyncs committed
245
246
247
248
249
250
251
252
}


@dataclass
class BenchmarkMetrics:
    completed: int
    total_input: int
    total_output: int
Ying Sheng's avatar
Ying Sheng committed
253
    total_output_retokenized: int
zhyncs's avatar
zhyncs committed
254
255
256
    request_throughput: float
    input_throughput: float
    output_throughput: float
Ying Sheng's avatar
Ying Sheng committed
257
    output_throughput_retokenized: float
zhyncs's avatar
zhyncs committed
258
259
260
261
262
263
264
265
266
267
268
269
    mean_ttft_ms: float
    median_ttft_ms: float
    std_ttft_ms: float
    p99_ttft_ms: float
    mean_tpot_ms: float
    median_tpot_ms: float
    std_tpot_ms: float
    p99_tpot_ms: float
    mean_itl_ms: float
    median_itl_ms: float
    std_itl_ms: float
    p99_itl_ms: float
zhyncs's avatar
zhyncs committed
270
271
    mean_e2e_latency_ms: float
    median_e2e_latency_ms: float
zhyncs's avatar
zhyncs committed
272
273


Lianmin Zheng's avatar
Lianmin Zheng committed
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json"


def download_sharegpt_dataset(path):
    url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"

    print(f"Downloading dataset from {url}")
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()

        total_size = int(response.headers.get("content-length", 0))
        block_size = 8192

        with open(path, "wb") as f, tqdm(
            desc="Downloading",
            total=total_size,
            unit="iB",
            unit_scale=True,
            unit_divisor=1024,
        ) as progress_bar:
            for data in response.iter_content(block_size):
                size = f.write(data)
                progress_bar.update(size)

        print(f"Dataset downloaded and saved to {path}")
    except requests.RequestException as e:
        raise Exception(f"Failed to download dataset: {e}")


zhyncs's avatar
zhyncs committed
304
305
306
307
308
309
310
311
312
def sample_sharegpt_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
    fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, int, int]]:
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")

Lianmin Zheng's avatar
Lianmin Zheng committed
313
314
315
316
    # Download sharegpt if necessary
    if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path):
        download_sharegpt_dataset(default_sharegpt_path)
        dataset_path = default_sharegpt_path
zhyncs's avatar
zhyncs committed
317
318
    else:
        dataset_path = (
Lianmin Zheng's avatar
Lianmin Zheng committed
319
            dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
zhyncs's avatar
zhyncs committed
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
        )

    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
    # Filter out the conversations with less than 2 turns.
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
    # Only keep the first two turns of each conversation.
    dataset = [
        (data["conversations"][0]["value"], data["conversations"][1]["value"])
        for data in dataset
    ]

    # Shuffle the dataset.
    random.shuffle(dataset)

    # Filter out sequences that are too long or too short
    filtered_dataset: List[Tuple[str, int, int]] = []
    for i in range(len(dataset)):
        if len(filtered_dataset) == num_requests:
            break

        # Tokenize the prompts and completions.
        prompt = dataset[i][0]
        prompt_token_ids = tokenizer(prompt).input_ids
        completion = dataset[i][1]
        completion_token_ids = tokenizer(completion).input_ids
        prompt_len = len(prompt_token_ids)
        output_len = (
            len(completion_token_ids) if fixed_output_len is None else fixed_output_len
        )
        if prompt_len < 4 or output_len < 4:
            # Prune too short sequences.
            continue
        if prompt_len > 1024 or prompt_len + output_len > 2048:
            # Prune too long sequences.
            continue
        filtered_dataset.append((prompt, prompt_len, output_len))

    return filtered_dataset


362
363
364
365
366
367
def sample_random_requests(
    input_len: int,
    output_len: int,
    num_prompts: int,
    range_ratio: float,
    tokenizer: PreTrainedTokenizerBase,
Lianmin Zheng's avatar
Lianmin Zheng committed
368
    dataset_path: str,
369
370
371
) -> List[Tuple[str, int, int]]:

    input_lens = np.random.randint(
Yineng Zhang's avatar
Yineng Zhang committed
372
        max(int(input_len * range_ratio), 1),
373
374
375
376
377
378
379
380
        input_len + 1,
        size=num_prompts,
    )
    output_lens = np.random.randint(
        int(output_len * range_ratio),
        output_len + 1,
        size=num_prompts,
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417

    if True:
        # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens

        # Download sharegpt if necessary
        if not os.path.isfile(dataset_path) and not os.path.isfile(
            default_sharegpt_path
        ):
            download_sharegpt_dataset(default_sharegpt_path)
            dataset_path = default_sharegpt_path
        else:
            dataset_path = (
                dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
            )

        # Load the dataset.
        with open(dataset_path) as f:
            dataset = json.load(f)
        # Filter out the conversations with less than 2 turns.
        dataset = [data for data in dataset if len(data["conversations"]) >= 2]
        # Only keep the first two turns of each conversation.
        dataset = [
            (data["conversations"][0]["value"], data["conversations"][1]["value"])
            for data in dataset
        ]

        # Shuffle the dataset.
        random.shuffle(dataset)

        # Filter out sequences that are too long or too short
        input_requests: List[Tuple[str, int, int]] = []
        for i in range(num_prompts):
            # Tokenize the prompts and completions.
            prompt = dataset[i][0]
            prompt_token_ids = tokenizer(prompt).input_ids
            prompt_len = len(prompt_token_ids)

Yineng Zhang's avatar
Yineng Zhang committed
418
            if prompt_len > input_lens[i]:
Lianmin Zheng's avatar
Lianmin Zheng committed
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
                input_ids = prompt_token_ids[: input_lens[i]]
            else:
                ratio = (input_lens[i] + prompt_len - 1) // prompt_len
                input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
            prompt = tokenizer.decode(input_ids)
            input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
    else:
        # Sample token ids from random integers. This can cause some NaN issues.
        offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
        input_requests = []
        for i in range(num_prompts):
            prompt = tokenizer.decode(
                [
                    (offsets[i] + i + j) % tokenizer.vocab_size
                    for j in range(input_lens[i])
                ]
            )
            input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
437
438
439
440
441
442

    print(f"#Input tokens: {np.sum(input_lens)}")
    print(f"#Output tokens: {np.sum(output_lens)}")
    return input_requests


zhyncs's avatar
zhyncs committed
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
async def get_request(
    input_requests: List[Tuple[str, int, int]],
    request_rate: float,
) -> AsyncGenerator[Tuple[str, int, int], None]:
    input_requests = iter(input_requests)
    for request in input_requests:
        yield request

        if request_rate == float("inf"):
            # If the request rate is infinity, then we don't need to wait.
            continue

        # Sample the request interval from the exponential distribution.
        interval = np.random.exponential(1.0 / request_rate)
        # The next request will be sent after the interval.
        await asyncio.sleep(interval)


def calculate_metrics(
    input_requests: List[Tuple[str, int, int]],
    outputs: List[RequestFuncOutput],
    dur_s: float,
    tokenizer: PreTrainedTokenizerBase,
466
    backend: str,
zhyncs's avatar
zhyncs committed
467
) -> Tuple[BenchmarkMetrics, List[int]]:
Ying Sheng's avatar
Ying Sheng committed
468
469
    output_lens: List[int] = []
    retokenized_output_lens: List[int] = []
zhyncs's avatar
zhyncs committed
470
471
472
473
474
    total_input = 0
    completed = 0
    itls: List[float] = []
    tpots: List[float] = []
    ttfts: List[float] = []
zhyncs's avatar
zhyncs committed
475
    e2e_latencies: List[float] = []
zhyncs's avatar
zhyncs committed
476
477
    for i in range(len(outputs)):
        if outputs[i].success:
Ying Sheng's avatar
Ying Sheng committed
478
479
480
481
482
483
            output_len = outputs[i].output_len
            output_lens.append(output_len)
            retokenized_output_len = len(
                tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids
            )
            retokenized_output_lens.append(retokenized_output_len)
zhyncs's avatar
zhyncs committed
484
485
486
487
488
            total_input += input_requests[i][1]
            if output_len > 1:
                tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
            itls += outputs[i].itl
            ttfts.append(outputs[i].ttft)
zhyncs's avatar
zhyncs committed
489
490
491

            e2e_latencies.append(outputs[i].latency)

zhyncs's avatar
zhyncs committed
492
493
            completed += 1
        else:
Ying Sheng's avatar
Ying Sheng committed
494
495
            output_lens.append(0)
            retokenized_output_lens.append(0)
zhyncs's avatar
zhyncs committed
496
497
498
499
500
501
502
503
504
505

    if completed == 0:
        warnings.warn(
            "All requests failed. This is likely due to a misconfiguration "
            "on the benchmark arguments.",
            stacklevel=2,
        )
    metrics = BenchmarkMetrics(
        completed=completed,
        total_input=total_input,
Ying Sheng's avatar
Ying Sheng committed
506
507
        total_output=sum(output_lens),
        total_output_retokenized=sum(retokenized_output_lens),
zhyncs's avatar
zhyncs committed
508
509
        request_throughput=completed / dur_s,
        input_throughput=total_input / dur_s,
Ying Sheng's avatar
Ying Sheng committed
510
511
        output_throughput=sum(output_lens) / dur_s,
        output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,
zhyncs's avatar
zhyncs committed
512
513
514
515
516
517
518
519
520
521
522
523
524
        mean_ttft_ms=np.mean(ttfts or 0)
        * 1000,  # ttfts is empty if streaming is not supported by backend
        median_ttft_ms=np.median(ttfts or 0) * 1000,
        std_ttft_ms=np.std(ttfts or 0) * 1000,
        p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
        mean_tpot_ms=np.mean(tpots or 0) * 1000,
        median_tpot_ms=np.median(tpots or 0) * 1000,
        std_tpot_ms=np.std(tpots or 0) * 1000,
        p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
        mean_itl_ms=np.mean(itls or 0) * 1000,
        median_itl_ms=np.median(itls or 0) * 1000,
        std_itl_ms=np.std(itls or 0) * 1000,
        p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
zhyncs's avatar
zhyncs committed
525
526
        mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
        median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
zhyncs's avatar
zhyncs committed
527
528
    )

Ying Sheng's avatar
Ying Sheng committed
529
    return metrics, output_lens
zhyncs's avatar
zhyncs committed
530
531
532
533
534
535
536
537
538
539


async def benchmark(
    backend: str,
    api_url: str,
    model_id: str,
    tokenizer: PreTrainedTokenizerBase,
    input_requests: List[Tuple[str, int, int]],
    request_rate: float,
    disable_tqdm: bool,
540
    enable_multi: bool,
zhyncs's avatar
zhyncs committed
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
):
    if backend in ASYNC_REQUEST_FUNCS:
        request_func = ASYNC_REQUEST_FUNCS[backend]
    else:
        raise ValueError(f"Unknown backend: {backend}")

    print("Starting initial single prompt test run...")
    test_prompt, test_prompt_len, test_output_len = input_requests[0]
    test_input = RequestFuncInput(
        model=model_id,
        prompt=test_prompt,
        api_url=api_url,
        prompt_len=test_prompt_len,
        output_len=test_output_len,
    )
    test_output = await request_func(request_func_input=test_input)
    if not test_output.success:
        raise ValueError(
            "Initial test run failed - Please make sure benchmark arguments "
            f"are correctly specified. Error: {test_output.error}"
        )
    else:
        print("Initial test run completed. Starting main benchmark run...")

    pbar = None if disable_tqdm else tqdm(total=len(input_requests))

    benchmark_start_time = time.perf_counter()
    tasks: List[asyncio.Task] = []
    async for request in get_request(input_requests, request_rate):
        prompt, prompt_len, output_len = request
        request_func_input = RequestFuncInput(
            model=model_id,
            prompt=prompt,
            api_url=api_url,
            prompt_len=prompt_len,
            output_len=output_len,
        )
        tasks.append(
            asyncio.create_task(
                request_func(request_func_input=request_func_input, pbar=pbar)
            )
        )
    outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)

    if pbar is not None:
        pbar.close()

    benchmark_duration = time.perf_counter() - benchmark_start_time

Ying Sheng's avatar
Ying Sheng committed
590
    metrics, output_lens = calculate_metrics(
zhyncs's avatar
zhyncs committed
591
592
593
594
        input_requests=input_requests,
        outputs=outputs,
        dur_s=benchmark_duration,
        tokenizer=tokenizer,
595
        backend=backend,
zhyncs's avatar
zhyncs committed
596
597
598
    )

    print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
599
    print("{:<40} {:<10}".format("Backend:", backend))
zhyncs's avatar
zhyncs committed
600
601
602
603
604
    print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
    print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
    print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
    print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
    print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
Ying Sheng's avatar
Ying Sheng committed
605
606
607
608
609
    print(
        "{:<40} {:<10}".format(
            "Total generated tokens (retokenized):", metrics.total_output_retokenized
        )
    )
zhyncs's avatar
zhyncs committed
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
    print(
        "{:<40} {:<10.2f}".format(
            "Request throughput (req/s):", metrics.request_throughput
        )
    )
    print(
        "{:<40} {:<10.2f}".format(
            "Input token throughput (tok/s):", metrics.input_throughput
        )
    )
    print(
        "{:<40} {:<10.2f}".format(
            "Output token throughput (tok/s):", metrics.output_throughput
        )
    )
zhyncs's avatar
zhyncs committed
625
626
627
628
629
630
631
632
633
    print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
    print(
        "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
    )
    print(
        "{:<40} {:<10.2f}".format(
            "Median E2E Latency (ms):", metrics.median_e2e_latency_ms
        )
    )
zhyncs's avatar
zhyncs committed
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
    print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-"))
    print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
    print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
    print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
    print(
        "{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")
    )
    print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
    print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
    print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
    print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-"))
    print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
    print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
    print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
    print("=" * 50)

zhyncs's avatar
zhyncs committed
650
651
652
653
654
655
656
657
658
659
660
    if (
        metrics.median_ttft_ms is not None
        and metrics.mean_itl_ms is not None
        and metrics.output_throughput is not None
    ):
        result = {
            "backend": args.backend,
            "dataset_name": args.dataset_name,
            "request_rate": request_rate,
            "total_input": metrics.total_input,
            "total_output": metrics.total_output,
Ying Sheng's avatar
Ying Sheng committed
661
            "total_output_retokenized": metrics.total_output_retokenized,
zhyncs's avatar
zhyncs committed
662
663
            "mean_e2e_latency": metrics.mean_e2e_latency_ms,
            "median_e2e_latency": metrics.median_e2e_latency_ms,
zhyncs's avatar
zhyncs committed
664
            "median_ttft": metrics.median_ttft_ms,
zhyncs's avatar
zhyncs committed
665
            "median_itl": metrics.median_itl_ms,
zhyncs's avatar
zhyncs committed
666
667
668
669
670
            "output_token_throughput": metrics.output_throughput,
            "sharegpt_output_len": args.sharegpt_output_len,
            "random_input_len": args.random_input_len,
            "random_output_len": args.random_output_len,
            "random_range_ratio": args.random_range_ratio,
671
            "benchmark_duration": benchmark_duration,
zhyncs's avatar
zhyncs committed
672
673
674
675
        }
    else:
        print(f"Error running benchmark for request rate: {request_rate}")
        print("-" * 30)
676

zhyncs's avatar
zhyncs committed
677
678
679
680
681
682
683
    # Determine output file name
    if args.output_file:
        output_file_name = args.output_file
    else:
        now = datetime.now().strftime("%m%d")
        if args.dataset_name == "random":
            output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
684
        else:
zhyncs's avatar
zhyncs committed
685
            output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
686

zhyncs's avatar
zhyncs committed
687
688
689
    # Append results to a JSONL file
    with open(output_file_name, "a") as file:
        file.write(json.dumps(result) + "\n")
690

zhyncs's avatar
zhyncs committed
691
692
693
694
695
    result = {
        "duration": benchmark_duration,
        "completed": metrics.completed,
        "total_input_tokens": metrics.total_input,
        "total_output_tokens": metrics.total_output,
Ying Sheng's avatar
Ying Sheng committed
696
        "total_output_tokens_retokenized": metrics.total_output_retokenized,
zhyncs's avatar
zhyncs committed
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
        "request_throughput": metrics.request_throughput,
        "input_throughput": metrics.input_throughput,
        "output_throughput": metrics.output_throughput,
        "mean_ttft_ms": metrics.mean_ttft_ms,
        "median_ttft_ms": metrics.median_ttft_ms,
        "std_ttft_ms": metrics.std_ttft_ms,
        "p99_ttft_ms": metrics.p99_ttft_ms,
        "mean_tpot_ms": metrics.mean_tpot_ms,
        "median_tpot_ms": metrics.median_tpot_ms,
        "std_tpot_ms": metrics.std_tpot_ms,
        "p99_tpot_ms": metrics.p99_tpot_ms,
        "mean_itl_ms": metrics.mean_itl_ms,
        "median_itl_ms": metrics.median_itl_ms,
        "std_itl_ms": metrics.std_itl_ms,
        "p99_itl_ms": metrics.p99_itl_ms,
        "input_lens": [output.prompt_len for output in outputs],
Ying Sheng's avatar
Ying Sheng committed
713
        "output_lens": output_lens,
zhyncs's avatar
zhyncs committed
714
715
716
717
        "ttfts": [output.ttft for output in outputs],
        "itls": [output.itl for output in outputs],
        "generated_texts": [output.generated_text for output in outputs],
        "errors": [output.error for output in outputs],
zhyncs's avatar
zhyncs committed
718
719
        "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
        "median_e2e_latency_ms": metrics.median_e2e_latency_ms,
zhyncs's avatar
zhyncs committed
720
721
722
723
    }
    return result


724
def parse_request_rate_range(request_rate_range):
zhyncs's avatar
zhyncs committed
725
726
727
728
729
    if len(request_rate_range.split(",")) == 3:
        start, stop, step = map(int, request_rate_range.split(","))
        return list(range(start, stop, step))
    else:
        return list(map(int, request_rate_range.split(",")))
730
731


732
733
734
735
736
737
738
739
740
def check_chat_template(model_path):
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        return "chat_template" in tokenizer.init_kwargs
    except Exception as e:
        print(f"Fail to load tokenizer config with error={e}")
        return False


zhyncs's avatar
zhyncs committed
741
742
743
744
745
746
747
748
749
def fire(args: argparse.Namespace):
    random.seed(args.seed)
    np.random.seed(args.seed)

    if args.port is None:
        args.port = {
            "sglang": 30000,
            "lmdeploy": 23333,
            "vllm": 8000,
750
            "trt": 8000,
zhyncs's avatar
zhyncs committed
751
752
753
754
755
756
757
758
759
760
761
762
763
        }.get(args.backend, 30000)

    api_url = (
        f"{args.base_url}/v1/completions"
        if args.base_url
        else f"http://{args.host}:{args.port}/v1/completions"
    )
    model_url = (
        f"{args.base_url}/v1/models"
        if args.base_url
        else f"http://{args.host}:{args.port}/v1/models"
    )

764
765
766
767
768
769
770
771
772
773
    if args.backend == "trt":
        api_url = (
            f"{args.base_url}/v2/models/ensemble/generate_stream"
            if args.base_url
            else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream"
        )
        if args.model is None:
            print("Please provide a model using `--model` when using `trt` backend.")
            sys.exit(1)

zhyncs's avatar
zhyncs committed
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
    if args.model is None:
        try:
            response = requests.get(model_url)
            model_list = response.json().get("data", [])
            args.model = model_list[0]["id"] if model_list else None
        except Exception as e:
            print(f"Failed to fetch model from {model_url}. Error: {e}")
            print(
                "Please specify the correct host and port using `--host` and `--port`."
            )
            sys.exit(1)

    if args.model is None:
        print("No model specified or found. Please provide a model using `--model`.")
        sys.exit(1)

790
791
792
793
794
795
    if not check_chat_template(args.model):
        print(
            "\nWARNING It is recommended to use the `Chat` or `Instruct` model for benchmarking.\n"
            "Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\n"
        )

zhyncs's avatar
zhyncs committed
796
797
798
799
800
801
802
803
    print(f"{args}\n")

    backend = args.backend
    model_id = args.model
    tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model

    tokenizer = get_tokenizer(tokenizer_id)

804
805
806
807
808
809
810
811
812
813
814
815
816
817
    if args.dataset_name == "sharegpt":
        input_requests = sample_sharegpt_requests(
            dataset_path=args.dataset_path,
            num_requests=args.num_prompts,
            tokenizer=tokenizer,
            fixed_output_len=args.sharegpt_output_len,
        )
    elif args.dataset_name == "random":
        input_requests = sample_random_requests(
            input_len=args.random_input_len,
            output_len=args.random_output_len,
            num_prompts=args.num_prompts,
            range_ratio=args.random_range_ratio,
            tokenizer=tokenizer,
Lianmin Zheng's avatar
Lianmin Zheng committed
818
            dataset_path=args.dataset_path,
819
820
821
        )
    else:
        raise ValueError(f"Unknown dataset: {args.dataset_name}")
zhyncs's avatar
zhyncs committed
822

823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
    if args.multi:
        request_rates = parse_request_rate_range(args.request_rate_range)

        for rate in request_rates:
            asyncio.run(
                benchmark(
                    backend=backend,
                    api_url=api_url,
                    model_id=model_id,
                    tokenizer=tokenizer,
                    input_requests=input_requests,
                    request_rate=rate,
                    disable_tqdm=args.disable_tqdm,
                    enable_multi=args.multi,
                )
            )
    else:
        asyncio.run(
            benchmark(
                backend=backend,
                api_url=api_url,
                model_id=model_id,
                tokenizer=tokenizer,
                input_requests=input_requests,
                request_rate=args.request_rate,
                disable_tqdm=args.disable_tqdm,
                enable_multi=args.multi,
            )
zhyncs's avatar
zhyncs committed
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
        )


# to avoid relying on SGLang's components
def set_ulimit(target_soft_limit=65535):
    resource_type = resource.RLIMIT_NOFILE
    current_soft, current_hard = resource.getrlimit(resource_type)

    if current_soft < target_soft_limit:
        try:
            resource.setrlimit(resource_type, (target_soft_limit, current_hard))
        except ValueError as e:
            print(f"Fail to set RLIMIT_NOFILE: {e}")


if __name__ == "__main__":
    parser = FlexibleArgumentParser(
        description="Benchmark the online serving throughput."
    )
    parser.add_argument(
        "--backend",
        type=str,
        required=True,
        choices=list(ASYNC_REQUEST_FUNCS.keys()),
        help="Must specify a backend, depending on the LLM Inference Engine.",
    )
    parser.add_argument(
        "--base-url",
        type=str,
        default=None,
        help="Server or API base url if not using http host and port.",
    )
    parser.add_argument(
        "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
    )
    parser.add_argument(
        "--port",
        type=int,
        help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
    )
    parser.add_argument(
892
893
894
895
896
897
898
899
        "--dataset-name",
        type=str,
        default="sharegpt",
        choices=["sharegpt", "random"],
        help="Name of the dataset to benchmark on.",
    )
    parser.add_argument(
        "--dataset-path", type=str, default="", help="Path to the dataset."
zhyncs's avatar
zhyncs committed
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
    )
    parser.add_argument(
        "--model",
        type=str,
        help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
    )
    parser.add_argument(
        "--tokenizer",
        type=str,
        help="Name or path of the tokenizer. If not set, using the model conf.",
    )
    parser.add_argument(
        "--num-prompts",
        type=int,
        default=1000,
        help="Number of prompts to process. Default is 1000.",
    )
    parser.add_argument(
        "--sharegpt-output-len",
        type=int,
        default=None,
        help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
    )
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
    parser.add_argument(
        "--random-input-len",
        type=int,
        default=1024,
        help="Number of input tokens per request, used only for random dataset.",
    )
    parser.add_argument(
        "--random-output-len",
        type=int,
        default=128,
        help="Number of output tokens per request, used only for random dataset.",
    )
    parser.add_argument(
        "--random-range-ratio",
        type=float,
Yineng Zhang's avatar
Yineng Zhang committed
938
        default=0.0,
939
940
941
        help="Range of sampled ratio of input/output length, "
        "used only for random dataset.",
    )
zhyncs's avatar
zhyncs committed
942
943
944
    parser.add_argument(
        "--request-rate",
        type=float,
945
        default=float("inf"),
zhyncs's avatar
zhyncs committed
946
947
948
949
950
951
952
953
954
        help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
        "Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.",
    )
    parser.add_argument("--seed", type=int, default=0, help="Default is 0.")
    parser.add_argument(
        "--disable-tqdm",
        action="store_true",
        help="Specify to disable tqdm progress bar.",
    )
955
956
957
958
959
960
961
962
963
    parser.add_argument(
        "--multi",
        action="store_true",
        help="Use request rate range rather than single value.",
    )
    parser.add_argument(
        "--request-rate-range",
        type=str,
        default="2,34,2",
zhyncs's avatar
zhyncs committed
964
        help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.",
965
966
    )
    parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
967
968
969
970
971
    parser.add_argument(
        "--disable-stream",
        action="store_true",
        help="Disable streaming mode.",
    )
zhyncs's avatar
zhyncs committed
972
973
974
975
976

    set_ulimit()

    args = parser.parse_args()
    fire(args)