Unverified Commit 02acd168 authored by Sophie du Couédic's avatar Sophie du Couédic Committed by GitHub
Browse files

[Benchmarks] Plot benchmark timeline and requests statistics (#35220)


Signed-off-by: default avatarSophie du Couédic <sop@zurich.ibm.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent ab87f852
...@@ -1033,7 +1033,7 @@ setup( ...@@ -1033,7 +1033,7 @@ setup(
ext_modules=ext_modules, ext_modules=ext_modules,
install_requires=get_requirements(), install_requires=get_requirements(),
extras_require={ extras_require={
"bench": ["pandas", "matplotlib", "seaborn", "datasets", "scipy"], "bench": ["pandas", "matplotlib", "seaborn", "datasets", "scipy", "plotly"],
"tensorizer": ["tensorizer==2.10.1"], "tensorizer": ["tensorizer==2.10.1"],
"fastsafetensors": ["fastsafetensors >= 0.2.2"], "fastsafetensors": ["fastsafetensors >= 0.2.2"],
"runai": ["runai-model-streamer[s3,gcs] >= 0.15.3"], "runai": ["runai-model-streamer[s3,gcs] >= 0.15.3"],
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Generate plots for benchmark results."""
from pathlib import Path
from typing import Any
from vllm.utils.import_utils import PlaceholderModule
try:
import plotly.express as px
import plotly.io as pio
except ImportError:
_plotly = PlaceholderModule("plotly")
px = _plotly.placeholder_attr("express")
pio = _plotly.placeholder_attr("io")
try:
import matplotlib.pyplot as plt
except ImportError:
_matplotlib = PlaceholderModule("matplotlib")
plt = _matplotlib.placeholder_attr("pyplot")
def generate_timeline_plot(
results: list[dict[str, Any]],
output_path: Path,
colors: list[str] | None = None,
itl_thresholds: list[float] | None = None,
labels: list[str] | None = None,
) -> None:
"""
Generate an HTML timeline plot from benchmark results.
Args:
results: List of per-request result dictionaries containing:
- start_time: Request start time (seconds)
- ttft: Time to first token (seconds)
- itl: List of inter-token latencies (seconds)
- latency: Total request latency (seconds)
- prompt_len: Number of prompt tokens
- output_tokens: Number of output tokens
output_path: Path where the HTML file will be saved
colors: List of colors for ITL categories (default: green, orange, red, black)
itl_thresholds: ITL thresholds in seconds (default: [1.0, 4.0, 6.0])
labels: Labels for ITL categories (default based on thresholds)
"""
# Set defaults
if colors is None:
colors = ["#109618", "#FF7F0E", "#D62728"]
if itl_thresholds is None:
itl_thresholds = [0.025, 0.050]
if labels is None:
labels = [
f"ITL < {itl_thresholds[0] * 1000:.0f}ms",
f"{itl_thresholds[0] * 1000:.0f}ms ≤ ITL < {itl_thresholds[1] * 1000:.0f}ms", # noqa
f"ITL ≥ {itl_thresholds[1] * 1000:.0f}ms",
]
labels_colors = {"TTFT": "#636EFA", **dict(zip(labels, colors))}
labels_order = ["TTFT"] + labels
timeline_data = construct_timeline_data(results, itl_thresholds, labels)
if not timeline_data:
print("No timeline data to plot")
return
# Create the plot
fig = px.timeline(
timeline_data,
x_start="start",
x_end="end",
y="request_id",
color="type",
color_discrete_map=labels_colors,
category_orders={"type": labels_order},
hover_data=[
"prompt_tokens",
"output_tokens",
"req_start_time",
"req_finish_time",
"segment_start",
"segment_end",
"duration",
],
)
# Customize hover template to show only time without date
fig.update_traces(
hovertemplate="<b>%{y}</b><br>"
"Type: %{fullData.name}<br>"
"Start: %{customdata[4]}<br>"
"End: %{customdata[5]}<br>"
"Duration: %{customdata[6]}<br>"
"Prompt Tokens: %{customdata[0]}<br>"
"Output Tokens: %{customdata[1]}<br>"
"Request Start Time: %{customdata[2]}<br>"
"Request End Time: %{customdata[3]}<br>"
"<extra></extra>"
)
fig.update_yaxes(autorange="reversed")
fig.update_layout(
xaxis_title="Time",
yaxis_title="Request ID",
showlegend=True,
)
# Save to HTML
pio.write_html(fig, str(output_path))
print(f"Timeline plot saved to: {output_path}")
def construct_timeline_data(
requests_data: list[dict[str, Any]],
itl_thresholds: list[float],
labels: list[str],
) -> list[dict[str, Any]]:
"""
Construct timeline data from request results.
Args:
requests_data: List of per-request result dictionaries
itl_thresholds: ITL thresholds in seconds
labels: Labels for ITL categories
Returns:
List of timeline segments for plotting
"""
def tostr(sec_time: float) -> str:
"""Convert seconds to HH:MM:SS.mmm format."""
h = int(sec_time // 3600)
assert h < 100, "time seems to last more than 100 hours"
m = int((sec_time % 3600) // 60)
s = sec_time % 60
return f"{h:02d}:{m:02d}:{s:06.3f}"
def itl_type(itl: float) -> str:
"""Categorize ITL based on thresholds."""
if itl < itl_thresholds[0]:
return labels[0]
elif itl < itl_thresholds[1]:
return labels[1]
else:
return labels[2]
# Find the earliest start time to use as t0
t0 = None
for request in requests_data:
start_time = request.get("start_time")
if start_time is not None and (t0 is None or start_time < t0):
t0 = start_time
if t0 is None:
return []
timeline_data = []
for i, request in enumerate(requests_data):
start_time = request.get("start_time")
ttft = request.get("ttft")
itl = request.get("itl", [])
latency = request.get("latency")
prompt_len = request.get("prompt_len", 0)
output_tokens = request.get("output_tokens", 0)
# Skip requests without required data
if start_time is None or ttft is None or latency is None:
continue
# Normalize start time
start_time = start_time - t0
start_time_str = tostr(start_time)
# TTFT segment
ttft_end = start_time + ttft
ttft_end_str = tostr(ttft_end)
timeline_data.append(
{
"request_id": f"Req {i}",
"start": start_time_str,
"end": ttft_end_str,
"type": "TTFT",
"prompt_tokens": prompt_len,
"output_tokens": output_tokens,
"req_start_time": tostr(start_time),
"req_finish_time": tostr(start_time + latency),
"segment_start": start_time_str,
"segment_end": ttft_end_str,
"duration": f"{ttft:.3f}s",
}
)
# ITL segments
prev_time = ttft_end
prev_time_str = ttft_end_str
for itl_value in itl:
itl_end = prev_time + itl_value
itl_end_str = tostr(itl_end)
timeline_data.append(
{
"request_id": f"Req {i}",
"start": prev_time_str,
"end": itl_end_str,
"type": itl_type(itl_value),
"prompt_tokens": prompt_len,
"output_tokens": output_tokens,
"req_start_time": tostr(start_time),
"req_finish_time": tostr(start_time + latency),
"segment_start": prev_time_str,
"segment_end": itl_end_str,
"duration": f"{itl_value:.3f}s",
}
)
prev_time = itl_end
prev_time_str = itl_end_str
return timeline_data
def generate_dataset_stats_plot(
results: list[dict[str, Any]],
output_path: Path,
) -> None:
"""
Generate a matplotlib figure with dataset statistics.
Creates a figure with 4 subplots:
- Top-left: Prompt tokens distribution (histogram)
- Top-right: Output tokens distribution (histogram)
- Bottom-left: Prompt+output tokens distribution (histogram)
- Bottom-right: Stacked bar chart (request_id vs tokens)
Args:
results: List of per-request result dictionaries containing:
- prompt_len: Number of prompt tokens
- output_tokens: Number of output tokens
output_path: Path where the figure will be saved
"""
# Extract data
prompt_tokens = []
output_tokens = []
total_tokens = []
for request in results:
prompt_len = request.get("prompt_len", 0)
output_len = request.get("output_tokens", 0)
prompt_tokens.append(prompt_len)
output_tokens.append(output_len)
total_tokens.append(prompt_len + output_len)
if not prompt_tokens:
print("No data available for dataset statistics plot")
return
# Create figure with 4 subplots
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 10))
# Top-left: Prompt tokens distribution
ax1.hist(prompt_tokens, bins=30, color="steelblue", edgecolor="black", alpha=0.7)
ax1.set_xlabel("Prompt Tokens")
ax1.set_ylabel("Frequency")
ax1.set_title("Prompt Tokens Distribution")
ax1.grid(True, alpha=0.3)
# Top-right: Output tokens distribution
ax2.hist(output_tokens, bins=30, color="coral", edgecolor="black", alpha=0.7)
ax2.set_xlabel("Output Tokens")
ax2.set_ylabel("Frequency")
ax2.set_title("Output Tokens Distribution")
ax2.grid(True, alpha=0.3)
# Bottom-left: Prompt+output tokens distribution
ax3.hist(
total_tokens, bins=30, color="mediumseagreen", edgecolor="black", alpha=0.7
)
ax3.set_xlabel("Total Tokens (Prompt + Output)")
ax3.set_ylabel("Frequency")
ax3.set_title("Total Tokens Distribution")
ax3.grid(True, alpha=0.3)
# Bottom-right: Stacked bar chart
request_ids = list(range(len(prompt_tokens)))
ax4.bar(
request_ids, prompt_tokens, label="Prompt Tokens", color="steelblue", alpha=0.7
)
ax4.bar(
request_ids,
output_tokens,
bottom=prompt_tokens,
label="Output Tokens",
color="coral",
alpha=0.7,
)
ax4.set_xlabel("Request ID")
ax4.set_ylabel("Tokens")
ax4.set_title("Tokens per Request (Stacked)")
ax4.legend()
ax4.grid(True, alpha=0.3, axis="y")
# Adjust layout to prevent overlap
plt.tight_layout()
# Save figure
plt.savefig(str(output_path), dpi=150, bbox_inches="tight")
plt.close(fig)
print(f"Dataset statistics plot saved to: {output_path}")
...@@ -34,6 +34,7 @@ from collections.abc import AsyncGenerator, Iterable ...@@ -34,6 +34,7 @@ from collections.abc import AsyncGenerator, Iterable
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from pathlib import Path
from typing import Any, Literal from typing import Any, Literal
import aiohttp import aiohttp
...@@ -1183,6 +1184,49 @@ def save_to_pytorch_benchmark_format( ...@@ -1183,6 +1184,49 @@ def save_to_pytorch_benchmark_format(
write_to_json(pt_file, pt_records) write_to_json(pt_file, pt_records)
def compute_result_filename(
args: argparse.Namespace,
model_id: str,
label: str,
current_dt: str,
) -> str | None:
"""Compute the result filename based on benchmark configuration.
Args:
args: Command line arguments containing result configuration
model_id: The model identifier
label: The benchmark label
current_dt: Current datetime string
Returns:
The computed filename path or None if no result saving is requested
"""
if not (args.plot_timeline or args.save_result or args.append_result):
return None
base_model_id = model_id.split("/")[-1]
max_concurrency_str = (
f"-concurrency{args.max_concurrency}"
if args.max_concurrency is not None
else ""
)
label = label or args.backend
if args.ramp_up_strategy is not None:
file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
else:
file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
if args.result_filename:
file_name = args.result_filename
if args.result_dir:
os.makedirs(args.result_dir, exist_ok=True)
file_name = os.path.join(args.result_dir, file_name)
return file_name
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
add_dataset_parser(parser) add_dataset_parser(parser)
parser.add_argument( parser.add_argument(
...@@ -1535,6 +1579,30 @@ def add_cli_args(parser: argparse.ArgumentParser): ...@@ -1535,6 +1579,30 @@ def add_cli_args(parser: argparse.ArgumentParser):
"connecting to servers with self-signed certificates.", "connecting to servers with self-signed certificates.",
) )
parser.add_argument(
"--plot-timeline",
action="store_true",
help="Generate an HTML timeline plot showing request execution. "
"The plot will be saved alongside the results JSON file.",
)
parser.add_argument(
"--timeline-itl-thresholds",
type=float,
nargs=2,
default=[25.0, 50.0],
metavar=("THRESHOLD1", "THRESHOLD2"),
help="ITL thresholds in milliseconds for timeline plot coloring. "
"Specify two values to categorize inter-token latencies into three groups: "
"below first threshold (green), between thresholds (orange), "
"and above second threshold (red). Default: 25 50 (milliseconds).",
)
parser.add_argument(
"--plot-dataset-stats",
action="store_true",
help="Generate a matplotlib figure with dataset statistics showing "
"prompt tokens, output tokens, and combined token distributions.",
)
def main(args: argparse.Namespace) -> dict[str, Any]: def main(args: argparse.Namespace) -> dict[str, Any]:
return asyncio.run(main_async(args)) return asyncio.run(main_async(args))
...@@ -1770,6 +1838,86 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: ...@@ -1770,6 +1838,86 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
# Merge with benchmark result # Merge with benchmark result
result_json = {**result_json, **benchmark_result} result_json = {**result_json, **benchmark_result}
# Compute file_name once before using it for plots or saving results
file_name = compute_result_filename(args, model_id, label, current_dt)
# Generate timeline plot if requested
if args.plot_timeline:
try:
from vllm.benchmarks.plot import generate_timeline_plot
# Prepare per-request data for timeline
per_request_data = []
start_times = benchmark_result.get("start_times", [])
ttfts = benchmark_result.get("ttfts", [])
itls = benchmark_result.get("itls", [])
input_lens = benchmark_result.get("input_lens", [])
output_lens = benchmark_result.get("output_lens", [])
if start_times and ttfts and itls:
for i in range(len(start_times)):
# Calculate latency as ttft + sum of all itls
latency = ttfts[i] + sum(itls[i]) if itls[i] else ttfts[i]
per_request_data.append(
{
"start_time": start_times[i],
"ttft": ttfts[i],
"itl": itls[i],
"latency": latency,
"prompt_len": input_lens[i],
"output_tokens": output_lens[i],
}
)
timeline_path = Path(file_name).with_suffix(".timeline.html")
# Convert thresholds from milliseconds to seconds
itl_thresholds_sec = [t / 1000.0 for t in args.timeline_itl_thresholds]
generate_timeline_plot(
per_request_data, timeline_path, itl_thresholds=itl_thresholds_sec
)
else:
warnings.warn(
"Timeline plot requires detailed metrics. "
"Ensure the benchmark completed successfully.",
stacklevel=2,
)
except Exception as e:
warnings.warn(f"Failed to generate timeline plot: {e}", stacklevel=2)
# Generate dataset statistics plot if requested
if args.plot_dataset_stats:
try:
from vllm.benchmarks.plot import generate_dataset_stats_plot
# Prepare per-request data for dataset stats
per_request_data = []
input_lens = benchmark_result.get("input_lens", [])
output_lens = benchmark_result.get("output_lens", [])
if input_lens and output_lens:
for req_input_len, req_output_len in zip(input_lens, output_lens):
per_request_data.append(
{
"prompt_len": req_input_len,
"output_tokens": req_output_len,
}
)
stats_path = Path(file_name).with_suffix(".dataset_stats.png")
generate_dataset_stats_plot(per_request_data, stats_path)
else:
warnings.warn(
"Dataset statistics plot requires input and "
"output length data. Ensure the benchmark completed "
"successfully.",
stacklevel=2,
)
except Exception as e:
warnings.warn(
f"Failed to generate dataset statistics plot: {e}", stacklevel=2
)
if not args.save_detailed: if not args.save_detailed:
# Remove fields with too many data points # Remove fields with too many data points
for field in [ for field in [
...@@ -1788,22 +1936,6 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: ...@@ -1788,22 +1936,6 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
# Save to file # Save to file
if args.save_result or args.append_result: if args.save_result or args.append_result:
base_model_id = model_id.split("/")[-1]
max_concurrency_str = (
f"-concurrency{args.max_concurrency}"
if args.max_concurrency is not None
else ""
)
label = label or args.backend
if args.ramp_up_strategy is not None:
file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
else:
file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
if args.result_filename:
file_name = args.result_filename
if args.result_dir:
os.makedirs(args.result_dir, exist_ok=True)
file_name = os.path.join(args.result_dir, file_name)
with open( with open(
file_name, mode="a+" if args.append_result else "w", encoding="utf-8" file_name, mode="a+" if args.append_result else "w", encoding="utf-8"
) as outfile: ) as outfile:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment