Unverified Commit 2cdf8604 authored by Abhijit Roy's avatar Abhijit Roy Committed by GitHub
Browse files

Add Jina Embeddings v5 model support (fixes #38633) (#39575)


Signed-off-by: default avatarAbhijit <abroy@redhat.com>
Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
Co-authored-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent 78453792
...@@ -45,6 +45,7 @@ You can compute pairwise similarity scores to build a similarity matrix using th ...@@ -45,6 +45,7 @@ You can compute pairwise similarity scores to build a similarity matrix using th
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ |
| `GteModel` | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | | `GteModel` | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | |
| `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | | `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | |
| `JinaEmbeddingsV5Model`<sup>C</sup> | Qwen3-based with task-specific LoRA adapters | `jinaai/jina-embeddings-v5-text-small` (see note) | ✅︎ | ✅︎ |
| `LlamaBidirectionalModel`<sup>C</sup> | Llama-based with bidirectional attention | `nvidia/llama-nemotron-embed-1b-v2`, etc. | ✅︎ | ✅︎ | | `LlamaBidirectionalModel`<sup>C</sup> | Llama-based with bidirectional attention | `nvidia/llama-nemotron-embed-1b-v2`, etc. | ✅︎ | ✅︎ |
| `LlamaModel`<sup>C</sup>, `LlamaForCausalLM`<sup>C</sup>, `MistralModel`<sup>C</sup>, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | | `LlamaModel`<sup>C</sup>, `LlamaForCausalLM`<sup>C</sup>, `MistralModel`<sup>C</sup>, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ |
| `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | | | `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | |
...@@ -73,6 +74,12 @@ You can compute pairwise similarity scores to build a similarity matrix using th ...@@ -73,6 +74,12 @@ You can compute pairwise similarity scores to build a similarity matrix using th
!!! note !!! note
`jinaai/jina-embeddings-v3` supports multiple tasks through LoRA, while vllm temporarily only supports text-matching tasks by merging LoRA weights. `jinaai/jina-embeddings-v3` supports multiple tasks through LoRA, while vllm temporarily only supports text-matching tasks by merging LoRA weights.
!!! note
`jinaai/jina-embeddings-v5-text-small` ships with four task-specific LoRA adapters
(`retrieval`, `text-matching`, `classification`, `clustering`). vLLM merges the
selected adapter into the base weights at load time. Choose the task with
`--hf-overrides '{"jina_task": "<task>"}'`; the default is `retrieval`.
### Multimodal Models ### Multimodal Models
!!! note !!! note
......
...@@ -364,6 +364,7 @@ class HfRunner: ...@@ -364,6 +364,7 @@ class HfRunner:
model_name: str, model_name: str,
dtype: str = "auto", dtype: str = "auto",
*, *,
revision: str | None = None,
model_kwargs: dict[str, Any] | None = None, model_kwargs: dict[str, Any] | None = None,
trust_remote_code: bool = True, trust_remote_code: bool = True,
is_sentence_transformer: bool = False, is_sentence_transformer: bool = False,
...@@ -383,6 +384,7 @@ class HfRunner: ...@@ -383,6 +384,7 @@ class HfRunner:
self._init( self._init(
model_name=model_name, model_name=model_name,
dtype=dtype, dtype=dtype,
revision=revision,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
is_sentence_transformer=is_sentence_transformer, is_sentence_transformer=is_sentence_transformer,
...@@ -396,6 +398,7 @@ class HfRunner: ...@@ -396,6 +398,7 @@ class HfRunner:
model_name: str, model_name: str,
dtype: str = "auto", dtype: str = "auto",
*, *,
revision: str | None = None,
model_kwargs: dict[str, Any] | None = None, model_kwargs: dict[str, Any] | None = None,
trust_remote_code: bool = True, trust_remote_code: bool = True,
is_sentence_transformer: bool = False, is_sentence_transformer: bool = False,
...@@ -437,6 +440,7 @@ class HfRunner: ...@@ -437,6 +440,7 @@ class HfRunner:
self.model = SentenceTransformer( self.model = SentenceTransformer(
model_name, model_name,
revision=revision,
device=self.device, device=self.device,
model_kwargs=model_kwargs, model_kwargs=model_kwargs,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
...@@ -447,6 +451,7 @@ class HfRunner: ...@@ -447,6 +451,7 @@ class HfRunner:
self.model = CrossEncoder( self.model = CrossEncoder(
model_name, model_name,
revision=revision,
device=self.device, device=self.device,
automodel_args=model_kwargs, automodel_args=model_kwargs,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
...@@ -456,6 +461,7 @@ class HfRunner: ...@@ -456,6 +461,7 @@ class HfRunner:
nn.Module, nn.Module,
auto_cls.from_pretrained( auto_cls.from_pretrained(
model_name, model_name,
revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
**model_kwargs, **model_kwargs,
), ),
......
...@@ -74,10 +74,25 @@ class MtebEmbedMixin(mteb.EncoderProtocol): ...@@ -74,10 +74,25 @@ class MtebEmbedMixin(mteb.EncoderProtocol):
return sim return sim
class HfMtebEncoder(MtebEmbedMixin):
def __init__(self, model):
self.model = model
def encode(
self,
inputs: DataLoader[mteb.types.BatchedInput],
*args,
**kwargs,
) -> np.ndarray:
sentences = [text for batch in inputs for text in batch["text"]]
return self.model.encode(sentences)
class VllmMtebEncoder(MtebEmbedMixin): class VllmMtebEncoder(MtebEmbedMixin):
def __init__(self, vllm_model): def __init__(self, vllm_model, prompt_prefix: str | None = None):
self.llm = vllm_model self.llm = vllm_model
self.rng = np.random.default_rng(seed=42) self.rng = np.random.default_rng(seed=42)
self.prompt_prefix = prompt_prefix
def encode( def encode(
self, self,
...@@ -87,7 +102,11 @@ class VllmMtebEncoder(MtebEmbedMixin): ...@@ -87,7 +102,11 @@ class VllmMtebEncoder(MtebEmbedMixin):
) -> np.ndarray: ) -> np.ndarray:
# Hoping to discover potential scheduling # Hoping to discover potential scheduling
# issues by randomizing the order. # issues by randomizing the order.
sentences = [text for batch in inputs for text in batch["text"]] sentences = [
self.prompt_prefix + text if self.prompt_prefix else text
for batch in inputs
for text in batch["text"]
]
r = self.rng.permutation(len(sentences)) r = self.rng.permutation(len(sentences))
sentences = [sentences[i] for i in r] sentences = [sentences[i] for i in r]
outputs = self.llm.embed(sentences, use_tqdm=False) outputs = self.llm.embed(sentences, use_tqdm=False)
...@@ -143,6 +162,7 @@ def mteb_test_embed_models( ...@@ -143,6 +162,7 @@ def mteb_test_embed_models(
vllm_extra_kwargs=None, vllm_extra_kwargs=None,
hf_model_callback=None, hf_model_callback=None,
atol=MTEB_EMBED_TOL, atol=MTEB_EMBED_TOL,
prompt_prefix: str | None = None,
): ):
vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs) vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs)
...@@ -182,7 +202,7 @@ def mteb_test_embed_models( ...@@ -182,7 +202,7 @@ def mteb_test_embed_models(
) )
vllm_main_score = run_mteb_embed_task( vllm_main_score = run_mteb_embed_task(
VllmMtebEncoder(vllm_model), MTEB_EMBED_TASKS VllmMtebEncoder(vllm_model, prompt_prefix=prompt_prefix), MTEB_EMBED_TASKS
) )
vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype
head_dtype = model_config.head_dtype head_dtype = model_config.head_dtype
...@@ -210,7 +230,9 @@ def mteb_test_embed_models( ...@@ -210,7 +230,9 @@ def mteb_test_embed_models(
if hf_model_callback is not None: if hf_model_callback is not None:
hf_model_callback(hf_model) hf_model_callback(hf_model)
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS) st_main_score = run_mteb_embed_task(
HfMtebEncoder(hf_model), MTEB_EMBED_TASKS
)
st_dtype = next(hf_model.model.parameters()).dtype st_dtype = next(hf_model.model.parameters()).dtype
# Check embeddings close to hf outputs # Check embeddings close to hf outputs
......
...@@ -28,7 +28,16 @@ EMBEDDING_MODELS = [ ...@@ -28,7 +28,16 @@ EMBEDDING_MODELS = [
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
) ),
EmbedModelInfo(
"jinaai/jina-embeddings-v5-text-small",
mteb_score=0.794535707854956,
architecture="JinaEmbeddingsV5Model",
seq_pooling_type="LAST",
attn_type="decoder",
is_prefix_caching_supported=True,
is_chunked_prefill_supported=True,
),
] ]
RERANK_MODELS = [ RERANK_MODELS = [
...@@ -46,11 +55,18 @@ RERANK_MODELS = [ ...@@ -46,11 +55,18 @@ RERANK_MODELS = [
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS) @pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None:
task = "retrieval" if "v5" in model_info.name else "text-matching"
prompt_prefix: str | None = "Document: " if "v5" in model_info.name else None
def hf_model_callback(model): def hf_model_callback(model):
model.encode = partial(model.encode, task="text-matching") model.encode = partial(model.encode, task=task)
mteb_test_embed_models( mteb_test_embed_models(
hf_runner, vllm_runner, model_info, hf_model_callback=hf_model_callback hf_runner,
vllm_runner,
model_info,
hf_model_callback=hf_model_callback,
prompt_prefix=prompt_prefix,
) )
...@@ -58,8 +74,10 @@ def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) - ...@@ -58,8 +74,10 @@ def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -
def test_embed_models_correctness( def test_embed_models_correctness(
hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts
) -> None: ) -> None:
task = "retrieval" if "v5" in model_info.name else "text-matching"
def hf_model_callback(model): def hf_model_callback(model):
model.encode = partial(model.encode, task="text-matching") model.encode = partial(model.encode, task=task)
correctness_test_embed_models( correctness_test_embed_models(
hf_runner, hf_runner,
...@@ -97,12 +115,14 @@ def test_matryoshka( ...@@ -97,12 +115,14 @@ def test_matryoshka(
# ST will strip the input texts, see test_embedding.py # ST will strip the input texts, see test_embedding.py
example_prompts = [str(s).strip() for s in example_prompts] example_prompts = [str(s).strip() for s in example_prompts]
task = "retrieval" if "v5" in model_info.name else "text-matching"
with hf_runner( with hf_runner(
model_info.name, model_info.name,
dtype=dtype, dtype=dtype,
is_sentence_transformer=True, is_sentence_transformer=True,
) as hf_model: ) as hf_model:
hf_outputs = hf_model.encode(example_prompts, task="text-matching") hf_outputs = hf_model.encode(example_prompts, task=task)
hf_outputs = matryoshka_fy(hf_outputs, dimensions) hf_outputs = matryoshka_fy(hf_outputs, dimensions)
with vllm_runner( with vllm_runner(
......
...@@ -609,6 +609,10 @@ _EMBEDDING_EXAMPLE_MODELS = { ...@@ -609,6 +609,10 @@ _EMBEDDING_EXAMPLE_MODELS = {
trust_remote_code=True, trust_remote_code=True,
hf_overrides={"architectures": ["GteNewModel"]}, hf_overrides={"architectures": ["GteNewModel"]},
), ),
"JinaEmbeddingsV5Model": _HfExamplesInfo(
"jinaai/jina-embeddings-v5-text-small",
trust_remote_code=True,
),
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"LlamaBidirectionalModel": _HfExamplesInfo( "LlamaBidirectionalModel": _HfExamplesInfo(
"nvidia/llama-nemotron-embed-1b-v2", trust_remote_code=True "nvidia/llama-nemotron-embed-1b-v2", trust_remote_code=True
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://huggingface.co/jinaai/jina-reranker-v3/blob/main/modeling.py # Adapted from https://huggingface.co/jinaai/jina-reranker-v3/blob/main/modeling.py
import json
import logging
from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
import torch import torch
from safetensors.torch import load as safetensors_load
from torch import nn from torch import nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask from vllm.tasks import PoolingTask
from vllm.transformers_utils.repo_utils import get_hf_file_bytes
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from ..layers.pooler import DispatchPooler from ..layers.pooler import DispatchPooler
...@@ -18,9 +23,12 @@ from ..layers.pooler.tokwise import ( ...@@ -18,9 +23,12 @@ from ..layers.pooler.tokwise import (
TokenPoolingMethodOutputItem, TokenPoolingMethodOutputItem,
) )
from .interfaces import SupportsLateInteraction from .interfaces import SupportsLateInteraction
from .qwen3 import Qwen3Model from .interfaces_base import VllmModelForPooling
from .qwen3 import Qwen3ForCausalLM, Qwen3Model
from .utils import AutoWeightsLoader, maybe_prefix from .utils import AutoWeightsLoader, maybe_prefix
logger = logging.getLogger(__name__)
class JinaForRanking(nn.Module, SupportsLateInteraction): class JinaForRanking(nn.Module, SupportsLateInteraction):
is_pooling_model = True is_pooling_model = True
...@@ -108,3 +116,143 @@ class JinaForRankingPool(StepPool): ...@@ -108,3 +116,143 @@ class JinaForRankingPool(StepPool):
embeds_list.append(embeds) embeds_list.append(embeds)
return embeds_list return embeds_list
# jina-embeddings-v5-text-small wraps Qwen3-0.6B-Base with four task-specific
# LoRA adapters. This implementation merges the selected adapter into the base
# weights at load time to avoid any runtime dependency on peft.
#
# Task selection:
# Pass --hf-overrides '{"jina_task": "retrieval"}' to select one of:
# retrieval (default), text-matching, classification, clustering.
_DEFAULT_TASK = "retrieval"
_SUPPORTED_TASKS = {"retrieval", "text-matching", "classification", "clustering"}
def _load_adapter(
model: str,
task: str,
revision: str | None,
) -> tuple[dict, dict[str, torch.Tensor]] | None:
"""Load adapter config and weights from a local path or HF repo.
Returns (adapter_config, adapter_weights) or None if not found.
"""
config_bytes = get_hf_file_bytes(
f"adapters/{task}/adapter_config.json",
model,
revision,
)
if config_bytes is None:
return None
adapter_config = json.loads(config_bytes)
weights_bytes = get_hf_file_bytes(
f"adapters/{task}/adapter_model.safetensors",
model,
revision,
)
if weights_bytes is None:
return None
adapter_weights = safetensors_load(weights_bytes)
return adapter_config, adapter_weights
def _build_lora_pairs(adapter_weights: dict) -> dict:
"""Group raw adapter tensors into {base_key: {"A": tensor, "B": tensor}} pairs.
Transforms adapter keys like:
base_model.model.layers.0.self_attn.q_proj.lora_A.weight
Into base keys like:
layers.0.self_attn.q_proj.weight
"""
lora_pairs = defaultdict(dict)
for key, tensor in adapter_weights.items():
clean_key = key
if clean_key.startswith("base_model.model."):
clean_key = clean_key[len("base_model.model.") :]
if ".lora_A." in clean_key:
base_key = clean_key.split(".lora_A.")[0] + ".weight"
lora_pairs[base_key]["A"] = tensor
elif ".lora_B." in clean_key:
base_key = clean_key.split(".lora_B.")[0] + ".weight"
lora_pairs[base_key]["B"] = tensor
return dict(lora_pairs)
class JinaEmbeddingsV5Model(Qwen3ForCausalLM, VllmModelForPooling):
"""Jina Embeddings V5 with task-specific LoRA adapters merged at load time.
Extends Qwen3ForCausalLM (the underlying architecture) and declares itself
as a pooling model so that as_embedding_model() does not wrap it.
"""
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self._model_name = vllm_config.model_config.model
self._revision = vllm_config.model_config.revision
self._task = getattr(
vllm_config.model_config.hf_config, "jina_task", _DEFAULT_TASK
)
if self._task not in _SUPPORTED_TASKS:
logger.warning(
"Unknown jina_task=%r. Falling back to %r.",
self._task,
_DEFAULT_TASK,
)
self._task = _DEFAULT_TASK
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler.for_embedding(pooler_config)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
lora_pairs: dict = {}
scaling = 1.0
result = _load_adapter(self._model_name, self._task, self._revision)
if result is None:
logger.warning(
"No adapter found for task %r in %r. Loading raw base weights.",
self._task,
self._model_name,
)
else:
adapter_config, adapter_weights = result
scaling = adapter_config["lora_alpha"] / adapter_config["r"]
lora_pairs = _build_lora_pairs(adapter_weights)
logger.info(
"Loaded %d adapter tensors for task %r (scaling=%.4f, %d LoRA pairs)",
len(adapter_weights),
self._task,
scaling,
len(lora_pairs),
)
def _merge_weights(
weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[tuple[str, torch.Tensor]]:
for name, tensor in weights:
clean_name = name
if clean_name.startswith("model."):
clean_name = clean_name[len("model.") :]
if clean_name in lora_pairs:
pair = lora_pairs[clean_name]
if "A" in pair and "B" in pair:
lora_A = pair["A"].to(device=tensor.device, dtype=tensor.dtype)
lora_B = pair["B"].to(device=tensor.device, dtype=tensor.dtype)
tensor = tensor + (lora_B @ lora_A) * scaling
yield name, tensor
loaded = self.model.load_weights(_merge_weights(weights))
return {f"model.{name}" for name in loaded}
...@@ -227,6 +227,7 @@ _EMBEDDING_MODELS = { ...@@ -227,6 +227,7 @@ _EMBEDDING_MODELS = {
"GritLM": ("gritlm", "GritLM"), "GritLM": ("gritlm", "GritLM"),
"GteModel": ("bert_with_rope", "SnowflakeGteNewModel"), "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
"GteNewModel": ("bert_with_rope", "GteNewModel"), "GteNewModel": ("bert_with_rope", "GteNewModel"),
"JinaEmbeddingsV5Model": ("jina", "JinaEmbeddingsV5Model"),
"LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"), "LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
"LlamaModel": ("llama", "LlamaForCausalLM"), "LlamaModel": ("llama", "LlamaForCausalLM"),
**{ **{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment