Unverified Commit b4fe16c7 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Add `vllm bench [latency, throughput]` CLI commands (#16508)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent bc5dd4f6
...@@ -341,6 +341,13 @@ steps: ...@@ -341,6 +341,13 @@ steps:
commands: commands:
- bash scripts/run-benchmarks.sh - bash scripts/run-benchmarks.sh
- label: Benchmarks CLI Test # 10min
source_file_dependencies:
- vllm/
- tests/benchmarks/
commands:
- pytest -v -s benchmarks/
- label: Quantization Test # 33min - label: Quantization Test # 33min
source_file_dependencies: source_file_dependencies:
- csrc/ - csrc/
......
# SPDX-License-Identifier: Apache-2.0
import subprocess
import pytest
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@pytest.mark.benchmark
def test_bench_latency():
command = [
"vllm", "bench", "latency", "--model", MODEL_NAME, "--input-len", "32",
"--output-len", "1", "--enforce-eager", "--load-format", "dummy"
]
result = subprocess.run(command, capture_output=True, text=True)
print(result.stdout)
print(result.stderr)
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
# SPDX-License-Identifier: Apache-2.0
import subprocess
import pytest
from ..utils import RemoteOpenAIServer
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@pytest.fixture(scope="module")
def server():
args = [
"--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.mark.benchmark
def test_bench_serve(server):
command = [
"vllm",
"bench",
"serve",
"--model",
MODEL_NAME,
"--host",
server.host,
"--port",
str(server.port),
"--random-input-len",
"32",
"--random-output-len",
"4",
"--num-prompts",
"5",
]
result = subprocess.run(command, capture_output=True, text=True)
print(result.stdout)
print(result.stderr)
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
# SPDX-License-Identifier: Apache-2.0
import subprocess
import pytest
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@pytest.mark.benchmark
def test_bench_throughput():
command = [
"vllm", "bench", "throughput", "--model", MODEL_NAME, "--input-len",
"32", "--output-len", "1", "--enforce-eager", "--load-format", "dummy"
]
result = subprocess.run(command, capture_output=True, text=True)
print(result.stdout)
print(result.stderr)
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
"""Benchmark the latency of processing a single batch of requests."""
import argparse
import dataclasses
import json
import os
import time
from pathlib import Path
from typing import Any, Optional
import numpy as np
import torch
from tqdm import tqdm
from vllm import LLM, SamplingParams
from vllm.benchmarks.utils import (convert_to_pytorch_benchmark_format,
write_to_json)
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType
from vllm.sampling_params import BeamSearchParams
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
results: dict[str, Any]) -> None:
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={"latency": results["latencies"]},
extra_info={k: results[k]
for k in ["avg_latency", "percentiles"]})
if pt_records:
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
write_to_json(pt_file, pt_records)
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--input-len", type=int, default=32)
parser.add_argument("--output-len", type=int, default=128)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument(
"--n",
type=int,
default=1,
help="Number of generated sequences per prompt.",
)
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument(
"--num-iters-warmup",
type=int,
default=10,
help="Number of iterations to run for warmup.",
)
parser.add_argument("--num-iters",
type=int,
default=30,
help="Number of iterations to run.")
parser.add_argument(
"--profile",
action="store_true",
help="profile the generation process of a single batch",
)
parser.add_argument(
"--profile-result-dir",
type=str,
default=None,
help=("path to save the pytorch profiler output. Can be visualized "
"with ui.perfetto.dev or Tensorboard."),
)
parser.add_argument(
"--output-json",
type=str,
default=None,
help="Path to save the latency results in JSON format.",
)
parser.add_argument(
"--disable-detokenize",
action="store_true",
help=("Do not detokenize responses (i.e. do not include "
"detokenization time in the latency measurement)"),
)
parser = EngineArgs.add_cli_args(parser)
def main(args: argparse.Namespace):
print(args)
engine_args = EngineArgs.from_cli_args(args)
# NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches.
llm = LLM(**dataclasses.asdict(engine_args))
assert llm.llm_engine.model_config.max_model_len >= (
args.input_len +
args.output_len), ("Please ensure that max_model_len is greater than"
" the sum of input_len and output_len.")
sampling_params = SamplingParams(
n=args.n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=args.output_len,
detokenize=not args.disable_detokenize,
)
print(sampling_params)
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_prompts: list[PromptType] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]
def llm_generate():
if not args.use_beam_search:
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
else:
llm.beam_search(
dummy_prompts,
BeamSearchParams(
beam_width=args.n,
max_tokens=args.output_len,
ignore_eos=True,
),
)
def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir:
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir)),
) as p:
llm_generate()
print(p.key_averages().table(sort_by="self_cuda_time_total"))
else:
start_time = time.perf_counter()
llm_generate()
end_time = time.perf_counter()
latency = end_time - start_time
return latency
print("Warming up...")
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
run_to_completion(profile_dir=None)
if args.profile:
profile_dir = args.profile_result_dir
if not profile_dir:
profile_dir = (Path(".") / "vllm_benchmark_result" /
f"latency_result_{time.time()}")
print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=profile_dir)
return
# Benchmark.
latencies = []
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
latencies.append(run_to_completion(profile_dir=None))
latencies = np.array(latencies)
percentages = [10, 25, 50, 75, 90, 99]
percentiles = np.percentile(latencies, percentages)
print(f"Avg latency: {np.mean(latencies)} seconds")
for percentage, percentile in zip(percentages, percentiles):
print(f"{percentage}% percentile latency: {percentile} seconds")
# Output JSON results if specified
if args.output_json:
results = {
"avg_latency": np.mean(latencies),
"latencies": latencies.tolist(),
"percentiles": dict(zip(percentages, percentiles.tolist())),
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results)
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
import argparse
from vllm.benchmarks.latency import add_cli_args, main
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
from vllm.entrypoints.cli.types import CLISubcommand
class BenchmarkLatencySubcommand(BenchmarkSubcommandBase):
""" The `latency` subcommand for vllm bench. """
def __init__(self):
self.name = "latency"
super().__init__()
@property
def help(self) -> str:
return "Benchmark the latency of a single batch of requests."
def add_cli_args(self, parser: argparse.ArgumentParser) -> None:
add_cli_args(parser)
@staticmethod
def cmd(args: argparse.Namespace) -> None:
main(args)
def cmd_init() -> list[CLISubcommand]:
return [BenchmarkLatencySubcommand()]
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import argparse import argparse
import vllm.entrypoints.cli.benchmark.latency
import vllm.entrypoints.cli.benchmark.serve import vllm.entrypoints.cli.benchmark.serve
import vllm.entrypoints.cli.benchmark.throughput
from vllm.entrypoints.cli.types import CLISubcommand from vllm.entrypoints.cli.types import CLISubcommand
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
# TODO: Add the rest of the benchmark subcommands here,
# e.g., throughput, latency, etc.
BENCHMARK_CMD_MODULES = [ BENCHMARK_CMD_MODULES = [
vllm.entrypoints.cli.benchmark.latency,
vllm.entrypoints.cli.benchmark.serve, vllm.entrypoints.cli.benchmark.serve,
vllm.entrypoints.cli.benchmark.throughput,
] ]
......
# SPDX-License-Identifier: Apache-2.0
import argparse
from vllm.benchmarks.throughput import add_cli_args, main
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
from vllm.entrypoints.cli.types import CLISubcommand
class BenchmarkThroughputSubcommand(BenchmarkSubcommandBase):
""" The `throughput` subcommand for vllm bench. """
def __init__(self):
self.name = "throughput"
super().__init__()
@property
def help(self) -> str:
return "Benchmark offline inference throughput."
def add_cli_args(self, parser: argparse.ArgumentParser) -> None:
add_cli_args(parser)
@staticmethod
def cmd(args: argparse.Namespace) -> None:
main(args)
def cmd_init() -> list[CLISubcommand]:
return [BenchmarkThroughputSubcommand()]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment