Unverified Commit 0cb7b065 authored by Breno Baldas Skuk's avatar Breno Baldas Skuk Committed by GitHub
Browse files

Feature/benchmark/random mm data/images (#23119)


Signed-off-by: default avatarbreno.skuk <breno.skuk@hcompany.ai>
parent 2da02dd0
......@@ -59,6 +59,12 @@ become available.
<td style="text-align: center;"></td>
<td><code>synthetic</code></td>
</tr>
<tr>
<td><strong>RandomMultiModal (Image/Video)</strong></td>
<td style="text-align: center;">🟡</td>
<td style="text-align: center;">🚧</td>
<td><code>synthetic</code> </td>
</tr>
<tr>
<td><strong>Prefix Repetition</strong></td>
<td style="text-align: center;"></td>
......@@ -722,4 +728,75 @@ python benchmarks/benchmark_serving.py \
--endpoint /v1/chat/completion
```
### Synthetic Random Images (random-mm)
Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets.
Notes:
- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`.
- Video sampling is not yet implemented.
Start the server (example):
```bash
vllm serve Qwen/Qwen2.5-VL-3B-Instruct \
--dtype bfloat16 \
--max-model-len 16384 \
--limit-mm-per-prompt '{"image": 3, "video": 0}' \
--mm-processor-kwargs max_pixels=1003520
```
Benchmark. It is recommended to use the flag `--ignore-eos` to simulate real responses. You can set the size of the output via the arg `random-output-len`.
Ex.1: Fixed number of items and a single image resolutionm, enforcing generation of approx 40 tokens:
```bash
vllm bench serve \
--backend openai-chat \
--model Qwen/Qwen2.5-VL-3B-Instruct \
--endpoint /v1/chat/completions \
--dataset-name random-mm \
--num-prompts 100 \
--max-concurrency 10 \
--random-prefix-len 25 \
--random-input-len 300 \
--random-output-len 40 \
--random-range-ratio 0.2 \
--random-mm-base-items-per-request 2 \
--random-mm-limit-mm-per-prompt '{"image": 3, "video": 0}' \
--random-mm-bucket-config '{(224, 224, 1): 1.0}' \
--request-rate inf \
--ignore-eos \
--seed 42
```
The number of items per request can be controlled by passing multiple image buckets:
```bash
--random-mm-base-items-per-request 2 \
--random-mm-num-mm-items-range-ratio 0.5 \
--random-mm-limit-mm-per-prompt '{"image": 4, "video": 0}' \
--random-mm-bucket-config '{(256, 256, 1): 0.7, (720, 1280, 1): 0.3}' \
```
Flags specific to `random-mm`:
- `--random-mm-base-items-per-request`: base number of multimodal items per request.
- `--random-mm-num-mm-items-range-ratio`: vary item count uniformly in the closed integer range [floor(n·(1−r)), ceil(n·(1+r))]. Set r=0 to keep it fixed; r=1 allows 0 items.
- `--random-mm-limit-mm-per-prompt`: per-modality hard caps, e.g. '{"image": 3, "video": 0}'.
- `--random-mm-bucket-config`: dict mapping (H, W, T) → probability. Entries with probability 0 are removed; remaining probabilities are renormalized to sum to 1. Use T=1 for images. Set any T>1 for videos (video sampling not yet supported).
Behavioral notes:
- If the requested base item count cannot be satisfied under the provided per-prompt limits, the tool raises an error rather than silently clamping.
How sampling works:
- Determine per-request item count k by sampling uniformly from the integer range defined by `--random-mm-base-items-per-request` and `--random-mm-num-mm-items-range-ratio`, then clamp k to at most the sum of per-modality limits.
- For each of the k items, sample a bucket (H, W, T) according to the normalized probabilities in `--random-mm-bucket-config`, while tracking how many items of each modality have been added.
- If a modality (e.g., image) reaches its limit from `--random-mm-limit-mm-per-prompt`, all buckets of that modality are excluded and the remaining bucket probabilities are renormalized before continuing.
This should be seen as an edge case, and if this behavior can be avoided by setting `--random-mm-limit-mm-per-prompt` to a large number. Note that this might result in errors due to engine config `--limit-mm-per-prompt`.
- The resulting request contains synthetic image data in `multi_modal_data` (OpenAI Chat format). When `random-mm` is used with the OpenAI Chat backend, prompts remain text and MM content is attached via `multi_modal_data`.
</details>
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
from typing import Any, NamedTuple, Optional, cast
import numpy as np
import pytest
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.benchmarks.datasets import (RandomDataset, RandomMultiModalDataset,
SampleRequest)
@pytest.fixture(scope="session")
def hf_tokenizer() -> PreTrainedTokenizerBase:
# Use a small, commonly available tokenizer
return AutoTokenizer.from_pretrained("gpt2")
class Params(NamedTuple):
num_requests: int
prefix_len: int
range_ratio: float
input_len: int
output_len: int
@pytest.fixture(scope="session")
def random_dataset_params() -> Params:
return Params(num_requests=16,
prefix_len=7,
range_ratio=0.3,
input_len=50,
output_len=20)
def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
"""Project a SampleRequest into a comparable tuple."""
return (req.prompt, req.prompt_len, req.expected_output_len)
def _collect_samples(dataset: RandomDataset,
tokenizer: PreTrainedTokenizerBase,
num_requests: int = 16,
prefix_len: int = 7,
range_ratio: float = 0.3,
input_len: int = 50,
output_len: int = 20) -> list[tuple[str, int, int]]:
samples = dataset.sample(
tokenizer=tokenizer,
num_requests=num_requests,
prefix_len=prefix_len,
range_ratio=range_ratio,
input_len=input_len,
output_len=output_len,
)
return [_fingerprint_sample(s) for s in samples]
@pytest.mark.benchmark
def test_random_dataset_same_seed(
hf_tokenizer: PreTrainedTokenizerBase,
random_dataset_params: Params) -> None:
"""Same seed should yield identical outputs, even if global RNGs change.
This guards against accidental reliance on Python's random or np.random
in RandomDataset after moving to numpy.default_rng.
"""
p = random_dataset_params
common_seed = 123
dataset_a = RandomDataset(random_seed=common_seed)
dataset_b = RandomDataset(random_seed=common_seed)
a = _collect_samples(dataset_a,
hf_tokenizer,
num_requests=p.num_requests,
prefix_len=p.prefix_len,
range_ratio=p.range_ratio,
input_len=p.input_len,
output_len=p.output_len)
# Perturb global RNG state to ensure isolation
random.seed(999)
_ = [random.random() for _ in range(100)]
np.random.seed(888)
_ = [np.random.random() for _ in range(100)]
b = _collect_samples(dataset_b,
hf_tokenizer,
num_requests=p.num_requests,
prefix_len=p.prefix_len,
range_ratio=p.range_ratio,
input_len=p.input_len,
output_len=p.output_len)
assert a == b
@pytest.mark.benchmark
def test_random_dataset_different_seeds(
hf_tokenizer: PreTrainedTokenizerBase,
random_dataset_params: Params) -> None:
"""Different seeds should change outputs with overwhelming likelihood."""
p = random_dataset_params
seed_a = 0
dataset_a = RandomDataset(random_seed=seed_a)
a = _collect_samples(dataset_a,
hf_tokenizer,
num_requests=p.num_requests,
prefix_len=p.prefix_len,
range_ratio=p.range_ratio,
input_len=p.input_len,
output_len=p.output_len)
seed_b = 999
dataset_b = RandomDataset(random_seed=seed_b)
# Perturb global RNG with same seed as dataset_a to ensure isolation
random.seed(seed_a)
np.random.seed(seed_a)
b = _collect_samples(dataset_b,
hf_tokenizer,
num_requests=p.num_requests,
prefix_len=p.prefix_len,
range_ratio=p.range_ratio,
input_len=p.input_len,
output_len=p.output_len)
assert a != b
# -----------------------------
# RandomMultiModalDataset tests
# -----------------------------
def _mm_fingerprint_sample(
req: SampleRequest,
) -> tuple[str, int, int, int, list[str]]:
"""Create a compact fingerprint for multimodal samples.
Includes:
- prompt string
- prompt_len
- expected_output_len
- count of multimodal items
- per-item type and URL prefix (e.g., 'data:image/jpeg;base64,')
"""
items = req.multi_modal_data or []
item_prefixes: list[str] = []
for it in items:
if isinstance(it, dict) and it.get("type") == "image_url":
url = it.get("image_url", {}).get("url", "")
# Only keep a short identifying prefix to avoid huge strings
item_prefixes.append(f"image:{url[:22]}")
elif isinstance(it, dict) and it.get("type") == "video_url":
url = it.get("video_url", {}).get("url", "")
item_prefixes.append(f"video:{url[:22]}")
else:
item_prefixes.append("unknown:")
return (req.prompt, req.prompt_len, req.expected_output_len, len(items),
item_prefixes)
def _collect_mm_samples(
dataset: RandomMultiModalDataset,
tokenizer: PreTrainedTokenizerBase,
*,
num_requests: int = 8,
prefix_len: int = 3,
range_ratio: float = 0.0,
input_len: int = 20,
output_len: int = 5,
base_items_per_request: int = 2,
num_mm_items_range_ratio: float = 0.0,
limit_mm_per_prompt: Optional[dict[str, int]] = None,
bucket_config: Optional[dict[tuple[int, int, int], float]] = None,
enable_multimodal_chat: bool = False,
) -> list[SampleRequest]:
if limit_mm_per_prompt is None:
limit_mm_per_prompt = {"image": 5, "video": 0}
if bucket_config is None:
bucket_config = {(32, 32, 1): 0.5, (52, 64, 1): 0.5}
return dataset.sample(
tokenizer=tokenizer,
num_requests=num_requests,
prefix_len=prefix_len,
range_ratio=range_ratio,
input_len=input_len,
output_len=output_len,
base_items_per_request=base_items_per_request,
num_mm_items_range_ratio=num_mm_items_range_ratio,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
enable_multimodal_chat=enable_multimodal_chat,
)
@pytest.mark.benchmark
def test_random_mm_same_seed(hf_tokenizer: PreTrainedTokenizerBase) -> None:
seed = 42
ds_a = RandomMultiModalDataset(random_seed=seed)
ds_b = RandomMultiModalDataset(random_seed=seed)
a = _collect_mm_samples(ds_a, hf_tokenizer)
b = _collect_mm_samples(ds_b, hf_tokenizer)
fa = [_mm_fingerprint_sample(s) for s in a]
fb = [_mm_fingerprint_sample(s) for s in b]
assert fa == fb
@pytest.mark.benchmark
def test_random_mm_different_seeds(
hf_tokenizer: PreTrainedTokenizerBase,
) -> None:
ds_a = RandomMultiModalDataset(random_seed=0)
ds_b = RandomMultiModalDataset(random_seed=999)
a = _collect_mm_samples(ds_a, hf_tokenizer)
b = _collect_mm_samples(ds_b, hf_tokenizer)
fa = [_mm_fingerprint_sample(s) for s in a]
fb = [_mm_fingerprint_sample(s) for s in b]
assert fa != fb
@pytest.mark.benchmark
def test_random_mm_respects_limits(
hf_tokenizer: PreTrainedTokenizerBase,
) -> None:
ds = RandomMultiModalDataset(random_seed=0)
# Requesting 3 items with a per-prompt limit of 1 should error per current
# design (dataset refuses to silently clamp below the requested baseline).
with pytest.raises(ValueError):
_collect_mm_samples(
ds,
hf_tokenizer,
num_requests=12,
base_items_per_request=3,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt={"image": 1, "video": 0},
bucket_config={(32, 32, 1): 1.0},
)
@pytest.mark.benchmark
def test_random_mm_zero_prob_entries_are_removed(
hf_tokenizer: PreTrainedTokenizerBase,
) -> None:
ds = RandomMultiModalDataset(random_seed=0)
# Second bucket has zero probability and should be ignored after
# normalization
samples = _collect_mm_samples(
ds,
hf_tokenizer,
num_requests=6,
base_items_per_request=2,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt={"image": 10, "video": 0},
bucket_config={(32, 32, 1): 1.0, (52, 64, 1): 0.0},
)
for s in samples:
assert isinstance(s.multi_modal_data, list)
typed_mm = cast(list[dict[str, Any]], s.multi_modal_data)
for it in typed_mm:
assert it.get("type") == "image_url"
@pytest.mark.benchmark
def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None:
ds = RandomMultiModalDataset(random_seed=0)
samples = _collect_mm_samples(
ds,
hf_tokenizer,
num_requests=5,
base_items_per_request=0,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt={"image": 5, "video": 0},
bucket_config={(32, 32, 1): 1.0},
)
for s in samples:
assert s.multi_modal_data == []
@pytest.mark.benchmark
def test_random_mm_num_items_per_prompt(
hf_tokenizer: PreTrainedTokenizerBase) -> None:
ds = RandomMultiModalDataset(random_seed=0)
# Fixed number of images per prompt
# set num_mm_items_range_ratio to 0.0
# TODO: modify video values when video sampling is implemented
samples_fixed_items = _collect_mm_samples(
ds,
hf_tokenizer,
num_requests=5,
base_items_per_request=3,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt={"image": 3, "video": 0},
bucket_config={(32, 32, 1): 1.0},
)
# Must have 5 requests each with 3 mm items per prompt
assert len(samples_fixed_items) == 5
for s in samples_fixed_items:
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
assert len(mm_data) == 3
for it in mm_data:
assert it.get("type") == "image_url"
@pytest.mark.benchmark
def test_random_mm_bucket_config_not_mutated(
hf_tokenizer: PreTrainedTokenizerBase,
) -> None:
ds = RandomMultiModalDataset(random_seed=0)
# This bucket config is not normalized to sum to 1
# and has more buckets than requested images
original = {(32, 32, 1): 0.2, (52, 64, 1): 6, (25, 64, 1): 3}
# Keep a snapshot to compare after sampling
snapshot = dict(original)
_ = _collect_mm_samples(
ds,
hf_tokenizer,
num_requests=4,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt={"image": 1, "video": 0},
bucket_config=original,
)
# Ensure the original dict content is unchanged
assert original == snapshot
# Vary number of mm items per prompt
# set num_mm_items_range_ratio to 0.5
samples_varying_items = _collect_mm_samples(
ds,
hf_tokenizer,
num_requests=5,
base_items_per_request=2,
num_mm_items_range_ratio=0.5,
limit_mm_per_prompt={"image": 4, "video": 0},
bucket_config={(32, 32, 1): 1.0},
)
# Must have 5 requests each with less than 4 mm items per prompt
# but at least 1 mm item per prompt
assert len(samples_varying_items) == 5
for s in samples_varying_items:
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
assert len(mm_data) <= 4
assert len(mm_data) >= 1
for it in mm_data:
assert it.get("type") == "image_url"
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment