mm_processor.py 17.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
r"""Benchmark multimodal processor latency.

This benchmark measures the latency of the mm processor module
using multimodal prompts from datasets.
MM processor stats are automatically enabled.

Run:
    vllm bench mm-processor \
        --model <your_model> \
        --dataset-name random-mm \
        --num-prompts 10 \
"""

import argparse
import dataclasses
import json
import time
from datetime import datetime
21
from typing import TYPE_CHECKING, Any
22
23
24

import numpy as np

25
26
27
28
from vllm.benchmarks.datasets import (
    MultiModalConversationDataset,
    VisionArenaDataset,
)
29
30
31
32
33
34
35
36
37
38
from vllm.benchmarks.throughput import get_requests
from vllm.engine.arg_utils import EngineArgs
from vllm.utils.gc_utils import freeze_gc_heap
from vllm.utils.import_utils import PlaceholderModule

try:
    import pandas as pd
except ImportError:
    pd = PlaceholderModule("pandas")

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
if TYPE_CHECKING:  # Avoid having to mock during docs build
    from vllm.v1.engine.llm_engine import LLMEngine
else:
    LLMEngine = object


def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, float]]:
    """
    Get all multimodal timing stats from the LLM engine.

    Collects both preprocessing stats (HF processor, hashing, cache lookup,
    prompt update) and encoder forward pass timing, merged by request_id.

    Args:
        llm_engine: The LLM engine (has input_processor and workers).

    Returns:
        Dictionary mapping request_id to merged stats dict containing
        both preprocessing and encoder timing metrics.

    Example:
        {
            'request-123': {
                'hf_processor_time': 0.45,
                'hashing_time': 0.02,
                'cache_lookup_time': 0.01,
                'prompt_update_time': 0.03,
                'preprocessor_total_time': 0.51,
                'encoder_forward_time': 0.23,
                'num_encoder_calls': 1
            }
        }
    """
    observability_config = llm_engine.vllm_config.observability_config
    if not observability_config or not observability_config.enable_mm_processor_stats:
        return {}

    renderer = llm_engine.renderer
    mm_processor = renderer.get_mm_processor()
    preprocessing_stats = mm_processor.info.ctx.get_all_timing_stats()

    encoder_stats = dict[str, dict[str, float]]()
    for worker_stats in llm_engine.collective_rpc("get_encoder_timing_stats"):
        if not worker_stats:
            continue

        for request_id, stats_dict in worker_stats.items():
            if request_id not in encoder_stats:
                encoder_stats[request_id] = dict(stats_dict)
            else:
                # Aggregate timing metrics across workers
                current_time = encoder_stats[request_id].get(
                    "encoder_forward_time", 0.0
                )
                new_time = stats_dict.get("encoder_forward_time", 0.0)
                encoder_stats[request_id]["encoder_forward_time"] = max(
                    current_time, new_time
                )

                current_calls = encoder_stats[request_id].get("num_encoder_calls", 0)
                new_calls = stats_dict.get("num_encoder_calls", 0)
                encoder_stats[request_id]["num_encoder_calls"] = max(
                    current_calls, new_calls
                )

    merged_stats = dict[str, dict[str, float]]()

    for request_id, prep_dict in preprocessing_stats.items():
        merged_stats[request_id] = dict(prep_dict)

    for request_id, enc_dict in encoder_stats.items():
        if request_id in merged_stats:
            merged_stats[request_id].update(enc_dict)
            continue

        # In V1 engine, the request_id in encoder_stats has a suffix
        # appended to the original request_id (which is used in
        # preprocessing_stats).
        # We try to strip the suffix to find the matching request.
        possible_original_id = request_id.rpartition("-")[0]
        if possible_original_id and possible_original_id in merged_stats:
            merged_stats[possible_original_id].update(enc_dict)
        else:
            merged_stats[request_id] = dict(enc_dict)

    return merged_stats

126
127

def collect_mm_processor_stats(
128
    llm_engine: LLMEngine,
129
    num_warmup_reqs: int = 0,
130
131
132
133
134
) -> dict[str, list[float]]:
    """
    Collect multimodal processor timing stats.
    Returns a dictionary mapping stage names to lists of timing values (in seconds).
    """
135
    all_stats = get_timing_stats_from_engine(llm_engine)
136

137
138
139
140
141
142
143
144
145
146
    stat_keys = [
        "hf_processor_time",
        "hashing_time",
        "cache_lookup_time",
        "prompt_update_time",
        "preprocessor_total_time",
        "encoder_forward_time",
        "num_encoder_calls",
    ]
    stats_by_stage = {key: [] for key in stat_keys}
147

148
149
150
151
    # Skip warmup requests
    stats_list = list(all_stats.values())[num_warmup_reqs:]

    for stats_dict in stats_list:
152
153
154
        for key in stat_keys:
            if key in stats_dict:
                stats_by_stage[key].append(stats_dict[key])
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

    return stats_by_stage


def calculate_mm_processor_metrics(
    stats_by_stage: dict[str, list[float]],
    selected_percentiles: list[float],
) -> dict[str, dict[str, float]]:
    """
    Calculate aggregate metrics from stats by stage.
    """
    metrics = {}

    for stage_name, times in stats_by_stage.items():
        if not times:
            metrics[stage_name] = {
                "mean": 0.0,
                "median": 0.0,
                "std": 0.0,
                **{f"p{p}": 0.0 for p in selected_percentiles},
            }
            continue

178
179
180
        is_count_metric = stage_name == "num_encoder_calls"
        values = times if is_count_metric else [t * 1000 for t in times]

181
        metrics[stage_name] = {
182
183
184
185
            "mean": float(np.mean(values)),
            "median": float(np.median(values)),
            "std": float(np.std(values)),
            **{f"p{p}": float(np.percentile(values, p)) for p in selected_percentiles},
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        }

    return metrics


def validate_args(args):
    """
    Validate command-line arguments for mm_processor benchmark.
    """
    if not getattr(args, "tokenizer", None):
        args.tokenizer = args.model
    if not hasattr(args, "dataset_path"):
        args.dataset_path = None
    if not hasattr(args, "lora_path"):
        args.lora_path = None
    if not hasattr(args, "max_loras"):
        args.max_loras = None

204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    if args.dataset_name == "hf" and not args.dataset_path:
        raise ValueError(
            "--dataset-path is required when using --dataset-name hf. "
            "For multimodal benchmarking, specify a dataset like "
            "'lmarena-ai/VisionArena-Chat'."
        )
    if args.dataset_name == "hf":
        supported_mm_datasets = (
            VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys()
            | MultiModalConversationDataset.SUPPORTED_DATASET_PATHS
        )
        if args.dataset_path not in supported_mm_datasets:
            raise ValueError(
                f"{args.dataset_path} is not a supported multimodal dataset. "
                f"Supported multimodal datasets are: {sorted(supported_mm_datasets)}"
            )

221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268

def benchmark_multimodal_processor(
    args: argparse.Namespace,
) -> dict[str, Any]:
    """
    Run the multimodal processor benchmark.
    """
    from vllm import LLM, SamplingParams

    validate_args(args)

    if args.seed is None:
        args.seed = 0

    engine_args = EngineArgs.from_cli_args(args)
    llm = LLM(**dataclasses.asdict(engine_args))

    tokenizer = llm.get_tokenizer()
    requests = get_requests(args, tokenizer)

    assert all(
        llm.llm_engine.model_config.max_model_len
        >= (request.prompt_len + request.expected_output_len)
        for request in requests
    ), (
        "Please ensure that max_model_len is greater than the sum of "
        "prompt_len and expected_output_len for all requests."
    )

    prompts = [request.prompt for request in requests]
    expected_output_lens = [request.expected_output_len for request in requests]

    sampling_params = [
        SamplingParams(
            n=1,
            temperature=0.0,
            max_tokens=output_len,
            detokenize=True,
        )
        for output_len in expected_output_lens
    ]

    selected_percentiles = [
        float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
    ]

    freeze_gc_heap()

269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    num_warmups = getattr(args, "num_warmups", 0)
    if num_warmups > 0:
        print(f"Processing {num_warmups} warmup requests...")
        # Create a temporary args object for warmup requests
        warmup_args = argparse.Namespace(**vars(args))
        warmup_args.num_prompts = num_warmups
        warmup_args.seed += 1
        warmup_requests = get_requests(warmup_args, tokenizer)
        warmup_prompts = [req.prompt for req in warmup_requests]
        warmup_output_lens = [req.expected_output_len for req in warmup_requests]
        warmup_sampling_params = [
            SamplingParams(max_tokens=output_len) for output_len in warmup_output_lens
        ]
        llm.chat(
            warmup_prompts,
            warmup_sampling_params,
            use_tqdm=not getattr(args, "disable_tqdm", False),
        )

288
289
290
291
292
293
294
295
296
297
    print(f"Processing {len(prompts)} requests...")
    start_time = time.perf_counter()

    outputs = llm.chat(
        prompts, sampling_params, use_tqdm=not getattr(args, "disable_tqdm", False)
    )

    end_time = time.perf_counter()
    total_time = end_time - start_time

298
    mm_stats_by_stage = collect_mm_processor_stats(llm.llm_engine, num_warmups)
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319

    if not any(mm_stats_by_stage.values()):
        print(
            "\n⚠️  Warning: No MM processor stats found in registry.\n"
            "   This may indicate that:\n"
            "   - No multimodal requests were processed\n"
            "   - Stats were already retrieved (registry is cleared after retrieval)\n"
        )

    mm_processor_metrics = calculate_mm_processor_metrics(
        mm_stats_by_stage, selected_percentiles
    )

    completed = len([o for o in outputs if o.finished])
    failed = len(outputs) - completed

    e2el_times = []
    for output in outputs:
        if not output.finished or output.metrics is None:
            continue
        metrics = output.metrics
320
321
322
323
324
325
326
327
328
329
        # Calculate E2E latency as: TTFT + (last_token_ts - first_token_ts)
        if (
            getattr(metrics, "first_token_latency", None) is not None
            and getattr(metrics, "last_token_ts", None) is not None
            and getattr(metrics, "first_token_ts", None) is not None
        ):
            ttft = metrics.first_token_latency
            # Decode time is the duration between the first and last token generation
            decode_time = max(0.0, metrics.last_token_ts - metrics.first_token_ts)
            e2el_times.append((ttft + decode_time) * 1000)
330
331

    if not e2el_times and completed > 0:
332
333
334
335
336
        print(
            "\n⚠️  Warning: Detailed end-to-end latency metrics not available.\n"
            "   Falling back to average request latency "
            "(total_time / num_completed_requests).\n"
        )
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
        avg_time_per_request = total_time / completed
        e2el_times = [avg_time_per_request * 1000] * completed

    if e2el_times:
        mean_e2el_ms = float(np.mean(e2el_times))
        median_e2el_ms = float(np.median(e2el_times))
        std_e2el_ms = float(np.std(e2el_times))
        percentiles_e2el_ms = [
            (p, float(np.percentile(e2el_times, p))) for p in selected_percentiles
        ]
    else:
        mean_e2el_ms = 0.0
        median_e2el_ms = 0.0
        std_e2el_ms = 0.0
        percentiles_e2el_ms = [(p, 0.0) for p in selected_percentiles]

353
354
355
356
357
358
359
360
361
362
363
    encoder_summary = {}
    if (
        "num_encoder_calls" in mm_stats_by_stage
        and mm_stats_by_stage["num_encoder_calls"]
    ):
        encoder_calls = mm_stats_by_stage["num_encoder_calls"]
        encoder_summary = {
            "total_encoder_calls": int(sum(encoder_calls)),
            "num_requests_with_encoder_calls": len(encoder_calls),
        }

364
365
366
367
368
369
370
371
    benchmark_result = {
        "completed": completed,
        "failed": failed,
        "mean_e2el_ms": mean_e2el_ms,
        "median_e2el_ms": median_e2el_ms,
        "std_e2el_ms": std_e2el_ms,
        "percentiles_e2el_ms": percentiles_e2el_ms,
        "mm_processor_stats": mm_processor_metrics,
372
        "encoder_summary": encoder_summary,
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
    }

    return benchmark_result


def add_cli_args(parser: argparse.ArgumentParser) -> None:
    """Add CLI arguments for the multimodal processor benchmark."""
    from vllm.engine.arg_utils import EngineArgs

    EngineArgs.add_cli_args(parser)

    parser.set_defaults(enable_mm_processor_stats=True)

    parser.add_argument(
        "--dataset-name",
        type=str,
        default="random-mm",
390
        choices=["random-mm", "hf"],
391
392
393
394
395
396
397
398
        help="Name of the dataset to benchmark on. Defaults to 'random-mm'.",
    )
    parser.add_argument(
        "--num-prompts",
        type=int,
        default=10,
        help="Number of prompts to process.",
    )
399
400
401
402
403
404
    parser.add_argument(
        "--num-warmups",
        type=int,
        default=1,
        help="Number of warmup prompts to process.",
    )
405
406
407
408
409
410
411
412
413

    from vllm.benchmarks.datasets import (
        add_random_dataset_base_args,
        add_random_multimodal_dataset_args,
    )

    add_random_dataset_base_args(parser)
    add_random_multimodal_dataset_args(parser)

414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
    # HuggingFace dataset arguments
    parser.add_argument(
        "--dataset-path",
        type=str,
        default=None,
        help="Path to the dataset file or HuggingFace dataset name "
        "(e.g., 'yale-nlp/MMVU', 'lmarena-ai/VisionArena-Chat').",
    )
    parser.add_argument(
        "--hf-subset",
        type=str,
        default=None,
        help="Subset of the HuggingFace dataset (optional).",
    )
    parser.add_argument(
        "--hf-split",
        type=str,
        default=None,
        help="Split of the HuggingFace dataset (e.g., 'train', 'test', 'validation').",
    )
    parser.add_argument(
        "--output-len",
        type=int,
        default=None,
        help="Output length for each request. "
        "Overrides the default output lengths from the dataset.",
    )

442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
    parser.add_argument(
        "--output-json",
        type=str,
        default=None,
        help="Path to save the benchmark results in JSON format.",
    )
    parser.add_argument(
        "--metric-percentiles",
        type=str,
        default="99",
        help="Comma-separated list of percentiles to calculate (e.g., '50,90,99').",
    )
    parser.add_argument(
        "--disable-tqdm",
        action="store_true",
        help="Disable tqdm progress bar.",
    )


def main(args: argparse.Namespace) -> None:
    """Main entry point for the multimodal processor benchmark."""

    print("Starting multimodal processor benchmark...")
    result = benchmark_multimodal_processor(args)

    print("\n" + "=" * 80)
    print("Multimodal Processor Benchmark Results")
    print("=" * 80)

    if "mm_processor_stats" in result:
472
        print("\nMM Processor Metrics:")
473
474
475
476
477
        selected_percentiles = [
            float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
        ]
        mm_data = []
        for stage, metrics in result["mm_processor_stats"].items():
478
479
480
            is_count = stage == "num_encoder_calls"
            unit = "" if is_count else " (ms)"

481
            row = {
482
                "Stage": stage + unit,
483
484
485
486
487
488
489
490
491
492
493
                "Mean": f"{metrics['mean']:.2f}",
                "Median": f"{metrics['median']:.2f}",
                "Std": f"{metrics['std']:.2f}",
            }
            for p in selected_percentiles:
                row[f"P{p}"] = f"{metrics.get(f'p{p}', 0.0):.2f}"
            mm_data.append(row)

        mm_df = pd.DataFrame(mm_data)
        print(mm_df.to_string(index=False))

494
495
496
497
498
499
500
501
        if "encoder_summary" in result and result["encoder_summary"]:
            total_calls = result["encoder_summary"]["total_encoder_calls"]
            num_requests = result["encoder_summary"]["num_requests_with_encoder_calls"]
            print(
                f"\nSummary: {total_calls} total encoder calls "
                f"across {num_requests} requests."
            )

502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
    if "mean_e2el_ms" in result:
        print("\nEnd-to-End Latency (ms):")
        selected_percentiles = [
            float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
        ]

        e2el_data = [
            {"Metric": "Mean", "Value (ms)": f"{result['mean_e2el_ms']:.2f}"},
            {"Metric": "Median", "Value (ms)": f"{result['median_e2el_ms']:.2f}"},
            {"Metric": "Std", "Value (ms)": f"{result['std_e2el_ms']:.2f}"},
        ]

        for p in selected_percentiles:
            percentile_value = next(
                (val for pct, val in result["percentiles_e2el_ms"] if pct == p),
                0.0,
            )
            e2el_data.append(
                {
                    "Metric": f"P{p}",
                    "Value (ms)": f"{percentile_value:.2f}",
                }
            )

        e2el_df = pd.DataFrame(e2el_data)
        print(e2el_df.to_string(index=False))

    if args.output_json:
        result["config"] = {
            "model": args.model,
            "num_prompts": args.num_prompts,
            "input_len": getattr(args, "random_input_len", None),
            "output_len": getattr(args, "random_output_len", None),
        }
        result["timestamp"] = datetime.now().isoformat()

        with open(args.output_json, "w") as f:
            json.dump(result, f, indent=2)
        print(f"\nResults saved to {args.output_json}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Benchmark mm processor latency")
    add_cli_args(parser)
    args = parser.parse_args()
    main(args)