test_colbert.py 12 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
"""Tests for ColBERT late interaction scoring.

Tests are parametrized across multiple ColBERT backbones to ensure the
generic ColBERT support works with different encoder architectures.
"""
8
9
10
11
12
13

import pytest
import torch

from vllm.entrypoints.pooling.score.utils import compute_maxsim_score

14
15
16
17
18
19
20
21
22
# -----------------------------------------------------------------------
# Model definitions: (model_name, colbert_dim, extra vllm_runner kwargs)
# -----------------------------------------------------------------------
COLBERT_MODELS = {
    "bert": {
        "model": "answerdotai/answerai-colbert-small-v1",
        "colbert_dim": 96,
        "max_model_len": 512,
        "extra_kwargs": {},
23
24
25
26
27
28
        "hf_comparison": {
            "weights_file": "model.safetensors",
            "weights_key": "linear.weight",
            "trust_remote_code": False,
            "model_cls": "BertModel",
        },
29
30
31
32
33
34
35
36
37
38
    },
    "modernbert": {
        "model": "lightonai/GTE-ModernColBERT-v1",
        "colbert_dim": 128,
        "max_model_len": 299,
        "extra_kwargs": {
            "hf_overrides": {
                "architectures": ["ColBERTModernBertModel"],
            },
        },
39
40
41
42
43
44
        "hf_comparison": {
            "weights_file": "1_Dense/model.safetensors",
            "weights_key": "linear.weight",
            "trust_remote_code": False,
            "model_cls": "AutoModel",
        },
45
46
47
48
49
50
51
52
53
54
    },
    "jina": {
        "model": "jinaai/jina-colbert-v2",
        "colbert_dim": 128,
        "max_model_len": 8192,
        "extra_kwargs": {
            "hf_overrides": {
                "architectures": ["ColBERTJinaRobertaModel"],
            },
        },
55
56
57
58
59
60
        "hf_comparison": {
            "weights_file": "model.safetensors",
            "weights_key": "linear.weight",
            "trust_remote_code": True,
            "model_cls": "AutoModel",
        },
61
    },
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    "lfm2": {
        "model": "LiquidAI/LFM2-ColBERT-350M",
        "colbert_dim": 128,
        "max_model_len": 511,
        "extra_kwargs": {
            "hf_overrides": {
                "architectures": ["ColBERTLfm2Model"],
            },
        },
        "hf_comparison": {
            "weights_file": "1_Dense/model.safetensors",
            "weights_key": "linear.weight",
            "trust_remote_code": False,
            "model_cls": "AutoModel",
        },
    },
78
}
79

80

81
82
83
84
85
86
87
88
89
90
91
92
93
TEXTS_1 = [
    "What is the capital of France?",
    "What is the capital of Germany?",
]

TEXTS_2 = [
    "The capital of France is Paris.",
    "The capital of Germany is Berlin.",
]

DTYPE = "half"


94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
def _load_hf_model(model_name: str, hf_spec: dict, device: torch.device):
    """Load HF model on the given device with a compatible attention impl."""
    from transformers import AutoModel, BertModel

    cls = BertModel if hf_spec["model_cls"] == "BertModel" else AutoModel
    trust = hf_spec.get("trust_remote_code", False)

    # Flash / Triton kernels require GPU tensors; fall back to eager on CPU.
    extra = {}
    if device.type == "cpu":
        extra["attn_implementation"] = "eager"

    model = cls.from_pretrained(
        model_name,
        trust_remote_code=trust,
        **extra,
    ).to(device)
    model.eval()
112
113
114
115
116
117
118
119

    # Transformers 5.0 weight materialization can clear non-persistent
    # buffers (e.g. rotary inv_freq) that were registered with
    # persistent=False.  Re-compute them so the model produces valid output.
    for mod in model.modules():
        if hasattr(mod, "_compute_inv_freq") and hasattr(mod, "inv_freq"):
            mod.inv_freq = mod._compute_inv_freq(device=device)

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    return model


def _load_projection_weight(model_name: str, hf_spec: dict, device: torch.device):
    """Download and return the ColBERT linear projection weight."""
    from huggingface_hub import hf_hub_download
    from safetensors.torch import load_file

    path = hf_hub_download(model_name, filename=hf_spec["weights_file"])
    weights = load_file(path)
    return weights[hf_spec["weights_key"]].to(device)


def _compute_hf_colbert_embeddings(model, tokenizer, linear_weight, texts, device):
    """Run HF model + projection and return L2-normalised token embeddings."""
    import torch.nn.functional as F

    embeddings = []
    for text in texts:
        inputs = tokenizer(text, return_tensors="pt").to(device)
        with torch.no_grad():
            hidden = model(**inputs).last_hidden_state.float()
            projected = F.linear(hidden, linear_weight.float())
            normalised = F.normalize(projected, p=2, dim=-1)
            embeddings.append(normalised.squeeze(0).cpu())
    return embeddings


def _assert_embeddings_close(vllm_outputs, hf_embeddings):
    """Assert that vLLM and HuggingFace embeddings match."""
    for i, (hf_emb, vllm_out) in enumerate(zip(hf_embeddings, vllm_outputs)):
        vllm_emb = torch.as_tensor(vllm_out).float()

        assert hf_emb.shape == vllm_emb.shape, (
            f"Shape mismatch for text {i}: HF {hf_emb.shape} vs vLLM {vllm_emb.shape}"
        )

        torch.testing.assert_close(
            vllm_emb,
            hf_emb,
            rtol=1e-2,
            atol=1e-2,
            msg=f"Embedding mismatch for text {i}",
        )
164
165
166
167
168
169
170
171


@pytest.fixture(params=list(COLBERT_MODELS.keys()), scope="module")
def colbert_spec(request):
    """Return the model spec dict for the current parametrization."""
    return COLBERT_MODELS[request.param]


172
@pytest.fixture(scope="module")
173
174
175
def colbert_model_name(colbert_spec):
    return colbert_spec["model"]

176

177
178
179
180
181
182
183
184
185
186
187
188
189
@pytest.fixture(scope="module")
def colbert_dim(colbert_spec):
    return colbert_spec["colbert_dim"]


@pytest.fixture(scope="module")
def colbert_max_model_len(colbert_spec):
    return colbert_spec["max_model_len"]


@pytest.fixture(scope="module")
def colbert_extra_kwargs(colbert_spec):
    return colbert_spec["extra_kwargs"]
190

191
192
193
194
195
196
197
198

def test_colbert_token_embed(
    vllm_runner,
    colbert_model_name,
    colbert_dim,
    colbert_max_model_len,
    colbert_extra_kwargs,
):
199
200
201
202
203
    """Test that ColBERT model produces token embeddings."""
    with vllm_runner(
        colbert_model_name,
        runner="pooling",
        dtype=DTYPE,
204
        max_model_len=colbert_max_model_len,
205
        enforce_eager=True,
206
        **colbert_extra_kwargs,
207
208
209
210
    ) as vllm_model:
        outputs = vllm_model.token_embed([TEXTS_1[0]])

        assert len(outputs) == 1
211
        emb = torch.as_tensor(outputs[0])
212
        assert emb.dim() == 2
213
        assert emb.shape[1] == colbert_dim
214
215
216
        assert emb.shape[0] > 1


217
218
219
220
221
222
def test_colbert_late_interaction_1_to_1(
    vllm_runner,
    colbert_model_name,
    colbert_max_model_len,
    colbert_extra_kwargs,
):
223
224
225
226
227
    """Test ColBERT late interaction scoring with 1:1 query-document pair."""
    with vllm_runner(
        colbert_model_name,
        runner="pooling",
        dtype=DTYPE,
228
        max_model_len=colbert_max_model_len,
229
        enforce_eager=True,
230
        **colbert_extra_kwargs,
231
232
233
234
    ) as vllm_model:
        q_outputs = vllm_model.token_embed([TEXTS_1[0]])
        d_outputs = vllm_model.token_embed([TEXTS_2[0]])

235
236
        q_emb = torch.as_tensor(q_outputs[0])
        d_emb = torch.as_tensor(d_outputs[0])
237
238
239
240
241
242
243
244
245

        manual_score = compute_maxsim_score(q_emb, d_emb).item()

        vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2[0])

        assert len(vllm_scores) == 1
        assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01)


246
247
248
249
250
251
def test_colbert_late_interaction_1_to_N(
    vllm_runner,
    colbert_model_name,
    colbert_max_model_len,
    colbert_extra_kwargs,
):
252
253
254
255
256
    """Test ColBERT late interaction scoring with 1:N query-documents."""
    with vllm_runner(
        colbert_model_name,
        runner="pooling",
        dtype=DTYPE,
257
        max_model_len=colbert_max_model_len,
258
        enforce_eager=True,
259
        **colbert_extra_kwargs,
260
261
262
263
    ) as vllm_model:
        q_outputs = vllm_model.token_embed([TEXTS_1[0]])
        d_outputs = vllm_model.token_embed(TEXTS_2)

264
        q_emb = torch.as_tensor(q_outputs[0])
265
266
267

        manual_scores = []
        for d_out in d_outputs:
268
            d_emb = torch.as_tensor(d_out)
269
270
271
272
273
274
275
276
277
            manual_scores.append(compute_maxsim_score(q_emb, d_emb).item())

        vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2)

        assert len(vllm_scores) == 2
        for i in range(2):
            assert vllm_scores[i] == pytest.approx(manual_scores[i], rel=0.01)


278
279
280
281
282
283
def test_colbert_late_interaction_N_to_N(
    vllm_runner,
    colbert_model_name,
    colbert_max_model_len,
    colbert_extra_kwargs,
):
284
285
286
287
288
    """Test ColBERT late interaction scoring with N:N query-documents."""
    with vllm_runner(
        colbert_model_name,
        runner="pooling",
        dtype=DTYPE,
289
        max_model_len=colbert_max_model_len,
290
        enforce_eager=True,
291
        **colbert_extra_kwargs,
292
293
294
295
296
297
    ) as vllm_model:
        q_outputs = vllm_model.token_embed(TEXTS_1)
        d_outputs = vllm_model.token_embed(TEXTS_2)

        manual_scores = []
        for q_out, d_out in zip(q_outputs, d_outputs):
298
299
            q_emb = torch.as_tensor(q_out)
            d_emb = torch.as_tensor(d_out)
300
301
302
303
304
305
306
307
308
            manual_scores.append(compute_maxsim_score(q_emb, d_emb).item())

        vllm_scores = vllm_model.score(TEXTS_1, TEXTS_2)

        assert len(vllm_scores) == 2
        for i in range(2):
            assert vllm_scores[i] == pytest.approx(manual_scores[i], rel=0.01)


309
310
311
312
313
314
315
def test_colbert_relevance_ordering(
    vllm_runner,
    colbert_model_name,
    colbert_max_model_len,
    colbert_extra_kwargs,
):
    """Test that ColBERT scores relevant documents higher than irrelevant."""
316
317
318
319
320
321
322
323
324
325
326
    query = "What is machine learning?"
    documents = [
        "Machine learning is a subset of artificial intelligence.",
        "Python is a programming language.",
        "Deep learning uses neural networks.",
    ]

    with vllm_runner(
        colbert_model_name,
        runner="pooling",
        dtype=DTYPE,
327
        max_model_len=colbert_max_model_len,
328
        enforce_eager=True,
329
        **colbert_extra_kwargs,
330
331
332
333
334
335
336
337
    ) as vllm_model:
        scores = vllm_model.score(query, documents)

        assert len(scores) == 3
        assert scores[0] > scores[1], "ML doc should score higher than Python doc"
        assert scores[2] > scores[1], "DL doc should score higher than Python doc"


338
339
340
341
342
343
def test_colbert_embed_not_supported(
    vllm_runner,
    colbert_model_name,
    colbert_max_model_len,
    colbert_extra_kwargs,
):
344
345
346
347
348
349
    """Test that ColBERT model does not support 'embed' task."""
    with (
        vllm_runner(
            colbert_model_name,
            runner="pooling",
            dtype=DTYPE,
350
            max_model_len=colbert_max_model_len,
351
            enforce_eager=True,
352
            **colbert_extra_kwargs,
353
354
355
356
357
358
        ) as vllm_model,
        pytest.raises(ValueError, match="Embedding API is not supported"),
    ):
        vllm_model.embed([TEXTS_1[0]])


359
360
361
362
@pytest.mark.parametrize("backend", list(COLBERT_MODELS.keys()))
def test_colbert_hf_comparison(vllm_runner, backend):
    """Test that vLLM ColBERT embeddings match HuggingFace for each backend."""
    from transformers import AutoTokenizer
363

364
365
    spec = COLBERT_MODELS[backend]
    hf_spec = spec["hf_comparison"]
366
    model_name = spec["model"]
367
368
    assert isinstance(model_name, str)
    assert isinstance(hf_spec, dict)
369
370
371
372
373
374
375
376
377
378
379
380
    test_texts = [TEXTS_1[0], TEXTS_2[0]]

    with vllm_runner(
        model_name,
        runner="pooling",
        dtype="float32",
        max_model_len=spec["max_model_len"],
        enforce_eager=True,
        **spec["extra_kwargs"],
    ) as vllm_model:
        vllm_outputs = vllm_model.token_embed(test_texts)

381
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
382
383
384

    hf_tokenizer = AutoTokenizer.from_pretrained(
        model_name,
385
        trust_remote_code=hf_spec.get("trust_remote_code", False),
386
    )
387
388
389
390
391
392
393
394
395
    hf_model = _load_hf_model(model_name, hf_spec, device)
    linear_weight = _load_projection_weight(model_name, hf_spec, device)

    hf_embeddings = _compute_hf_colbert_embeddings(
        hf_model,
        hf_tokenizer,
        linear_weight,
        test_texts,
        device,
396
397
398
    )

    _assert_embeddings_close(vllm_outputs, hf_embeddings)