Unverified Commit 1f5b7c41 authored by Reagan Lee's avatar Reagan Lee Committed by GitHub
Browse files

Add Multimodal Processor Benchmark (#29105)


Signed-off-by: default avatarReagan Lee <reaganjlee@gmail.com>
Signed-off-by: default avatarReagan <reaganjlee@gmail.com>
parent adcf682f
# vllm bench mm-processor
## JSON CLI Arguments
--8<-- "docs/cli/json_tip.inc.md"
## Arguments
--8<-- "docs/argparse/bench_mm_processor.inc.md"
......@@ -92,6 +92,7 @@ def auto_mock(module_name: str, attr: str, max_mocks: int = 100):
bench_latency = auto_mock("vllm.benchmarks", "latency")
bench_mm_processor = auto_mock("vllm.benchmarks", "mm_processor")
bench_serve = auto_mock("vllm.benchmarks", "serve")
bench_sweep_plot = auto_mock("vllm.benchmarks.sweep.plot", "SweepPlotArgs")
bench_sweep_plot_pareto = auto_mock(
......@@ -222,6 +223,7 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
"run-batch": create_parser(openai_run_batch.make_arg_parser),
# Benchmark CLI
"bench_latency": create_parser(bench_latency.add_cli_args),
"bench_mm_processor": create_parser(bench_mm_processor.add_cli_args),
"bench_serve": create_parser(bench_serve.add_cli_args),
"bench_sweep_plot": create_parser(bench_sweep_plot.add_cli_args),
"bench_sweep_plot_pareto": create_parser(bench_sweep_plot_pareto.add_cli_args),
......
......@@ -1437,19 +1437,97 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
)
random_group = parser.add_argument_group("random dataset options")
random_group.add_argument(
add_random_dataset_base_args(random_group)
random_mm_group = parser.add_argument_group(
"random multimodal dataset options extended from random dataset"
)
add_random_multimodal_dataset_args(random_mm_group)
hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument(
"--hf-subset", type=str, default=None, help="Subset of the HF dataset."
)
hf_group.add_argument(
"--hf-split", type=str, default=None, help="Split of the HF dataset."
)
hf_group.add_argument(
"--hf-name",
type=str,
default=None,
help=(
"Name of the dataset on HuggingFace "
"(e.g., 'lmarena-ai/VisionArena-Chat'). "
"Specify this if your dataset-path is a local path."
),
)
hf_group.add_argument(
"--hf-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output lengths "
"from the sampled HF dataset.",
)
prefix_repetition_group = parser.add_argument_group(
"prefix repetition dataset options"
)
prefix_repetition_group.add_argument(
"--prefix-repetition-prefix-len",
type=int,
default=256,
help="Number of prefix tokens per request, used only for prefix "
"repetition dataset.",
)
prefix_repetition_group.add_argument(
"--prefix-repetition-suffix-len",
type=int,
default=256,
help="Number of suffix tokens per request, used only for prefix "
"repetition dataset. Total input length is prefix_len + suffix_len.",
)
prefix_repetition_group.add_argument(
"--prefix-repetition-num-prefixes",
type=int,
default=10,
help="Number of prefixes to generate, used only for prefix repetition "
"dataset. Prompts per prefix is num_requests // num_prefixes.",
)
prefix_repetition_group.add_argument(
"--prefix-repetition-output-len",
type=int,
default=128,
help="Number of output tokens per request, used only for prefix "
"repetition dataset.",
)
def add_random_dataset_base_args(
parser_or_group: FlexibleArgumentParser | argparse._ArgumentGroup,
) -> None:
"""Add CLI arguments for base random dataset options.
This function adds arguments needed for:
- random (random dataset)
- random-mm (random multimodal dataset)
- random-rerank (random dataset for reranking)
Args:
parser_or_group: Either a parser or an argument group to add arguments to.
"""
parser_or_group.add_argument(
"--random-input-len",
type=int,
default=1024,
help="Number of input tokens per request, used only for random sampling.",
)
random_group.add_argument(
parser_or_group.add_argument(
"--random-output-len",
type=int,
default=128,
help="Number of output tokens per request, used only for random sampling.",
)
random_group.add_argument(
parser_or_group.add_argument(
"--random-range-ratio",
type=float,
default=0.0,
......@@ -1458,7 +1536,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"a symmetric sampling range"
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
)
random_group.add_argument(
parser_or_group.add_argument(
"--random-prefix-len",
type=int,
default=0,
......@@ -1471,13 +1549,13 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"input_len * (1 + range_ratio)]."
),
)
random_group.add_argument(
parser_or_group.add_argument(
"--random-batch-size",
type=int,
default=1,
help=("Batch size for random sampling. Only used for embeddings benchmark."),
)
random_group.add_argument(
parser_or_group.add_argument(
"--no-reranker",
action="store_true",
help=(
......@@ -1486,11 +1564,19 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
),
)
# random multimodal dataset options
random_mm_group = parser.add_argument_group(
"random multimodal dataset options extended from random dataset"
)
random_mm_group.add_argument(
def add_random_multimodal_dataset_args(
parser_or_group: FlexibleArgumentParser | argparse._ArgumentGroup,
) -> None:
"""Add CLI arguments for random multimodal dataset options.
This function adds arguments needed for:
- random-mm (random multimodal dataset)
Args:
parser_or_group: Either a parser or an argument group to add arguments to.
"""
parser_or_group.add_argument(
"--random-mm-base-items-per-request",
type=int,
default=RandomMultiModalDataset.DEFAULT_BASE_ITEMS_PER_REQUEST,
......@@ -1500,7 +1586,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"--random-mm-num-mm-items-range-ratio."
),
)
random_mm_group.add_argument(
parser_or_group.add_argument(
"--random-mm-num-mm-items-range-ratio",
type=float,
default=RandomMultiModalDataset.DEFAULT_NUM_MM_ITEMS_RANGE_RATIO,
......@@ -1515,7 +1601,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"An error is raised if the computed min exceeds the max."
),
)
random_mm_group.add_argument(
parser_or_group.add_argument(
"--random-mm-limit-mm-per-prompt",
type=json.loads,
default=RandomMultiModalDataset.DEFAULT_LIMIT_MM_PER_PROMPT,
......@@ -1559,7 +1645,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
return normalize(parsed)
raise ValueError("Unsupported value for --random-mm-bucket-config.")
random_mm_group.add_argument(
parser_or_group.add_argument(
"--random-mm-bucket-config",
type=_parse_mm_bucket_config,
default=RandomMultiModalDataset.DEFAULT_MM_ITEM_BUCKET_CONFIG,
......@@ -1580,63 +1666,6 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
),
)
hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument(
"--hf-subset", type=str, default=None, help="Subset of the HF dataset."
)
hf_group.add_argument(
"--hf-split", type=str, default=None, help="Split of the HF dataset."
)
hf_group.add_argument(
"--hf-name",
type=str,
default=None,
help=(
"Name of the dataset on HuggingFace "
"(e.g., 'lmarena-ai/VisionArena-Chat'). "
"Specify this if your dataset-path is a local path."
),
)
hf_group.add_argument(
"--hf-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output lengths "
"from the sampled HF dataset.",
)
prefix_repetition_group = parser.add_argument_group(
"prefix repetition dataset options"
)
prefix_repetition_group.add_argument(
"--prefix-repetition-prefix-len",
type=int,
default=256,
help="Number of prefix tokens per request, used only for prefix "
"repetition dataset.",
)
prefix_repetition_group.add_argument(
"--prefix-repetition-suffix-len",
type=int,
default=256,
help="Number of suffix tokens per request, used only for prefix "
"repetition dataset. Total input length is prefix_len + suffix_len.",
)
prefix_repetition_group.add_argument(
"--prefix-repetition-num-prefixes",
type=int,
default=10,
help="Number of prefixes to generate, used only for prefix repetition "
"dataset. Prompts per prefix is num_requests // num_prefixes.",
)
prefix_repetition_group.add_argument(
"--prefix-repetition-output-len",
type=int,
default=128,
help="Number of output tokens per request, used only for prefix "
"repetition dataset.",
)
def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]:
if not hasattr(args, "request_id_prefix"):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
r"""Benchmark multimodal processor latency.
This benchmark measures the latency of the mm processor module
using multimodal prompts from datasets.
MM processor stats are automatically enabled.
Run:
vllm bench mm-processor \
--model <your_model> \
--dataset-name random-mm \
--num-prompts 10 \
"""
import argparse
import dataclasses
import json
import time
from datetime import datetime
from typing import Any
import numpy as np
from vllm.benchmarks.throughput import get_requests
from vllm.engine.arg_utils import EngineArgs
from vllm.multimodal.processing import (
get_timing_stats_from_engine_client,
)
from vllm.utils.gc_utils import freeze_gc_heap
from vllm.utils.import_utils import PlaceholderModule
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
def collect_mm_processor_stats(
llm_engine: Any,
) -> dict[str, list[float]]:
"""
Collect multimodal processor timing stats.
Returns a dictionary mapping stage names to lists of timing values (in seconds).
"""
all_stats = get_timing_stats_from_engine_client(llm_engine)
stats_by_stage = {
"hf_processor_time": [],
"hashing_time": [],
"cache_lookup_time": [],
"prompt_update_time": [],
"total_time": [],
}
for stats_dict in all_stats.values():
stats_by_stage["hf_processor_time"].append(
stats_dict.get("hf_processor_time", 0.0)
)
stats_by_stage["hashing_time"].append(stats_dict.get("hashing_time", 0.0))
stats_by_stage["cache_lookup_time"].append(
stats_dict.get("cache_lookup_time", 0.0)
)
stats_by_stage["prompt_update_time"].append(
stats_dict.get("prompt_update_time", 0.0)
)
stats_by_stage["total_time"].append(stats_dict.get("total_time", 0.0))
return stats_by_stage
def calculate_mm_processor_metrics(
stats_by_stage: dict[str, list[float]],
selected_percentiles: list[float],
) -> dict[str, dict[str, float]]:
"""
Calculate aggregate metrics from stats by stage.
"""
metrics = {}
for stage_name, times in stats_by_stage.items():
if not times:
metrics[stage_name] = {
"mean": 0.0,
"median": 0.0,
"std": 0.0,
**{f"p{p}": 0.0 for p in selected_percentiles},
}
continue
times_ms = [t * 1000 for t in times]
metrics[stage_name] = {
"mean": float(np.mean(times_ms)),
"median": float(np.median(times_ms)),
"std": float(np.std(times_ms)),
**{
f"p{p}": float(np.percentile(times_ms, p)) for p in selected_percentiles
},
}
return metrics
def validate_args(args):
"""
Validate command-line arguments for mm_processor benchmark.
"""
if not getattr(args, "tokenizer", None):
args.tokenizer = args.model
if not hasattr(args, "dataset_path"):
args.dataset_path = None
if not hasattr(args, "lora_path"):
args.lora_path = None
if not hasattr(args, "max_loras"):
args.max_loras = None
def benchmark_multimodal_processor(
args: argparse.Namespace,
) -> dict[str, Any]:
"""
Run the multimodal processor benchmark.
"""
from vllm import LLM, SamplingParams
validate_args(args)
if args.seed is None:
args.seed = 0
engine_args = EngineArgs.from_cli_args(args)
llm = LLM(**dataclasses.asdict(engine_args))
tokenizer = llm.get_tokenizer()
requests = get_requests(args, tokenizer)
assert all(
llm.llm_engine.model_config.max_model_len
>= (request.prompt_len + request.expected_output_len)
for request in requests
), (
"Please ensure that max_model_len is greater than the sum of "
"prompt_len and expected_output_len for all requests."
)
prompts = [request.prompt for request in requests]
expected_output_lens = [request.expected_output_len for request in requests]
sampling_params = [
SamplingParams(
n=1,
temperature=0.0,
max_tokens=output_len,
detokenize=True,
)
for output_len in expected_output_lens
]
selected_percentiles = [
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
]
freeze_gc_heap()
print(f"Processing {len(prompts)} requests...")
start_time = time.perf_counter()
outputs = llm.chat(
prompts, sampling_params, use_tqdm=not getattr(args, "disable_tqdm", False)
)
end_time = time.perf_counter()
total_time = end_time - start_time
mm_stats_by_stage = collect_mm_processor_stats(
llm.llm_engine,
)
if not any(mm_stats_by_stage.values()):
print(
"\n⚠️ Warning: No MM processor stats found in registry.\n"
" This may indicate that:\n"
" - No multimodal requests were processed\n"
" - Stats were already retrieved (registry is cleared after retrieval)\n"
)
mm_processor_metrics = calculate_mm_processor_metrics(
mm_stats_by_stage, selected_percentiles
)
completed = len([o for o in outputs if o.finished])
failed = len(outputs) - completed
e2el_times = []
for output in outputs:
if not output.finished or output.metrics is None:
continue
metrics = output.metrics
for attr in ("finished_time", "last_token_time"):
if (
getattr(metrics, attr, None) is not None
and getattr(metrics, "arrival_time", None) is not None
):
e2el_times.append(
(getattr(metrics, attr) - metrics.arrival_time) * 1000
)
break
if not e2el_times and completed > 0:
avg_time_per_request = total_time / completed
e2el_times = [avg_time_per_request * 1000] * completed
if e2el_times:
mean_e2el_ms = float(np.mean(e2el_times))
median_e2el_ms = float(np.median(e2el_times))
std_e2el_ms = float(np.std(e2el_times))
percentiles_e2el_ms = [
(p, float(np.percentile(e2el_times, p))) for p in selected_percentiles
]
else:
mean_e2el_ms = 0.0
median_e2el_ms = 0.0
std_e2el_ms = 0.0
percentiles_e2el_ms = [(p, 0.0) for p in selected_percentiles]
benchmark_result = {
"completed": completed,
"failed": failed,
"mean_e2el_ms": mean_e2el_ms,
"median_e2el_ms": median_e2el_ms,
"std_e2el_ms": std_e2el_ms,
"percentiles_e2el_ms": percentiles_e2el_ms,
"mm_processor_stats": mm_processor_metrics,
}
return benchmark_result
def add_cli_args(parser: argparse.ArgumentParser) -> None:
"""Add CLI arguments for the multimodal processor benchmark."""
from vllm.engine.arg_utils import EngineArgs
EngineArgs.add_cli_args(parser)
parser.set_defaults(enable_mm_processor_stats=True)
parser.add_argument(
"--dataset-name",
type=str,
default="random-mm",
choices=["random-mm", "random-rerank"],
help="Name of the dataset to benchmark on. Defaults to 'random-mm'.",
)
parser.add_argument(
"--num-prompts",
type=int,
default=10,
help="Number of prompts to process.",
)
from vllm.benchmarks.datasets import (
add_random_dataset_base_args,
add_random_multimodal_dataset_args,
)
add_random_dataset_base_args(parser)
add_random_multimodal_dataset_args(parser)
parser.add_argument(
"--output-json",
type=str,
default=None,
help="Path to save the benchmark results in JSON format.",
)
parser.add_argument(
"--metric-percentiles",
type=str,
default="99",
help="Comma-separated list of percentiles to calculate (e.g., '50,90,99').",
)
parser.add_argument(
"--disable-tqdm",
action="store_true",
help="Disable tqdm progress bar.",
)
def main(args: argparse.Namespace) -> None:
"""Main entry point for the multimodal processor benchmark."""
print("Starting multimodal processor benchmark...")
result = benchmark_multimodal_processor(args)
print("\n" + "=" * 80)
print("Multimodal Processor Benchmark Results")
print("=" * 80)
if "mm_processor_stats" in result:
print("\nMM Processor Timing (ms):")
selected_percentiles = [
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
]
mm_data = []
for stage, metrics in result["mm_processor_stats"].items():
row = {
"Stage": stage,
"Mean": f"{metrics['mean']:.2f}",
"Median": f"{metrics['median']:.2f}",
"Std": f"{metrics['std']:.2f}",
}
for p in selected_percentiles:
row[f"P{p}"] = f"{metrics.get(f'p{p}', 0.0):.2f}"
mm_data.append(row)
mm_df = pd.DataFrame(mm_data)
print(mm_df.to_string(index=False))
if "mean_e2el_ms" in result:
print("\nEnd-to-End Latency (ms):")
selected_percentiles = [
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
]
e2el_data = [
{"Metric": "Mean", "Value (ms)": f"{result['mean_e2el_ms']:.2f}"},
{"Metric": "Median", "Value (ms)": f"{result['median_e2el_ms']:.2f}"},
{"Metric": "Std", "Value (ms)": f"{result['std_e2el_ms']:.2f}"},
]
for p in selected_percentiles:
percentile_value = next(
(val for pct, val in result["percentiles_e2el_ms"] if pct == p),
0.0,
)
e2el_data.append(
{
"Metric": f"P{p}",
"Value (ms)": f"{percentile_value:.2f}",
}
)
e2el_df = pd.DataFrame(e2el_data)
print(e2el_df.to_string(index=False))
if args.output_json:
result["config"] = {
"model": args.model,
"num_prompts": args.num_prompts,
"input_len": getattr(args, "random_input_len", None),
"output_len": getattr(args, "random_output_len", None),
}
result["timestamp"] = datetime.now().isoformat()
with open(args.output_json, "w") as f:
json.dump(result, f, indent=2)
print(f"\nResults saved to {args.output_json}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark mm processor latency")
add_cli_args(parser)
args = parser.parse_args()
main(args)
......@@ -24,10 +24,14 @@ from vllm.benchmarks.datasets import (
MultiModalConversationDataset,
PrefixRepetitionRandomDataset,
RandomDataset,
RandomDatasetForReranking,
RandomMultiModalDataset,
SampleRequest,
ShareGPTDataset,
SonnetDataset,
VisionArenaDataset,
add_random_dataset_base_args,
add_random_multimodal_dataset_args,
)
from vllm.benchmarks.lib.utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
......@@ -342,8 +346,6 @@ def get_requests(args, tokenizer):
"lora_path": args.lora_path,
"max_loras": args.max_loras,
"num_requests": args.num_prompts,
"input_len": args.input_len,
"output_len": args.output_len,
}
if args.dataset_name == "random" or (
......@@ -351,12 +353,26 @@ def get_requests(args, tokenizer):
and args.dataset_name not in {"prefix_repetition", "random-mm", "random-rerank"}
):
sample_kwargs["range_ratio"] = args.random_range_ratio
sample_kwargs["prefix_len"] = args.prefix_len
# prefer random_* arguments, fall back to regular arguments
random_prefix_len = getattr(args, "random_prefix_len", None)
sample_kwargs["prefix_len"] = (
random_prefix_len if random_prefix_len is not None else args.prefix_len
)
random_input_len = getattr(args, "random_input_len", None)
sample_kwargs["input_len"] = (
random_input_len if random_input_len is not None else args.input_len
)
random_output_len = getattr(args, "random_output_len", None)
sample_kwargs["output_len"] = (
random_output_len if random_output_len is not None else args.output_len
)
dataset_cls = RandomDataset
elif args.dataset_name == "sharegpt":
dataset_cls = ShareGPTDataset
if args.backend == "vllm-chat":
sample_kwargs["enable_multimodal_chat"] = True
if args.output_len is not None:
sample_kwargs["output_len"] = args.output_len
elif args.dataset_name == "sonnet":
assert tokenizer.chat_template or tokenizer.default_chat_template, (
"Tokenizer/model must have chat template for sonnet dataset."
......@@ -364,9 +380,15 @@ def get_requests(args, tokenizer):
dataset_cls = SonnetDataset
sample_kwargs["prefix_len"] = args.prefix_len
sample_kwargs["return_prompt_formatted"] = True
if args.input_len is not None:
sample_kwargs["input_len"] = args.input_len
if args.output_len is not None:
sample_kwargs["output_len"] = args.output_len
elif args.dataset_name == "burstgpt":
dataset_cls = BurstGPTDataset
elif args.dataset_name == "hf":
if args.output_len is not None:
sample_kwargs["output_len"] = args.output_len
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = VisionArenaDataset
common_kwargs["dataset_subset"] = None
......@@ -395,6 +417,56 @@ def get_requests(args, tokenizer):
sample_kwargs["suffix_len"] = args.prefix_repetition_suffix_len
sample_kwargs["num_prefixes"] = args.prefix_repetition_num_prefixes
sample_kwargs["output_len"] = args.prefix_repetition_output_len
elif args.dataset_name == "random-mm":
dataset_cls = RandomMultiModalDataset
# prefer random_* arguments, fall back to regular arguments
random_input_len = getattr(args, "random_input_len", None)
sample_kwargs["input_len"] = (
random_input_len
if random_input_len is not None
else getattr(args, "input_len", None)
)
random_output_len = getattr(args, "random_output_len", None)
sample_kwargs["output_len"] = (
random_output_len
if random_output_len is not None
else getattr(args, "output_len", None)
)
sample_kwargs["base_items_per_request"] = getattr(
args, "random_mm_base_items_per_request", None
)
sample_kwargs["num_mm_items_range_ratio"] = getattr(
args, "random_mm_num_mm_items_range_ratio", None
)
sample_kwargs["limit_mm_per_prompt"] = getattr(
args, "random_mm_limit_mm_per_prompt", None
)
sample_kwargs["bucket_config"] = getattr(args, "random_mm_bucket_config", None)
sample_kwargs["enable_multimodal_chat"] = True
random_prefix_len = getattr(args, "random_prefix_len", None)
prefix_len = getattr(args, "prefix_len", None)
sample_kwargs["prefix_len"] = (
random_prefix_len if random_prefix_len is not None else prefix_len
)
sample_kwargs["range_ratio"] = args.random_range_ratio
elif args.dataset_name == "random-rerank":
dataset_cls = RandomDatasetForReranking
# prefer random_* arguments, fall back to regular arguments
random_input_len = getattr(args, "random_input_len", None)
sample_kwargs["input_len"] = (
random_input_len
if random_input_len is not None
else getattr(args, "input_len", None)
)
random_output_len = getattr(args, "random_output_len", None)
sample_kwargs["output_len"] = (
random_output_len
if random_output_len is not None
else getattr(args, "output_len", None)
)
sample_kwargs["batchsize"] = getattr(args, "random_batch_size", 1)
sample_kwargs["is_reranker"] = not getattr(args, "no_reranker", False)
sample_kwargs["range_ratio"] = args.random_range_ratio
else:
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
# Remove None values
......@@ -451,8 +523,12 @@ def validate_args(args):
):
print("When dataset path is not set, it will default to random dataset")
args.dataset_name = "random"
if args.input_len is None:
raise ValueError("input_len must be provided for a random dataset")
random_input_len = getattr(args, "random_input_len", None)
if args.input_len is None and random_input_len is None:
raise ValueError(
"Either --input-len or --random-input-len must be provided "
"for a random dataset"
)
# === Dataset Name Specific Checks ===
# --hf-subset and --hf-split: only used
......@@ -485,26 +561,79 @@ def validate_args(args):
else:
raise ValueError(f"{args.dataset_path} is not supported by hf dataset.")
# --random-range-ratio: only used when dataset_name is 'random'
if args.dataset_name != "random" and args.random_range_ratio is not None:
# --random-range-ratio: only used when dataset_name is 'random',
# 'random-mm', or 'random-rerank'
if (
args.dataset_name not in {"random", "random-mm", "random-rerank"}
and args.random_range_ratio is not None
):
warnings.warn(
"--random-range-ratio will be ignored since \
--dataset-name is not 'random'.",
--dataset-name is not 'random', 'random-mm', or 'random-rerank'.",
stacklevel=2,
)
# --random-batch-size: only used when dataset_name is 'random-rerank'
if (
args.dataset_name != "random-rerank"
and getattr(args, "random_batch_size", None) is not None
) and args.random_batch_size != 1:
warnings.warn(
"--random-batch-size will be ignored since \
--dataset-name is not 'random-rerank'.",
stacklevel=2,
)
# --no-reranker: only used when dataset_name is 'random-rerank'
if args.dataset_name != "random-rerank" and getattr(args, "no_reranker", False):
warnings.warn(
"--no-reranker will be ignored since \
--dataset-name is not 'random-rerank'.",
stacklevel=2,
)
# --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
# set.
# --prefix-len: only used when dataset_name is 'random', 'random-mm',
# 'sonnet', or not set.
if (
args.dataset_name not in {"random", "sonnet", None}
args.dataset_name not in {"random", "random-mm", "sonnet", None}
and args.prefix_len is not None
):
warnings.warn(
"--prefix-len will be ignored since --dataset-name\
is not 'random', 'sonnet', or not set.",
is not 'random', 'random-mm', 'sonnet', or not set.",
stacklevel=2,
)
# === Random Dataset Argument Conflict Detection ===
# Check for conflicts between regular and random arguments when using
# random datasets
if args.dataset_name in {"random", "random-mm", "random-rerank"}:
random_input_len = getattr(args, "random_input_len", None)
random_output_len = getattr(args, "random_output_len", None)
random_prefix_len = getattr(args, "random_prefix_len", None)
if args.input_len is not None and random_input_len is not None:
warnings.warn(
"Both --input-len and --random-input-len are specified. "
"The random version (--random-input-len) will be preferred "
"in this run.",
stacklevel=2,
)
if args.output_len is not None and random_output_len is not None:
warnings.warn(
"Both --output-len and --random-output-len are specified. "
"The random version (--random-output-len) will be preferred "
"in this run.",
stacklevel=2,
)
if args.prefix_len is not None and random_prefix_len is not None:
warnings.warn(
"Both --prefix-len and --random-prefix-len are specified. "
"The random version (--random-prefix-len) will be preferred "
"in this run.",
stacklevel=2,
)
# === LoRA Settings ===
if getattr(args, "enable_lora", False) and args.backend != "vllm":
raise ValueError("LoRA benchmarking is only supported for vLLM backend")
......@@ -554,7 +683,16 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--dataset-name",
type=str,
choices=["sharegpt", "random", "sonnet", "burstgpt", "hf", "prefix_repetition"],
choices=[
"sharegpt",
"random",
"sonnet",
"burstgpt",
"hf",
"prefix_repetition",
"random-mm",
"random-rerank",
],
help="Name of the dataset to benchmark on.",
default="sharegpt",
)
......@@ -636,23 +774,19 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="Number of fixed prefix tokens before the random "
"context in a request (default: 0).",
)
# random dataset
parser.add_argument(
"--random-range-ratio",
type=float,
default=0.0,
help="Range ratio for sampling input/output length, "
"used only for RandomDataset. Must be in the range [0, 1) to define "
"a symmetric sampling range "
"[length * (1 - range_ratio), length * (1 + range_ratio)].",
)
# hf dtaset
parser.add_argument(
"--hf-subset", type=str, default=None, help="Subset of the HF dataset."
"--hf-subset",
type=str,
default=None,
help="Subset of the HF dataset.",
)
parser.add_argument(
"--hf-split", type=str, default=None, help="Split of the HF dataset."
"--hf-split",
type=str,
default=None,
help="Split of the HF dataset.",
)
parser.add_argument(
"--profile",
......@@ -662,31 +796,28 @@ def add_cli_args(parser: argparse.ArgumentParser):
)
# prefix repetition dataset
prefix_repetition_group = parser.add_argument_group(
"prefix repetition dataset options"
)
prefix_repetition_group.add_argument(
parser.add_argument(
"--prefix-repetition-prefix-len",
type=int,
default=None,
help="Number of prefix tokens per request, used only for prefix "
"repetition dataset.",
)
prefix_repetition_group.add_argument(
parser.add_argument(
"--prefix-repetition-suffix-len",
type=int,
default=None,
help="Number of suffix tokens per request, used only for prefix "
"repetition dataset. Total input length is prefix_len + suffix_len.",
)
prefix_repetition_group.add_argument(
parser.add_argument(
"--prefix-repetition-num-prefixes",
type=int,
default=None,
help="Number of prefixes to generate, used only for prefix repetition "
"dataset. Prompts per prefix is num_requests // num_prefixes.",
)
prefix_repetition_group.add_argument(
parser.add_argument(
"--prefix-repetition-output-len",
type=int,
default=None,
......@@ -694,6 +825,10 @@ def add_cli_args(parser: argparse.ArgumentParser):
"repetition dataset.",
)
# (random, random-mm, random-rerank)
add_random_dataset_base_args(parser)
add_random_multimodal_dataset_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser)
......
......@@ -67,6 +67,14 @@ class ObservabilityConfig:
enable_mfu_metrics: bool = False
"""Enable Model FLOPs Utilization (MFU) metrics."""
enable_mm_processor_stats: bool = False
"""Enable collection of timing statistics for multimodal processor operations.
This is for internal use only (e.g., benchmarks) and is not exposed as a CLI
argument."""
enable_mfu_metrics: bool = False
"""Enable Model FLOPs Utilization (MFU) metrics."""
@cached_property
def collect_model_forward_time(self) -> bool:
"""Whether to collect model forward time for the request."""
......
......@@ -523,6 +523,7 @@ class EngineArgs:
ObservabilityConfig.enable_layerwise_nvtx_tracing
)
enable_mfu_metrics: bool = ObservabilityConfig.enable_mfu_metrics
enable_mm_processor_stats: bool = ObservabilityConfig.enable_mm_processor_stats
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
......@@ -1712,6 +1713,7 @@ class EngineArgs:
cudagraph_metrics=self.cudagraph_metrics,
enable_layerwise_nvtx_tracing=self.enable_layerwise_nvtx_tracing,
enable_mfu_metrics=self.enable_mfu_metrics,
enable_mm_processor_stats=self.enable_mm_processor_stats,
)
# Compilation config overrides
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand
from vllm.entrypoints.cli.benchmark.mm_processor import (
BenchmarkMMProcessorSubcommand,
)
from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand
from vllm.entrypoints.cli.benchmark.startup import BenchmarkStartupSubcommand
from vllm.entrypoints.cli.benchmark.sweep import BenchmarkSweepSubcommand
......@@ -8,6 +11,7 @@ from vllm.entrypoints.cli.benchmark.throughput import BenchmarkThroughputSubcomm
__all__: list[str] = [
"BenchmarkLatencySubcommand",
"BenchmarkMMProcessorSubcommand",
"BenchmarkServingSubcommand",
"BenchmarkStartupSubcommand",
"BenchmarkSweepSubcommand",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm.benchmarks.mm_processor import add_cli_args, main
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
class BenchmarkMMProcessorSubcommand(BenchmarkSubcommandBase):
"""The `mm-processor` subcommand for `vllm bench`."""
name = "mm-processor"
help = "Benchmark multimodal processor latency across different configurations."
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
add_cli_args(parser)
@staticmethod
def cmd(args: argparse.Namespace) -> None:
main(args)
......@@ -6,7 +6,7 @@ from typing import Any, cast
from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.config import ModelConfig, ObservabilityConfig
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache
......@@ -47,6 +47,7 @@ class InputPreprocessor:
self,
model_config: ModelConfig,
tokenizer: TokenizerLike | None,
observability_config: ObservabilityConfig | None = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: BaseMultiModalProcessorCache | None = None,
) -> None:
......@@ -54,6 +55,7 @@ class InputPreprocessor:
self.model_config = model_config
self.tokenizer = tokenizer
self.observability_config = observability_config
self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache
......@@ -232,6 +234,7 @@ class InputPreprocessor:
if not hasattr(self, "_mm_processor"):
self._mm_processor = self.mm_registry.create_processor(
self.model_config,
self.observability_config,
tokenizer=self.tokenizer,
cache=self.mm_processor_cache,
)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextvars
import threading
import time
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Callable, Generator, ItemsView, Iterable, Mapping, Sequence
from contextlib import contextmanager
from dataclasses import dataclass, field, replace
from enum import Enum
from functools import lru_cache
......@@ -53,7 +56,7 @@ if TYPE_CHECKING:
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin
from vllm.config import ModelConfig
from vllm.config import ModelConfig, ObservabilityConfig
from .cache import BaseMultiModalProcessorCache
from .profiling import BaseDummyInputsBuilder
......@@ -63,6 +66,7 @@ else:
ProcessorMixin = object
ModelConfig = object
ObservabilityConfig = object
BaseMultiModalProcessorCache = object
......@@ -70,6 +74,127 @@ logger = init_logger(__name__)
_S = TypeVar("_S", str, list[int])
_request_id_context: contextvars.ContextVar[str | None] = contextvars.ContextVar(
"_request_id_context", default=None
)
def get_current_request_id() -> str | None:
"""Get the current request_id from the context, if available."""
return _request_id_context.get()
@contextmanager
def set_request_id(request_id: str) -> Generator[None, None, None]:
"""Context manager to set the request_id for the current context."""
token = _request_id_context.set(request_id)
try:
yield
finally:
_request_id_context.reset(token)
@dataclass
class MultiModalProcessorTimingStats:
"""Per-request timing statistics for multimodal processor stages."""
hf_processor_time: float = 0.0
"""Time spent in HuggingFace processor calls (seconds)."""
hashing_time: float = 0.0
"""Time spent computing multimodal item hashes (seconds)."""
cache_lookup_time: float = 0.0
"""Time spent in cache lookups and merges (seconds)."""
prompt_update_time: float = 0.0
"""Time spent applying prompt updates and finding placeholders (seconds)."""
total_time: float = 0.0
"""Total processing time (seconds)."""
def to_dict(self) -> dict[str, float]:
"""Convert stats to a dictionary for JSON serialization."""
return {
"hf_processor_time": self.hf_processor_time,
"hashing_time": self.hashing_time,
"cache_lookup_time": self.cache_lookup_time,
"prompt_update_time": self.prompt_update_time,
"total_time": self.total_time,
}
def get_timing_stats_from_engine_client(
engine_client: Any,
) -> dict[str, dict[str, float]]:
"""
Get all timing stats from the context associated with the engine client.
Args:
engine_client: The engine client that has input_processor.
Returns:
A dictionary mapping request_id to stats dict.
"""
try:
if not engine_client.vllm_config.observability_config.enable_mm_processor_stats:
return {}
except (AttributeError, RuntimeError):
return {}
try:
input_processor = engine_client.input_processor
input_preprocessor = input_processor.input_preprocessor
if hasattr(input_preprocessor, "_get_mm_processor"):
mm_processor = input_preprocessor._get_mm_processor()
if mm_processor is not None and hasattr(mm_processor, "info"):
ctx = mm_processor.info.ctx
return ctx.get_all_timing_stats()
except (AttributeError, RuntimeError):
pass
return {}
@contextmanager
def _timed_operation(ctx: "InputProcessingContext", stage_name: str):
"""
Context manager to time an operation using the context's timing stats.
The request_id is automatically retrieved from the context variable,
so it doesn't need to be passed as a parameter.
Args:
ctx: The InputProcessingContext containing the timing stats registry.
stage_name: Name of the stage being timed.
"""
request_id = get_current_request_id()
if ctx is None or request_id is None:
yield
return
stats = ctx.get_timing_stats(request_id)
if stats is None:
yield
return
start_time = time.perf_counter()
try:
yield
finally:
elapsed = time.perf_counter() - start_time
if stage_name == "hf_processor":
stats.hf_processor_time += elapsed
elif stage_name == "hashing":
stats.hashing_time += elapsed
elif stage_name == "cache_lookup":
stats.cache_lookup_time += elapsed
elif stage_name == "prompt_update":
stats.prompt_update_time += elapsed
stats.total_time += elapsed
PromptSeq: TypeAlias = str | list[int]
"""A token sequence (list of token IDs) or text."""
......@@ -951,6 +1076,21 @@ class InputProcessingContext:
tokenizer: TokenizerLike | None
"""The tokenizer used to tokenize the inputs."""
observability_config: "ObservabilityConfig | None" = field(
default=None, compare=False, repr=False
)
"""Configuration for observability features."""
timing_stats_registry: dict[str, MultiModalProcessorTimingStats] = field(
default_factory=dict, compare=False, repr=False
)
"""Registry for storing timing stats keyed by request_id."""
_timing_stats_registry_lock: threading.Lock = field(
default_factory=threading.Lock, compare=False, repr=False
)
"""Lock for thread-safe access to timing_stats_registry."""
def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
......@@ -1159,6 +1299,71 @@ class InputProcessingContext:
return self._postprocess_output(output)
def get_timing_stats(
self, request_id: str
) -> MultiModalProcessorTimingStats | None:
"""
Get timing stats for a request.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return None
with self._timing_stats_registry_lock:
return self.timing_stats_registry.get(request_id)
def create_timing_stats(self, request_id: str) -> MultiModalProcessorTimingStats:
"""
Create and store timing stats in the registry for a request.
This should be called at the start of processing for a request.
The stats object is created immediately and stored in the registry.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return MultiModalProcessorTimingStats()
with self._timing_stats_registry_lock:
if request_id in self.timing_stats_registry:
raise ValueError(
f"Timing stats already exist for request_id: {request_id}"
)
stats = MultiModalProcessorTimingStats()
self.timing_stats_registry[request_id] = stats
return stats
def clear_timing_stats_registry(self) -> int:
"""
Clear all stats from the registry. Returns the number of stats cleared.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return 0
with self._timing_stats_registry_lock:
count = len(self.timing_stats_registry)
self.timing_stats_registry.clear()
return count
def get_all_timing_stats(self) -> dict[str, dict[str, float]]:
"""
Get all timing stats as a dictionary for API endpoints.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return {}
with self._timing_stats_registry_lock:
return {
rid: stats.to_dict()
for rid, stats in self.timing_stats_registry.items()
}
class BaseProcessingInfo:
"""Base class to provide the information necessary for data processing."""
......@@ -1502,11 +1707,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Call the HF processor on the prompt text and
associated multi-modal data.
"""
return self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data),
dict(**mm_kwargs, **tok_kwargs),
)
with _timed_operation(self.info.ctx, "hf_processor"):
return self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data),
dict(**mm_kwargs, **tok_kwargs),
)
def _hf_processor_applies_updates(
self,
......@@ -1854,12 +2060,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
)
# Use overrides if provided; fallback to data-dependent hashing.
mm_hashes = self._hash_mm_items(
mm_data_items,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_uuids=mm_uuids,
)
with _timed_operation(self.info.ctx, "hashing"):
mm_hashes = self._hash_mm_items(
mm_data_items,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_uuids=mm_uuids,
)
mm_prompt_updates = self._get_mm_prompt_updates(
mm_data_items,
......@@ -1900,18 +2107,20 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_uuids=mm_uuids,
)
mm_hashes = self._hash_mm_items(
mm_data_items,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_uuids=mm_uuids,
)
with _timed_operation(self.info.ctx, "hashing"):
mm_hashes = self._hash_mm_items(
mm_data_items,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_uuids=mm_uuids,
)
mm_is_cached, mm_missing_data_items = self._get_cache_missing_items(
cache=cache,
mm_data_items=mm_data_items,
mm_hashes=mm_hashes,
)
with _timed_operation(self.info.ctx, "cache_lookup"):
mm_is_cached, mm_missing_data_items = self._get_cache_missing_items(
cache=cache,
mm_data_items=mm_data_items,
mm_hashes=mm_hashes,
)
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# so we can't apply prompt updates until the new multimodal
......@@ -1941,13 +2150,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_missing_kwargs,
)
mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
cache,
mm_hashes=mm_hashes,
mm_is_cached=mm_is_cached,
mm_missing_kwargs=mm_missing_kwargs,
mm_missing_prompt_updates=mm_missing_prompt_updates,
)
with _timed_operation(self.info.ctx, "cache_lookup"):
mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
cache,
mm_hashes=mm_hashes,
mm_is_cached=mm_is_cached,
mm_missing_kwargs=mm_missing_kwargs,
mm_missing_prompt_updates=mm_missing_prompt_updates,
)
mm_info = MultiModalProcessingInfo(
kwargs=mm_kwargs,
......@@ -2129,6 +2339,10 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
3. Extract information about the placeholder tokens from the
processed token IDs.
"""
request_id = get_current_request_id()
if request_id is not None:
self.info.ctx.create_timing_stats(request_id)
mm_items = self._to_mm_items(mm_data)
if tokenization_kwargs is None:
......@@ -2147,13 +2361,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
)
# NOTE: tokenization_kwargs are not required to init processor
prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates(
mm_items=mm_items,
prompt_ids=prompt_ids,
mm_kwargs=mm_info.kwargs,
mm_prompt_updates=mm_info.prompt_updates,
is_update_applied=is_update_applied,
)
with _timed_operation(self.info.ctx, "prompt_update"):
prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates(
mm_items=mm_items,
prompt_ids=prompt_ids,
mm_kwargs=mm_info.kwargs,
mm_prompt_updates=mm_info.prompt_updates,
is_update_applied=is_update_applied,
)
mm_placeholder_ranges = {
modality: [item.to_range() for item in placeholders]
......
......@@ -5,6 +5,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.observability import ObservabilityConfig
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
......@@ -22,7 +23,7 @@ from .profiling import (
)
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.config import ModelConfig, ObservabilityConfig
from vllm.model_executor.models.interfaces import SupportsMultiModal
logger = init_logger(__name__)
......@@ -148,6 +149,7 @@ class MultiModalRegistry:
*,
cache: BaseMultiModalProcessorCache | None = None,
profiler_limits: Mapping[str, int] | None = None,
observability_config: ObservabilityConfig | None = None,
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
......@@ -156,7 +158,9 @@ class MultiModalRegistry:
if not model_config.is_multimodal_model:
return {}
processor = self.create_processor(model_config, cache=cache)
processor = self.create_processor(
model_config, observability_config, cache=cache
)
profiler: MultiModalProfiler = MultiModalProfiler(processor)
seq_len = model_config.max_model_len
......@@ -174,6 +178,7 @@ class MultiModalRegistry:
model_config: "ModelConfig",
*,
cache: BaseMultiModalProcessorCache | None = None,
observability_config: ObservabilityConfig | None = None,
) -> Mapping[str, int]:
"""
Get the maximum number of multi-modal input instances for each modality
......@@ -182,7 +187,9 @@ class MultiModalRegistry:
if not model_config.is_multimodal_model:
return {}
processor = self.create_processor(model_config, cache=cache)
processor = self.create_processor(
model_config, observability_config, cache=cache
)
profiler: MultiModalProfiler = MultiModalProfiler(processor)
return profiler.get_mm_limits()
......@@ -231,27 +238,32 @@ class MultiModalRegistry:
def _create_processing_ctx(
self,
model_config: "ModelConfig",
observability_config: "ObservabilityConfig | None" = None,
tokenizer: TokenizerLike | None = None,
) -> InputProcessingContext:
if tokenizer is None and not model_config.skip_tokenizer_init:
tokenizer = cached_tokenizer_from_config(model_config)
return InputProcessingContext(model_config, tokenizer)
return InputProcessingContext(
model_config, tokenizer, observability_config=observability_config
)
def _create_processing_info(
self,
model_config: "ModelConfig",
observability_config: "ObservabilityConfig | None" = None,
*,
tokenizer: TokenizerLike | None = None,
) -> BaseProcessingInfo:
model_cls = self._get_model_cls(model_config)
factories = model_cls._processor_factory
ctx = self._create_processing_ctx(model_config, tokenizer)
ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
return factories.info(ctx)
def create_processor(
self,
model_config: "ModelConfig",
observability_config: "ObservabilityConfig | None" = None,
*,
tokenizer: TokenizerLike | None = None,
cache: BaseMultiModalProcessorCache | None = None,
......@@ -265,7 +277,7 @@ class MultiModalRegistry:
model_cls = self._get_model_cls(model_config)
factories = model_cls._processor_factory
ctx = self._create_processing_ctx(model_config, tokenizer)
ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
return factories.build_processor(ctx, cache=cache)
......@@ -276,13 +288,16 @@ class MultiModalRegistry:
mm_counts: Mapping[str, int] | None = None,
*,
cache: BaseMultiModalProcessorCache | None = None,
observability_config: ObservabilityConfig | None = None,
) -> DummyDecoderData:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by `model_config`.
"""
processor = self.create_processor(model_config, cache=cache)
processor = self.create_processor(
model_config, observability_config, cache=cache
)
profiler: MultiModalProfiler = MultiModalProfiler(processor)
# Extract configurable options from multimodal config.
......@@ -309,13 +324,16 @@ class MultiModalRegistry:
mm_counts: Mapping[str, int] | None = None,
*,
cache: BaseMultiModalProcessorCache | None = None,
observability_config: ObservabilityConfig | None = None,
) -> DummyEncoderData:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by `model_config`.
"""
processor = self.create_processor(model_config, cache=cache)
processor = self.create_processor(
model_config, observability_config, cache=cache
)
profiler: MultiModalProfiler = MultiModalProfiler(processor)
# Extract configurable options from multimodal config.
......
......@@ -15,7 +15,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.processing import EncDecMultiModalProcessor, set_request_id
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
......@@ -60,6 +60,7 @@ class InputProcessor:
self.input_preprocessor = InputPreprocessor(
self.model_config,
tokenizer,
self.vllm_config.observability_config,
mm_registry,
mm_processor_cache=self.mm_processor_cache,
)
......@@ -493,11 +494,13 @@ class InputProcessor:
# 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess
# multimodal data and expand prompt token ids accordingly.
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
with set_request_id(request_id):
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
from vllm.platforms import current_platform
current_platform.validate_request(
......@@ -641,6 +644,7 @@ class InputProcessor:
mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor(
model_config,
self.vllm_config.observability_config,
tokenizer=tokenizer,
)
assert isinstance(mm_processor, EncDecMultiModalProcessor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment