Unverified Commit 1f5b7c41 authored by Reagan Lee's avatar Reagan Lee Committed by GitHub
Browse files

Add Multimodal Processor Benchmark (#29105)


Signed-off-by: default avatarReagan Lee <reaganjlee@gmail.com>
Signed-off-by: default avatarReagan <reaganjlee@gmail.com>
parent adcf682f
# vllm bench mm-processor
## JSON CLI Arguments
--8<-- "docs/cli/json_tip.inc.md"
## Arguments
--8<-- "docs/argparse/bench_mm_processor.inc.md"
...@@ -92,6 +92,7 @@ def auto_mock(module_name: str, attr: str, max_mocks: int = 100): ...@@ -92,6 +92,7 @@ def auto_mock(module_name: str, attr: str, max_mocks: int = 100):
bench_latency = auto_mock("vllm.benchmarks", "latency") bench_latency = auto_mock("vllm.benchmarks", "latency")
bench_mm_processor = auto_mock("vllm.benchmarks", "mm_processor")
bench_serve = auto_mock("vllm.benchmarks", "serve") bench_serve = auto_mock("vllm.benchmarks", "serve")
bench_sweep_plot = auto_mock("vllm.benchmarks.sweep.plot", "SweepPlotArgs") bench_sweep_plot = auto_mock("vllm.benchmarks.sweep.plot", "SweepPlotArgs")
bench_sweep_plot_pareto = auto_mock( bench_sweep_plot_pareto = auto_mock(
...@@ -222,6 +223,7 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): ...@@ -222,6 +223,7 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
"run-batch": create_parser(openai_run_batch.make_arg_parser), "run-batch": create_parser(openai_run_batch.make_arg_parser),
# Benchmark CLI # Benchmark CLI
"bench_latency": create_parser(bench_latency.add_cli_args), "bench_latency": create_parser(bench_latency.add_cli_args),
"bench_mm_processor": create_parser(bench_mm_processor.add_cli_args),
"bench_serve": create_parser(bench_serve.add_cli_args), "bench_serve": create_parser(bench_serve.add_cli_args),
"bench_sweep_plot": create_parser(bench_sweep_plot.add_cli_args), "bench_sweep_plot": create_parser(bench_sweep_plot.add_cli_args),
"bench_sweep_plot_pareto": create_parser(bench_sweep_plot_pareto.add_cli_args), "bench_sweep_plot_pareto": create_parser(bench_sweep_plot_pareto.add_cli_args),
......
...@@ -1437,19 +1437,97 @@ def add_dataset_parser(parser: FlexibleArgumentParser): ...@@ -1437,19 +1437,97 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
) )
random_group = parser.add_argument_group("random dataset options") random_group = parser.add_argument_group("random dataset options")
random_group.add_argument( add_random_dataset_base_args(random_group)
random_mm_group = parser.add_argument_group(
"random multimodal dataset options extended from random dataset"
)
add_random_multimodal_dataset_args(random_mm_group)
hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument(
"--hf-subset", type=str, default=None, help="Subset of the HF dataset."
)
hf_group.add_argument(
"--hf-split", type=str, default=None, help="Split of the HF dataset."
)
hf_group.add_argument(
"--hf-name",
type=str,
default=None,
help=(
"Name of the dataset on HuggingFace "
"(e.g., 'lmarena-ai/VisionArena-Chat'). "
"Specify this if your dataset-path is a local path."
),
)
hf_group.add_argument(
"--hf-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output lengths "
"from the sampled HF dataset.",
)
prefix_repetition_group = parser.add_argument_group(
"prefix repetition dataset options"
)
prefix_repetition_group.add_argument(
"--prefix-repetition-prefix-len",
type=int,
default=256,
help="Number of prefix tokens per request, used only for prefix "
"repetition dataset.",
)
prefix_repetition_group.add_argument(
"--prefix-repetition-suffix-len",
type=int,
default=256,
help="Number of suffix tokens per request, used only for prefix "
"repetition dataset. Total input length is prefix_len + suffix_len.",
)
prefix_repetition_group.add_argument(
"--prefix-repetition-num-prefixes",
type=int,
default=10,
help="Number of prefixes to generate, used only for prefix repetition "
"dataset. Prompts per prefix is num_requests // num_prefixes.",
)
prefix_repetition_group.add_argument(
"--prefix-repetition-output-len",
type=int,
default=128,
help="Number of output tokens per request, used only for prefix "
"repetition dataset.",
)
def add_random_dataset_base_args(
parser_or_group: FlexibleArgumentParser | argparse._ArgumentGroup,
) -> None:
"""Add CLI arguments for base random dataset options.
This function adds arguments needed for:
- random (random dataset)
- random-mm (random multimodal dataset)
- random-rerank (random dataset for reranking)
Args:
parser_or_group: Either a parser or an argument group to add arguments to.
"""
parser_or_group.add_argument(
"--random-input-len", "--random-input-len",
type=int, type=int,
default=1024, default=1024,
help="Number of input tokens per request, used only for random sampling.", help="Number of input tokens per request, used only for random sampling.",
) )
random_group.add_argument( parser_or_group.add_argument(
"--random-output-len", "--random-output-len",
type=int, type=int,
default=128, default=128,
help="Number of output tokens per request, used only for random sampling.", help="Number of output tokens per request, used only for random sampling.",
) )
random_group.add_argument( parser_or_group.add_argument(
"--random-range-ratio", "--random-range-ratio",
type=float, type=float,
default=0.0, default=0.0,
...@@ -1458,7 +1536,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser): ...@@ -1458,7 +1536,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"a symmetric sampling range" "a symmetric sampling range"
"[length * (1 - range_ratio), length * (1 + range_ratio)].", "[length * (1 - range_ratio), length * (1 + range_ratio)].",
) )
random_group.add_argument( parser_or_group.add_argument(
"--random-prefix-len", "--random-prefix-len",
type=int, type=int,
default=0, default=0,
...@@ -1471,13 +1549,13 @@ def add_dataset_parser(parser: FlexibleArgumentParser): ...@@ -1471,13 +1549,13 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"input_len * (1 + range_ratio)]." "input_len * (1 + range_ratio)]."
), ),
) )
random_group.add_argument( parser_or_group.add_argument(
"--random-batch-size", "--random-batch-size",
type=int, type=int,
default=1, default=1,
help=("Batch size for random sampling. Only used for embeddings benchmark."), help=("Batch size for random sampling. Only used for embeddings benchmark."),
) )
random_group.add_argument( parser_or_group.add_argument(
"--no-reranker", "--no-reranker",
action="store_true", action="store_true",
help=( help=(
...@@ -1486,11 +1564,19 @@ def add_dataset_parser(parser: FlexibleArgumentParser): ...@@ -1486,11 +1564,19 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
), ),
) )
# random multimodal dataset options
random_mm_group = parser.add_argument_group( def add_random_multimodal_dataset_args(
"random multimodal dataset options extended from random dataset" parser_or_group: FlexibleArgumentParser | argparse._ArgumentGroup,
) ) -> None:
random_mm_group.add_argument( """Add CLI arguments for random multimodal dataset options.
This function adds arguments needed for:
- random-mm (random multimodal dataset)
Args:
parser_or_group: Either a parser or an argument group to add arguments to.
"""
parser_or_group.add_argument(
"--random-mm-base-items-per-request", "--random-mm-base-items-per-request",
type=int, type=int,
default=RandomMultiModalDataset.DEFAULT_BASE_ITEMS_PER_REQUEST, default=RandomMultiModalDataset.DEFAULT_BASE_ITEMS_PER_REQUEST,
...@@ -1500,7 +1586,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser): ...@@ -1500,7 +1586,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"--random-mm-num-mm-items-range-ratio." "--random-mm-num-mm-items-range-ratio."
), ),
) )
random_mm_group.add_argument( parser_or_group.add_argument(
"--random-mm-num-mm-items-range-ratio", "--random-mm-num-mm-items-range-ratio",
type=float, type=float,
default=RandomMultiModalDataset.DEFAULT_NUM_MM_ITEMS_RANGE_RATIO, default=RandomMultiModalDataset.DEFAULT_NUM_MM_ITEMS_RANGE_RATIO,
...@@ -1515,7 +1601,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser): ...@@ -1515,7 +1601,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"An error is raised if the computed min exceeds the max." "An error is raised if the computed min exceeds the max."
), ),
) )
random_mm_group.add_argument( parser_or_group.add_argument(
"--random-mm-limit-mm-per-prompt", "--random-mm-limit-mm-per-prompt",
type=json.loads, type=json.loads,
default=RandomMultiModalDataset.DEFAULT_LIMIT_MM_PER_PROMPT, default=RandomMultiModalDataset.DEFAULT_LIMIT_MM_PER_PROMPT,
...@@ -1559,7 +1645,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser): ...@@ -1559,7 +1645,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
return normalize(parsed) return normalize(parsed)
raise ValueError("Unsupported value for --random-mm-bucket-config.") raise ValueError("Unsupported value for --random-mm-bucket-config.")
random_mm_group.add_argument( parser_or_group.add_argument(
"--random-mm-bucket-config", "--random-mm-bucket-config",
type=_parse_mm_bucket_config, type=_parse_mm_bucket_config,
default=RandomMultiModalDataset.DEFAULT_MM_ITEM_BUCKET_CONFIG, default=RandomMultiModalDataset.DEFAULT_MM_ITEM_BUCKET_CONFIG,
...@@ -1580,63 +1666,6 @@ def add_dataset_parser(parser: FlexibleArgumentParser): ...@@ -1580,63 +1666,6 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
), ),
) )
hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument(
"--hf-subset", type=str, default=None, help="Subset of the HF dataset."
)
hf_group.add_argument(
"--hf-split", type=str, default=None, help="Split of the HF dataset."
)
hf_group.add_argument(
"--hf-name",
type=str,
default=None,
help=(
"Name of the dataset on HuggingFace "
"(e.g., 'lmarena-ai/VisionArena-Chat'). "
"Specify this if your dataset-path is a local path."
),
)
hf_group.add_argument(
"--hf-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output lengths "
"from the sampled HF dataset.",
)
prefix_repetition_group = parser.add_argument_group(
"prefix repetition dataset options"
)
prefix_repetition_group.add_argument(
"--prefix-repetition-prefix-len",
type=int,
default=256,
help="Number of prefix tokens per request, used only for prefix "
"repetition dataset.",
)
prefix_repetition_group.add_argument(
"--prefix-repetition-suffix-len",
type=int,
default=256,
help="Number of suffix tokens per request, used only for prefix "
"repetition dataset. Total input length is prefix_len + suffix_len.",
)
prefix_repetition_group.add_argument(
"--prefix-repetition-num-prefixes",
type=int,
default=10,
help="Number of prefixes to generate, used only for prefix repetition "
"dataset. Prompts per prefix is num_requests // num_prefixes.",
)
prefix_repetition_group.add_argument(
"--prefix-repetition-output-len",
type=int,
default=128,
help="Number of output tokens per request, used only for prefix "
"repetition dataset.",
)
def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]: def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]:
if not hasattr(args, "request_id_prefix"): if not hasattr(args, "request_id_prefix"):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
r"""Benchmark multimodal processor latency.
This benchmark measures the latency of the mm processor module
using multimodal prompts from datasets.
MM processor stats are automatically enabled.
Run:
vllm bench mm-processor \
--model <your_model> \
--dataset-name random-mm \
--num-prompts 10 \
"""
import argparse
import dataclasses
import json
import time
from datetime import datetime
from typing import Any
import numpy as np
from vllm.benchmarks.throughput import get_requests
from vllm.engine.arg_utils import EngineArgs
from vllm.multimodal.processing import (
get_timing_stats_from_engine_client,
)
from vllm.utils.gc_utils import freeze_gc_heap
from vllm.utils.import_utils import PlaceholderModule
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
def collect_mm_processor_stats(
llm_engine: Any,
) -> dict[str, list[float]]:
"""
Collect multimodal processor timing stats.
Returns a dictionary mapping stage names to lists of timing values (in seconds).
"""
all_stats = get_timing_stats_from_engine_client(llm_engine)
stats_by_stage = {
"hf_processor_time": [],
"hashing_time": [],
"cache_lookup_time": [],
"prompt_update_time": [],
"total_time": [],
}
for stats_dict in all_stats.values():
stats_by_stage["hf_processor_time"].append(
stats_dict.get("hf_processor_time", 0.0)
)
stats_by_stage["hashing_time"].append(stats_dict.get("hashing_time", 0.0))
stats_by_stage["cache_lookup_time"].append(
stats_dict.get("cache_lookup_time", 0.0)
)
stats_by_stage["prompt_update_time"].append(
stats_dict.get("prompt_update_time", 0.0)
)
stats_by_stage["total_time"].append(stats_dict.get("total_time", 0.0))
return stats_by_stage
def calculate_mm_processor_metrics(
stats_by_stage: dict[str, list[float]],
selected_percentiles: list[float],
) -> dict[str, dict[str, float]]:
"""
Calculate aggregate metrics from stats by stage.
"""
metrics = {}
for stage_name, times in stats_by_stage.items():
if not times:
metrics[stage_name] = {
"mean": 0.0,
"median": 0.0,
"std": 0.0,
**{f"p{p}": 0.0 for p in selected_percentiles},
}
continue
times_ms = [t * 1000 for t in times]
metrics[stage_name] = {
"mean": float(np.mean(times_ms)),
"median": float(np.median(times_ms)),
"std": float(np.std(times_ms)),
**{
f"p{p}": float(np.percentile(times_ms, p)) for p in selected_percentiles
},
}
return metrics
def validate_args(args):
"""
Validate command-line arguments for mm_processor benchmark.
"""
if not getattr(args, "tokenizer", None):
args.tokenizer = args.model
if not hasattr(args, "dataset_path"):
args.dataset_path = None
if not hasattr(args, "lora_path"):
args.lora_path = None
if not hasattr(args, "max_loras"):
args.max_loras = None
def benchmark_multimodal_processor(
args: argparse.Namespace,
) -> dict[str, Any]:
"""
Run the multimodal processor benchmark.
"""
from vllm import LLM, SamplingParams
validate_args(args)
if args.seed is None:
args.seed = 0
engine_args = EngineArgs.from_cli_args(args)
llm = LLM(**dataclasses.asdict(engine_args))
tokenizer = llm.get_tokenizer()
requests = get_requests(args, tokenizer)
assert all(
llm.llm_engine.model_config.max_model_len
>= (request.prompt_len + request.expected_output_len)
for request in requests
), (
"Please ensure that max_model_len is greater than the sum of "
"prompt_len and expected_output_len for all requests."
)
prompts = [request.prompt for request in requests]
expected_output_lens = [request.expected_output_len for request in requests]
sampling_params = [
SamplingParams(
n=1,
temperature=0.0,
max_tokens=output_len,
detokenize=True,
)
for output_len in expected_output_lens
]
selected_percentiles = [
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
]
freeze_gc_heap()
print(f"Processing {len(prompts)} requests...")
start_time = time.perf_counter()
outputs = llm.chat(
prompts, sampling_params, use_tqdm=not getattr(args, "disable_tqdm", False)
)
end_time = time.perf_counter()
total_time = end_time - start_time
mm_stats_by_stage = collect_mm_processor_stats(
llm.llm_engine,
)
if not any(mm_stats_by_stage.values()):
print(
"\n⚠️ Warning: No MM processor stats found in registry.\n"
" This may indicate that:\n"
" - No multimodal requests were processed\n"
" - Stats were already retrieved (registry is cleared after retrieval)\n"
)
mm_processor_metrics = calculate_mm_processor_metrics(
mm_stats_by_stage, selected_percentiles
)
completed = len([o for o in outputs if o.finished])
failed = len(outputs) - completed
e2el_times = []
for output in outputs:
if not output.finished or output.metrics is None:
continue
metrics = output.metrics
for attr in ("finished_time", "last_token_time"):
if (
getattr(metrics, attr, None) is not None
and getattr(metrics, "arrival_time", None) is not None
):
e2el_times.append(
(getattr(metrics, attr) - metrics.arrival_time) * 1000
)
break
if not e2el_times and completed > 0:
avg_time_per_request = total_time / completed
e2el_times = [avg_time_per_request * 1000] * completed
if e2el_times:
mean_e2el_ms = float(np.mean(e2el_times))
median_e2el_ms = float(np.median(e2el_times))
std_e2el_ms = float(np.std(e2el_times))
percentiles_e2el_ms = [
(p, float(np.percentile(e2el_times, p))) for p in selected_percentiles
]
else:
mean_e2el_ms = 0.0
median_e2el_ms = 0.0
std_e2el_ms = 0.0
percentiles_e2el_ms = [(p, 0.0) for p in selected_percentiles]
benchmark_result = {
"completed": completed,
"failed": failed,
"mean_e2el_ms": mean_e2el_ms,
"median_e2el_ms": median_e2el_ms,
"std_e2el_ms": std_e2el_ms,
"percentiles_e2el_ms": percentiles_e2el_ms,
"mm_processor_stats": mm_processor_metrics,
}
return benchmark_result
def add_cli_args(parser: argparse.ArgumentParser) -> None:
"""Add CLI arguments for the multimodal processor benchmark."""
from vllm.engine.arg_utils import EngineArgs
EngineArgs.add_cli_args(parser)
parser.set_defaults(enable_mm_processor_stats=True)
parser.add_argument(
"--dataset-name",
type=str,
default="random-mm",
choices=["random-mm", "random-rerank"],
help="Name of the dataset to benchmark on. Defaults to 'random-mm'.",
)
parser.add_argument(
"--num-prompts",
type=int,
default=10,
help="Number of prompts to process.",
)
from vllm.benchmarks.datasets import (
add_random_dataset_base_args,
add_random_multimodal_dataset_args,
)
add_random_dataset_base_args(parser)
add_random_multimodal_dataset_args(parser)
parser.add_argument(
"--output-json",
type=str,
default=None,
help="Path to save the benchmark results in JSON format.",
)
parser.add_argument(
"--metric-percentiles",
type=str,
default="99",
help="Comma-separated list of percentiles to calculate (e.g., '50,90,99').",
)
parser.add_argument(
"--disable-tqdm",
action="store_true",
help="Disable tqdm progress bar.",
)
def main(args: argparse.Namespace) -> None:
"""Main entry point for the multimodal processor benchmark."""
print("Starting multimodal processor benchmark...")
result = benchmark_multimodal_processor(args)
print("\n" + "=" * 80)
print("Multimodal Processor Benchmark Results")
print("=" * 80)
if "mm_processor_stats" in result:
print("\nMM Processor Timing (ms):")
selected_percentiles = [
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
]
mm_data = []
for stage, metrics in result["mm_processor_stats"].items():
row = {
"Stage": stage,
"Mean": f"{metrics['mean']:.2f}",
"Median": f"{metrics['median']:.2f}",
"Std": f"{metrics['std']:.2f}",
}
for p in selected_percentiles:
row[f"P{p}"] = f"{metrics.get(f'p{p}', 0.0):.2f}"
mm_data.append(row)
mm_df = pd.DataFrame(mm_data)
print(mm_df.to_string(index=False))
if "mean_e2el_ms" in result:
print("\nEnd-to-End Latency (ms):")
selected_percentiles = [
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
]
e2el_data = [
{"Metric": "Mean", "Value (ms)": f"{result['mean_e2el_ms']:.2f}"},
{"Metric": "Median", "Value (ms)": f"{result['median_e2el_ms']:.2f}"},
{"Metric": "Std", "Value (ms)": f"{result['std_e2el_ms']:.2f}"},
]
for p in selected_percentiles:
percentile_value = next(
(val for pct, val in result["percentiles_e2el_ms"] if pct == p),
0.0,
)
e2el_data.append(
{
"Metric": f"P{p}",
"Value (ms)": f"{percentile_value:.2f}",
}
)
e2el_df = pd.DataFrame(e2el_data)
print(e2el_df.to_string(index=False))
if args.output_json:
result["config"] = {
"model": args.model,
"num_prompts": args.num_prompts,
"input_len": getattr(args, "random_input_len", None),
"output_len": getattr(args, "random_output_len", None),
}
result["timestamp"] = datetime.now().isoformat()
with open(args.output_json, "w") as f:
json.dump(result, f, indent=2)
print(f"\nResults saved to {args.output_json}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark mm processor latency")
add_cli_args(parser)
args = parser.parse_args()
main(args)
...@@ -24,10 +24,14 @@ from vllm.benchmarks.datasets import ( ...@@ -24,10 +24,14 @@ from vllm.benchmarks.datasets import (
MultiModalConversationDataset, MultiModalConversationDataset,
PrefixRepetitionRandomDataset, PrefixRepetitionRandomDataset,
RandomDataset, RandomDataset,
RandomDatasetForReranking,
RandomMultiModalDataset,
SampleRequest, SampleRequest,
ShareGPTDataset, ShareGPTDataset,
SonnetDataset, SonnetDataset,
VisionArenaDataset, VisionArenaDataset,
add_random_dataset_base_args,
add_random_multimodal_dataset_args,
) )
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
...@@ -342,8 +346,6 @@ def get_requests(args, tokenizer): ...@@ -342,8 +346,6 @@ def get_requests(args, tokenizer):
"lora_path": args.lora_path, "lora_path": args.lora_path,
"max_loras": args.max_loras, "max_loras": args.max_loras,
"num_requests": args.num_prompts, "num_requests": args.num_prompts,
"input_len": args.input_len,
"output_len": args.output_len,
} }
if args.dataset_name == "random" or ( if args.dataset_name == "random" or (
...@@ -351,12 +353,26 @@ def get_requests(args, tokenizer): ...@@ -351,12 +353,26 @@ def get_requests(args, tokenizer):
and args.dataset_name not in {"prefix_repetition", "random-mm", "random-rerank"} and args.dataset_name not in {"prefix_repetition", "random-mm", "random-rerank"}
): ):
sample_kwargs["range_ratio"] = args.random_range_ratio sample_kwargs["range_ratio"] = args.random_range_ratio
sample_kwargs["prefix_len"] = args.prefix_len # prefer random_* arguments, fall back to regular arguments
random_prefix_len = getattr(args, "random_prefix_len", None)
sample_kwargs["prefix_len"] = (
random_prefix_len if random_prefix_len is not None else args.prefix_len
)
random_input_len = getattr(args, "random_input_len", None)
sample_kwargs["input_len"] = (
random_input_len if random_input_len is not None else args.input_len
)
random_output_len = getattr(args, "random_output_len", None)
sample_kwargs["output_len"] = (
random_output_len if random_output_len is not None else args.output_len
)
dataset_cls = RandomDataset dataset_cls = RandomDataset
elif args.dataset_name == "sharegpt": elif args.dataset_name == "sharegpt":
dataset_cls = ShareGPTDataset dataset_cls = ShareGPTDataset
if args.backend == "vllm-chat": if args.backend == "vllm-chat":
sample_kwargs["enable_multimodal_chat"] = True sample_kwargs["enable_multimodal_chat"] = True
if args.output_len is not None:
sample_kwargs["output_len"] = args.output_len
elif args.dataset_name == "sonnet": elif args.dataset_name == "sonnet":
assert tokenizer.chat_template or tokenizer.default_chat_template, ( assert tokenizer.chat_template or tokenizer.default_chat_template, (
"Tokenizer/model must have chat template for sonnet dataset." "Tokenizer/model must have chat template for sonnet dataset."
...@@ -364,9 +380,15 @@ def get_requests(args, tokenizer): ...@@ -364,9 +380,15 @@ def get_requests(args, tokenizer):
dataset_cls = SonnetDataset dataset_cls = SonnetDataset
sample_kwargs["prefix_len"] = args.prefix_len sample_kwargs["prefix_len"] = args.prefix_len
sample_kwargs["return_prompt_formatted"] = True sample_kwargs["return_prompt_formatted"] = True
if args.input_len is not None:
sample_kwargs["input_len"] = args.input_len
if args.output_len is not None:
sample_kwargs["output_len"] = args.output_len
elif args.dataset_name == "burstgpt": elif args.dataset_name == "burstgpt":
dataset_cls = BurstGPTDataset dataset_cls = BurstGPTDataset
elif args.dataset_name == "hf": elif args.dataset_name == "hf":
if args.output_len is not None:
sample_kwargs["output_len"] = args.output_len
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = VisionArenaDataset dataset_cls = VisionArenaDataset
common_kwargs["dataset_subset"] = None common_kwargs["dataset_subset"] = None
...@@ -395,6 +417,56 @@ def get_requests(args, tokenizer): ...@@ -395,6 +417,56 @@ def get_requests(args, tokenizer):
sample_kwargs["suffix_len"] = args.prefix_repetition_suffix_len sample_kwargs["suffix_len"] = args.prefix_repetition_suffix_len
sample_kwargs["num_prefixes"] = args.prefix_repetition_num_prefixes sample_kwargs["num_prefixes"] = args.prefix_repetition_num_prefixes
sample_kwargs["output_len"] = args.prefix_repetition_output_len sample_kwargs["output_len"] = args.prefix_repetition_output_len
elif args.dataset_name == "random-mm":
dataset_cls = RandomMultiModalDataset
# prefer random_* arguments, fall back to regular arguments
random_input_len = getattr(args, "random_input_len", None)
sample_kwargs["input_len"] = (
random_input_len
if random_input_len is not None
else getattr(args, "input_len", None)
)
random_output_len = getattr(args, "random_output_len", None)
sample_kwargs["output_len"] = (
random_output_len
if random_output_len is not None
else getattr(args, "output_len", None)
)
sample_kwargs["base_items_per_request"] = getattr(
args, "random_mm_base_items_per_request", None
)
sample_kwargs["num_mm_items_range_ratio"] = getattr(
args, "random_mm_num_mm_items_range_ratio", None
)
sample_kwargs["limit_mm_per_prompt"] = getattr(
args, "random_mm_limit_mm_per_prompt", None
)
sample_kwargs["bucket_config"] = getattr(args, "random_mm_bucket_config", None)
sample_kwargs["enable_multimodal_chat"] = True
random_prefix_len = getattr(args, "random_prefix_len", None)
prefix_len = getattr(args, "prefix_len", None)
sample_kwargs["prefix_len"] = (
random_prefix_len if random_prefix_len is not None else prefix_len
)
sample_kwargs["range_ratio"] = args.random_range_ratio
elif args.dataset_name == "random-rerank":
dataset_cls = RandomDatasetForReranking
# prefer random_* arguments, fall back to regular arguments
random_input_len = getattr(args, "random_input_len", None)
sample_kwargs["input_len"] = (
random_input_len
if random_input_len is not None
else getattr(args, "input_len", None)
)
random_output_len = getattr(args, "random_output_len", None)
sample_kwargs["output_len"] = (
random_output_len
if random_output_len is not None
else getattr(args, "output_len", None)
)
sample_kwargs["batchsize"] = getattr(args, "random_batch_size", 1)
sample_kwargs["is_reranker"] = not getattr(args, "no_reranker", False)
sample_kwargs["range_ratio"] = args.random_range_ratio
else: else:
raise ValueError(f"Unknown dataset name: {args.dataset_name}") raise ValueError(f"Unknown dataset name: {args.dataset_name}")
# Remove None values # Remove None values
...@@ -451,8 +523,12 @@ def validate_args(args): ...@@ -451,8 +523,12 @@ def validate_args(args):
): ):
print("When dataset path is not set, it will default to random dataset") print("When dataset path is not set, it will default to random dataset")
args.dataset_name = "random" args.dataset_name = "random"
if args.input_len is None: random_input_len = getattr(args, "random_input_len", None)
raise ValueError("input_len must be provided for a random dataset") if args.input_len is None and random_input_len is None:
raise ValueError(
"Either --input-len or --random-input-len must be provided "
"for a random dataset"
)
# === Dataset Name Specific Checks === # === Dataset Name Specific Checks ===
# --hf-subset and --hf-split: only used # --hf-subset and --hf-split: only used
...@@ -485,26 +561,79 @@ def validate_args(args): ...@@ -485,26 +561,79 @@ def validate_args(args):
else: else:
raise ValueError(f"{args.dataset_path} is not supported by hf dataset.") raise ValueError(f"{args.dataset_path} is not supported by hf dataset.")
# --random-range-ratio: only used when dataset_name is 'random' # --random-range-ratio: only used when dataset_name is 'random',
if args.dataset_name != "random" and args.random_range_ratio is not None: # 'random-mm', or 'random-rerank'
if (
args.dataset_name not in {"random", "random-mm", "random-rerank"}
and args.random_range_ratio is not None
):
warnings.warn( warnings.warn(
"--random-range-ratio will be ignored since \ "--random-range-ratio will be ignored since \
--dataset-name is not 'random'.", --dataset-name is not 'random', 'random-mm', or 'random-rerank'.",
stacklevel=2,
)
# --random-batch-size: only used when dataset_name is 'random-rerank'
if (
args.dataset_name != "random-rerank"
and getattr(args, "random_batch_size", None) is not None
) and args.random_batch_size != 1:
warnings.warn(
"--random-batch-size will be ignored since \
--dataset-name is not 'random-rerank'.",
stacklevel=2,
)
# --no-reranker: only used when dataset_name is 'random-rerank'
if args.dataset_name != "random-rerank" and getattr(args, "no_reranker", False):
warnings.warn(
"--no-reranker will be ignored since \
--dataset-name is not 'random-rerank'.",
stacklevel=2, stacklevel=2,
) )
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not # --prefix-len: only used when dataset_name is 'random', 'random-mm',
# set. # 'sonnet', or not set.
if ( if (
args.dataset_name not in {"random", "sonnet", None} args.dataset_name not in {"random", "random-mm", "sonnet", None}
and args.prefix_len is not None and args.prefix_len is not None
): ):
warnings.warn( warnings.warn(
"--prefix-len will be ignored since --dataset-name\ "--prefix-len will be ignored since --dataset-name\
is not 'random', 'sonnet', or not set.", is not 'random', 'random-mm', 'sonnet', or not set.",
stacklevel=2, stacklevel=2,
) )
# === Random Dataset Argument Conflict Detection ===
# Check for conflicts between regular and random arguments when using
# random datasets
if args.dataset_name in {"random", "random-mm", "random-rerank"}:
random_input_len = getattr(args, "random_input_len", None)
random_output_len = getattr(args, "random_output_len", None)
random_prefix_len = getattr(args, "random_prefix_len", None)
if args.input_len is not None and random_input_len is not None:
warnings.warn(
"Both --input-len and --random-input-len are specified. "
"The random version (--random-input-len) will be preferred "
"in this run.",
stacklevel=2,
)
if args.output_len is not None and random_output_len is not None:
warnings.warn(
"Both --output-len and --random-output-len are specified. "
"The random version (--random-output-len) will be preferred "
"in this run.",
stacklevel=2,
)
if args.prefix_len is not None and random_prefix_len is not None:
warnings.warn(
"Both --prefix-len and --random-prefix-len are specified. "
"The random version (--random-prefix-len) will be preferred "
"in this run.",
stacklevel=2,
)
# === LoRA Settings === # === LoRA Settings ===
if getattr(args, "enable_lora", False) and args.backend != "vllm": if getattr(args, "enable_lora", False) and args.backend != "vllm":
raise ValueError("LoRA benchmarking is only supported for vLLM backend") raise ValueError("LoRA benchmarking is only supported for vLLM backend")
...@@ -554,7 +683,16 @@ def add_cli_args(parser: argparse.ArgumentParser): ...@@ -554,7 +683,16 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--dataset-name", "--dataset-name",
type=str, type=str,
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf", "prefix_repetition"], choices=[
"sharegpt",
"random",
"sonnet",
"burstgpt",
"hf",
"prefix_repetition",
"random-mm",
"random-rerank",
],
help="Name of the dataset to benchmark on.", help="Name of the dataset to benchmark on.",
default="sharegpt", default="sharegpt",
) )
...@@ -636,23 +774,19 @@ def add_cli_args(parser: argparse.ArgumentParser): ...@@ -636,23 +774,19 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="Number of fixed prefix tokens before the random " help="Number of fixed prefix tokens before the random "
"context in a request (default: 0).", "context in a request (default: 0).",
) )
# random dataset
parser.add_argument(
"--random-range-ratio",
type=float,
default=0.0,
help="Range ratio for sampling input/output length, "
"used only for RandomDataset. Must be in the range [0, 1) to define "
"a symmetric sampling range "
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
)
# hf dtaset # hf dtaset
parser.add_argument( parser.add_argument(
"--hf-subset", type=str, default=None, help="Subset of the HF dataset." "--hf-subset",
type=str,
default=None,
help="Subset of the HF dataset.",
) )
parser.add_argument( parser.add_argument(
"--hf-split", type=str, default=None, help="Split of the HF dataset." "--hf-split",
type=str,
default=None,
help="Split of the HF dataset.",
) )
parser.add_argument( parser.add_argument(
"--profile", "--profile",
...@@ -662,31 +796,28 @@ def add_cli_args(parser: argparse.ArgumentParser): ...@@ -662,31 +796,28 @@ def add_cli_args(parser: argparse.ArgumentParser):
) )
# prefix repetition dataset # prefix repetition dataset
prefix_repetition_group = parser.add_argument_group( parser.add_argument(
"prefix repetition dataset options"
)
prefix_repetition_group.add_argument(
"--prefix-repetition-prefix-len", "--prefix-repetition-prefix-len",
type=int, type=int,
default=None, default=None,
help="Number of prefix tokens per request, used only for prefix " help="Number of prefix tokens per request, used only for prefix "
"repetition dataset.", "repetition dataset.",
) )
prefix_repetition_group.add_argument( parser.add_argument(
"--prefix-repetition-suffix-len", "--prefix-repetition-suffix-len",
type=int, type=int,
default=None, default=None,
help="Number of suffix tokens per request, used only for prefix " help="Number of suffix tokens per request, used only for prefix "
"repetition dataset. Total input length is prefix_len + suffix_len.", "repetition dataset. Total input length is prefix_len + suffix_len.",
) )
prefix_repetition_group.add_argument( parser.add_argument(
"--prefix-repetition-num-prefixes", "--prefix-repetition-num-prefixes",
type=int, type=int,
default=None, default=None,
help="Number of prefixes to generate, used only for prefix repetition " help="Number of prefixes to generate, used only for prefix repetition "
"dataset. Prompts per prefix is num_requests // num_prefixes.", "dataset. Prompts per prefix is num_requests // num_prefixes.",
) )
prefix_repetition_group.add_argument( parser.add_argument(
"--prefix-repetition-output-len", "--prefix-repetition-output-len",
type=int, type=int,
default=None, default=None,
...@@ -694,6 +825,10 @@ def add_cli_args(parser: argparse.ArgumentParser): ...@@ -694,6 +825,10 @@ def add_cli_args(parser: argparse.ArgumentParser):
"repetition dataset.", "repetition dataset.",
) )
# (random, random-mm, random-rerank)
add_random_dataset_base_args(parser)
add_random_multimodal_dataset_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
......
...@@ -67,6 +67,14 @@ class ObservabilityConfig: ...@@ -67,6 +67,14 @@ class ObservabilityConfig:
enable_mfu_metrics: bool = False enable_mfu_metrics: bool = False
"""Enable Model FLOPs Utilization (MFU) metrics.""" """Enable Model FLOPs Utilization (MFU) metrics."""
enable_mm_processor_stats: bool = False
"""Enable collection of timing statistics for multimodal processor operations.
This is for internal use only (e.g., benchmarks) and is not exposed as a CLI
argument."""
enable_mfu_metrics: bool = False
"""Enable Model FLOPs Utilization (MFU) metrics."""
@cached_property @cached_property
def collect_model_forward_time(self) -> bool: def collect_model_forward_time(self) -> bool:
"""Whether to collect model forward time for the request.""" """Whether to collect model forward time for the request."""
......
...@@ -523,6 +523,7 @@ class EngineArgs: ...@@ -523,6 +523,7 @@ class EngineArgs:
ObservabilityConfig.enable_layerwise_nvtx_tracing ObservabilityConfig.enable_layerwise_nvtx_tracing
) )
enable_mfu_metrics: bool = ObservabilityConfig.enable_mfu_metrics enable_mfu_metrics: bool = ObservabilityConfig.enable_mfu_metrics
enable_mm_processor_stats: bool = ObservabilityConfig.enable_mm_processor_stats
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
...@@ -1712,6 +1713,7 @@ class EngineArgs: ...@@ -1712,6 +1713,7 @@ class EngineArgs:
cudagraph_metrics=self.cudagraph_metrics, cudagraph_metrics=self.cudagraph_metrics,
enable_layerwise_nvtx_tracing=self.enable_layerwise_nvtx_tracing, enable_layerwise_nvtx_tracing=self.enable_layerwise_nvtx_tracing,
enable_mfu_metrics=self.enable_mfu_metrics, enable_mfu_metrics=self.enable_mfu_metrics,
enable_mm_processor_stats=self.enable_mm_processor_stats,
) )
# Compilation config overrides # Compilation config overrides
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand
from vllm.entrypoints.cli.benchmark.mm_processor import (
BenchmarkMMProcessorSubcommand,
)
from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand
from vllm.entrypoints.cli.benchmark.startup import BenchmarkStartupSubcommand from vllm.entrypoints.cli.benchmark.startup import BenchmarkStartupSubcommand
from vllm.entrypoints.cli.benchmark.sweep import BenchmarkSweepSubcommand from vllm.entrypoints.cli.benchmark.sweep import BenchmarkSweepSubcommand
...@@ -8,6 +11,7 @@ from vllm.entrypoints.cli.benchmark.throughput import BenchmarkThroughputSubcomm ...@@ -8,6 +11,7 @@ from vllm.entrypoints.cli.benchmark.throughput import BenchmarkThroughputSubcomm
__all__: list[str] = [ __all__: list[str] = [
"BenchmarkLatencySubcommand", "BenchmarkLatencySubcommand",
"BenchmarkMMProcessorSubcommand",
"BenchmarkServingSubcommand", "BenchmarkServingSubcommand",
"BenchmarkStartupSubcommand", "BenchmarkStartupSubcommand",
"BenchmarkSweepSubcommand", "BenchmarkSweepSubcommand",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm.benchmarks.mm_processor import add_cli_args, main
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
class BenchmarkMMProcessorSubcommand(BenchmarkSubcommandBase):
"""The `mm-processor` subcommand for `vllm bench`."""
name = "mm-processor"
help = "Benchmark multimodal processor latency across different configurations."
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
add_cli_args(parser)
@staticmethod
def cmd(args: argparse.Namespace) -> None:
main(args)
...@@ -6,7 +6,7 @@ from typing import Any, cast ...@@ -6,7 +6,7 @@ from typing import Any, cast
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm.config import ModelConfig from vllm.config import ModelConfig, ObservabilityConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.cache import BaseMultiModalProcessorCache
...@@ -47,6 +47,7 @@ class InputPreprocessor: ...@@ -47,6 +47,7 @@ class InputPreprocessor:
self, self,
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: TokenizerLike | None, tokenizer: TokenizerLike | None,
observability_config: ObservabilityConfig | None = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: BaseMultiModalProcessorCache | None = None, mm_processor_cache: BaseMultiModalProcessorCache | None = None,
) -> None: ) -> None:
...@@ -54,6 +55,7 @@ class InputPreprocessor: ...@@ -54,6 +55,7 @@ class InputPreprocessor:
self.model_config = model_config self.model_config = model_config
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.observability_config = observability_config
self.mm_registry = mm_registry self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache self.mm_processor_cache = mm_processor_cache
...@@ -232,6 +234,7 @@ class InputPreprocessor: ...@@ -232,6 +234,7 @@ class InputPreprocessor:
if not hasattr(self, "_mm_processor"): if not hasattr(self, "_mm_processor"):
self._mm_processor = self.mm_registry.create_processor( self._mm_processor = self.mm_registry.create_processor(
self.model_config, self.model_config,
self.observability_config,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
cache=self.mm_processor_cache, cache=self.mm_processor_cache,
) )
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextvars
import threading
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable, Generator, ItemsView, Iterable, Mapping, Sequence from collections.abc import Callable, Generator, ItemsView, Iterable, Mapping, Sequence
from contextlib import contextmanager
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from enum import Enum from enum import Enum
from functools import lru_cache from functools import lru_cache
...@@ -53,7 +56,7 @@ if TYPE_CHECKING: ...@@ -53,7 +56,7 @@ if TYPE_CHECKING:
from transformers.feature_extraction_utils import BatchFeature from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin from transformers.processing_utils import ProcessorMixin
from vllm.config import ModelConfig from vllm.config import ModelConfig, ObservabilityConfig
from .cache import BaseMultiModalProcessorCache from .cache import BaseMultiModalProcessorCache
from .profiling import BaseDummyInputsBuilder from .profiling import BaseDummyInputsBuilder
...@@ -63,6 +66,7 @@ else: ...@@ -63,6 +66,7 @@ else:
ProcessorMixin = object ProcessorMixin = object
ModelConfig = object ModelConfig = object
ObservabilityConfig = object
BaseMultiModalProcessorCache = object BaseMultiModalProcessorCache = object
...@@ -70,6 +74,127 @@ logger = init_logger(__name__) ...@@ -70,6 +74,127 @@ logger = init_logger(__name__)
_S = TypeVar("_S", str, list[int]) _S = TypeVar("_S", str, list[int])
_request_id_context: contextvars.ContextVar[str | None] = contextvars.ContextVar(
"_request_id_context", default=None
)
def get_current_request_id() -> str | None:
"""Get the current request_id from the context, if available."""
return _request_id_context.get()
@contextmanager
def set_request_id(request_id: str) -> Generator[None, None, None]:
"""Context manager to set the request_id for the current context."""
token = _request_id_context.set(request_id)
try:
yield
finally:
_request_id_context.reset(token)
@dataclass
class MultiModalProcessorTimingStats:
"""Per-request timing statistics for multimodal processor stages."""
hf_processor_time: float = 0.0
"""Time spent in HuggingFace processor calls (seconds)."""
hashing_time: float = 0.0
"""Time spent computing multimodal item hashes (seconds)."""
cache_lookup_time: float = 0.0
"""Time spent in cache lookups and merges (seconds)."""
prompt_update_time: float = 0.0
"""Time spent applying prompt updates and finding placeholders (seconds)."""
total_time: float = 0.0
"""Total processing time (seconds)."""
def to_dict(self) -> dict[str, float]:
"""Convert stats to a dictionary for JSON serialization."""
return {
"hf_processor_time": self.hf_processor_time,
"hashing_time": self.hashing_time,
"cache_lookup_time": self.cache_lookup_time,
"prompt_update_time": self.prompt_update_time,
"total_time": self.total_time,
}
def get_timing_stats_from_engine_client(
engine_client: Any,
) -> dict[str, dict[str, float]]:
"""
Get all timing stats from the context associated with the engine client.
Args:
engine_client: The engine client that has input_processor.
Returns:
A dictionary mapping request_id to stats dict.
"""
try:
if not engine_client.vllm_config.observability_config.enable_mm_processor_stats:
return {}
except (AttributeError, RuntimeError):
return {}
try:
input_processor = engine_client.input_processor
input_preprocessor = input_processor.input_preprocessor
if hasattr(input_preprocessor, "_get_mm_processor"):
mm_processor = input_preprocessor._get_mm_processor()
if mm_processor is not None and hasattr(mm_processor, "info"):
ctx = mm_processor.info.ctx
return ctx.get_all_timing_stats()
except (AttributeError, RuntimeError):
pass
return {}
@contextmanager
def _timed_operation(ctx: "InputProcessingContext", stage_name: str):
"""
Context manager to time an operation using the context's timing stats.
The request_id is automatically retrieved from the context variable,
so it doesn't need to be passed as a parameter.
Args:
ctx: The InputProcessingContext containing the timing stats registry.
stage_name: Name of the stage being timed.
"""
request_id = get_current_request_id()
if ctx is None or request_id is None:
yield
return
stats = ctx.get_timing_stats(request_id)
if stats is None:
yield
return
start_time = time.perf_counter()
try:
yield
finally:
elapsed = time.perf_counter() - start_time
if stage_name == "hf_processor":
stats.hf_processor_time += elapsed
elif stage_name == "hashing":
stats.hashing_time += elapsed
elif stage_name == "cache_lookup":
stats.cache_lookup_time += elapsed
elif stage_name == "prompt_update":
stats.prompt_update_time += elapsed
stats.total_time += elapsed
PromptSeq: TypeAlias = str | list[int] PromptSeq: TypeAlias = str | list[int]
"""A token sequence (list of token IDs) or text.""" """A token sequence (list of token IDs) or text."""
...@@ -951,6 +1076,21 @@ class InputProcessingContext: ...@@ -951,6 +1076,21 @@ class InputProcessingContext:
tokenizer: TokenizerLike | None tokenizer: TokenizerLike | None
"""The tokenizer used to tokenize the inputs.""" """The tokenizer used to tokenize the inputs."""
observability_config: "ObservabilityConfig | None" = field(
default=None, compare=False, repr=False
)
"""Configuration for observability features."""
timing_stats_registry: dict[str, MultiModalProcessorTimingStats] = field(
default_factory=dict, compare=False, repr=False
)
"""Registry for storing timing stats keyed by request_id."""
_timing_stats_registry_lock: threading.Lock = field(
default_factory=threading.Lock, compare=False, repr=False
)
"""Lock for thread-safe access to timing_stats_registry."""
def get_tokenizer(self) -> TokenizerLike: def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None: if self.tokenizer is None:
raise ValueError( raise ValueError(
...@@ -1159,6 +1299,71 @@ class InputProcessingContext: ...@@ -1159,6 +1299,71 @@ class InputProcessingContext:
return self._postprocess_output(output) return self._postprocess_output(output)
def get_timing_stats(
self, request_id: str
) -> MultiModalProcessorTimingStats | None:
"""
Get timing stats for a request.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return None
with self._timing_stats_registry_lock:
return self.timing_stats_registry.get(request_id)
def create_timing_stats(self, request_id: str) -> MultiModalProcessorTimingStats:
"""
Create and store timing stats in the registry for a request.
This should be called at the start of processing for a request.
The stats object is created immediately and stored in the registry.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return MultiModalProcessorTimingStats()
with self._timing_stats_registry_lock:
if request_id in self.timing_stats_registry:
raise ValueError(
f"Timing stats already exist for request_id: {request_id}"
)
stats = MultiModalProcessorTimingStats()
self.timing_stats_registry[request_id] = stats
return stats
def clear_timing_stats_registry(self) -> int:
"""
Clear all stats from the registry. Returns the number of stats cleared.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return 0
with self._timing_stats_registry_lock:
count = len(self.timing_stats_registry)
self.timing_stats_registry.clear()
return count
def get_all_timing_stats(self) -> dict[str, dict[str, float]]:
"""
Get all timing stats as a dictionary for API endpoints.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return {}
with self._timing_stats_registry_lock:
return {
rid: stats.to_dict()
for rid, stats in self.timing_stats_registry.items()
}
class BaseProcessingInfo: class BaseProcessingInfo:
"""Base class to provide the information necessary for data processing.""" """Base class to provide the information necessary for data processing."""
...@@ -1502,11 +1707,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1502,11 +1707,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Call the HF processor on the prompt text and Call the HF processor on the prompt text and
associated multi-modal data. associated multi-modal data.
""" """
return self.info.ctx.call_hf_processor( with _timed_operation(self.info.ctx, "hf_processor"):
self.info.get_hf_processor(**mm_kwargs), return self.info.ctx.call_hf_processor(
dict(text=prompt, **mm_data), self.info.get_hf_processor(**mm_kwargs),
dict(**mm_kwargs, **tok_kwargs), dict(text=prompt, **mm_data),
) dict(**mm_kwargs, **tok_kwargs),
)
def _hf_processor_applies_updates( def _hf_processor_applies_updates(
self, self,
...@@ -1854,12 +2060,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1854,12 +2060,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
) )
# Use overrides if provided; fallback to data-dependent hashing. # Use overrides if provided; fallback to data-dependent hashing.
mm_hashes = self._hash_mm_items( with _timed_operation(self.info.ctx, "hashing"):
mm_data_items, mm_hashes = self._hash_mm_items(
hf_processor_mm_kwargs, mm_data_items,
tokenization_kwargs, hf_processor_mm_kwargs,
mm_uuids=mm_uuids, tokenization_kwargs,
) mm_uuids=mm_uuids,
)
mm_prompt_updates = self._get_mm_prompt_updates( mm_prompt_updates = self._get_mm_prompt_updates(
mm_data_items, mm_data_items,
...@@ -1900,18 +2107,20 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1900,18 +2107,20 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_uuids=mm_uuids, mm_uuids=mm_uuids,
) )
mm_hashes = self._hash_mm_items( with _timed_operation(self.info.ctx, "hashing"):
mm_data_items, mm_hashes = self._hash_mm_items(
hf_processor_mm_kwargs, mm_data_items,
tokenization_kwargs, hf_processor_mm_kwargs,
mm_uuids=mm_uuids, tokenization_kwargs,
) mm_uuids=mm_uuids,
)
mm_is_cached, mm_missing_data_items = self._get_cache_missing_items( with _timed_operation(self.info.ctx, "cache_lookup"):
cache=cache, mm_is_cached, mm_missing_data_items = self._get_cache_missing_items(
mm_data_items=mm_data_items, cache=cache,
mm_hashes=mm_hashes, mm_data_items=mm_data_items,
) mm_hashes=mm_hashes,
)
# NOTE: `prompt` does not correspond to `mm_missing_data_items`, # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# so we can't apply prompt updates until the new multimodal # so we can't apply prompt updates until the new multimodal
...@@ -1941,13 +2150,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1941,13 +2150,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_missing_kwargs, mm_missing_kwargs,
) )
mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs( with _timed_operation(self.info.ctx, "cache_lookup"):
cache, mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
mm_hashes=mm_hashes, cache,
mm_is_cached=mm_is_cached, mm_hashes=mm_hashes,
mm_missing_kwargs=mm_missing_kwargs, mm_is_cached=mm_is_cached,
mm_missing_prompt_updates=mm_missing_prompt_updates, mm_missing_kwargs=mm_missing_kwargs,
) mm_missing_prompt_updates=mm_missing_prompt_updates,
)
mm_info = MultiModalProcessingInfo( mm_info = MultiModalProcessingInfo(
kwargs=mm_kwargs, kwargs=mm_kwargs,
...@@ -2129,6 +2339,10 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -2129,6 +2339,10 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
3. Extract information about the placeholder tokens from the 3. Extract information about the placeholder tokens from the
processed token IDs. processed token IDs.
""" """
request_id = get_current_request_id()
if request_id is not None:
self.info.ctx.create_timing_stats(request_id)
mm_items = self._to_mm_items(mm_data) mm_items = self._to_mm_items(mm_data)
if tokenization_kwargs is None: if tokenization_kwargs is None:
...@@ -2147,13 +2361,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -2147,13 +2361,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
) )
# NOTE: tokenization_kwargs are not required to init processor # NOTE: tokenization_kwargs are not required to init processor
prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates( with _timed_operation(self.info.ctx, "prompt_update"):
mm_items=mm_items, prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates(
prompt_ids=prompt_ids, mm_items=mm_items,
mm_kwargs=mm_info.kwargs, prompt_ids=prompt_ids,
mm_prompt_updates=mm_info.prompt_updates, mm_kwargs=mm_info.kwargs,
is_update_applied=is_update_applied, mm_prompt_updates=mm_info.prompt_updates,
) is_update_applied=is_update_applied,
)
mm_placeholder_ranges = { mm_placeholder_ranges = {
modality: [item.to_range() for item in placeholders] modality: [item.to_range() for item in placeholders]
......
...@@ -5,6 +5,7 @@ from dataclasses import dataclass ...@@ -5,6 +5,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.config.observability import ObservabilityConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
...@@ -22,7 +23,7 @@ from .profiling import ( ...@@ -22,7 +23,7 @@ from .profiling import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig, ObservabilityConfig
from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.interfaces import SupportsMultiModal
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -148,6 +149,7 @@ class MultiModalRegistry: ...@@ -148,6 +149,7 @@ class MultiModalRegistry:
*, *,
cache: BaseMultiModalProcessorCache | None = None, cache: BaseMultiModalProcessorCache | None = None,
profiler_limits: Mapping[str, int] | None = None, profiler_limits: Mapping[str, int] | None = None,
observability_config: ObservabilityConfig | None = None,
) -> Mapping[str, int]: ) -> Mapping[str, int]:
""" """
Get the maximum number of tokens per data item from each modality based Get the maximum number of tokens per data item from each modality based
...@@ -156,7 +158,9 @@ class MultiModalRegistry: ...@@ -156,7 +158,9 @@ class MultiModalRegistry:
if not model_config.is_multimodal_model: if not model_config.is_multimodal_model:
return {} return {}
processor = self.create_processor(model_config, cache=cache) processor = self.create_processor(
model_config, observability_config, cache=cache
)
profiler: MultiModalProfiler = MultiModalProfiler(processor) profiler: MultiModalProfiler = MultiModalProfiler(processor)
seq_len = model_config.max_model_len seq_len = model_config.max_model_len
...@@ -174,6 +178,7 @@ class MultiModalRegistry: ...@@ -174,6 +178,7 @@ class MultiModalRegistry:
model_config: "ModelConfig", model_config: "ModelConfig",
*, *,
cache: BaseMultiModalProcessorCache | None = None, cache: BaseMultiModalProcessorCache | None = None,
observability_config: ObservabilityConfig | None = None,
) -> Mapping[str, int]: ) -> Mapping[str, int]:
""" """
Get the maximum number of multi-modal input instances for each modality Get the maximum number of multi-modal input instances for each modality
...@@ -182,7 +187,9 @@ class MultiModalRegistry: ...@@ -182,7 +187,9 @@ class MultiModalRegistry:
if not model_config.is_multimodal_model: if not model_config.is_multimodal_model:
return {} return {}
processor = self.create_processor(model_config, cache=cache) processor = self.create_processor(
model_config, observability_config, cache=cache
)
profiler: MultiModalProfiler = MultiModalProfiler(processor) profiler: MultiModalProfiler = MultiModalProfiler(processor)
return profiler.get_mm_limits() return profiler.get_mm_limits()
...@@ -231,27 +238,32 @@ class MultiModalRegistry: ...@@ -231,27 +238,32 @@ class MultiModalRegistry:
def _create_processing_ctx( def _create_processing_ctx(
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
observability_config: "ObservabilityConfig | None" = None,
tokenizer: TokenizerLike | None = None, tokenizer: TokenizerLike | None = None,
) -> InputProcessingContext: ) -> InputProcessingContext:
if tokenizer is None and not model_config.skip_tokenizer_init: if tokenizer is None and not model_config.skip_tokenizer_init:
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(model_config)
return InputProcessingContext(model_config, tokenizer) return InputProcessingContext(
model_config, tokenizer, observability_config=observability_config
)
def _create_processing_info( def _create_processing_info(
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
observability_config: "ObservabilityConfig | None" = None,
*, *,
tokenizer: TokenizerLike | None = None, tokenizer: TokenizerLike | None = None,
) -> BaseProcessingInfo: ) -> BaseProcessingInfo:
model_cls = self._get_model_cls(model_config) model_cls = self._get_model_cls(model_config)
factories = model_cls._processor_factory factories = model_cls._processor_factory
ctx = self._create_processing_ctx(model_config, tokenizer) ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
return factories.info(ctx) return factories.info(ctx)
def create_processor( def create_processor(
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
observability_config: "ObservabilityConfig | None" = None,
*, *,
tokenizer: TokenizerLike | None = None, tokenizer: TokenizerLike | None = None,
cache: BaseMultiModalProcessorCache | None = None, cache: BaseMultiModalProcessorCache | None = None,
...@@ -265,7 +277,7 @@ class MultiModalRegistry: ...@@ -265,7 +277,7 @@ class MultiModalRegistry:
model_cls = self._get_model_cls(model_config) model_cls = self._get_model_cls(model_config)
factories = model_cls._processor_factory factories = model_cls._processor_factory
ctx = self._create_processing_ctx(model_config, tokenizer) ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
return factories.build_processor(ctx, cache=cache) return factories.build_processor(ctx, cache=cache)
...@@ -276,13 +288,16 @@ class MultiModalRegistry: ...@@ -276,13 +288,16 @@ class MultiModalRegistry:
mm_counts: Mapping[str, int] | None = None, mm_counts: Mapping[str, int] | None = None,
*, *,
cache: BaseMultiModalProcessorCache | None = None, cache: BaseMultiModalProcessorCache | None = None,
observability_config: ObservabilityConfig | None = None,
) -> DummyDecoderData: ) -> DummyDecoderData:
""" """
Create dummy data for profiling the memory usage of a model. Create dummy data for profiling the memory usage of a model.
The model is identified by `model_config`. The model is identified by `model_config`.
""" """
processor = self.create_processor(model_config, cache=cache) processor = self.create_processor(
model_config, observability_config, cache=cache
)
profiler: MultiModalProfiler = MultiModalProfiler(processor) profiler: MultiModalProfiler = MultiModalProfiler(processor)
# Extract configurable options from multimodal config. # Extract configurable options from multimodal config.
...@@ -309,13 +324,16 @@ class MultiModalRegistry: ...@@ -309,13 +324,16 @@ class MultiModalRegistry:
mm_counts: Mapping[str, int] | None = None, mm_counts: Mapping[str, int] | None = None,
*, *,
cache: BaseMultiModalProcessorCache | None = None, cache: BaseMultiModalProcessorCache | None = None,
observability_config: ObservabilityConfig | None = None,
) -> DummyEncoderData: ) -> DummyEncoderData:
""" """
Create dummy data for profiling the memory usage of a model. Create dummy data for profiling the memory usage of a model.
The model is identified by `model_config`. The model is identified by `model_config`.
""" """
processor = self.create_processor(model_config, cache=cache) processor = self.create_processor(
model_config, observability_config, cache=cache
)
profiler: MultiModalProfiler = MultiModalProfiler(processor) profiler: MultiModalProfiler = MultiModalProfiler(processor)
# Extract configurable options from multimodal config. # Extract configurable options from multimodal config.
......
...@@ -15,7 +15,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry ...@@ -15,7 +15,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config from vllm.multimodal.cache import processor_cache_from_config
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
from vllm.multimodal.parse import MultiModalDataParser from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.processing import EncDecMultiModalProcessor, set_request_id
from vllm.multimodal.utils import argsort_mm_positions from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -60,6 +60,7 @@ class InputProcessor: ...@@ -60,6 +60,7 @@ class InputProcessor:
self.input_preprocessor = InputPreprocessor( self.input_preprocessor = InputPreprocessor(
self.model_config, self.model_config,
tokenizer, tokenizer,
self.vllm_config.observability_config,
mm_registry, mm_registry,
mm_processor_cache=self.mm_processor_cache, mm_processor_cache=self.mm_processor_cache,
) )
...@@ -493,11 +494,13 @@ class InputProcessor: ...@@ -493,11 +494,13 @@ class InputProcessor:
# 1. Tokenize text prompt, with LoRA request if one exists. # 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess # 2. For multimodal models with a merged preprocessor, preprocess
# multimodal data and expand prompt token ids accordingly. # multimodal data and expand prompt token ids accordingly.
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( with set_request_id(request_id):
prompt, processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
tokenization_kwargs=tokenization_kwargs, prompt,
mm_uuids=mm_uuids, tokenization_kwargs=tokenization_kwargs,
) mm_uuids=mm_uuids,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
current_platform.validate_request( current_platform.validate_request(
...@@ -641,6 +644,7 @@ class InputProcessor: ...@@ -641,6 +644,7 @@ class InputProcessor:
mm_registry = self.input_preprocessor.mm_registry mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor( mm_processor = mm_registry.create_processor(
model_config, model_config,
self.vllm_config.observability_config,
tokenizer=tokenizer, tokenizer=tokenizer,
) )
assert isinstance(mm_processor, EncDecMultiModalProcessor) assert isinstance(mm_processor, EncDecMultiModalProcessor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment