test_process_multi_modal_uuids.py 5.35 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest

from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config

cherry_pil_image = ImageAsset("cherry_blossom").pil_image
stop_pil_image = ImageAsset("stop_sign").pil_image
baby_reading_np_ndarrays = VideoAsset("baby_reading").np_ndarrays


def _build_renderer(
    *, mm_cache_gb: float = 4.0, enable_prefix_caching: bool = True
) -> HfRenderer:
    model_config = ModelConfig(
        model="Qwen/Qwen2.5-VL-3B-Instruct",
        max_model_len=128,
        mm_processor_cache_gb=mm_cache_gb,
    )

    vllm_config = VllmConfig(
        model_config=model_config,
        cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching),
    )

    _, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)

    return HfRenderer.from_config(
        vllm_config,
        tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
    )


def test_multi_modal_uuids_length_mismatch_raises():
    renderer = _build_renderer()

    mm_data = {"image": [cherry_pil_image, stop_pil_image]}

    # Mismatch: 2 items but only 1 uuid provided
    mm_uuids = {"image": ["hash_cherry"]}

    mm_processor = renderer.get_mm_processor()
    mm_items = mm_processor.info.parse_mm_data(mm_data)

    with pytest.raises(ValueError, match="must have same length as"):
        renderer._process_mm_uuids(mm_data, mm_items, mm_uuids, "req-1")


def test_multi_modal_uuids_missing_modality_raises():
    renderer = _build_renderer()

    mm_data = {
        "image": [cherry_pil_image],
        "video": None,
    }

    # Only image uuids provided; video missing should raise
    mm_uuids = {"image": ["hash_cherry"]}

    mm_processor = renderer.get_mm_processor()
    mm_items = mm_processor.info.parse_mm_data(mm_data)

    with pytest.raises(ValueError, match="is empty but .* is missing"):
        renderer._process_mm_uuids(mm_data, mm_items, mm_uuids, "req-2")


@pytest.mark.parametrize(
    "mm_cache_gb, enable_prefix_caching",
    [
        (4.0, True),  # default behavior
        (4.0, False),  # prefix caching disabled
        (0.0, True),  # processor cache disabled
    ],
)
def test_multi_modal_uuids_accepts_none_and_passes_through(
    monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool
):
    renderer = _build_renderer(
        mm_cache_gb=mm_cache_gb,
        enable_prefix_caching=enable_prefix_caching,
    )

    mm_data = {
        "image": [cherry_pil_image, stop_pil_image],
        "video": baby_reading_np_ndarrays,
    }

    # Use a consistent two-image scenario across all configurations
    mm_uuids = {"image": [None, "hash_stop"], "video": None}

    mm_processor = renderer.get_mm_processor()
    mm_items = mm_processor.info.parse_mm_data(mm_data)
    processed_mm_uuids = renderer._process_mm_uuids(
        mm_data, mm_items, mm_uuids, "req-3"
    )

    assert processed_mm_uuids == mm_uuids


@pytest.mark.parametrize(
    "mm_cache_gb, enable_prefix_caching",
    [
        (4.0, True),  # default behavior
        (4.0, False),  # prefix caching disabled
        (0.0, True),  # processor cache disabled
    ],
)
def test_multi_modal_uuids_accepts_empty(
    monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool
):
    renderer = _build_renderer(
        mm_cache_gb=mm_cache_gb,
        enable_prefix_caching=enable_prefix_caching,
    )

    # While None means cached multi-modal input requiring UUIDs
    # an empty list means no multi-modal input
    mm_data = {"image": [], "video": []}  # type: ignore[var-annotated]
    mm_uuids = {"image": [], "video": None}  # type: ignore[var-annotated]

    mm_processor = renderer.get_mm_processor()
    mm_items = mm_processor.info.parse_mm_data(mm_data)
    processed_mm_uuids = renderer._process_mm_uuids(
        mm_data, mm_items, mm_uuids, "req-4"
    )

    assert processed_mm_uuids == mm_uuids


def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
    # When both processor cache is 0 and prefix caching disabled, the
    # processor builds overrides from request id instead of using user UUIDs.
    renderer = _build_renderer(mm_cache_gb=0.0, enable_prefix_caching=False)

    request_id = "req-42"
    mm_data = {
        "image": [cherry_pil_image, stop_pil_image],
        "video": baby_reading_np_ndarrays,
    }
    mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": ["hash_video"]}

    mm_processor = renderer.get_mm_processor()
    mm_items = mm_processor.info.parse_mm_data(mm_data)
    processed_mm_uuids = renderer._process_mm_uuids(
        mm_data, mm_items, mm_uuids, request_id
    )

    # Expect request-id-based overrides are passed through
    assert set(mm_uuids.keys()) == {"image", "video"}
    assert len(mm_uuids["image"]) == 2
    assert len(mm_uuids["video"]) == 1
    assert processed_mm_uuids["image"][0].startswith(
        f"{request_id}-image-"
    ) and processed_mm_uuids["image"][0].endswith("-0")
    assert processed_mm_uuids["image"][1].startswith(
        f"{request_id}-image-"
    ) and processed_mm_uuids["image"][1].endswith("-1")
    assert processed_mm_uuids["video"][0].startswith(
        f"{request_id}-video-"
    ) and processed_mm_uuids["video"][0].endswith("-0")