Unverified Commit ecca3fee authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Frontend] Add `vllm bench sweep` to CLI (#27639)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 9a0d2f0d
...@@ -5,4 +5,4 @@ nav: ...@@ -5,4 +5,4 @@ nav:
- complete.md - complete.md
- run-batch.md - run-batch.md
- vllm bench: - vllm bench:
- bench/*.md - bench/**/*.md
# vllm bench sweep plot
## JSON CLI Arguments
--8<-- "docs/cli/json_tip.inc.md"
## Options
--8<-- "docs/argparse/bench_sweep_plot.md"
# vllm bench sweep serve
## JSON CLI Arguments
--8<-- "docs/cli/json_tip.inc.md"
## Options
--8<-- "docs/argparse/bench_sweep_serve.md"
# vllm bench sweep serve_sla
## JSON CLI Arguments
--8<-- "docs/cli/json_tip.inc.md"
## Options
--8<-- "docs/argparse/bench_sweep_serve_sla.md"
...@@ -1061,7 +1061,7 @@ Follow these steps to run the script: ...@@ -1061,7 +1061,7 @@ Follow these steps to run the script:
Example command: Example command:
```bash ```bash
python -m vllm.benchmarks.sweep.serve \ vllm bench sweep serve \
--serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \ --serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \
--bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \ --bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \
--serve-params benchmarks/serve_hparams.json \ --serve-params benchmarks/serve_hparams.json \
...@@ -1109,7 +1109,7 @@ For example, to ensure E2E latency within different target values for 99% of req ...@@ -1109,7 +1109,7 @@ For example, to ensure E2E latency within different target values for 99% of req
Example command: Example command:
```bash ```bash
python -m vllm.benchmarks.sweep.serve_sla \ vllm bench sweep serve_sla \
--serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \ --serve-cmd 'vllm serve meta-llama/Llama-2-7b-chat-hf' \
--bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \ --bench-cmd 'vllm bench serve --model meta-llama/Llama-2-7b-chat-hf --backend vllm --endpoint /v1/completions --dataset-name sharegpt --dataset-path benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json' \
--serve-params benchmarks/serve_hparams.json \ --serve-params benchmarks/serve_hparams.json \
...@@ -1138,7 +1138,7 @@ The algorithm for adjusting the SLA variable is as follows: ...@@ -1138,7 +1138,7 @@ The algorithm for adjusting the SLA variable is as follows:
Example command: Example command:
```bash ```bash
python -m vllm.benchmarks.sweep.plot benchmarks/results/<timestamp> \ vllm bench sweep plot benchmarks/results/<timestamp> \
--var-x max_concurrency \ --var-x max_concurrency \
--row-by random_input_len \ --row-by random_input_len \
--col-by random_output_len \ --col-by random_output_len \
......
...@@ -56,15 +56,20 @@ def auto_mock(module, attr, max_mocks=50): ...@@ -56,15 +56,20 @@ def auto_mock(module, attr, max_mocks=50):
) )
latency = auto_mock("vllm.benchmarks", "latency") bench_latency = auto_mock("vllm.benchmarks", "latency")
serve = auto_mock("vllm.benchmarks", "serve") bench_serve = auto_mock("vllm.benchmarks", "serve")
throughput = auto_mock("vllm.benchmarks", "throughput") bench_sweep_plot = auto_mock("vllm.benchmarks.sweep.plot", "SweepPlotArgs")
bench_sweep_serve = auto_mock("vllm.benchmarks.sweep.serve", "SweepServeArgs")
bench_sweep_serve_sla = auto_mock(
"vllm.benchmarks.sweep.serve_sla", "SweepServeSLAArgs"
)
bench_throughput = auto_mock("vllm.benchmarks", "throughput")
AsyncEngineArgs = auto_mock("vllm.engine.arg_utils", "AsyncEngineArgs") AsyncEngineArgs = auto_mock("vllm.engine.arg_utils", "AsyncEngineArgs")
EngineArgs = auto_mock("vllm.engine.arg_utils", "EngineArgs") EngineArgs = auto_mock("vllm.engine.arg_utils", "EngineArgs")
ChatCommand = auto_mock("vllm.entrypoints.cli.openai", "ChatCommand") ChatCommand = auto_mock("vllm.entrypoints.cli.openai", "ChatCommand")
CompleteCommand = auto_mock("vllm.entrypoints.cli.openai", "CompleteCommand") CompleteCommand = auto_mock("vllm.entrypoints.cli.openai", "CompleteCommand")
cli_args = auto_mock("vllm.entrypoints.openai", "cli_args") openai_cli_args = auto_mock("vllm.entrypoints.openai", "cli_args")
run_batch = auto_mock("vllm.entrypoints.openai", "run_batch") openai_run_batch = auto_mock("vllm.entrypoints.openai", "run_batch")
FlexibleArgumentParser = auto_mock( FlexibleArgumentParser = auto_mock(
"vllm.utils.argparse_utils", "FlexibleArgumentParser" "vllm.utils.argparse_utils", "FlexibleArgumentParser"
) )
...@@ -114,6 +119,9 @@ class MarkdownFormatter(HelpFormatter): ...@@ -114,6 +119,9 @@ class MarkdownFormatter(HelpFormatter):
self._markdown_output.append(f"{action.help}\n\n") self._markdown_output.append(f"{action.help}\n\n")
if (default := action.default) != SUPPRESS: if (default := action.default) != SUPPRESS:
# Make empty string defaults visible
if default == "":
default = '""'
self._markdown_output.append(f"Default: `{default}`\n\n") self._markdown_output.append(f"Default: `{default}`\n\n")
def format_help(self): def format_help(self):
...@@ -150,17 +158,23 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): ...@@ -150,17 +158,23 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
# Create parsers to document # Create parsers to document
parsers = { parsers = {
# Engine args
"engine_args": create_parser(EngineArgs.add_cli_args), "engine_args": create_parser(EngineArgs.add_cli_args),
"async_engine_args": create_parser( "async_engine_args": create_parser(
AsyncEngineArgs.add_cli_args, async_args_only=True AsyncEngineArgs.add_cli_args, async_args_only=True
), ),
"serve": create_parser(cli_args.make_arg_parser), # CLI
"serve": create_parser(openai_cli_args.make_arg_parser),
"chat": create_parser(ChatCommand.add_cli_args), "chat": create_parser(ChatCommand.add_cli_args),
"complete": create_parser(CompleteCommand.add_cli_args), "complete": create_parser(CompleteCommand.add_cli_args),
"bench_latency": create_parser(latency.add_cli_args), "run-batch": create_parser(openai_run_batch.make_arg_parser),
"bench_throughput": create_parser(throughput.add_cli_args), # Benchmark CLI
"bench_serve": create_parser(serve.add_cli_args), "bench_latency": create_parser(bench_latency.add_cli_args),
"run-batch": create_parser(run_batch.make_arg_parser), "bench_serve": create_parser(bench_serve.add_cli_args),
"bench_sweep_plot": create_parser(bench_sweep_plot.add_cli_args),
"bench_sweep_serve": create_parser(bench_sweep_serve.add_cli_args),
"bench_sweep_serve_sla": create_parser(bench_sweep_serve_sla.add_cli_args),
"bench_throughput": create_parser(bench_throughput.add_cli_args),
} }
# Generate documentation for each parser # Generate documentation for each parser
......
...@@ -709,7 +709,7 @@ setup( ...@@ -709,7 +709,7 @@ setup(
ext_modules=ext_modules, ext_modules=ext_modules,
install_requires=get_requirements(), install_requires=get_requirements(),
extras_require={ extras_require={
"bench": ["pandas", "datasets"], "bench": ["pandas", "matplotlib", "seaborn", "datasets"],
"tensorizer": ["tensorizer==2.10.1"], "tensorizer": ["tensorizer==2.10.1"],
"fastsafetensors": ["fastsafetensors >= 0.1.10"], "fastsafetensors": ["fastsafetensors >= 0.1.10"],
"runai": ["runai-model-streamer[s3,gcs] >= 0.14.0"], "runai": ["runai-model-streamer[s3,gcs] >= 0.14.0"],
......
...@@ -141,7 +141,7 @@ def attempt_to_make_names_unique(entries_and_traces): ...@@ -141,7 +141,7 @@ def attempt_to_make_names_unique(entries_and_traces):
""" """
def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame: def group_trace_by_operations(trace_df: "pd.DataFrame") -> "pd.DataFrame":
def is_rms_norm(op_name: str): def is_rms_norm(op_name: str):
if "rms_norm_kernel" in op_name: if "rms_norm_kernel" in op_name:
return True return True
...@@ -370,12 +370,12 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame: ...@@ -370,12 +370,12 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
def plot_trace_df( def plot_trace_df(
traces_df: pd.DataFrame, traces_df: "pd.DataFrame",
plot_metric: str, plot_metric: str,
plot_title: str, plot_title: str,
output: Path | None = None, output: Path | None = None,
): ):
def get_phase_description(traces_df: pd.DataFrame, phase: str) -> str: def get_phase_description(traces_df: "pd.DataFrame", phase: str) -> str:
phase_df = traces_df.query(f'phase == "{phase}"') phase_df = traces_df.query(f'phase == "{phase}"')
descs = phase_df["phase_desc"].to_list() descs = phase_df["phase_desc"].to_list()
assert all([desc == descs[0] for desc in descs]) assert all([desc == descs[0] for desc in descs])
...@@ -438,7 +438,7 @@ def main( ...@@ -438,7 +438,7 @@ def main(
top_k: int, top_k: int,
json_nodes_to_fold: list[str], json_nodes_to_fold: list[str],
): ):
def prepare_data(profile_json: dict, step_keys: list[str]) -> pd.DataFrame: def prepare_data(profile_json: dict, step_keys: list[str]) -> "pd.DataFrame":
def get_entries_and_traces(key: str): def get_entries_and_traces(key: str):
entries_and_traces: list[tuple[Any, Any]] = [] entries_and_traces: list[tuple[Any, Any]] = []
for root in profile_json[key]["summary_stats"]: for root in profile_json[key]["summary_stats"]:
...@@ -449,8 +449,8 @@ def main( ...@@ -449,8 +449,8 @@ def main(
return entries_and_traces return entries_and_traces
def keep_only_top_entries( def keep_only_top_entries(
df: pd.DataFrame, metric: str, top_k: int = 9 df: "pd.DataFrame", metric: str, top_k: int = 9
) -> pd.DataFrame: ) -> "pd.DataFrame":
df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, ["name"]] = "others" df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, ["name"]] = "others"
return df return df
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG
from .plot import SweepPlotArgs
from .plot import main as plot_main
from .serve import SweepServeArgs
from .serve import main as serve_main
from .serve_sla import SweepServeSLAArgs
from .serve_sla import main as serve_sla_main
SUBCOMMANDS = (
(SweepServeArgs, serve_main),
(SweepServeSLAArgs, serve_sla_main),
(SweepPlotArgs, plot_main),
)
def add_cli_args(parser: argparse.ArgumentParser):
subparsers = parser.add_subparsers(required=True, dest="sweep_type")
for cmd, entrypoint in SUBCOMMANDS:
cmd_subparser = subparsers.add_parser(
cmd.parser_name,
description=cmd.parser_help,
usage=f"vllm bench sweep {cmd.parser_name} [options]",
)
cmd_subparser.set_defaults(dispatch_function=entrypoint)
cmd.add_cli_args(cmd_subparser)
cmd_subparser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(
subcmd=f"sweep {cmd.parser_name}"
)
def main(args: argparse.Namespace):
args.dispatch_function(args)
...@@ -8,16 +8,24 @@ from dataclasses import dataclass ...@@ -8,16 +8,24 @@ from dataclasses import dataclass
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from types import TracebackType from types import TracebackType
from typing import ClassVar
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from typing_extensions import Self, override from typing_extensions import Self, override
from vllm.utils.collection_utils import full_groupby from vllm.utils.collection_utils import full_groupby
from vllm.utils.import_utils import PlaceholderModule
from .utils import sanitize_filename from .utils import sanitize_filename
try:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
except ImportError:
plt = PlaceholderModule("matplotlib").placeholder_attr("pyplot")
pd = PlaceholderModule("pandas")
seaborn = PlaceholderModule("seaborn")
@dataclass @dataclass
class PlotFilterBase(ABC): class PlotFilterBase(ABC):
...@@ -40,7 +48,7 @@ class PlotFilterBase(ABC): ...@@ -40,7 +48,7 @@ class PlotFilterBase(ABC):
) )
@abstractmethod @abstractmethod
def apply(self, df: pd.DataFrame) -> pd.DataFrame: def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
"""Applies this filter to a DataFrame.""" """Applies this filter to a DataFrame."""
raise NotImplementedError raise NotImplementedError
...@@ -48,7 +56,7 @@ class PlotFilterBase(ABC): ...@@ -48,7 +56,7 @@ class PlotFilterBase(ABC):
@dataclass @dataclass
class PlotEqualTo(PlotFilterBase): class PlotEqualTo(PlotFilterBase):
@override @override
def apply(self, df: pd.DataFrame) -> pd.DataFrame: def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
try: try:
target = float(self.target) target = float(self.target)
except ValueError: except ValueError:
...@@ -60,28 +68,28 @@ class PlotEqualTo(PlotFilterBase): ...@@ -60,28 +68,28 @@ class PlotEqualTo(PlotFilterBase):
@dataclass @dataclass
class PlotLessThan(PlotFilterBase): class PlotLessThan(PlotFilterBase):
@override @override
def apply(self, df: pd.DataFrame) -> pd.DataFrame: def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df[df[self.var] < float(self.target)] return df[df[self.var] < float(self.target)]
@dataclass @dataclass
class PlotLessThanOrEqualTo(PlotFilterBase): class PlotLessThanOrEqualTo(PlotFilterBase):
@override @override
def apply(self, df: pd.DataFrame) -> pd.DataFrame: def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df[df[self.var] <= float(self.target)] return df[df[self.var] <= float(self.target)]
@dataclass @dataclass
class PlotGreaterThan(PlotFilterBase): class PlotGreaterThan(PlotFilterBase):
@override @override
def apply(self, df: pd.DataFrame) -> pd.DataFrame: def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df[df[self.var] > float(self.target)] return df[df[self.var] > float(self.target)]
@dataclass @dataclass
class PlotGreaterThanOrEqualTo(PlotFilterBase): class PlotGreaterThanOrEqualTo(PlotFilterBase):
@override @override
def apply(self, df: pd.DataFrame) -> pd.DataFrame: def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
return df[df[self.var] >= float(self.target)] return df[df[self.var] >= float(self.target)]
...@@ -103,7 +111,7 @@ class PlotFilters(list[PlotFilterBase]): ...@@ -103,7 +111,7 @@ class PlotFilters(list[PlotFilterBase]):
return cls(PlotFilterBase.parse_str(e) for e in s.split(",")) return cls(PlotFilterBase.parse_str(e) for e in s.split(","))
def apply(self, df: pd.DataFrame) -> pd.DataFrame: def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
for item in self: for item in self:
df = item.apply(df) df = item.apply(df)
...@@ -127,7 +135,7 @@ class PlotBinner: ...@@ -127,7 +135,7 @@ class PlotBinner:
f"Valid operators are: {sorted(PLOT_BINNERS)}", f"Valid operators are: {sorted(PLOT_BINNERS)}",
) )
def apply(self, df: pd.DataFrame) -> pd.DataFrame: def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
"""Applies this binner to a DataFrame.""" """Applies this binner to a DataFrame."""
df = df.copy() df = df.copy()
df[self.var] = df[self.var] // self.bin_size * self.bin_size df[self.var] = df[self.var] // self.bin_size * self.bin_size
...@@ -147,7 +155,7 @@ class PlotBinners(list[PlotBinner]): ...@@ -147,7 +155,7 @@ class PlotBinners(list[PlotBinner]):
return cls(PlotBinner.parse_str(e) for e in s.split(",")) return cls(PlotBinner.parse_str(e) for e in s.split(","))
def apply(self, df: pd.DataFrame) -> pd.DataFrame: def apply(self, df: "pd.DataFrame") -> "pd.DataFrame":
for item in self: for item in self:
df = item.apply(df) df = item.apply(df)
...@@ -396,7 +404,54 @@ def plot( ...@@ -396,7 +404,54 @@ def plot(
) )
def add_cli_args(parser: argparse.ArgumentParser): @dataclass
class SweepPlotArgs:
output_dir: Path
fig_dir: Path
fig_by: list[str]
row_by: list[str]
col_by: list[str]
curve_by: list[str]
var_x: str
var_y: str
filter_by: PlotFilters
bin_by: PlotBinners
scale_x: str | None
scale_y: str | None
dry_run: bool
parser_name: ClassVar[str] = "plot"
parser_help: ClassVar[str] = "Plot performance curves from parameter sweep results."
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
output_dir = Path(args.OUTPUT_DIR)
if not output_dir.exists():
raise ValueError(f"No parameter sweep results under {output_dir}")
curve_by = [] if not args.curve_by else args.curve_by.split(",")
row_by = [] if not args.row_by else args.row_by.split(",")
col_by = [] if not args.col_by else args.col_by.split(",")
fig_by = [] if not args.fig_by else args.fig_by.split(",")
return cls(
output_dir=output_dir,
fig_dir=output_dir / args.fig_dir,
fig_by=fig_by,
row_by=row_by,
col_by=col_by,
curve_by=curve_by,
var_x=args.var_x,
var_y=args.var_y,
filter_by=PlotFilters.parse_str(args.filter_by),
bin_by=PlotBinners.parse_str(args.bin_by),
scale_x=args.scale_x,
scale_y=args.scale_y,
dry_run=args.dry_run,
)
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument( parser.add_argument(
"OUTPUT_DIR", "OUTPUT_DIR",
type=str, type=str,
...@@ -467,7 +522,7 @@ def add_cli_args(parser: argparse.ArgumentParser): ...@@ -467,7 +522,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
default="", default="",
help="A comma-separated list of statements indicating values to bin by. " help="A comma-separated list of statements indicating values to bin by. "
"This is useful to avoid plotting points that are too close together. " "This is useful to avoid plotting points that are too close together. "
"Example: `request_throughput%1` means " "Example: `request_throughput%%1` means "
"use a bin size of 1 for the `request_throughput` variable.", "use a bin size of 1 for the `request_throughput` variable.",
) )
parser.add_argument( parser.add_argument(
...@@ -493,38 +548,33 @@ def add_cli_args(parser: argparse.ArgumentParser): ...@@ -493,38 +548,33 @@ def add_cli_args(parser: argparse.ArgumentParser):
"then exits without drawing them.", "then exits without drawing them.",
) )
return parser
def main(args: argparse.Namespace):
output_dir = Path(args.OUTPUT_DIR)
if not output_dir.exists():
raise ValueError(f"No parameter sweep results under {output_dir}")
curve_by = [] if not args.curve_by else args.curve_by.split(",")
row_by = [] if not args.row_by else args.row_by.split(",")
col_by = [] if not args.col_by else args.col_by.split(",")
fig_by = [] if not args.fig_by else args.fig_by.split(",")
plot( def run_main(args: SweepPlotArgs):
output_dir=output_dir, return plot(
fig_dir=output_dir / args.fig_dir, output_dir=args.output_dir,
fig_by=fig_by, fig_dir=args.fig_dir,
row_by=row_by, fig_by=args.fig_by,
col_by=col_by, row_by=args.row_by,
curve_by=curve_by, col_by=args.col_by,
curve_by=args.curve_by,
var_x=args.var_x, var_x=args.var_x,
var_y=args.var_y, var_y=args.var_y,
filter_by=PlotFilters.parse_str(args.filter_by), filter_by=args.filter_by,
bin_by=PlotBinners.parse_str(args.bin_by), bin_by=args.bin_by,
scale_x=args.scale_x, scale_x=args.scale_x,
scale_y=args.scale_y, scale_y=args.scale_y,
dry_run=args.dry_run, dry_run=args.dry_run,
) )
def main(args: argparse.Namespace):
run_main(SweepPlotArgs.from_cli_args(args))
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description=SweepPlotArgs.parser_help)
description="Plot performance curves from parameter sweep results." SweepPlotArgs.add_cli_args(parser)
)
add_cli_args(parser)
main(parser.parse_args()) main(parser.parse_args())
...@@ -7,13 +7,19 @@ import shlex ...@@ -7,13 +7,19 @@ import shlex
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import ClassVar
import pandas as pd from vllm.utils.import_utils import PlaceholderModule
from .param_sweep import ParameterSweep, ParameterSweepItem from .param_sweep import ParameterSweep, ParameterSweepItem
from .server import ServerProcess from .server import ServerProcess
from .utils import sanitize_filename from .utils import sanitize_filename
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
@contextlib.contextmanager @contextlib.contextmanager
def run_server( def run_server(
...@@ -257,6 +263,9 @@ class SweepServeArgs: ...@@ -257,6 +263,9 @@ class SweepServeArgs:
dry_run: bool dry_run: bool
resume: str | None resume: str | None
parser_name: ClassVar[str] = "serve"
parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings."
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
serve_cmd = shlex.split(args.serve_cmd) serve_cmd = shlex.split(args.serve_cmd)
...@@ -401,9 +410,7 @@ def main(args: argparse.Namespace): ...@@ -401,9 +410,7 @@ def main(args: argparse.Namespace):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description=SweepServeArgs.parser_help)
description="Run vLLM server benchmark under multiple settings."
)
SweepServeArgs.add_cli_args(parser) SweepServeArgs.add_cli_args(parser)
main(parser.parse_args()) main(parser.parse_args())
...@@ -7,17 +7,23 @@ import math ...@@ -7,17 +7,23 @@ import math
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Literal, get_args from typing import ClassVar, Literal, get_args
import pandas as pd
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.utils.import_utils import PlaceholderModule
from .param_sweep import ParameterSweep, ParameterSweepItem from .param_sweep import ParameterSweep, ParameterSweepItem
from .serve import SweepServeArgs, run_benchmark, run_server from .serve import SweepServeArgs, run_benchmark, run_server
from .server import ServerProcess from .server import ServerProcess
from .sla_sweep import SLASweep, SLASweepItem from .sla_sweep import SLASweep, SLASweepItem
from .utils import sanitize_filename from .utils import sanitize_filename
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
def _get_sla_base_path( def _get_sla_base_path(
output_dir: Path, output_dir: Path,
...@@ -399,6 +405,9 @@ class SweepServeSLAArgs(SweepServeArgs): ...@@ -399,6 +405,9 @@ class SweepServeSLAArgs(SweepServeArgs):
sla_params: SLASweep sla_params: SLASweep
sla_variable: SLAVariable sla_variable: SLAVariable
parser_name: ClassVar[str] = "serve_sla"
parser_help: ClassVar[str] = "Tune a variable to meet SLAs under multiple settings."
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
# NOTE: Don't use super() as `from_cli_args` calls `cls()` # NOTE: Don't use super() as `from_cli_args` calls `cls()`
...@@ -419,7 +428,8 @@ class SweepServeSLAArgs(SweepServeArgs): ...@@ -419,7 +428,8 @@ class SweepServeSLAArgs(SweepServeArgs):
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser = super().add_cli_args(parser) parser = super().add_cli_args(parser)
parser.add_argument( sla_group = parser.add_argument_group("sla options")
sla_group.add_argument(
"--sla-params", "--sla-params",
type=str, type=str,
required=True, required=True,
...@@ -431,7 +441,7 @@ class SweepServeSLAArgs(SweepServeArgs): ...@@ -431,7 +441,7 @@ class SweepServeSLAArgs(SweepServeArgs):
"the maximum `sla_variable` that satisfies the constraints for " "the maximum `sla_variable` that satisfies the constraints for "
"each combination of `serve_params`, `bench_params`, and `sla_params`.", "each combination of `serve_params`, `bench_params`, and `sla_params`.",
) )
parser.add_argument( sla_group.add_argument(
"--sla-variable", "--sla-variable",
type=str, type=str,
choices=get_args(SLAVariable), choices=get_args(SLAVariable),
...@@ -476,9 +486,7 @@ def main(args: argparse.Namespace): ...@@ -476,9 +486,7 @@ def main(args: argparse.Namespace):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description=SweepServeSLAArgs.parser_help)
description="Tune a variable to meet SLAs under multiple settings."
)
SweepServeSLAArgs.add_cli_args(parser) SweepServeSLAArgs.add_cli_args(parser)
main(parser.parse_args()) main(parser.parse_args())
...@@ -2,10 +2,12 @@ ...@@ -2,10 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand
from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand
from vllm.entrypoints.cli.benchmark.sweep import BenchmarkSweepSubcommand
from vllm.entrypoints.cli.benchmark.throughput import BenchmarkThroughputSubcommand from vllm.entrypoints.cli.benchmark.throughput import BenchmarkThroughputSubcommand
__all__: list[str] = [ __all__: list[str] = [
"BenchmarkLatencySubcommand", "BenchmarkLatencySubcommand",
"BenchmarkServingSubcommand", "BenchmarkServingSubcommand",
"BenchmarkSweepSubcommand",
"BenchmarkThroughputSubcommand", "BenchmarkThroughputSubcommand",
] ]
...@@ -6,7 +6,7 @@ from vllm.entrypoints.cli.types import CLISubcommand ...@@ -6,7 +6,7 @@ from vllm.entrypoints.cli.types import CLISubcommand
class BenchmarkSubcommandBase(CLISubcommand): class BenchmarkSubcommandBase(CLISubcommand):
"""The base class of subcommands for vllm bench.""" """The base class of subcommands for `vllm bench`."""
help: str help: str
......
...@@ -7,7 +7,7 @@ from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase ...@@ -7,7 +7,7 @@ from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
class BenchmarkLatencySubcommand(BenchmarkSubcommandBase): class BenchmarkLatencySubcommand(BenchmarkSubcommandBase):
"""The `latency` subcommand for vllm bench.""" """The `latency` subcommand for `vllm bench`."""
name = "latency" name = "latency"
help = "Benchmark the latency of a single batch of requests." help = "Benchmark the latency of a single batch of requests."
......
...@@ -7,7 +7,7 @@ from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase ...@@ -7,7 +7,7 @@ from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
class BenchmarkServingSubcommand(BenchmarkSubcommandBase): class BenchmarkServingSubcommand(BenchmarkSubcommandBase):
"""The `serve` subcommand for vllm bench.""" """The `serve` subcommand for `vllm bench`."""
name = "serve" name = "serve"
help = "Benchmark the online serving throughput." help = "Benchmark the online serving throughput."
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm.benchmarks.sweep.cli import add_cli_args, main
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
class BenchmarkSweepSubcommand(BenchmarkSubcommandBase):
"""The `sweep` subcommand for `vllm bench`."""
name = "sweep"
help = "Benchmark for a parameter sweep."
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
add_cli_args(parser)
@staticmethod
def cmd(args: argparse.Namespace) -> None:
main(args)
...@@ -7,7 +7,7 @@ from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase ...@@ -7,7 +7,7 @@ from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
class BenchmarkThroughputSubcommand(BenchmarkSubcommandBase): class BenchmarkThroughputSubcommand(BenchmarkSubcommandBase):
"""The `throughput` subcommand for vllm bench.""" """The `throughput` subcommand for `vllm bench`."""
name = "throughput" name = "throughput"
help = "Benchmark offline inference throughput." help = "Benchmark offline inference throughput."
......
...@@ -7,7 +7,6 @@ from collections.abc import Callable ...@@ -7,7 +7,6 @@ from collections.abc import Callable
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Optional, TypeAlias from typing import Any, Optional, TypeAlias
import pandas as pd
from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult
from torch._C._profiler import _EventType, _ExperimentalConfig, _ProfilerEvent from torch._C._profiler import _EventType, _ExperimentalConfig, _ProfilerEvent
from torch.autograd.profiler import FunctionEvent from torch.autograd.profiler import FunctionEvent
...@@ -21,6 +20,12 @@ from vllm.profiler.utils import ( ...@@ -21,6 +20,12 @@ from vllm.profiler.utils import (
event_torch_op_stack_trace, event_torch_op_stack_trace,
indent_string, indent_string,
) )
from vllm.utils.import_utils import PlaceholderModule
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
@dataclass @dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment