Unverified Commit 071d863e authored by Ilya Boytsov's avatar Ilya Boytsov Committed by GitHub
Browse files

Extend ColBERT support to non-standard BERT backbones (#34170)


Signed-off-by: default avatarIlya Boytsov <ilya.boytsov@aleph-alpha.com>
parent 0916e796
...@@ -311,20 +311,31 @@ An OpenAI client example can be found here: [examples/pooling/embed/openai_embed ...@@ -311,20 +311,31 @@ An OpenAI client example can be found here: [examples/pooling/embed/openai_embed
[ColBERT](https://arxiv.org/abs/2004.12832) (Contextualized Late Interaction over BERT) is a retrieval model that uses per-token embeddings and MaxSim scoring for document ranking. Unlike single-vector embedding models, ColBERT retains token-level representations and computes relevance scores through late interaction, providing better accuracy while being more efficient than cross-encoders. [ColBERT](https://arxiv.org/abs/2004.12832) (Contextualized Late Interaction over BERT) is a retrieval model that uses per-token embeddings and MaxSim scoring for document ranking. Unlike single-vector embedding models, ColBERT retains token-level representations and computes relevance scores through late interaction, providing better accuracy while being more efficient than cross-encoders.
vLLM supports ColBERT models for reranking tasks, automatically applying MaxSim scoring for query-document relevance: vLLM supports ColBERT models with multiple encoder backbones:
| Architecture | Backbone | Example HF Models |
|---|---|---|
| `HF_ColBERT` | BERT | `answerdotai/answerai-colbert-small-v1`, `colbert-ir/colbertv2.0` |
| `ColBERTModernBertModel` | ModernBERT | `lightonai/GTE-ModernColBERT-v1` |
| `ColBERTJinaRobertaModel` | Jina XLM-RoBERTa | `jinaai/jina-colbert-v2` |
**BERT-based ColBERT** models work out of the box:
```shell ```shell
vllm serve answerdotai/answerai-colbert-small-v1 vllm serve answerdotai/answerai-colbert-small-v1
``` ```
Currently supports ColBERT models with standard BERT encoders (e.g., `answerdotai/answerai-colbert-small-v1`, `colbert-ir/colbertv2.0`). For **non-BERT backbones**, use `--hf-overrides` to set the correct architecture:
ColBERT models with modified encoder architectures are not yet supported, including BERT variants with rotary embeddings (e.g., `jinaai/jina-colbert-v2`) or other custom encoders (e.g., `LiquidAI/LFM2-ColBERT-350M`).
If your standard BERT ColBERT model's config doesn't specify the architecture as `HF_ColBERT`, override it with:
```shell ```shell
vllm serve your-colbert-model --hf-overrides '{"architectures": ["HF_ColBERT"]}' # ModernBERT backbone
vllm serve lightonai/GTE-ModernColBERT-v1 \
--hf-overrides '{"architectures": ["ColBERTModernBertModel"]}'
# Jina XLM-RoBERTa backbone
vllm serve jinaai/jina-colbert-v2 \
--hf-overrides '{"architectures": ["ColBERTJinaRobertaModel"]}' \
--trust-remote-code
``` ```
Then you can use the rerank endpoint: Then you can use the rerank endpoint:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" """
Example of using ColBERT late interaction model for reranking. Example of using ColBERT late interaction models for reranking and scoring.
ColBERT (Contextualized Late Interaction over BERT) uses per-token embeddings ColBERT (Contextualized Late Interaction over BERT) uses per-token embeddings
and MaxSim scoring for document reranking, providing better accuracy than and MaxSim scoring for document reranking, providing better accuracy than
single-vector models while being more efficient than cross-encoders. single-vector models while being more efficient than cross-encoders.
Start the server with: vLLM supports ColBERT with multiple encoder backbones. Start the server
with one of the following:
# BERT backbone (works out of the box)
vllm serve answerdotai/answerai-colbert-small-v1 vllm serve answerdotai/answerai-colbert-small-v1
# ModernBERT backbone
vllm serve lightonai/GTE-ModernColBERT-v1 \
--hf-overrides '{"architectures": ["ColBERTModernBertModel"]}'
# Jina XLM-RoBERTa backbone
vllm serve jinaai/jina-colbert-v2 \
--hf-overrides '{"architectures": ["ColBERTJinaRobertaModel"]}' \
--trust-remote-code
Then run this script: Then run this script:
python colbert_rerank_online.py python colbert_rerank_online.py
""" """
...@@ -18,39 +30,62 @@ import json ...@@ -18,39 +30,62 @@ import json
import requests import requests
url = "http://127.0.0.1:8000/rerank" # Change this to match the model you started the server with
MODEL = "answerdotai/answerai-colbert-small-v1"
BASE_URL = "http://127.0.0.1:8000"
headers = {"accept": "application/json", "Content-Type": "application/json"} headers = {"accept": "application/json", "Content-Type": "application/json"}
data = { documents = [
"model": "answerdotai/answerai-colbert-small-v1", "Machine learning is a subset of artificial intelligence.",
"query": "What is machine learning?", "Python is a programming language.",
"documents": [ "Deep learning uses neural networks for complex tasks.",
"Machine learning is a subset of artificial intelligence.", "The weather today is sunny.",
"Python is a programming language.", ]
"Deep learning uses neural networks for complex tasks.",
"The weather today is sunny.",
], def rerank_example():
} """Use the /rerank endpoint to rank documents by query relevance."""
print("=== Rerank Example ===")
data = {
"model": MODEL,
"query": "What is machine learning?",
"documents": documents,
}
response = requests.post(f"{BASE_URL}/rerank", headers=headers, json=data)
result = response.json()
print(json.dumps(result, indent=2))
print("\nRanked documents (most relevant first):")
for item in result["results"]:
doc_idx = item["index"]
score = item["relevance_score"]
print(f" Score {score:.4f}: {documents[doc_idx]}")
def score_example():
"""Use the /score endpoint for pairwise query-document scoring."""
print("\n=== Score Example ===")
data = {
"model": MODEL,
"text_1": "What is machine learning?",
"text_2": [
"Machine learning is a subset of AI.",
"The weather is sunny.",
],
}
response = requests.post(f"{BASE_URL}/score", headers=headers, json=data)
result = response.json()
print(json.dumps(result, indent=2))
def main(): def main():
response = requests.post(url, headers=headers, json=data) rerank_example()
score_example()
if response.status_code == 200:
print("ColBERT Rerank Request successful!")
result = response.json()
print(json.dumps(result, indent=2))
# Show ranked results
print("\nRanked documents (most relevant first):")
for item in result["results"]:
doc_idx = item["index"]
score = item["relevance_score"]
print(f" Score {score:.4f}: {data['documents'][doc_idx]}")
else:
print(f"Request failed with status code: {response.status_code}")
print(response.text)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -8,10 +8,8 @@ import requests ...@@ -8,10 +8,8 @@ import requests
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.pooling.score.protocol import RerankResponse, ScoreResponse from vllm.entrypoints.pooling.score.protocol import RerankResponse, ScoreResponse
# ColBERT model - using answerai-colbert-small-v1 as it's a smaller model
MODEL_NAME = "answerdotai/answerai-colbert-small-v1" MODEL_NAME = "answerdotai/answerai-colbert-small-v1"
COLBERT_DIM = 96 # This model uses 96-dimensional output COLBERT_DIM = 96
DTYPE = "half"
MAX_MODEL_LEN = 512 MAX_MODEL_LEN = 512
...@@ -26,129 +24,119 @@ def server(): ...@@ -26,129 +24,119 @@ def server():
yield remote_server yield remote_server
@pytest.mark.parametrize("model_name", [MODEL_NAME]) class TestColBERTOnline:
def test_colbert_rerank(server: RemoteOpenAIServer, model_name: str): def test_rerank(self, server: RemoteOpenAIServer):
"""Test ColBERT rerank endpoint.""" """Test ColBERT rerank endpoint."""
query = "What is the capital of France?" query = "What is the capital of France?"
documents = [ documents = [
"The capital of Brazil is Brasilia.", "The capital of Brazil is Brasilia.",
"The capital of France is Paris.", "The capital of France is Paris.",
] ]
rerank_response = requests.post( rerank_response = requests.post(
server.url_for("rerank"), server.url_for("rerank"),
json={ json={
"model": model_name, "model": MODEL_NAME,
"query": query, "query": query,
"documents": documents, "documents": documents,
}, },
) )
rerank_response.raise_for_status() rerank_response.raise_for_status()
rerank = RerankResponse.model_validate(rerank_response.json()) rerank = RerankResponse.model_validate(rerank_response.json())
assert rerank.id is not None assert rerank.id is not None
assert rerank.results is not None assert rerank.results is not None
assert len(rerank.results) == 2 assert len(rerank.results) == 2
# The relevant document (Paris) should have higher score paris_result = next(r for r in rerank.results if r.index == 1)
paris_result = next(r for r in rerank.results if r.index == 1) brazil_result = next(r for r in rerank.results if r.index == 0)
brazil_result = next(r for r in rerank.results if r.index == 0)
assert paris_result.relevance_score > brazil_result.relevance_score
assert paris_result.relevance_score > brazil_result.relevance_score
def test_rerank_top_n(self, server: RemoteOpenAIServer):
"""Test ColBERT rerank with top_n parameter."""
@pytest.mark.parametrize("model_name", [MODEL_NAME]) query = "What is the capital of France?"
def test_colbert_rerank_top_n(server: RemoteOpenAIServer, model_name: str): documents = [
"""Test ColBERT rerank with top_n parameter.""" "The capital of Brazil is Brasilia.",
query = "What is the capital of France?" "The capital of France is Paris.",
documents = [ "Machine learning is a field of AI.",
"The capital of Brazil is Brasilia.", ]
"The capital of France is Paris.",
"Machine learning is a field of AI.", rerank_response = requests.post(
] server.url_for("rerank"),
json={
rerank_response = requests.post( "model": MODEL_NAME,
server.url_for("rerank"), "query": query,
json={ "documents": documents,
"model": model_name, "top_n": 2,
"query": query, },
"documents": documents, )
"top_n": 2, rerank_response.raise_for_status()
}, rerank = RerankResponse.model_validate(rerank_response.json())
)
rerank_response.raise_for_status() assert len(rerank.results) == 2
rerank = RerankResponse.model_validate(rerank_response.json()) assert rerank.results[0].index == 1
assert len(rerank.results) == 2 def test_score(self, server: RemoteOpenAIServer):
# Top result should be about Paris (index 1) """Test ColBERT score endpoint."""
assert rerank.results[0].index == 1 text_1 = "What is the capital of France?"
text_2 = ["The capital of France is Paris.", "Python is a language."]
@pytest.mark.parametrize("model_name", [MODEL_NAME]) score_response = requests.post(
def test_colbert_score(server: RemoteOpenAIServer, model_name: str): server.url_for("score"),
"""Test ColBERT score endpoint.""" json={
text_1 = "What is the capital of France?" "model": MODEL_NAME,
text_2 = ["The capital of France is Paris.", "Python is a language."] "text_1": text_1,
"text_2": text_2,
score_response = requests.post( },
server.url_for("score"), )
json={ score_response.raise_for_status()
"model": model_name, score = ScoreResponse.model_validate(score_response.json())
"text_1": text_1,
"text_2": text_2, assert score.id is not None
}, assert score.data is not None
) assert len(score.data) == 2
score_response.raise_for_status()
score = ScoreResponse.model_validate(score_response.json()) assert score.data[0].score > score.data[1].score
assert score.id is not None def test_token_embed(self, server: RemoteOpenAIServer):
assert score.data is not None """Test ColBERT token_embed task via pooling endpoint."""
assert len(score.data) == 2 text = "What is the capital of France?"
# The relevant document should have higher score pooling_response = requests.post(
assert score.data[0].score > score.data[1].score server.url_for("pooling"),
json={
"model": MODEL_NAME,
@pytest.mark.parametrize("model_name", [MODEL_NAME]) "input": text,
def test_colbert_token_embed(server: RemoteOpenAIServer, model_name: str): "task": "token_embed",
"""Test ColBERT token_embed task via pooling endpoint.""" },
text = "What is the capital of France?" )
pooling_response.raise_for_status()
pooling_response = requests.post( pooling = pooling_response.json()
server.url_for("pooling"),
json={ assert "data" in pooling
"model": model_name, assert len(pooling["data"]) == 1
"input": text,
"task": "token_embed", embeddings = pooling["data"][0]["data"]
}, assert isinstance(embeddings, list)
) assert len(embeddings) > 0
pooling_response.raise_for_status() assert len(embeddings[0]) == COLBERT_DIM
pooling = pooling_response.json()
def test_embed_not_supported(self, server: RemoteOpenAIServer):
assert "data" in pooling """Test that ColBERT model does not support 'embed' task."""
assert len(pooling["data"]) == 1 task = "embed"
text = "What is the capital of France?"
# Token embeddings should be 2D
embeddings = pooling["data"][0]["data"] response = requests.post(
assert isinstance(embeddings, list) server.url_for("pooling"),
assert len(embeddings) > 0 # Should have tokens json={
assert len(embeddings[0]) == COLBERT_DIM "model": MODEL_NAME,
"input": text,
"task": task,
@pytest.mark.parametrize("model_name", [MODEL_NAME]) },
def test_colbert_embed_not_supported(server: RemoteOpenAIServer, model_name: str): )
"""Test that ColBERT model does not support 'embed' task."""
task = "embed" assert response.json()["error"]["type"] == "BadRequestError"
text = "What is the capital of France?" assert response.json()["error"]["message"].startswith(
f"Unsupported task: {task!r}"
response = requests.post( )
server.url_for("pooling"),
json={
"model": model_name,
"input": text,
"task": task,
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(f"Unsupported task: {task!r}")
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for ColBERT late interaction scoring.""" """Tests for ColBERT late interaction scoring.
Tests are parametrized across multiple ColBERT backbones to ensure the
generic ColBERT support works with different encoder architectures.
"""
import pytest import pytest
import torch import torch
from vllm.entrypoints.pooling.score.utils import compute_maxsim_score from vllm.entrypoints.pooling.score.utils import compute_maxsim_score
# ColBERT model - using answerai-colbert-small-v1 as it's a smaller model # -----------------------------------------------------------------------
# suitable for testing (based on BERT-base) # Model definitions: (model_name, colbert_dim, extra vllm_runner kwargs)
COLBERT_MODEL = "answerdotai/answerai-colbert-small-v1" # -----------------------------------------------------------------------
COLBERT_DIM = 96 # This model uses 96-dimensional output COLBERT_MODELS = {
"bert": {
"model": "answerdotai/answerai-colbert-small-v1",
"colbert_dim": 96,
"max_model_len": 512,
"extra_kwargs": {},
},
"modernbert": {
"model": "lightonai/GTE-ModernColBERT-v1",
"colbert_dim": 128,
"max_model_len": 299,
"extra_kwargs": {
"hf_overrides": {
"architectures": ["ColBERTModernBertModel"],
},
},
},
"jina": {
"model": "jinaai/jina-colbert-v2",
"colbert_dim": 128,
"max_model_len": 8192,
"extra_kwargs": {
"hf_overrides": {
"architectures": ["ColBERTJinaRobertaModel"],
},
},
},
}
TEXTS_1 = [ TEXTS_1 = [
"What is the capital of France?", "What is the capital of France?",
...@@ -25,80 +56,121 @@ TEXTS_2 = [ ...@@ -25,80 +56,121 @@ TEXTS_2 = [
DTYPE = "half" DTYPE = "half"
# -----------------------------------------------------------------------
# Fixtures
# -----------------------------------------------------------------------
@pytest.fixture(params=list(COLBERT_MODELS.keys()), scope="module")
def colbert_spec(request):
"""Return the model spec dict for the current parametrization."""
return COLBERT_MODELS[request.param]
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def colbert_model_name(): def colbert_model_name(colbert_spec):
return COLBERT_MODEL return colbert_spec["model"]
@pytest.fixture(scope="module")
def colbert_dim(colbert_spec):
return colbert_spec["colbert_dim"]
@pytest.fixture(scope="module")
def colbert_max_model_len(colbert_spec):
return colbert_spec["max_model_len"]
@pytest.fixture(scope="module")
def colbert_extra_kwargs(colbert_spec):
return colbert_spec["extra_kwargs"]
def test_colbert_token_embed(vllm_runner, colbert_model_name):
# -----------------------------------------------------------------------
# Tests
# -----------------------------------------------------------------------
def test_colbert_token_embed(
vllm_runner,
colbert_model_name,
colbert_dim,
colbert_max_model_len,
colbert_extra_kwargs,
):
"""Test that ColBERT model produces token embeddings.""" """Test that ColBERT model produces token embeddings."""
with vllm_runner( with vllm_runner(
colbert_model_name, colbert_model_name,
runner="pooling", runner="pooling",
dtype=DTYPE, dtype=DTYPE,
max_model_len=512, max_model_len=colbert_max_model_len,
enforce_eager=True, enforce_eager=True,
**colbert_extra_kwargs,
) as vllm_model: ) as vllm_model:
# Get token embeddings for a single text
outputs = vllm_model.token_embed([TEXTS_1[0]]) outputs = vllm_model.token_embed([TEXTS_1[0]])
assert len(outputs) == 1 assert len(outputs) == 1
# Token embeddings should be 2D: [num_tokens, colbert_dim]
emb = torch.tensor(outputs[0]) emb = torch.tensor(outputs[0])
assert emb.dim() == 2 assert emb.dim() == 2
assert emb.shape[1] == COLBERT_DIM assert emb.shape[1] == colbert_dim
# Should have at least a few tokens
assert emb.shape[0] > 1 assert emb.shape[0] > 1
def test_colbert_late_interaction_1_to_1(vllm_runner, colbert_model_name): def test_colbert_late_interaction_1_to_1(
vllm_runner,
colbert_model_name,
colbert_max_model_len,
colbert_extra_kwargs,
):
"""Test ColBERT late interaction scoring with 1:1 query-document pair.""" """Test ColBERT late interaction scoring with 1:1 query-document pair."""
with vllm_runner( with vllm_runner(
colbert_model_name, colbert_model_name,
runner="pooling", runner="pooling",
dtype=DTYPE, dtype=DTYPE,
max_model_len=512, max_model_len=colbert_max_model_len,
enforce_eager=True, enforce_eager=True,
**colbert_extra_kwargs,
) as vllm_model: ) as vllm_model:
# Get token embeddings
q_outputs = vllm_model.token_embed([TEXTS_1[0]]) q_outputs = vllm_model.token_embed([TEXTS_1[0]])
d_outputs = vllm_model.token_embed([TEXTS_2[0]]) d_outputs = vllm_model.token_embed([TEXTS_2[0]])
q_emb = torch.tensor(q_outputs[0]) q_emb = torch.tensor(q_outputs[0])
d_emb = torch.tensor(d_outputs[0]) d_emb = torch.tensor(d_outputs[0])
# Compute MaxSim manually
manual_score = compute_maxsim_score(q_emb, d_emb).item() manual_score = compute_maxsim_score(q_emb, d_emb).item()
# Use the score API (which should internally use _late_interaction_score)
vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2[0]) vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2[0])
assert len(vllm_scores) == 1 assert len(vllm_scores) == 1
assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01) assert vllm_scores[0] == pytest.approx(manual_score, rel=0.01)
def test_colbert_late_interaction_1_to_N(vllm_runner, colbert_model_name): def test_colbert_late_interaction_1_to_N(
vllm_runner,
colbert_model_name,
colbert_max_model_len,
colbert_extra_kwargs,
):
"""Test ColBERT late interaction scoring with 1:N query-documents.""" """Test ColBERT late interaction scoring with 1:N query-documents."""
with vllm_runner( with vllm_runner(
colbert_model_name, colbert_model_name,
runner="pooling", runner="pooling",
dtype=DTYPE, dtype=DTYPE,
max_model_len=512, max_model_len=colbert_max_model_len,
enforce_eager=True, enforce_eager=True,
**colbert_extra_kwargs,
) as vllm_model: ) as vllm_model:
# Get token embeddings
q_outputs = vllm_model.token_embed([TEXTS_1[0]]) q_outputs = vllm_model.token_embed([TEXTS_1[0]])
d_outputs = vllm_model.token_embed(TEXTS_2) d_outputs = vllm_model.token_embed(TEXTS_2)
q_emb = torch.tensor(q_outputs[0]) q_emb = torch.tensor(q_outputs[0])
# Compute MaxSim manually for each document
manual_scores = [] manual_scores = []
for d_out in d_outputs: for d_out in d_outputs:
d_emb = torch.tensor(d_out) d_emb = torch.tensor(d_out)
manual_scores.append(compute_maxsim_score(q_emb, d_emb).item()) manual_scores.append(compute_maxsim_score(q_emb, d_emb).item())
# Use the score API
vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2) vllm_scores = vllm_model.score(TEXTS_1[0], TEXTS_2)
assert len(vllm_scores) == 2 assert len(vllm_scores) == 2
...@@ -106,27 +178,30 @@ def test_colbert_late_interaction_1_to_N(vllm_runner, colbert_model_name): ...@@ -106,27 +178,30 @@ def test_colbert_late_interaction_1_to_N(vllm_runner, colbert_model_name):
assert vllm_scores[i] == pytest.approx(manual_scores[i], rel=0.01) assert vllm_scores[i] == pytest.approx(manual_scores[i], rel=0.01)
def test_colbert_late_interaction_N_to_N(vllm_runner, colbert_model_name): def test_colbert_late_interaction_N_to_N(
vllm_runner,
colbert_model_name,
colbert_max_model_len,
colbert_extra_kwargs,
):
"""Test ColBERT late interaction scoring with N:N query-documents.""" """Test ColBERT late interaction scoring with N:N query-documents."""
with vllm_runner( with vllm_runner(
colbert_model_name, colbert_model_name,
runner="pooling", runner="pooling",
dtype=DTYPE, dtype=DTYPE,
max_model_len=512, max_model_len=colbert_max_model_len,
enforce_eager=True, enforce_eager=True,
**colbert_extra_kwargs,
) as vllm_model: ) as vllm_model:
# Get token embeddings
q_outputs = vllm_model.token_embed(TEXTS_1) q_outputs = vllm_model.token_embed(TEXTS_1)
d_outputs = vllm_model.token_embed(TEXTS_2) d_outputs = vllm_model.token_embed(TEXTS_2)
# Compute MaxSim manually for each pair
manual_scores = [] manual_scores = []
for q_out, d_out in zip(q_outputs, d_outputs): for q_out, d_out in zip(q_outputs, d_outputs):
q_emb = torch.tensor(q_out) q_emb = torch.tensor(q_out)
d_emb = torch.tensor(d_out) d_emb = torch.tensor(d_out)
manual_scores.append(compute_maxsim_score(q_emb, d_emb).item()) manual_scores.append(compute_maxsim_score(q_emb, d_emb).item())
# Use the score API
vllm_scores = vllm_model.score(TEXTS_1, TEXTS_2) vllm_scores = vllm_model.score(TEXTS_1, TEXTS_2)
assert len(vllm_scores) == 2 assert len(vllm_scores) == 2
...@@ -134,8 +209,13 @@ def test_colbert_late_interaction_N_to_N(vllm_runner, colbert_model_name): ...@@ -134,8 +209,13 @@ def test_colbert_late_interaction_N_to_N(vllm_runner, colbert_model_name):
assert vllm_scores[i] == pytest.approx(manual_scores[i], rel=0.01) assert vllm_scores[i] == pytest.approx(manual_scores[i], rel=0.01)
def test_colbert_relevance_ordering(vllm_runner, colbert_model_name): def test_colbert_relevance_ordering(
"""Test that ColBERT scores relevant documents higher than irrelevant ones.""" vllm_runner,
colbert_model_name,
colbert_max_model_len,
colbert_extra_kwargs,
):
"""Test that ColBERT scores relevant documents higher than irrelevant."""
query = "What is machine learning?" query = "What is machine learning?"
documents = [ documents = [
"Machine learning is a subset of artificial intelligence.", "Machine learning is a subset of artificial intelligence.",
...@@ -147,48 +227,73 @@ def test_colbert_relevance_ordering(vllm_runner, colbert_model_name): ...@@ -147,48 +227,73 @@ def test_colbert_relevance_ordering(vllm_runner, colbert_model_name):
colbert_model_name, colbert_model_name,
runner="pooling", runner="pooling",
dtype=DTYPE, dtype=DTYPE,
max_model_len=512, max_model_len=colbert_max_model_len,
enforce_eager=True, enforce_eager=True,
**colbert_extra_kwargs,
) as vllm_model: ) as vllm_model:
scores = vllm_model.score(query, documents) scores = vllm_model.score(query, documents)
assert len(scores) == 3 assert len(scores) == 3
# ML-related documents should score higher than unrelated Python doc
# Document 0 (ML definition) should be most relevant
# Document 2 (Deep learning) should also be relevant
# Document 1 (Python) should be least relevant
assert scores[0] > scores[1], "ML doc should score higher than Python doc" assert scores[0] > scores[1], "ML doc should score higher than Python doc"
assert scores[2] > scores[1], "DL doc should score higher than Python doc" assert scores[2] > scores[1], "DL doc should score higher than Python doc"
def test_colbert_embed_not_supported(vllm_runner, colbert_model_name): def test_colbert_embed_not_supported(
vllm_runner,
colbert_model_name,
colbert_max_model_len,
colbert_extra_kwargs,
):
"""Test that ColBERT model does not support 'embed' task.""" """Test that ColBERT model does not support 'embed' task."""
with ( with (
vllm_runner( vllm_runner(
colbert_model_name, colbert_model_name,
runner="pooling", runner="pooling",
dtype=DTYPE, dtype=DTYPE,
max_model_len=512, max_model_len=colbert_max_model_len,
enforce_eager=True, enforce_eager=True,
**colbert_extra_kwargs,
) as vllm_model, ) as vllm_model,
pytest.raises(ValueError, match="Embedding API is not supported"), pytest.raises(ValueError, match="Embedding API is not supported"),
): ):
vllm_model.embed([TEXTS_1[0]]) vllm_model.embed([TEXTS_1[0]])
def test_colbert_hf_comparison(vllm_runner, colbert_model_name): # -----------------------------------------------------------------------
"""Test that vLLM ColBERT produces same embeddings as HuggingFace.""" # Per-model HuggingFace comparison tests
# -----------------------------------------------------------------------
def _assert_embeddings_close(vllm_outputs, hf_embeddings):
"""Assert that vLLM and HuggingFace embeddings match."""
for i, (hf_emb, vllm_out) in enumerate(zip(hf_embeddings, vllm_outputs)):
vllm_emb = torch.tensor(vllm_out).float()
assert hf_emb.shape == vllm_emb.shape, (
f"Shape mismatch for text {i}: HF {hf_emb.shape} vs vLLM {vllm_emb.shape}"
)
torch.testing.assert_close(
vllm_emb,
hf_emb,
rtol=1e-2,
atol=1e-2,
msg=f"Embedding mismatch for text {i}",
)
def test_colbert_hf_comparison_bert(vllm_runner):
"""Test that vLLM ColBERT produces same embeddings as HuggingFace (BERT)."""
import torch.nn.functional as F import torch.nn.functional as F
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from safetensors.torch import load_file from safetensors.torch import load_file
from transformers import AutoTokenizer, BertModel from transformers import AutoTokenizer, BertModel
model_name = COLBERT_MODELS["bert"]["model"]
test_texts = [TEXTS_1[0], TEXTS_2[0]] test_texts = [TEXTS_1[0], TEXTS_2[0]]
# Get vLLM embeddings first (to avoid GPU memory contention)
# Use fp32 to match HuggingFace default precision for fair comparison
with vllm_runner( with vllm_runner(
colbert_model_name, model_name,
runner="pooling", runner="pooling",
dtype="float32", dtype="float32",
max_model_len=512, max_model_len=512,
...@@ -196,14 +301,11 @@ def test_colbert_hf_comparison(vllm_runner, colbert_model_name): ...@@ -196,14 +301,11 @@ def test_colbert_hf_comparison(vllm_runner, colbert_model_name):
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.token_embed(test_texts) vllm_outputs = vllm_model.token_embed(test_texts)
# Get HuggingFace reference embeddings on CPU hf_tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load the base BERT model and manually apply the ColBERT linear projection hf_bert = BertModel.from_pretrained(model_name)
hf_tokenizer = AutoTokenizer.from_pretrained(colbert_model_name)
hf_bert = BertModel.from_pretrained(colbert_model_name)
hf_bert.eval() hf_bert.eval()
# Load the ColBERT linear weights from safetensors weights_path = hf_hub_download(model_name, filename="model.safetensors")
weights_path = hf_hub_download(colbert_model_name, filename="model.safetensors")
weights = load_file(weights_path) weights = load_file(weights_path)
linear_weight = weights["linear.weight"] # [96, 384] linear_weight = weights["linear.weight"] # [96, 384]
...@@ -212,36 +314,103 @@ def test_colbert_hf_comparison(vllm_runner, colbert_model_name): ...@@ -212,36 +314,103 @@ def test_colbert_hf_comparison(vllm_runner, colbert_model_name):
inputs = hf_tokenizer(text, return_tensors="pt") inputs = hf_tokenizer(text, return_tensors="pt")
with torch.no_grad(): with torch.no_grad():
outputs = hf_bert(**inputs) outputs = hf_bert(**inputs)
# Get last hidden state: [1, seq_len, 384]
hidden_states = outputs.last_hidden_state hidden_states = outputs.last_hidden_state
# Apply ColBERT linear projection: [1, seq_len, 96]
token_emb = F.linear(hidden_states, linear_weight) token_emb = F.linear(hidden_states, linear_weight)
# L2 normalize
token_emb = F.normalize(token_emb, p=2, dim=-1) token_emb = F.normalize(token_emb, p=2, dim=-1)
hf_embeddings.append(token_emb.squeeze(0).float()) hf_embeddings.append(token_emb.squeeze(0).float())
# Compare embeddings _assert_embeddings_close(vllm_outputs, hf_embeddings)
for i, (hf_emb, vllm_out) in enumerate(zip(hf_embeddings, vllm_outputs)):
vllm_emb = torch.tensor(vllm_out).float()
# Print first few components for debugging
print(f"\n=== Text {i}: '{test_texts[i][:30]}...' ===")
print(f"HF shape: {hf_emb.shape}, vLLM shape: {vllm_emb.shape}")
print(f"HF first token, first 10 dims: {hf_emb[0, :10].tolist()}")
print(f"vLLM first token, first 10 dims: {vllm_emb[0, :10].tolist()}")
print(f"HF last token, first 10 dims: {hf_emb[-1, :10].tolist()}")
print(f"vLLM last token, first 10 dims: {vllm_emb[-1, :10].tolist()}")
# Should have same shape def test_colbert_hf_comparison_modernbert(vllm_runner):
assert hf_emb.shape == vllm_emb.shape, ( """Test that vLLM ColBERT produces same embeddings as HuggingFace
f"Shape mismatch for text {i}: HF {hf_emb.shape} vs vLLM {vllm_emb.shape}" (ModernBERT)."""
) import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import AutoModel, AutoTokenizer
# Should have same values (with tolerance for fp16) spec = COLBERT_MODELS["modernbert"]
torch.testing.assert_close( model_name = spec["model"]
vllm_emb, test_texts = [TEXTS_1[0], TEXTS_2[0]]
hf_emb,
rtol=1e-2, with vllm_runner(
atol=1e-2, model_name,
msg=f"Embedding mismatch for text {i}", runner="pooling",
) dtype="float32",
max_model_len=spec["max_model_len"],
enforce_eager=True,
**spec["extra_kwargs"],
) as vllm_model:
vllm_outputs = vllm_model.token_embed(test_texts)
hf_tokenizer = AutoTokenizer.from_pretrained(model_name)
hf_model = AutoModel.from_pretrained(model_name)
hf_model.eval()
# Load projection from sentence-transformers 1_Dense layer
dense_path = hf_hub_download(model_name, filename="1_Dense/model.safetensors")
dense_weights = load_file(dense_path)
linear_weight = dense_weights["linear.weight"] # [128, 768]
hf_embeddings = []
for text in test_texts:
inputs = hf_tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = hf_model(**inputs)
hidden_states = outputs.last_hidden_state
token_emb = F.linear(hidden_states, linear_weight)
token_emb = F.normalize(token_emb, p=2, dim=-1)
hf_embeddings.append(token_emb.squeeze(0).float())
_assert_embeddings_close(vllm_outputs, hf_embeddings)
def test_colbert_hf_comparison_jina(vllm_runner):
"""Test that vLLM ColBERT produces same embeddings as HuggingFace
(Jina XLM-RoBERTa)."""
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from transformers import AutoModel, AutoTokenizer
spec = COLBERT_MODELS["jina"]
model_name = spec["model"]
test_texts = [TEXTS_1[0], TEXTS_2[0]]
with vllm_runner(
model_name,
runner="pooling",
dtype="float32",
max_model_len=spec["max_model_len"],
enforce_eager=True,
**spec["extra_kwargs"],
) as vllm_model:
vllm_outputs = vllm_model.token_embed(test_texts)
hf_tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
)
hf_model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
)
hf_model.eval()
# Load projection from main checkpoint
weights_path = hf_hub_download(model_name, filename="model.safetensors")
weights = load_file(weights_path)
linear_weight = weights["linear.weight"] # [128, 1024]
hf_embeddings = []
for text in test_texts:
inputs = hf_tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = hf_model(**inputs)
hidden_states = outputs.last_hidden_state
token_emb = F.linear(hidden_states.float(), linear_weight.float())
token_emb = F.normalize(token_emb, p=2, dim=-1)
hf_embeddings.append(token_emb.squeeze(0).float())
_assert_embeddings_close(vllm_outputs, hf_embeddings)
...@@ -529,6 +529,15 @@ _EMBEDDING_EXAMPLE_MODELS = { ...@@ -529,6 +529,15 @@ _EMBEDDING_EXAMPLE_MODELS = {
# [Text-only] # [Text-only]
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"HF_ColBERT": _HfExamplesInfo("answerdotai/answerai-colbert-small-v1"), "HF_ColBERT": _HfExamplesInfo("answerdotai/answerai-colbert-small-v1"),
"ColBERTModernBertModel": _HfExamplesInfo(
"lightonai/GTE-ModernColBERT-v1",
hf_overrides={"architectures": ["ColBERTModernBertModel"]},
),
"ColBERTJinaRobertaModel": _HfExamplesInfo(
"jinaai/jina-colbert-v2",
trust_remote_code=True,
hf_overrides={"architectures": ["ColBERTJinaRobertaModel"]},
),
"BgeM3EmbeddingModel": _HfExamplesInfo("BAAI/bge-m3"), "BgeM3EmbeddingModel": _HfExamplesInfo("BAAI/bge-m3"),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
"Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"), "Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"),
......
...@@ -6,6 +6,14 @@ ColBERT late interaction model for retrieval and reranking. ...@@ -6,6 +6,14 @@ ColBERT late interaction model for retrieval and reranking.
ColBERT uses per-token embeddings and late interaction (MaxSim) scoring ColBERT uses per-token embeddings and late interaction (MaxSim) scoring
instead of single-vector representations or cross-encoder concatenation. instead of single-vector representations or cross-encoder concatenation.
This module provides:
- :class:`ColBERTMixin` — mixin that adds ColBERT late-interaction support
to any embedding model.
- :class:`ColBERTModel` — ColBERT with BERT backbone (original architecture).
- :class:`ColBERTModernBertModel` — ColBERT with ModernBERT backbone.
- :class:`ColBERTJinaRobertaModel` — ColBERT with Jina XLM-RoBERTa backbone.
Reference: https://arxiv.org/abs/2004.12832 Reference: https://arxiv.org/abs/2004.12832
""" """
...@@ -23,51 +31,60 @@ from .bert import BertEmbeddingModel, BertModel ...@@ -23,51 +31,60 @@ from .bert import BertEmbeddingModel, BertModel
from .interfaces_base import default_pooling_type from .interfaces_base import default_pooling_type
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL") class ColBERTMixin:
class ColBERTModel(BertEmbeddingModel): """Mixin that adds ColBERT late interaction support to any embedding model.
"""ColBERT late interaction model for retrieval/reranking.
This model extends BertEmbeddingModel with a ColBERT-style linear ColBERT (Contextualized Late Interaction over BERT) uses per-token
projection layer for per-token embeddings. It supports only: embeddings with a linear projection layer. This mixin provides:
- "token_embed" task: Per-token embeddings for late interaction
ColBERT is fundamentally a per-token embedding model - the linear - ``supports_late_interaction`` class-var
projection is trained for per-token representations, not for CLS - ColBERT linear projection initialisation / lazy creation
pooling. Use a dedicated dense embedding model if you need single- - Weight loading helpers for the projection layer
vector representations. - A builder for the token-embedding pooler
The ColBERT scoring (MaxSim) is computed externally, either client-side **Integration:**
or via the late interaction scoring path in ServingScores.
Attributes: 1. Inherit from both ``ColBERTMixin`` and ``nn.Module``.
colbert_linear: Linear projection from hidden_size to colbert_dim 2. In ``__init__``: call ``super().__init__()``, then
supports_late_interaction: Flag indicating this model uses late :meth:`_init_colbert_components`, then create ``self.model``
interaction scoring (the backbone) and ``self.pooler`` via :meth:`_build_colbert_pooler`.
3. In ``load_weights``: use :meth:`_load_colbert_weights` to separate
the ColBERT projection weight, then delegate the rest to the backbone.
""" """
# Mark this model as supporting late interaction scoring
supports_late_interaction: ClassVar[Literal[True]] = True supports_late_interaction: ClassVar[Literal[True]] = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Set during _init_colbert_components
# Get config before calling super().__init__ colbert_dim: int | None
config = vllm_config.model_config.hf_config colbert_linear: nn.Linear | None
self.hidden_size = config.hidden_size hidden_size: int
self.head_dtype = vllm_config.model_config.head_dtype head_dtype: torch.dtype
# ColBERT dimension - check various config field names used by different
# ColBERT implementations. If not found in config, will be inferred
# from loaded weights in load_weights()
self.colbert_dim: int | None = (
getattr(config, "colbert_dim", None)
or getattr(config, "dim", None)
or getattr(config, "projection_dim", None)
)
# Initialize parent (this will call _build_pooler) # ------------------------------------------------------------------ init
super().__init__(vllm_config=vllm_config, prefix=prefix)
def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> BertModel: def _init_colbert_components(
return BertModel(vllm_config=vllm_config, prefix=prefix) self,
hidden_size: int,
colbert_dim: int | None,
head_dtype: torch.dtype,
) -> None:
"""Initialise ColBERT projection layer.
Args:
hidden_size: Hidden dimension of the encoder backbone.
colbert_dim: Output dimension for ColBERT embeddings. If
``None``, will be inferred from weights during loading (or
auto-loaded from sentence-transformers Dense layers).
head_dtype: Data type for the projection layer.
"""
self.hidden_size = hidden_size
self.colbert_dim = colbert_dim
self.head_dtype = head_dtype
if colbert_dim is not None:
self.colbert_linear = self._build_colbert_linear()
else:
self.colbert_linear = None
def _build_colbert_linear(self) -> nn.Linear: def _build_colbert_linear(self) -> nn.Linear:
"""Build the ColBERT linear projection layer.""" """Build the ColBERT linear projection layer."""
...@@ -80,24 +97,127 @@ class ColBERTModel(BertEmbeddingModel): ...@@ -80,24 +97,127 @@ class ColBERTModel(BertEmbeddingModel):
dtype=self.head_dtype, dtype=self.head_dtype,
) )
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: # ---------------------------------------------------------------- pooler
# ColBERT linear projection: hidden_size -> colbert_dim
# Original ColBERT uses bias=False
# If colbert_dim is not set from config, it will be inferred during
# load_weights and the linear layer will be created there
if self.colbert_dim is not None:
self.colbert_linear = self._build_colbert_linear()
else:
# Placeholder - will be created when weights are loaded
self.colbert_linear = None
# ColBERT only supports token_embed - it's fundamentally a per-token def _build_colbert_pooler(self, pooler_config: PoolerConfig) -> Pooler:
# embedding model. """Build pooler for ColBERT token embeddings.
When ``colbert_linear`` is set, it is used as the projector.
Otherwise ``pooler_for_token_embed`` falls back to auto-loading
sentence-transformers Dense layers (``1_Dense/`` etc.).
"""
return pooler_for_token_embed( return pooler_for_token_embed(
pooler_config, pooler_config,
projector=self.colbert_linear, projector=self.colbert_linear,
) )
# --------------------------------------------------------- config helper
@classmethod
def get_colbert_dim_from_config(cls, hf_config) -> int | None:
"""Extract ColBERT dimension from a HuggingFace config.
Checks ``colbert_dim``, ``dim`` and ``projection_dim`` in that order.
"""
return (
getattr(hf_config, "colbert_dim", None)
or getattr(hf_config, "dim", None)
or getattr(hf_config, "projection_dim", None)
)
# -------------------------------------------------------- weight loading
def _load_colbert_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
colbert_weight_names: tuple[str, ...] = (
"linear.weight",
"colbert_linear.weight",
),
) -> tuple[list[tuple[str, torch.Tensor]], set[str]]:
"""Separate and load ColBERT projection weights.
Scans *weights* for entries whose name ends with one of
*colbert_weight_names*. The matching weight is loaded into
``self.colbert_linear`` (creating it first if ``colbert_dim`` was
not known at init time).
Args:
weights: Iterable of ``(name, tensor)`` weight pairs.
colbert_weight_names: Suffixes that identify the ColBERT linear
weight.
Returns:
``(remaining_weights, loaded_names)`` — the weights that were
**not** consumed and the set of names that were loaded.
"""
weights_list = list(weights)
other_weights: list[tuple[str, torch.Tensor]] = []
colbert_weight: tuple[str, torch.Tensor] | None = None
for name, weight in weights_list:
if any(name.endswith(cw) for cw in colbert_weight_names):
colbert_weight = (name, weight)
else:
other_weights.append((name, weight))
loaded: set[str] = set()
if colbert_weight is not None:
_name, weight = colbert_weight
if weight.dim() == 2:
# Infer colbert_dim from weight shape if not set
if self.colbert_dim is None:
self.colbert_dim = weight.shape[0]
self.colbert_linear = self._build_colbert_linear()
# Update the pooler's projector
if hasattr(self, "pooler") and hasattr(self.pooler, "head"):
self.pooler.head.projector = self.colbert_linear
assert self.colbert_linear is not None
# Move to same device as model
if hasattr(self, "model"):
device = next(self.model.parameters()).device
self.colbert_linear.to(device)
weight = weight.to(self.colbert_linear.weight.device)
self.colbert_linear.weight.data.copy_(weight)
loaded.add("pooler.head.projector.weight")
return other_weights, loaded
# -----------------------------------------------------------------------
# Concrete model: ColBERT + BERT backbone (original architecture)
# -----------------------------------------------------------------------
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
class ColBERTModel(ColBERTMixin, BertEmbeddingModel):
"""ColBERT late interaction model with BERT backbone.
Supports the ``token_embed`` task (per-token embeddings for late
interaction). MaxSim scoring is computed externally.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
# Must run before super().__init__ because _build_pooler reads these.
colbert_dim = self.get_colbert_dim_from_config(config)
self._init_colbert_components(
hidden_size=config.hidden_size,
colbert_dim=colbert_dim,
head_dtype=vllm_config.model_config.head_dtype,
)
super().__init__(vllm_config=vllm_config, prefix=prefix)
def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> BertModel:
return BertModel(vllm_config=vllm_config, prefix=prefix)
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
return self._build_colbert_pooler(pooler_config)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
def _strip(name: str) -> str: def _strip(name: str) -> str:
for p in ("model.", "bert."): for p in ("model.", "bert."):
...@@ -111,7 +231,7 @@ class ColBERTModel(BertEmbeddingModel): ...@@ -111,7 +231,7 @@ class ColBERTModel(BertEmbeddingModel):
for name, weight in weights_list: for name, weight in weights_list:
stripped = _strip(name) stripped = _strip(name)
# Handle different checkpoint naming conventions for ColBERT linear # Handle different checkpoint naming conventions
if stripped in ("linear.weight", "colbert_linear.weight"): if stripped in ("linear.weight", "colbert_linear.weight"):
colbert_side.append(("colbert_linear.weight", weight)) colbert_side.append(("colbert_linear.weight", weight))
elif stripped.startswith("linear.") or stripped.startswith( elif stripped.startswith("linear.") or stripped.startswith(
...@@ -122,31 +242,178 @@ class ColBERTModel(BertEmbeddingModel): ...@@ -122,31 +242,178 @@ class ColBERTModel(BertEmbeddingModel):
else: else:
model_side.append((stripped, weight)) model_side.append((stripped, weight))
# Load base BERT weights using BertModel.load_weights which handles QKV fusion
loaded: set[str] = set() loaded: set[str] = set()
loaded_model = self.model.load_weights(model_side) loaded_model = self.model.load_weights(model_side)
loaded.update({"model." + n for n in loaded_model}) loaded.update({"model." + n for n in loaded_model})
# Load ColBERT linear weights
if colbert_side: if colbert_side:
for name, weight in colbert_side: _, colbert_loaded = self._load_colbert_weights(colbert_side)
if name == "colbert_linear.weight": loaded.update(colbert_loaded)
# Infer colbert_dim from weights if not set in config
if self.colbert_dim is None: return loaded
# Weight shape is [colbert_dim, hidden_size]
self.colbert_dim = weight.shape[0]
# Create the linear layer now that we know the dimension # -----------------------------------------------------------------------
self.colbert_linear = self._build_colbert_linear() # Concrete model: ColBERT + ModernBERT backbone
# Move to the same device as the model's existing parameters # -----------------------------------------------------------------------
device = next(self.model.parameters()).device
self.colbert_linear.to(device)
# Update the pooler's projector to use the new linear layer
self.pooler.head.projector = self.colbert_linear
# Load weights directly into the pooler's projector from .modernbert import ModernBertModel # noqa: E402
weight = weight.to(self.pooler.head.projector.weight.device)
self.pooler.head.projector.weight.data.copy_(weight)
loaded.add("pooler.head.projector.weight") @default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
break class ColBERTModernBertModel(ColBERTMixin, nn.Module):
"""ColBERT late interaction model with ModernBERT backbone.
For ``lightonai/GTE-ModernColBERT-v1`` and similar models.
The projection is auto-loaded from sentence-transformers ``1_Dense/``
when not present in the main checkpoint.
"""
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
colbert_dim = self.get_colbert_dim_from_config(config)
self._init_colbert_components(
hidden_size=config.hidden_size,
colbert_dim=colbert_dim,
head_dtype=vllm_config.model_config.head_dtype,
)
self.model = ModernBertModel(
vllm_config=vllm_config,
prefix=prefix,
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = self._build_colbert_pooler(pooler_config)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors=None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
return self.model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
other_weights, colbert_loaded = self._load_colbert_weights(weights)
# Strip "model." prefix added by the embedding adapter
model_weights = [
(n[len("model.") :] if n.startswith("model.") else n, w)
for n, w in other_weights
]
loaded_model = self.model.load_weights(model_weights)
loaded = {"model." + n for n in loaded_model} | colbert_loaded
# When the ST projector was auto-loaded during init
# (not from the main checkpoint), mark its params as loaded
# so the weight validator doesn't complain.
if hasattr(self.pooler, "head"):
head = self.pooler.head
projector = getattr(head, "projector", None)
if projector is not None and isinstance(projector, nn.Module):
for name, _ in projector.named_parameters():
loaded.add(f"pooler.head.projector.{name}")
return loaded
# -----------------------------------------------------------------------
# Concrete model: ColBERT + Jina XLM-RoBERTa backbone
# -----------------------------------------------------------------------
from .bert_with_rope import JinaRobertaModel # noqa: E402
@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
class ColBERTJinaRobertaModel(ColBERTMixin, nn.Module):
"""ColBERT late interaction model with Jina XLM-RoBERTa backbone.
For ``jinaai/jina-colbert-v2`` and similar models.
"""
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
colbert_dim = self.get_colbert_dim_from_config(config)
self._init_colbert_components(
hidden_size=config.hidden_size,
colbert_dim=colbert_dim,
head_dtype=vllm_config.model_config.head_dtype,
)
self.model = JinaRobertaModel(
vllm_config=vllm_config,
prefix=prefix,
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = self._build_colbert_pooler(pooler_config)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors=None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
return self.model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights_list = list(weights)
model_side: list[tuple[str, torch.Tensor]] = []
colbert_side: list[tuple[str, torch.Tensor]] = []
for name, weight in weights_list:
stripped = name
# Strip "model." prefix added by the embedding adapter
if stripped.startswith("model."):
stripped = stripped[len("model.") :]
# Strip "roberta." prefix from checkpoint
if stripped.startswith("roberta."):
stripped = stripped[len("roberta.") :]
if stripped in ("linear.weight", "colbert_linear.weight"):
colbert_side.append(("colbert_linear.weight", weight))
elif stripped.startswith("pooler."):
# Skip HF pooler weights (not used in ColBERT)
continue
else:
model_side.append((stripped, weight))
loaded: set[str] = set()
loaded_model = self.model.load_weights(model_side)
loaded.update({"model." + n for n in loaded_model})
if colbert_side:
_, colbert_loaded = self._load_colbert_weights(colbert_side)
loaded.update(colbert_loaded)
return loaded return loaded
...@@ -629,6 +629,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { ...@@ -629,6 +629,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig, "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
"Qwen3VLForSequenceClassification": Qwen3VLForSequenceClassificationConfig, "Qwen3VLForSequenceClassification": Qwen3VLForSequenceClassificationConfig,
"XLMRobertaModel": JinaRobertaModelConfig, "XLMRobertaModel": JinaRobertaModelConfig,
"ColBERTJinaRobertaModel": JinaRobertaModelConfig,
"JinaVLForRanking": JinaVLForSequenceClassificationConfig, "JinaVLForRanking": JinaVLForSequenceClassificationConfig,
"JambaForSequenceClassification": JambaForSequenceClassificationConfig, "JambaForSequenceClassification": JambaForSequenceClassificationConfig,
"GptOssForCausalLM": GptOssForCausalLMConfig, "GptOssForCausalLM": GptOssForCausalLMConfig,
......
...@@ -208,6 +208,8 @@ _EMBEDDING_MODELS = { ...@@ -208,6 +208,8 @@ _EMBEDDING_MODELS = {
"BertModel": ("bert", "BertEmbeddingModel"), "BertModel": ("bert", "BertEmbeddingModel"),
"BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"), "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
"HF_ColBERT": ("colbert", "ColBERTModel"), "HF_ColBERT": ("colbert", "ColBERTModel"),
"ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"),
"ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3TextModel": ("gemma3", "Gemma3Model"), "Gemma3TextModel": ("gemma3", "Gemma3Model"),
......
...@@ -1068,9 +1068,11 @@ def try_get_dense_modules( ...@@ -1068,9 +1068,11 @@ def try_get_dense_modules(
if isinstance(modules, dict): if isinstance(modules, dict):
modules = modules.get("modules", []) modules = modules.get("modules", [])
dense_modules = [ _DENSE_MODULE_TYPES = {
m for m in modules if m.get("type") == "sentence_transformers.models.Dense" "sentence_transformers.models.Dense",
] "pylate.models.Dense.Dense",
}
dense_modules = [m for m in modules if m.get("type") in _DENSE_MODULE_TYPES]
if not dense_modules: if not dense_modules:
return None return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment