Unverified Commit fd68cd13 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Fixes for SLA finder (#35537)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 0edf101d
...@@ -112,6 +112,7 @@ Example command: ...@@ -112,6 +112,7 @@ Example command:
vllm bench sweep serve_sla \ vllm bench sweep serve_sla \
--serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \ --serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \
--bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 100' \ --bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 100' \
--sla-variable max_concurrency \
--serve-params benchmarks/serve_hparams.json \ --serve-params benchmarks/serve_hparams.json \
--bench-params benchmarks/bench_hparams.json --bench-params benchmarks/bench_hparams.json
-o benchmarks/results -o benchmarks/results
...@@ -119,8 +120,8 @@ vllm bench sweep serve_sla \ ...@@ -119,8 +120,8 @@ vllm bench sweep serve_sla \
The algorithm for scanning through different values of `sla_variable` can be summarized as follows: The algorithm for scanning through different values of `sla_variable` can be summarized as follows:
1. Run the benchmark once with `sla_variable = 1` to simulate serial inference. This results in the lowest possible latency and throughput. 1. Run the benchmark by sending requests one at a time (serial inference). This results in the lowest possible latency and throughput.
2. Run the benchmark once with `sla_variable = num_prompts` to simulate batch inference over the whole dataset. This results in the highest possible latency and throughput. 2. Run the benchmark by sending all requests at once (batch inference). This results in the highest possible latency and throughput.
3. Estimate the maximum value of `sla_variable` that can be supported by the server without oversaturating it. 3. Estimate the maximum value of `sla_variable` that can be supported by the server without oversaturating it.
4. Run the benchmark over intermediate values of `sla_variable` uniformly using the remaining iterations. 4. Run the benchmark over intermediate values of `sla_variable` uniformly using the remaining iterations.
...@@ -129,6 +130,9 @@ You can override the number of iterations in the algorithm by setting `--sla-ite ...@@ -129,6 +130,9 @@ You can override the number of iterations in the algorithm by setting `--sla-ite
!!! tip !!! tip
This is our equivalent of [GuideLLM's `--profile sweep`](https://github.com/vllm-project/guidellm/blob/v0.5.3/src/guidellm/benchmark/profiles.py#L575). This is our equivalent of [GuideLLM's `--profile sweep`](https://github.com/vllm-project/guidellm/blob/v0.5.3/src/guidellm/benchmark/profiles.py#L575).
In general, `--sla-variable max_concurrency` produces more reliable results because it directly controls the workload imposed on the vLLM engine.
Nevertheless, we default to `--sla-variable request_rate` to maintain similar behavior as GuideLLM.
## Startup Benchmark ## Startup Benchmark
`vllm bench sweep startup` runs `vllm bench startup` across parameter combinations to compare cold/warm startup time for different engine settings. `vllm bench sweep startup` runs `vllm bench startup` across parameter combinations to compare cold/warm startup time for different engine settings.
...@@ -197,23 +201,32 @@ Control the variables to plot via `--var-x` and `--var-y`, optionally applying ` ...@@ -197,23 +201,32 @@ Control the variables to plot via `--var-x` and `--var-y`, optionally applying `
Example commands for visualizing [SLA Scanner](#sla-scanner) results: Example commands for visualizing [SLA Scanner](#sla-scanner) results:
```bash ```bash
# Latency increases as the request rate increases # Name of the directory that stores the results
vllm bench sweep plot benchmarks/results/<timestamp> \ TIMESTAMP=$1
--var-x request_rate \
--var-y p99_ttft_ms \ # Latency increases as the workload increases
--row-by random_input_len \ vllm bench sweep plot benchmarks/results/$TIMESTAMP \
--col-by random_output_len \ --var-x max_concurrency \
--var-y median_ttft_ms \
--col-by _benchmark_name \
--curve-by max_num_seqs,max_num_batched_tokens \
--fig-name latency_curve
# Throughput saturates as workload increases
vllm bench sweep plot benchmarks/results/$TIMESTAMP \
--var-x max_concurrency \
--var-y total_token_throughput \
--col-by _benchmark_name \
--curve-by max_num_seqs,max_num_batched_tokens \ --curve-by max_num_seqs,max_num_batched_tokens \
--filter-by 'request_rate<=128' --fig-name throughput_curve
# Tradeoff between latency and throughput # Tradeoff between latency and throughput
vllm bench sweep plot benchmarks/results/<timestamp> \ vllm bench sweep plot benchmarks/results/$TIMESTAMP \
--var-x request_throughput \ --var-x total_token_throughput \
--var-y median_ttft_ms \ --var-y median_ttft_ms \
--row-by random_input_len \ --col-by _benchmark_name \
--col-by random_output_len \
--curve-by max_num_seqs,max_num_batched_tokens \ --curve-by max_num_seqs,max_num_batched_tokens \
--filter-by 'request_rate<=128' --fig-name latency_throughput
``` ```
!!! tip !!! tip
......
...@@ -60,6 +60,8 @@ except ImportError: ...@@ -60,6 +60,8 @@ except ImportError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_NUM_PROMPTS = 1000
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Data Classes # Data Classes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
...@@ -1338,7 +1340,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser): ...@@ -1338,7 +1340,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
parser.add_argument( parser.add_argument(
"--num-prompts", "--num-prompts",
type=int, type=int,
default=1000, default=DEFAULT_NUM_PROMPTS,
help="Number of prompts to process.", help="Number of prompts to process.",
) )
parser.add_argument( parser.add_argument(
......
...@@ -324,6 +324,11 @@ def _plot_fig( ...@@ -324,6 +324,11 @@ def _plot_fig(
df = filter_by.apply(df) df = filter_by.apply(df)
df = bin_by.apply(df) df = bin_by.apply(df)
if len(df) == 0:
print(f"No data to plot. Filters: {filter_by}")
print("[END FIGURE]")
return
# Sort by curve_by columns alphabetically for consistent legend ordering # Sort by curve_by columns alphabetically for consistent legend ordering
if curve_by: if curve_by:
df = df.sort_values(by=curve_by) df = df.sort_values(by=curve_by)
...@@ -570,13 +575,13 @@ class SweepPlotArgs: ...@@ -570,13 +575,13 @@ class SweepPlotArgs:
parser.add_argument( parser.add_argument(
"--var-x", "--var-x",
type=str, type=str,
default="request_throughput", default="total_token_throughput",
help="The variable for the x-axis.", help="The variable for the x-axis.",
) )
parser.add_argument( parser.add_argument(
"--var-y", "--var-y",
type=str, type=str,
default="p99_ttft_ms", default="median_ttft_ms",
help="The variable for the y-axis", help="The variable for the y-axis",
) )
parser.add_argument( parser.add_argument(
......
...@@ -138,12 +138,16 @@ def _get_comb_base_path( ...@@ -138,12 +138,16 @@ def _get_comb_base_path(
output_dir: Path, output_dir: Path,
serve_comb: ParameterSweepItem, serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem, bench_comb: ParameterSweepItem,
*,
extra_parts: tuple[str, ...] = (),
): ):
parts = list[str]() parts = list[str]()
if serve_comb: if serve_comb:
parts.extend(("SERVE-", serve_comb.name)) parts.extend(("SERVE-", serve_comb.name))
if bench_comb: if bench_comb:
parts.extend(("BENCH-", bench_comb.name)) parts.extend(("BENCH-", bench_comb.name))
if extra_parts:
parts.extend(extra_parts)
return output_dir / sanitize_filename("-".join(parts)) return output_dir / sanitize_filename("-".join(parts))
......
...@@ -10,6 +10,7 @@ from typing import ClassVar, Literal, get_args ...@@ -10,6 +10,7 @@ from typing import ClassVar, Literal, get_args
import numpy as np import numpy as np
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.benchmarks.datasets import DEFAULT_NUM_PROMPTS
from vllm.utils.import_utils import PlaceholderModule from vllm.utils.import_utils import PlaceholderModule
from .param_sweep import ParameterSweep, ParameterSweepItem from .param_sweep import ParameterSweep, ParameterSweepItem
...@@ -65,7 +66,12 @@ def run_comb_sla( ...@@ -65,7 +66,12 @@ def run_comb_sla(
bench_cmd, bench_cmd,
serve_comb=serve_comb, serve_comb=serve_comb,
bench_comb=bench_comb_sla, bench_comb=bench_comb_sla,
base_path=_get_comb_base_path(output_dir, serve_comb, bench_comb_sla), base_path=_get_comb_base_path(
output_dir,
serve_comb,
bench_comb,
extra_parts=("SLA-", f"{sla_variable}={sla_value}"),
),
num_runs=num_runs, num_runs=num_runs,
dry_run=dry_run, dry_run=dry_run,
link_vars=link_vars, link_vars=link_vars,
...@@ -93,11 +99,25 @@ def explore_sla( ...@@ -93,11 +99,25 @@ def explore_sla(
if sla_iters < 2: if sla_iters < 2:
raise ValueError("`sla_iters` should be at least 2") raise ValueError("`sla_iters` should be at least 2")
dataset_size = DEFAULT_NUM_PROMPTS
if "num_prompts" in bench_comb:
dataset_size = int(bench_comb["num_prompts"]) # type: ignore
else:
for i, arg in enumerate(bench_cmd):
if arg == "--num-prompts" and i + 1 < len(bench_cmd):
dataset_size = int(bench_cmd[i + 1])
break
elif arg.startswith("--num-prompts="):
dataset_size = int(arg.split("=", 1)[1])
break
print(f"Dataset size: {dataset_size}")
serial_comb_data = run_comb_sla( serial_comb_data = run_comb_sla(
server, server,
bench_cmd, bench_cmd,
serve_comb=serve_comb, serve_comb=serve_comb,
bench_comb=bench_comb, bench_comb=bench_comb | {"max_concurrency": 1},
output_dir=output_dir, output_dir=output_dir,
num_runs=num_runs, num_runs=num_runs,
dry_run=dry_run, dry_run=dry_run,
...@@ -109,13 +129,13 @@ def explore_sla( ...@@ -109,13 +129,13 @@ def explore_sla(
server, server,
bench_cmd, bench_cmd,
serve_comb=serve_comb, serve_comb=serve_comb,
bench_comb=bench_comb, bench_comb=bench_comb | {"max_concurrency": dataset_size},
output_dir=output_dir, output_dir=output_dir,
num_runs=num_runs, num_runs=num_runs,
dry_run=dry_run, dry_run=dry_run,
link_vars=link_vars, link_vars=link_vars,
sla_variable=sla_variable, sla_variable=sla_variable,
sla_value=int(bench_comb.get("num_prompts", 1000)), # type: ignore sla_value=dataset_size,
) )
if serial_comb_data is None or batch_comb_data is None: if serial_comb_data is None or batch_comb_data is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment