test_colbert.py 11.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
"""Tests for ColBERT late interaction scoring.

Tests are parametrized across multiple ColBERT backbones to ensure the
generic ColBERT support works with different encoder architectures.
"""
8
9
10
11

import pytest
import torch

12
from vllm.entrypoints.pooling.scoring.utils import compute_maxsim_score
13

14
15
16
17
18
19
20
21
22
# -----------------------------------------------------------------------
# Model definitions: (model_name, colbert_dim, extra vllm_runner kwargs)
# -----------------------------------------------------------------------
COLBERT_MODELS = {
    "bert": {
        "model": "answerdotai/answerai-colbert-small-v1",
        "colbert_dim": 96,
        "max_model_len": 512,
        "extra_kwargs": {},
23
24
25
26
27
28
        "hf_comparison": {
            "weights_file": "model.safetensors",
            "weights_key": "linear.weight",
            "trust_remote_code": False,
            "model_cls": "BertModel",
        },
29
30
31
32
33
34
35
36
37
38
    },
    "modernbert": {
        "model": "lightonai/GTE-ModernColBERT-v1",
        "colbert_dim": 128,
        "max_model_len": 299,
        "extra_kwargs": {
            "hf_overrides": {
                "architectures": ["ColBERTModernBertModel"],
            },
        },
39
40
41
42
43
44
        "hf_comparison": {
            "weights_file": "1_Dense/model.safetensors",
            "weights_key": "linear.weight",
            "trust_remote_code": False,
            "model_cls": "AutoModel",
        },
45
46
47
48
49
50
51
52
53
54
    },
    "jina": {
        "model": "jinaai/jina-colbert-v2",
        "colbert_dim": 128,
        "max_model_len": 8192,
        "extra_kwargs": {
            "hf_overrides": {
                "architectures": ["ColBERTJinaRobertaModel"],
            },
        },
55
56
57
58
59
60
        "hf_comparison": {
            "weights_file": "model.safetensors",
            "weights_key": "linear.weight",
            "trust_remote_code": True,
            "model_cls": "AutoModel",
        },
61
    },
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    "lfm2": {
        "model": "LiquidAI/LFM2-ColBERT-350M",
        "colbert_dim": 128,
        "max_model_len": 511,
        "extra_kwargs": {
            "hf_overrides": {
                "architectures": ["ColBERTLfm2Model"],
            },
        },
        "hf_comparison": {
            "weights_file": "1_Dense/model.safetensors",
            "weights_key": "linear.weight",
            "trust_remote_code": False,
            "model_cls": "AutoModel",
        },
    },
78
}
79

80

81
82
83
84
85
86
87
88
89
90
91
92
93
TEXTS_1 = [
    "What is the capital of France?",
    "What is the capital of Germany?",
]

TEXTS_2 = [
    "The capital of France is Paris.",
    "The capital of Germany is Berlin.",
]

DTYPE = "half"


94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
def _load_hf_model(model_name: str, hf_spec: dict, device: torch.device):
    """Load HF model on the given device with a compatible attention impl."""
    from transformers import AutoModel, BertModel

    cls = BertModel if hf_spec["model_cls"] == "BertModel" else AutoModel
    trust = hf_spec.get("trust_remote_code", False)

    # Flash / Triton kernels require GPU tensors; fall back to eager on CPU.
    extra = {}
    if device.type == "cpu":
        extra["attn_implementation"] = "eager"

    model = cls.from_pretrained(
        model_name,
        trust_remote_code=trust,
        **extra,
    ).to(device)
    model.eval()
    return model


def _load_projection_weight(model_name: str, hf_spec: dict, device: torch.device):
    """Download and return the ColBERT linear projection weight."""
    from huggingface_hub import hf_hub_download
    from safetensors.torch import load_file

    path = hf_hub_download(model_name, filename=hf_spec["weights_file"])
    weights = load_file(path)
    return weights[hf_spec["weights_key"]].to(device)


def _compute_hf_colbert_embeddings(model, tokenizer, linear_weight, texts, device):
    """Run HF model + projection and return L2-normalised token embeddings."""
    import torch.nn.functional as F

    embeddings = []
    for text in texts:
        inputs = tokenizer(text, return_tensors="pt").to(device)
        with torch.no_grad():
            hidden = model(**inputs).last_hidden_state.float()
            projected = F.linear(hidden, linear_weight.float())
            normalised = F.normalize(projected, p=2, dim=-1)
            embeddings.append(normalised.squeeze(0).cpu())
    return embeddings


def _assert_embeddings_close(vllm_outputs, hf_embeddings):
    """Assert that vLLM and HuggingFace embeddings match."""
    for i, (hf_emb, vllm_out) in enumerate(zip(hf_embeddings, vllm_outputs)):
        vllm_emb = torch.as_tensor(vllm_out).float()

        assert hf_emb.shape == vllm_emb.shape, (
            f"Shape mismatch for text {i}: HF {hf_emb.shape} vs vLLM {vllm_emb.shape}"
        )

        torch.testing.assert_close(
            vllm_emb,
            hf_emb,
            rtol=1e-2,
            atol=1e-2,
            msg=f"Embedding mismatch for text {i}",
        )
156
157
158
159
160
161
162
163


@pytest.fixture(params=list(COLBERT_MODELS.keys()), scope="module")
def colbert_spec(request):
    """Return the model spec dict for the current parametrization."""
    return COLBERT_MODELS[request.param]


164
@pytest.fixture(scope="module")
165
166
167
def colbert_model_name(colbert_spec):
    return colbert_spec["model"]

168

169
170
171
172
173
174
175
176
177
178
179
180
181
@pytest.fixture(scope="module")
def colbert_dim(colbert_spec):
    return colbert_spec["colbert_dim"]


@pytest.fixture(scope="module")
def colbert_max_model_len(colbert_spec):
    return colbert_spec["max_model_len"]


@pytest.fixture(scope="module")
def colbert_extra_kwargs(colbert_spec):
    return colbert_spec["extra_kwargs"]
182

183
184
185
186
187
188
189
190

def test_colbert_token_embed(
    vllm_runner,
    colbert_model_name,
    colbert_dim,
    colbert_max_model_len,
    colbert_extra_kwargs,
):
191
192
193
194
195
    """Test that ColBERT model produces token embeddings."""
    with vllm_runner(
        colbert_model_name,
        runner="pooling",
        dtype=DTYPE,
196
        max_model_len=colbert_max_model_len,
197
        enforce_eager=True,
198
        **colbert_extra_kwargs,
199
200
201
202
    ) as vllm_model:
        outputs = vllm_model.token_embed([TEXTS_1[0]])

        assert len(outputs) == 1
203
        emb = torch.as_tensor(outputs[0])
204
        assert emb.dim() == 2
205
        assert emb.shape[1] == colbert_dim
206
207
208
        assert emb.shape[0] > 1


209
210
211
212
213
214
def test_colbert_late_interaction_1_to_1(
    vllm_runner,
    colbert_model_name,
    colbert_max_model_len,
    colbert_extra_kwargs,
):
215
216
217
218
219
    """Test ColBERT late interaction scoring with 1:1 query-document pair."""
    with vllm_runner(
        colbert_model_name,
        runner="pooling",
        dtype=DTYPE,
220
        max_model_len=colbert_max_model_len,
221
        enforce_eager=True,
222
        **colbert_extra_kwargs,
223
224
225
226
    ) as vllm_model:
        q_outputs = vllm_model.token_embed([TEXTS_1[0]])
        d_outputs = vllm_model.token_embed([TEXTS_2[0]])

227
228
        q_emb = torch.as_tensor(q_outputs[0])
        d_emb = torch.as_tensor(d_outputs[0])
229
230
231
232
233
234
235
236
237

        manual_score = compute_maxsim_score(q_emb, d_emb).item()

        vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2[0])

        assert len(vllm_scores) == 1
        assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01)


238
239
240
241
242
243
def test_colbert_late_interaction_1_to_N(
    vllm_runner,
    colbert_model_name,
    colbert_max_model_len,
    colbert_extra_kwargs,
):
244
245
246
247
248
    """Test ColBERT late interaction scoring with 1:N query-documents."""
    with vllm_runner(
        colbert_model_name,
        runner="pooling",
        dtype=DTYPE,
249
        max_model_len=colbert_max_model_len,
250
        enforce_eager=True,
251
        **colbert_extra_kwargs,
252
253
254
255
    ) as vllm_model:
        q_outputs = vllm_model.token_embed([TEXTS_1[0]])
        d_outputs = vllm_model.token_embed(TEXTS_2)

256
        q_emb = torch.as_tensor(q_outputs[0])
257
258
259

        manual_scores = []
        for d_out in d_outputs:
260
            d_emb = torch.as_tensor(d_out)
261
262
263
264
265
266
267
268
269
            manual_scores.append(compute_maxsim_score(q_emb, d_emb).item())

        vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2)

        assert len(vllm_scores) == 2
        for i in range(2):
            assert vllm_scores[i] == pytest.approx(manual_scores[i], rel=0.01)


270
271
272
273
274
275
def test_colbert_late_interaction_N_to_N(
    vllm_runner,
    colbert_model_name,
    colbert_max_model_len,
    colbert_extra_kwargs,
):
276
277
278
279
280
    """Test ColBERT late interaction scoring with N:N query-documents."""
    with vllm_runner(
        colbert_model_name,
        runner="pooling",
        dtype=DTYPE,
281
        max_model_len=colbert_max_model_len,
282
        enforce_eager=True,
283
        **colbert_extra_kwargs,
284
285
286
287
288
289
    ) as vllm_model:
        q_outputs = vllm_model.token_embed(TEXTS_1)
        d_outputs = vllm_model.token_embed(TEXTS_2)

        manual_scores = []
        for q_out, d_out in zip(q_outputs, d_outputs):
290
291
            q_emb = torch.as_tensor(q_out)
            d_emb = torch.as_tensor(d_out)
292
293
294
295
296
297
298
299
300
            manual_scores.append(compute_maxsim_score(q_emb, d_emb).item())

        vllm_scores = vllm_model.score(TEXTS_1, TEXTS_2)

        assert len(vllm_scores) == 2
        for i in range(2):
            assert vllm_scores[i] == pytest.approx(manual_scores[i], rel=0.01)


301
302
303
304
305
306
307
def test_colbert_relevance_ordering(
    vllm_runner,
    colbert_model_name,
    colbert_max_model_len,
    colbert_extra_kwargs,
):
    """Test that ColBERT scores relevant documents higher than irrelevant."""
308
309
310
311
312
313
314
315
316
317
318
    query = "What is machine learning?"
    documents = [
        "Machine learning is a subset of artificial intelligence.",
        "Python is a programming language.",
        "Deep learning uses neural networks.",
    ]

    with vllm_runner(
        colbert_model_name,
        runner="pooling",
        dtype=DTYPE,
319
        max_model_len=colbert_max_model_len,
320
        enforce_eager=True,
321
        **colbert_extra_kwargs,
322
323
324
325
326
327
328
329
    ) as vllm_model:
        scores = vllm_model.score(query, documents)

        assert len(scores) == 3
        assert scores[0] > scores[1], "ML doc should score higher than Python doc"
        assert scores[2] > scores[1], "DL doc should score higher than Python doc"


330
331
332
333
334
335
def test_colbert_embed_not_supported(
    vllm_runner,
    colbert_model_name,
    colbert_max_model_len,
    colbert_extra_kwargs,
):
336
337
338
339
340
341
    """Test that ColBERT model does not support 'embed' task."""
    with (
        vllm_runner(
            colbert_model_name,
            runner="pooling",
            dtype=DTYPE,
342
            max_model_len=colbert_max_model_len,
343
            enforce_eager=True,
344
            **colbert_extra_kwargs,
345
346
347
348
349
350
        ) as vllm_model,
        pytest.raises(ValueError, match="Embedding API is not supported"),
    ):
        vllm_model.embed([TEXTS_1[0]])


351
352
353
354
@pytest.mark.parametrize("backend", list(COLBERT_MODELS.keys()))
def test_colbert_hf_comparison(vllm_runner, backend):
    """Test that vLLM ColBERT embeddings match HuggingFace for each backend."""
    from transformers import AutoTokenizer
355

356
357
    spec = COLBERT_MODELS[backend]
    hf_spec = spec["hf_comparison"]
358
    model_name = spec["model"]
359
360
    assert isinstance(model_name, str)
    assert isinstance(hf_spec, dict)
361
362
363
364
365
366
367
368
369
370
371
372
    test_texts = [TEXTS_1[0], TEXTS_2[0]]

    with vllm_runner(
        model_name,
        runner="pooling",
        dtype="float32",
        max_model_len=spec["max_model_len"],
        enforce_eager=True,
        **spec["extra_kwargs"],
    ) as vllm_model:
        vllm_outputs = vllm_model.token_embed(test_texts)

373
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
374
375
376

    hf_tokenizer = AutoTokenizer.from_pretrained(
        model_name,
377
        trust_remote_code=hf_spec.get("trust_remote_code", False),
378
    )
379
380
381
382
383
384
385
386
387
    hf_model = _load_hf_model(model_name, hf_spec, device)
    linear_weight = _load_projection_weight(model_name, hf_spec, device)

    hf_embeddings = _compute_hf_colbert_embeddings(
        hf_model,
        hf_tokenizer,
        linear_weight,
        test_texts,
        device,
388
389
390
    )

    _assert_embeddings_close(vllm_outputs, hf_embeddings)