bench_serving.py 93.9 KB
Newer Older
zhyncs's avatar
zhyncs committed
1
2
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/backend_request_func.py
# Adapted from https://github.com/vllm-project/vllm/blob/6366efc67b0aedd2c1721c14385370e50b297fb3/benchmarks/benchmark_serving.py
3

Ying Sheng's avatar
Ying Sheng committed
4
"""
5
Benchmark online serving with dynamic requests.
Ying Sheng's avatar
Ying Sheng committed
6
7

Usage:
8
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
Ying Sheng's avatar
Ying Sheng committed
9

10
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5
Ying Sheng's avatar
Ying Sheng committed
11
"""
zhyncs's avatar
zhyncs committed
12
13
14

import argparse
import asyncio
15
import io
zhyncs's avatar
zhyncs committed
16
17
import json
import os
18
import pickle
zhyncs's avatar
zhyncs committed
19
20
21
22
23
24
import random
import resource
import sys
import time
import traceback
import warnings
25
from argparse import ArgumentParser
zhyncs's avatar
zhyncs committed
26
from dataclasses import dataclass, field
27
from datetime import datetime
28
from json import JSONDecodeError
29
from pathlib import Path
30
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
zhyncs's avatar
zhyncs committed
31
32
33

import aiohttp
import numpy as np
34
import pybase64
zhyncs's avatar
zhyncs committed
35
import requests
36
37
from datasets import load_dataset
from PIL import Image
zhyncs's avatar
zhyncs committed
38
39
from tqdm.asyncio import tqdm
from transformers import (
40
    AutoProcessor,
zhyncs's avatar
zhyncs committed
41
42
43
44
45
46
    AutoTokenizer,
    PreTrainedTokenizer,
    PreTrainedTokenizerBase,
    PreTrainedTokenizerFast,
)

47
ASSISTANT_SUFFIX = "Assistant:"
zhyncs's avatar
zhyncs committed
48

49
50
global args

zhyncs's avatar
zhyncs committed
51

Yineng Zhang's avatar
Yineng Zhang committed
52
53
54
55
56
57
# don't want to import sglang package here
def _get_bool_env_var(name: str, default: str = "false") -> bool:
    value = os.getenv(name, default)
    return value.lower() in ("true", "1")


58
59
60
61
62
63
64
65
66
67
68
69
70
def _create_bench_client_session():
    # When the pressure is big, the read buffer could be full before aio thread read
    # the content. We increase the read_bufsize from 64K to 10M.
    # Define constants for timeout and buffer size for clarity and maintainability
    BENCH_AIOHTTP_TIMEOUT_SECONDS = 6 * 60 * 60  # 6 hours
    BENCH_AIOHTTP_READ_BUFSIZE_BYTES = 10 * 1024**2  # 10 MB

    aiohttp_timeout = aiohttp.ClientTimeout(total=BENCH_AIOHTTP_TIMEOUT_SECONDS)
    return aiohttp.ClientSession(
        timeout=aiohttp_timeout, read_bufsize=BENCH_AIOHTTP_READ_BUFSIZE_BYTES
    )


zhyncs's avatar
zhyncs committed
71
72
73
74
75
76
77
@dataclass
class RequestFuncInput:
    prompt: str
    api_url: str
    prompt_len: int
    output_len: int
    model: str
78
    lora_name: str
79
    image_data: Optional[List[str]]
80
    extra_request_body: Dict[str, Any]
81
    timestamp: Optional[float] = None
zhyncs's avatar
zhyncs committed
82
83
84
85
86
87
88
89
90


@dataclass
class RequestFuncOutput:
    generated_text: str = ""
    success: bool = False
    latency: float = 0.0
    ttft: float = 0.0  # Time to first token
    itl: List[float] = field(default_factory=list)  # List of inter-token latencies
91
    text_chunks: List[str] = field(default_factory=list)
zhyncs's avatar
zhyncs committed
92
93
    prompt_len: int = 0
    error: str = ""
94
    output_len: int = 0
zhyncs's avatar
zhyncs committed
95

96
97
98
99
100
101
    @staticmethod
    def init_new(request_func_input: RequestFuncInput):
        output = RequestFuncOutput()
        output.prompt_len = request_func_input.prompt_len
        return output

zhyncs's avatar
zhyncs committed
102
103
104
105
106

def remove_prefix(text: str, prefix: str) -> str:
    return text[len(prefix) :] if text.startswith(prefix) else text


107
108
109
110
def remove_suffix(text: str, suffix: str) -> str:
    return text[: -len(suffix)] if text.endswith(suffix) else text


111
def get_auth_headers() -> Dict[str, str]:
112
113
114
    openai_api_key = os.environ.get("OPENAI_API_KEY")
    if openai_api_key:
        return {"Authorization": f"Bearer {openai_api_key}"}
115
    else:
116
117
118
        api_key = os.environ.get("API_KEY")
        if api_key:
            return {"Authorization": f"{api_key}"}
119
120
121
        return {}


122
# trt llm does not support ignore_eos
123
124
125
126
127
128
129
130
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
async def async_request_trt_llm(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith("generate_stream")

131
    async with _create_bench_client_session() as session:
132
133
134
        payload = {
            "accumulate_tokens": True,
            "text_input": request_func_input.prompt,
zhyncs's avatar
zhyncs committed
135
            "temperature": 0.000001,
136
137
138
            "top_p": 1.0,
            "max_tokens": request_func_input.output_len,
            "stream": True,
Ying Sheng's avatar
Ying Sheng committed
139
140
            "min_length": request_func_input.output_len,
            "end_id": 1048576,
141
            **request_func_input.extra_request_body,
142
        }
143
144
145
        if args.disable_ignore_eos:
            del payload["min_length"]
            del payload["end_id"]
146
        output = RequestFuncOutput.init_new(request_func_input)
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165

        ttft = 0.0
        st = time.perf_counter()
        most_recent_timestamp = st
        try:
            async with session.post(url=api_url, json=payload) as response:
                if response.status == 200:
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
                            continue

                        chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data:")

                        data = json.loads(chunk)
                        output.generated_text += data["text_output"]
                        timestamp = time.perf_counter()
                        # First token
                        if ttft == 0.0:
Xu Song's avatar
Xu Song committed
166
                            ttft = timestamp - st
167
168
169
170
171
172
173
174
175
176
                            output.ttft = ttft

                        # Decoding phase
                        else:
                            output.itl.append(timestamp - most_recent_timestamp)

                        most_recent_timestamp = timestamp

                    output.latency = most_recent_timestamp - st
                    output.success = True
Ying Sheng's avatar
Ying Sheng committed
177
                    output.output_len = request_func_input.output_len
178
179
180
181
182
183
184
185
186
187
188
189
190
191

                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

        if pbar:
            pbar.update(1)
        return output


zhyncs's avatar
zhyncs committed
192
193
194
195
196
197
198
199
200
201
# set ignore_eos True by default
async def async_request_openai_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    assert api_url.endswith(
        "completions"
    ), "OpenAI Completions API URL must end with 'completions'."

Lianmin Zheng's avatar
Lianmin Zheng committed
202
203
    prompt = request_func_input.prompt

204
    async with _create_bench_client_session() as session:
zhyncs's avatar
zhyncs committed
205
206
        payload = {
            "model": request_func_input.model,
Lianmin Zheng's avatar
Lianmin Zheng committed
207
            "prompt": prompt,
zhyncs's avatar
zhyncs committed
208
209
210
            "temperature": 0.0,
            "best_of": 1,
            "max_tokens": request_func_input.output_len,
211
            "stream": not args.disable_stream,
212
            "ignore_eos": not args.disable_ignore_eos,
213
            **request_func_input.extra_request_body,
zhyncs's avatar
zhyncs committed
214
        }
Mick's avatar
Mick committed
215

216
217
218
219
220
        # hack to accommodate different LoRA conventions between SGLang and vLLM.
        if request_func_input.lora_name:
            payload["model"] = request_func_input.lora_name
            payload["lora_path"] = request_func_input.lora_name

Mick's avatar
Mick committed
221
222
223
        if request_func_input.image_data:
            payload.update({"image_data": request_func_input.image_data})

224
        headers = get_auth_headers()
zhyncs's avatar
zhyncs committed
225

226
        output = RequestFuncOutput.init_new(request_func_input)
zhyncs's avatar
zhyncs committed
227
228

        generated_text = ""
229
        output_len = request_func_input.output_len
zhyncs's avatar
zhyncs committed
230
231
232
233
234
235
236
237
238
239
240
241
242
243
        ttft = 0.0
        st = time.perf_counter()
        most_recent_timestamp = st
        try:
            async with session.post(
                url=api_url, json=payload, headers=headers
            ) as response:
                if response.status == 200:
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
                            continue

                        chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
244
                        latency = time.perf_counter() - st
zhyncs's avatar
zhyncs committed
245
                        if chunk == "[DONE]":
246
                            pass
zhyncs's avatar
zhyncs committed
247
248
249
250
251
252
253
254
255
256
257
258
259
260
                        else:
                            data = json.loads(chunk)

                            # NOTE: Some completion API might have a last
                            # usage summary response without a token so we
                            # want to check a token was generated
                            if data["choices"][0]["text"]:
                                timestamp = time.perf_counter()
                                # First token
                                if ttft == 0.0:
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft

                                # Decoding phase
261
                                else:
262
263
264
                                    output.text_chunks.append(
                                        data["choices"][0]["text"]
                                    )
265
                                    output.itl.append(timestamp - most_recent_timestamp)
zhyncs's avatar
zhyncs committed
266
267
268

                                most_recent_timestamp = timestamp
                                generated_text += data["choices"][0]["text"]
Lzhang-hub's avatar
Lzhang-hub committed
269
                                output_len = (data.get("usage") or {}).get(
270
271
                                    "completion_tokens", output_len
                                )
zhyncs's avatar
zhyncs committed
272
273
274
275

                    output.generated_text = generated_text
                    output.success = True
                    output.latency = latency
276
                    output.output_len = output_len
zhyncs's avatar
zhyncs committed
277
278
279
280
281
282
283
284
285
286
287
288
289
                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

    if pbar:
        pbar.update(1)
    return output


290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
async def async_request_openai_chat_completions(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    """Makes a request to the OpenAI Chat Completions API.

    Handles both streaming and non-streaming responses, including support
    for image data in messages. Calculates and returns various performance
    metrics.

    Args:
        request_func_input: Input parameters for the request.
        pbar: Optional tqdm progress bar to update.

    Returns:
        RequestFuncOutput: Output of the request, including generated text,
                           latency, TTFT, ITL, and success status.
    """
    api_url = request_func_input.api_url
    assert api_url.endswith(
        "chat/completions"
    ), "OpenAI Chat Completions API URL must end with 'chat/completions'."

    if request_func_input.image_data:
314
315
316
317
318
319
320
321
322
        # Build multi-image content: a list of image_url entries followed by the text
        content_items = [
            {
                "type": "image_url",
                "image_url": {"url": img_url},
            }
            for img_url in request_func_input.image_data
        ]
        content_items.append({"type": "text", "text": request_func_input.prompt})
323
324
325
        messages = [
            {
                "role": "user",
326
                "content": content_items,
327
328
329
330
331
332
333
334
335
336
            },
        ]
    else:
        messages = [{"role": "user", "content": request_func_input.prompt}]

    async with _create_bench_client_session() as session:
        payload = {
            "model": request_func_input.model,
            "messages": messages,
            "temperature": 0.0,
337
            "max_completion_tokens": request_func_input.output_len,
338
            "stream": not args.disable_stream,
339
            "ignore_eos": not args.disable_ignore_eos,
340
341
            **request_func_input.extra_request_body,
        }
342
343
344
345
346
347

        # hack to accommodate different LoRA conventions between SGLang and vLLM.
        if request_func_input.lora_name:
            payload["model"] = request_func_input.lora_name
            payload["lora_path"] = request_func_input.lora_name

348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
        headers = get_auth_headers()

        output = RequestFuncOutput.init_new(request_func_input)

        generated_text = ""
        output_len = request_func_input.output_len
        ttft = 0.0
        st = time.perf_counter()
        most_recent_timestamp = st
        try:
            async with session.post(
                url=api_url, json=payload, headers=headers
            ) as response:
                if response.status == 200:
                    if args.disable_stream:
                        # Non-streaming response
                        response_json = await response.json()
                        output.generated_text = response_json["choices"][0]["message"][
                            "content"
                        ]
                        output.success = True
                        output.latency = time.perf_counter() - st
                        output.ttft = (
                            output.latency
                        )  # For non-streaming, TTFT = total latency
                        output.output_len = response_json.get("usage", {}).get(
                            "completion_tokens", output_len
                        )
                    else:
                        # Streaming response
                        async for chunk_bytes in response.content:
                            chunk_bytes = chunk_bytes.strip()
                            if not chunk_bytes:
                                continue

                            chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
                            latency = time.perf_counter() - st
                            if chunk == "[DONE]":
                                pass
                            else:
                                data = json.loads(chunk)

                                # Check if this chunk contains content
                                delta = data.get("choices", [{}])[0].get("delta", {})
                                content = delta.get("content", "")

                                if content:
                                    timestamp = time.perf_counter()
                                    # First token
                                    if ttft == 0.0:
                                        ttft = timestamp - st
                                        output.ttft = ttft

                                    # Decoding phase
                                    else:
                                        output.itl.append(
                                            timestamp - most_recent_timestamp
                                        )

                                    most_recent_timestamp = timestamp
                                    generated_text += content

                                # Check for usage info in final chunk
                                output_len = (data.get("usage") or {}).get(
                                    "completion_tokens", output_len
                                )

                        output.generated_text = generated_text
                        output.success = True
                        output.latency = latency
                        output.output_len = output_len
                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

    if pbar:
        pbar.update(1)
    return output


432
433
434
435
436
437
438
439
async def async_request_truss(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url

    prompt = request_func_input.prompt

440
    async with _create_bench_client_session() as session:
441
442
443
444
445
446
447
448
449
450
        payload = {
            "model": request_func_input.model,
            "prompt": prompt,
            "temperature": 0.0,
            "best_of": 1,
            "max_tokens": request_func_input.output_len,
            "stream": not args.disable_stream,
            "ignore_eos": not args.disable_ignore_eos,
            **request_func_input.extra_request_body,
        }
451
        headers = get_auth_headers()
452

453
        output = RequestFuncOutput.init_new(request_func_input)
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478

        generated_text = ""
        ttft = 0.0
        st = time.perf_counter()
        most_recent_timestamp = st
        try:
            async with session.post(
                url=api_url, json=payload, headers=headers
            ) as response:
                if response.status == 200:
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
                            continue

                        chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
                        latency = time.perf_counter() - st
                        if chunk == "[DONE]":
                            pass
                        else:
                            data = json.loads(chunk)

                            # NOTE: Some completion API might have a last
                            # usage summary response without a token so we
                            # want to check a token was generated
479
                            if data["choices"][0]["text"]:
480
481
482
483
484
485
486
487
488
489
490
                                timestamp = time.perf_counter()
                                # First token
                                if ttft == 0.0:
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft

                                # Decoding phase
                                else:
                                    output.itl.append(timestamp - most_recent_timestamp)

                                most_recent_timestamp = timestamp
491
                                generated_text += data["choices"][0]["text"]
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509

                    output.generated_text = generated_text
                    output.success = True
                    output.latency = latency
                    output.output_len = request_func_input.output_len
                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

    if pbar:
        pbar.update(1)
    return output


510
511
512
513
514
515
516
async def async_request_sglang_generate(
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    api_url = request_func_input.api_url
    prompt = request_func_input.prompt

517
    async with _create_bench_client_session() as session:
518
        payload = {
519
            ("text" if isinstance(prompt, str) else "input_ids"): prompt,
520
521
522
523
524
525
            "sampling_params": {
                "temperature": 0.0,
                "max_new_tokens": request_func_input.output_len,
                "ignore_eos": not args.disable_ignore_eos,
            },
            "stream": not args.disable_stream,
526
            "lora_path": request_func_input.lora_name,
527
528
            "return_logprob": args.return_logprob,
            "logprob_start_len": -1,
529
530
            **request_func_input.extra_request_body,
        }
531

532
        # Add image data if available (list of image urls/base64)
533
534
535
        if request_func_input.image_data:
            payload["image_data"] = request_func_input.image_data

536
        headers = get_auth_headers()
537

538
        output = RequestFuncOutput.init_new(request_func_input)
539
540

        generated_text = ""
541
        output_len = request_func_input.output_len
542
543
544
        ttft = 0.0
        st = time.perf_counter()
        most_recent_timestamp = st
545
        last_output_len = 0
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
        try:
            async with session.post(
                url=api_url, json=payload, headers=headers
            ) as response:
                if response.status == 200:
                    async for chunk_bytes in response.content:
                        chunk_bytes = chunk_bytes.strip()
                        if not chunk_bytes:
                            continue

                        chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
                        latency = time.perf_counter() - st
                        if chunk == "[DONE]":
                            pass
                        else:
                            data = json.loads(chunk)

                            # NOTE: Some completion API might have a last
                            # usage summary response without a token so we
                            # want to check a token was generated
Zijian's avatar
Zijian committed
566
                            if "text" in data and data["text"]:
567
                                timestamp = time.perf_counter()
568
569
570
                                generated_text = data["text"]
                                output_len = data["meta_info"]["completion_tokens"]

571
572
573
574
575
576
577
                                # First token
                                if ttft == 0.0:
                                    ttft = time.perf_counter() - st
                                    output.ttft = ttft

                                # Decoding phase
                                else:
578
579
580
                                    num_new_tokens = output_len - last_output_len
                                    if num_new_tokens == 0:
                                        continue
581
582
                                    chunk_gap = timestamp - most_recent_timestamp
                                    adjust_itl = chunk_gap / num_new_tokens
583
                                    output.itl.extend([adjust_itl] * num_new_tokens)
584
585

                                most_recent_timestamp = timestamp
Lianmin Zheng's avatar
Lianmin Zheng committed
586
                                last_output_len = output_len
587
588
589
590

                    output.generated_text = generated_text
                    output.success = True
                    output.latency = latency
591
                    output.output_len = output_len
592
593
594
595
596
597
598
                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))
599
            print(f"{output.error=}")
600
601
602
603
604
605

    if pbar:
        pbar.update(1)
    return output


606
async def async_request_gserver(
Lianmin Zheng's avatar
Lianmin Zheng committed
607
608
609
610
611
612
    request_func_input: RequestFuncInput,
    pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
    raise NotImplementedError()


613
async def async_request_profile(api_url: str) -> RequestFuncOutput:
614
    async with _create_bench_client_session() as session:
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
        output = RequestFuncOutput()
        try:
            async with session.post(url=api_url) as response:
                if response.status == 200:
                    output.success = True
                else:
                    output.error = response.reason or ""
                    output.success = False
        except Exception:
            output.success = False
            exc_info = sys.exc_info()
            output.error = "".join(traceback.format_exception(*exc_info))

    return output


631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
def _build_profile_urls(
    profile_prefill_url: Optional[List[str]],
    profile_decode_url: Optional[List[str]],
) -> List[Tuple[str, str]]:
    """Build profile URLs list from prefill/decode URL arguments.

    Returns:
        List of (worker_type, url) tuples. e.g., [("Prefill-0", "http://..."), ("Decode-0", "http://...")]
    """
    profile_urls = []
    if profile_prefill_url:
        for idx, url in enumerate(profile_prefill_url):
            profile_urls.append((f"Prefill-{idx}", url))
    if profile_decode_url:
        for idx, url in enumerate(profile_decode_url):
            profile_urls.append((f"Decode-{idx}", url))
    return profile_urls


async def _call_profile_pd(profile_urls: List[Tuple[str, str]], mode: str) -> None:
    """Call profile endpoint (start/stop) on PD separated workers.

    Args:
        profile_urls: List of (worker_type, url) tuples
        mode: "start" or "stop"
    """
    endpoint = "/start_profile" if mode == "start" else "/stop_profile"
    action = "Starting" if mode == "start" else "Stopping"
    action_past = "started" if mode == "start" else "stopped"

    print(f"{action} profiler...")

    for worker_type, url in profile_urls:
        profile_output = await async_request_profile(api_url=url + endpoint)
        if profile_output.success:
            print(f"Profiler {action_past} for {worker_type} worker at {url}")
        else:
            print(
                f"Failed to {mode} profiler for {worker_type} worker at {url}: {profile_output.error}"
            )


zhyncs's avatar
zhyncs committed
673
def get_model(pretrained_model_name_or_path: str) -> str:
674
    if os.getenv("SGLANG_USE_MODELSCOPE", "false").lower() == "true":
zhyncs's avatar
zhyncs committed
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
        import huggingface_hub.constants
        from modelscope import snapshot_download

        model_path = snapshot_download(
            model_id=pretrained_model_name_or_path,
            local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
            ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
        )

        return model_path
    return pretrained_model_name_or_path


def get_tokenizer(
    pretrained_model_name_or_path: str,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
691
692
693
694
    assert (
        pretrained_model_name_or_path is not None
        and pretrained_model_name_or_path != ""
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
695
696
697
    if pretrained_model_name_or_path.endswith(
        ".json"
    ) or pretrained_model_name_or_path.endswith(".model"):
698
        from sglang.srt.utils.hf_transformers_utils import get_tokenizer
Lianmin Zheng's avatar
Lianmin Zheng committed
699
700
701

        return get_tokenizer(pretrained_model_name_or_path)

zhyncs's avatar
zhyncs committed
702
703
704
705
706
707
708
709
710
    if pretrained_model_name_or_path is not None and not os.path.exists(
        pretrained_model_name_or_path
    ):
        pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)
    return AutoTokenizer.from_pretrained(
        pretrained_model_name_or_path, trust_remote_code=True
    )


711
712
713
714
715
716
717
718
719
720
def get_processor(
    pretrained_model_name_or_path: str,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
    assert (
        pretrained_model_name_or_path is not None
        and pretrained_model_name_or_path != ""
    )
    if pretrained_model_name_or_path.endswith(
        ".json"
    ) or pretrained_model_name_or_path.endswith(".model"):
Mick's avatar
Mick committed
721
        from sglang.srt.utils.hf_transformers_utils import get_processor
722
723
724
725
726
727
728
729
730
731
732
733
734

        return get_processor(pretrained_model_name_or_path)

    if pretrained_model_name_or_path is not None and not os.path.exists(
        pretrained_model_name_or_path
    ):
        pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)
    return AutoProcessor.from_pretrained(
        pretrained_model_name_or_path, trust_remote_code=True
    )


def get_dataset(args, tokenizer, model_id=None):
fzyzcjy's avatar
fzyzcjy committed
735
    tokenize_prompt = getattr(args, "tokenize_prompt", False)
736
    if args.dataset_name == "sharegpt":
fzyzcjy's avatar
fzyzcjy committed
737
        assert not tokenize_prompt
738
739
740
741
742
        input_requests = sample_sharegpt_requests(
            dataset_path=args.dataset_path,
            num_requests=args.num_prompts,
            tokenizer=tokenizer,
            fixed_output_len=args.sharegpt_output_len,
743
            context_len=args.sharegpt_context_len,
744
            prompt_suffix=args.prompt_suffix,
745
            apply_chat_template=args.apply_chat_template,
746
        )
747
    elif args.dataset_name.startswith("random"):
748
749
750
751
752
753
        input_requests = sample_random_requests(
            input_len=args.random_input_len,
            output_len=args.random_output_len,
            num_prompts=args.num_prompts,
            range_ratio=args.random_range_ratio,
            tokenizer=tokenizer,
754
            dataset_path=args.dataset_path,
755
            random_sample=args.dataset_name == "random",
fzyzcjy's avatar
fzyzcjy committed
756
            return_text=not tokenize_prompt,
757
        )
758
759
760
    elif args.dataset_name == "image":
        processor = get_processor(model_id)
        input_requests = sample_image_requests(
761
            num_requests=args.num_prompts,
762
            image_count=args.image_count,
763
764
765
            input_len=args.random_input_len,
            output_len=args.random_output_len,
            range_ratio=args.random_range_ratio,
766
767
768
769
            processor=processor,
            image_content=args.image_content,
            image_format=args.image_format,
            image_resolution=args.image_resolution,
770
            backend=args.backend,
771
        )
772
    elif args.dataset_name == "generated-shared-prefix":
fzyzcjy's avatar
fzyzcjy committed
773
        assert not tokenize_prompt
774
        input_requests = sample_generated_shared_prefix_requests(
775
776
777
778
779
            num_groups=args.gsp_num_groups,
            prompts_per_group=args.gsp_prompts_per_group,
            system_prompt_len=args.gsp_system_prompt_len,
            question_len=args.gsp_question_len,
            output_len=args.gsp_output_len,
780
            tokenizer=tokenizer,
781
            args=args,
782
        )
783
    elif args.dataset_name == "mmmu":
784
        processor = get_processor(model_id)
785
786
        input_requests = sample_mmmu_requests(
            num_requests=args.num_prompts,
787
            processor=processor,
788
            backend=args.backend,
789
790
791
            fixed_output_len=args.random_output_len,
            random_sample=True,
        )
792
793
794
795
    elif args.dataset_name == "mooncake":
        # For mooncake, we don't generate the prompts here.
        # We just load the raw trace data. The async generator will handle the rest.
        if not args.dataset_path:
796
            local_path = os.path.join("/tmp", args.mooncake_workload + "_trace.jsonl")
797
798
799
800
        else:
            local_path = args.dataset_path

        if not os.path.exists(local_path):
801
802
803
            download_and_cache_file(
                MOONCAKE_DATASET_URL[args.mooncake_workload], local_path
            )
804
805
806
807
808
809

        with open(local_path, "r") as f:
            all_requests_data = [json.loads(line) for line in f if line.strip()]

        # Limit the number of requests based on --num-prompts
        input_requests = all_requests_data[: args.num_prompts]
810
811
812
813
814
    else:
        raise ValueError(f"Unknown dataset: {args.dataset_name}")
    return input_requests


zhyncs's avatar
zhyncs committed
815
ASYNC_REQUEST_FUNCS = {
816
817
818
    "sglang": async_request_sglang_generate,
    "sglang-native": async_request_sglang_generate,
    "sglang-oai": async_request_openai_completions,
819
    "sglang-oai-chat": async_request_openai_chat_completions,
zhyncs's avatar
zhyncs committed
820
    "vllm": async_request_openai_completions,
821
    "vllm-chat": async_request_openai_chat_completions,
zhyncs's avatar
zhyncs committed
822
    "lmdeploy": async_request_openai_completions,
823
    "lmdeploy-chat": async_request_openai_chat_completions,
824
    "trt": async_request_trt_llm,
825
    "gserver": async_request_gserver,
826
    "truss": async_request_truss,
zhyncs's avatar
zhyncs committed
827
828
829
830
831
832
833
}


@dataclass
class BenchmarkMetrics:
    completed: int
    total_input: int
834
835
    total_input_text: int
    total_input_vision: int
zhyncs's avatar
zhyncs committed
836
    total_output: int
Ying Sheng's avatar
Ying Sheng committed
837
    total_output_retokenized: int
zhyncs's avatar
zhyncs committed
838
839
840
    request_throughput: float
    input_throughput: float
    output_throughput: float
Ying Sheng's avatar
Ying Sheng committed
841
    output_throughput_retokenized: float
842
843
    total_throughput: float
    total_throughput_retokenized: float
zhyncs's avatar
zhyncs committed
844
845
846
847
848
849
850
851
852
853
854
    mean_ttft_ms: float
    median_ttft_ms: float
    std_ttft_ms: float
    p99_ttft_ms: float
    mean_tpot_ms: float
    median_tpot_ms: float
    std_tpot_ms: float
    p99_tpot_ms: float
    mean_itl_ms: float
    median_itl_ms: float
    std_itl_ms: float
855
    p95_itl_ms: float
zhyncs's avatar
zhyncs committed
856
    p99_itl_ms: float
857
    max_itl_ms: float
zhyncs's avatar
zhyncs committed
858
859
    mean_e2e_latency_ms: float
    median_e2e_latency_ms: float
860
861
    std_e2e_latency_ms: float
    p99_e2e_latency_ms: float
862
    concurrency: float
zhyncs's avatar
zhyncs committed
863
864


Lianmin Zheng's avatar
Lianmin Zheng committed
865
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
866
867
868
869
870
871
MOONCAKE_DATASET_URL = {
    "mooncake": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/arxiv-trace/mooncake_trace.jsonl",
    "conversation": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/conversation_trace.jsonl",
    "synthetic": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/synthetic_trace.jsonl",
    "toolagent": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/toolagent_trace.jsonl",
}
Lianmin Zheng's avatar
Lianmin Zheng committed
872
873


Lianmin Zheng's avatar
Lianmin Zheng committed
874
875
876
877
def download_and_cache_file(url: str, filename: Optional[str] = None):
    """Read and cache a file from a url."""
    if filename is None:
        filename = os.path.join("/tmp", url.split("/")[-1])
Lianmin Zheng's avatar
Lianmin Zheng committed
878

Lianmin Zheng's avatar
Lianmin Zheng committed
879
    # Check if the cache file already exists
880
    if is_file_valid_json(filename):
Lianmin Zheng's avatar
Lianmin Zheng committed
881
        return filename
Lianmin Zheng's avatar
Lianmin Zheng committed
882

Lianmin Zheng's avatar
Lianmin Zheng committed
883
    print(f"Downloading from {url} to {filename}")
Lianmin Zheng's avatar
Lianmin Zheng committed
884

Lianmin Zheng's avatar
Lianmin Zheng committed
885
886
887
    # Stream the response to show the progress bar
    response = requests.get(url, stream=True)
    response.raise_for_status()  # Check for request errors
Lianmin Zheng's avatar
Lianmin Zheng committed
888

Lianmin Zheng's avatar
Lianmin Zheng committed
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
    # Total size of the file in bytes
    total_size = int(response.headers.get("content-length", 0))
    chunk_size = 1024  # Download in chunks of 1KB

    # Use tqdm to display the progress bar
    with open(filename, "wb") as f, tqdm(
        desc=filename,
        total=total_size,
        unit="B",
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for chunk in response.iter_content(chunk_size=chunk_size):
            f.write(chunk)
            bar.update(len(chunk))

    return filename
Lianmin Zheng's avatar
Lianmin Zheng committed
906
907


908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
def is_file_valid_json(path):
    if not os.path.isfile(path):
        return False

    # TODO can fuse into the real file open later
    try:
        with open(path) as f:
            json.load(f)
        return True
    except JSONDecodeError as e:
        print(
            f"{path} exists but json loading fails ({e=}), thus treat as invalid file"
        )
        return False


924
925
926
927
928
@dataclass
class DatasetRow:
    prompt: str
    prompt_len: int
    output_len: int
929
930
    text_prompt_len: Optional[int] = None
    vision_prompt_len: Optional[int] = None
931
    image_data: Optional[List[str]] = None
932
933
    timestamp: Optional[float] = None

934
935
936
937
938
939
    def __post_init__(self):
        if self.text_prompt_len is None:
            self.text_prompt_len = self.prompt_len
        if self.vision_prompt_len is None:
            self.vision_prompt_len = 0

940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986

async def get_mooncake_request_over_time(
    input_requests: List[Dict],
    tokenizer: PreTrainedTokenizerBase,
    slowdown_factor: float,
    num_rounds: int,
) -> AsyncGenerator[DatasetRow, None]:
    """
    An async generator that yields requests based on the timestamps in the Mooncake trace file,
    with support for multi-round sessions.
    """
    if not input_requests:
        return

    input_requests.sort(key=lambda r: r["timestamp"])

    start_time = time.perf_counter()
    trace_start_time_ms = input_requests[0]["timestamp"]

    for record in input_requests:
        # Calculate when this entire session should start
        relative_arrival_time_s = (record["timestamp"] - trace_start_time_ms) / 1000.0
        target_arrival_time_s = relative_arrival_time_s * slowdown_factor

        current_elapsed_time_s = time.perf_counter() - start_time
        sleep_duration_s = target_arrival_time_s - current_elapsed_time_s
        if sleep_duration_s > 0:
            await asyncio.sleep(sleep_duration_s)

        # Once the session starts, generate all rounds for it as a burst
        # This simulates a user engaging in a multi-turn conversation

        # Base user query constructed from hash_ids
        user_query_base = ""
        hash_ids = record.get("hash_ids", [])
        for hash_id in hash_ids:
            user_query_base += f"{hash_id}" + " ".join(
                ["hi"] * 128
            )  # Shorter for multi-round
        user_query_base += "Tell me a story based on this context."

        output_len_per_round = record.get("output_length", 256)
        chat_history = []

        for i in range(num_rounds):
            # Add user query for the current round
            chat_history.append(
Mick's avatar
Mick committed
987
                {"role": "user", "content": f"Round {i + 1}: {user_query_base}"}
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
            )

            # Form the full prompt from history
            try:
                full_prompt_text = tokenizer.apply_chat_template(
                    chat_history, tokenize=False, add_generation_prompt=True
                )
            except Exception:
                full_prompt_text = "\n".join(
                    [f"{msg['role']}: {msg['content']}" for msg in chat_history]
                )

            prompt_len = len(tokenizer.encode(full_prompt_text))

            yield DatasetRow(
                prompt=full_prompt_text,
                prompt_len=prompt_len,
                output_len=output_len_per_round,
            )

            # Add a placeholder assistant response for the next round's context
            # We use a placeholder because we don't know the real response
            placeholder_response = " ".join(["story"] * output_len_per_round)
            chat_history.append({"role": "assistant", "content": placeholder_response})
1012
1013


1014
1015
def sample_mmmu_requests(
    num_requests: int,
Mick's avatar
Mick committed
1016
    processor: AutoProcessor | AutoTokenizer,
1017
    backend: str = "sglang",
1018
1019
    fixed_output_len: Optional[int] = None,
    random_sample: bool = True,
1020
) -> List[DatasetRow]:
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
    """
    Sample requests from the MMMU dataset using HuggingFace datasets.

    Args:
        num_requests: Number of requests to sample.
        fixed_output_len: If provided, use this fixed output length for all requests.
        random_sample: Whether to randomly sample or take the first N.

    Returns:
        List of tuples (prompt, prompt_token_len, output_token_len).
    """
    print("Loading MMMU dataset from HuggingFace...")

    try:
        print("Attempting to load MMMU Math dataset...")
        mmmu_dataset = load_dataset("MMMU/MMMU", "Math", split="test")
        print(
            f"Successfully loaded MMMU Math dataset from HuggingFace with {len(mmmu_dataset)} examples"
        )
    except Exception as e:
        print(f"Failed to load MMMU Math dataset: {e}")
        raise ValueError(f"Failed to load MMMU dataset: {e}")

    # Sample from the dataset
    if len(mmmu_dataset) > num_requests:
        if random_sample:
            # Random sample
            indices = random.sample(range(len(mmmu_dataset)), num_requests)
            sample_dataset = mmmu_dataset.select(indices)
        else:
            # Take first N
            sample_dataset = mmmu_dataset.select(
                range(min(num_requests, len(mmmu_dataset)))
            )
    else:
        print(f"Dataset has less than {num_requests} examples, using all examples")
        sample_dataset = mmmu_dataset

    print(f"Selected {len(sample_dataset)} examples for benchmarking")

    # Create prompts
    filtered_dataset = []

    for i, example in enumerate(sample_dataset):
        try:
            # Extract image_1
            image = example.get("image_1")

            if image is not None:
                if hasattr(image, "save"):
                    # Convert RGBA images to RGB before encoding
                    if image.mode == "RGBA":
                        image = image.convert("RGB")

1075
                    # Encode image to base64 (save as PNG to support palette/alpha modes)
1076
                    buffered = io.BytesIO()
1077
                    image.save(buffered, format="PNG")
1078
                    img_str = pybase64.b64encode(buffered.getvalue()).decode("utf-8")
1079
                    image_data = f"data:image/png;base64,{img_str}"
1080
1081
1082
1083
1084
1085
                else:
                    continue

                # Extract the question
                question = example.get("question")

1086
                # Construct the prompt
1087
                text_prompt = f"Question: {question}\n\nAnswer: "
1088
                output_len = fixed_output_len if fixed_output_len is not None else 256
1089
                data_row = create_mm_data_row(
1090
                    text_prompt, [image], [image_data], output_len, processor, backend
1091
                )
1092
                filtered_dataset.append(data_row)
1093
1094
1095
1096
1097
1098
1099
1100

        except Exception as e:
            print(f"Error processing example {i}: {e}")

    print(f"\nCreated {len(filtered_dataset)} MMMU prompts")
    return filtered_dataset


zhyncs's avatar
zhyncs committed
1101
1102
1103
1104
1105
def sample_sharegpt_requests(
    dataset_path: str,
    num_requests: int,
    tokenizer: PreTrainedTokenizerBase,
    fixed_output_len: Optional[int] = None,
1106
    context_len: Optional[int] = None,
1107
    prompt_suffix: Optional[str] = "",
1108
    apply_chat_template=False,
1109
) -> List[DatasetRow]:
zhyncs's avatar
zhyncs committed
1110
1111
1112
    if fixed_output_len is not None and fixed_output_len < 4:
        raise ValueError("output_len too small")

Lianmin Zheng's avatar
Lianmin Zheng committed
1113
    # Download sharegpt if necessary
1114
    if not is_file_valid_json(dataset_path) and dataset_path == "":
Lianmin Zheng's avatar
Lianmin Zheng committed
1115
        dataset_path = download_and_cache_file(SHAREGPT_URL)
zhyncs's avatar
zhyncs committed
1116
1117
1118
1119

    # Load the dataset.
    with open(dataset_path) as f:
        dataset = json.load(f)
1120

zhyncs's avatar
zhyncs committed
1121
    # Filter out the conversations with less than 2 turns.
1122
1123
1124
1125
1126
    dataset = [
        data
        for data in dataset
        if len(data.get("conversations", data.get("conversation", []))) >= 2
    ]
zhyncs's avatar
zhyncs committed
1127
1128
    # Only keep the first two turns of each conversation.
    dataset = [
1129
1130
1131
1132
        (
            data.get("conversations", data.get("conversation", []))[0]["value"],
            data.get("conversations", data.get("conversation", []))[1]["value"],
        )
zhyncs's avatar
zhyncs committed
1133
1134
1135
1136
1137
1138
1139
        for data in dataset
    ]

    # Shuffle the dataset.
    random.shuffle(dataset)

    # Filter out sequences that are too long or too short
1140
    filtered_dataset: List[DatasetRow] = []
zhyncs's avatar
zhyncs committed
1141
1142
1143
1144
1145
1146
    for i in range(len(dataset)):
        if len(filtered_dataset) == num_requests:
            break

        # Tokenize the prompts and completions.
        prompt = dataset[i][0]
1147
        if prompt_suffix:
1148
1149
1150
1151
1152
            prompt = (
                remove_suffix(prompt, ASSISTANT_SUFFIX)
                + prompt_suffix
                + ASSISTANT_SUFFIX
            )
1153
1154
1155
1156
1157
1158
1159

        if apply_chat_template:
            prompt = tokenizer.apply_chat_template(
                [{"role": "user", "content": prompt}],
                add_generation_prompt=True,
                tokenize=False,
            )
1160
1161
            if tokenizer.bos_token:
                prompt = prompt.replace(tokenizer.bos_token, "")
1162

Lianmin Zheng's avatar
Lianmin Zheng committed
1163
        prompt_token_ids = tokenizer.encode(prompt)
zhyncs's avatar
zhyncs committed
1164
        completion = dataset[i][1]
Lianmin Zheng's avatar
Lianmin Zheng committed
1165
        completion_token_ids = tokenizer.encode(completion)
zhyncs's avatar
zhyncs committed
1166
1167
1168
1169
        prompt_len = len(prompt_token_ids)
        output_len = (
            len(completion_token_ids) if fixed_output_len is None else fixed_output_len
        )
1170

1171
        if prompt_len < 2 or output_len < 2:
zhyncs's avatar
zhyncs committed
1172
1173
            # Prune too short sequences.
            continue
1174
1175

        if context_len and prompt_len + output_len > context_len:
zhyncs's avatar
zhyncs committed
1176
1177
            # Prune too long sequences.
            continue
1178

1179
        filtered_dataset.append(
1180
1181
1182
1183
1184
            DatasetRow(
                prompt=prompt,
                prompt_len=prompt_len,
                output_len=output_len,
            )
1185
        )
zhyncs's avatar
zhyncs committed
1186

1187
1188
    print(f"#Input tokens: {np.sum([x.prompt_len for x in filtered_dataset])}")
    print(f"#Output tokens: {np.sum([x.output_len for x in filtered_dataset])}")
zhyncs's avatar
zhyncs committed
1189
1190
1191
    return filtered_dataset


1192
1193
1194
1195
1196
1197
def sample_random_requests(
    input_len: int,
    output_len: int,
    num_prompts: int,
    range_ratio: float,
    tokenizer: PreTrainedTokenizerBase,
Lianmin Zheng's avatar
Lianmin Zheng committed
1198
    dataset_path: str,
1199
    random_sample: bool = True,
1200
    return_text: bool = True,
1201
) -> List[DatasetRow]:
1202
    input_lens = np.random.randint(
Yineng Zhang's avatar
Yineng Zhang committed
1203
        max(int(input_len * range_ratio), 1),
1204
1205
1206
1207
1208
1209
1210
1211
        input_len + 1,
        size=num_prompts,
    )
    output_lens = np.random.randint(
        int(output_len * range_ratio),
        output_len + 1,
        size=num_prompts,
    )
Lianmin Zheng's avatar
Lianmin Zheng committed
1212

1213
    if random_sample:
Lianmin Zheng's avatar
Lianmin Zheng committed
1214
1215
1216
        # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens

        # Download sharegpt if necessary
1217
        if not is_file_valid_json(dataset_path):
Lianmin Zheng's avatar
Lianmin Zheng committed
1218
            dataset_path = download_and_cache_file(SHAREGPT_URL)
Lianmin Zheng's avatar
Lianmin Zheng committed
1219
1220
1221
1222
1223

        # Load the dataset.
        with open(dataset_path) as f:
            dataset = json.load(f)
        # Filter out the conversations with less than 2 turns.
1224
1225
1226
1227
1228
        dataset = [
            data
            for data in dataset
            if len(data.get("conversations", data.get("conversation", []))) >= 2
        ]
Lianmin Zheng's avatar
Lianmin Zheng committed
1229
1230
        # Only keep the first two turns of each conversation.
        dataset = [
1231
1232
1233
1234
            (
                data.get("conversations", data.get("conversation", []))[0]["value"],
                data.get("conversations", data.get("conversation", []))[1]["value"],
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1235
1236
1237
1238
1239
1240
            for data in dataset
        ]
        # Shuffle the dataset.
        random.shuffle(dataset)

        # Filter out sequences that are too long or too short
1241
        input_requests: List[DatasetRow] = []
1242
1243
1244
1245
1246
        for data in dataset:
            i = len(input_requests)
            if i == num_prompts:
                break

Lianmin Zheng's avatar
Lianmin Zheng committed
1247
            # Tokenize the prompts and completions.
1248
            prompt = data[0]
Lianmin Zheng's avatar
Lianmin Zheng committed
1249
            prompt_token_ids = tokenizer.encode(prompt)
Lianmin Zheng's avatar
Lianmin Zheng committed
1250
1251
            prompt_len = len(prompt_token_ids)

1252
1253
1254
1255
            # Skip empty prompt
            if prompt_len == 0:
                continue

Yineng Zhang's avatar
Yineng Zhang committed
1256
            if prompt_len > input_lens[i]:
Lianmin Zheng's avatar
Lianmin Zheng committed
1257
1258
1259
1260
                input_ids = prompt_token_ids[: input_lens[i]]
            else:
                ratio = (input_lens[i] + prompt_len - 1) // prompt_len
                input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
1261
1262
1263
            input_content = input_ids
            if return_text:
                input_content = tokenizer.decode(input_content)
1264
1265
            input_requests.append(
                DatasetRow(
1266
                    prompt=input_content,
1267
1268
1269
1270
                    prompt_len=int(input_lens[i]),
                    output_len=int(output_lens[i]),
                )
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1271
1272
1273
1274
1275
    else:
        # Sample token ids from random integers. This can cause some NaN issues.
        offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
        input_requests = []
        for i in range(num_prompts):
1276
1277
1278
1279
1280
1281
            input_content = [
                (offsets[i] + i + j) % tokenizer.vocab_size
                for j in range(input_lens[i])
            ]
            if return_text:
                input_content = tokenizer.decode(input_content)
1282
1283
            input_requests.append(
                DatasetRow(
1284
                    prompt=input_content,
1285
1286
1287
1288
                    prompt_len=int(input_lens[i]),
                    output_len=int(output_lens[i]),
                )
            )
1289
1290
1291
1292
1293
1294

    print(f"#Input tokens: {np.sum(input_lens)}")
    print(f"#Output tokens: {np.sum(output_lens)}")
    return input_requests


1295
def parse_image_resolution(image_resolution: str) -> Tuple[int, int]:
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
    """Parse image resolution into (width, height).

    Supports presets '1080p', '720p', '360p' and custom 'heightxwidth' format
    (e.g., '1080x1920' means height=1080, width=1920).
    """
    resolution_to_size = {
        "4k": (3840, 2160),
        "1080p": (1920, 1080),
        "720p": (1280, 720),
        "360p": (640, 360),
    }
    if image_resolution in resolution_to_size:
        return resolution_to_size[image_resolution]

    res = image_resolution.strip().lower()
    if "x" in res:
        parts = res.split("x")
        if len(parts) == 2 and parts[0].isdigit() and parts[1].isdigit():
            height = int(parts[0])
            width = int(parts[1])
            if height > 0 and width > 0:
                return (width, height)

    raise ValueError(
1320
        f"Unsupported image resolution: {image_resolution}. "
1321
1322
1323
1324
        "Choose from 4k, 1080p, 720p, 360p, or provide custom 'heightxwidth' (e.g., 1080x1920)."
    )


1325
1326
1327
def create_mm_data_row(
    text_prompt, images: list, images_base64, output_len, processor, backend
):
1328
    try:
1329
1330
1331
1332
1333
1334
1335
1336
1337
        if type(processor).__name__ == "Phi4MMProcessor":
            # <|endoftext10|> is the image token used in the phi-4-multimodal model.
            content_items = text_prompt.replace("image 1", "|endoftext10|")
        else:
            content_items = [
                {"type": "image", "image": {"url": image_base64}}
                for image_base64 in images_base64
            ]
            content_items.append({"type": "text", "text": text_prompt})
1338
1339
1340
1341
1342
        prompt_str = processor.apply_chat_template(
            [{"role": "user", "content": content_items}],
            add_generation_prompt=True,
            tokenize=False,
        )
Mick's avatar
Mick committed
1343
1344
1345
    except Exception as e:
        # Note (Xinyuan): This is a workaround for an issue where some tokenizers do not support content as a list. (e.g. InternVL)
        print(f"Error applying chat template: {e}, fallback to <image> tag")
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
        # Some tokenizers do not support list content; fall back to a placeholder in the text
        prompt_str = f"<image>{text_prompt}"

    # Calculate total tokens (text + vision)
    prompt_len = processor(
        text=[prompt_str],
        images=images,
        padding=False,
        return_tensors="pt",
    )["input_ids"].numel()

    # Calculate text-only tokens
    try:
        # Create text-only version of the prompt
        text_only_prompt = processor.apply_chat_template(
            [{"role": "user", "content": text_prompt}],
            add_generation_prompt=True,
            tokenize=False,
        )
        text_prompt_len = processor(
            text=[text_only_prompt],
            padding=False,
            return_tensors="pt",
        )["input_ids"].numel()
    except Exception:
        # Fallback: just tokenize the text prompt directly
1372
1373
1374
1375
        tokenizer_to_use = (
            processor.tokenizer if hasattr(processor, "tokenizer") else processor
        )
        text_prompt_len = len(tokenizer_to_use.encode(text_prompt))
1376
1377
1378
1379

    # Vision tokens = total tokens - text tokens
    vision_prompt_len = prompt_len - text_prompt_len

1380
1381
1382
1383
1384
1385
1386
1387
    use_raw_prompt = backend in [
        "sglang-oai",
        "sglang-oai-chat",
        "vllm",
        "vllm-chat",
        "lmdeploy",
        "lmdeploy-chat",
    ]
1388
    return DatasetRow(
1389
        prompt=text_prompt if use_raw_prompt else prompt_str,
1390
1391
1392
1393
1394
1395
1396
1397
1398
        prompt_len=prompt_len,
        output_len=output_len,
        text_prompt_len=text_prompt_len,
        vision_prompt_len=vision_prompt_len,
        image_data=images_base64,
    )


def sample_image_requests(
1399
    num_requests: int,
1400
    image_count: int,
1401
1402
1403
    input_len: int,
    output_len: int,
    range_ratio: float,
1404
1405
1406
1407
    processor: AutoProcessor,
    image_content: str,
    image_format: str,
    image_resolution: str,
1408
    backend: str,
1409
) -> List[DatasetRow]:
1410
    """Generate requests with images.
1411

1412
    - Each request includes ``image_count`` images.
1413
1414
1415
1416
1417
1418
1419
    - Supported resolutions: 4k (3840x2160), 1080p (1920x1080), 720p (1280x720), 360p (640x360),
      or custom 'heightxwidth' (e.g., 1080x1920).
    - Text lengths follow the 'random' dataset sampling rule. ``prompt_len``
      only counts text tokens and excludes image data.
    """

    # Parse resolution (supports presets and 'heightxwidth')
1420
    width, height = parse_image_resolution(image_resolution)
1421
1422

    # Check for potentially problematic combinations and warn user
1423
    if width * height >= 1920 * 1080 and image_count * num_requests >= 100:
1424
        warnings.warn(
1425
            f"High resolution ({width}x{height}) with {image_count * num_requests} total images "
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
            f"may take a long time. Consider reducing resolution or image count.",
            UserWarning,
            stacklevel=2,
        )

    # Sample text lengths
    input_lens = np.random.randint(
        max(int(input_len * range_ratio), 1), input_len + 1, size=num_requests
    )
    output_lens = np.random.randint(
        int(output_len * range_ratio), output_len + 1, size=num_requests
    )

1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
    def _gen_random_image_data_uri(
        width: int = width, height: int = height
    ) -> (Image, str, int):
        if image_content == "blank":
            # Generate blank white image
            arr = np.full((height, width, 3), 255, dtype=np.uint8)
        else:
            # Generate random colored image
            arr = (np.random.rand(height, width, 3) * 255).astype(np.uint8)
        img = Image.fromarray(arr)
1449
        buf = io.BytesIO()
1450
        img.save(buf, format=image_format, quality=85)
1451
        encoded = pybase64.b64encode(buf.getvalue()).decode("utf-8")
1452
1453
1454
        image_data = f"data:image/{image_format};base64,{encoded}"
        image_bytes = len(image_data.encode("utf-8"))
        return img, image_data, image_bytes
1455
1456

    dataset: List[DatasetRow] = []
1457
    total_image_bytes = 0
1458
1459
    for i in range(num_requests):
        # Generate text prompt
1460
        text_prompt = gen_prompt(processor.tokenizer, int(input_lens[i]))
1461
1462

        # Generate image list
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
        images, images_base64, images_bytes = zip(
            *[_gen_random_image_data_uri() for _ in range(image_count)]
        )
        total_image_bytes += sum(list(images_bytes))

        data_row = create_mm_data_row(
            text_prompt,
            list(images),
            list(images_base64),
            int(output_lens[i]),
            processor,
1474
            backend,
1475
1476
        )

1477
1478
        dataset.append(data_row)

1479
1480
    print(f"#Input tokens: {np.sum([x.prompt_len for x in dataset])}")
    print(f"#Output tokens: {np.sum([x.output_len for x in dataset])}")
1481
    print(
Mick's avatar
Mick committed
1482
        f"\nCreated {len(dataset)} {image_content} {image_format} images with average {total_image_bytes // num_requests} bytes per request"
1483
    )
1484
1485
1486
    return dataset


1487
1488
1489
1490
1491
1492
1493
def gen_prompt(tokenizer, token_num):
    """Generate a random prompt of specified token length using tokenizer vocabulary."""
    all_available_tokens = list(tokenizer.get_vocab().values())
    selected_tokens = random.choices(all_available_tokens, k=token_num)
    return tokenizer.decode(selected_tokens)


1494
1495
1496
1497
1498
1499
def get_gen_prefix_cache_path(args, tokenizer):
    """Create cache directory under ~/.cache/sglang/benchmark"""
    cache_dir = Path.home() / ".cache" / "sglang" / "benchmark"

    # Create a unique cache filename based on the generation parameters
    cache_key = (
1500
1501
        f"gen_shared_prefix_{args.gsp_num_groups}_{args.gsp_prompts_per_group}_"
        f"{args.gsp_system_prompt_len}_{args.gsp_question_len}_{args.gsp_output_len}_"
1502
1503
1504
1505
1506
        f"{tokenizer.__class__.__name__}.pkl"
    )
    return cache_dir / cache_key


1507
1508
1509
1510
1511
1512
1513
def sample_generated_shared_prefix_requests(
    num_groups: int,
    prompts_per_group: int,
    system_prompt_len: int,
    question_len: int,
    output_len: int,
    tokenizer: PreTrainedTokenizerBase,
1514
    args: argparse.Namespace,
1515
) -> List[DatasetRow]:
1516
1517
1518
1519
1520
1521
1522
    """Generate benchmark requests with shared system prompts using random tokens and caching."""
    cache_path = get_gen_prefix_cache_path(args, tokenizer)

    # Try to load from cache first
    if cache_path.exists():
        print(f"\nLoading cached generated input data from {cache_path}")
        with open(cache_path, "rb") as f:
1523
1524
            return pickle.load(f)

1525
1526
    print("\nGenerating new input data...")

1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
    # Generate system prompts for each group
    system_prompts = []
    for _ in range(num_groups):
        system_prompt = gen_prompt(tokenizer, system_prompt_len)
        system_prompts.append(system_prompt)

    # Generate questions
    questions = []
    for _ in range(num_groups * prompts_per_group):
        question = gen_prompt(tokenizer, question_len)
        questions.append(question)

    # Combine system prompts with questions
    input_requests = []
    total_input_tokens = 0
    total_output_tokens = 0

1544
    for group_idx in tqdm(range(num_groups), desc="Generating system prompt"):
1545
        system_prompt = system_prompts[group_idx]
1546
1547
1548
        for prompt_idx in tqdm(
            range(prompts_per_group), desc="Generating questions", leave=False
        ):
1549
1550
1551
1552
            question = questions[group_idx * prompts_per_group + prompt_idx]
            full_prompt = f"{system_prompt}\n\n{question}"
            prompt_len = len(tokenizer.encode(full_prompt))

1553
1554
            input_requests.append(
                DatasetRow(
1555
1556
1557
                    prompt=full_prompt,
                    prompt_len=prompt_len,
                    output_len=output_len,
1558
1559
                )
            )
1560
1561
1562
            total_input_tokens += prompt_len
            total_output_tokens += output_len

1563
1564
1565
1566
    # Shuffle questions
    random.shuffle(input_requests)

    # Print statistics
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
    print(f"\nGenerated shared prefix dataset statistics:")
    print(f"Number of groups: {num_groups}")
    print(f"Prompts per group: {prompts_per_group}")
    print(f"Total prompts: {len(input_requests)}")
    print(f"Total input tokens: {total_input_tokens}")
    print(f"Total output tokens: {total_output_tokens}")
    print(
        f"Average system prompt length: {sum(len(tokenizer.encode(sp)) for sp in system_prompts) / len(system_prompts):.1f} tokens"
    )
    print(
        f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
    )
1579
1580
1581
1582
1583
1584

    # Save to cache
    cache_path.parent.mkdir(parents=True, exist_ok=True)
    print(f"Caching generated input data to {cache_path}")
    with open(cache_path, "wb") as f:
        pickle.dump(input_requests, f)
1585
1586
1587
1588

    return input_requests


zhyncs's avatar
zhyncs committed
1589
async def get_request(
1590
    input_requests: List[DatasetRow],
zhyncs's avatar
zhyncs committed
1591
    request_rate: float,
1592
1593
    use_trace_timestamps: bool = False,
    slowdown_factor: float = 1.0,
1594
) -> AsyncGenerator[DatasetRow, None]:
1595
1596
1597
1598
1599
1600
    if use_trace_timestamps:
        print(
            f"Using trace timestamps for request generation with slowdown factor {slowdown_factor}."
        )
        # Sort requests by timestamp for correct replay
        input_requests.sort(key=lambda r: r.timestamp)
zhyncs's avatar
zhyncs committed
1601

1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
        start_time = time.perf_counter()
        trace_start_time_ms = input_requests[0].timestamp if input_requests else 0

        for request in input_requests:
            trace_time_s = (request.timestamp - trace_start_time_ms) / 1000.0
            target_arrival_time = start_time + (trace_time_s * slowdown_factor)

            sleep_duration = target_arrival_time - time.perf_counter()
            if sleep_duration > 0:
                await asyncio.sleep(sleep_duration)

            yield request
    else:
        input_requests_iter = iter(input_requests)
        for request in input_requests_iter:
            yield request

            if request_rate == float("inf"):
                # If the request rate is infinity, then we don't need to wait.
                continue
zhyncs's avatar
zhyncs committed
1622

1623
1624
1625
1626
            # Sample the request interval from the exponential distribution.
            interval = np.random.exponential(1.0 / request_rate)
            # The next request will be sent after the interval.
            await asyncio.sleep(interval)
zhyncs's avatar
zhyncs committed
1627
1628
1629


def calculate_metrics(
1630
    input_requests: List[DatasetRow],
zhyncs's avatar
zhyncs committed
1631
1632
1633
    outputs: List[RequestFuncOutput],
    dur_s: float,
    tokenizer: PreTrainedTokenizerBase,
1634
    backend: str,
1635
    accept_length: Optional[float] = None,
zhyncs's avatar
zhyncs committed
1636
) -> Tuple[BenchmarkMetrics, List[int]]:
Ying Sheng's avatar
Ying Sheng committed
1637
1638
    output_lens: List[int] = []
    retokenized_output_lens: List[int] = []
zhyncs's avatar
zhyncs committed
1639
    total_input = 0
1640
1641
    total_input_text = 0
    total_input_vision = 0
zhyncs's avatar
zhyncs committed
1642
1643
1644
1645
    completed = 0
    itls: List[float] = []
    tpots: List[float] = []
    ttfts: List[float] = []
zhyncs's avatar
zhyncs committed
1646
    e2e_latencies: List[float] = []
1647
1648
1649
1650
1651
1652
1653
1654
    retokenized_itls: List[float] = []

    use_retokenized_itl = (
        accept_length is not None
        and accept_length > 0
        and backend in ("sglang-oai", "sglang-oai-chat")
    )

zhyncs's avatar
zhyncs committed
1655
1656
    for i in range(len(outputs)):
        if outputs[i].success:
Ying Sheng's avatar
Ying Sheng committed
1657
1658
1659
            output_len = outputs[i].output_len
            output_lens.append(output_len)
            retokenized_output_len = len(
Lianmin Zheng's avatar
Lianmin Zheng committed
1660
                tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
Ying Sheng's avatar
Ying Sheng committed
1661
1662
            )
            retokenized_output_lens.append(retokenized_output_len)
1663
1664
1665
            total_input += input_requests[i].prompt_len
            total_input_text += input_requests[i].text_prompt_len
            total_input_vision += input_requests[i].vision_prompt_len
zhyncs's avatar
zhyncs committed
1666
1667
            if output_len > 1:
                tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
1668
1669
1670
1671
1672
1673
1674
1675
1676
            if use_retokenized_itl:
                for k, itl in enumerate(outputs[i].itl):
                    num_tokens = len(
                        tokenizer.encode(
                            outputs[i].text_chunks[k], add_special_tokens=False
                        )
                    )
                    adjusted_itl = itl / num_tokens
                    retokenized_itls.extend([adjusted_itl] * num_tokens)
1677
1678
            else:
                itls += outputs[i].itl
zhyncs's avatar
zhyncs committed
1679
            ttfts.append(outputs[i].ttft)
zhyncs's avatar
zhyncs committed
1680
1681
1682

            e2e_latencies.append(outputs[i].latency)

zhyncs's avatar
zhyncs committed
1683
1684
            completed += 1
        else:
Ying Sheng's avatar
Ying Sheng committed
1685
1686
            output_lens.append(0)
            retokenized_output_lens.append(0)
zhyncs's avatar
zhyncs committed
1687
1688
1689
1690
1691
1692
1693

    if completed == 0:
        warnings.warn(
            "All requests failed. This is likely due to a misconfiguration "
            "on the benchmark arguments.",
            stacklevel=2,
        )
1694
1695

    itls = retokenized_itls if use_retokenized_itl else itls
zhyncs's avatar
zhyncs committed
1696
1697
1698
    metrics = BenchmarkMetrics(
        completed=completed,
        total_input=total_input,
1699
1700
        total_input_text=total_input_text,
        total_input_vision=total_input_vision,
Ying Sheng's avatar
Ying Sheng committed
1701
1702
        total_output=sum(output_lens),
        total_output_retokenized=sum(retokenized_output_lens),
zhyncs's avatar
zhyncs committed
1703
1704
        request_throughput=completed / dur_s,
        input_throughput=total_input / dur_s,
Ying Sheng's avatar
Ying Sheng committed
1705
1706
        output_throughput=sum(output_lens) / dur_s,
        output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,
1707
1708
1709
        total_throughput=(total_input + sum(output_lens)) / dur_s,
        total_throughput_retokenized=(total_input + sum(retokenized_output_lens))
        / dur_s,
zhyncs's avatar
zhyncs committed
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
        mean_ttft_ms=np.mean(ttfts or 0)
        * 1000,  # ttfts is empty if streaming is not supported by backend
        median_ttft_ms=np.median(ttfts or 0) * 1000,
        std_ttft_ms=np.std(ttfts or 0) * 1000,
        p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
        mean_tpot_ms=np.mean(tpots or 0) * 1000,
        median_tpot_ms=np.median(tpots or 0) * 1000,
        std_tpot_ms=np.std(tpots or 0) * 1000,
        p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
        mean_itl_ms=np.mean(itls or 0) * 1000,
        median_itl_ms=np.median(itls or 0) * 1000,
        std_itl_ms=np.std(itls or 0) * 1000,
1722
        p95_itl_ms=np.percentile(itls or 0, 95) * 1000,
zhyncs's avatar
zhyncs committed
1723
        p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
1724
        max_itl_ms=np.max(itls or 0) * 1000,
zhyncs's avatar
zhyncs committed
1725
1726
        mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
        median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
1727
1728
        std_e2e_latency_ms=np.std(e2e_latencies) * 1000,
        p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000,
1729
        concurrency=np.sum(e2e_latencies) / dur_s,
zhyncs's avatar
zhyncs committed
1730
1731
    )

Ying Sheng's avatar
Ying Sheng committed
1732
    return metrics, output_lens
zhyncs's avatar
zhyncs committed
1733
1734
1735
1736
1737


async def benchmark(
    backend: str,
    api_url: str,
1738
    base_url: str,
zhyncs's avatar
zhyncs committed
1739
1740
    model_id: str,
    tokenizer: PreTrainedTokenizerBase,
1741
    input_requests: List[DatasetRow],
zhyncs's avatar
zhyncs committed
1742
    request_rate: float,
1743
    max_concurrency: Optional[int],
zhyncs's avatar
zhyncs committed
1744
    disable_tqdm: bool,
1745
    lora_names: List[str],
1746
    extra_request_body: Dict[str, Any],
1747
    profile: bool,
1748
    pd_separated: bool = False,
Yineng Zhang's avatar
Yineng Zhang committed
1749
    flush_cache: bool = False,
1750
    warmup_requests: int = 1,
1751
1752
1753
    use_trace_timestamps: bool = False,
    mooncake_slowdown_factor=1.0,
    mooncake_num_rounds=1,
1754
1755
    profile_prefill_url: Optional[List[str]] = None,
    profile_decode_url: Optional[List[str]] = None,
zhyncs's avatar
zhyncs committed
1756
1757
1758
1759
1760
1761
):
    if backend in ASYNC_REQUEST_FUNCS:
        request_func = ASYNC_REQUEST_FUNCS[backend]
    else:
        raise ValueError(f"Unknown backend: {backend}")

1762
    # Limit concurrency
1763
1764
1765
1766
1767
1768
1769
1770
1771
    # From https://github.com/vllm-project/vllm/pull/9390
    semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None

    async def limited_request_func(request_func_input, pbar):
        if semaphore is None:
            return await request_func(request_func_input=request_func_input, pbar=pbar)
        async with semaphore:
            return await request_func(request_func_input=request_func_input, pbar=pbar)

1772
    # Warmup
1773
    print(f"Starting warmup with {warmup_requests} sequences...")
1774

1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
    # Handle the data structure difference for the warmup request
    if args.dataset_name == "mooncake":
        # For mooncake, input_requests is a list of dicts.
        # We need to build a temporary DatasetRow for the warmup phase.
        warmup_record = input_requests[0]

        # Build prompt from hash_ids, just like in the async generator
        hash_ids = warmup_record.get("hash_ids", [])
        prompt_text = ""
        for hash_id in hash_ids:
            prompt_text += f"{hash_id}" + " ".join(["hi"] * 512)
        prompt_text += "Can you tell me a detailed story in 1000 words?"

        output_len = warmup_record.get("output_length", 32)
        prompt_len = len(tokenizer.encode(prompt_text))

        # Create a temporary DatasetRow object for warmup
        test_request = DatasetRow(
            prompt=prompt_text,
            prompt_len=prompt_len,
            output_len=output_len,
            image_data=None,  # Mooncake doesn't have image data
        )
    else:
        # For all other datasets, input_requests is a list of DatasetRow objects
        test_request = input_requests[0]
1801

1802
    if lora_names is not None and len(lora_names) != 0:
1803
1804
1805
1806
        lora_name = lora_names[0]
    else:
        lora_name = None

1807
    # Create the test input once
zhyncs's avatar
zhyncs committed
1808
1809
    test_input = RequestFuncInput(
        model=model_id,
1810
        prompt=test_request.prompt,
zhyncs's avatar
zhyncs committed
1811
        api_url=api_url,
1812
1813
        prompt_len=test_request.prompt_len,
        output_len=min(test_request.output_len, 32),
1814
        lora_name=lora_name,
1815
        image_data=test_request.image_data,
1816
        extra_request_body=extra_request_body,
zhyncs's avatar
zhyncs committed
1817
    )
1818
1819
1820

    # Run warmup requests
    warmup_tasks = []
1821
    for _ in range(warmup_requests):
1822
1823
1824
1825
1826
1827
1828
        warmup_tasks.append(
            asyncio.create_task(request_func(request_func_input=test_input))
        )

    warmup_outputs = await asyncio.gather(*warmup_tasks)

    # Check if at least one warmup request succeeded
1829
    if warmup_requests > 0 and not any(output.success for output in warmup_outputs):
zhyncs's avatar
zhyncs committed
1830
        raise ValueError(
1831
1832
            "Warmup failed - Please make sure benchmark arguments "
            f"are correctly specified. Error: {warmup_outputs[0].error}"
zhyncs's avatar
zhyncs committed
1833
1834
        )
    else:
1835
1836
1837
        print(
            f"Warmup completed with {args.warmup_requests} sequences. Starting main benchmark run..."
        )
zhyncs's avatar
zhyncs committed
1838

1839
    # Flush cache
Yineng Zhang's avatar
Yineng Zhang committed
1840
    if ("sglang" in backend and _get_bool_env_var("SGLANG_IS_IN_CI")) or flush_cache:
1841
        requests.post(base_url + "/flush_cache", headers=get_auth_headers())
1842
1843

    time.sleep(1.0)
1844

1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
    # Build profile URLs for PD separated mode (do this once at the beginning)
    pd_profile_urls = []
    if profile and pd_separated:
        pd_profile_urls = _build_profile_urls(profile_prefill_url, profile_decode_url)
        if not pd_profile_urls:
            print(
                "Warning: PD separated mode requires --profile-prefill-url or --profile-decode-url"
            )
            print("Skipping profiler start. Please specify worker URLs for profiling.")

1855
    # Start profiler
1856
    if profile:
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
        if pd_separated:
            if pd_profile_urls:
                await _call_profile_pd(pd_profile_urls, "start")
        else:
            print("Starting profiler...")
            profile_output = await async_request_profile(
                api_url=base_url + "/start_profile"
            )
            if profile_output.success:
                print("Profiler started")
1867

1868
    # Run all requests
zhyncs's avatar
zhyncs committed
1869
1870
    benchmark_start_time = time.perf_counter()
    tasks: List[asyncio.Task] = []
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
    pbar_total = len(input_requests)
    if (
        backend == "sglang" and args.dataset_name == "mooncake"
    ):  # Assuming mooncake is mainly for sglang or similar backends
        print("Using time-based Mooncake request scheduler, ignoring --request-rate.")
        request_generator = get_mooncake_request_over_time(
            input_requests, tokenizer, mooncake_slowdown_factor, mooncake_num_rounds
        )
        print(
            f"Starting Mooncake trace replay. Sessions: {len(input_requests)}, Rounds per session: {mooncake_num_rounds}. Slowdown factor: {mooncake_slowdown_factor}"
        )
        pbar_total *= args.mooncake_num_rounds
    else:
        request_generator = get_request(input_requests, request_rate)

    pbar = None if disable_tqdm else tqdm(total=pbar_total)
    async for request in request_generator:
1888
        if lora_names is not None and len(lora_names) != 0:
1889
1890
1891
1892
1893
            idx = random.randint(0, len(lora_names) - 1)
            lora_name = lora_names[idx]
        else:
            lora_name = None

zhyncs's avatar
zhyncs committed
1894
1895
        request_func_input = RequestFuncInput(
            model=model_id,
1896
            prompt=request.prompt,
zhyncs's avatar
zhyncs committed
1897
            api_url=api_url,
1898
1899
            prompt_len=request.prompt_len,
            output_len=request.output_len,
1900
            lora_name=lora_name,
1901
            image_data=request.image_data,
1902
            extra_request_body=extra_request_body,
1903
            timestamp=request.timestamp,
zhyncs's avatar
zhyncs committed
1904
        )
1905

zhyncs's avatar
zhyncs committed
1906
1907
        tasks.append(
            asyncio.create_task(
1908
                limited_request_func(request_func_input=request_func_input, pbar=pbar)
zhyncs's avatar
zhyncs committed
1909
1910
1911
1912
            )
        )
    outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)

1913
    # Stop profiler
1914
    if profile:
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
        if pd_separated:
            if pd_profile_urls:
                await _call_profile_pd(pd_profile_urls, "stop")
        else:
            print("Stopping profiler...")
            profile_output = await async_request_profile(
                api_url=base_url + "/stop_profile"
            )
            if profile_output.success:
                print("Profiler stopped")
1925

zhyncs's avatar
zhyncs committed
1926
1927
1928
    if pbar is not None:
        pbar.close()

1929
    if "sglang" in backend:
Muqi Li's avatar
Muqi Li committed
1930
1931
1932
        server_info = requests.get(
            base_url + "/get_server_info", headers=get_auth_headers()
        )
Yineng Zhang's avatar
Yineng Zhang committed
1933
        if server_info.status_code == 200:
1934
1935
1936
            server_info_json = server_info.json()
            if "decode" in server_info_json:
                server_info_json = server_info_json["decode"][0]
1937
1938
1939
1940
1941
1942
1943
1944
1945
            if (
                "internal_states" in server_info_json
                and server_info_json["internal_states"]
            ):
                accept_length = server_info_json["internal_states"][0].get(
                    "avg_spec_accept_length", None
                )
            else:
                accept_length = None
1946
        else:
Yineng Zhang's avatar
Yineng Zhang committed
1947
            accept_length = None
1948
1949
1950
    else:
        accept_length = None

1951
    # Compute metrics and print results
zhyncs's avatar
zhyncs committed
1952
    benchmark_duration = time.perf_counter() - benchmark_start_time
Ying Sheng's avatar
Ying Sheng committed
1953
    metrics, output_lens = calculate_metrics(
zhyncs's avatar
zhyncs committed
1954
1955
1956
1957
        input_requests=input_requests,
        outputs=outputs,
        dur_s=benchmark_duration,
        tokenizer=tokenizer,
1958
        backend=backend,
1959
        accept_length=accept_length,
zhyncs's avatar
zhyncs committed
1960
1961
1962
    )

    print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
1963
    print("{:<40} {:<10}".format("Backend:", backend))
1964
1965
1966
1967
1968
    print(
        "{:<40} {:<10}".format(
            "Traffic request rate:", "trace" if use_trace_timestamps else request_rate
        )
    )
1969
1970
    print(
        "{:<40} {:<10}".format(
1971
            "Max request concurrency:",
1972
1973
1974
            max_concurrency if max_concurrency else "not set",
        )
    )
zhyncs's avatar
zhyncs committed
1975
1976
1977
    print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
    print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
    print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
1978
1979
1980
1981
    print("{:<40} {:<10}".format("Total input text tokens:", metrics.total_input_text))
    print(
        "{:<40} {:<10}".format("Total input vision tokens:", metrics.total_input_vision)
    )
zhyncs's avatar
zhyncs committed
1982
    print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
Ying Sheng's avatar
Ying Sheng committed
1983
1984
1985
1986
1987
    print(
        "{:<40} {:<10}".format(
            "Total generated tokens (retokenized):", metrics.total_output_retokenized
        )
    )
zhyncs's avatar
zhyncs committed
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
    print(
        "{:<40} {:<10.2f}".format(
            "Request throughput (req/s):", metrics.request_throughput
        )
    )
    print(
        "{:<40} {:<10.2f}".format(
            "Input token throughput (tok/s):", metrics.input_throughput
        )
    )
    print(
        "{:<40} {:<10.2f}".format(
            "Output token throughput (tok/s):", metrics.output_throughput
        )
    )
2003
2004
2005
2006
2007
    print(
        "{:<40} {:<10.2f}".format(
            "Total token throughput (tok/s):", metrics.total_throughput
        )
    )
2008
    print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency))
2009
2010
    if accept_length:
        print("{:<40} {:<10.2f}".format("Accept length:", accept_length))
zhyncs's avatar
zhyncs committed
2011
2012
2013
2014
2015
2016
2017
2018
2019
    print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
    print(
        "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
    )
    print(
        "{:<40} {:<10.2f}".format(
            "Median E2E Latency (ms):", metrics.median_e2e_latency_ms
        )
    )
zhyncs's avatar
zhyncs committed
2020
2021
2022
2023
    print("{s:{c}^{n}}".format(s="Time to First Token", n=50, c="-"))
    print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
    print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
    print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
2024
    print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
zhyncs's avatar
zhyncs committed
2025
2026
    print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
    print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
2027
    print("{:<40} {:<10.2f}".format("P95 ITL (ms):", metrics.p95_itl_ms))
zhyncs's avatar
zhyncs committed
2028
    print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
2029
    print("{:<40} {:<10.2f}".format("Max ITL (ms):", metrics.max_itl_ms))
zhyncs's avatar
zhyncs committed
2030
2031
    print("=" * 50)

zhyncs's avatar
zhyncs committed
2032
2033
2034
2035
2036
2037
    if (
        metrics.median_ttft_ms is not None
        and metrics.mean_itl_ms is not None
        and metrics.output_throughput is not None
    ):
        result = {
2038
            # Arguments
fzyzcjy's avatar
fzyzcjy committed
2039
            "tag": getattr(args, "tag", None),
zhyncs's avatar
zhyncs committed
2040
2041
            "backend": args.backend,
            "dataset_name": args.dataset_name,
2042
            "request_rate": "trace" if use_trace_timestamps else request_rate,
2043
            "max_concurrency": max_concurrency,
2044
2045
2046
2047
2048
2049
2050
            "sharegpt_output_len": args.sharegpt_output_len,
            "random_input_len": args.random_input_len,
            "random_output_len": args.random_output_len,
            "random_range_ratio": args.random_range_ratio,
            # Results
            "duration": benchmark_duration,
            "completed": metrics.completed,
2051
            "total_input_tokens": metrics.total_input,
2052
2053
            "total_input_text_tokens": metrics.total_input_text,
            "total_input_vision_tokens": metrics.total_input_vision,
2054
2055
            "total_output_tokens": metrics.total_output,
            "total_output_tokens_retokenized": metrics.total_output_retokenized,
2056
2057
2058
            "request_throughput": metrics.request_throughput,
            "input_throughput": metrics.input_throughput,
            "output_throughput": metrics.output_throughput,
2059
2060
            "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
            "median_e2e_latency_ms": metrics.median_e2e_latency_ms,
2061
2062
            "std_e2e_latency_ms": metrics.std_e2e_latency_ms,
            "p99_e2e_latency_ms": metrics.p99_e2e_latency_ms,
2063
            "mean_ttft_ms": metrics.mean_ttft_ms,
2064
            "median_ttft_ms": metrics.median_ttft_ms,
2065
2066
2067
2068
2069
2070
            "std_ttft_ms": metrics.std_ttft_ms,
            "p99_ttft_ms": metrics.p99_ttft_ms,
            "mean_tpot_ms": metrics.mean_tpot_ms,
            "median_tpot_ms": metrics.median_tpot_ms,
            "std_tpot_ms": metrics.std_tpot_ms,
            "p99_tpot_ms": metrics.p99_tpot_ms,
2071
            "mean_itl_ms": metrics.mean_itl_ms,
2072
            "median_itl_ms": metrics.median_itl_ms,
2073
            "std_itl_ms": metrics.std_itl_ms,
2074
            "p95_itl_ms": metrics.p95_itl_ms,
2075
            "p99_itl_ms": metrics.p99_itl_ms,
2076
            "concurrency": metrics.concurrency,
2077
            "accept_length": accept_length,
zhyncs's avatar
zhyncs committed
2078
2079
2080
2081
        }
    else:
        print(f"Error running benchmark for request rate: {request_rate}")
        print("-" * 30)
2082

zhyncs's avatar
zhyncs committed
2083
2084
2085
2086
2087
    # Determine output file name
    if args.output_file:
        output_file_name = args.output_file
    else:
        now = datetime.now().strftime("%m%d")
2088
        if args.dataset_name == "image":
2089
2090
            output_file_name = (
                f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_"
2091
2092
                f"{args.random_output_len}_{args.image_count}imgs_"
                f"{args.image_resolution}.jsonl"
2093
2094
            )
        elif args.dataset_name.startswith("random"):
zhyncs's avatar
zhyncs committed
2095
            output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
2096
        else:
2097
2098
2099
            output_file_name = (
                f"{args.backend}_{now}_{args.num_prompts}_{args.dataset_name}.jsonl"
            )
2100

2101
2102
2103
2104
2105
2106
2107
2108
2109
    result_details = {
        "input_lens": [output.prompt_len for output in outputs],
        "output_lens": output_lens,
        "ttfts": [output.ttft for output in outputs],
        "itls": [output.itl for output in outputs],
        "generated_texts": [output.generated_text for output in outputs],
        "errors": [output.error for output in outputs],
    }

zhyncs's avatar
zhyncs committed
2110
2111
    # Append results to a JSONL file
    with open(output_file_name, "a") as file:
2112
2113
2114
2115
2116
2117
2118
        if args.output_details:
            result_for_dump = result | result_details
        else:
            result_for_dump = result
        file.write(json.dumps(result_for_dump) + "\n")

    return result | result_details
zhyncs's avatar
zhyncs committed
2119
2120


2121
2122
2123
2124
2125
2126
2127
2128
2129
def check_chat_template(model_path):
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        return "chat_template" in tokenizer.init_kwargs
    except Exception as e:
        print(f"Fail to load tokenizer config with error={e}")
        return False


2130
2131
2132
2133
2134
2135
def set_global_args(args_: argparse.Namespace):
    """Set the global args."""
    global args
    args = args_


2136
2137
2138
2139
def run_benchmark(args_: argparse.Namespace):
    global args
    args = args_

2140
2141
2142
2143
    # Set default value for max_concurrency if not present
    if not hasattr(args, "max_concurrency"):
        args.max_concurrency = None

2144
2145
2146
2147
    # Set default value for warmup_requests if not present
    if not hasattr(args, "warmup_requests"):
        args.warmup_requests = 1

2148
2149
2150
    if not hasattr(args, "output_details"):
        args.output_details = False

2151
2152
2153
    if not hasattr(args, "tokenize_prompt"):
        args.tokenize_prompt = False

2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
    if not hasattr(args, "use_trace_timestamps"):
        args.use_trace_timestamps = False
    if not hasattr(args, "mooncake_slowdown_factor"):
        args.mooncake_slowdown_factor = 1.0

    if not hasattr(args, "mooncake_slowdown_factor"):
        args.mooncake_slowdown_factor = 1.0

    if not hasattr(args, "mooncake_num_rounds"):
        args.mooncake_num_rounds = 1

2165
2166
    print(f"benchmark_args={args}")

Lianmin Zheng's avatar
Lianmin Zheng committed
2167
    # Set global environments
2168
    set_ulimit()
zhyncs's avatar
zhyncs committed
2169
2170
2171
    random.seed(args.seed)
    np.random.seed(args.seed)

2172
2173
2174
2175
    extra_request_body = {}
    if args.extra_request_body:
        extra_request_body = json.loads(args.extra_request_body)

2176
2177
2178
2179
2180
    if args.tokenize_prompt:
        assert (
            args.backend == "sglang"
        ), "`--tokenize-prompt` only compatible with `--backend sglang` currently"

Lianmin Zheng's avatar
Lianmin Zheng committed
2181
    # Set url
zhyncs's avatar
zhyncs committed
2182
2183
2184
    if args.port is None:
        args.port = {
            "sglang": 30000,
2185
2186
            "sglang-native": 30000,
            "sglang-oai": 30000,
zhyncs's avatar
zhyncs committed
2187
2188
            "lmdeploy": 23333,
            "vllm": 8000,
2189
            "trt": 8000,
2190
            "gserver": 9988,
2191
            "truss": 8080,
zhyncs's avatar
zhyncs committed
2192
2193
2194
2195
2196
2197
2198
2199
        }.get(args.backend, 30000)

    model_url = (
        f"{args.base_url}/v1/models"
        if args.base_url
        else f"http://{args.host}:{args.port}/v1/models"
    )

2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
    if args.backend in ["sglang", "sglang-native"]:
        api_url = (
            f"{args.base_url}/generate"
            if args.base_url
            else f"http://{args.host}:{args.port}/generate"
        )
    elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
        api_url = (
            f"{args.base_url}/v1/completions"
            if args.base_url
            else f"http://{args.host}:{args.port}/v1/completions"
        )
2212
2213
2214
2215
2216
2217
    elif args.backend in ["sglang-oai-chat", "vllm-chat", "lmdeploy-chat"]:
        api_url = (
            f"{args.base_url}/v1/chat/completions"
            if args.base_url
            else f"http://{args.host}:{args.port}/v1/chat/completions"
        )
2218
    elif args.backend == "trt":
2219
2220
2221
2222
2223
2224
2225
2226
        api_url = (
            f"{args.base_url}/v2/models/ensemble/generate_stream"
            if args.base_url
            else f"http://{args.host}:{args.port}/v2/models/ensemble/generate_stream"
        )
        if args.model is None:
            print("Please provide a model using `--model` when using `trt` backend.")
            sys.exit(1)
2227
    elif args.backend == "gserver":
Lianmin Zheng's avatar
Lianmin Zheng committed
2228
2229
        api_url = args.base_url if args.base_url else f"{args.host}:{args.port}"
        args.model = args.model or "default"
2230
2231
2232
2233
2234
2235
    elif args.backend == "truss":
        api_url = (
            f"{args.base_url}/v1/models/model:predict"
            if args.base_url
            else f"http://{args.host}:{args.port}/v1/models/model:predict"
        )
2236
2237
2238
    base_url = (
        f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url
    )
2239

Lianmin Zheng's avatar
Lianmin Zheng committed
2240
    # Get model name
zhyncs's avatar
zhyncs committed
2241
    if args.model is None:
2242
2243
2244
2245
2246
        if args.backend == "truss":
            print(
                "Please provide a model with `--model` when using truss backend. e.g. --model meta-llama/Llama-3.1-8B-Instruct"
            )
            sys.exit(1)
zhyncs's avatar
zhyncs committed
2247
        try:
2248
            response = requests.get(model_url, headers=get_auth_headers())
zhyncs's avatar
zhyncs committed
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
            model_list = response.json().get("data", [])
            args.model = model_list[0]["id"] if model_list else None
        except Exception as e:
            print(f"Failed to fetch model from {model_url}. Error: {e}")
            print(
                "Please specify the correct host and port using `--host` and `--port`."
            )
            sys.exit(1)

    if args.model is None:
        print("No model specified or found. Please provide a model using `--model`.")
        sys.exit(1)

2262
2263
2264
2265
2266
2267
    if not check_chat_template(args.model):
        print(
            "\nWARNING It is recommended to use the `Chat` or `Instruct` model for benchmarking.\n"
            "Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\n"
        )

2268
2269
2270
2271
2272
2273
    if args.dataset_name in ["image", "mmmu"]:
        args.apply_chat_template = True
        assert (
            not args.tokenize_prompt
        ), "`--tokenize-prompt` not compatible with image dataset"

zhyncs's avatar
zhyncs committed
2274
2275
    print(f"{args}\n")

Lianmin Zheng's avatar
Lianmin Zheng committed
2276
    # Read dataset
zhyncs's avatar
zhyncs committed
2277
    backend = args.backend
2278
    model_id = args.served_model_name or args.model
zhyncs's avatar
zhyncs committed
2279
2280
    tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
    tokenizer = get_tokenizer(tokenizer_id)
2281
    input_requests = get_dataset(args, tokenizer, model_id)
zhyncs's avatar
zhyncs committed
2282

Yineng Zhang's avatar
Yineng Zhang committed
2283
2284
2285
2286
    # compatible with SimpleNamespace
    if not hasattr(args, "flush_cache"):
        args.flush_cache = False

2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
    return asyncio.run(
        benchmark(
            backend=backend,
            api_url=api_url,
            base_url=base_url,
            model_id=model_id,
            tokenizer=tokenizer,
            input_requests=input_requests,
            request_rate=args.request_rate,
            max_concurrency=args.max_concurrency,
            disable_tqdm=args.disable_tqdm,
2298
            lora_names=args.lora_name,
2299
2300
            extra_request_body=extra_request_body,
            profile=args.profile,
2301
            pd_separated=args.pd_separated,
Yineng Zhang's avatar
Yineng Zhang committed
2302
            flush_cache=args.flush_cache,
2303
            warmup_requests=args.warmup_requests,
2304
2305
2306
            use_trace_timestamps=args.use_trace_timestamps,
            mooncake_slowdown_factor=args.mooncake_slowdown_factor,
            mooncake_num_rounds=args.mooncake_num_rounds,
2307
2308
            profile_prefill_url=getattr(args, "profile_prefill_url", None),
            profile_decode_url=getattr(args, "profile_decode_url", None),
Lianmin Zheng's avatar
Lianmin Zheng committed
2309
        )
2310
    )
zhyncs's avatar
zhyncs committed
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323


def set_ulimit(target_soft_limit=65535):
    resource_type = resource.RLIMIT_NOFILE
    current_soft, current_hard = resource.getrlimit(resource_type)

    if current_soft < target_soft_limit:
        try:
            resource.setrlimit(resource_type, (target_soft_limit, current_hard))
        except ValueError as e:
            print(f"Fail to set RLIMIT_NOFILE: {e}")


2324
2325
2326
2327
2328
2329
2330
class LoRAPathAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        setattr(namespace, self.dest, [])
        for lora_name in values:
            getattr(namespace, self.dest).append(lora_name)


zhyncs's avatar
zhyncs committed
2331
if __name__ == "__main__":
2332
    parser = ArgumentParser(description="Benchmark the online serving throughput.")
zhyncs's avatar
zhyncs committed
2333
2334
2335
2336
    parser.add_argument(
        "--backend",
        type=str,
        choices=list(ASYNC_REQUEST_FUNCS.keys()),
2337
        default="sglang",
zhyncs's avatar
zhyncs committed
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
        help="Must specify a backend, depending on the LLM Inference Engine.",
    )
    parser.add_argument(
        "--base-url",
        type=str,
        default=None,
        help="Server or API base url if not using http host and port.",
    )
    parser.add_argument(
        "--host", type=str, default="0.0.0.0", help="Default host is 0.0.0.0."
    )
    parser.add_argument(
        "--port",
        type=int,
        help="If not set, the default port is configured according to its default value for different LLM Inference Engines.",
    )
    parser.add_argument(
2355
2356
2357
        "--dataset-name",
        type=str,
        default="sharegpt",
2358
2359
2360
2361
2362
2363
        choices=[
            "sharegpt",
            "random",
            "random-ids",
            "generated-shared-prefix",
            "mmmu",
2364
            "image",
2365
            "mooncake",
2366
        ],
2367
2368
2369
2370
        help="Name of the dataset to benchmark on.",
    )
    parser.add_argument(
        "--dataset-path", type=str, default="", help="Path to the dataset."
zhyncs's avatar
zhyncs committed
2371
2372
2373
2374
2375
2376
    )
    parser.add_argument(
        "--model",
        type=str,
        help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
    )
2377
2378
2379
2380
2381
    parser.add_argument(
        "--served-model-name",
        type=str,
        help="The name of the model as served by the serving service. If not set, this defaults to the value of --model.",
    )
zhyncs's avatar
zhyncs committed
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
    parser.add_argument(
        "--tokenizer",
        type=str,
        help="Name or path of the tokenizer. If not set, using the model conf.",
    )
    parser.add_argument(
        "--num-prompts",
        type=int,
        default=1000,
        help="Number of prompts to process. Default is 1000.",
    )
    parser.add_argument(
        "--sharegpt-output-len",
        type=int,
        default=None,
        help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
    )
2399
2400
2401
2402
2403
2404
    parser.add_argument(
        "--sharegpt-context-len",
        type=int,
        default=None,
        help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
    )
2405
2406
2407
    parser.add_argument(
        "--random-input-len",
        type=int,
2408
        default=1024,
2409
        help="Number of input tokens per request, used only for random and image dataset.",
2410
2411
2412
    )
    parser.add_argument(
        "--random-output-len",
2413
        default=1024,
2414
        type=int,
2415
        help="Number of output tokens per request, used only for random and image dataset.",
2416
2417
2418
2419
    )
    parser.add_argument(
        "--random-range-ratio",
        type=float,
Yineng Zhang's avatar
Yineng Zhang committed
2420
        default=0.0,
2421
        help="Range of sampled ratio of input/output length, "
2422
        "used only for random and image dataset.",
2423
    )
2424
    # image dataset args
2425
    parser.add_argument(
2426
        "--image-count",
2427
2428
        type=int,
        default=1,
2429
        help="Number of images per request (only available with the image dataset)",
2430
2431
    )
    parser.add_argument(
2432
        "--image-resolution",
2433
2434
2435
        type=str,
        default="1080p",
        help=(
2436
            "Resolution of images for image dataset. "
2437
2438
2439
            "Supports presets 4k/1080p/720p/360p or custom 'heightxwidth' (e.g., 1080x1920)."
        ),
    )
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
    parser.add_argument(
        "--image-format",
        type=str,
        default="jpeg",
        help=("Format of images for image dataset. " "Supports jpeg and png."),
    )
    parser.add_argument(
        "--image-content",
        type=str,
        default="random",
        help=("Content for images for image dataset. " "Supports random and blank."),
    )
zhyncs's avatar
zhyncs committed
2452
2453
2454
    parser.add_argument(
        "--request-rate",
        type=float,
2455
        default=float("inf"),
zhyncs's avatar
zhyncs committed
2456
        help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
min-xu-et's avatar
min-xu-et committed
2457
        "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
zhyncs's avatar
zhyncs committed
2458
    )
2459
2460
2461
2462
2463
    parser.add_argument(
        "--use-trace-timestamps",
        action="store_true",
        help="Use timestamps from the trace file for request scheduling. Only valid for 'mooncake' dataset.",
    )
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
    parser.add_argument(
        "--max-concurrency",
        type=int,
        default=None,
        help="Maximum number of concurrent requests. This can be used "
        "to help simulate an environment where a higher level component "
        "is enforcing a maximum number of concurrent requests. While the "
        "--request-rate argument controls the rate at which requests are "
        "initiated, this argument will control how many are actually allowed "
        "to execute at a time. This means that when used in combination, the "
        "actual request rate may be lower than specified with --request-rate, "
        "if the server is not processing requests fast enough to keep up.",
    )
2477
    parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
2478
2479
2480
    parser.add_argument(
        "--output-details", action="store_true", help="Output details of benchmarking."
    )
2481
2482
2483
2484
2485
    parser.add_argument(
        "--disable-tqdm",
        action="store_true",
        help="Specify to disable tqdm progress bar.",
    )
2486
2487
2488
2489
2490
    parser.add_argument(
        "--disable-stream",
        action="store_true",
        help="Disable streaming mode.",
    )
2491
    parser.add_argument(
2492
        "--return-logprob",
2493
        action="store_true",
2494
        help="Return logprob.",
2495
    )
2496
    parser.add_argument("--seed", type=int, default=1, help="The random seed.")
2497
    parser.add_argument(
2498
        "--disable-ignore-eos",
2499
        action="store_true",
2500
        help="Disable ignoring EOS.",
2501
    )
2502
2503
2504
2505
2506
2507
2508
    parser.add_argument(
        "--extra-request-body",
        metavar='{"key1": "value1", "key2": "value2"}',
        type=str,
        help="Append given JSON object to the request payload. You can use this to specify"
        "additional generate params like sampling params.",
    )
2509
2510
2511
2512
2513
    parser.add_argument(
        "--apply-chat-template",
        action="store_true",
        help="Apply chat template",
    )
2514
2515
2516
2517
2518
2519
2520
2521
2522
    parser.add_argument(
        "--profile",
        action="store_true",
        help="Use Torch Profiler. The endpoint must be launched with "
        "SGLANG_TORCH_PROFILER_DIR to enable profiler.",
    )
    parser.add_argument(
        "--lora-name",
        type=str,
2523
        nargs="*",
2524
        default=None,
2525
2526
        action=LoRAPathAction,
        help="The names of LoRA adapters. You can provide a list of names in the format {name} {name} {name}...",
2527
    )
2528
2529
2530
2531
2532
2533
2534
    parser.add_argument(
        "--prompt-suffix",
        type=str,
        default="",
        help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
    )
    parser.add_argument(
Yineng Zhang's avatar
Yineng Zhang committed
2535
        "--pd-separated",
2536
2537
2538
        action="store_true",
        help="Benchmark PD disaggregation server",
    )
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562

    # Create a mutually exclusive group for profiling URLs
    # In PD separated mode, prefill and decode workers must be profiled separately
    profile_url_group = parser.add_mutually_exclusive_group()
    profile_url_group.add_argument(
        "--profile-prefill-url",
        type=str,
        nargs="*",
        default=None,
        help="URL(s) of the prefill worker(s) for profiling in PD separated mode. "
        "Can specify multiple URLs: --profile-prefill-url http://localhost:30000 http://localhost:30001. "
        "NOTE: Cannot be used together with --profile-decode-url. "
        "In PD separated mode, prefill and decode workers must be profiled separately.",
    )
    profile_url_group.add_argument(
        "--profile-decode-url",
        type=str,
        nargs="*",
        default=None,
        help="URL(s) of the decode worker(s) for profiling in PD separated mode. "
        "Can specify multiple URLs: --profile-decode-url http://localhost:30010 http://localhost:30011. "
        "NOTE: Cannot be used together with --profile-prefill-url. "
        "In PD separated mode, prefill and decode workers must be profiled separately.",
    )
Yineng Zhang's avatar
Yineng Zhang committed
2563
2564
2565
2566
2567
    parser.add_argument(
        "--flush-cache",
        action="store_true",
        help="Flush the cache before running the benchmark",
    )
2568
2569
2570
2571
2572
2573
    parser.add_argument(
        "--warmup-requests",
        type=int,
        default=1,
        help="Number of warmup requests to run before the benchmark",
    )
2574
2575
2576
2577
2578
    parser.add_argument(
        "--tokenize-prompt",
        action="store_true",
        help="Use integer ids instead of string for inputs. Useful to control prompt lengths accurately",
    )
2579
2580
2581

    group = parser.add_argument_group("generated-shared-prefix dataset arguments")
    group.add_argument(
2582
        "--gsp-num-groups",
2583
2584
2585
2586
2587
        type=int,
        default=64,
        help="Number of system prompt groups for generated-shared-prefix dataset",
    )
    group.add_argument(
2588
        "--gsp-prompts-per-group",
2589
2590
2591
2592
2593
        type=int,
        default=16,
        help="Number of prompts per system prompt group for generated-shared-prefix dataset",
    )
    group.add_argument(
2594
        "--gsp-system-prompt-len",
2595
2596
2597
2598
2599
        type=int,
        default=2048,
        help="Target length in tokens for system prompts in generated-shared-prefix dataset",
    )
    group.add_argument(
2600
        "--gsp-question-len",
2601
2602
2603
2604
2605
        type=int,
        default=128,
        help="Target length in tokens for questions in generated-shared-prefix dataset",
    )
    group.add_argument(
2606
        "--gsp-output-len",
2607
2608
2609
2610
        type=int,
        default=256,
        help="Target length in tokens for outputs in generated-shared-prefix dataset",
    )
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
    mooncake_group = parser.add_argument_group("mooncake dataset arguments")
    mooncake_group.add_argument(
        "--mooncake-slowdown-factor",
        type=float,
        default=1.0,
        help="Slowdown factor for replaying the mooncake trace. "
        "A value of 2.0 means the replay is twice as slow. "
        "NOTE: --request-rate is IGNORED in mooncake mode.",
    )
    mooncake_group.add_argument(
        "--mooncake-num-rounds",
        type=int,
        default=1,
        help="Number of conversation rounds for each session in the mooncake dataset. "
        "A value > 1 will enable true multi-turn session benchmarking.",
    )
    mooncake_group.add_argument(
        "--mooncake-workload",
        type=str,
        default="conversation",
        choices=[
            "mooncake",
            "conversation",
            "synthetic",
            "toolagent",
        ],
        help="Underlying workload for the mooncake dataset.",
    )
2639
2640
2641
    parser.add_argument(
        "--tag", type=str, default=None, help="The tag to be dumped to output."
    )
zhyncs's avatar
zhyncs committed
2642
    args = parser.parse_args()
2643
    run_benchmark(args)