"examples/vscode:/vscode.git/clone" did not exist on "a00d88973daf9a151ecbd4c740ca99645715b9df"
Unverified Commit ff365eea authored by Maximilien de Bayser's avatar Maximilien de Bayser Committed by GitHub
Browse files

Support bge-m3 sparse embeddings and colbert embeddings (#14526)


Signed-off-by: default avatarMax de Bayser <mbayser@br.ibm.com>
Signed-off-by: default avatarMax de Bayser <maxdebayser@gmail.com>
parent 444e2e7e
...@@ -305,6 +305,44 @@ Expected output: ...@@ -305,6 +305,44 @@ Expected output:
An OpenAI client example can be found here: [examples/pooling/embed/openai_embedding_matryoshka_fy_client.py](../../examples/pooling/embed/openai_embedding_matryoshka_fy_client.py) An OpenAI client example can be found here: [examples/pooling/embed/openai_embedding_matryoshka_fy_client.py](../../examples/pooling/embed/openai_embedding_matryoshka_fy_client.py)
## Specific models
### BAAI/bge-m3
The `BAAI/bge-m3` model comes with extra weights for sparse and colbert embeddings but unfortunately in its `config.json`
the architecture is declared as `XLMRobertaModel`, which makes `vLLM` load it as a vanilla ROBERTA model without the
extra weights. To load the full model weights, override its architecture like this:
```shell
vllm serve BAAI/bge-m3 --hf-overrides '{"architectures": ["BgeM3EmbeddingModel"]}'
```
Then you obtain the sparse embeddings like this:
```shell
curl -s http://localhost:8000/pooling -H "Content-Type: application/json" -d '{
"model": "BAAI/bge-m3",
"task": "token_classify",
"input": ["What is BGE M3?", "Defination of BM25"]
}'
```
Due to limitations in the the output schema, the output consists of a list of
token scores for each token for each input. This means that you'll have to call
`/tokenize` as well to be able to pair tokens with scores.
Refer to the tests in `tests/models/language/pooling/test_bge_m3.py` to see how
to do that.
You can obtain the colbert embeddings like this:
```shell
curl -s http://localhost:8000/pooling -H "Content-Type: application/json" -d '{
"model": "BAAI/bge-m3",
"task": "token_embed",
"input": ["What is BGE M3?", "Defination of BM25"]
}'
```
## Deprecated Features ## Deprecated Features
### Encode task ### Encode task
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence from collections.abc import Sequence
import openai
import pytest import pytest
from tests.conftest import HfRunner from tests.conftest import HfRunner
...@@ -65,3 +66,16 @@ def correctness_test_embed_models( ...@@ -65,3 +66,16 @@ def correctness_test_embed_models(
hf_model_callback(hf_model) hf_model_callback(hf_model)
run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs) run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs)
async def run_client_embeddings(
client: openai.AsyncOpenAI,
model_name: str,
queries: list[str],
instruction: str = "",
) -> list[list[float]]:
outputs = await client.embeddings.create(
model=model_name,
input=[instruction + q for q in queries],
)
return [data.embedding for data in outputs.data]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import httpx
import openai
import pytest
import pytest_asyncio
import torch
from ....utils import RemoteOpenAIServer
from .embed_utils import run_client_embeddings
MODEL_NAME = "BAAI/bge-m3"
MAX_MODEL_LEN = 512
# Example from https://huggingface.co/BAAI/bge-m3
sentences_1 = ["What is BGE M3?", "Defination of BM25"]
sentences_2 = [
"BGE M3 is an embedding model supporting dense retrieval, "
"lexical matching and multi-vector interaction.",
"BM25 is a bag-of-words retrieval function that ranks a set "
"of documents based on the query terms appearing in each document",
]
similarity_reference = [[0.6265, 0.3477], [0.3499, 0.678]]
lexical_score_reference = [0.19554901123046875, 0.0]
colbert_score_reference = [0.7797, 0.4620]
@pytest.fixture(scope="module")
def server():
args = [
"--max-model-len",
str(MAX_MODEL_LEN),
"--hf-overrides",
'{"architectures": ["BgeM3EmbeddingModel"]}',
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
async def test_bge_m3_api_server_embedding(client: openai.AsyncOpenAI):
embeddings_list_1 = await run_client_embeddings(
client,
MODEL_NAME,
sentences_1,
)
embeddings_list_2 = await run_client_embeddings(
client,
MODEL_NAME,
sentences_2,
)
embeddings_1 = torch.tensor(embeddings_list_1)
embeddings_2 = torch.tensor(embeddings_list_2)
similarity = embeddings_1 @ embeddings_2.T
# reference values from BAAI/bge-m3 documentation
reference = torch.tensor(similarity_reference)
assert torch.allclose(similarity, reference, rtol=0.01)
async def tokenize(client: openai.AsyncOpenAI, sentences: list[str]) -> list[list[int]]:
futures = []
for sentence in sentences:
futures.append(
client.post(
"../tokenize",
body={"model": MODEL_NAME, "prompt": sentence},
cast_to=httpx.Response,
)
)
return [(await future).json()["tokens"] for future in futures]
async def sparse_embeddings(
client: openai.AsyncOpenAI, sentences: list[str]
) -> list[dict[int, float]]:
all_tokens = await tokenize(client, sentences)
result = await client.post(
"../pooling",
body={"model": MODEL_NAME, "input": sentences, "task": "token_classify"},
cast_to=httpx.Response,
)
all_embeddings = [data["data"] for data in result.json()["data"]]
ret = []
for sent_tokens, sent_emb in zip(all_tokens, all_embeddings):
token_embs = dict[int, float]()
if sent_tokens[0] == 0:
sent_tokens = sent_tokens[1:]
for token, val in zip(sent_tokens, sent_emb):
token_embs[token] = max(val, token_embs.get(token, 0.0))
ret.append(token_embs)
return ret
# Based on https://github.com/FlagOpen/FlagEmbedding/blob/6fd176266f2382878bcc69cd656cff425d52f49b/FlagEmbedding/inference/embedder/encoder_only/m3.py#L129
def compute_lexical_matching_score(
lw1: dict[int, float], lw2: dict[int, float]
) -> float:
scores = 0.0
for token, weight in lw1.items():
if token in lw2:
scores += weight * lw2[token]
return scores
@pytest.mark.asyncio
async def test_bge_m3_api_server_sparse_embedding(client: openai.AsyncOpenAI):
embeddings_1 = await sparse_embeddings(client, sentences_1)
embeddings_2 = await sparse_embeddings(client, sentences_2)
lexical_scores_1_0_x_2_0 = compute_lexical_matching_score(
embeddings_1[0], embeddings_2[0]
)
assert lexical_scores_1_0_x_2_0 == pytest.approx(
lexical_score_reference[0], rel=0.01
)
lexical_scores_1_0_x_1_1 = compute_lexical_matching_score(
embeddings_1[0], embeddings_1[1]
)
assert lexical_scores_1_0_x_1_1 == pytest.approx(
lexical_score_reference[1], rel=0.01
)
# https://github.com/FlagOpen/FlagEmbedding/blob/6fd176266f2382878bcc69cd656cff425d52f49b/FlagEmbedding/inference/embedder/encoder_only/m3.py#L163
def colbert_score(q_reps: torch.Tensor, p_reps: torch.Tensor) -> torch.Tensor:
token_scores = torch.einsum("in,jn->ij", q_reps, p_reps)
scores, _ = token_scores.max(-1)
scores = torch.sum(scores) / q_reps.size(0)
return scores
@pytest.mark.asyncio
async def test_bge_m3_api_server_multi_vector(client: openai.AsyncOpenAI):
result_1 = await client.post(
"../pooling",
body={"model": MODEL_NAME, "input": sentences_1, "task": "token_embed"},
cast_to=httpx.Response,
)
embeddings_1 = [torch.tensor(data["data"]) for data in result_1.json()["data"]]
result_2 = await client.post(
"../pooling",
body={"model": MODEL_NAME, "input": sentences_2, "task": "token_embed"},
cast_to=httpx.Response,
)
embeddings_2 = [torch.tensor(data["data"]) for data in result_2.json()["data"]]
colbert_score_1_0_x_2_0 = colbert_score(embeddings_1[0], embeddings_2[0])
assert colbert_score_1_0_x_2_0 == pytest.approx(
colbert_score_reference[0], rel=0.01
)
colbert_score_1_0_x_2_1 = colbert_score(embeddings_1[0], embeddings_2[1])
assert colbert_score_1_0_x_2_1 == pytest.approx(
colbert_score_reference[1], rel=0.01
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np import numpy as np
import openai
import pytest import pytest
from scipy.spatial.distance import cosine from scipy.spatial.distance import cosine
...@@ -9,6 +8,7 @@ from vllm import LLM, SamplingParams ...@@ -9,6 +8,7 @@ from vllm import LLM, SamplingParams
from vllm.config import ModelConfig from vllm.config import ModelConfig
from ....utils import RemoteOpenAIServer from ....utils import RemoteOpenAIServer
from .embed_utils import run_client_embeddings
MODEL_NAME = "parasail-ai/GritLM-7B-vllm" MODEL_NAME = "parasail-ai/GritLM-7B-vllm"
MAX_MODEL_LEN = 4000 MAX_MODEL_LEN = 4000
...@@ -55,18 +55,6 @@ def run_llm_encode( ...@@ -55,18 +55,6 @@ def run_llm_encode(
return [output.outputs.embedding for output in outputs] return [output.outputs.embedding for output in outputs]
async def run_client_embeddings(
client: openai.AsyncOpenAI,
queries: list[str],
instruction: str,
) -> list[list[float]]:
outputs = await client.embeddings.create(
model=MODEL_NAME,
input=[instruction + q for q in queries],
)
return [data.embedding for data in outputs.data]
def gritlm_instruction(instruction): def gritlm_instruction(instruction):
return ( return (
"<|user|>\n" + instruction + "\n<|embed|>\n" if instruction else "<|embed|>\n" "<|user|>\n" + instruction + "\n<|embed|>\n" if instruction else "<|embed|>\n"
...@@ -145,11 +133,13 @@ async def test_gritlm_api_server_embedding(): ...@@ -145,11 +133,13 @@ async def test_gritlm_api_server_embedding():
d_rep = await run_client_embeddings( d_rep = await run_client_embeddings(
client_embedding, client_embedding,
MODEL_NAME,
documents, documents,
d_instruction, d_instruction,
) )
q_rep = await run_client_embeddings( q_rep = await run_client_embeddings(
client_embedding, client_embedding,
MODEL_NAME,
queries, queries,
q_instruction, q_instruction,
) )
......
...@@ -513,6 +513,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -513,6 +513,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
_EMBEDDING_EXAMPLE_MODELS = { _EMBEDDING_EXAMPLE_MODELS = {
# [Text-only] # [Text-only]
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"BgeM3EmbeddingModel": _HfExamplesInfo("BAAI/bge-m3"),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
"Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"), "Gemma3TextModel": _HfExamplesInfo("google/embeddinggemma-300m"),
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
......
...@@ -125,4 +125,49 @@ class IdentityPooler(Pooler): ...@@ -125,4 +125,49 @@ class IdentityPooler(Pooler):
return hidden_states return hidden_states
__all__ = ["DispatchPooler", "IdentityPooler"] class BOSEOSFilter(Pooler):
"""Filters the BOS and EOS token results from outputs."""
def __init__(
self,
pooler: Pooler,
bos_token_id: int = -1, # -1 disables the filtering
eos_token_id: int = -1,
) -> None:
super().__init__()
self.pooler = pooler
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooler.get_supported_tasks()
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return PoolingParamsUpdate(requires_token_ids=True)
def forward(
self,
hidden_states: torch.Tensor | list[torch.Tensor],
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
pooled_outputs = self.pooler(hidden_states, pooling_metadata)
assert isinstance(pooled_outputs, list)
for i, prompt_len in enumerate(pooling_metadata.prompt_lens):
pooled_data = pooled_outputs[i]
assert (
isinstance(pooled_data, torch.Tensor)
and pooled_data.shape[0] == prompt_len
)
token_ids = pooling_metadata.prompt_token_ids[i, :prompt_len]
if token_ids[0] == self.bos_token_id:
pooled_data = pooled_data[1:]
if token_ids[-1] == self.eos_token_id:
pooled_data = pooled_data[:-1]
pooled_outputs[i] = pooled_data.squeeze()
return pooled_outputs
__all__ = ["BOSEOSFilter", "DispatchPooler", "IdentityPooler"]
...@@ -6,7 +6,11 @@ from typing import TypeAlias ...@@ -6,7 +6,11 @@ from typing import TypeAlias
import torch import torch
from vllm.config import PoolerConfig, get_current_vllm_config from vllm.config import PoolerConfig, get_current_vllm_config
from vllm.model_executor.layers.pooler import ClassifierFn, PoolingParamsUpdate from vllm.model_executor.layers.pooler import (
ClassifierFn,
PoolingParamsUpdate,
ProjectorFn,
)
from vllm.model_executor.layers.pooler.abstract import Pooler from vllm.model_executor.layers.pooler.abstract import Pooler
from vllm.model_executor.layers.pooler.activations import ( from vllm.model_executor.layers.pooler.activations import (
PoolerActivation, PoolerActivation,
...@@ -89,14 +93,18 @@ class TokenPooler(Pooler): ...@@ -89,14 +93,18 @@ class TokenPooler(Pooler):
return pooled_data return pooled_data
def pooler_for_token_embed(pooler_config: PoolerConfig): def pooler_for_token_embed(
pooler_config: PoolerConfig, projector: ProjectorFn | None = None
) -> TokenPooler:
pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type()) pooling = get_tok_pooling_method(pooler_config.get_tok_pooling_type())
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
model_config = vllm_config.model_config model_config = vllm_config.model_config
head = TokenEmbeddingPoolerHead( head = TokenEmbeddingPoolerHead(
head_dtype=model_config.head_dtype, head_dtype=model_config.head_dtype,
projector=_load_st_projector(model_config), projector=projector
if projector is not None
else _load_st_projector(model_config),
activation=PoolerNormalize(), activation=PoolerNormalize(),
) )
......
...@@ -234,6 +234,7 @@ _EMBEDDING_MODELS = { ...@@ -234,6 +234,7 @@ _EMBEDDING_MODELS = {
"TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
"BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
# [Multimodal] # [Multimodal]
"CLIPModel": ("clip", "CLIPEmbeddingModel"), "CLIPModel": ("clip", "CLIPEmbeddingModel"),
"LlavaNextForConditionalGeneration": ( "LlavaNextForConditionalGeneration": (
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from collections.abc import Iterable from collections.abc import Iterable
import torch import torch
from torch import nn from torch import nn
from transformers import RobertaConfig from transformers import RobertaConfig
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, PoolerConfig, VllmConfig
from vllm.model_executor.layers.pooler import DispatchPooler from vllm.model_executor.layers.pooler import (
BOSEOSFilter,
DispatchPooler,
Pooler,
)
from vllm.model_executor.layers.pooler.seqwise import (
pooler_for_embed,
)
from vllm.model_executor.layers.pooler.tokwise import (
AllPool,
pooler_for_token_classify,
pooler_for_token_embed,
)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.bert import ( from vllm.model_executor.models.bert import (
TOKEN_TYPE_SHIFT, TOKEN_TYPE_SHIFT,
BertEmbeddingModel, BertEmbeddingModel,
...@@ -149,6 +164,98 @@ class RobertaEmbeddingModel(BertEmbeddingModel): ...@@ -149,6 +164,98 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
return loader.load_weights(weights_list, mapper=mapper) return loader.load_weights(weights_list, mapper=mapper)
def filter_secondary_weights(
all_weights: Iterable[tuple[str, torch.Tensor]],
secondary_weights: list[str],
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, torch.Tensor]]]:
all_weights1, all_weights2 = itertools.tee(all_weights)
def filtered(n):
return any(n.startswith(f) for f in secondary_weights)
return ((n, w) for n, w in all_weights1 if filtered(n)), (
(n, w) for n, w in all_weights2 if not filtered(n)
)
class BgeM3EmbeddingModel(RobertaEmbeddingModel):
"""A model that extends RobertaEmbeddingModel with sparse embeddings.
This class supports loading an additional sparse_linear.pt file
to create sparse embeddings as described in https://arxiv.org/abs/2402.03216
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.hidden_size = vllm_config.model_config.hf_config.hidden_size
model_config = vllm_config.model_config
self.head_dtype = model_config.head_dtype
self.bos_token_id = model_config.hf_config.bos_token_id
self.eos_token_id = model_config.hf_config.eos_token_id
super().__init__(vllm_config=vllm_config, prefix=prefix)
self.secondary_weight_prefixes = ["sparse_linear.", "colbert_linear."]
self.secondary_weight_files = [
prefix + "pt" for prefix in self.secondary_weight_prefixes
]
self.secondary_weights = [
DefaultModelLoader.Source(
model_or_path=vllm_config.model_config.model,
revision=None,
prefix=prefix,
allow_patterns_overrides=[filename],
)
for filename, prefix in zip(
self.secondary_weight_files, self.secondary_weight_prefixes
)
]
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
self.sparse_linear = nn.Linear(self.hidden_size, 1, dtype=self.head_dtype)
self.colbert_linear = nn.Linear(
self.hidden_size, self.hidden_size, dtype=self.head_dtype
)
return DispatchPooler(
{
"embed": pooler_for_embed(pooler_config),
"token_embed": BOSEOSFilter(
pooler_for_token_embed(pooler_config, self.colbert_linear),
self.bos_token_id,
# for some reason m3 only filters the bos for colbert vectors
),
"token_classify": BOSEOSFilter(
pooler_for_token_classify(
pooler_config,
pooling=AllPool(),
classifier=self.sparse_linear,
act_fn=torch.relu,
),
self.bos_token_id,
self.eos_token_id,
),
}
)
def load_weights(self, all_weights: Iterable[tuple[str, torch.Tensor]]):
secondary, weights = filter_secondary_weights(
all_weights, self.secondary_weight_prefixes
)
super().load_weights(weights)
params_dict = dict(self.named_parameters())
for name, loaded_weight in secondary:
if any(
name.startswith(prefix) for prefix in self.secondary_weight_prefixes
):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
@default_pooling_type(seq_pooling_type="CLS") @default_pooling_type(seq_pooling_type="CLS")
class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
"""A model that uses Roberta to provide embedding functionalities. """A model that uses Roberta to provide embedding functionalities.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment