Unverified Commit 439afa4e authored by Ilya Boytsov's avatar Ilya Boytsov Committed by GitHub
Browse files

feat: Add ColBERT late interaction model support (#33686)


Signed-off-by: default avatarIlya Boytsov <ilyaboytsov1805@gmail.com>
Signed-off-by: default avatarIlya Boytsov <boytsovpanamera@mail.ru>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent fa4e0fb0
...@@ -307,6 +307,62 @@ An OpenAI client example can be found here: [examples/pooling/embed/openai_embed ...@@ -307,6 +307,62 @@ An OpenAI client example can be found here: [examples/pooling/embed/openai_embed
## Specific models ## Specific models
### ColBERT Late Interaction Models
[ColBERT](https://arxiv.org/abs/2004.12832) (Contextualized Late Interaction over BERT) is a retrieval model that uses per-token embeddings and MaxSim scoring for document ranking. Unlike single-vector embedding models, ColBERT retains token-level representations and computes relevance scores through late interaction, providing better accuracy while being more efficient than cross-encoders.
vLLM supports ColBERT models for reranking tasks, automatically applying MaxSim scoring for query-document relevance:
```shell
vllm serve answerdotai/answerai-colbert-small-v1
```
Currently supports ColBERT models with standard BERT encoders (e.g., `answerdotai/answerai-colbert-small-v1`, `colbert-ir/colbertv2.0`).
ColBERT models with modified encoder architectures are not yet supported, including BERT variants with rotary embeddings (e.g., `jinaai/jina-colbert-v2`) or other custom encoders (e.g., `LiquidAI/LFM2-ColBERT-350M`).
If your standard BERT ColBERT model's config doesn't specify the architecture as `HF_ColBERT`, override it with:
```shell
vllm serve your-colbert-model --hf-overrides '{"architectures": ["HF_ColBERT"]}'
```
Then you can use the rerank endpoint:
```shell
curl -s http://localhost:8000/rerank -H "Content-Type: application/json" -d '{
"model": "answerdotai/answerai-colbert-small-v1",
"query": "What is machine learning?",
"documents": [
"Machine learning is a subset of artificial intelligence.",
"Python is a programming language.",
"Deep learning uses neural networks."
]
}'
```
Or the score endpoint:
```shell
curl -s http://localhost:8000/score -H "Content-Type: application/json" -d '{
"model": "answerdotai/answerai-colbert-small-v1",
"text_1": "What is machine learning?",
"text_2": ["Machine learning is a subset of AI.", "The weather is sunny."]
}'
```
You can also get the raw token embeddings using the pooling endpoint with `token_embed` task:
```shell
curl -s http://localhost:8000/pooling -H "Content-Type: application/json" -d '{
"model": "answerdotai/answerai-colbert-small-v1",
"input": "What is machine learning?",
"task": "token_embed"
}'
```
An example can be found here: [examples/pooling/score/colbert_rerank_online.py](../../examples/pooling/score/colbert_rerank_online.py)
### BAAI/bge-m3 ### BAAI/bge-m3
The `BAAI/bge-m3` model comes with extra weights for sparse and colbert embeddings but unfortunately in its `config.json` The `BAAI/bge-m3` model comes with extra weights for sparse and colbert embeddings but unfortunately in its `config.json`
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Example of using ColBERT late interaction model for reranking.
ColBERT (Contextualized Late Interaction over BERT) uses per-token embeddings
and MaxSim scoring for document reranking, providing better accuracy than
single-vector models while being more efficient than cross-encoders.
Start the server with:
vllm serve answerdotai/answerai-colbert-small-v1
Then run this script:
python colbert_rerank_online.py
"""
import json
import requests
url = "http://127.0.0.1:8000/rerank"
headers = {"accept": "application/json", "Content-Type": "application/json"}
data = {
"model": "answerdotai/answerai-colbert-small-v1",
"query": "What is machine learning?",
"documents": [
"Machine learning is a subset of artificial intelligence.",
"Python is a programming language.",
"Deep learning uses neural networks for complex tasks.",
"The weather today is sunny.",
],
}
def main():
response = requests.post(url, headers=headers, json=data)
if response.status_code == 200:
print("ColBERT Rerank Request successful!")
result = response.json()
print(json.dumps(result, indent=2))
# Show ranked results
print("\nRanked documents (most relevant first):")
for item in result["results"]:
doc_idx = item["index"]
score = item["relevance_score"]
print(f" Score {score:.4f}: {data['documents'][doc_idx]}")
else:
print(f"Request failed with status code: {response.status_code}")
print(response.text)
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Online API tests for ColBERT late interaction scoring."""
import pytest
import requests
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.pooling.score.protocol import RerankResponse, ScoreResponse
# ColBERT model - using answerai-colbert-small-v1 as it's a smaller model
MODEL_NAME = "answerdotai/answerai-colbert-small-v1"
COLBERT_DIM = 96 # This model uses 96-dimensional output
DTYPE = "half"
MAX_MODEL_LEN = 512
@pytest.fixture(scope="module")
def server():
args = [
"--max-model-len",
str(MAX_MODEL_LEN),
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_colbert_rerank(server: RemoteOpenAIServer, model_name: str):
"""Test ColBERT rerank endpoint."""
query = "What is the capital of France?"
documents = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
]
rerank_response = requests.post(
server.url_for("rerank"),
json={
"model": model_name,
"query": query,
"documents": documents,
},
)
rerank_response.raise_for_status()
rerank = RerankResponse.model_validate(rerank_response.json())
assert rerank.id is not None
assert rerank.results is not None
assert len(rerank.results) == 2
# The relevant document (Paris) should have higher score
paris_result = next(r for r in rerank.results if r.index == 1)
brazil_result = next(r for r in rerank.results if r.index == 0)
assert paris_result.relevance_score > brazil_result.relevance_score
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_colbert_rerank_top_n(server: RemoteOpenAIServer, model_name: str):
"""Test ColBERT rerank with top_n parameter."""
query = "What is the capital of France?"
documents = [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.",
"Machine learning is a field of AI.",
]
rerank_response = requests.post(
server.url_for("rerank"),
json={
"model": model_name,
"query": query,
"documents": documents,
"top_n": 2,
},
)
rerank_response.raise_for_status()
rerank = RerankResponse.model_validate(rerank_response.json())
assert len(rerank.results) == 2
# Top result should be about Paris (index 1)
assert rerank.results[0].index == 1
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_colbert_score(server: RemoteOpenAIServer, model_name: str):
"""Test ColBERT score endpoint."""
text_1 = "What is the capital of France?"
text_2 = ["The capital of France is Paris.", "Python is a language."]
score_response = requests.post(
server.url_for("score"),
json={
"model": model_name,
"text_1": text_1,
"text_2": text_2,
},
)
score_response.raise_for_status()
score = ScoreResponse.model_validate(score_response.json())
assert score.id is not None
assert score.data is not None
assert len(score.data) == 2
# The relevant document should have higher score
assert score.data[0].score > score.data[1].score
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_colbert_token_embed(server: RemoteOpenAIServer, model_name: str):
"""Test ColBERT token_embed task via pooling endpoint."""
text = "What is the capital of France?"
pooling_response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": text,
"task": "token_embed",
},
)
pooling_response.raise_for_status()
pooling = pooling_response.json()
assert "data" in pooling
assert len(pooling["data"]) == 1
# Token embeddings should be 2D
embeddings = pooling["data"][0]["data"]
assert isinstance(embeddings, list)
assert len(embeddings) > 0 # Should have tokens
assert len(embeddings[0]) == COLBERT_DIM
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_colbert_embed_not_supported(server: RemoteOpenAIServer, model_name: str):
"""Test that ColBERT model does not support 'embed' task."""
text = "What is the capital of France?"
pooling_response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": text,
"task": "embed",
},
)
# Should return error
assert pooling_response.status_code == 400
assert "Task embed is not supported" in pooling_response.text
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for ColBERT late interaction scoring."""
import pytest
import torch
from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
# ColBERT model - using answerai-colbert-small-v1 as it's a smaller model
# suitable for testing (based on BERT-base)
COLBERT_MODEL = "answerdotai/answerai-colbert-small-v1"
COLBERT_DIM = 96 # This model uses 96-dimensional output
TEXTS_1 = [
"What is the capital of France?",
"What is the capital of Germany?",
]
TEXTS_2 = [
"The capital of France is Paris.",
"The capital of Germany is Berlin.",
]
DTYPE = "half"
@pytest.fixture(scope="module")
def colbert_model_name():
return COLBERT_MODEL
def test_colbert_token_embed(vllm_runner, colbert_model_name):
"""Test that ColBERT model produces token embeddings."""
with vllm_runner(
colbert_model_name,
runner="pooling",
dtype=DTYPE,
max_model_len=512,
enforce_eager=True,
) as vllm_model:
# Get token embeddings for a single text
outputs = vllm_model.token_embed([TEXTS_1[0]])
assert len(outputs) == 1
# Token embeddings should be 2D: [num_tokens, colbert_dim]
emb = torch.tensor(outputs[0])
assert emb.dim() == 2
assert emb.shape[1] == COLBERT_DIM
# Should have at least a few tokens
assert emb.shape[0] > 1
def test_colbert_late_interaction_1_to_1(vllm_runner, colbert_model_name):
"""Test ColBERT late interaction scoring with 1:1 query-document pair."""
with vllm_runner(
colbert_model_name,
runner="pooling",
dtype=DTYPE,
max_model_len=512,
enforce_eager=True,
) as vllm_model:
# Get token embeddings
q_outputs = vllm_model.token_embed([TEXTS_1[0]])
d_outputs = vllm_model.token_embed([TEXTS_2[0]])
q_emb = torch.tensor(q_outputs[0])
d_emb = torch.tensor(d_outputs[0])
# Compute MaxSim manually
manual_score = compute_maxsim_score(q_emb, d_emb).item()
# Use the score API (which should internally use _late_interaction_score)
vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2[0])
assert len(vllm_scores) == 1
assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01)
def test_colbert_late_interaction_1_to_N(vllm_runner, colbert_model_name):
"""Test ColBERT late interaction scoring with 1:N query-documents."""
with vllm_runner(
colbert_model_name,
runner="pooling",
dtype=DTYPE,
max_model_len=512,
enforce_eager=True,
) as vllm_model:
# Get token embeddings
q_outputs = vllm_model.token_embed([TEXTS_1[0]])
d_outputs = vllm_model.token_embed(TEXTS_2)
q_emb = torch.tensor(q_outputs[0])
# Compute MaxSim manually for each document
manual_scores = []
for d_out in d_outputs:
d_emb = torch.tensor(d_out)
manual_scores.append(compute_maxsim_score(q_emb, d_emb).item())
# Use the score API
vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2)
assert len(vllm_scores) == 2
for i in range(2):
assert vllm_scores[i] == pytest.approx(manual_scores[i], rel=0.01)
def test_colbert_late_interaction_N_to_N(vllm_runner, colbert_model_name):
"""Test ColBERT late interaction scoring with N:N query-documents."""
with vllm_runner(
colbert_model_name,
runner="pooling",
dtype=DTYPE,
max_model_len=512,
enforce_eager=True,
) as vllm_model:
# Get token embeddings
q_outputs = vllm_model.token_embed(TEXTS_1)
d_outputs = vllm_model.token_embed(TEXTS_2)
# Compute MaxSim manually for each pair
manual_scores = []
for q_out, d_out in zip(q_outputs, d_outputs):
q_emb = torch.tensor(q_out)
d_emb = torch.tensor(d_out)
manual_scores.append(compute_maxsim_score(q_emb, d_emb).item())
# Use the score API
vllm_scores = vllm_model.score(TEXTS_1, TEXTS_2)
assert len(vllm_scores) == 2
for i in range(2):
assert vllm_scores[i] == pytest.approx(manual_scores[i], rel=0.01)
def test_colbert_relevance_ordering(vllm_runner, colbert_model_name):
"""Test that ColBERT scores relevant documents higher than irrelevant ones."""
query = "What is machine learning?"
documents = [
"Machine learning is a subset of artificial intelligence.",
"Python is a programming language.",
"Deep learning uses neural networks.",
]
with vllm_runner(
colbert_model_name,
runner="pooling",
dtype=DTYPE,
max_model_len=512,
enforce_eager=True,
) as vllm_model:
scores = vllm_model.score(query, documents)
assert len(scores) == 3
# ML-related documents should score higher than unrelated Python doc
# Document 0 (ML definition) should be most relevant
# Document 2 (Deep learning) should also be relevant
# Document 1 (Python) should be least relevant
assert scores[0] > scores[1], "ML doc should score higher than Python doc"
assert scores[2] > scores[1], "DL doc should score higher than Python doc"
def test_colbert_embed_not_supported(vllm_runner, colbert_model_name):
"""Test that ColBERT model does not support 'embed' task."""
with (
vllm_runner(
colbert_model_name,
runner="pooling",
dtype=DTYPE,
max_model_len=512,
enforce_eager=True,
) as vllm_model,
pytest.raises(ValueError, match="Embedding API is not supported"),
):
vllm_model.embed([TEXTS_1[0]])
def test_colbert_hf_comparison(vllm_runner, colbert_model_name):
"""Test that vLLM ColBERT produces same embeddings as HuggingFace."""
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import AutoTokenizer, BertModel
test_texts = [TEXTS_1[0], TEXTS_2[0]]
# Get vLLM embeddings first (to avoid GPU memory contention)
# Use fp32 to match HuggingFace default precision for fair comparison
with vllm_runner(
colbert_model_name,
runner="pooling",
dtype="float32",
max_model_len=512,
enforce_eager=True,
) as vllm_model:
vllm_outputs = vllm_model.token_embed(test_texts)
# Get HuggingFace reference embeddings on CPU
# Load the base BERT model and manually apply the ColBERT linear projection
hf_tokenizer = AutoTokenizer.from_pretrained(colbert_model_name)
hf_bert = BertModel.from_pretrained(colbert_model_name)
hf_bert.eval()
# Load the ColBERT linear weights from safetensors
weights_path = hf_hub_download(colbert_model_name, filename="model.safetensors")
weights = load_file(weights_path)
linear_weight = weights["linear.weight"] # [96, 384]
hf_embeddings = []
for text in test_texts:
inputs = hf_tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = hf_bert(**inputs)
# Get last hidden state: [1, seq_len, 384]
hidden_states = outputs.last_hidden_state
# Apply ColBERT linear projection: [1, seq_len, 96]
token_emb = F.linear(hidden_states, linear_weight)
# L2 normalize
token_emb = F.normalize(token_emb, p=2, dim=-1)
hf_embeddings.append(token_emb.squeeze(0).float())
# Compare embeddings
for i, (hf_emb, vllm_out) in enumerate(zip(hf_embeddings, vllm_outputs)):
vllm_emb = torch.tensor(vllm_out).float()
# Print first few components for debugging
print(f"\n=== Text {i}: '{test_texts[i][:30]}...' ===")
print(f"HF shape: {hf_emb.shape}, vLLM shape: {vllm_emb.shape}")
print(f"HF first token, first 10 dims: {hf_emb[0, :10].tolist()}")
print(f"vLLM first token, first 10 dims: {vllm_emb[0, :10].tolist()}")
print(f"HF last token, first 10 dims: {hf_emb[-1, :10].tolist()}")
print(f"vLLM last token, first 10 dims: {vllm_emb[-1, :10].tolist()}")
# Should have same shape
assert hf_emb.shape == vllm_emb.shape, (
f"Shape mismatch for text {i}: HF {hf_emb.shape} vs vLLM {vllm_emb.shape}"
)
# Should have same values (with tolerance for fp16)
torch.testing.assert_close(
vllm_emb,
hf_emb,
rtol=1e-2,
atol=1e-2,
msg=f"Embedding mismatch for text {i}",
)
...@@ -520,6 +520,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -520,6 +520,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
_EMBEDDING_EXAMPLE_MODELS = { _EMBEDDING_EXAMPLE_MODELS = {
# [Text-only] # [Text-only]
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"HF_ColBERT": _HfExamplesInfo("answerdotai/answerai-colbert-small-v1"),
"BgeM3EmbeddingModel": _HfExamplesInfo("BAAI/bge-m3"), "BgeM3EmbeddingModel": _HfExamplesInfo("BAAI/bge-m3"),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
"Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"), "Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"),
......
...@@ -1411,6 +1411,11 @@ class ModelConfig: ...@@ -1411,6 +1411,11 @@ class ModelConfig:
self._model_info.supports_cross_encoding or self.convert_type == "classify" self._model_info.supports_cross_encoding or self.convert_type == "classify"
) )
@property
def is_late_interaction(self) -> bool:
"""Check if model uses late interaction (ColBERT-style) scoring."""
return self._model_info.supports_late_interaction
@property @property
def is_pp_supported(self) -> bool: def is_pp_supported(self) -> bool:
return self._model_info.supports_pp return self._model_info.supports_pp
......
...@@ -44,6 +44,7 @@ from vllm.entrypoints.pooling.score.utils import ( ...@@ -44,6 +44,7 @@ from vllm.entrypoints.pooling.score.utils import (
ScoreMultiModalParam, ScoreMultiModalParam,
_cosine_similarity, _cosine_similarity,
compress_token_type_ids, compress_token_type_ids,
compute_maxsim_score,
get_score_prompt, get_score_prompt,
validate_score_input, validate_score_input,
) )
...@@ -1368,6 +1369,87 @@ class LLM: ...@@ -1368,6 +1369,87 @@ class LLM:
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput) items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items] return [ScoringRequestOutput.from_base(item) for item in items]
def _late_interaction_score(
self,
data_1: list[ScoreData],
data_2: list[ScoreData],
*,
use_tqdm: bool | Callable[..., tqdm],
pooling_params: PoolingParams | None,
lora_request: list[LoRARequest] | LoRARequest | None,
tokenization_kwargs: dict[str, Any],
) -> list[ScoringRequestOutput]:
"""
Late interaction scoring (ColBERT MaxSim).
Encodes queries and documents into per-token embeddings, then computes
MaxSim: sum over query tokens of max similarity to any document token.
"""
from vllm.outputs import PoolingOutput
tokenizer = self.get_tokenizer()
# Extract text from ScoreData
text_1: list[str] = []
for text in data_1:
if not isinstance(text, str):
raise NotImplementedError(
"Late interaction scores currently do not support multimodal input."
)
text_1.append(text)
text_2: list[str] = []
for text in data_2:
if not isinstance(text, str):
raise NotImplementedError(
"Late interaction scores currently do not support multimodal input."
)
text_2.append(text)
encoded_output: list[PoolingRequestOutput] = self.encode(
text_1 + text_2,
use_tqdm=use_tqdm,
lora_request=lora_request,
pooling_params=pooling_params,
pooling_task="token_embed",
tokenization_kwargs=tokenization_kwargs,
)
encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]
if len(encoded_output_1) == 1:
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
# Compute MaxSim scores
scores: list[PoolingRequestOutput] = []
padding: list[int] = []
if (pad_token_id := tokenizer.pad_token_id) is not None:
padding = [pad_token_id]
for emb_1, emb_2 in zip(encoded_output_1, encoded_output_2):
# emb_1.outputs.data: [query_len, dim]
# emb_2.outputs.data: [doc_len, dim]
q_emb = emb_1.outputs.data
d_emb = emb_2.outputs.data
maxsim_score = compute_maxsim_score(q_emb, d_emb)
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
scores.append(
PoolingRequestOutput(
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
outputs=PoolingOutput(data=maxsim_score),
prompt_token_ids=tokens,
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
finished=True,
)
)
items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
return [ScoringRequestOutput.from_base(item) for item in items]
def _cross_encoding_score( def _cross_encoding_score(
self, self,
data_1: list[ScoreData], data_1: list[ScoreData],
...@@ -1497,7 +1579,11 @@ class LLM: ...@@ -1497,7 +1579,11 @@ class LLM:
) )
supported_tasks = self.supported_tasks supported_tasks = self.supported_tasks
if all(t not in supported_tasks for t in ("embed", "classify")): # Late interaction models (e.g., ColBERT) use token_embed for scoring
is_late_interaction = model_config.is_late_interaction
if not is_late_interaction and all(
t not in supported_tasks for t in ("embed", "classify")
):
raise ValueError( raise ValueError(
"Score API is not supported by this model. " "Score API is not supported by this model. "
"Try converting the model using " "Try converting the model using "
...@@ -1538,6 +1624,15 @@ class LLM: ...@@ -1538,6 +1624,15 @@ class LLM:
tokenization_kwargs=encode_kwargs, tokenization_kwargs=encode_kwargs,
score_template=chat_template, score_template=chat_template,
) )
elif is_late_interaction:
return self._late_interaction_score(
score_data_1,
score_data_2,
use_tqdm=use_tqdm,
pooling_params=pooling_params,
lora_request=lora_request,
tokenization_kwargs=encode_kwargs,
)
else: else:
return self._embedding_score( return self._embedding_score(
score_data_1, score_data_1,
......
...@@ -37,7 +37,11 @@ def register_pooling_api_routers( ...@@ -37,7 +37,11 @@ def register_pooling_api_routers(
app.include_router(embed_router) app.include_router(embed_router)
if "score" in supported_tasks or "embed" in supported_tasks: # Score/rerank endpoints are available for:
# - "score" task (cross-encoder models)
# - "embed" task (bi-encoder models)
# - "token_embed" task (late interaction models like ColBERT)
if any(t in supported_tasks for t in ("score", "embed", "token_embed")):
from vllm.entrypoints.pooling.score.api_router import router as score_router from vllm.entrypoints.pooling.score.api_router import router as score_router
app.include_router(score_router) app.include_router(score_router)
...@@ -101,6 +105,10 @@ def init_pooling_state( ...@@ -101,6 +105,10 @@ def init_pooling_state(
if "classify" in supported_tasks if "classify" in supported_tasks
else None else None
) )
# ServingScores handles score/rerank for:
# - "score" task (cross-encoder models)
# - "embed" task (bi-encoder models)
# - "token_embed" task (late interaction models like ColBERT)
state.openai_serving_scores = ( state.openai_serving_scores = (
ServingScores( ServingScores(
engine_client, engine_client,
...@@ -109,6 +117,6 @@ def init_pooling_state( ...@@ -109,6 +117,6 @@ def init_pooling_state(
score_template=resolved_chat_template, score_template=resolved_chat_template,
log_error_stack=args.log_error_stack, log_error_stack=args.log_error_stack,
) )
if ("embed" in supported_tasks or "score" in supported_tasks) if any(t in supported_tasks for t in ("embed", "score", "token_embed"))
else None else None
) )
...@@ -31,6 +31,7 @@ from vllm.entrypoints.pooling.score.utils import ( ...@@ -31,6 +31,7 @@ from vllm.entrypoints.pooling.score.utils import (
ScoreInputs, ScoreInputs,
_cosine_similarity, _cosine_similarity,
compress_token_type_ids, compress_token_type_ids,
compute_maxsim_score,
get_score_prompt, get_score_prompt,
validate_score_input, validate_score_input,
) )
...@@ -68,9 +69,12 @@ class ServingScores(OpenAIServing): ...@@ -68,9 +69,12 @@ class ServingScores(OpenAIServing):
self.is_cross_encoder = self.model_config.is_cross_encoder self.is_cross_encoder = self.model_config.is_cross_encoder
self.is_multimodal_model = self.model_config.is_multimodal_model self.is_multimodal_model = self.model_config.is_multimodal_model
self.architecture = self.model_config.architecture self.architecture = self.model_config.architecture
self.is_late_interaction = self.model_config.is_late_interaction
if self.is_cross_encoder: if self.is_cross_encoder:
self._score_func = self._cross_encoding_score self._score_func = self._cross_encoding_score
elif self.is_late_interaction:
self._score_func = self._late_interaction_score
else: else:
self._score_func = self._embedding_score self._score_func = self._embedding_score
...@@ -172,6 +176,142 @@ class ServingScores(OpenAIServing): ...@@ -172,6 +176,142 @@ class ServingScores(OpenAIServing):
return final_res_batch return final_res_batch
async def _late_interaction_score(
self,
data_1: list[ScoreData],
data_2: list[ScoreData],
request: RerankRequest | ScoreRequest,
request_id: str,
lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None,
) -> list[PoolingRequestOutput] | ErrorResponse:
"""
Late interaction scoring (ColBERT MaxSim).
Encodes queries and documents into per-token embeddings, then computes
MaxSim: sum over query tokens of max similarity to any document token.
"""
input_texts: list[str] = []
for text in data_1 + data_2:
if not isinstance(text, str):
raise NotImplementedError(
"Late interaction scores currently do not support multimodal input."
)
input_texts.append(text)
model_config = self.model_config
tokenizer = self.renderer.get_tokenizer()
encode_async = make_async(
tokenizer.encode,
executor=self._tokenizer_executor,
)
tokenization_kwargs = request.build_tok_params(model_config).get_encode_kwargs()
tokenized_prompts = await asyncio.gather(
*(encode_async(t, **tokenization_kwargs) for t in input_texts)
)
engine_prompts: list[TokensPrompt] = []
for tok_result, input_text in zip(tokenized_prompts, input_texts):
text_token_prompt = self._validate_input(request, tok_result, input_text)
engine_prompts.append(
TokensPrompt(prompt_token_ids=text_token_prompt["prompt_token_ids"])
)
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
# Use token_embed task for late interaction models
from vllm import PoolingParams
pooling_params = PoolingParams(
task="token_embed",
truncate_prompt_tokens=request.truncate_prompt_tokens,
use_activation=request.use_activation,
)
try:
pooling_params.verify("token_embed", self.model_config)
except ValueError as e:
return self.create_error_response(str(e))
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(
request_id_item,
input_texts[i],
params=pooling_params,
lora_request=lora_request,
)
generators.append(
self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
)
result_generator = merge_async_iterators(*generators)
# Collect token embeddings
embeddings: list[PoolingRequestOutput | None] = [None] * len(engine_prompts)
async for i, res in result_generator:
embeddings[i] = res
# Split into query and document embeddings
emb_data_1: list[PoolingRequestOutput] = []
emb_data_2: list[PoolingRequestOutput] = []
for i in range(0, len(data_1)):
assert (emb := embeddings[i]) is not None
emb_data_1.append(emb)
for i in range(len(data_1), len(embeddings)):
assert (emb := embeddings[i]) is not None
emb_data_2.append(emb)
# Expand queries if 1:N scoring
if len(emb_data_1) == 1:
emb_data_1 = emb_data_1 * len(emb_data_2)
# Compute MaxSim scores
from vllm.outputs import PoolingOutput
scores: list[PoolingRequestOutput] = []
padding: list[int] = []
if (pad_token_id := tokenizer.pad_token_id) is not None:
padding = [pad_token_id]
for emb_1, emb_2 in zip(emb_data_1, emb_data_2):
# emb_1.outputs.data: [query_len, dim]
# emb_2.outputs.data: [doc_len, dim]
q_emb = emb_1.outputs.data
d_emb = emb_2.outputs.data
maxsim_score = compute_maxsim_score(q_emb, d_emb)
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
scores.append(
PoolingRequestOutput(
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
outputs=PoolingOutput(data=maxsim_score),
prompt_token_ids=tokens,
num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
finished=True,
)
)
return scores
async def _cross_encoding_score( async def _cross_encoding_score(
self, self,
data_1: list[ScoreData], data_1: list[ScoreData],
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, TypeAlias, cast from typing import Any, TypeAlias, cast
import torch
from torch.nn import CosineSimilarity from torch.nn import CosineSimilarity
from typing_extensions import Required, TypedDict from typing_extensions import Required, TypedDict
...@@ -34,6 +35,23 @@ ScoreContentPartParam: TypeAlias = ( ...@@ -34,6 +35,23 @@ ScoreContentPartParam: TypeAlias = (
) )
def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tensor:
"""
Compute ColBERT MaxSim score.
Args:
q_emb: Query token embeddings [query_len, dim]
d_emb: Document token embeddings [doc_len, dim]
Returns:
MaxSim score (sum over query tokens of max similarity to any doc token)
"""
# [query_len, doc_len]
token_scores = torch.matmul(q_emb, d_emb.T)
# Max over document tokens, sum over query tokens
return token_scores.amax(dim=-1).sum()
class ScoreMultiModalParam(TypedDict, total=False): class ScoreMultiModalParam(TypedDict, total=False):
""" """
A specialized parameter type for scoring multimodal content A specialized parameter type for scoring multimodal content
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
ColBERT late interaction model for retrieval and reranking.
ColBERT uses per-token embeddings and late interaction (MaxSim) scoring
instead of single-vector representations or cross-encoder concatenation.
Reference: https://arxiv.org/abs/2004.12832
"""
from collections.abc import Iterable
from typing import ClassVar, Literal
import torch
from torch import nn
from vllm.config import PoolerConfig, VllmConfig
from vllm.model_executor.layers.pooler import Pooler
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_embed
from .bert import BertEmbeddingModel, BertModel
from .interfaces_base import default_pooling_type
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
class ColBERTModel(BertEmbeddingModel):
"""ColBERT late interaction model for retrieval/reranking.
This model extends BertEmbeddingModel with a ColBERT-style linear
projection layer for per-token embeddings. It supports only:
- "token_embed" task: Per-token embeddings for late interaction
ColBERT is fundamentally a per-token embedding model - the linear
projection is trained for per-token representations, not for CLS
pooling. Use a dedicated dense embedding model if you need single-
vector representations.
The ColBERT scoring (MaxSim) is computed externally, either client-side
or via the late interaction scoring path in ServingScores.
Attributes:
colbert_linear: Linear projection from hidden_size to colbert_dim
supports_late_interaction: Flag indicating this model uses late
interaction scoring
"""
# Mark this model as supporting late interaction scoring
supports_late_interaction: ClassVar[Literal[True]] = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Get config before calling super().__init__
config = vllm_config.model_config.hf_config
self.hidden_size = config.hidden_size
self.head_dtype = vllm_config.model_config.head_dtype
# ColBERT dimension - check various config field names used by different
# ColBERT implementations. If not found in config, will be inferred
# from loaded weights in load_weights()
self.colbert_dim: int | None = (
getattr(config, "colbert_dim", None)
or getattr(config, "dim", None)
or getattr(config, "projection_dim", None)
)
# Initialize parent (this will call _build_pooler)
super().__init__(vllm_config=vllm_config, prefix=prefix)
def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> BertModel:
return BertModel(vllm_config=vllm_config, prefix=prefix)
def _build_colbert_linear(self) -> nn.Linear:
"""Build the ColBERT linear projection layer."""
if self.colbert_dim is None:
raise ValueError("colbert_dim must be set before building the linear layer")
return nn.Linear(
self.hidden_size,
self.colbert_dim,
bias=False,
dtype=self.head_dtype,
)
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
# ColBERT linear projection: hidden_size -> colbert_dim
# Original ColBERT uses bias=False
# If colbert_dim is not set from config, it will be inferred during
# load_weights and the linear layer will be created there
if self.colbert_dim is not None:
self.colbert_linear = self._build_colbert_linear()
else:
# Placeholder - will be created when weights are loaded
self.colbert_linear = None
# ColBERT only supports token_embed - it's fundamentally a per-token
# embedding model.
return pooler_for_token_embed(
pooler_config,
projector=self.colbert_linear,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
def _strip(name: str) -> str:
for p in ("model.", "bert."):
if name.startswith(p):
name = name[len(p) :]
return name
weights_list = list(weights)
model_side: list[tuple[str, torch.Tensor]] = []
colbert_side: list[tuple[str, torch.Tensor]] = []
for name, weight in weights_list:
stripped = _strip(name)
# Handle different checkpoint naming conventions for ColBERT linear
if stripped in ("linear.weight", "colbert_linear.weight"):
colbert_side.append(("colbert_linear.weight", weight))
elif stripped.startswith("linear.") or stripped.startswith(
"colbert_linear."
):
new_name = stripped.replace("linear.", "colbert_linear.")
colbert_side.append((new_name, weight))
else:
model_side.append((stripped, weight))
# Load base BERT weights using BertModel.load_weights which handles QKV fusion
loaded: set[str] = set()
loaded_model = self.model.load_weights(model_side)
loaded.update({"model." + n for n in loaded_model})
# Load ColBERT linear weights
if colbert_side:
for name, weight in colbert_side:
if name == "colbert_linear.weight":
# Infer colbert_dim from weights if not set in config
if self.colbert_dim is None:
# Weight shape is [colbert_dim, hidden_size]
self.colbert_dim = weight.shape[0]
# Create the linear layer now that we know the dimension
self.colbert_linear = self._build_colbert_linear()
# Move to the same device as the model's existing parameters
device = next(self.model.parameters()).device
self.colbert_linear.to(device)
# Update the pooler's projector to use the new linear layer
self.pooler.head.projector = self.colbert_linear
# Load weights directly into the pooler's projector
weight = weight.to(self.pooler.head.projector.weight.device)
self.pooler.head.projector.weight.data.copy_(weight)
loaded.add("pooler.head.projector.weight")
break
return loaded
...@@ -981,6 +981,40 @@ def supports_cross_encoding( ...@@ -981,6 +981,40 @@ def supports_cross_encoding(
return is_pooling_model(model) and _supports_cross_encoding(model) return is_pooling_model(model) and _supports_cross_encoding(model)
@runtime_checkable
class SupportsLateInteraction(Protocol):
"""The interface required for all models that support late interaction.
Late interaction models (like ColBERT) encode queries and documents
separately into per-token embeddings, then compute similarity via
MaxSim (max over document tokens, sum over query tokens).
"""
supports_late_interaction: ClassVar[Literal[True]] = True
@overload
def supports_late_interaction(
model: type[object],
) -> TypeIs[type[SupportsLateInteraction]]: ...
@overload
def supports_late_interaction(model: object) -> TypeIs[SupportsLateInteraction]: ...
def _supports_late_interaction(
model: type[object] | object,
) -> TypeIs[type[SupportsLateInteraction]] | TypeIs[SupportsLateInteraction]:
return getattr(model, "supports_late_interaction", False)
def supports_late_interaction(
model: type[object] | object,
) -> TypeIs[type[SupportsLateInteraction]] | TypeIs[SupportsLateInteraction]:
return is_pooling_model(model) and _supports_late_interaction(model)
class SupportsQuant: class SupportsQuant:
"""The interface required for all models that support quantization.""" """The interface required for all models that support quantization."""
......
...@@ -49,6 +49,7 @@ from .interfaces import ( ...@@ -49,6 +49,7 @@ from .interfaces import (
is_hybrid, is_hybrid,
requires_raw_input_tokens, requires_raw_input_tokens,
supports_cross_encoding, supports_cross_encoding,
supports_late_interaction,
supports_mamba_prefix_caching, supports_mamba_prefix_caching,
supports_multimodal, supports_multimodal,
supports_multimodal_encoder_tp_data, supports_multimodal_encoder_tp_data,
...@@ -205,6 +206,7 @@ _EMBEDDING_MODELS = { ...@@ -205,6 +206,7 @@ _EMBEDDING_MODELS = {
# [Text-only] # [Text-only]
"BertModel": ("bert", "BertEmbeddingModel"), "BertModel": ("bert", "BertEmbeddingModel"),
"BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"), "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
"HF_ColBERT": ("colbert", "ColBERTModel"),
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3TextModel": ("gemma3", "Gemma3Model"), "Gemma3TextModel": ("gemma3", "Gemma3Model"),
...@@ -593,6 +595,7 @@ class _ModelInfo: ...@@ -593,6 +595,7 @@ class _ModelInfo:
default_seq_pooling_type: SequencePoolingType default_seq_pooling_type: SequencePoolingType
default_tok_pooling_type: TokenPoolingType default_tok_pooling_type: TokenPoolingType
supports_cross_encoding: bool supports_cross_encoding: bool
supports_late_interaction: bool
supports_multimodal: bool supports_multimodal: bool
supports_multimodal_raw_input_only: bool supports_multimodal_raw_input_only: bool
requires_raw_input_tokens: bool requires_raw_input_tokens: bool
...@@ -616,6 +619,7 @@ class _ModelInfo: ...@@ -616,6 +619,7 @@ class _ModelInfo:
default_tok_pooling_type=get_default_tok_pooling_type(model), default_tok_pooling_type=get_default_tok_pooling_type(model),
attn_type=get_attn_type(model), attn_type=get_attn_type(model),
supports_cross_encoding=supports_cross_encoding(model), supports_cross_encoding=supports_cross_encoding(model),
supports_late_interaction=supports_late_interaction(model),
supports_multimodal=supports_multimodal(model), supports_multimodal=supports_multimodal(model),
supports_multimodal_raw_input_only=supports_multimodal_raw_input_only( supports_multimodal_raw_input_only=supports_multimodal_raw_input_only(
model model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment