test_process_multi_modal_uuids.py 6.09 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest

from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import CacheConfig, ModelConfig, VllmConfig
9
from vllm.multimodal.parse import parse_mm_uuids
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from vllm.renderers.hf import HfRenderer
from vllm.tokenizers.registry import tokenizer_args_from_config

cherry_pil_image = ImageAsset("cherry_blossom").pil_image
stop_pil_image = ImageAsset("stop_sign").pil_image
baby_reading_np_ndarrays = VideoAsset("baby_reading").np_ndarrays


def _build_renderer(
    *, mm_cache_gb: float = 4.0, enable_prefix_caching: bool = True
) -> HfRenderer:
    model_config = ModelConfig(
        model="Qwen/Qwen2.5-VL-3B-Instruct",
        max_model_len=128,
        mm_processor_cache_gb=mm_cache_gb,
    )

    vllm_config = VllmConfig(
        model_config=model_config,
        cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching),
    )

    _, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)

    return HfRenderer.from_config(
        vllm_config,
        tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
    )


def test_multi_modal_uuids_length_mismatch_raises():
    renderer = _build_renderer()

    mm_data = {"image": [cherry_pil_image, stop_pil_image]}

45
46
47
48
49
50
51
52
53
54
    # Mismatch: 2 items but only 0 uuids provided
    mm_uuids = {"image": []}  # type: ignore[var-annotated]

    mm_processor = renderer.get_mm_processor()
    mm_data_items = mm_processor.info.parse_mm_data(mm_data)
    mm_uuid_items = parse_mm_uuids(mm_uuids)

    with pytest.raises(ValueError, match="must have same length as"):
        renderer._process_mm_uuids(mm_data, mm_data_items, mm_uuid_items, "req-1a")

55
56
57
58
    # Mismatch: 2 items but only 1 uuid provided
    mm_uuids = {"image": ["hash_cherry"]}

    mm_processor = renderer.get_mm_processor()
59
60
    mm_data_items = mm_processor.info.parse_mm_data(mm_data)
    mm_uuid_items = parse_mm_uuids(mm_uuids)
61
62

    with pytest.raises(ValueError, match="must have same length as"):
63
        renderer._process_mm_uuids(mm_data, mm_data_items, mm_uuid_items, "req-1b")
64
65
66
67
68
69
70
71
72
73
74
75
76
77


def test_multi_modal_uuids_missing_modality_raises():
    renderer = _build_renderer()

    mm_data = {
        "image": [cherry_pil_image],
        "video": None,
    }

    # Only image uuids provided; video missing should raise
    mm_uuids = {"image": ["hash_cherry"]}

    mm_processor = renderer.get_mm_processor()
78
79
    mm_data_items = mm_processor.info.parse_mm_data(mm_data)
    mm_uuid_items = parse_mm_uuids(mm_uuids)
80
81

    with pytest.raises(ValueError, match="is empty but .* is missing"):
82
        renderer._process_mm_uuids(mm_data, mm_data_items, mm_uuid_items, "req-2")
83
84
85
86
87
88
89
90
91
92
93


@pytest.mark.parametrize(
    "mm_cache_gb, enable_prefix_caching",
    [
        (4.0, True),  # default behavior
        (4.0, False),  # prefix caching disabled
        (0.0, True),  # processor cache disabled
    ],
)
def test_multi_modal_uuids_accepts_none_and_passes_through(
94
    mm_cache_gb: float, enable_prefix_caching: bool
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
):
    renderer = _build_renderer(
        mm_cache_gb=mm_cache_gb,
        enable_prefix_caching=enable_prefix_caching,
    )

    mm_data = {
        "image": [cherry_pil_image, stop_pil_image],
        "video": baby_reading_np_ndarrays,
    }

    # Use a consistent two-image scenario across all configurations
    mm_uuids = {"image": [None, "hash_stop"], "video": None}

    mm_processor = renderer.get_mm_processor()
110
111
112
    mm_data_items = mm_processor.info.parse_mm_data(mm_data)
    mm_uuid_items = parse_mm_uuids(mm_uuids)

113
    processed_mm_uuids = renderer._process_mm_uuids(
114
        mm_data, mm_data_items, mm_uuid_items, "req-3"
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    )

    assert processed_mm_uuids == mm_uuids


@pytest.mark.parametrize(
    "mm_cache_gb, enable_prefix_caching",
    [
        (4.0, True),  # default behavior
        (4.0, False),  # prefix caching disabled
        (0.0, True),  # processor cache disabled
    ],
)
def test_multi_modal_uuids_accepts_empty(
129
    mm_cache_gb: float, enable_prefix_caching: bool
130
131
132
133
134
135
136
137
):
    renderer = _build_renderer(
        mm_cache_gb=mm_cache_gb,
        enable_prefix_caching=enable_prefix_caching,
    )

    # While None means cached multi-modal input requiring UUIDs
    # an empty list means no multi-modal input
138
139
    mm_data = {"image": [], "video": [], "audio": None}  # type: ignore[var-annotated]
    mm_uuids = {"image": [], "video": None, "audio": []}  # type: ignore[var-annotated]
140
141

    mm_processor = renderer.get_mm_processor()
142
143
144
    mm_data_items = mm_processor.info.parse_mm_data(mm_data)
    mm_uuid_items = parse_mm_uuids(mm_uuids)

145
    processed_mm_uuids = renderer._process_mm_uuids(
146
        mm_data, mm_data_items, mm_uuid_items, "req-4"
147
148
149
150
151
    )

    assert processed_mm_uuids == mm_uuids


152
def test_multi_modal_uuids_ignored_when_caching_disabled():
153
154
155
156
157
158
159
160
161
162
163
164
    # When both processor cache is 0 and prefix caching disabled, the
    # processor builds overrides from request id instead of using user UUIDs.
    renderer = _build_renderer(mm_cache_gb=0.0, enable_prefix_caching=False)

    request_id = "req-42"
    mm_data = {
        "image": [cherry_pil_image, stop_pil_image],
        "video": baby_reading_np_ndarrays,
    }
    mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": ["hash_video"]}

    mm_processor = renderer.get_mm_processor()
165
166
167
    mm_data_items = mm_processor.info.parse_mm_data(mm_data)
    mm_uuid_items = parse_mm_uuids(mm_uuids)

168
    processed_mm_uuids = renderer._process_mm_uuids(
169
        mm_data, mm_data_items, mm_uuid_items, request_id
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    )

    # Expect request-id-based overrides are passed through
    assert set(mm_uuids.keys()) == {"image", "video"}
    assert len(mm_uuids["image"]) == 2
    assert len(mm_uuids["video"]) == 1
    assert processed_mm_uuids["image"][0].startswith(
        f"{request_id}-image-"
    ) and processed_mm_uuids["image"][0].endswith("-0")
    assert processed_mm_uuids["image"][1].startswith(
        f"{request_id}-image-"
    ) and processed_mm_uuids["image"][1].endswith("-1")
    assert processed_mm_uuids["video"][0].startswith(
        f"{request_id}-video-"
    ) and processed_mm_uuids["video"][0].endswith("-0")