test_colbert.py 12.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
"""Tests for ColBERT late interaction scoring.

Tests are parametrized across multiple ColBERT backbones to ensure the
generic ColBERT support works with different encoder architectures.
"""
8
9
10
11
12
13

import pytest
import torch

from vllm.entrypoints.pooling.score.utils import compute_maxsim_score

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# -----------------------------------------------------------------------
# Model definitions: (model_name, colbert_dim, extra vllm_runner kwargs)
# -----------------------------------------------------------------------
COLBERT_MODELS = {
    "bert": {
        "model": "answerdotai/answerai-colbert-small-v1",
        "colbert_dim": 96,
        "max_model_len": 512,
        "extra_kwargs": {},
    },
    "modernbert": {
        "model": "lightonai/GTE-ModernColBERT-v1",
        "colbert_dim": 128,
        "max_model_len": 299,
        "extra_kwargs": {
            "hf_overrides": {
                "architectures": ["ColBERTModernBertModel"],
            },
        },
    },
    "jina": {
        "model": "jinaai/jina-colbert-v2",
        "colbert_dim": 128,
        "max_model_len": 8192,
        "extra_kwargs": {
            "hf_overrides": {
                "architectures": ["ColBERTJinaRobertaModel"],
            },
        },
    },
}
45
46
47
48
49
50
51
52
53
54
55
56
57
58

TEXTS_1 = [
    "What is the capital of France?",
    "What is the capital of Germany?",
]

TEXTS_2 = [
    "The capital of France is Paris.",
    "The capital of Germany is Berlin.",
]

DTYPE = "half"


59
60
61
62
63
64
65
66
67
68
69
# -----------------------------------------------------------------------
# Fixtures
# -----------------------------------------------------------------------


@pytest.fixture(params=list(COLBERT_MODELS.keys()), scope="module")
def colbert_spec(request):
    """Return the model spec dict for the current parametrization."""
    return COLBERT_MODELS[request.param]


70
@pytest.fixture(scope="module")
71
72
73
def colbert_model_name(colbert_spec):
    return colbert_spec["model"]

74

75
76
77
78
79
80
81
82
83
84
85
86
87
@pytest.fixture(scope="module")
def colbert_dim(colbert_spec):
    return colbert_spec["colbert_dim"]


@pytest.fixture(scope="module")
def colbert_max_model_len(colbert_spec):
    return colbert_spec["max_model_len"]


@pytest.fixture(scope="module")
def colbert_extra_kwargs(colbert_spec):
    return colbert_spec["extra_kwargs"]
88

89
90
91
92
93
94
95
96
97
98
99
100
101

# -----------------------------------------------------------------------
# Tests
# -----------------------------------------------------------------------


def test_colbert_token_embed(
    vllm_runner,
    colbert_model_name,
    colbert_dim,
    colbert_max_model_len,
    colbert_extra_kwargs,
):
102
103
104
105
106
    """Test that ColBERT model produces token embeddings."""
    with vllm_runner(
        colbert_model_name,
        runner="pooling",
        dtype=DTYPE,
107
        max_model_len=colbert_max_model_len,
108
        enforce_eager=True,
109
        **colbert_extra_kwargs,
110
111
112
113
114
115
    ) as vllm_model:
        outputs = vllm_model.token_embed([TEXTS_1[0]])

        assert len(outputs) == 1
        emb = torch.tensor(outputs[0])
        assert emb.dim() == 2
116
        assert emb.shape[1] == colbert_dim
117
118
119
        assert emb.shape[0] > 1


120
121
122
123
124
125
def test_colbert_late_interaction_1_to_1(
    vllm_runner,
    colbert_model_name,
    colbert_max_model_len,
    colbert_extra_kwargs,
):
126
127
128
129
130
    """Test ColBERT late interaction scoring with 1:1 query-document pair."""
    with vllm_runner(
        colbert_model_name,
        runner="pooling",
        dtype=DTYPE,
131
        max_model_len=colbert_max_model_len,
132
        enforce_eager=True,
133
        **colbert_extra_kwargs,
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
    ) as vllm_model:
        q_outputs = vllm_model.token_embed([TEXTS_1[0]])
        d_outputs = vllm_model.token_embed([TEXTS_2[0]])

        q_emb = torch.tensor(q_outputs[0])
        d_emb = torch.tensor(d_outputs[0])

        manual_score = compute_maxsim_score(q_emb, d_emb).item()

        vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2[0])

        assert len(vllm_scores) == 1
        assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01)


149
150
151
152
153
154
def test_colbert_late_interaction_1_to_N(
    vllm_runner,
    colbert_model_name,
    colbert_max_model_len,
    colbert_extra_kwargs,
):
155
156
157
158
159
    """Test ColBERT late interaction scoring with 1:N query-documents."""
    with vllm_runner(
        colbert_model_name,
        runner="pooling",
        dtype=DTYPE,
160
        max_model_len=colbert_max_model_len,
161
        enforce_eager=True,
162
        **colbert_extra_kwargs,
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    ) as vllm_model:
        q_outputs = vllm_model.token_embed([TEXTS_1[0]])
        d_outputs = vllm_model.token_embed(TEXTS_2)

        q_emb = torch.tensor(q_outputs[0])

        manual_scores = []
        for d_out in d_outputs:
            d_emb = torch.tensor(d_out)
            manual_scores.append(compute_maxsim_score(q_emb, d_emb).item())

        vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2)

        assert len(vllm_scores) == 2
        for i in range(2):
            assert vllm_scores[i] == pytest.approx(manual_scores[i], rel=0.01)


181
182
183
184
185
186
def test_colbert_late_interaction_N_to_N(
    vllm_runner,
    colbert_model_name,
    colbert_max_model_len,
    colbert_extra_kwargs,
):
187
188
189
190
191
    """Test ColBERT late interaction scoring with N:N query-documents."""
    with vllm_runner(
        colbert_model_name,
        runner="pooling",
        dtype=DTYPE,
192
        max_model_len=colbert_max_model_len,
193
        enforce_eager=True,
194
        **colbert_extra_kwargs,
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    ) as vllm_model:
        q_outputs = vllm_model.token_embed(TEXTS_1)
        d_outputs = vllm_model.token_embed(TEXTS_2)

        manual_scores = []
        for q_out, d_out in zip(q_outputs, d_outputs):
            q_emb = torch.tensor(q_out)
            d_emb = torch.tensor(d_out)
            manual_scores.append(compute_maxsim_score(q_emb, d_emb).item())

        vllm_scores = vllm_model.score(TEXTS_1, TEXTS_2)

        assert len(vllm_scores) == 2
        for i in range(2):
            assert vllm_scores[i] == pytest.approx(manual_scores[i], rel=0.01)


212
213
214
215
216
217
218
def test_colbert_relevance_ordering(
    vllm_runner,
    colbert_model_name,
    colbert_max_model_len,
    colbert_extra_kwargs,
):
    """Test that ColBERT scores relevant documents higher than irrelevant."""
219
220
221
222
223
224
225
226
227
228
229
    query = "What is machine learning?"
    documents = [
        "Machine learning is a subset of artificial intelligence.",
        "Python is a programming language.",
        "Deep learning uses neural networks.",
    ]

    with vllm_runner(
        colbert_model_name,
        runner="pooling",
        dtype=DTYPE,
230
        max_model_len=colbert_max_model_len,
231
        enforce_eager=True,
232
        **colbert_extra_kwargs,
233
234
235
236
237
238
239
240
    ) as vllm_model:
        scores = vllm_model.score(query, documents)

        assert len(scores) == 3
        assert scores[0] > scores[1], "ML doc should score higher than Python doc"
        assert scores[2] > scores[1], "DL doc should score higher than Python doc"


241
242
243
244
245
246
def test_colbert_embed_not_supported(
    vllm_runner,
    colbert_model_name,
    colbert_max_model_len,
    colbert_extra_kwargs,
):
247
248
249
250
251
252
    """Test that ColBERT model does not support 'embed' task."""
    with (
        vllm_runner(
            colbert_model_name,
            runner="pooling",
            dtype=DTYPE,
253
            max_model_len=colbert_max_model_len,
254
            enforce_eager=True,
255
            **colbert_extra_kwargs,
256
257
258
259
260
261
        ) as vllm_model,
        pytest.raises(ValueError, match="Embedding API is not supported"),
    ):
        vllm_model.embed([TEXTS_1[0]])


262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
# -----------------------------------------------------------------------
# Per-model HuggingFace comparison tests
# -----------------------------------------------------------------------


def _assert_embeddings_close(vllm_outputs, hf_embeddings):
    """Assert that vLLM and HuggingFace embeddings match."""
    for i, (hf_emb, vllm_out) in enumerate(zip(hf_embeddings, vllm_outputs)):
        vllm_emb = torch.tensor(vllm_out).float()

        assert hf_emb.shape == vllm_emb.shape, (
            f"Shape mismatch for text {i}: HF {hf_emb.shape} vs vLLM {vllm_emb.shape}"
        )

        torch.testing.assert_close(
            vllm_emb,
            hf_emb,
            rtol=1e-2,
            atol=1e-2,
            msg=f"Embedding mismatch for text {i}",
        )


def test_colbert_hf_comparison_bert(vllm_runner):
    """Test that vLLM ColBERT produces same embeddings as HuggingFace (BERT)."""
287
288
289
290
291
    import torch.nn.functional as F
    from huggingface_hub import hf_hub_download
    from safetensors.torch import load_file
    from transformers import AutoTokenizer, BertModel

292
    model_name = COLBERT_MODELS["bert"]["model"]
293
294
295
    test_texts = [TEXTS_1[0], TEXTS_2[0]]

    with vllm_runner(
296
        model_name,
297
298
299
300
301
302
303
        runner="pooling",
        dtype="float32",
        max_model_len=512,
        enforce_eager=True,
    ) as vllm_model:
        vllm_outputs = vllm_model.token_embed(test_texts)

304
305
    hf_tokenizer = AutoTokenizer.from_pretrained(model_name)
    hf_bert = BertModel.from_pretrained(model_name)
306
307
    hf_bert.eval()

308
    weights_path = hf_hub_download(model_name, filename="model.safetensors")
309
310
311
312
313
314
315
316
317
318
319
320
321
    weights = load_file(weights_path)
    linear_weight = weights["linear.weight"]  # [96, 384]

    hf_embeddings = []
    for text in test_texts:
        inputs = hf_tokenizer(text, return_tensors="pt")
        with torch.no_grad():
            outputs = hf_bert(**inputs)
            hidden_states = outputs.last_hidden_state
            token_emb = F.linear(hidden_states, linear_weight)
            token_emb = F.normalize(token_emb, p=2, dim=-1)
            hf_embeddings.append(token_emb.squeeze(0).float())

322
    _assert_embeddings_close(vllm_outputs, hf_embeddings)
323
324


325
326
327
328
329
330
331
def test_colbert_hf_comparison_modernbert(vllm_runner):
    """Test that vLLM ColBERT produces same embeddings as HuggingFace
    (ModernBERT)."""
    import torch.nn.functional as F
    from huggingface_hub import hf_hub_download
    from safetensors.torch import load_file
    from transformers import AutoModel, AutoTokenizer
332

333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
    spec = COLBERT_MODELS["modernbert"]
    model_name = spec["model"]
    test_texts = [TEXTS_1[0], TEXTS_2[0]]

    with vllm_runner(
        model_name,
        runner="pooling",
        dtype="float32",
        max_model_len=spec["max_model_len"],
        enforce_eager=True,
        **spec["extra_kwargs"],
    ) as vllm_model:
        vllm_outputs = vllm_model.token_embed(test_texts)

    hf_tokenizer = AutoTokenizer.from_pretrained(model_name)
    hf_model = AutoModel.from_pretrained(model_name)
    hf_model.eval()

    # Load projection from sentence-transformers 1_Dense layer
    dense_path = hf_hub_download(model_name, filename="1_Dense/model.safetensors")
    dense_weights = load_file(dense_path)
    linear_weight = dense_weights["linear.weight"]  # [128, 768]

    hf_embeddings = []
    for text in test_texts:
        inputs = hf_tokenizer(text, return_tensors="pt")
        with torch.no_grad():
            outputs = hf_model(**inputs)
            hidden_states = outputs.last_hidden_state
            token_emb = F.linear(hidden_states, linear_weight)
            token_emb = F.normalize(token_emb, p=2, dim=-1)
            hf_embeddings.append(token_emb.squeeze(0).float())

    _assert_embeddings_close(vllm_outputs, hf_embeddings)


def test_colbert_hf_comparison_jina(vllm_runner):
    """Test that vLLM ColBERT produces same embeddings as HuggingFace
    (Jina XLM-RoBERTa)."""
    import torch.nn.functional as F
    from huggingface_hub import hf_hub_download
    from safetensors.torch import load_file
    from transformers import AutoModel, AutoTokenizer

    spec = COLBERT_MODELS["jina"]
    model_name = spec["model"]
    test_texts = [TEXTS_1[0], TEXTS_2[0]]

    with vllm_runner(
        model_name,
        runner="pooling",
        dtype="float32",
        max_model_len=spec["max_model_len"],
        enforce_eager=True,
        **spec["extra_kwargs"],
    ) as vllm_model:
        vllm_outputs = vllm_model.token_embed(test_texts)

    hf_tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        trust_remote_code=True,
    )
    hf_model = AutoModel.from_pretrained(
        model_name,
        trust_remote_code=True,
    )
    hf_model.eval()

    # Load projection from main checkpoint
    weights_path = hf_hub_download(model_name, filename="model.safetensors")
    weights = load_file(weights_path)
    linear_weight = weights["linear.weight"]  # [128, 1024]

    hf_embeddings = []
    for text in test_texts:
        inputs = hf_tokenizer(text, return_tensors="pt")
        with torch.no_grad():
            outputs = hf_model(**inputs)
            hidden_states = outputs.last_hidden_state
            token_emb = F.linear(hidden_states.float(), linear_weight.float())
            token_emb = F.normalize(token_emb, p=2, dim=-1)
            hf_embeddings.append(token_emb.squeeze(0).float())

    _assert_embeddings_close(vllm_outputs, hf_embeddings)