benchmark_serving.py 49.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
r"""Benchmark online serving throughput.
4
5

On the server side, run one of the following commands:
6
    vLLM OpenAI API server
Ethan Xu's avatar
Ethan Xu committed
7
8
    vllm serve <your_model> \
        --swap-space 16 \
9
        --disable-log-requests
10
11
12
13

On the client side, run:
    python benchmarks/benchmark_serving.py \
        --backend <backend> \
14
15
16
17
18
        --model <your_model> \
        --dataset-name sharegpt \
        --dataset-path <path to dataset> \
        --request-rate <request_rate> \ # By default <request_rate> is inf
        --num-prompts <num_prompts> # By default <num_prompts> is 1000
19

20
21
22
    when using tgi backend, add
        --endpoint /generate_stream
    to the end of the command above.
23
"""
24

25
26
import argparse
import asyncio
27
import gc
28
import json
29
import os
30
31
import random
import time
32
import warnings
33
from collections.abc import AsyncGenerator, Iterable
34
35
from dataclasses import dataclass
from datetime import datetime
36
from typing import Any, Literal, Optional
37
38

import numpy as np
39
from tqdm.asyncio import tqdm
40
from transformers import PreTrainedTokenizerBase
41

42
43
44
45
46
47
48
from backend_request_func import (
    ASYNC_REQUEST_FUNCS,
    OPENAI_COMPATIBLE_BACKENDS,
    RequestFuncInput,
    RequestFuncOutput,
)

49
50
51
52
try:
    from vllm.transformers_utils.tokenizer import get_tokenizer
except ImportError:
    from backend_request_func import get_tokenizer
53

54
55
56
57
58
try:
    from vllm.utils import FlexibleArgumentParser
except ImportError:
    from argparse import ArgumentParser as FlexibleArgumentParser

59
60
61
62
63
from benchmark_dataset import (
    AIMODataset,
    ASRDataset,
    BurstGPTDataset,
    ConversationDataset,
64
    CustomDataset,
65
66
67
68
69
70
71
72
73
74
    HuggingFaceDataset,
    InstructCoderDataset,
    MTBenchDataset,
    NextEditPredictionDataset,
    RandomDataset,
    SampleRequest,
    ShareGPTDataset,
    SonnetDataset,
    VisionArenaDataset,
)
75
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
76

77
78
MILLISECONDS_TO_SECONDS_CONVERSION = 1000

79
80
81
82
83
84
85

@dataclass
class BenchmarkMetrics:
    completed: int
    total_input: int
    total_output: int
    request_throughput: float
86
    request_goodput: float
87
    output_throughput: float
88
    total_token_throughput: float
89
90
    mean_ttft_ms: float
    median_ttft_ms: float
91
    std_ttft_ms: float
92
    percentiles_ttft_ms: list[tuple[float, float]]
93
94
    mean_tpot_ms: float
    median_tpot_ms: float
95
    std_tpot_ms: float
96
    percentiles_tpot_ms: list[tuple[float, float]]
97
98
    mean_itl_ms: float
    median_itl_ms: float
99
    std_itl_ms: float
100
    percentiles_itl_ms: list[tuple[float, float]]
101
102
103
104
105
106
    # E2EL stands for end-to-end latency per request.
    # It is the time taken on the client side from sending
    # a request to receiving a complete response.
    mean_e2el_ms: float
    median_e2el_ms: float
    std_e2el_ms: float
107
    percentiles_e2el_ms: list[tuple[float, float]]
108
109


110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def _get_current_request_rate(
    ramp_up_strategy: Optional[Literal["linear", "exponential"]],
    ramp_up_start_rps: Optional[int],
    ramp_up_end_rps: Optional[int],
    request_index: int,
    total_requests: int,
    request_rate: float,
) -> float:
    if (
        ramp_up_strategy
        and ramp_up_start_rps is not None
        and ramp_up_end_rps is not None
    ):
        progress = request_index / max(total_requests - 1, 1)
        if ramp_up_strategy == "linear":
            increase = (ramp_up_end_rps - ramp_up_start_rps) * progress
            return ramp_up_start_rps + increase
        elif ramp_up_strategy == "exponential":
            ratio = ramp_up_end_rps / ramp_up_start_rps
            return ramp_up_start_rps * (ratio**progress)
        else:
            raise ValueError(f"Unknown ramp-up strategy: {ramp_up_strategy}")
    return request_rate


135
async def get_request(
136
    input_requests: list[SampleRequest],
137
    request_rate: float,
138
    burstiness: float = 1.0,
139
140
141
142
    ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
    ramp_up_start_rps: Optional[int] = None,
    ramp_up_end_rps: Optional[int] = None,
) -> AsyncGenerator[tuple[SampleRequest, float], None]:
143
    """
144
    Asynchronously generates requests at a specified rate
145
    with OPTIONAL burstiness and OPTIONAL ramp-up strategy.
146

147
    Args:
148
        input_requests:
149
            A list of input requests, each represented as a SampleRequest.
150
        request_rate:
151
            The rate at which requests are generated (requests/s).
152
153
        burstiness (optional):
            The burstiness factor of the request generation.
154
155
156
            Only takes effect when request_rate is not inf.
            Default value is 1, which follows a Poisson process.
            Otherwise, the request intervals follow a gamma distribution.
157
158
            A lower burstiness value (0 < burstiness < 1) results
            in more bursty requests, while a higher burstiness value
159
            (burstiness > 1) results in a more uniform arrival of requests.
160
161
162
163
164
165
166
         ramp_up_strategy (optional):
            The ramp-up strategy. Can be "linear" or "exponential".
            If None, uses constant request rate (specified by request_rate).
        ramp_up_start_rps (optional):
            The starting request rate for ramp-up.
        ramp_up_end_rps (optional):
            The ending request rate for ramp-up.
167
168
    """
    assert burstiness > 0, (
169
170
        f"A positive burstiness factor is expected, but given {burstiness}."
    )
171
172
173
174
175
176
    # Convert to list to get length for ramp-up calculations
    if isinstance(input_requests, Iterable) and not isinstance(input_requests, list):
        input_requests = list(input_requests)

    total_requests = len(input_requests)
    request_index = 0
177

178
    for request in input_requests:
179
180
181
182
183
184
185
186
187
188
        current_request_rate = _get_current_request_rate(
            ramp_up_strategy,
            ramp_up_start_rps,
            ramp_up_end_rps,
            request_index,
            total_requests,
            request_rate,
        )

        yield request, current_request_rate
189

190
191
192
        request_index += 1

        if current_request_rate == float("inf"):
193
194
            # If the request rate is infinity, then we don't need to wait.
            continue
195

196
197
        theta = 1.0 / (current_request_rate * burstiness)

198
199
200
        # Sample the request interval from the gamma distribution.
        # If burstiness is 1, it follows exponential distribution.
        interval = np.random.gamma(shape=burstiness, scale=theta)
201
202
203
204
        # The next request will be sent after the interval.
        await asyncio.sleep(interval)


205
def calculate_metrics(
206
    input_requests: list[SampleRequest],
207
    outputs: list[RequestFuncOutput],
208
209
    dur_s: float,
    tokenizer: PreTrainedTokenizerBase,
210
211
212
213
214
    selected_percentile_metrics: list[str],
    selected_percentiles: list[float],
    goodput_config_dict: dict[str, float],
) -> tuple[BenchmarkMetrics, list[int]]:
    actual_output_lens: list[int] = []
215
216
    total_input = 0
    completed = 0
217
    good_completed = 0
218
219
220
221
222
    itls: list[float] = []
    tpots: list[float] = []
    all_tpots: list[float] = []
    ttfts: list[float] = []
    e2els: list[float] = []
223
224
    for i in range(len(outputs)):
        if outputs[i].success:
225
226
            output_len = outputs[i].output_tokens

227
            if not output_len:
228
229
230
231
232
233
                # We use the tokenizer to count the number of output tokens
                # for some serving backends instead of looking at
                # len(outputs[i].itl) since multiple output tokens may be
                # bundled together
                # Note : this may inflate the output token count slightly
                output_len = len(
234
235
236
237
                    tokenizer(
                        outputs[i].generated_text, add_special_tokens=False
                    ).input_ids
                )
238
            actual_output_lens.append(output_len)
239
            total_input += input_requests[i].prompt_len
240
            tpot = 0
241
            if output_len > 1:
242
243
                latency_minus_ttft = outputs[i].latency - outputs[i].ttft
                tpot = latency_minus_ttft / (output_len - 1)
244
245
246
                tpots.append(tpot)
            # Note: if output_len <= 1, we regard tpot as 0 for goodput
            all_tpots.append(tpot)
247
            itls += outputs[i].itl
248
            ttfts.append(outputs[i].ttft)
249
            e2els.append(outputs[i].latency)
250
            completed += 1
251
252
        else:
            actual_output_lens.append(0)
253

254
    if goodput_config_dict:
255
256
257
        valid_metrics = []
        slo_values = []

258
        if "ttft" in goodput_config_dict:
259
            valid_metrics.append(ttfts)
260
261
262
            slo_values.append(
                goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION
            )
263
        if "tpot" in goodput_config_dict:
264
            valid_metrics.append(all_tpots)
265
266
267
            slo_values.append(
                goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION
            )
268
        if "e2el" in goodput_config_dict:
269
            valid_metrics.append(e2els)
270
271
272
            slo_values.append(
                goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION
            )
273
274
275
276
277
278

        for req_metric in zip(*valid_metrics):
            is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)])
            if is_good_req:
                good_completed += 1

279
280
281
282
    if completed == 0:
        warnings.warn(
            "All requests failed. This is likely due to a misconfiguration "
            "on the benchmark arguments.",
283
284
            stacklevel=2,
        )
285
286
287
    metrics = BenchmarkMetrics(
        completed=completed,
        total_input=total_input,
288
        total_output=sum(actual_output_lens),
289
        request_throughput=completed / dur_s,
290
        request_goodput=good_completed / dur_s,
291
        output_throughput=sum(actual_output_lens) / dur_s,
292
        total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s,
293
294
        mean_ttft_ms=np.mean(ttfts or 0)
        * 1000,  # ttfts is empty if streaming is not supported by backend
295
        std_ttft_ms=np.std(ttfts or 0) * 1000,
296
        median_ttft_ms=np.median(ttfts or 0) * 1000,
297
298
299
        percentiles_ttft_ms=[
            (p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles
        ],
300
        mean_tpot_ms=np.mean(tpots or 0) * 1000,
301
        std_tpot_ms=np.std(tpots or 0) * 1000,
302
        median_tpot_ms=np.median(tpots or 0) * 1000,
303
304
305
        percentiles_tpot_ms=[
            (p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles
        ],
306
        mean_itl_ms=np.mean(itls or 0) * 1000,
307
        std_itl_ms=np.std(itls or 0) * 1000,
308
        median_itl_ms=np.median(itls or 0) * 1000,
309
310
311
        percentiles_itl_ms=[
            (p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles
        ],
312
        mean_e2el_ms=np.mean(e2els or 0) * 1000,
313
        std_e2el_ms=np.std(e2els or 0) * 1000,
314
        median_e2el_ms=np.median(e2els or 0) * 1000,
315
316
317
        percentiles_e2el_ms=[
            (p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles
        ],
318
    )
319

320
    return metrics, actual_output_lens
321

322
323
324
325

async def benchmark(
    backend: str,
    api_url: str,
326
    base_url: str,
327
    model_id: str,
328
    model_name: str,
329
    tokenizer: PreTrainedTokenizerBase,
330
    input_requests: list[SampleRequest],
331
    logprobs: Optional[int],
332
    request_rate: float,
333
    burstiness: float,
334
    disable_tqdm: bool,
335
    profile: bool,
336
    selected_percentile_metrics: list[str],
337
    selected_percentiles: list[float],
338
    ignore_eos: bool,
339
    goodput_config_dict: dict[str, float],
340
    max_concurrency: Optional[int],
341
    lora_modules: Optional[Iterable[str]],
342
    extra_body: Optional[dict],
343
344
345
    ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
    ramp_up_start_rps: Optional[int] = None,
    ramp_up_end_rps: Optional[int] = None,
346
347
):
    if backend in ASYNC_REQUEST_FUNCS:
348
        request_func = ASYNC_REQUEST_FUNCS[backend]
349
350
351
    else:
        raise ValueError(f"Unknown backend: {backend}")

352
    print("Starting initial single prompt test run...")
353
354
355
356
357
358
    test_prompt, test_prompt_len, test_output_len, test_mm_content = (
        input_requests[0].prompt,
        input_requests[0].prompt_len,
        input_requests[0].expected_output_len,
        input_requests[0].multi_modal_data,
    )
359
360

    assert test_mm_content is None or isinstance(test_mm_content, dict)
361
362
    test_input = RequestFuncInput(
        model=model_id,
363
        model_name=model_name,
364
365
366
367
        prompt=test_prompt,
        api_url=api_url,
        prompt_len=test_prompt_len,
        output_len=test_output_len,
368
        logprobs=logprobs,
369
        multi_modal_content=test_mm_content,
370
        ignore_eos=ignore_eos,
371
        extra_body=extra_body,
372
    )
373

374
375
376
377
    test_output = await request_func(request_func_input=test_input)
    if not test_output.success:
        raise ValueError(
            "Initial test run failed - Please make sure benchmark arguments "
378
379
            f"are correctly specified. Error: {test_output.error}"
        )
380
381
    else:
        print("Initial test run completed. Starting main benchmark run...")
382

383
384
385
    if lora_modules:
        # For each input request, choose a LoRA module at random.
        lora_modules = iter(
386
387
            [random.choice(lora_modules) for _ in range(len(input_requests))]
        )
388

389
390
    if profile:
        print("Starting profiler...")
391
392
393
394
395
396
397
398
399
400
401
402
        profile_input = RequestFuncInput(
            model=model_id,
            model_name=model_name,
            prompt=test_prompt,
            api_url=base_url + "/start_profile",
            prompt_len=test_prompt_len,
            output_len=test_output_len,
            logprobs=logprobs,
            multi_modal_content=test_mm_content,
            ignore_eos=ignore_eos,
            extra_body=extra_body,
        )
403
404
405
406
        profile_output = await request_func(request_func_input=profile_input)
        if profile_output.success:
            print("Profiler started")

407
    distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution"
408

409
410
411
412
413
414
415
416
417
    if ramp_up_strategy is not None:
        print(
            f"Traffic ramp-up strategy: {ramp_up_strategy}. Will increase "
            f"RPS from {ramp_up_start_rps} to {ramp_up_end_rps} RPS over "
            "the duration of the benchmark."
        )
    else:
        print(f"Traffic request rate: {request_rate} RPS.")

418
    print(f"Burstiness factor: {burstiness} ({distribution})")
419
    print(f"Maximum request concurrency: {max_concurrency}")
420

421
422
    pbar = None if disable_tqdm else tqdm(total=len(input_requests))

423
424
425
426
    # This can be used once the minimum Python version is 3.10 or higher,
    # and it will simplify the code in limited_request_func.
    #    semaphore = (asyncio.Semaphore(max_concurrency)
    #                 if max_concurrency else contextlib.nullcontext())
427
    semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
428
429
430

    async def limited_request_func(request_func_input, pbar):
        if semaphore is None:
431
            return await request_func(request_func_input=request_func_input, pbar=pbar)
432
        async with semaphore:
433
            return await request_func(request_func_input=request_func_input, pbar=pbar)
434

435
    benchmark_start_time = time.perf_counter()
436
    tasks: list[asyncio.Task] = []
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464

    rps_change_events = []
    last_int_rps = -1
    if ramp_up_strategy is not None and ramp_up_start_rps is not None:
        last_int_rps = ramp_up_start_rps
        rps_change_events.append(
            {
                "rps": last_int_rps,
                "timestamp": datetime.now().isoformat(),
            }
        )

    async for request, current_request_rate in get_request(
        input_requests,
        request_rate,
        burstiness,
        ramp_up_strategy,
        ramp_up_start_rps,
        ramp_up_end_rps,
    ):
        if ramp_up_strategy is not None:
            current_int_rps = int(current_request_rate)
            if current_int_rps > last_int_rps:
                timestamp = datetime.now().isoformat()
                for rps_val in range(last_int_rps + 1, current_int_rps + 1):
                    rps_change_events.append({"rps": rps_val, "timestamp": timestamp})
                last_int_rps = current_int_rps

465
466
467
468
469
470
        prompt, prompt_len, output_len, mm_content = (
            request.prompt,
            request.prompt_len,
            request.expected_output_len,
            request.multi_modal_data,
        )
471
472
473
474
475
        req_model_id, req_model_name = model_id, model_name
        if lora_modules:
            req_lora_module = next(lora_modules)
            req_model_id, req_model_name = req_lora_module, req_lora_module

476
477
478
479
480
481
482
483
484
485
486
487
        request_func_input = RequestFuncInput(
            model=req_model_id,
            model_name=req_model_name,
            prompt=prompt,
            api_url=api_url,
            prompt_len=prompt_len,
            output_len=output_len,
            logprobs=logprobs,
            multi_modal_content=mm_content,
            ignore_eos=ignore_eos,
            extra_body=extra_body,
        )
488
489
        task = limited_request_func(request_func_input=request_func_input, pbar=pbar)
        tasks.append(asyncio.create_task(task))
490
    outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
491

492
493
494
495
496
497
498
499
    if profile:
        print("Stopping profiler...")
        profile_input = RequestFuncInput(
            model=model_id,
            prompt=test_prompt,
            api_url=base_url + "/stop_profile",
            prompt_len=test_prompt_len,
            output_len=test_output_len,
500
            logprobs=logprobs,
501
502
503
504
505
        )
        profile_output = await request_func(request_func_input=profile_input)
        if profile_output.success:
            print("Profiler stopped")

506
    if pbar is not None:
507
508
509
510
        pbar.close()

    benchmark_duration = time.perf_counter() - benchmark_start_time

511
    metrics, actual_output_lens = calculate_metrics(
512
513
514
515
        input_requests=input_requests,
        outputs=outputs,
        dur_s=benchmark_duration,
        tokenizer=tokenizer,
516
517
        selected_percentile_metrics=selected_percentile_metrics,
        selected_percentiles=selected_percentiles,
518
        goodput_config_dict=goodput_config_dict,
519
520
    )

521
    print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
522
    print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
523
    print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
524
    print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
525
526
527
528
529
530
    print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
    print(
        "{:<40} {:<10.2f}".format(
            "Request throughput (req/s):", metrics.request_throughput
        )
    )
531
    if goodput_config_dict:
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
        print(
            "{:<40} {:<10.2f}".format(
                "Request goodput (req/s):", metrics.request_goodput
            )
        )
    print(
        "{:<40} {:<10.2f}".format(
            "Output token throughput (tok/s):", metrics.output_throughput
        )
    )
    print(
        "{:<40} {:<10.2f}".format(
            "Total Token throughput (tok/s):", metrics.total_token_throughput
        )
    )
547
548
549
550
551
552

    result = {
        "duration": benchmark_duration,
        "completed": metrics.completed,
        "total_input_tokens": metrics.total_input,
        "total_output_tokens": metrics.total_output,
553
        "request_throughput": metrics.request_throughput,
Kebe's avatar
Kebe committed
554
        "request_goodput": metrics.request_goodput if goodput_config_dict else None,
555
        "output_throughput": metrics.output_throughput,
556
        "total_token_throughput": metrics.total_token_throughput,
557
558
559
560
561
562
        "input_lens": [output.prompt_len for output in outputs],
        "output_lens": actual_output_lens,
        "ttfts": [output.ttft for output in outputs],
        "itls": [output.itl for output in outputs],
        "generated_texts": [output.generated_text for output in outputs],
        "errors": [output.error for output in outputs],
563
    }
564

565
566
567
    if rps_change_events:
        result["rps_change_events"] = rps_change_events

568
569
570
571
572
573
574
575
    def process_one_metric(
        # E.g., "ttft"
        metric_attribute_name: str,
        # E.g., "TTFT"
        metric_name: str,
        # E.g., "Time to First Token"
        metric_header: str,
    ):
576
        # This function prints and adds statistics of the specified
577
578
579
        # metric.
        if metric_attribute_name not in selected_percentile_metrics:
            return
580
581
582
583
584
585
586
587
588
589
590
591
592
        print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
        print(
            "{:<40} {:<10.2f}".format(
                f"Mean {metric_name} (ms):",
                getattr(metrics, f"mean_{metric_attribute_name}_ms"),
            )
        )
        print(
            "{:<40} {:<10.2f}".format(
                f"Median {metric_name} (ms):",
                getattr(metrics, f"median_{metric_attribute_name}_ms"),
            )
        )
593
        result[f"mean_{metric_attribute_name}_ms"] = getattr(
594
595
            metrics, f"mean_{metric_attribute_name}_ms"
        )
596
        result[f"median_{metric_attribute_name}_ms"] = getattr(
597
598
            metrics, f"median_{metric_attribute_name}_ms"
        )
599
        result[f"std_{metric_attribute_name}_ms"] = getattr(
600
601
602
            metrics, f"std_{metric_attribute_name}_ms"
        )
        for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"):
603
            p_word = str(int(p)) if int(p) == p else str(p)
604
            print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
605
606
607
            result[f"p{p_word}_{metric_attribute_name}_ms"] = value

    process_one_metric("ttft", "TTFT", "Time to First Token")
608
    process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)")
609
610
611
612
613
    process_one_metric("itl", "ITL", "Inter-token Latency")
    process_one_metric("e2el", "E2EL", "End-to-end Latency")

    print("=" * 50)

614
    return result
615
616


617
618
def check_goodput_args(args):
    # Check and parse goodput arguments
619
    goodput_config_dict = {}
620
621
    VALID_NAMES = ["ttft", "tpot", "e2el"]
    if args.goodput:
622
623
        goodput_config_dict = parse_goodput(args.goodput)
        for slo_name, slo_val in goodput_config_dict.items():
624
625
626
627
            if slo_name not in VALID_NAMES:
                raise ValueError(
                    f"Invalid metric name found, {slo_name}: {slo_val}. "
                    "The service level objective name should be one of "
628
629
                    f"{str(VALID_NAMES)}. "
                )
630
631
632
633
            if slo_val < 0:
                raise ValueError(
                    f"Invalid value found, {slo_name}: {slo_val}. "
                    "The service level objective value should be "
634
635
                    "non-negative."
                )
636
    return goodput_config_dict
637
638
639


def parse_goodput(slo_pairs):
640
    goodput_config_dict = {}
641
642
643
    try:
        for slo_pair in slo_pairs:
            slo_name, slo_val = slo_pair.split(":")
644
            goodput_config_dict[slo_name] = float(slo_val)
645
646
647
    except ValueError as err:
        raise argparse.ArgumentTypeError(
            "Invalid format found for service level objectives. "
648
            'Specify service level objectives for goodput as "KEY:VALUE" '
649
            "pairs, where the key is a metric name, and the value is a "
650
651
            "number in milliseconds."
        ) from err
652
    return goodput_config_dict
653
654


655
656
657
def save_to_pytorch_benchmark_format(
    args: argparse.Namespace, results: dict[str, Any], file_name: str
) -> None:
658
    metrics = [
659
660
661
662
663
664
665
666
667
668
669
670
        "median_ttft_ms",
        "mean_ttft_ms",
        "std_ttft_ms",
        "p99_ttft_ms",
        "mean_tpot_ms",
        "median_tpot_ms",
        "std_tpot_ms",
        "p99_tpot_ms",
        "median_itl_ms",
        "mean_itl_ms",
        "std_itl_ms",
        "p99_itl_ms",
671
672
673
674
675
676
    ]
    # These raw data might be useful, but they are rather big. They can be added
    # later if needed
    ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"]
    pt_records = convert_to_pytorch_benchmark_format(
        args=args,
677
        metrics={k: [results[k]] for k in metrics},
678
679
        extra_info={
            k: results[k]
680
681
682
683
            for k in results
            if k not in metrics and k not in ignored_metrics
        },
    )
684
685
686
    if pt_records:
        # Don't use json suffix here as we don't want CI to pick it up
        pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json"
687
        write_to_json(pt_file, pt_records)
688
689


690
691
692
693
694
def main(args: argparse.Namespace):
    print(args)
    random.seed(args.seed)
    np.random.seed(args.seed)

695
696
    backend = args.backend
    model_id = args.model
697
    model_name = args.served_model_name
698
    tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
699
    tokenizer_mode = args.tokenizer_mode
700

701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
    # Validate ramp-up arguments
    if args.ramp_up_strategy is not None:
        if args.request_rate != float("inf"):
            raise ValueError(
                "When using ramp-up, do not specify --request-rate. "
                "The request rate will be controlled by ramp-up parameters. "
                "Please remove the --request-rate argument."
            )
        if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None:
            raise ValueError(
                "When using --ramp-up-strategy, both --ramp-up-start-rps and "
                "--ramp-up-end-rps must be specified"
            )
        if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0:
            raise ValueError("Ramp-up start and end RPS must be non-negative")
        if args.ramp_up_start_rps > args.ramp_up_end_rps:
            raise ValueError("Ramp-up start RPS must be less than end RPS")
        if args.ramp_up_strategy == "exponential" and args.ramp_up_start_rps == 0:
            raise ValueError("For exponential ramp-up, the start RPS cannot be 0.")

721
722
    if args.base_url is not None:
        api_url = f"{args.base_url}{args.endpoint}"
723
        base_url = f"{args.base_url}"
724
725
    else:
        api_url = f"http://{args.host}:{args.port}{args.endpoint}"
726
        base_url = f"http://{args.host}:{args.port}"
727

728
729
730
731
732
    tokenizer = get_tokenizer(
        tokenizer_id,
        tokenizer_mode=tokenizer_mode,
        trust_remote_code=args.trust_remote_code,
    )
733

734
735
736
    if args.dataset_name is None:
        raise ValueError(
            "Please specify '--dataset-name' and the corresponding "
737
738
            "'--dataset-path' if required."
        )
739

740
741
742
743
744
745
746
747
748
749
    if args.dataset_name == "custom":
        dataset = CustomDataset(dataset_path=args.dataset_path)
        input_requests = dataset.sample(
            num_requests=args.num_prompts,
            tokenizer=tokenizer,
            output_len=args.custom_output_len,
            skip_chat_template=args.custom_skip_chat_template,
        )

    elif args.dataset_name == "sonnet":
750
751
        dataset = SonnetDataset(dataset_path=args.dataset_path)
        # For the "sonnet" dataset, formatting depends on the backend.
752
        if args.backend == "openai-chat":
753
754
755
756
757
758
759
760
            input_requests = dataset.sample(
                num_requests=args.num_prompts,
                input_len=args.sonnet_input_len,
                output_len=args.sonnet_output_len,
                prefix_len=args.sonnet_prefix_len,
                tokenizer=tokenizer,
                return_prompt_formatted=False,
            )
761
        else:
762
            assert tokenizer.chat_template or tokenizer.default_chat_template, (
763
764
765
766
767
768
769
770
771
772
                "Tokenizer/model must have chat template for sonnet dataset."
            )
            input_requests = dataset.sample(
                num_requests=args.num_prompts,
                input_len=args.sonnet_input_len,
                output_len=args.sonnet_output_len,
                prefix_len=args.sonnet_prefix_len,
                tokenizer=tokenizer,
                return_prompt_formatted=True,
            )
773

774
    elif args.dataset_name == "hf":
775
776
777
        # all following datasets are implemented from the
        # HuggingFaceDataset base class
        if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
778
            dataset_class = VisionArenaDataset
779
780
781
            args.hf_split = "train"
            args.hf_subset = None
        elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
782
783
            dataset_class = InstructCoderDataset
            args.hf_split = "train"
784
785
786
        elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS:
            dataset_class = MTBenchDataset
            args.hf_split = "train"
787
788
        elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
            dataset_class = ConversationDataset
789
790
791
        elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
            dataset_class = AIMODataset
            args.hf_split = "train"
792
793
794
        elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS:  # noqa: E501
            dataset_class = NextEditPredictionDataset
            args.hf_split = "train"
795
796
797
        elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
            dataset_class = ASRDataset
            args.hf_split = "train"
798
        else:
799
800
801
802
803
804
805
            supported_datasets = set(
                [
                    dataset_name
                    for cls in HuggingFaceDataset.__subclasses__()
                    for dataset_name in cls.SUPPORTED_DATASET_PATHS
                ]
            )
806
807
808
809
810
            raise ValueError(
                f"Unsupported dataset path: {args.dataset_path}. "
                "Huggingface dataset only supports dataset_path"
                f" from one of following: {supported_datasets}. "
                "Please consider contributing if you would "
811
812
                "like to add support for additional dataset formats."
            )
813

814
815
816
817
        if dataset_class.IS_MULTIMODAL and backend not in [
            "openai-chat",
            "openai-audio",
        ]:
818
819
            # multi-modal benchmark is only available on OpenAI Chat backend.
            raise ValueError(
820
821
822
                "Multi-modal content is only supported on 'openai-chat' and "
                "'openai-audio' backend."
            )
823
        input_requests = dataset_class(
824
825
826
            dataset_path=args.dataset_path,
            dataset_subset=args.hf_subset,
            dataset_split=args.hf_split,
827
            random_seed=args.seed,
828
            no_stream=args.no_stream,
829
        ).sample(
830
831
            num_requests=args.num_prompts,
            tokenizer=tokenizer,
832
            output_len=args.hf_output_len,
833
834
        )

835
    else:
836
837
        # For datasets that follow a similar structure, use a mapping.
        dataset_mapping = {
838
839
840
841
842
843
844
845
846
847
848
            "sharegpt": lambda: ShareGPTDataset(
                random_seed=args.seed, dataset_path=args.dataset_path
            ).sample(
                tokenizer=tokenizer,
                num_requests=args.num_prompts,
                output_len=args.sharegpt_output_len,
            ),
            "burstgpt": lambda: BurstGPTDataset(
                random_seed=args.seed, dataset_path=args.dataset_path
            ).sample(tokenizer=tokenizer, num_requests=args.num_prompts),
            "random": lambda: RandomDataset(dataset_path=args.dataset_path).sample(
849
850
851
852
853
854
                tokenizer=tokenizer,
                num_requests=args.num_prompts,
                prefix_len=args.random_prefix_len,
                input_len=args.random_input_len,
                output_len=args.random_output_len,
                range_ratio=args.random_range_ratio,
855
            ),
856
        }
857

858
859
860
861
        try:
            input_requests = dataset_mapping[args.dataset_name]()
        except KeyError as err:
            raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
862
863
    goodput_config_dict = check_goodput_args(args)

864
865
866
867
868
869
870
    # Collect the sampling parameters.
    sampling_params = {
        k: v
        for k, v in {
            "top_p": args.top_p,
            "top_k": args.top_k,
            "min_p": args.min_p,
871
872
873
            "temperature": args.temperature,
        }.items()
        if v is not None
874
875
876
877
878
    }

    # Sampling parameters are only supported by openai-compatible backend.
    if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS:
        raise ValueError(
879
880
            "Sampling parameters are only supported by openai-compatible backends."
        )
881
882
883
884

    if "temperature" not in sampling_params:
        sampling_params["temperature"] = 0.0  # Default to greedy decoding.

885
886
887
888
    if args.backend == "llama.cpp":
        # Disable prompt caching in llama.cpp backend
        sampling_params["cache_prompt"] = False

889
890
891
    # Avoid GC processing "static" data - reduce pause times.
    gc.collect()
    gc.freeze()
892

893
894
895
896
    benchmark_result = asyncio.run(
        benchmark(
            backend=backend,
            api_url=api_url,
897
            base_url=base_url,
898
            model_id=model_id,
899
            model_name=model_name,
900
901
            tokenizer=tokenizer,
            input_requests=input_requests,
902
            logprobs=args.logprobs,
903
            request_rate=args.request_rate,
904
            burstiness=args.burstiness,
905
            disable_tqdm=args.disable_tqdm,
906
            profile=args.profile,
907
            selected_percentile_metrics=args.percentile_metrics.split(","),
908
            selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")],
909
            ignore_eos=args.ignore_eos,
910
            goodput_config_dict=goodput_config_dict,
911
            max_concurrency=args.max_concurrency,
912
            lora_modules=args.lora_modules,
913
            extra_body=sampling_params,
914
915
916
            ramp_up_strategy=args.ramp_up_strategy,
            ramp_up_start_rps=args.ramp_up_start_rps,
            ramp_up_end_rps=args.ramp_up_end_rps,
917
918
        )
    )
919
920

    # Save config and results to json
921
    if args.save_result or args.append_result:
922
        result_json: dict[str, Any] = {}
923
924
925
926
927
928
929
930
931

        # Setup
        current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
        result_json["date"] = current_dt
        result_json["backend"] = backend
        result_json["model_id"] = model_id
        result_json["tokenizer_id"] = tokenizer_id
        result_json["num_prompts"] = args.num_prompts

932
933
934
935
936
937
938
939
940
941
        # Metadata
        if args.metadata:
            for item in args.metadata:
                if "=" in item:
                    kvstring = item.split("=")
                    result_json[kvstring[0].strip()] = kvstring[1].strip()
                else:
                    raise ValueError(
                        "Invalid metadata format. Please use KEY=VALUE format."
                    )
942
        # Traffic
943
944
945
        result_json["request_rate"] = (
            args.request_rate if args.request_rate < float("inf") else "inf"
        )
946
947
948
        result_json["burstiness"] = args.burstiness
        result_json["max_concurrency"] = args.max_concurrency

949
950
951
952
953
        if args.ramp_up_strategy is not None:
            result_json["ramp_up_strategy"] = args.ramp_up_strategy
            result_json["ramp_up_start_rps"] = args.ramp_up_start_rps
            result_json["ramp_up_end_rps"] = args.ramp_up_end_rps

954
955
        # Merge with benchmark result
        result_json = {**result_json, **benchmark_result}
956

957
958
959
        if not args.save_detailed:
            # Remove fields with too many data points
            for field in [
960
961
962
963
964
965
                "input_lens",
                "output_lens",
                "ttfts",
                "itls",
                "generated_texts",
                "errors",
966
967
968
            ]:
                if field in result_json:
                    del result_json[field]
969
970
                if field in benchmark_result:
                    del benchmark_result[field]
971

972
973
        # Save to file
        base_model_id = model_id.split("/")[-1]
974
975
976
977
978
        max_concurrency_str = (
            f"-concurrency{args.max_concurrency}"
            if args.max_concurrency is not None
            else ""
        )
979
980
981
982
        if args.ramp_up_strategy is not None:
            file_name = f"{backend}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json"  # noqa
        else:
            file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json"  # noqa
983
984
        if args.result_filename:
            file_name = args.result_filename
985
        if args.result_dir:
986
            os.makedirs(args.result_dir, exist_ok=True)
987
            file_name = os.path.join(args.result_dir, file_name)
988
989
990
        with open(
            file_name, mode="a+" if args.append_result else "w", encoding="utf-8"
        ) as outfile:
991
992
993
            # Append a newline.
            if args.append_result and outfile.tell() != 0:
                outfile.write("\n")
994
            json.dump(result_json, outfile)
995
        save_to_pytorch_benchmark_format(args, result_json, file_name)
996
997


998
def create_argument_parser():
999
    parser = FlexibleArgumentParser(
1000
1001
        description="Benchmark the online serving throughput."
    )
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
    parser.add_argument(
        "--backend",
        type=str,
        default="vllm",
        choices=list(ASYNC_REQUEST_FUNCS.keys()),
    )
    parser.add_argument(
        "--base-url",
        type=str,
        default=None,
        help="Server or API base url if not using http host and port.",
    )
1014
1015
    # Use 127.0.0.1 here instead of localhost to force the use of ipv4
    parser.add_argument("--host", type=str, default="127.0.0.1")
1016
    parser.add_argument("--port", type=int, default=8000)
1017
1018
1019
    parser.add_argument(
        "--endpoint",
        type=str,
1020
        default="/v1/completions",
1021
1022
        help="API endpoint.",
    )
1023
1024
1025
1026
    parser.add_argument(
        "--dataset-name",
        type=str,
        default="sharegpt",
1027
        choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"],
1028
1029
        help="Name of the dataset to benchmark on.",
    )
1030
1031
1032
1033
1034
1035
1036
    parser.add_argument(
        "--dataset-path",
        type=str,
        default=None,
        help="Path to the sharegpt/sonnet dataset. "
        "Or the huggingface dataset ID if using HF dataset.",
    )
1037
1038
1039
1040
1041
    parser.add_argument(
        "--no-stream",
        action="store_true",
        help="Do not load the dataset in streaming mode.",
    )
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
    parser.add_argument(
        "--max-concurrency",
        type=int,
        default=None,
        help="Maximum number of concurrent requests. This can be used "
        "to help simulate an environment where a higher level component "
        "is enforcing a maximum number of concurrent requests. While the "
        "--request-rate argument controls the rate at which requests are "
        "initiated, this argument will control how many are actually allowed "
        "to execute at a time. This means that when used in combination, the "
        "actual request rate may be lower than specified with --request-rate, "
1053
1054
        "if the server is not processing requests fast enough to keep up.",
    )
1055

1056
1057
1058
1059
1060
1061
1062
1063
1064
    parser.add_argument(
        "--model",
        type=str,
        required=True,
        help="Name of the model.",
    )
    parser.add_argument(
        "--tokenizer",
        type=str,
1065
        help="Name or path of the tokenizer, if not using the default tokenizer.",  # noqa: E501
1066
    )
1067
    parser.add_argument("--use-beam-search", action="store_true")
1068
1069
1070
1071
1072
1073
    parser.add_argument(
        "--num-prompts",
        type=int,
        default=1000,
        help="Number of prompts to process.",
    )
1074
1075
1076
1077
    parser.add_argument(
        "--logprobs",
        type=int,
        default=None,
1078
1079
1080
1081
1082
1083
1084
        help=(
            "Number of logprobs-per-token to compute & return as part of "
            "the request. If unspecified, then either (1) if beam search "
            "is disabled, no logprobs are computed & a single dummy "
            "logprob is returned for each token; or (2) if beam search "
            "is enabled 1 logprob per token is computed"
        ),
1085
    )
1086
1087
1088
1089
1090
1091
    parser.add_argument(
        "--request-rate",
        type=float,
        default=float("inf"),
        help="Number of requests per second. If this is inf, "
        "then all the requests are sent at time 0. "
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
        "Otherwise, we use Poisson process or gamma distribution "
        "to synthesize the request arrival times.",
    )
    parser.add_argument(
        "--burstiness",
        type=float,
        default=1.0,
        help="Burstiness factor of the request generation. "
        "Only take effect when request_rate is not inf. "
        "Default value is 1, which follows Poisson process. "
        "Otherwise, the request intervals follow a gamma distribution. "
        "A lower burstiness value (0 < burstiness < 1) results in more "
        "bursty requests. A higher burstiness value (burstiness > 1) "
        "results in a more uniform arrival of requests.",
1106
    )
1107
    parser.add_argument("--seed", type=int, default=0)
1108
1109
1110
1111
1112
1113
1114
1115
    parser.add_argument(
        "--trust-remote-code",
        action="store_true",
        help="Trust remote code from huggingface",
    )
    parser.add_argument(
        "--disable-tqdm",
        action="store_true",
1116
        help="Specify to disable tqdm progress bar.",
1117
1118
    )
    parser.add_argument(
1119
1120
1121
1122
1123
1124
        "--profile",
        action="store_true",
        help="Use Torch Profiler. The endpoint must be launched with "
        "VLLM_TORCH_PROFILER_DIR to enable profiler.",
    )
    parser.add_argument(
1125
1126
1127
1128
        "--save-result",
        action="store_true",
        help="Specify to save benchmark results to a json file",
    )
1129
1130
1131
1132
1133
1134
    parser.add_argument(
        "--save-detailed",
        action="store_true",
        help="When saving the results, whether to include per request "
        "information such as response, error, ttfs, tpots, etc.",
    )
1135
1136
1137
1138
1139
    parser.add_argument(
        "--append-result",
        action="store_true",
        help="Append the benchmark result to the existing json file.",
    )
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
    parser.add_argument(
        "--metadata",
        metavar="KEY=VALUE",
        nargs="*",
        help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) "
        "for metadata of this run to be saved in the result JSON file "
        "for record keeping purposes.",
    )
    parser.add_argument(
        "--result-dir",
        type=str,
        default=None,
        help="Specify directory to save benchmark json results."
        "If not specified, results are saved in the current directory.",
    )
1155
1156
1157
1158
1159
1160
1161
1162
1163
    parser.add_argument(
        "--result-filename",
        type=str,
        default=None,
        help="Specify the filename to save benchmark json results."
        "If not specified, results will be saved in "
        "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
        " format.",
    )
1164
1165
1166
1167
    parser.add_argument(
        "--ignore-eos",
        action="store_true",
        help="Set ignore_eos flag when sending the benchmark request."
1168
1169
        "Warning: ignore_eos is not supported in deepspeed_mii and tgi.",
    )
1170
1171
1172
1173
    parser.add_argument(
        "--percentile-metrics",
        type=str,
        default="ttft,tpot,itl",
1174
        help="Comma-separated list of selected metrics to report percentils. "
1175
        "This argument specifies the metrics to report percentiles. "
1176
1177
1178
        'Allowed metric names are "ttft", "tpot", "itl", "e2el". '
        'Default value is "ttft,tpot,itl".',
    )
1179
1180
1181
1182
    parser.add_argument(
        "--metric-percentiles",
        type=str,
        default="99",
1183
        help="Comma-separated list of percentiles for selected metrics. "
1184
1185
1186
        'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". '
        'Default value is "99". '
        'Use "--percentile-metrics" to select metrics.',
1187
    )
1188
1189
1190
1191
    parser.add_argument(
        "--goodput",
        nargs="+",
        required=False,
1192
        help='Specify service level objectives for goodput as "KEY:VALUE" '
1193
        "pairs, where the key is a metric name, and the value is in "
1194
        'milliseconds. Multiple "KEY:VALUE" pairs can be provided, '
1195
        "separated by spaces. Allowed request level metric names are "
1196
        '"ttft", "tpot", "e2el". For more context on the definition of '
1197
        "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
1198
1199
        "and the blog: https://hao-ai-lab.github.io/blogs/distserve",
    )
1200

1201
    # group for dataset specific arguments
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
    custom_group = parser.add_argument_group("custom dataset options")
    custom_group.add_argument(
        "--custom-output-len",
        type=int,
        default=256,
        help="Number of output tokens per request, used only for custom dataset.",
    )
    custom_group.add_argument(
        "--custom-skip-chat-template",
        action="store_true",
        help="Skip applying chat template to prompt, used only for custom dataset.",
    )

1215
1216
1217
1218
1219
    sonnet_group = parser.add_argument_group("sonnet dataset options")
    sonnet_group.add_argument(
        "--sonnet-input-len",
        type=int,
        default=550,
1220
        help="Number of input tokens per request, used only for sonnet dataset.",
1221
1222
1223
1224
1225
    )
    sonnet_group.add_argument(
        "--sonnet-output-len",
        type=int,
        default=150,
1226
        help="Number of output tokens per request, used only for sonnet dataset.",
1227
1228
1229
1230
1231
    )
    sonnet_group.add_argument(
        "--sonnet-prefix-len",
        type=int,
        default=200,
1232
        help="Number of prefix tokens per request, used only for sonnet dataset.",
1233
1234
1235
1236
1237
1238
1239
1240
    )

    sharegpt_group = parser.add_argument_group("sharegpt dataset options")
    sharegpt_group.add_argument(
        "--sharegpt-output-len",
        type=int,
        default=None,
        help="Output length for each request. Overrides the output length "
1241
1242
        "from the ShareGPT dataset.",
    )
1243
1244
1245
1246
1247
1248

    random_group = parser.add_argument_group("random dataset options")
    random_group.add_argument(
        "--random-input-len",
        type=int,
        default=1024,
1249
        help="Number of input tokens per request, used only for random sampling.",
1250
1251
1252
1253
1254
    )
    random_group.add_argument(
        "--random-output-len",
        type=int,
        default=128,
1255
        help="Number of output tokens per request, used only for random sampling.",
1256
1257
1258
1259
    )
    random_group.add_argument(
        "--random-range-ratio",
        type=float,
1260
1261
1262
1263
1264
        default=0.0,
        help="Range ratio for sampling input/output length, "
        "used only for random sampling. Must be in the range [0, 1) to define "
        "a symmetric sampling range"
        "[length * (1 - range_ratio), length * (1 + range_ratio)].",
1265
1266
1267
1268
1269
    )
    random_group.add_argument(
        "--random-prefix-len",
        type=int,
        default=0,
1270
1271
1272
1273
1274
1275
1276
1277
        help=(
            "Number of fixed prefix tokens before the random context "
            "in a request. "
            "The total input length is the sum of `random-prefix-len` and "
            "a random "
            "context length sampled from [input_len * (1 - range_ratio), "
            "input_len * (1 + range_ratio)]."
        ),
1278
    )
1279
1280

    hf_group = parser.add_argument_group("hf dataset options")
1281
1282
1283
1284
1285
1286
    hf_group.add_argument(
        "--hf-subset", type=str, default=None, help="Subset of the HF dataset."
    )
    hf_group.add_argument(
        "--hf-split", type=str, default=None, help="Split of the HF dataset."
    )
1287
1288
1289
1290
1291
1292
1293
1294
    hf_group.add_argument(
        "--hf-output-len",
        type=int,
        default=None,
        help="Output length for each request. Overrides the output lengths "
        "from the sampled HF dataset.",
    )

1295
1296
1297
1298
1299
    sampling_group = parser.add_argument_group("sampling parameters")
    sampling_group.add_argument(
        "--top-p",
        type=float,
        default=None,
1300
1301
        help="Top-p sampling parameter. Only has effect on openai-compatible backends.",
    )
1302
1303
1304
1305
    sampling_group.add_argument(
        "--top-k",
        type=int,
        default=None,
1306
1307
        help="Top-k sampling parameter. Only has effect on openai-compatible backends.",
    )
1308
1309
1310
1311
    sampling_group.add_argument(
        "--min-p",
        type=float,
        default=None,
1312
1313
        help="Min-p sampling parameter. Only has effect on openai-compatible backends.",
    )
1314
1315
1316
1317
1318
1319
    sampling_group.add_argument(
        "--temperature",
        type=float,
        default=None,
        help="Temperature sampling parameter. Only has effect on "
        "openai-compatible backends. If not specified, default to greedy "
1320
1321
        "decoding (i.e. temperature==0.0).",
    )
1322

1323
    parser.add_argument(
1324
        "--tokenizer-mode",
1325
1326
        type=str,
        default="auto",
1327
        choices=["auto", "slow", "mistral", "custom"],
1328
1329
        help='The tokenizer mode.\n\n* "auto" will use the '
        'fast tokenizer if available.\n* "slow" will '
1330
        "always use the slow tokenizer. \n* "
1331
        '"mistral" will always use the `mistral_common` tokenizer. \n*'
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
        '"custom" will use --tokenizer to select the preregistered tokenizer.',
    )

    parser.add_argument(
        "--served-model-name",
        type=str,
        default=None,
        help="The model name used in the API. "
        "If not specified, the model name will be the "
        "same as the ``--model`` argument. ",
    )

    parser.add_argument(
        "--lora-modules",
        nargs="+",
        default=None,
        help="A subset of LoRA module names passed in when "
        "launching the server. For each request, the "
        "script chooses a LoRA module at random.",
    )
1352

1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
    parser.add_argument(
        "--ramp-up-strategy",
        type=str,
        default=None,
        choices=["linear", "exponential"],
        help="The ramp-up strategy. This would be used to "
        "ramp up the request rate from initial RPS to final "
        "RPS rate (specified by --ramp-up-start-rps and --ramp-up-end-rps). "
        "over the duration of the benchmark.",
    )
    parser.add_argument(
        "--ramp-up-start-rps",
        type=int,
        default=None,
        help="The starting request rate for ramp-up (RPS). "
        "Needs to be specified when --ramp-up-strategy is used.",
    )
    parser.add_argument(
        "--ramp-up-end-rps",
        type=int,
        default=None,
        help="The ending request rate for ramp-up (RPS). "
        "Needs to be specified when --ramp-up-strategy is used.",
    )

1378
1379
    return parser

1380

1381
1382
1383
if __name__ == "__main__":
    parser = create_argument_parser()
    args = parser.parse_args()
1384
    main(args)