test_colpali.py 8.88 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for ColPali late interaction model for multi-modal retrieval.

ColPali is a multi-vector retrieval model based on PaliGemma backbone
(SigLIP + Gemma) with ColBERT-style late interaction scoring (MaxSim).
It produces per-token embeddings for both text and image inputs.
"""

from io import BytesIO

12
import pybase64 as base64
13
14
15
16
17
18
19
20
import pytest
import torch
from PIL import Image

from vllm.entrypoints.chat_utils import (
    ChatCompletionContentPartImageParam,
    ChatCompletionContentPartTextParam,
)
21
from vllm.entrypoints.pooling.scoring.typing import ScoreMultiModalParam
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

from ....conftest import VllmRunner

MODELS = [
    "vidore/colpali-v1.3-hf",
]

EMBED_DIMS = {
    "vidore/colpali-v1.3-hf": 128,
}

TEXT_QUERIES = [
    "What is the capital of France?",
    "Describe the contents of the document.",
]

TEXT_DOCUMENTS = [
    "The capital of France is Paris.",
    "This document contains important financial data.",
]

DTYPE = "half"
GPU_MEMORY_UTILIZATION = 0.7


def _make_base64_image(
    width: int = 64, height: int = 64, color: tuple[int, int, int] = (255, 0, 0)
) -> str:
    """Create a small solid-color PNG image and return its base64 data URI."""
    img = Image.new("RGB", (width, height), color)
    buf = BytesIO()
    img.save(buf, format="PNG")
    b64 = base64.b64encode(buf.getvalue()).decode()
    return f"data:image/png;base64,{b64}"


def _make_image_mm_param(
    image_uri: str,
    text: str | None = None,
) -> ScoreMultiModalParam:
    """Build a ScoreMultiModalParam containing an image (and optional text)."""
    content: list = [
        ChatCompletionContentPartImageParam(
            type="image_url",
            image_url={"url": image_uri},
        ),
    ]
    if text is not None:
        content.append(
            ChatCompletionContentPartTextParam(type="text", text=text),
        )
    return ScoreMultiModalParam(content=content)


def _run_token_embed_test(
    vllm_runner: type[VllmRunner],
    model: str,
    *,
    dtype: str,
) -> None:
    """Verify per-token embedding shape and L2 normalization."""
    with vllm_runner(
        model,
        runner="pooling",
        dtype=dtype,
        max_model_len=4096,
        enforce_eager=True,
        gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
    ) as vllm_model:
        outputs = vllm_model.token_embed([TEXT_QUERIES[0]])

        assert len(outputs) == 1
        emb = torch.tensor(outputs[0])
        # Token embeddings should be 2D: [num_tokens, embed_dim]
        assert emb.dim() == 2
        assert emb.shape[1] == EMBED_DIMS[model]
        assert emb.shape[0] > 1

        # Verify L2 normalization
        norms = torch.norm(emb, p=2, dim=-1)
        torch.testing.assert_close(
            norms,
            torch.ones_like(norms),
            rtol=1e-2,
            atol=1e-2,
        )


def _run_late_interaction_test(
    vllm_runner: type[VllmRunner],
    model: str,
    *,
    dtype: str,
) -> None:
    """Verify MaxSim scoring matches manual computation."""
117
    from vllm.entrypoints.pooling.scoring.utils import compute_maxsim_score
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323

    with vllm_runner(
        model,
        runner="pooling",
        dtype=dtype,
        max_model_len=4096,
        enforce_eager=True,
        gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
    ) as vllm_model:
        q_outputs = vllm_model.token_embed([TEXT_QUERIES[0]])
        d_outputs = vllm_model.token_embed([TEXT_DOCUMENTS[0]])

        q_emb = torch.tensor(q_outputs[0])
        d_emb = torch.tensor(d_outputs[0])

        manual_score = compute_maxsim_score(q_emb, d_emb).item()

        vllm_scores = vllm_model.score(TEXT_QUERIES[0], TEXT_DOCUMENTS[0])

        assert len(vllm_scores) == 1
        assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01)


def _run_relevance_test(
    vllm_runner: type[VllmRunner],
    model: str,
    *,
    dtype: str,
) -> None:
    """Verify that relevant documents score higher than irrelevant ones."""
    query = "What is machine learning?"
    documents = [
        "Machine learning is a subset of artificial intelligence.",
        "The weather forecast shows rain tomorrow.",
        "Deep learning uses neural networks for complex tasks.",
    ]

    with vllm_runner(
        model,
        runner="pooling",
        dtype=dtype,
        max_model_len=4096,
        enforce_eager=True,
        gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
    ) as vllm_model:
        scores = vllm_model.score(query, documents)

        assert len(scores) == 3
        assert scores[0] > scores[1], "ML doc should score higher than weather doc"
        assert scores[2] > scores[1], "DL doc should score higher than weather doc"


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [DTYPE])
def test_colpali_token_embed(
    vllm_runner,
    model: str,
    dtype: str,
) -> None:
    _run_token_embed_test(vllm_runner, model, dtype=dtype)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [DTYPE])
def test_colpali_late_interaction_scoring(
    vllm_runner,
    model: str,
    dtype: str,
) -> None:
    _run_late_interaction_test(vllm_runner, model, dtype=dtype)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [DTYPE])
def test_colpali_relevance_ordering(
    vllm_runner,
    model: str,
    dtype: str,
) -> None:
    _run_relevance_test(vllm_runner, model, dtype=dtype)


# ── Multimodal scoring tests ────────────────────────────────


def _run_multimodal_text_query_image_docs_test(
    vllm_runner: type[VllmRunner],
    model: str,
    *,
    dtype: str,
) -> None:
    """Score a text query against image documents via the multimodal path."""
    red_image = _make_base64_image(64, 64, color=(255, 0, 0))
    blue_image = _make_base64_image(64, 64, color=(0, 0, 255))

    query = "Describe the red object"
    image_docs = [
        _make_image_mm_param(red_image),
        _make_image_mm_param(blue_image),
    ]

    with vllm_runner(
        model,
        runner="pooling",
        dtype=dtype,
        max_model_len=4096,
        enforce_eager=True,
        gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
    ) as vllm_model:
        scores = vllm_model.llm.score(query, image_docs)

        assert len(scores) == 2
        for s in scores:
            assert isinstance(s.outputs.score, float)


def _run_multimodal_mixed_docs_test(
    vllm_runner: type[VllmRunner],
    model: str,
    *,
    dtype: str,
) -> None:
    """Score a text query against a mix of text and image documents."""
    red_image = _make_base64_image(64, 64, color=(255, 0, 0))

    query = "What is the capital of France?"
    documents: list = [
        "The capital of France is Paris.",
        _make_image_mm_param(red_image),
    ]

    with vllm_runner(
        model,
        runner="pooling",
        dtype=dtype,
        max_model_len=4096,
        enforce_eager=True,
        gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
    ) as vllm_model:
        scores = vllm_model.llm.score(query, documents)

        assert len(scores) == 2
        for s in scores:
            assert isinstance(s.outputs.score, float)
        # Text document about France should score higher than a random image
        assert scores[0].outputs.score > scores[1].outputs.score


def _run_multimodal_image_query_text_docs_test(
    vllm_runner: type[VllmRunner],
    model: str,
    *,
    dtype: str,
) -> None:
    """Score an image query against text documents."""
    red_image = _make_base64_image(64, 64, color=(255, 0, 0))
    image_query = _make_image_mm_param(red_image, text="red color")

    documents = [
        "A bright red sports car.",
        "The weather forecast shows rain tomorrow.",
    ]

    with vllm_runner(
        model,
        runner="pooling",
        dtype=dtype,
        max_model_len=4096,
        enforce_eager=True,
        gpu_memory_utilization=GPU_MEMORY_UTILIZATION,
    ) as vllm_model:
        scores = vllm_model.llm.score(image_query, documents)

        assert len(scores) == 2
        for s in scores:
            assert isinstance(s.outputs.score, float)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [DTYPE])
def test_colpali_multimodal_text_query_image_docs(
    vllm_runner,
    model: str,
    dtype: str,
) -> None:
    _run_multimodal_text_query_image_docs_test(vllm_runner, model, dtype=dtype)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [DTYPE])
def test_colpali_multimodal_mixed_docs(
    vllm_runner,
    model: str,
    dtype: str,
) -> None:
    _run_multimodal_mixed_docs_test(vllm_runner, model, dtype=dtype)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [DTYPE])
def test_colpali_multimodal_image_query_text_docs(
    vllm_runner,
    model: str,
    dtype: str,
) -> None:
    _run_multimodal_image_query_text_docs_test(vllm_runner, model, dtype=dtype)