bench_serving.py 85 KB
Newer Older
zhyncs's avatar
zhyncs committed
1
2
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py
3

Ying Sheng's avatar
Ying Sheng committed
4
"""
5
Benchmark online serving with dynamic requests.
Ying Sheng's avatar
Ying Sheng committed
6
7

Usage:
8
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
Ying Sheng's avatar
Ying Sheng committed
9

10
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5
Ying Sheng's avatar
Ying Sheng committed
11
"""
zhyncs's avatar
zhyncs committed
12
13
14

import argparse
import asyncio
15
16
import base64
import io
zhyncs's avatar
zhyncs committed
17
18
import json
import os
19
import pickle
zhyncs's avatar
zhyncs committed
20
21
22
23
24
25
import random
import resource
import sys
import time
import traceback
import warnings
26
from argparse import ArgumentParser
zhyncs's avatar
zhyncs committed
27
from dataclasses import dataclass, field
28
from datetime import datetime
29
from json import JSONDecodeError
30
from pathlib import Path
31
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
zhyncs's avatar
zhyncs committed
32
33
34
35
36
37
38
39
40
41
42
43

import aiohttp
import numpy as np
import requests
from tqdm.asyncio import tqdm
from transformers import (
    AutoTokenizer,
    PreTrainedTokenizer,
    PreTrainedTokenizerBase,
    PreTrainedTokenizerFast,
)

44
ASSISTANT_SUFFIX = "Assistant:"
zhyncs's avatar
zhyncs committed
45

46
47
global args

zhyncs's avatar
zhyncs committed
48

Yineng Zhang's avatar
Yineng Zhang committed
49
50
51
52
53
54
# don't want to import sglang package here
def _get_bool_env_var(name: str, default: str = "false") -> bool:
    value = os.getenv(name, default)
    return value.lower() in ("true", "1")


55
56
57
58
59
60
61
62
63
64
65
66
67
def _create_bench_client_session():
    # When the pressure is big, the read buffer could be full before aio thread read
    # the content. We increase the read_bufsize from 64K to 10M.
    # Define constants for timeout and buffer size for clarity and maintainability
    BENCH_AIOHTTP_TIMEOUT_SECONDS = 6 * 60 * 60  # 6 hours
    BENCH_AIOHTTP_READ_BUFSIZE_BYTES = 10 * 1024**2  # 10 MB

    aiohttp_timeout = aiohttp.ClientTimeout(total=BENCH_AIOHTTP_TIMEOUT_SECONDS)
    return aiohttp.ClientSession(
        timeout=aiohttp_timeout, read_bufsize=BENCH_AIOHTTP_READ_BUFSIZE_BYTES
    )


zhyncs's avatar
zhyncs committed
68
69
70
71
72
73
74
@dataclass
class RequestFuncInput:
    prompt: str
    api_url: str
    prompt_len: int
    output_len: int
    model: str
75
    lora_name: str
76
    image_data: Optional[List[str]]
77
    extra_request_body: Dict[str, Any]
78
    timestamp: Optional[float] = None
zhyncs's avatar
zhyncs committed
79
80
81
82
83
84
85
86
87
88
89


@dataclass
class RequestFuncOutput:
    generated_text: str = ""
    success: bool = False
    latency: float = 0.0
    ttft: float = 0.0  # Time to first token
    itl: List[float] = field(default_factory=list)  # List of inter-token latencies
    prompt_len: int = 0
    error: str = ""
90
    output_len: int = 0
zhyncs's avatar
zhyncs committed
91

92
93
94
95
96
97
    @staticmethod
    def init_new(request_func_input: RequestFuncInput):
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len
        return output

zhyncs's avatar
zhyncs committed
98
99
100
101
102

def remove_prefix(text: str, prefix: str) -> str:
    return text[len(prefix) :] if text.startswith(prefix) else text


103
104
105
106
def remove_suffix(text: str, suffix: str) -> str:
    return text[: -len(suffix)] if text.endswith(suffix) else text


107
108
109
110
111
112
113
114
def get_auth_headers() -> Dict[str, str]:
    api_key = os.environ.get("OPENAI_API_KEY")
    if api_key:
        return {"Authorization": f"Bearer {api_key}"}
    else:
        return {}


115
# trt llm does not support ignore_eos
116
117
118
119
120
121
122
123
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
async def async_request_trt_llm(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

124
    async with _create_bench_client_session() as session:
125
126
127
        payload = {
            "accumulate_tokens": True,
            "text_input": request_func_input.prompt,
zhyncs's avatar
zhyncs committed
128
            "temperature": 0.000001,
129
130
131
            "top_p": 1.0,
            "max_tokens": request_func_input.output_len,
            "stream": True,
Ying Sheng's avatar
Ying Sheng committed
132
133
            "min_length": request_func_input.output_len,
            "end_id": 1048576,
134
            **request_func_input.extra_request_body,
135
        }
136
137
138
        if args.disable_ignore_eos:
            del payload["min_length"]
            del payload["end_id"]
139
        output = RequestFuncOutput.init_new(request_func_input)
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

        ttft = 0.0
        st = time.perf_counter()
        most_recent_timestamp = st
        try:
            async with session.post(url=api_url, json=payload) as response:
                if response.status == 200:
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
                            continue

                        chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:")

                        data = json.loads(chunk)
                        output.generated_text += data["text_output"]
                        timestamp = time.perf_counter()
                        # First token
                        if ttft == 0.0:
Xu Song's avatar
Xu Song committed
159
                            ttft = timestamp - st
160
161
162
163
164
165
166
167
168
169
                            output.ttft = ttft

                        # Decoding phase
                        else:
                            output.itl.append(timestamp - most_recent_timestamp)

                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
                    output.success = True
Ying Sheng's avatar
Ying Sheng committed
170
                    output.output_len = request_func_input.output_len
171
172
173
174
175
176
177
178
179
180
181
182
183
184

                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

        if pbar:
            pbar.update(1)
        return output


zhyncs's avatar
zhyncs committed
185
186
187
188
189
190
191
192
193
194
# set ignore_eos True by default
async def async_request_openai_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith(
        "completions"
    ), "OpenAI Completions API URL must end with 'completions'."

Lianmin Zheng's avatar
Lianmin Zheng committed
195
196
    prompt = request_func_input.prompt

197
    async with _create_bench_client_session() as session:
zhyncs's avatar
zhyncs committed
198
199
        payload = {
            "model": request_func_input.model,
Lianmin Zheng's avatar
Lianmin Zheng committed
200
            "prompt": prompt,
zhyncs's avatar
zhyncs committed
201
202
203
            "temperature": 0.0,
            "best_of": 1,
            "max_tokens": request_func_input.output_len,
204
            "stream": not args.disable_stream,
205
            "ignore_eos": not args.disable_ignore_eos,
206
            **request_func_input.extra_request_body,
zhyncs's avatar
zhyncs committed
207
        }
208
        headers = get_auth_headers()
zhyncs's avatar
zhyncs committed
209

210
        output = RequestFuncOutput.init_new(request_func_input)
zhyncs's avatar
zhyncs committed
211
212

        generated_text = ""
213
        output_len = request_func_input.output_len
zhyncs's avatar
zhyncs committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        ttft = 0.0
        st = time.perf_counter()
        most_recent_timestamp = st
        try:
            async with session.post(
                url=api_url, json=payload, headers=headers
            ) as response:
                if response.status == 200:
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
                            continue

                        chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
228
                        latency = time.perf_counter() - st
zhyncs's avatar
zhyncs committed
229
                        if chunk == "[DONE]":
230
                            pass
zhyncs's avatar
zhyncs committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244
                        else:
                            data = json.loads(chunk)

                            # NOTE: Some completion API might have a last
                            # usage summary response without a token so we
                            # want to check a token was generated
                            if data["choices"][0]["text"]:
                                timestamp = time.perf_counter()
                                # First token
                                if ttft == 0.0:
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft

                                # Decoding phase
245
246
                                else:
                                    output.itl.append(timestamp - most_recent_timestamp)
zhyncs's avatar
zhyncs committed
247
248
249

                                most_recent_timestamp = timestamp
                                generated_text += data["choices"][0]["text"]
Lzhang-hub's avatar
Lzhang-hub committed
250
                                output_len = (data.get("usage") or {}).get(
251
252
                                    "completion_tokens", output_len
                                )
zhyncs's avatar
zhyncs committed
253
254
255
256

                    output.generated_text = generated_text
                    output.success = True
                    output.latency = latency
257
                    output.output_len = output_len
zhyncs's avatar
zhyncs committed
258
259
260
261
262
263
264
265
266
267
268
269
270
                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

    if pbar:
        pbar.update(1)
    return output


271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
async def async_request_openai_chat_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    """Makes a request to the OpenAI Chat Completions API.

    Handles both streaming and non-streaming responses, including support
    for image data in messages. Calculates and returns various performance
    metrics.

    Args:
        request_func_input: Input parameters for the request.
        pbar: Optional tqdm progress bar to update.

    Returns:
        RequestFuncOutput: Output of the request, including generated text,
                           latency, TTFT, ITL, and success status.
    """
    api_url = request_func_input.api_url
    assert api_url.endswith(
        "chat/completions"
    ), "OpenAI Chat Completions API URL must end with 'chat/completions'."

    if request_func_input.image_data:
295
296
297
298
299
300
301
302
303
        # Build multi-image content: a list of image_url entries followed by the text
        content_items = [
            {
                "type": "image_url",
                "image_url": {"url": img_url},
            }
            for img_url in request_func_input.image_data
        ]
        content_items.append({"type": "text", "text": request_func_input.prompt})
304
305
306
        messages = [
            {
                "role": "user",
307
                "content": content_items,
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
            },
        ]
    else:
        messages = [{"role": "user", "content": request_func_input.prompt}]

    async with _create_bench_client_session() as session:
        payload = {
            "model": request_func_input.model,
            "messages": messages,
            "temperature": 0.0,
            "max_tokens": request_func_input.output_len,
            "stream": not args.disable_stream,
            **request_func_input.extra_request_body,
        }
        headers = get_auth_headers()

        output = RequestFuncOutput.init_new(request_func_input)

        generated_text = ""
        output_len = request_func_input.output_len
        ttft = 0.0
        st = time.perf_counter()
        most_recent_timestamp = st
        try:
            async with session.post(
                url=api_url, json=payload, headers=headers
            ) as response:
                if response.status == 200:
                    if args.disable_stream:
                        # Non-streaming response
                        response_json = await response.json()
                        output.generated_text = response_json["choices"][0]["message"][
                            "content"
                        ]
                        output.success = True
                        output.latency = time.perf_counter() - st
                        output.ttft = (
                            output.latency
                        )  # For non-streaming, TTFT = total latency
                        output.output_len = response_json.get("usage", {}).get(
                            "completion_tokens", output_len
                        )
                    else:
                        # Streaming response
                        async for chunk_bytes in response.content:
                            chunk_bytes = chunk_bytes.strip()
                            if not chunk_bytes:
                                continue

                            chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
                            latency = time.perf_counter() - st
                            if chunk == "[DONE]":
                                pass
                            else:
                                data = json.loads(chunk)

                                # Check if this chunk contains content
                                delta = data.get("choices", [{}])[0].get("delta", {})
                                content = delta.get("content", "")

                                if content:
                                    timestamp = time.perf_counter()
                                    # First token
                                    if ttft == 0.0:
                                        ttft = timestamp - st
                                        output.ttft = ttft

                                    # Decoding phase
                                    else:
                                        output.itl.append(
                                            timestamp - most_recent_timestamp
                                        )

                                    most_recent_timestamp = timestamp
                                    generated_text += content

                                # Check for usage info in final chunk
                                output_len = (data.get("usage") or {}).get(
                                    "completion_tokens", output_len
                                )

                        output.generated_text = generated_text
                        output.success = True
                        output.latency = latency
                        output.output_len = output_len
                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

    if pbar:
        pbar.update(1)
    return output


406
407
408
409
410
411
412
413
async def async_request_truss(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url

    prompt = request_func_input.prompt

414
    async with _create_bench_client_session() as session:
415
416
417
418
419
420
421
422
423
424
        payload = {
            "model": request_func_input.model,
            "prompt": prompt,
            "temperature": 0.0,
            "best_of": 1,
            "max_tokens": request_func_input.output_len,
            "stream": not args.disable_stream,
            "ignore_eos": not args.disable_ignore_eos,
            **request_func_input.extra_request_body,
        }
425
        headers = get_auth_headers()
426

427
        output = RequestFuncOutput.init_new(request_func_input)
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452

        generated_text = ""
        ttft = 0.0
        st = time.perf_counter()
        most_recent_timestamp = st
        try:
            async with session.post(
                url=api_url, json=payload, headers=headers
            ) as response:
                if response.status == 200:
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
                            continue

                        chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
                        latency = time.perf_counter() - st
                        if chunk == "[DONE]":
                            pass
                        else:
                            data = json.loads(chunk)

                            # NOTE: Some completion API might have a last
                            # usage summary response without a token so we
                            # want to check a token was generated
453
                            if data["choices"][0]["text"]:
454
455
456
457
458
459
460
461
462
463
464
                                timestamp = time.perf_counter()
                                # First token
                                if ttft == 0.0:
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft

                                # Decoding phase
                                else:
                                    output.itl.append(timestamp - most_recent_timestamp)

                                most_recent_timestamp = timestamp
465
                                generated_text += data["choices"][0]["text"]
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483

                    output.generated_text = generated_text
                    output.success = True
                    output.latency = latency
                    output.output_len = request_func_input.output_len
                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

    if pbar:
        pbar.update(1)
    return output


484
485
486
487
488
489
490
async def async_request_sglang_generate(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    prompt = request_func_input.prompt

491
    async with _create_bench_client_session() as session:
492
        payload = {
493
            ("text" if isinstance(prompt, str) else "input_ids"): prompt,
494
495
496
497
498
499
            "sampling_params": {
                "temperature": 0.0,
                "max_new_tokens": request_func_input.output_len,
                "ignore_eos": not args.disable_ignore_eos,
            },
            "stream": not args.disable_stream,
500
            "lora_path": request_func_input.lora_name,
501
502
            "return_logprob": args.return_logprob,
            "logprob_start_len": -1,
503
504
            **request_func_input.extra_request_body,
        }
505

506
        # Add image data if available (list of image urls/base64)
507
508
509
        if request_func_input.image_data:
            payload["image_data"] = request_func_input.image_data

510
        headers = get_auth_headers()
511

512
        output = RequestFuncOutput.init_new(request_func_input)
513
514

        generated_text = ""
515
        output_len = request_func_input.output_len
516
517
518
        ttft = 0.0
        st = time.perf_counter()
        most_recent_timestamp = st
519
        last_output_len = 0
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
        try:
            async with session.post(
                url=api_url, json=payload, headers=headers
            ) as response:
                if response.status == 200:
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
                            continue

                        chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
                        latency = time.perf_counter() - st
                        if chunk == "[DONE]":
                            pass
                        else:
                            data = json.loads(chunk)

                            # NOTE: Some completion API might have a last
                            # usage summary response without a token so we
                            # want to check a token was generated
Zijian's avatar
Zijian committed
540
                            if "text" in data and data["text"]:
541
                                timestamp = time.perf_counter()
542
543
544
                                generated_text = data["text"]
                                output_len = data["meta_info"]["completion_tokens"]

545
546
547
548
549
550
551
                                # First token
                                if ttft == 0.0:
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft

                                # Decoding phase
                                else:
552
553
554
555
556
557
558
                                    num_new_tokens = output_len - last_output_len
                                    if num_new_tokens == 0:
                                        continue
                                    adjust_itl = (
                                        timestamp - most_recent_timestamp
                                    ) / num_new_tokens
                                    output.itl.extend([adjust_itl] * num_new_tokens)
559
560

                                most_recent_timestamp = timestamp
Lianmin Zheng's avatar
Lianmin Zheng committed
561
                                last_output_len = output_len
562
563
564
565

                    output.generated_text = generated_text
                    output.success = True
                    output.latency = latency
566
                    output.output_len = output_len
567
568
569
570
571
572
573
                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
574
            print(f"{output.error=}")
575
576
577
578
579
580

    if pbar:
        pbar.update(1)
    return output


581
async def async_request_gserver(
Lianmin Zheng's avatar
Lianmin Zheng committed
582
583
584
585
586
587
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    raise NotImplementedError()


588
async def async_request_profile(api_url: str) -> RequestFuncOutput:
589
    async with _create_bench_client_session() as session:
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
        output = RequestFuncOutput()
        try:
            async with session.post(url=api_url) as response:
                if response.status == 200:
                    output.success = True
                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

    return output


zhyncs's avatar
zhyncs committed
606
def get_model(pretrained_model_name_or_path: str) -> str:
607
    if os.getenv("SGLANG_USE_MODELSCOPE", "false").lower() == "true":
zhyncs's avatar
zhyncs committed
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
        import huggingface_hub.constants
        from modelscope import snapshot_download

        model_path = snapshot_download(
            model_id=pretrained_model_name_or_path,
            local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
            ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
        )

        return model_path
    return pretrained_model_name_or_path


def get_tokenizer(
    pretrained_model_name_or_path: str,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
624
625
626
627
    assert (
        pretrained_model_name_or_path is not None
        and pretrained_model_name_or_path != ""
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
628
629
630
631
632
633
634
    if pretrained_model_name_or_path.endswith(
        ".json"
    ) or pretrained_model_name_or_path.endswith(".model"):
        from sglang.srt.hf_transformers_utils import get_tokenizer

        return get_tokenizer(pretrained_model_name_or_path)

zhyncs's avatar
zhyncs committed
635
636
637
638
639
640
641
642
643
    if pretrained_model_name_or_path is not None and not os.path.exists(
        pretrained_model_name_or_path
    ):
        pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)
    return AutoTokenizer.from_pretrained(
        pretrained_model_name_or_path, trust_remote_code=True
    )


644
def get_dataset(args, tokenizer):
fzyzcjy's avatar
fzyzcjy committed
645
    tokenize_prompt = getattr(args, "tokenize_prompt", False)
646
    if args.dataset_name == "sharegpt":
fzyzcjy's avatar
fzyzcjy committed
647
        assert not tokenize_prompt
648
649
650
651
652
        input_requests = sample_sharegpt_requests(
            dataset_path=args.dataset_path,
            num_requests=args.num_prompts,
            tokenizer=tokenizer,
            fixed_output_len=args.sharegpt_output_len,
653
            context_len=args.sharegpt_context_len,
654
            prompt_suffix=args.prompt_suffix,
655
            apply_chat_template=args.apply_chat_template,
656
        )
657
    elif args.dataset_name.startswith("random") and args.dataset_name != "random-image":
658
659
660
661
662
663
664
        input_requests = sample_random_requests(
            input_len=args.random_input_len,
            output_len=args.random_output_len,
            num_prompts=args.num_prompts,
            range_ratio=args.random_range_ratio,
            tokenizer=tokenizer,
            dataset_path=args.dataset_path,
665
            random_sample=args.dataset_name == "random",
fzyzcjy's avatar
fzyzcjy committed
666
            return_text=not tokenize_prompt,
667
        )
668
669
670
671
672
673
674
675
676
677
678
679
    elif args.dataset_name == "random-image":
        assert not tokenize_prompt, "random-image does not support --tokenize-prompt"
        input_requests = sample_random_image_requests(
            num_requests=args.num_prompts,
            num_images=args.random_image_num_images,
            input_len=args.random_input_len,
            output_len=args.random_output_len,
            range_ratio=args.random_range_ratio,
            tokenizer=tokenizer,
            apply_chat_template=args.apply_chat_template,
            image_resolution=args.random_image_resolution,
        )
680
    elif args.dataset_name == "generated-shared-prefix":
fzyzcjy's avatar
fzyzcjy committed
681
        assert not tokenize_prompt
682
        input_requests = sample_generated_shared_prefix_requests(
683
684
685
686
687
            num_groups=args.gsp_num_groups,
            prompts_per_group=args.gsp_prompts_per_group,
            system_prompt_len=args.gsp_system_prompt_len,
            question_len=args.gsp_question_len,
            output_len=args.gsp_output_len,
688
            tokenizer=tokenizer,
689
            args=args,
690
        )
691
    elif args.dataset_name == "mmmu":
fzyzcjy's avatar
fzyzcjy committed
692
        assert not tokenize_prompt
693
694
695
696
        input_requests = sample_mmmu_requests(
            num_requests=args.num_prompts,
            tokenizer=tokenizer,
            fixed_output_len=args.random_output_len,
697
            apply_chat_template=args.apply_chat_template,
698
699
            random_sample=True,
        )
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
    elif args.dataset_name == "mooncake":
        # For mooncake, we don't generate the prompts here.
        # We just load the raw trace data. The async generator will handle the rest.
        if not args.dataset_path:
            local_path = os.path.join("/tmp",  args.mooncake_workload + "_trace.jsonl")
        else:
            local_path = args.dataset_path

        if not os.path.exists(local_path):
            download_and_cache_file(MOONCAKE_DATASET_URL[args.mooncake_workload], local_path)

        with open(local_path, "r") as f:
            all_requests_data = [json.loads(line) for line in f if line.strip()]

        # Limit the number of requests based on --num-prompts
        input_requests = all_requests_data[: args.num_prompts]
716
717
718
719
720
    else:
        raise ValueError(f"Unknown dataset: {args.dataset_name}")
    return input_requests


zhyncs's avatar
zhyncs committed
721
ASYNC_REQUEST_FUNCS = {
722
723
724
    "sglang": async_request_sglang_generate,
    "sglang-native": async_request_sglang_generate,
    "sglang-oai": async_request_openai_completions,
725
    "sglang-oai-chat": async_request_openai_chat_completions,
zhyncs's avatar
zhyncs committed
726
    "vllm": async_request_openai_completions,
727
    "vllm-chat": async_request_openai_chat_completions,
zhyncs's avatar
zhyncs committed
728
    "lmdeploy": async_request_openai_completions,
729
    "lmdeploy-chat": async_request_openai_chat_completions,
730
    "trt": async_request_trt_llm,
731
    "gserver": async_request_gserver,
732
    "truss": async_request_truss,
zhyncs's avatar
zhyncs committed
733
734
735
736
737
738
739
740
}


@dataclass
class BenchmarkMetrics:
    completed: int
    total_input: int
    total_output: int
Ying Sheng's avatar
Ying Sheng committed
741
    total_output_retokenized: int
zhyncs's avatar
zhyncs committed
742
743
744
    request_throughput: float
    input_throughput: float
    output_throughput: float
Ying Sheng's avatar
Ying Sheng committed
745
    output_throughput_retokenized: float
746
747
    total_throughput: float
    total_throughput_retokenized: float
zhyncs's avatar
zhyncs committed
748
749
750
751
752
753
754
755
756
757
758
    mean_ttft_ms: float
    median_ttft_ms: float
    std_ttft_ms: float
    p99_ttft_ms: float
    mean_tpot_ms: float
    median_tpot_ms: float
    std_tpot_ms: float
    p99_tpot_ms: float
    mean_itl_ms: float
    median_itl_ms: float
    std_itl_ms: float
759
    p95_itl_ms: float
zhyncs's avatar
zhyncs committed
760
    p99_itl_ms: float
761
    max_itl_ms: float
zhyncs's avatar
zhyncs committed
762
763
    mean_e2e_latency_ms: float
    median_e2e_latency_ms: float
764
765
    std_e2e_latency_ms: float
    p99_e2e_latency_ms: float
766
    concurrency: float
zhyncs's avatar
zhyncs committed
767
768


Lianmin Zheng's avatar
Lianmin Zheng committed
769
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
770
771
772
773
774
775
MOONCAKE_DATASET_URL = {
    "mooncake": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/arxiv-trace/mooncake_trace.jsonl",
    "conversation": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/conversation_trace.jsonl",
    "synthetic": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/synthetic_trace.jsonl",
    "toolagent": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/toolagent_trace.jsonl",
}
Lianmin Zheng's avatar
Lianmin Zheng committed
776
777


Lianmin Zheng's avatar
Lianmin Zheng committed
778
779
780
781
def download_and_cache_file(url: str, filename: Optional[str] = None):
    """Read and cache a file from a url."""
    if filename is None:
        filename = os.path.join("/tmp", url.split("/")[-1])
Lianmin Zheng's avatar
Lianmin Zheng committed
782

Lianmin Zheng's avatar
Lianmin Zheng committed
783
    # Check if the cache file already exists
784
    if is_file_valid_json(filename):
Lianmin Zheng's avatar
Lianmin Zheng committed
785
        return filename
Lianmin Zheng's avatar
Lianmin Zheng committed
786

Lianmin Zheng's avatar
Lianmin Zheng committed
787
    print(f"Downloading from {url} to {filename}")
Lianmin Zheng's avatar
Lianmin Zheng committed
788

Lianmin Zheng's avatar
Lianmin Zheng committed
789
790
791
    # Stream the response to show the progress bar
    response = requests.get(url, stream=True)
    response.raise_for_status()  # Check for request errors
Lianmin Zheng's avatar
Lianmin Zheng committed
792

Lianmin Zheng's avatar
Lianmin Zheng committed
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
    # Total size of the file in bytes
    total_size = int(response.headers.get("content-length", 0))
    chunk_size = 1024  # Download in chunks of 1KB

    # Use tqdm to display the progress bar
    with open(filename, "wb") as f, tqdm(
        desc=filename,
        total=total_size,
        unit="B",
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for chunk in response.iter_content(chunk_size=chunk_size):
            f.write(chunk)
            bar.update(len(chunk))

    return filename
Lianmin Zheng's avatar
Lianmin Zheng committed
810
811


812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
def is_file_valid_json(path):
    if not os.path.isfile(path):
        return False

    # TODO can fuse into the real file open later
    try:
        with open(path) as f:
            json.load(f)
        return True
    except JSONDecodeError as e:
        print(
            f"{path} exists but json loading fails ({e=}), thus treat as invalid file"
        )
        return False


828
829
830
831
832
@dataclass
class DatasetRow:
    prompt: str
    prompt_len: int
    output_len: int
833
    image_data: Optional[List[str]] = None
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
    timestamp: Optional[float] = None


async def get_mooncake_request_over_time(
    input_requests: List[Dict],
    tokenizer: PreTrainedTokenizerBase,
    slowdown_factor: float,
    num_rounds: int,
) -> AsyncGenerator[DatasetRow, None]:
    """
    An async generator that yields requests based on the timestamps in the Mooncake trace file,
    with support for multi-round sessions.
    """
    if not input_requests:
        return

    input_requests.sort(key=lambda r: r["timestamp"])

    start_time = time.perf_counter()
    trace_start_time_ms = input_requests[0]["timestamp"]

    for record in input_requests:
        # Calculate when this entire session should start
        relative_arrival_time_s = (record["timestamp"] - trace_start_time_ms) / 1000.0
        target_arrival_time_s = relative_arrival_time_s * slowdown_factor

        current_elapsed_time_s = time.perf_counter() - start_time
        sleep_duration_s = target_arrival_time_s - current_elapsed_time_s
        if sleep_duration_s > 0:
            await asyncio.sleep(sleep_duration_s)

        # Once the session starts, generate all rounds for it as a burst
        # This simulates a user engaging in a multi-turn conversation

        # Base user query constructed from hash_ids
        user_query_base = ""
        hash_ids = record.get("hash_ids", [])
        for hash_id in hash_ids:
            user_query_base += f"{hash_id}" + " ".join(
                ["hi"] * 128
            )  # Shorter for multi-round
        user_query_base += "Tell me a story based on this context."

        output_len_per_round = record.get("output_length", 256)
        chat_history = []

        for i in range(num_rounds):
            # Add user query for the current round
            chat_history.append(
                {"role": "user", "content": f"Round {i+1}: {user_query_base}"}
            )

            # Form the full prompt from history
            try:
                full_prompt_text = tokenizer.apply_chat_template(
                    chat_history, tokenize=False, add_generation_prompt=True
                )
            except Exception:
                full_prompt_text = "\n".join(
                    [f"{msg['role']}: {msg['content']}" for msg in chat_history]
                )

            prompt_len = len(tokenizer.encode(full_prompt_text))

            yield DatasetRow(
                prompt=full_prompt_text,
                prompt_len=prompt_len,
                output_len=output_len_per_round,
            )

            # Add a placeholder assistant response for the next round's context
            # We use a placeholder because we don't know the real response
            placeholder_response = " ".join(["story"] * output_len_per_round)
            chat_history.append({"role": "assistant", "content": placeholder_response})
908
909


910
911
912
913
def sample_mmmu_requests(
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
    fixed_output_len: Optional[int] = None,
914
    apply_chat_template: bool = True,
915
    random_sample: bool = True,
916
) -> List[DatasetRow]:
917
918
919
920
921
922
923
    """
    Sample requests from the MMMU dataset using HuggingFace datasets.

    Args:
        num_requests: Number of requests to sample.
        tokenizer: Tokenizer to use for token counting.
        fixed_output_len: If provided, use this fixed output length for all requests.
924
        apply_chat_template: Whether to apply the chat template to the prompt.
925
926
927
928
929
930
931
932
        random_sample: Whether to randomly sample or take the first N.

    Returns:
        List of tuples (prompt, prompt_token_len, output_token_len).
    """
    try:
        import io

933
        import pybase64
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
        from datasets import load_dataset
    except ImportError:
        raise ImportError("Please install datasets: pip install datasets")

    print("Loading MMMU dataset from HuggingFace...")

    try:
        print("Attempting to load MMMU Math dataset...")
        mmmu_dataset = load_dataset("MMMU/MMMU", "Math", split="test")
        print(
            f"Successfully loaded MMMU Math dataset from HuggingFace with {len(mmmu_dataset)} examples"
        )
    except Exception as e:
        print(f"Failed to load MMMU Math dataset: {e}")
        raise ValueError(f"Failed to load MMMU dataset: {e}")

    # Sample from the dataset
    if len(mmmu_dataset) > num_requests:
        if random_sample:
            # Random sample
            indices = random.sample(range(len(mmmu_dataset)), num_requests)
            sample_dataset = mmmu_dataset.select(indices)
        else:
            # Take first N
            sample_dataset = mmmu_dataset.select(
                range(min(num_requests, len(mmmu_dataset)))
            )
    else:
        print(f"Dataset has less than {num_requests} examples, using all examples")
        sample_dataset = mmmu_dataset

    print(f"Selected {len(sample_dataset)} examples for benchmarking")

    # Create prompts
    filtered_dataset = []

    for i, example in enumerate(sample_dataset):
        try:
            # Extract image_1
            image = example.get("image_1")

            if image is not None:
                if hasattr(image, "save"):
                    # Convert RGBA images to RGB before encoding
                    if image.mode == "RGBA":
                        image = image.convert("RGB")

981
                    # Encode image to base64 (save as PNG to support palette/alpha modes)
982
                    buffered = io.BytesIO()
983
                    image.save(buffered, format="PNG")
984
                    img_str = pybase64.b64encode(buffered.getvalue()).decode("utf-8")
985
                    image_data = f"data:image/png;base64,{img_str}"
986
987
988
989
990
991
                else:
                    continue

                # Extract the question
                question = example.get("question")

992
                # Construct the prompt
993
                prompt = f"Question: {question}\n\nAnswer: "
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
                if apply_chat_template:
                    try:
                        prompt = tokenizer.apply_chat_template(
                            [
                                {
                                    "role": "user",
                                    "content": [
                                        {
                                            "type": "image_url",
                                            "image_url": {"url": image_data},
                                        },
                                        {"type": "text", "text": prompt},
                                    ],
                                }
                            ],
                            add_generation_prompt=True,
                            tokenize=False,
                        )
                    except Exception as e:
                        # Note (Xinyuan): This is a workaround for an issue where some tokenizers do not support content as a list. (e.g. InternVL)
                        print(
                            f"Error applying chat template: {e}, fallback to <image> tag"
                        )
                        prompt = f"<image>{prompt}"
1018
1019

                # Calculate token lengths for text only (without image data)
1020
                prompt_token_ids = tokenizer.encode(prompt)
1021
                prompt_len = len(prompt_token_ids)
1022
1023
1024

                output_len = fixed_output_len if fixed_output_len is not None else 256

1025
1026
                filtered_dataset.append(
                    DatasetRow(
1027
1028
1029
                        prompt=prompt,
                        prompt_len=prompt_len,
                        output_len=output_len,
1030
                        image_data=[image_data],
1031
1032
                    )
                )
1033
1034
1035
1036
1037
1038
1039
1040

        except Exception as e:
            print(f"Error processing example {i}: {e}")

    print(f"\nCreated {len(filtered_dataset)} MMMU prompts")
    return filtered_dataset


zhyncs's avatar
zhyncs committed
1041
1042
1043
1044
1045
def sample_sharegpt_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
    fixed_output_len: Optional[int] = None,
1046
    context_len: Optional[int] = None,
1047
    prompt_suffix: Optional[str] = "",
1048
    apply_chat_template=False,
1049
) -> List[DatasetRow]:
zhyncs's avatar
zhyncs committed
1050
1051
1052
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")

Lianmin Zheng's avatar
Lianmin Zheng committed
1053
    # Download sharegpt if necessary
1054
    if not is_file_valid_json(dataset_path) and dataset_path == "":
Lianmin Zheng's avatar
Lianmin Zheng committed
1055
        dataset_path = download_and_cache_file(SHAREGPT_URL)
zhyncs's avatar
zhyncs committed
1056
1057
1058
1059

    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
1060

zhyncs's avatar
zhyncs committed
1061
    # Filter out the conversations with less than 2 turns.
1062
1063
1064
1065
1066
    dataset = [
        data
        for data in dataset
        if len(data.get("conversations", data.get("conversation", []))) >= 2
    ]
zhyncs's avatar
zhyncs committed
1067
1068
    # Only keep the first two turns of each conversation.
    dataset = [
1069
1070
1071
1072
        (
            data.get("conversations", data.get("conversation", []))[0]["value"],
            data.get("conversations", data.get("conversation", []))[1]["value"],
        )
zhyncs's avatar
zhyncs committed
1073
1074
1075
1076
1077
1078
1079
        for data in dataset
    ]

    # Shuffle the dataset.
    random.shuffle(dataset)

    # Filter out sequences that are too long or too short
1080
    filtered_dataset: List[DatasetRow] = []
zhyncs's avatar
zhyncs committed
1081
1082
1083
1084
1085
1086
    for i in range(len(dataset)):
        if len(filtered_dataset) == num_requests:
            break

        # Tokenize the prompts and completions.
        prompt = dataset[i][0]
1087
        if prompt_suffix:
1088
1089
1090
1091
1092
            prompt = (
                remove_suffix(prompt, ASSISTANT_SUFFIX)
                + prompt_suffix
                + ASSISTANT_SUFFIX
            )
1093
1094
1095
1096
1097
1098
1099
1100
1101

        if apply_chat_template:
            prompt = tokenizer.apply_chat_template(
                [{"role": "user", "content": prompt}],
                add_generation_prompt=True,
                tokenize=False,
            )
            prompt = prompt.replace(tokenizer.bos_token, "")

Lianmin Zheng's avatar
Lianmin Zheng committed
1102
        prompt_token_ids = tokenizer.encode(prompt)
zhyncs's avatar
zhyncs committed
1103
        completion = dataset[i][1]
Lianmin Zheng's avatar
Lianmin Zheng committed
1104
        completion_token_ids = tokenizer.encode(completion)
zhyncs's avatar
zhyncs committed
1105
1106
1107
1108
        prompt_len = len(prompt_token_ids)
        output_len = (
            len(completion_token_ids) if fixed_output_len is None else fixed_output_len
        )
1109

1110
        if prompt_len < 2 or output_len < 2:
zhyncs's avatar
zhyncs committed
1111
1112
            # Prune too short sequences.
            continue
1113
1114

        if context_len and prompt_len + output_len > context_len:
zhyncs's avatar
zhyncs committed
1115
1116
            # Prune too long sequences.
            continue
1117

1118
1119
1120
        filtered_dataset.append(
            DatasetRow(prompt=prompt, prompt_len=prompt_len, output_len=output_len)
        )
zhyncs's avatar
zhyncs committed
1121

1122
1123
    print(f"#Input tokens: {np.sum([x.prompt_len for x in filtered_dataset])}")
    print(f"#Output tokens: {np.sum([x.output_len for x in filtered_dataset])}")
zhyncs's avatar
zhyncs committed
1124
1125
1126
    return filtered_dataset


1127
1128
1129
1130
1131
1132
def sample_random_requests(
    input_len: int,
    output_len: int,
    num_prompts: int,
    range_ratio: float,
    tokenizer: PreTrainedTokenizerBase,
Lianmin Zheng's avatar
Lianmin Zheng committed
1133
    dataset_path: str,
1134
    random_sample: bool = True,
1135
    return_text: bool = True,
1136
) -> List[DatasetRow]:
1137
    input_lens = np.random.randint(
Yineng Zhang's avatar
Yineng Zhang committed
1138
        max(int(input_len * range_ratio), 1),
1139
1140
1141
1142
1143
1144
1145
1146
        input_len + 1,
        size=num_prompts,
    )
    output_lens = np.random.randint(
        int(output_len * range_ratio),
        output_len + 1,
        size=num_prompts,
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1147

1148
    if random_sample:
Lianmin Zheng's avatar
Lianmin Zheng committed
1149
1150
1151
        # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens

        # Download sharegpt if necessary
1152
        if not is_file_valid_json(dataset_path):
Lianmin Zheng's avatar
Lianmin Zheng committed
1153
            dataset_path = download_and_cache_file(SHAREGPT_URL)
Lianmin Zheng's avatar
Lianmin Zheng committed
1154
1155
1156
1157
1158

        # Load the dataset.
        with open(dataset_path) as f:
            dataset = json.load(f)
        # Filter out the conversations with less than 2 turns.
1159
1160
1161
1162
1163
        dataset = [
            data
            for data in dataset
            if len(data.get("conversations", data.get("conversation", []))) >= 2
        ]
Lianmin Zheng's avatar
Lianmin Zheng committed
1164
1165
        # Only keep the first two turns of each conversation.
        dataset = [
1166
1167
1168
1169
            (
                data.get("conversations", data.get("conversation", []))[0]["value"],
                data.get("conversations", data.get("conversation", []))[1]["value"],
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1170
1171
1172
1173
1174
1175
            for data in dataset
        ]
        # Shuffle the dataset.
        random.shuffle(dataset)

        # Filter out sequences that are too long or too short
1176
        input_requests: List[DatasetRow] = []
1177
1178
1179
1180
1181
        for data in dataset:
            i = len(input_requests)
            if i == num_prompts:
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1182
            # Tokenize the prompts and completions.
1183
            prompt = data[0]
Lianmin Zheng's avatar
Lianmin Zheng committed
1184
            prompt_token_ids = tokenizer.encode(prompt)
Lianmin Zheng's avatar
Lianmin Zheng committed
1185
1186
            prompt_len = len(prompt_token_ids)

1187
1188
1189
1190
            # Skip empty prompt
            if prompt_len == 0:
                continue

Yineng Zhang's avatar
Yineng Zhang committed
1191
            if prompt_len > input_lens[i]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1192
1193
1194
1195
                input_ids = prompt_token_ids[: input_lens[i]]
            else:
                ratio = (input_lens[i] + prompt_len - 1) // prompt_len
                input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
1196
1197
1198
            input_content = input_ids
            if return_text:
                input_content = tokenizer.decode(input_content)
1199
1200
            input_requests.append(
                DatasetRow(
1201
                    prompt=input_content,
1202
1203
1204
1205
                    prompt_len=int(input_lens[i]),
                    output_len=int(output_lens[i]),
                )
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1206
1207
1208
1209
1210
    else:
        # Sample token ids from random integers. This can cause some NaN issues.
        offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
        input_requests = []
        for i in range(num_prompts):
1211
1212
1213
1214
1215
1216
            input_content = [
                (offsets[i] + i + j) % tokenizer.vocab_size
                for j in range(input_lens[i])
            ]
            if return_text:
                input_content = tokenizer.decode(input_content)
1217
1218
            input_requests.append(
                DatasetRow(
1219
                    prompt=input_content,
1220
1221
1222
1223
                    prompt_len=int(input_lens[i]),
                    output_len=int(output_lens[i]),
                )
            )
1224
1225
1226
1227
1228
1229

    print(f"#Input tokens: {np.sum(input_lens)}")
    print(f"#Output tokens: {np.sum(output_lens)}")
    return input_requests


1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
def parse_random_image_resolution(image_resolution: str) -> Tuple[int, int]:
    """Parse image resolution into (width, height).

    Supports presets '1080p', '720p', '360p' and custom 'heightxwidth' format
    (e.g., '1080x1920' means height=1080, width=1920).
    """
    resolution_to_size = {
        "4k": (3840, 2160),
        "1080p": (1920, 1080),
        "720p": (1280, 720),
        "360p": (640, 360),
    }
    if image_resolution in resolution_to_size:
        return resolution_to_size[image_resolution]

    res = image_resolution.strip().lower()
    if "x" in res:
        parts = res.split("x")
        if len(parts) == 2 and parts[0].isdigit() and parts[1].isdigit():
            height = int(parts[0])
            width = int(parts[1])
            if height > 0 and width > 0:
                return (width, height)

    raise ValueError(
        f"Unsupported random-image resolution: {image_resolution}. "
        "Choose from 4k, 1080p, 720p, 360p, or provide custom 'heightxwidth' (e.g., 1080x1920)."
    )


def sample_random_image_requests(
    num_requests: int,
    num_images: int,
    input_len: int,
    output_len: int,
    range_ratio: float,
    tokenizer: PreTrainedTokenizerBase,
    apply_chat_template: bool = True,
    image_resolution: str = "1080p",
) -> List[DatasetRow]:
    """Generate requests with random images.

    - Each request includes ``num_images`` random images.
    - Supported resolutions: 4k (3840x2160), 1080p (1920x1080), 720p (1280x720), 360p (640x360),
      or custom 'heightxwidth' (e.g., 1080x1920).
    - Text lengths follow the 'random' dataset sampling rule. ``prompt_len``
      only counts text tokens and excludes image data.
    """
    try:
        import pybase64
        from PIL import Image
    except ImportError as e:
        raise ImportError(
            "Please install Pillow to generate random images: pip install pillow"
        ) from e

    # Parse resolution (supports presets and 'heightxwidth')
    width, height = parse_random_image_resolution(image_resolution)

    # Check for potentially problematic combinations and warn user
    if width * height >= 1920 * 1080 and num_images * num_requests >= 100:
        warnings.warn(
            f"High resolution ({width}x{height}) with {num_images * num_requests} total images "
            f"may take a long time. Consider reducing resolution or image count.",
            UserWarning,
            stacklevel=2,
        )

    # Sample text lengths
    input_lens = np.random.randint(
        max(int(input_len * range_ratio), 1), input_len + 1, size=num_requests
    )
    output_lens = np.random.randint(
        int(output_len * range_ratio), output_len + 1, size=num_requests
    )

    def _gen_random_image_data_uri(width: int = width, height: int = height) -> str:
        arr = (np.random.rand(height, width, 3) * 255).astype(np.uint8)
        img = Image.fromarray(arr, mode="RGB")
        buf = io.BytesIO()
        img.save(buf, format="JPEG", quality=85)
        encoded = pybase64.b64encode(buf.getvalue()).decode("utf-8")
        return f"data:image/jpeg;base64,{encoded}"

    dataset: List[DatasetRow] = []
    for i in range(num_requests):
        # Generate text prompt
        text_prompt = gen_prompt(tokenizer, int(input_lens[i]))

        # Generate image list
        images = [_gen_random_image_data_uri() for _ in range(num_images)]

        prompt_str = text_prompt
        if apply_chat_template:
            try:
                content_items = [
                    {"type": "image_url", "image_url": {"url": img_url}}
                    for img_url in images
                ]
                content_items.append({"type": "text", "text": text_prompt})
                prompt_str = tokenizer.apply_chat_template(
                    [{"role": "user", "content": content_items}],
                    add_generation_prompt=True,
                    tokenize=False,
                )
            except Exception:
                # Some tokenizers do not support list content; fall back to a placeholder in the text
                prompt_str = f"<image>{text_prompt}"

        prompt_token_ids = tokenizer.encode(prompt_str)
        prompt_token_len = len(prompt_token_ids)

        dataset.append(
            DatasetRow(
                prompt=prompt_str,
                prompt_len=prompt_token_len,
                output_len=int(output_lens[i]),
                image_data=images,
            )
        )

    print(f"#Input tokens: {np.sum([x.prompt_len for x in dataset])}")
    print(f"#Output tokens: {np.sum([x.output_len for x in dataset])}")
    return dataset


1356
1357
1358
1359
1360
1361
1362
def gen_prompt(tokenizer, token_num):
    """Generate a random prompt of specified token length using tokenizer vocabulary."""
    all_available_tokens = list(tokenizer.get_vocab().values())
    selected_tokens = random.choices(all_available_tokens, k=token_num)
    return tokenizer.decode(selected_tokens)


1363
1364
1365
1366
1367
1368
def get_gen_prefix_cache_path(args, tokenizer):
    """Create cache directory under ~/.cache/sglang/benchmark"""
    cache_dir = Path.home() / ".cache" / "sglang" / "benchmark"

    # Create a unique cache filename based on the generation parameters
    cache_key = (
1369
1370
        f"gen_shared_prefix_{args.gsp_num_groups}_{args.gsp_prompts_per_group}_"
        f"{args.gsp_system_prompt_len}_{args.gsp_question_len}_{args.gsp_output_len}_"
1371
1372
1373
1374
1375
        f"{tokenizer.__class__.__name__}.pkl"
    )
    return cache_dir / cache_key


1376
1377
1378
1379
1380
1381
1382
def sample_generated_shared_prefix_requests(
    num_groups: int,
    prompts_per_group: int,
    system_prompt_len: int,
    question_len: int,
    output_len: int,
    tokenizer: PreTrainedTokenizerBase,
1383
    args: argparse.Namespace,
1384
) -> List[DatasetRow]:
1385
1386
1387
1388
1389
1390
1391
    """Generate benchmark requests with shared system prompts using random tokens and caching."""
    cache_path = get_gen_prefix_cache_path(args, tokenizer)

    # Try to load from cache first
    if cache_path.exists():
        print(f"\nLoading cached generated input data from {cache_path}")
        with open(cache_path, "rb") as f:
1392
1393
            return pickle.load(f)

1394
1395
    print("\nGenerating new input data...")

1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
    # Generate system prompts for each group
    system_prompts = []
    for _ in range(num_groups):
        system_prompt = gen_prompt(tokenizer, system_prompt_len)
        system_prompts.append(system_prompt)

    # Generate questions
    questions = []
    for _ in range(num_groups * prompts_per_group):
        question = gen_prompt(tokenizer, question_len)
        questions.append(question)

    # Combine system prompts with questions
    input_requests = []
    total_input_tokens = 0
    total_output_tokens = 0

1413
    for group_idx in tqdm(range(num_groups), desc="Generating system prompt"):
1414
        system_prompt = system_prompts[group_idx]
1415
1416
1417
        for prompt_idx in tqdm(
            range(prompts_per_group), desc="Generating questions", leave=False
        ):
1418
1419
1420
1421
            question = questions[group_idx * prompts_per_group + prompt_idx]
            full_prompt = f"{system_prompt}\n\n{question}"
            prompt_len = len(tokenizer.encode(full_prompt))

1422
1423
1424
1425
1426
            input_requests.append(
                DatasetRow(
                    prompt=full_prompt, prompt_len=prompt_len, output_len=output_len
                )
            )
1427
1428
1429
            total_input_tokens += prompt_len
            total_output_tokens += output_len

1430
1431
1432
1433
    # Shuffle questions
    random.shuffle(input_requests)

    # Print statistics
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
    print(f"\nGenerated shared prefix dataset statistics:")
    print(f"Number of groups: {num_groups}")
    print(f"Prompts per group: {prompts_per_group}")
    print(f"Total prompts: {len(input_requests)}")
    print(f"Total input tokens: {total_input_tokens}")
    print(f"Total output tokens: {total_output_tokens}")
    print(
        f"Average system prompt length: {sum(len(tokenizer.encode(sp)) for sp in system_prompts) / len(system_prompts):.1f} tokens"
    )
    print(
        f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
    )
1446
1447
1448
1449
1450
1451

    # Save to cache
    cache_path.parent.mkdir(parents=True, exist_ok=True)
    print(f"Caching generated input data to {cache_path}")
    with open(cache_path, "wb") as f:
        pickle.dump(input_requests, f)
1452
1453
1454
1455

    return input_requests


zhyncs's avatar
zhyncs committed
1456
async def get_request(
1457
    input_requests: List[DatasetRow],
zhyncs's avatar
zhyncs committed
1458
    request_rate: float,
1459
1460
    use_trace_timestamps: bool = False,
    slowdown_factor: float = 1.0,
1461
) -> AsyncGenerator[DatasetRow, None]:
1462
1463
1464
1465
1466
1467
    if use_trace_timestamps:
        print(
            f"Using trace timestamps for request generation with slowdown factor {slowdown_factor}."
        )
        # Sort requests by timestamp for correct replay
        input_requests.sort(key=lambda r: r.timestamp)
zhyncs's avatar
zhyncs committed
1468

1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
        start_time = time.perf_counter()
        trace_start_time_ms = input_requests[0].timestamp if input_requests else 0

        for request in input_requests:
            trace_time_s = (request.timestamp - trace_start_time_ms) / 1000.0
            target_arrival_time = start_time + (trace_time_s * slowdown_factor)

            sleep_duration = target_arrival_time - time.perf_counter()
            if sleep_duration > 0:
                await asyncio.sleep(sleep_duration)

            yield request
    else:
        input_requests_iter = iter(input_requests)
        for request in input_requests_iter:
            yield request

            if request_rate == float("inf"):
                # If the request rate is infinity, then we don't need to wait.
                continue
zhyncs's avatar
zhyncs committed
1489

1490
1491
1492
1493
            # Sample the request interval from the exponential distribution.
            interval = np.random.exponential(1.0 / request_rate)
            # The next request will be sent after the interval.
            await asyncio.sleep(interval)
zhyncs's avatar
zhyncs committed
1494
1495
1496


def calculate_metrics(
1497
    input_requests: List[DatasetRow],
zhyncs's avatar
zhyncs committed
1498
1499
1500
    outputs: List[RequestFuncOutput],
    dur_s: float,
    tokenizer: PreTrainedTokenizerBase,
1501
    backend: str,
zhyncs's avatar
zhyncs committed
1502
) -> Tuple[BenchmarkMetrics, List[int]]:
Ying Sheng's avatar
Ying Sheng committed
1503
1504
    output_lens: List[int] = []
    retokenized_output_lens: List[int] = []
zhyncs's avatar
zhyncs committed
1505
1506
1507
1508
1509
    total_input = 0
    completed = 0
    itls: List[float] = []
    tpots: List[float] = []
    ttfts: List[float] = []
zhyncs's avatar
zhyncs committed
1510
    e2e_latencies: List[float] = []
zhyncs's avatar
zhyncs committed
1511
1512
    for i in range(len(outputs)):
        if outputs[i].success:
Ying Sheng's avatar
Ying Sheng committed
1513
1514
1515
            output_len = outputs[i].output_len
            output_lens.append(output_len)
            retokenized_output_len = len(
Lianmin Zheng's avatar
Lianmin Zheng committed
1516
                tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
Ying Sheng's avatar
Ying Sheng committed
1517
1518
            )
            retokenized_output_lens.append(retokenized_output_len)
1519
            total_input += outputs[i].prompt_len
zhyncs's avatar
zhyncs committed
1520
1521
1522
1523
            if output_len > 1:
                tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
            itls += outputs[i].itl
            ttfts.append(outputs[i].ttft)
zhyncs's avatar
zhyncs committed
1524
1525
1526

            e2e_latencies.append(outputs[i].latency)

zhyncs's avatar
zhyncs committed
1527
1528
            completed += 1
        else:
Ying Sheng's avatar
Ying Sheng committed
1529
1530
            output_lens.append(0)
            retokenized_output_lens.append(0)
zhyncs's avatar
zhyncs committed
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540

    if completed == 0:
        warnings.warn(
            "All requests failed. This is likely due to a misconfiguration "
            "on the benchmark arguments.",
            stacklevel=2,
        )
    metrics = BenchmarkMetrics(
        completed=completed,
        total_input=total_input,
Ying Sheng's avatar
Ying Sheng committed
1541
1542
        total_output=sum(output_lens),
        total_output_retokenized=sum(retokenized_output_lens),
zhyncs's avatar
zhyncs committed
1543
1544
        request_throughput=completed / dur_s,
        input_throughput=total_input / dur_s,
Ying Sheng's avatar
Ying Sheng committed
1545
1546
        output_throughput=sum(output_lens) / dur_s,
        output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,
1547
1548
1549
        total_throughput=(total_input + sum(output_lens)) / dur_s,
        total_throughput_retokenized=(total_input + sum(retokenized_output_lens))
        / dur_s,
zhyncs's avatar
zhyncs committed
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
        mean_ttft_ms=np.mean(ttfts or 0)
        * 1000,  # ttfts is empty if streaming is not supported by backend
        median_ttft_ms=np.median(ttfts or 0) * 1000,
        std_ttft_ms=np.std(ttfts or 0) * 1000,
        p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
        mean_tpot_ms=np.mean(tpots or 0) * 1000,
        median_tpot_ms=np.median(tpots or 0) * 1000,
        std_tpot_ms=np.std(tpots or 0) * 1000,
        p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
        mean_itl_ms=np.mean(itls or 0) * 1000,
        median_itl_ms=np.median(itls or 0) * 1000,
        std_itl_ms=np.std(itls or 0) * 1000,
1562
        p95_itl_ms=np.percentile(itls or 0, 95) * 1000,
zhyncs's avatar
zhyncs committed
1563
        p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
1564
        max_itl_ms=np.max(itls or 0) * 1000,
zhyncs's avatar
zhyncs committed
1565
1566
        mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
        median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
1567
1568
        std_e2e_latency_ms=np.std(e2e_latencies) * 1000,
        p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000,
1569
        concurrency=np.sum(e2e_latencies) / dur_s,
zhyncs's avatar
zhyncs committed
1570
1571
    )

Ying Sheng's avatar
Ying Sheng committed
1572
    return metrics, output_lens
zhyncs's avatar
zhyncs committed
1573
1574
1575
1576
1577


async def benchmark(
    backend: str,
    api_url: str,
1578
    base_url: str,
zhyncs's avatar
zhyncs committed
1579
1580
    model_id: str,
    tokenizer: PreTrainedTokenizerBase,
1581
    input_requests: List[DatasetRow],
zhyncs's avatar
zhyncs committed
1582
    request_rate: float,
1583
    max_concurrency: Optional[int],
zhyncs's avatar
zhyncs committed
1584
    disable_tqdm: bool,
1585
    lora_names: List[str],
1586
    extra_request_body: Dict[str, Any],
1587
    profile: bool,
1588
    pd_separated: bool = False,
Yineng Zhang's avatar
Yineng Zhang committed
1589
    flush_cache: bool = False,
1590
    warmup_requests: int = 1,
1591
1592
1593
    use_trace_timestamps: bool = False,
    mooncake_slowdown_factor=1.0,
    mooncake_num_rounds=1,
zhyncs's avatar
zhyncs committed
1594
1595
1596
1597
1598
1599
):
    if backend in ASYNC_REQUEST_FUNCS:
        request_func = ASYNC_REQUEST_FUNCS[backend]
    else:
        raise ValueError(f"Unknown backend: {backend}")

1600
    # Limit concurrency
1601
1602
1603
1604
1605
1606
1607
1608
1609
    # From https://github.com/vllm-project/vllm/pull/9390
    semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None

    async def limited_request_func(request_func_input, pbar):
        if semaphore is None:
            return await request_func(request_func_input=request_func_input, pbar=pbar)
        async with semaphore:
            return await request_func(request_func_input=request_func_input, pbar=pbar)

1610
    # Warmup
1611
    print(f"Starting warmup with {warmup_requests} sequences...")
1612

1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
    # Handle the data structure difference for the warmup request
    if args.dataset_name == "mooncake":
        # For mooncake, input_requests is a list of dicts.
        # We need to build a temporary DatasetRow for the warmup phase.
        warmup_record = input_requests[0]

        # Build prompt from hash_ids, just like in the async generator
        hash_ids = warmup_record.get("hash_ids", [])
        prompt_text = ""
        for hash_id in hash_ids:
            prompt_text += f"{hash_id}" + " ".join(["hi"] * 512)
        prompt_text += "Can you tell me a detailed story in 1000 words?"

        output_len = warmup_record.get("output_length", 32)
        prompt_len = len(tokenizer.encode(prompt_text))

        # Create a temporary DatasetRow object for warmup
        test_request = DatasetRow(
            prompt=prompt_text,
            prompt_len=prompt_len,
            output_len=output_len,
            image_data=None,  # Mooncake doesn't have image data
        )
    else:
        # For all other datasets, input_requests is a list of DatasetRow objects
        test_request = input_requests[0]
1639

1640
    if lora_names is not None and len(lora_names) != 0:
1641
1642
1643
1644
        lora_name = lora_names[0]
    else:
        lora_name = None

1645
    # Create the test input once
zhyncs's avatar
zhyncs committed
1646
1647
    test_input = RequestFuncInput(
        model=model_id,
1648
        prompt=test_request.prompt,
zhyncs's avatar
zhyncs committed
1649
        api_url=api_url,
1650
1651
        prompt_len=test_request.prompt_len,
        output_len=min(test_request.output_len, 32),
1652
        lora_name=lora_name,
1653
        image_data=test_request.image_data,
1654
        extra_request_body=extra_request_body,
zhyncs's avatar
zhyncs committed
1655
    )
1656
1657
1658

    # Run warmup requests
    warmup_tasks = []
1659
    for _ in range(warmup_requests):
1660
1661
1662
1663
1664
1665
1666
        warmup_tasks.append(
            asyncio.create_task(request_func(request_func_input=test_input))
        )

    warmup_outputs = await asyncio.gather(*warmup_tasks)

    # Check if at least one warmup request succeeded
1667
    if warmup_requests > 0 and not any(output.success for output in warmup_outputs):
zhyncs's avatar
zhyncs committed
1668
        raise ValueError(
1669
1670
            "Warmup failed - Please make sure benchmark arguments "
            f"are correctly specified. Error: {warmup_outputs[0].error}"
zhyncs's avatar
zhyncs committed
1671
1672
        )
    else:
1673
1674
1675
        print(
            f"Warmup completed with {args.warmup_requests} sequences. Starting main benchmark run..."
        )
zhyncs's avatar
zhyncs committed
1676

1677
    # Flush cache
Yineng Zhang's avatar
Yineng Zhang committed
1678
    if ("sglang" in backend and _get_bool_env_var("SGLANG_IS_IN_CI")) or flush_cache:
1679
        requests.post(base_url + "/flush_cache", headers=get_auth_headers())
1680
1681

    time.sleep(1.0)
1682

1683
    # Start profiler
1684
1685
1686
1687
1688
1689
1690
1691
    if profile:
        print("Starting profiler...")
        profile_output = await async_request_profile(
            api_url=base_url + "/start_profile"
        )
        if profile_output.success:
            print("Profiler started")

1692
    # Run all requests
zhyncs's avatar
zhyncs committed
1693
1694
    benchmark_start_time = time.perf_counter()
    tasks: List[asyncio.Task] = []
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
    pbar_total = len(input_requests)
    if (
        backend == "sglang" and args.dataset_name == "mooncake"
    ):  # Assuming mooncake is mainly for sglang or similar backends
        print("Using time-based Mooncake request scheduler, ignoring --request-rate.")
        request_generator = get_mooncake_request_over_time(
            input_requests, tokenizer, mooncake_slowdown_factor, mooncake_num_rounds
        )
        print(
            f"Starting Mooncake trace replay. Sessions: {len(input_requests)}, Rounds per session: {mooncake_num_rounds}. Slowdown factor: {mooncake_slowdown_factor}"
        )
        pbar_total *= args.mooncake_num_rounds
    else:
        request_generator = get_request(input_requests, request_rate)

    pbar = None if disable_tqdm else tqdm(total=pbar_total)
    async for request in request_generator:
1712
        if lora_names is not None and len(lora_names) != 0:
1713
1714
1715
1716
1717
            idx = random.randint(0, len(lora_names) - 1)
            lora_name = lora_names[idx]
        else:
            lora_name = None

zhyncs's avatar
zhyncs committed
1718
1719
        request_func_input = RequestFuncInput(
            model=model_id,
1720
            prompt=request.prompt,
zhyncs's avatar
zhyncs committed
1721
            api_url=api_url,
1722
1723
            prompt_len=request.prompt_len,
            output_len=request.output_len,
1724
            lora_name=lora_name,
1725
            image_data=request.image_data,
1726
            extra_request_body=extra_request_body,
1727
            timestamp=request.timestamp,
zhyncs's avatar
zhyncs committed
1728
        )
1729

zhyncs's avatar
zhyncs committed
1730
1731
        tasks.append(
            asyncio.create_task(
1732
                limited_request_func(request_func_input=request_func_input, pbar=pbar)
zhyncs's avatar
zhyncs committed
1733
1734
1735
1736
            )
        )
    outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)

1737
    # Stop profiler
1738
1739
1740
1741
1742
1743
    if profile:
        print("Stopping profiler...")
        profile_output = await async_request_profile(api_url=base_url + "/stop_profile")
        if profile_output.success:
            print("Profiler stopped")

zhyncs's avatar
zhyncs committed
1744
1745
1746
    if pbar is not None:
        pbar.close()

1747
1748
    if "sglang" in backend:
        server_info = requests.get(base_url + "/get_server_info")
Yineng Zhang's avatar
Yineng Zhang committed
1749
        if server_info.status_code == 200:
1750
1751
1752
1753
1754
1755
            server_info_json = server_info.json()
            if "decode" in server_info_json:
                server_info_json = server_info_json["decode"][0]
            accept_length = server_info_json["internal_states"][0].get(
                "avg_spec_accept_length", None
            )
1756
        else:
Yineng Zhang's avatar
Yineng Zhang committed
1757
            accept_length = None
1758
1759
1760
    else:
        accept_length = None

1761
    # Compute metrics and print results
zhyncs's avatar
zhyncs committed
1762
    benchmark_duration = time.perf_counter() - benchmark_start_time
Ying Sheng's avatar
Ying Sheng committed
1763
    metrics, output_lens = calculate_metrics(
zhyncs's avatar
zhyncs committed
1764
1765
1766
1767
        input_requests=input_requests,
        outputs=outputs,
        dur_s=benchmark_duration,
        tokenizer=tokenizer,
1768
        backend=backend,
zhyncs's avatar
zhyncs committed
1769
1770
1771
    )

    print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
1772
    print("{:<40} {:<10}".format("Backend:", backend))
1773
1774
1775
1776
1777
    print(
        "{:<40} {:<10}".format(
            "Traffic request rate:", "trace" if use_trace_timestamps else request_rate
        )
    )
1778
1779
    print(
        "{:<40} {:<10}".format(
1780
            "Max request concurrency:",
1781
1782
1783
            max_concurrency if max_concurrency else "not set",
        )
    )
zhyncs's avatar
zhyncs committed
1784
1785
1786
1787
    print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
    print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
    print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
    print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
Ying Sheng's avatar
Ying Sheng committed
1788
1789
1790
1791
1792
    print(
        "{:<40} {:<10}".format(
            "Total generated tokens (retokenized):", metrics.total_output_retokenized
        )
    )
zhyncs's avatar
zhyncs committed
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
    print(
        "{:<40} {:<10.2f}".format(
            "Request throughput (req/s):", metrics.request_throughput
        )
    )
    print(
        "{:<40} {:<10.2f}".format(
            "Input token throughput (tok/s):", metrics.input_throughput
        )
    )
    print(
        "{:<40} {:<10.2f}".format(
            "Output token throughput (tok/s):", metrics.output_throughput
        )
    )
1808
1809
1810
1811
1812
    print(
        "{:<40} {:<10.2f}".format(
            "Total token throughput (tok/s):", metrics.total_throughput
        )
    )
1813
    print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency))
1814
1815
    if accept_length:
        print("{:<40} {:<10.2f}".format("Accept length:", accept_length))
zhyncs's avatar
zhyncs committed
1816
1817
1818
1819
1820
1821
1822
1823
1824
    print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
    print(
        "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
    )
    print(
        "{:<40} {:<10.2f}".format(
            "Median E2E Latency (ms):", metrics.median_e2e_latency_ms
        )
    )
zhyncs's avatar
zhyncs committed
1825
1826
1827
1828
    print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-"))
    print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
    print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
    print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
1829
    print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
zhyncs's avatar
zhyncs committed
1830
1831
    print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
    print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
1832
    print("{:<40} {:<10.2f}".format("P95 ITL (ms):", metrics.p95_itl_ms))
zhyncs's avatar
zhyncs committed
1833
    print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
1834
    print("{:<40} {:<10.2f}".format("Max ITL (ms):", metrics.max_itl_ms))
zhyncs's avatar
zhyncs committed
1835
1836
    print("=" * 50)

zhyncs's avatar
zhyncs committed
1837
1838
1839
1840
1841
1842
    if (
        metrics.median_ttft_ms is not None
        and metrics.mean_itl_ms is not None
        and metrics.output_throughput is not None
    ):
        result = {
1843
            # Arguments
zhyncs's avatar
zhyncs committed
1844
1845
            "backend": args.backend,
            "dataset_name": args.dataset_name,
1846
            "request_rate": "trace" if use_trace_timestamps else request_rate,
1847
            "max_concurrency": max_concurrency,
1848
1849
1850
1851
1852
1853
1854
            "sharegpt_output_len": args.sharegpt_output_len,
            "random_input_len": args.random_input_len,
            "random_output_len": args.random_output_len,
            "random_range_ratio": args.random_range_ratio,
            # Results
            "duration": benchmark_duration,
            "completed": metrics.completed,
1855
1856
1857
            "total_input_tokens": metrics.total_input,
            "total_output_tokens": metrics.total_output,
            "total_output_tokens_retokenized": metrics.total_output_retokenized,
1858
1859
1860
            "request_throughput": metrics.request_throughput,
            "input_throughput": metrics.input_throughput,
            "output_throughput": metrics.output_throughput,
1861
1862
            "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
            "median_e2e_latency_ms": metrics.median_e2e_latency_ms,
1863
1864
            "std_e2e_latency_ms": metrics.std_e2e_latency_ms,
            "p99_e2e_latency_ms": metrics.p99_e2e_latency_ms,
1865
            "mean_ttft_ms": metrics.mean_ttft_ms,
1866
            "median_ttft_ms": metrics.median_ttft_ms,
1867
1868
1869
1870
1871
1872
            "std_ttft_ms": metrics.std_ttft_ms,
            "p99_ttft_ms": metrics.p99_ttft_ms,
            "mean_tpot_ms": metrics.mean_tpot_ms,
            "median_tpot_ms": metrics.median_tpot_ms,
            "std_tpot_ms": metrics.std_tpot_ms,
            "p99_tpot_ms": metrics.p99_tpot_ms,
1873
            "mean_itl_ms": metrics.mean_itl_ms,
1874
            "median_itl_ms": metrics.median_itl_ms,
1875
            "std_itl_ms": metrics.std_itl_ms,
1876
            "p95_itl_ms": metrics.p95_itl_ms,
1877
            "p99_itl_ms": metrics.p99_itl_ms,
1878
            "concurrency": metrics.concurrency,
1879
            "accept_length": accept_length,
zhyncs's avatar
zhyncs committed
1880
1881
1882
1883
        }
    else:
        print(f"Error running benchmark for request rate: {request_rate}")
        print("-" * 30)
1884

zhyncs's avatar
zhyncs committed
1885
1886
1887
1888
1889
    # Determine output file name
    if args.output_file:
        output_file_name = args.output_file
    else:
        now = datetime.now().strftime("%m%d")
1890
1891
1892
1893
1894
1895
1896
        if args.dataset_name == "random-image":
            output_file_name = (
                f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_"
                f"{args.random_output_len}_{args.random_image_num_images}imgs_"
                f"{args.random_image_resolution}.jsonl"
            )
        elif args.dataset_name.startswith("random"):
zhyncs's avatar
zhyncs committed
1897
            output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
1898
        else:
1899
1900
1901
            output_file_name = (
                f"{args.backend}_{now}_{args.num_prompts}_{args.dataset_name}.jsonl"
            )
1902

1903
1904
1905
1906
1907
1908
1909
1910
1911
    result_details = {
        "input_lens": [output.prompt_len for output in outputs],
        "output_lens": output_lens,
        "ttfts": [output.ttft for output in outputs],
        "itls": [output.itl for output in outputs],
        "generated_texts": [output.generated_text for output in outputs],
        "errors": [output.error for output in outputs],
    }

zhyncs's avatar
zhyncs committed
1912
1913
    # Append results to a JSONL file
    with open(output_file_name, "a") as file:
1914
1915
1916
1917
1918
1919
1920
        if args.output_details:
            result_for_dump = result | result_details
        else:
            result_for_dump = result
        file.write(json.dumps(result_for_dump) + "\n")

    return result | result_details
zhyncs's avatar
zhyncs committed
1921
1922


1923
1924
1925
1926
1927
1928
1929
1930
1931
def check_chat_template(model_path):
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        return "chat_template" in tokenizer.init_kwargs
    except Exception as e:
        print(f"Fail to load tokenizer config with error={e}")
        return False


1932
1933
1934
1935
1936
1937
def set_global_args(args_: argparse.Namespace):
    """Set the global args."""
    global args
    args = args_


1938
1939
1940
1941
def run_benchmark(args_: argparse.Namespace):
    global args
    args = args_

1942
1943
1944
1945
    # Set default value for max_concurrency if not present
    if not hasattr(args, "max_concurrency"):
        args.max_concurrency = None

1946
1947
1948
1949
    # Set default value for warmup_requests if not present
    if not hasattr(args, "warmup_requests"):
        args.warmup_requests = 1

1950
1951
1952
    if not hasattr(args, "output_details"):
        args.output_details = False

1953
1954
1955
    if not hasattr(args, "tokenize_prompt"):
        args.tokenize_prompt = False

1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
    if not hasattr(args, "use_trace_timestamps"):
        args.use_trace_timestamps = False
    if not hasattr(args, "mooncake_slowdown_factor"):
        args.mooncake_slowdown_factor = 1.0

    if not hasattr(args, "mooncake_slowdown_factor"):
        args.mooncake_slowdown_factor = 1.0

    if not hasattr(args, "mooncake_num_rounds"):
        args.mooncake_num_rounds = 1

1967
1968
    print(f"benchmark_args={args}")

Lianmin Zheng's avatar
Lianmin Zheng committed
1969
    # Set global environments
1970
    set_ulimit()
zhyncs's avatar
zhyncs committed
1971
1972
1973
    random.seed(args.seed)
    np.random.seed(args.seed)

1974
1975
1976
1977
    extra_request_body = {}
    if args.extra_request_body:
        extra_request_body = json.loads(args.extra_request_body)

1978
1979
1980
1981
1982
    if args.tokenize_prompt:
        assert (
            args.backend == "sglang"
        ), "`--tokenize-prompt` only compatible with `--backend sglang` currently"

Lianmin Zheng's avatar
Lianmin Zheng committed
1983
    # Set url
zhyncs's avatar
zhyncs committed
1984
1985
1986
    if args.port is None:
        args.port = {
            "sglang": 30000,
1987
1988
            "sglang-native": 30000,
            "sglang-oai": 30000,
zhyncs's avatar
zhyncs committed
1989
1990
            "lmdeploy": 23333,
            "vllm": 8000,
1991
            "trt": 8000,
1992
            "gserver": 9988,
1993
            "truss": 8080,
zhyncs's avatar
zhyncs committed
1994
1995
1996
1997
1998
1999
2000
2001
        }.get(args.backend, 30000)

    model_url = (
        f"{args.base_url}/v1/models"
        if args.base_url
        else f"http://{args.host}:{args.port}/v1/models"
    )

2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
    if args.backend in ["sglang", "sglang-native"]:
        api_url = (
            f"{args.base_url}/generate"
            if args.base_url
            else f"http://{args.host}:{args.port}/generate"
        )
    elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
        api_url = (
            f"{args.base_url}/v1/completions"
            if args.base_url
            else f"http://{args.host}:{args.port}/v1/completions"
        )
2014
2015
2016
2017
2018
2019
    elif args.backend in ["sglang-oai-chat", "vllm-chat", "lmdeploy-chat"]:
        api_url = (
            f"{args.base_url}/v1/chat/completions"
            if args.base_url
            else f"http://{args.host}:{args.port}/v1/chat/completions"
        )
2020
    elif args.backend == "trt":
2021
2022
2023
2024
2025
2026
2027
2028
        api_url = (
            f"{args.base_url}/v2/models/ensemble/generate_stream"
            if args.base_url
            else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream"
        )
        if args.model is None:
            print("Please provide a model using `--model` when using `trt` backend.")
            sys.exit(1)
2029
    elif args.backend == "gserver":
Lianmin Zheng's avatar
Lianmin Zheng committed
2030
2031
        api_url = args.base_url if args.base_url else f"{args.host}:{args.port}"
        args.model = args.model or "default"
2032
2033
2034
2035
2036
2037
    elif args.backend == "truss":
        api_url = (
            f"{args.base_url}/v1/models/model:predict"
            if args.base_url
            else f"http://{args.host}:{args.port}/v1/models/model:predict"
        )
2038
2039
2040
    base_url = (
        f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url
    )
2041

Lianmin Zheng's avatar
Lianmin Zheng committed
2042
    # Get model name
zhyncs's avatar
zhyncs committed
2043
    if args.model is None:
2044
2045
2046
2047
2048
        if args.backend == "truss":
            print(
                "Please provide a model with `--model` when using truss backend. e.g. --model meta-llama/Llama-3.1-8B-Instruct"
            )
            sys.exit(1)
zhyncs's avatar
zhyncs committed
2049
        try:
2050
            response = requests.get(model_url, headers=get_auth_headers())
zhyncs's avatar
zhyncs committed
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
            model_list = response.json().get("data", [])
            args.model = model_list[0]["id"] if model_list else None
        except Exception as e:
            print(f"Failed to fetch model from {model_url}. Error: {e}")
            print(
                "Please specify the correct host and port using `--host` and `--port`."
            )
            sys.exit(1)

    if args.model is None:
        print("No model specified or found. Please provide a model using `--model`.")
        sys.exit(1)

2064
2065
2066
2067
2068
2069
    if not check_chat_template(args.model):
        print(
            "\nWARNING It is recommended to use the `Chat` or `Instruct` model for benchmarking.\n"
            "Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\n"
        )

zhyncs's avatar
zhyncs committed
2070
2071
    print(f"{args}\n")

Lianmin Zheng's avatar
Lianmin Zheng committed
2072
    # Read dataset
zhyncs's avatar
zhyncs committed
2073
2074
2075
2076
    backend = args.backend
    model_id = args.model
    tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
    tokenizer = get_tokenizer(tokenizer_id)
2077
    input_requests = get_dataset(args, tokenizer)
zhyncs's avatar
zhyncs committed
2078

Yineng Zhang's avatar
Yineng Zhang committed
2079
2080
2081
2082
    # compatible with SimpleNamespace
    if not hasattr(args, "flush_cache"):
        args.flush_cache = False

2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
    return asyncio.run(
        benchmark(
            backend=backend,
            api_url=api_url,
            base_url=base_url,
            model_id=model_id,
            tokenizer=tokenizer,
            input_requests=input_requests,
            request_rate=args.request_rate,
            max_concurrency=args.max_concurrency,
            disable_tqdm=args.disable_tqdm,
2094
            lora_names=args.lora_name,
2095
2096
            extra_request_body=extra_request_body,
            profile=args.profile,
2097
            pd_separated=args.pd_separated,
Yineng Zhang's avatar
Yineng Zhang committed
2098
            flush_cache=args.flush_cache,
2099
            warmup_requests=args.warmup_requests,
2100
2101
2102
            use_trace_timestamps=args.use_trace_timestamps,
            mooncake_slowdown_factor=args.mooncake_slowdown_factor,
            mooncake_num_rounds=args.mooncake_num_rounds,
Lianmin Zheng's avatar
Lianmin Zheng committed
2103
        )
2104
    )
zhyncs's avatar
zhyncs committed
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117


def set_ulimit(target_soft_limit=65535):
    resource_type = resource.RLIMIT_NOFILE
    current_soft, current_hard = resource.getrlimit(resource_type)

    if current_soft < target_soft_limit:
        try:
            resource.setrlimit(resource_type, (target_soft_limit, current_hard))
        except ValueError as e:
            print(f"Fail to set RLIMIT_NOFILE: {e}")


2118
2119
2120
2121
2122
2123
2124
class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, [])
        for lora_name in values:
            getattr(namespace, self.dest).append(lora_name)


zhyncs's avatar
zhyncs committed
2125
if __name__ == "__main__":
2126
    parser = ArgumentParser(description="Benchmark the online serving throughput.")
zhyncs's avatar
zhyncs committed
2127
2128
2129
2130
    parser.add_argument(
        "--backend",
        type=str,
        choices=list(ASYNC_REQUEST_FUNCS.keys()),
2131
        default="sglang",
zhyncs's avatar
zhyncs committed
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
        help="Must specify a backend, depending on the LLM Inference Engine.",
    )
    parser.add_argument(
        "--base-url",
        type=str,
        default=None,
        help="Server or API base url if not using http host and port.",
    )
    parser.add_argument(
        "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
    )
    parser.add_argument(
        "--port",
        type=int,
        help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
    )
    parser.add_argument(
2149
2150
2151
        "--dataset-name",
        type=str,
        default="sharegpt",
2152
2153
2154
2155
2156
2157
2158
        choices=[
            "sharegpt",
            "random",
            "random-ids",
            "generated-shared-prefix",
            "mmmu",
            "random-image",
2159
            "mooncake",
2160
        ],
2161
2162
2163
2164
        help="Name of the dataset to benchmark on.",
    )
    parser.add_argument(
        "--dataset-path", type=str, default="", help="Path to the dataset."
zhyncs's avatar
zhyncs committed
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
    )
    parser.add_argument(
        "--model",
        type=str,
        help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
    )
    parser.add_argument(
        "--tokenizer",
        type=str,
        help="Name or path of the tokenizer. If not set, using the model conf.",
    )
    parser.add_argument(
        "--num-prompts",
        type=int,
        default=1000,
        help="Number of prompts to process. Default is 1000.",
    )
    parser.add_argument(
        "--sharegpt-output-len",
        type=int,
        default=None,
        help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
    )
2188
2189
2190
2191
2192
2193
    parser.add_argument(
        "--sharegpt-context-len",
        type=int,
        default=None,
        help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
    )
2194
2195
2196
    parser.add_argument(
        "--random-input-len",
        type=int,
2197
        default=1024,
2198
2199
2200
2201
        help="Number of input tokens per request, used only for random dataset.",
    )
    parser.add_argument(
        "--random-output-len",
2202
        default=1024,
2203
2204
2205
2206
2207
2208
        type=int,
        help="Number of output tokens per request, used only for random dataset.",
    )
    parser.add_argument(
        "--random-range-ratio",
        type=float,
Yineng Zhang's avatar
Yineng Zhang committed
2209
        default=0.0,
2210
2211
2212
        help="Range of sampled ratio of input/output length, "
        "used only for random dataset.",
    )
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
    # random-image dataset args
    parser.add_argument(
        "--random-image-num-images",
        type=int,
        default=1,
        help="Number of images per request (only available with the random-image dataset)",
    )
    parser.add_argument(
        "--random-image-resolution",
        type=str,
        default="1080p",
        help=(
            "Resolution of random images for random-image dataset. "
            "Supports presets 4k/1080p/720p/360p or custom 'heightxwidth' (e.g., 1080x1920)."
        ),
    )
zhyncs's avatar
zhyncs committed
2229
2230
2231
    parser.add_argument(
        "--request-rate",
        type=float,
2232
        default=float("inf"),
zhyncs's avatar
zhyncs committed
2233
        help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
min-xu-et's avatar
min-xu-et committed
2234
        "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
zhyncs's avatar
zhyncs committed
2235
    )
2236
2237
2238
2239
2240
    parser.add_argument(
        "--use-trace-timestamps",
        action="store_true",
        help="Use timestamps from the trace file for request scheduling. Only valid for 'mooncake' dataset.",
    )
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
    parser.add_argument(
        "--max-concurrency",
        type=int,
        default=None,
        help="Maximum number of concurrent requests. This can be used "
        "to help simulate an environment where a higher level component "
        "is enforcing a maximum number of concurrent requests. While the "
        "--request-rate argument controls the rate at which requests are "
        "initiated, this argument will control how many are actually allowed "
        "to execute at a time. This means that when used in combination, the "
        "actual request rate may be lower than specified with --request-rate, "
        "if the server is not processing requests fast enough to keep up.",
    )
2254
    parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
2255
2256
2257
    parser.add_argument(
        "--output-details", action="store_true", help="Output details of benchmarking."
    )
2258
2259
2260
2261
2262
    parser.add_argument(
        "--disable-tqdm",
        action="store_true",
        help="Specify to disable tqdm progress bar.",
    )
2263
2264
2265
2266
2267
    parser.add_argument(
        "--disable-stream",
        action="store_true",
        help="Disable streaming mode.",
    )
2268
    parser.add_argument(
2269
        "--return-logprob",
2270
        action="store_true",
2271
        help="Return logprob.",
2272
    )
2273
    parser.add_argument("--seed", type=int, default=1, help="The random seed.")
2274
    parser.add_argument(
2275
        "--disable-ignore-eos",
2276
        action="store_true",
2277
        help="Disable ignoring EOS.",
2278
    )
2279
2280
2281
2282
2283
2284
2285
    parser.add_argument(
        "--extra-request-body",
        metavar='{"key1": "value1", "key2": "value2"}',
        type=str,
        help="Append given JSON object to the request payload. You can use this to specify"
        "additional generate params like sampling params.",
    )
2286
2287
2288
2289
2290
    parser.add_argument(
        "--apply-chat-template",
        action="store_true",
        help="Apply chat template",
    )
2291
2292
2293
2294
2295
2296
2297
2298
2299
    parser.add_argument(
        "--profile",
        action="store_true",
        help="Use Torch Profiler. The endpoint must be launched with "
        "SGLANG_TORCH_PROFILER_DIR to enable profiler.",
    )
    parser.add_argument(
        "--lora-name",
        type=str,
2300
        nargs="*",
2301
        default=None,
2302
2303
        action=LoRAPathAction,
        help="The names of LoRA adapters. You can provide a list of names in the format {name} {name} {name}...",
2304
    )
2305
2306
2307
2308
2309
2310
2311
    parser.add_argument(
        "--prompt-suffix",
        type=str,
        default="",
        help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
    )
    parser.add_argument(
Yineng Zhang's avatar
Yineng Zhang committed
2312
        "--pd-separated",
2313
2314
2315
        action="store_true",
        help="Benchmark PD disaggregation server",
    )
Yineng Zhang's avatar
Yineng Zhang committed
2316
2317
2318
2319
2320
    parser.add_argument(
        "--flush-cache",
        action="store_true",
        help="Flush the cache before running the benchmark",
    )
2321
2322
2323
2324
2325
2326
    parser.add_argument(
        "--warmup-requests",
        type=int,
        default=1,
        help="Number of warmup requests to run before the benchmark",
    )
2327
2328
2329
2330
2331
    parser.add_argument(
        "--tokenize-prompt",
        action="store_true",
        help="Use integer ids instead of string for inputs. Useful to control prompt lengths accurately",
    )
2332
2333
2334

    group = parser.add_argument_group("generated-shared-prefix dataset arguments")
    group.add_argument(
2335
        "--gsp-num-groups",
2336
2337
2338
2339
2340
        type=int,
        default=64,
        help="Number of system prompt groups for generated-shared-prefix dataset",
    )
    group.add_argument(
2341
        "--gsp-prompts-per-group",
2342
2343
2344
2345
2346
        type=int,
        default=16,
        help="Number of prompts per system prompt group for generated-shared-prefix dataset",
    )
    group.add_argument(
2347
        "--gsp-system-prompt-len",
2348
2349
2350
2351
2352
        type=int,
        default=2048,
        help="Target length in tokens for system prompts in generated-shared-prefix dataset",
    )
    group.add_argument(
2353
        "--gsp-question-len",
2354
2355
2356
2357
2358
        type=int,
        default=128,
        help="Target length in tokens for questions in generated-shared-prefix dataset",
    )
    group.add_argument(
2359
        "--gsp-output-len",
2360
2361
2362
2363
        type=int,
        default=256,
        help="Target length in tokens for outputs in generated-shared-prefix dataset",
    )
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
    mooncake_group = parser.add_argument_group("mooncake dataset arguments")
    mooncake_group.add_argument(
        "--mooncake-slowdown-factor",
        type=float,
        default=1.0,
        help="Slowdown factor for replaying the mooncake trace. "
        "A value of 2.0 means the replay is twice as slow. "
        "NOTE: --request-rate is IGNORED in mooncake mode.",
    )
    mooncake_group.add_argument(
        "--mooncake-num-rounds",
        type=int,
        default=1,
        help="Number of conversation rounds for each session in the mooncake dataset. "
        "A value > 1 will enable true multi-turn session benchmarking.",
    )
    mooncake_group.add_argument(
        "--mooncake-workload",
        type=str,
        default="conversation",
        choices=[
            "mooncake",
            "conversation",
            "synthetic",
            "toolagent",
        ],
        help="Underlying workload for the mooncake dataset.",
    )
zhyncs's avatar
zhyncs committed
2392
    args = parser.parse_args()
2393
    run_benchmark(args)