benchmark_serving.py 48.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
r"""Benchmark online serving throughput.
4
5

On the server side, run one of the following commands:
6
    vLLM OpenAI API server
Ethan Xu's avatar
Ethan Xu committed
7
8
    vllm serve <your_model> \
        --swap-space 16 \
9
        --disable-log-requests
10
11
12
13

On the client side, run:
    python benchmarks/benchmark_serving.py \
        --backend <backend> \
14
15
16
17
18
        --model <your_model> \
        --dataset-name sharegpt \
        --dataset-path <path to dataset> \
        --request-rate <request_rate> \ # By default <request_rate> is inf
        --num-prompts <num_prompts> # By default <num_prompts> is 1000
19

20
21
22
    when using tgi backend, add
        --endpoint /generate_stream
    to the end of the command above.
23
"""
24

25
26
import argparse
import asyncio
27
import gc
28
import json
29
import os
30
31
import random
import time
32
import warnings
33
from collections.abc import AsyncGenerator, Iterable
34
35
from dataclasses import dataclass
from datetime import datetime
36
from typing import Any, Literal, Optional
37
38

import numpy as np
39
from tqdm.asyncio import tqdm
40
from transformers import PreTrainedTokenizerBase
41

42
43
44
45
46
47
48
from backend_request_func import (
    ASYNC_REQUEST_FUNCS,
    OPENAI_COMPATIBLE_BACKENDS,
    RequestFuncInput,
    RequestFuncOutput,
)

49
50
51
52
try:
    from vllm.transformers_utils.tokenizer import get_tokenizer
except ImportError:
    from backend_request_func import get_tokenizer
53

54
55
56
57
58
try:
    from vllm.utils import FlexibleArgumentParser
except ImportError:
    from argparse import ArgumentParser as FlexibleArgumentParser

59
60
61
62
63
from benchmark_dataset import (
    AIMODataset,
    ASRDataset,
    BurstGPTDataset,
    ConversationDataset,
64
    CustomDataset,
65
66
67
68
69
70
71
72
73
74
    HuggingFaceDataset,
    InstructCoderDataset,
    MTBenchDataset,
    NextEditPredictionDataset,
    RandomDataset,
    SampleRequest,
    ShareGPTDataset,
    SonnetDataset,
    VisionArenaDataset,
)
75
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
76

77
78
MILLISECONDS_TO_SECONDS_CONVERSION = 1000

79
80
81
82
83
84
85

@dataclass
class BenchmarkMetrics:
    completed: int
    total_input: int
    total_output: int
    request_throughput: float
86
    request_goodput: float
87
    output_throughput: float
88
    total_token_throughput: float
89
90
    mean_ttft_ms: float
    median_ttft_ms: float
91
    std_ttft_ms: float
92
    percentiles_ttft_ms: list[tuple[float, float]]
93
94
    mean_tpot_ms: float
    median_tpot_ms: float
95
    std_tpot_ms: float
96
    percentiles_tpot_ms: list[tuple[float, float]]
97
98
    mean_itl_ms: float
    median_itl_ms: float
99
    std_itl_ms: float
100
    percentiles_itl_ms: list[tuple[float, float]]
101
102
103
104
105
106
    # E2EL stands for end-to-end latency per request.
    # It is the time taken on the client side from sending
    # a request to receiving a complete response.
    mean_e2el_ms: float
    median_e2el_ms: float
    std_e2el_ms: float
107
    percentiles_e2el_ms: list[tuple[float, float]]
108
109


110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def _get_current_request_rate(
    ramp_up_strategy: Optional[Literal["linear", "exponential"]],
    ramp_up_start_rps: Optional[int],
    ramp_up_end_rps: Optional[int],
    request_index: int,
    total_requests: int,
    request_rate: float,
) -> float:
    if (
        ramp_up_strategy
        and ramp_up_start_rps is not None
        and ramp_up_end_rps is not None
    ):
        progress = request_index / max(total_requests - 1, 1)
        if ramp_up_strategy == "linear":
            increase = (ramp_up_end_rps - ramp_up_start_rps) * progress
            return ramp_up_start_rps + increase
        elif ramp_up_strategy == "exponential":
            ratio = ramp_up_end_rps / ramp_up_start_rps
            return ramp_up_start_rps * (ratio**progress)
        else:
            raise ValueError(f"Unknown ramp-up strategy: {ramp_up_strategy}")
    return request_rate


135
async def get_request(
136
    input_requests: list[SampleRequest],
137
    request_rate: float,
138
    burstiness: float = 1.0,
139
140
141
142
    ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
    ramp_up_start_rps: Optional[int] = None,
    ramp_up_end_rps: Optional[int] = None,
) -> AsyncGenerator[tuple[SampleRequest, float], None]:
143
    """
144
    Asynchronously generates requests at a specified rate
145
    with OPTIONAL burstiness and OPTIONAL ramp-up strategy.
146

147
    Args:
148
        input_requests:
149
            A list of input requests, each represented as a SampleRequest.
150
        request_rate:
151
            The rate at which requests are generated (requests/s).
152
153
        burstiness (optional):
            The burstiness factor of the request generation.
154
155
156
            Only takes effect when request_rate is not inf.
            Default value is 1, which follows a Poisson process.
            Otherwise, the request intervals follow a gamma distribution.
157
158
            A lower burstiness value (0 < burstiness < 1) results
            in more bursty requests, while a higher burstiness value
159
            (burstiness > 1) results in a more uniform arrival of requests.
160
161
162
163
164
165
166
         ramp_up_strategy (optional):
            The ramp-up strategy. Can be "linear" or "exponential".
            If None, uses constant request rate (specified by request_rate).
        ramp_up_start_rps (optional):
            The starting request rate for ramp-up.
        ramp_up_end_rps (optional):
            The ending request rate for ramp-up.
167
168
    """
    assert burstiness > 0, (
169
170
        f"A positive burstiness factor is expected, but given {burstiness}."
    )
171
172
173
174
175
176
    # Convert to list to get length for ramp-up calculations
    if isinstance(input_requests, Iterable) and not isinstance(input_requests, list):
        input_requests = list(input_requests)

    total_requests = len(input_requests)
    request_index = 0
177

178
    for request in input_requests:
179
180
181
182
183
184
185
186
187
188
        current_request_rate = _get_current_request_rate(
            ramp_up_strategy,
            ramp_up_start_rps,
            ramp_up_end_rps,
            request_index,
            total_requests,
            request_rate,
        )

        yield request, current_request_rate
189

190
191
192
        request_index += 1

        if current_request_rate == float("inf"):
193
194
            # If the request rate is infinity, then we don't need to wait.
            continue
195

196
197
        theta = 1.0 / (current_request_rate * burstiness)

198
199
200
        # Sample the request interval from the gamma distribution.
        # If burstiness is 1, it follows exponential distribution.
        interval = np.random.gamma(shape=burstiness, scale=theta)
201
202
203
204
        # The next request will be sent after the interval.
        await asyncio.sleep(interval)


205
def calculate_metrics(
206
    input_requests: list[SampleRequest],
207
    outputs: list[RequestFuncOutput],
208
209
    dur_s: float,
    tokenizer: PreTrainedTokenizerBase,
210
211
212
213
214
    selected_percentile_metrics: list[str],
    selected_percentiles: list[float],
    goodput_config_dict: dict[str, float],
) -> tuple[BenchmarkMetrics, list[int]]:
    actual_output_lens: list[int] = []
215
216
    total_input = 0
    completed = 0
217
    good_completed = 0
218
219
220
221
222
    itls: list[float] = []
    tpots: list[float] = []
    all_tpots: list[float] = []
    ttfts: list[float] = []
    e2els: list[float] = []
223
224
    for i in range(len(outputs)):
        if outputs[i].success:
225
226
            output_len = outputs[i].output_tokens

227
            if not output_len:
228
229
230
231
232
233
                # We use the tokenizer to count the number of output tokens
                # for some serving backends instead of looking at
                # len(outputs[i].itl) since multiple output tokens may be
                # bundled together
                # Note : this may inflate the output token count slightly
                output_len = len(
234
235
236
237
                    tokenizer(
                        outputs[i].generated_text, add_special_tokens=False
                    ).input_ids
                )
238
            actual_output_lens.append(output_len)
239
            total_input += input_requests[i].prompt_len
240
            tpot = 0
241
            if output_len > 1:
242
243
                latency_minus_ttft = outputs[i].latency - outputs[i].ttft
                tpot = latency_minus_ttft / (output_len - 1)
244
245
246
                tpots.append(tpot)
            # Note: if output_len <= 1, we regard tpot as 0 for goodput
            all_tpots.append(tpot)
247
            itls += outputs[i].itl
248
            ttfts.append(outputs[i].ttft)
249
            e2els.append(outputs[i].latency)
250
            completed += 1
251
252
        else:
            actual_output_lens.append(0)
253

254
    if goodput_config_dict:
255
256
257
        valid_metrics = []
        slo_values = []

258
        if "ttft" in goodput_config_dict:
259
            valid_metrics.append(ttfts)
260
261
262
            slo_values.append(
                goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION
            )
263
        if "tpot" in goodput_config_dict:
264
            valid_metrics.append(all_tpots)
265
266
267
            slo_values.append(
                goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION
            )
268
        if "e2el" in goodput_config_dict:
269
            valid_metrics.append(e2els)
270
271
272
            slo_values.append(
                goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION
            )
273
274
275
276
277
278

        for req_metric in zip(*valid_metrics):
            is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
            if is_good_req:
                good_completed += 1

279
280
281
282
    if completed == 0:
        warnings.warn(
            "All requests failed. This is likely due to a misconfiguration "
            "on the benchmark arguments.",
283
284
            stacklevel=2,
        )
285
286
287
    metrics = BenchmarkMetrics(
        completed=completed,
        total_input=total_input,
288
        total_output=sum(actual_output_lens),
289
        request_throughput=completed / dur_s,
290
        request_goodput=good_completed / dur_s,
291
        output_throughput=sum(actual_output_lens) / dur_s,
292
        total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
293
294
        mean_ttft_ms=np.mean(ttfts or 0)
        * 1000,  # ttfts is empty if streaming is not supported by backend
295
        std_ttft_ms=np.std(ttfts or 0) * 1000,
296
        median_ttft_ms=np.median(ttfts or 0) * 1000,
297
298
299
        percentiles_ttft_ms=[
            (p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles
        ],
300
        mean_tpot_ms=np.mean(tpots or 0) * 1000,
301
        std_tpot_ms=np.std(tpots or 0) * 1000,
302
        median_tpot_ms=np.median(tpots or 0) * 1000,
303
304
305
        percentiles_tpot_ms=[
            (p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles
        ],
306
        mean_itl_ms=np.mean(itls or 0) * 1000,
307
        std_itl_ms=np.std(itls or 0) * 1000,
308
        median_itl_ms=np.median(itls or 0) * 1000,
309
310
311
        percentiles_itl_ms=[
            (p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles
        ],
312
        mean_e2el_ms=np.mean(e2els or 0) * 1000,
313
        std_e2el_ms=np.std(e2els or 0) * 1000,
314
        median_e2el_ms=np.median(e2els or 0) * 1000,
315
316
317
        percentiles_e2el_ms=[
            (p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles
        ],
318
    )
319

320
    return metrics, actual_output_lens
321

322
323
324
325

async def benchmark(
    backend: str,
    api_url: str,
326
    base_url: str,
327
    model_id: str,
328
    model_name: str,
329
    tokenizer: PreTrainedTokenizerBase,
330
    input_requests: list[SampleRequest],
331
    logprobs: Optional[int],
332
    request_rate: float,
333
    burstiness: float,
334
    disable_tqdm: bool,
335
    profile: bool,
336
    selected_percentile_metrics: list[str],
337
    selected_percentiles: list[float],
338
    ignore_eos: bool,
339
    goodput_config_dict: dict[str, float],
340
    max_concurrency: Optional[int],
341
    lora_modules: Optional[Iterable[str]],
342
    extra_body: Optional[dict],
343
344
345
    ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
    ramp_up_start_rps: Optional[int] = None,
    ramp_up_end_rps: Optional[int] = None,
346
347
):
    if backend in ASYNC_REQUEST_FUNCS:
348
        request_func = ASYNC_REQUEST_FUNCS[backend]
349
350
351
    else:
        raise ValueError(f"Unknown backend: {backend}")

352
    print("Starting initial single prompt test run...")
353
354
355
356
357
358
    test_prompt, test_prompt_len, test_output_len, test_mm_content = (
        input_requests[0].prompt,
        input_requests[0].prompt_len,
        input_requests[0].expected_output_len,
        input_requests[0].multi_modal_data,
    )
359
360

    assert test_mm_content is None or isinstance(test_mm_content, dict)
361
362
    test_input = RequestFuncInput(
        model=model_id,
363
        model_name=model_name,
364
365
366
367
        prompt=test_prompt,
        api_url=api_url,
        prompt_len=test_prompt_len,
        output_len=test_output_len,
368
        logprobs=logprobs,
369
        multi_modal_content=test_mm_content,
370
        ignore_eos=ignore_eos,
371
        extra_body=extra_body,
372
    )
373

374
375
376
377
    test_output = await request_func(request_func_input=test_input)
    if not test_output.success:
        raise ValueError(
            "Initial test run failed - Please make sure benchmark arguments "
378
379
            f"are correctly specified. Error: {test_output.error}"
        )
380
381
    else:
        print("Initial test run completed. Starting main benchmark run...")
382

383
384
385
    if lora_modules:
        # For each input request, choose a LoRA module at random.
        lora_modules = iter(
386
387
            [random.choice(lora_modules) for _ in range(len(input_requests))]
        )
388

389
390
    if profile:
        print("Starting profiler...")
391
392
393
394
395
396
397
398
399
400
401
402
        profile_input = RequestFuncInput(
            model=model_id,
            model_name=model_name,
            prompt=test_prompt,
            api_url=base_url + "/start_profile",
            prompt_len=test_prompt_len,
            output_len=test_output_len,
            logprobs=logprobs,
            multi_modal_content=test_mm_content,
            ignore_eos=ignore_eos,
            extra_body=extra_body,
        )
403
404
405
406
        profile_output = await request_func(request_func_input=profile_input)
        if profile_output.success:
            print("Profiler started")

407
    distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution"
408

409
410
411
412
413
414
415
416
417
    if ramp_up_strategy is not None:
        print(
            f"Traffic ramp-up strategy: {ramp_up_strategy}. Will increase "
            f"RPS from {ramp_up_start_rps} to {ramp_up_end_rps} RPS over "
            "the duration of the benchmark."
        )
    else:
        print(f"Traffic request rate: {request_rate} RPS.")

418
    print(f"Burstiness factor: {burstiness} ({distribution})")
419
    print(f"Maximum request concurrency: {max_concurrency}")
420

421
422
    pbar = None if disable_tqdm else tqdm(total=len(input_requests))

423
424
425
426
    # This can be used once the minimum Python version is 3.10 or higher,
    # and it will simplify the code in limited_request_func.
    #    semaphore = (asyncio.Semaphore(max_concurrency)
    #                 if max_concurrency else contextlib.nullcontext())
427
    semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
428
429
430

    async def limited_request_func(request_func_input, pbar):
        if semaphore is None:
431
            return await request_func(request_func_input=request_func_input, pbar=pbar)
432
        async with semaphore:
433
            return await request_func(request_func_input=request_func_input, pbar=pbar)
434

435
    benchmark_start_time = time.perf_counter()
436
    tasks: list[asyncio.Task] = []
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464

    rps_change_events = []
    last_int_rps = -1
    if ramp_up_strategy is not None and ramp_up_start_rps is not None:
        last_int_rps = ramp_up_start_rps
        rps_change_events.append(
            {
                "rps": last_int_rps,
                "timestamp": datetime.now().isoformat(),
            }
        )

    async for request, current_request_rate in get_request(
        input_requests,
        request_rate,
        burstiness,
        ramp_up_strategy,
        ramp_up_start_rps,
        ramp_up_end_rps,
    ):
        if ramp_up_strategy is not None:
            current_int_rps = int(current_request_rate)
            if current_int_rps > last_int_rps:
                timestamp = datetime.now().isoformat()
                for rps_val in range(last_int_rps + 1, current_int_rps + 1):
                    rps_change_events.append({"rps": rps_val, "timestamp": timestamp})
                last_int_rps = current_int_rps

465
466
467
468
469
470
        prompt, prompt_len, output_len, mm_content = (
            request.prompt,
            request.prompt_len,
            request.expected_output_len,
            request.multi_modal_data,
        )
471
472
473
474
475
        req_model_id, req_model_name = model_id, model_name
        if lora_modules:
            req_lora_module = next(lora_modules)
            req_model_id, req_model_name = req_lora_module, req_lora_module

476
477
478
479
480
481
482
483
484
485
486
487
        request_func_input = RequestFuncInput(
            model=req_model_id,
            model_name=req_model_name,
            prompt=prompt,
            api_url=api_url,
            prompt_len=prompt_len,
            output_len=output_len,
            logprobs=logprobs,
            multi_modal_content=mm_content,
            ignore_eos=ignore_eos,
            extra_body=extra_body,
        )
488
489
        task = limited_request_func(request_func_input=request_func_input, pbar=pbar)
        tasks.append(asyncio.create_task(task))
490
    outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
491

492
493
494
495
496
497
498
499
    if profile:
        print("Stopping profiler...")
        profile_input = RequestFuncInput(
            model=model_id,
            prompt=test_prompt,
            api_url=base_url + "/stop_profile",
            prompt_len=test_prompt_len,
            output_len=test_output_len,
500
            logprobs=logprobs,
501
502
503
504
505
        )
        profile_output = await request_func(request_func_input=profile_input)
        if profile_output.success:
            print("Profiler stopped")

506
    if pbar is not None:
507
508
509
510
        pbar.close()

    benchmark_duration = time.perf_counter() - benchmark_start_time

511
    metrics, actual_output_lens = calculate_metrics(
512
513
514
515
        input_requests=input_requests,
        outputs=outputs,
        dur_s=benchmark_duration,
        tokenizer=tokenizer,
516
517
        selected_percentile_metrics=selected_percentile_metrics,
        selected_percentiles=selected_percentiles,
518
        goodput_config_dict=goodput_config_dict,
519
520
    )

521
    print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
522
    print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
523
    print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
524
    print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
525
526
527
528
529
530
    print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
    print(
        "{:<40} {:<10.2f}".format(
            "Request throughput (req/s):", metrics.request_throughput
        )
    )
531
    if goodput_config_dict:
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
        print(
            "{:<40} {:<10.2f}".format(
                "Request goodput (req/s):", metrics.request_goodput
            )
        )
    print(
        "{:<40} {:<10.2f}".format(
            "Output token throughput (tok/s):", metrics.output_throughput
        )
    )
    print(
        "{:<40} {:<10.2f}".format(
            "Total Token throughput (tok/s):", metrics.total_token_throughput
        )
    )
547
548
549
550
551
552

    result = {
        "duration": benchmark_duration,
        "completed": metrics.completed,
        "total_input_tokens": metrics.total_input,
        "total_output_tokens": metrics.total_output,
553
        "request_throughput": metrics.request_throughput,
554
        "request_goodput:": metrics.request_goodput if goodput_config_dict else None,
555
        "output_throughput": metrics.output_throughput,
556
        "total_token_throughput": metrics.total_token_throughput,
557
558
559
560
561
562
        "input_lens": [output.prompt_len for output in outputs],
        "output_lens": actual_output_lens,
        "ttfts": [output.ttft for output in outputs],
        "itls": [output.itl for output in outputs],
        "generated_texts": [output.generated_text for output in outputs],
        "errors": [output.error for output in outputs],
563
    }
564

565
566
567
    if rps_change_events:
        result["rps_change_events"] = rps_change_events

568
569
570
571
572
573
574
575
    def process_one_metric(
        # E.g., "ttft"
        metric_attribute_name: str,
        # E.g., "TTFT"
        metric_name: str,
        # E.g., "Time to First Token"
        metric_header: str,
    ):
576
        # This function prints and adds statistics of the specified
577
578
579
        # metric.
        if metric_attribute_name not in selected_percentile_metrics:
            return
580
581
582
583
584
585
586
587
588
589
590
591
592
        print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
        print(
            "{:<40} {:<10.2f}".format(
                f"Mean {metric_name} (ms):",
                getattr(metrics, f"mean_{metric_attribute_name}_ms"),
            )
        )
        print(
            "{:<40} {:<10.2f}".format(
                f"Median {metric_name} (ms):",
                getattr(metrics, f"median_{metric_attribute_name}_ms"),
            )
        )
593
        result[f"mean_{metric_attribute_name}_ms"] = getattr(
594
595
            metrics, f"mean_{metric_attribute_name}_ms"
        )
596
        result[f"median_{metric_attribute_name}_ms"] = getattr(
597
598
            metrics, f"median_{metric_attribute_name}_ms"
        )
599
        result[f"std_{metric_attribute_name}_ms"] = getattr(
600
601
602
            metrics, f"std_{metric_attribute_name}_ms"
        )
        for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
603
            p_word = str(int(p)) if int(p) == p else str(p)
604
            print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
605
606
607
            result[f"p{p_word}_{metric_attribute_name}_ms"] = value

    process_one_metric("ttft", "TTFT", "Time to First Token")
608
    process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
609
610
611
612
613
    process_one_metric("itl", "ITL", "Inter-token Latency")
    process_one_metric("e2el", "E2EL", "End-to-end Latency")

    print("=" * 50)

614
    return result
615
616


617
618
def check_goodput_args(args):
    # Check and parse goodput arguments
619
    goodput_config_dict = {}
620
621
    VALID_NAMES = ["ttft", "tpot", "e2el"]
    if args.goodput:
622
623
        goodput_config_dict = parse_goodput(args.goodput)
        for slo_name, slo_val in goodput_config_dict.items():
624
625
626
627
            if slo_name not in VALID_NAMES:
                raise ValueError(
                    f"Invalid metric name found, {slo_name}: {slo_val}. "
                    "The service level objective name should be one of "
628
629
                    f"{str(VALID_NAMES)}. "
                )
630
631
632
633
            if slo_val < 0:
                raise ValueError(
                    f"Invalid value found, {slo_name}: {slo_val}. "
                    "The service level objective value should be "
634
635
                    "non-negative."
                )
636
    return goodput_config_dict
637
638
639


def parse_goodput(slo_pairs):
640
    goodput_config_dict = {}
641
642
643
    try:
        for slo_pair in slo_pairs:
            slo_name, slo_val = slo_pair.split(":")
644
            goodput_config_dict[slo_name] = float(slo_val)
645
646
647
    except ValueError as err:
        raise argparse.ArgumentTypeError(
            "Invalid format found for service level objectives. "
648
            'Specify service level objectives for goodput as "KEY:VALUE" '
649
            "pairs, where the key is a metric name, and the value is a "
650
651
            "number in milliseconds."
        ) from err
652
    return goodput_config_dict
653
654


655
656
657
def save_to_pytorch_benchmark_format(
    args: argparse.Namespace, results: dict[str, Any], file_name: str
) -> None:
658
    metrics = [
659
660
661
662
663
664
665
666
667
668
669
670
        "median_ttft_ms",
        "mean_ttft_ms",
        "std_ttft_ms",
        "p99_ttft_ms",
        "mean_tpot_ms",
        "median_tpot_ms",
        "std_tpot_ms",
        "p99_tpot_ms",
        "median_itl_ms",
        "mean_itl_ms",
        "std_itl_ms",
        "p99_itl_ms",
671
672
673
674
675
676
    ]
    # These raw data might be useful, but they are rather big. They can be added
    # later if needed
    ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"]
    pt_records = convert_to_pytorch_benchmark_format(
        args=args,
677
        metrics={k: [results[k]] for k in metrics},
678
679
        extra_info={
            k: results[k]
680
681
682
683
            for k in results
            if k not in metrics and k not in ignored_metrics
        },
    )
684
685
686
    if pt_records:
        # Don't use json suffix here as we don't want CI to pick it up
        pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json"
687
        write_to_json(pt_file, pt_records)
688
689


690
691
692
693
694
def main(args: argparse.Namespace):
    print(args)
    random.seed(args.seed)
    np.random.seed(args.seed)

695
696
    backend = args.backend
    model_id = args.model
697
    model_name = args.served_model_name
698
    tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
699
    tokenizer_mode = args.tokenizer_mode
700

701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
    # Validate ramp-up arguments
    if args.ramp_up_strategy is not None:
        if args.request_rate != float("inf"):
            raise ValueError(
                "When using ramp-up, do not specify --request-rate. "
                "The request rate will be controlled by ramp-up parameters. "
                "Please remove the --request-rate argument."
            )
        if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None:
            raise ValueError(
                "When using --ramp-up-strategy, both --ramp-up-start-rps and "
                "--ramp-up-end-rps must be specified"
            )
        if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0:
            raise ValueError("Ramp-up start and end RPS must be non-negative")
        if args.ramp_up_start_rps > args.ramp_up_end_rps:
            raise ValueError("Ramp-up start RPS must be less than end RPS")
        if args.ramp_up_strategy == "exponential" and args.ramp_up_start_rps == 0:
            raise ValueError("For exponential ramp-up, the start RPS cannot be 0.")

721
722
    if args.base_url is not None:
        api_url = f"{args.base_url}{args.endpoint}"
723
        base_url = f"{args.base_url}"
724
725
    else:
        api_url = f"http://{args.host}:{args.port}{args.endpoint}"
726
        base_url = f"http://{args.host}:{args.port}"
727

728
729
730
731
732
    tokenizer = get_tokenizer(
        tokenizer_id,
        tokenizer_mode=tokenizer_mode,
        trust_remote_code=args.trust_remote_code,
    )
733

734
735
736
    if args.dataset_name is None:
        raise ValueError(
            "Please specify '--dataset-name' and the corresponding "
737
738
            "'--dataset-path' if required."
        )
739

740
741
742
743
744
745
746
747
748
749
    if args.dataset_name == "custom":
        dataset = CustomDataset(dataset_path=args.dataset_path)
        input_requests = dataset.sample(
            num_requests=args.num_prompts,
            tokenizer=tokenizer,
            output_len=args.custom_output_len,
            skip_chat_template=args.custom_skip_chat_template,
        )

    elif args.dataset_name == "sonnet":
750
751
        dataset = SonnetDataset(dataset_path=args.dataset_path)
        # For the "sonnet" dataset, formatting depends on the backend.
752
        if args.backend == "openai-chat":
753
754
755
756
757
758
759
760
            input_requests = dataset.sample(
                num_requests=args.num_prompts,
                input_len=args.sonnet_input_len,
                output_len=args.sonnet_output_len,
                prefix_len=args.sonnet_prefix_len,
                tokenizer=tokenizer,
                return_prompt_formatted=False,
            )
761
        else:
762
            assert tokenizer.chat_template or tokenizer.default_chat_template, (
763
764
765
766
767
768
769
770
771
772
                "Tokenizer/model must have chat template for sonnet dataset."
            )
            input_requests = dataset.sample(
                num_requests=args.num_prompts,
                input_len=args.sonnet_input_len,
                output_len=args.sonnet_output_len,
                prefix_len=args.sonnet_prefix_len,
                tokenizer=tokenizer,
                return_prompt_formatted=True,
            )
773

774
    elif args.dataset_name == "hf":
775
776
777
        # all following datasets are implemented from the
        # HuggingFaceDataset base class
        if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
778
            dataset_class = VisionArenaDataset
779
780
781
            args.hf_split = "train"
            args.hf_subset = None
        elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
782
783
            dataset_class = InstructCoderDataset
            args.hf_split = "train"
784
785
786
        elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS:
            dataset_class = MTBenchDataset
            args.hf_split = "train"
787
788
        elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
            dataset_class = ConversationDataset
789
790
791
        elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
            dataset_class = AIMODataset
            args.hf_split = "train"
792
793
794
        elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS:  # noqa: E501
            dataset_class = NextEditPredictionDataset
            args.hf_split = "train"
795
796
797
        elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
            dataset_class = ASRDataset
            args.hf_split = "train"
798
        else:
799
800
801
802
803
804
805
            supported_datasets = set(
                [
                    dataset_name
                    for cls in HuggingFaceDataset.__subclasses__()
                    for dataset_name in cls.SUPPORTED_DATASET_PATHS
                ]
            )
806
807
808
809
810
            raise ValueError(
                f"Unsupported dataset path: {args.dataset_path}. "
                "Huggingface dataset only supports dataset_path"
                f" from one of following: {supported_datasets}. "
                "Please consider contributing if you would "
811
812
                "like to add support for additional dataset formats."
            )
813

814
815
816
817
        if dataset_class.IS_MULTIMODAL and backend not in [
            "openai-chat",
            "openai-audio",
        ]:
818
819
            # multi-modal benchmark is only available on OpenAI Chat backend.
            raise ValueError(
820
821
822
                "Multi-modal content is only supported on 'openai-chat' and "
                "'openai-audio' backend."
            )
823
        input_requests = dataset_class(
824
825
826
            dataset_path=args.dataset_path,
            dataset_subset=args.hf_subset,
            dataset_split=args.hf_split,
827
            random_seed=args.seed,
828
        ).sample(
829
830
            num_requests=args.num_prompts,
            tokenizer=tokenizer,
831
            output_len=args.hf_output_len,
832
833
        )

834
    else:
835
836
        # For datasets that follow a similar structure, use a mapping.
        dataset_mapping = {
837
838
839
840
841
842
843
844
845
846
847
            "sharegpt": lambda: ShareGPTDataset(
                random_seed=args.seed, dataset_path=args.dataset_path
            ).sample(
                tokenizer=tokenizer,
                num_requests=args.num_prompts,
                output_len=args.sharegpt_output_len,
            ),
            "burstgpt": lambda: BurstGPTDataset(
                random_seed=args.seed, dataset_path=args.dataset_path
            ).sample(tokenizer=tokenizer, num_requests=args.num_prompts),
            "random": lambda: RandomDataset(dataset_path=args.dataset_path).sample(
848
849
850
851
852
853
                tokenizer=tokenizer,
                num_requests=args.num_prompts,
                prefix_len=args.random_prefix_len,
                input_len=args.random_input_len,
                output_len=args.random_output_len,
                range_ratio=args.random_range_ratio,
854
            ),
855
        }
856

857
858
859
860
        try:
            input_requests = dataset_mapping[args.dataset_name]()
        except KeyError as err:
            raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
861
862
    goodput_config_dict = check_goodput_args(args)

863
864
865
866
867
868
869
    # Collect the sampling parameters.
    sampling_params = {
        k: v
        for k, v in {
            "top_p": args.top_p,
            "top_k": args.top_k,
            "min_p": args.min_p,
870
871
872
            "temperature": args.temperature,
        }.items()
        if v is not None
873
874
875
876
877
    }

    # Sampling parameters are only supported by openai-compatible backend.
    if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
        raise ValueError(
878
879
            "Sampling parameters are only supported by openai-compatible backends."
        )
880
881
882
883

    if "temperature" not in sampling_params:
        sampling_params["temperature"] = 0.0  # Default to greedy decoding.

884
885
886
887
    if args.backend == "llama.cpp":
        # Disable prompt caching in llama.cpp backend
        sampling_params["cache_prompt"] = False

888
889
890
    # Avoid GC processing "static" data - reduce pause times.
    gc.collect()
    gc.freeze()
891

892
893
894
895
    benchmark_result = asyncio.run(
        benchmark(
            backend=backend,
            api_url=api_url,
896
            base_url=base_url,
897
            model_id=model_id,
898
            model_name=model_name,
899
900
            tokenizer=tokenizer,
            input_requests=input_requests,
901
            logprobs=args.logprobs,
902
            request_rate=args.request_rate,
903
            burstiness=args.burstiness,
904
            disable_tqdm=args.disable_tqdm,
905
            profile=args.profile,
906
            selected_percentile_metrics=args.percentile_metrics.split(","),
907
            selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
908
            ignore_eos=args.ignore_eos,
909
            goodput_config_dict=goodput_config_dict,
910
            max_concurrency=args.max_concurrency,
911
            lora_modules=args.lora_modules,
912
            extra_body=sampling_params,
913
914
915
            ramp_up_strategy=args.ramp_up_strategy,
            ramp_up_start_rps=args.ramp_up_start_rps,
            ramp_up_end_rps=args.ramp_up_end_rps,
916
917
        )
    )
918
919

    # Save config and results to json
920
    if args.save_result or args.append_result:
921
        result_json: dict[str, Any] = {}
922
923
924
925
926
927
928
929
930

        # Setup
        current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
        result_json["date"] = current_dt
        result_json["backend"] = backend
        result_json["model_id"] = model_id
        result_json["tokenizer_id"] = tokenizer_id
        result_json["num_prompts"] = args.num_prompts

931
932
933
934
935
936
937
938
939
940
        # Metadata
        if args.metadata:
            for item in args.metadata:
                if "=" in item:
                    kvstring = item.split("=")
                    result_json[kvstring[0].strip()] = kvstring[1].strip()
                else:
                    raise ValueError(
                        "Invalid metadata format. Please use KEY=VALUE format."
                    )
941
        # Traffic
942
943
944
        result_json["request_rate"] = (
            args.request_rate if args.request_rate < float("inf") else "inf"
        )
945
946
947
        result_json["burstiness"] = args.burstiness
        result_json["max_concurrency"] = args.max_concurrency

948
949
950
951
952
        if args.ramp_up_strategy is not None:
            result_json["ramp_up_strategy"] = args.ramp_up_strategy
            result_json["ramp_up_start_rps"] = args.ramp_up_start_rps
            result_json["ramp_up_end_rps"] = args.ramp_up_end_rps

953
954
        # Merge with benchmark result
        result_json = {**result_json, **benchmark_result}
955

956
957
958
        if not args.save_detailed:
            # Remove fields with too many data points
            for field in [
959
960
961
962
963
964
                "input_lens",
                "output_lens",
                "ttfts",
                "itls",
                "generated_texts",
                "errors",
965
966
967
            ]:
                if field in result_json:
                    del result_json[field]
968
969
                if field in benchmark_result:
                    del benchmark_result[field]
970

971
972
        # Save to file
        base_model_id = model_id.split("/")[-1]
973
974
975
976
977
        max_concurrency_str = (
            f"-concurrency{args.max_concurrency}"
            if args.max_concurrency is not None
            else ""
        )
978
979
980
981
        if args.ramp_up_strategy is not None:
            file_name = f"{backend}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json"  # noqa
        else:
            file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json"  # noqa
982
983
        if args.result_filename:
            file_name = args.result_filename
984
        if args.result_dir:
985
            os.makedirs(args.result_dir, exist_ok=True)
986
            file_name = os.path.join(args.result_dir, file_name)
987
988
989
        with open(
            file_name, mode="a+" if args.append_result else "w", encoding="utf-8"
        ) as outfile:
990
991
992
            # Append a newline.
            if args.append_result and outfile.tell() != 0:
                outfile.write("\n")
993
            json.dump(result_json, outfile)
994
        save_to_pytorch_benchmark_format(args, result_json, file_name)
995
996


997
def create_argument_parser():
998
    parser = FlexibleArgumentParser(
999
1000
        description="Benchmark the online serving throughput."
    )
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
    parser.add_argument(
        "--backend",
        type=str,
        default="vllm",
        choices=list(ASYNC_REQUEST_FUNCS.keys()),
    )
    parser.add_argument(
        "--base-url",
        type=str,
        default=None,
        help="Server or API base url if not using http host and port.",
    )
1013
1014
    # Use 127.0.0.1 here instead of localhost to force the use of ipv4
    parser.add_argument("--host", type=str, default="127.0.0.1")
1015
    parser.add_argument("--port", type=int, default=8000)
1016
1017
1018
    parser.add_argument(
        "--endpoint",
        type=str,
1019
        default="/v1/completions",
1020
1021
        help="API endpoint.",
    )
1022
1023
1024
1025
    parser.add_argument(
        "--dataset-name",
        type=str,
        default="sharegpt",
1026
        choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"],
1027
1028
        help="Name of the dataset to benchmark on.",
    )
1029
1030
1031
1032
1033
1034
1035
    parser.add_argument(
        "--dataset-path",
        type=str,
        default=None,
        help="Path to the sharegpt/sonnet dataset. "
        "Or the huggingface dataset ID if using HF dataset.",
    )
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
    parser.add_argument(
        "--max-concurrency",
        type=int,
        default=None,
        help="Maximum number of concurrent requests. This can be used "
        "to help simulate an environment where a higher level component "
        "is enforcing a maximum number of concurrent requests. While the "
        "--request-rate argument controls the rate at which requests are "
        "initiated, this argument will control how many are actually allowed "
        "to execute at a time. This means that when used in combination, the "
        "actual request rate may be lower than specified with --request-rate, "
1047
1048
        "if the server is not processing requests fast enough to keep up.",
    )
1049

1050
1051
1052
1053
1054
1055
1056
1057
1058
    parser.add_argument(
        "--model",
        type=str,
        required=True,
        help="Name of the model.",
    )
    parser.add_argument(
        "--tokenizer",
        type=str,
1059
        help="Name or path of the tokenizer, if not using the default tokenizer.",  # noqa: E501
1060
    )
1061
    parser.add_argument("--use-beam-search", action="store_true")
1062
1063
1064
1065
1066
1067
    parser.add_argument(
        "--num-prompts",
        type=int,
        default=1000,
        help="Number of prompts to process.",
    )
1068
1069
1070
1071
    parser.add_argument(
        "--logprobs",
        type=int,
        default=None,
1072
1073
1074
1075
1076
1077
1078
        help=(
            "Number of logprobs-per-token to compute & return as part of "
            "the request. If unspecified, then either (1) if beam search "
            "is disabled, no logprobs are computed & a single dummy "
            "logprob is returned for each token; or (2) if beam search "
            "is enabled 1 logprob per token is computed"
        ),
1079
    )
1080
1081
1082
1083
1084
1085
    parser.add_argument(
        "--request-rate",
        type=float,
        default=float("inf"),
        help="Number of requests per second. If this is inf, "
        "then all the requests are sent at time 0. "
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
        "Otherwise, we use Poisson process or gamma distribution "
        "to synthesize the request arrival times.",
    )
    parser.add_argument(
        "--burstiness",
        type=float,
        default=1.0,
        help="Burstiness factor of the request generation. "
        "Only take effect when request_rate is not inf. "
        "Default value is 1, which follows Poisson process. "
        "Otherwise, the request intervals follow a gamma distribution. "
        "A lower burstiness value (0 < burstiness < 1) results in more "
        "bursty requests. A higher burstiness value (burstiness > 1) "
        "results in a more uniform arrival of requests.",
1100
    )
1101
    parser.add_argument("--seed", type=int, default=0)
1102
1103
1104
1105
1106
1107
1108
1109
    parser.add_argument(
        "--trust-remote-code",
        action="store_true",
        help="Trust remote code from huggingface",
    )
    parser.add_argument(
        "--disable-tqdm",
        action="store_true",
1110
        help="Specify to disable tqdm progress bar.",
1111
1112
    )
    parser.add_argument(
1113
1114
1115
1116
1117
1118
        "--profile",
        action="store_true",
        help="Use Torch Profiler. The endpoint must be launched with "
        "VLLM_TORCH_PROFILER_DIR to enable profiler.",
    )
    parser.add_argument(
1119
1120
1121
1122
        "--save-result",
        action="store_true",
        help="Specify to save benchmark results to a json file",
    )
1123
1124
1125
1126
1127
1128
    parser.add_argument(
        "--save-detailed",
        action="store_true",
        help="When saving the results, whether to include per request "
        "information such as response, error, ttfs, tpots, etc.",
    )
1129
1130
1131
1132
1133
    parser.add_argument(
        "--append-result",
        action="store_true",
        help="Append the benchmark result to the existing json file.",
    )
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
    parser.add_argument(
        "--metadata",
        metavar="KEY=VALUE",
        nargs="*",
        help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) "
        "for metadata of this run to be saved in the result JSON file "
        "for record keeping purposes.",
    )
    parser.add_argument(
        "--result-dir",
        type=str,
        default=None,
        help="Specify directory to save benchmark json results."
        "If not specified, results are saved in the current directory.",
    )
1149
1150
1151
1152
1153
1154
1155
1156
1157
    parser.add_argument(
        "--result-filename",
        type=str,
        default=None,
        help="Specify the filename to save benchmark json results."
        "If not specified, results will be saved in "
        "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
        " format.",
    )
1158
1159
1160
1161
    parser.add_argument(
        "--ignore-eos",
        action="store_true",
        help="Set ignore_eos flag when sending the benchmark request."
1162
1163
        "Warning: ignore_eos is not supported in deepspeed_mii and tgi.",
    )
1164
1165
1166
1167
    parser.add_argument(
        "--percentile-metrics",
        type=str,
        default="ttft,tpot,itl",
1168
        help="Comma-separated list of selected metrics to report percentils. "
1169
        "This argument specifies the metrics to report percentiles. "
1170
1171
1172
        'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
        'Default value is "ttft,tpot,itl".',
    )
1173
1174
1175
1176
    parser.add_argument(
        "--metric-percentiles",
        type=str,
        default="99",
1177
        help="Comma-separated list of percentiles for selected metrics. "
1178
1179
1180
        'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". '
        'Default value is "99". '
        'Use "--percentile-metrics" to select metrics.',
1181
    )
1182
1183
1184
1185
    parser.add_argument(
        "--goodput",
        nargs="+",
        required=False,
1186
        help='Specify service level objectives for goodput as "KEY:VALUE" '
1187
        "pairs, where the key is a metric name, and the value is in "
1188
        'milliseconds. Multiple "KEY:VALUE" pairs can be provided, '
1189
        "separated by spaces. Allowed request level metric names are "
1190
        '"ttft", "tpot", "e2el". For more context on the definition of '
1191
        "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
1192
1193
        "and the blog: https://hao-ai-lab.github.io/blogs/distserve",
    )
1194

1195
    # group for dataset specific arguments
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
    custom_group = parser.add_argument_group("custom dataset options")
    custom_group.add_argument(
        "--custom-output-len",
        type=int,
        default=256,
        help="Number of output tokens per request, used only for custom dataset.",
    )
    custom_group.add_argument(
        "--custom-skip-chat-template",
        action="store_true",
        help="Skip applying chat template to prompt, used only for custom dataset.",
    )

1209
1210
1211
1212
1213
    sonnet_group = parser.add_argument_group("sonnet dataset options")
    sonnet_group.add_argument(
        "--sonnet-input-len",
        type=int,
        default=550,
1214
        help="Number of input tokens per request, used only for sonnet dataset.",
1215
1216
1217
1218
1219
    )
    sonnet_group.add_argument(
        "--sonnet-output-len",
        type=int,
        default=150,
1220
        help="Number of output tokens per request, used only for sonnet dataset.",
1221
1222
1223
1224
1225
    )
    sonnet_group.add_argument(
        "--sonnet-prefix-len",
        type=int,
        default=200,
1226
        help="Number of prefix tokens per request, used only for sonnet dataset.",
1227
1228
1229
1230
1231
1232
1233
1234
    )

    sharegpt_group = parser.add_argument_group("sharegpt dataset options")
    sharegpt_group.add_argument(
        "--sharegpt-output-len",
        type=int,
        default=None,
        help="Output length for each request. Overrides the output length "
1235
1236
        "from the ShareGPT dataset.",
    )
1237
1238
1239
1240
1241
1242

    random_group = parser.add_argument_group("random dataset options")
    random_group.add_argument(
        "--random-input-len",
        type=int,
        default=1024,
1243
        help="Number of input tokens per request, used only for random sampling.",
1244
1245
1246
1247
1248
    )
    random_group.add_argument(
        "--random-output-len",
        type=int,
        default=128,
1249
        help="Number of output tokens per request, used only for random sampling.",
1250
1251
1252
1253
    )
    random_group.add_argument(
        "--random-range-ratio",
        type=float,
1254
1255
1256
1257
1258
        default=0.0,
        help="Range ratio for sampling input/output length, "
        "used only for random sampling. Must be in the range [0, 1) to define "
        "a symmetric sampling range"
        "[length * (1 - range_ratio), length * (1 + range_ratio)].",
1259
1260
1261
1262
1263
    )
    random_group.add_argument(
        "--random-prefix-len",
        type=int,
        default=0,
1264
1265
1266
1267
1268
1269
1270
1271
        help=(
            "Number of fixed prefix tokens before the random context "
            "in a request. "
            "The total input length is the sum of `random-prefix-len` and "
            "a random "
            "context length sampled from [input_len * (1 - range_ratio), "
            "input_len * (1 + range_ratio)]."
        ),
1272
    )
1273
1274

    hf_group = parser.add_argument_group("hf dataset options")
1275
1276
1277
1278
1279
1280
    hf_group.add_argument(
        "--hf-subset", type=str, default=None, help="Subset of the HF dataset."
    )
    hf_group.add_argument(
        "--hf-split", type=str, default=None, help="Split of the HF dataset."
    )
1281
1282
1283
1284
1285
1286
1287
1288
    hf_group.add_argument(
        "--hf-output-len",
        type=int,
        default=None,
        help="Output length for each request. Overrides the output lengths "
        "from the sampled HF dataset.",
    )

1289
1290
1291
1292
1293
    sampling_group = parser.add_argument_group("sampling parameters")
    sampling_group.add_argument(
        "--top-p",
        type=float,
        default=None,
1294
1295
        help="Top-p sampling parameter. Only has effect on openai-compatible backends.",
    )
1296
1297
1298
1299
    sampling_group.add_argument(
        "--top-k",
        type=int,
        default=None,
1300
1301
        help="Top-k sampling parameter. Only has effect on openai-compatible backends.",
    )
1302
1303
1304
1305
    sampling_group.add_argument(
        "--min-p",
        type=float,
        default=None,
1306
1307
        help="Min-p sampling parameter. Only has effect on openai-compatible backends.",
    )
1308
1309
1310
1311
1312
1313
    sampling_group.add_argument(
        "--temperature",
        type=float,
        default=None,
        help="Temperature sampling parameter. Only has effect on "
        "openai-compatible backends. If not specified, default to greedy "
1314
1315
        "decoding (i.e. temperature==0.0).",
    )
1316

1317
    parser.add_argument(
1318
        "--tokenizer-mode",
1319
1320
        type=str,
        default="auto",
1321
        choices=["auto", "slow", "mistral", "custom"],
1322
1323
        help='The tokenizer mode.\n\n* "auto" will use the '
        'fast tokenizer if available.\n* "slow" will '
1324
        "always use the slow tokenizer. \n* "
1325
        '"mistral" will always use the `mistral_common` tokenizer. \n*'
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
        '"custom" will use --tokenizer to select the preregistered tokenizer.',
    )

    parser.add_argument(
        "--served-model-name",
        type=str,
        default=None,
        help="The model name used in the API. "
        "If not specified, the model name will be the "
        "same as the ``--model`` argument. ",
    )

    parser.add_argument(
        "--lora-modules",
        nargs="+",
        default=None,
        help="A subset of LoRA module names passed in when "
        "launching the server. For each request, the "
        "script chooses a LoRA module at random.",
    )
1346

1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
    parser.add_argument(
        "--ramp-up-strategy",
        type=str,
        default=None,
        choices=["linear", "exponential"],
        help="The ramp-up strategy. This would be used to "
        "ramp up the request rate from initial RPS to final "
        "RPS rate (specified by --ramp-up-start-rps and --ramp-up-end-rps). "
        "over the duration of the benchmark.",
    )
    parser.add_argument(
        "--ramp-up-start-rps",
        type=int,
        default=None,
        help="The starting request rate for ramp-up (RPS). "
        "Needs to be specified when --ramp-up-strategy is used.",
    )
    parser.add_argument(
        "--ramp-up-end-rps",
        type=int,
        default=None,
        help="The ending request rate for ramp-up (RPS). "
        "Needs to be specified when --ramp-up-strategy is used.",
    )

1372
1373
    return parser

1374

1375
1376
1377
if __name__ == "__main__":
    parser = create_argument_parser()
    args = parser.parse_args()
1378
    main(args)