Unverified Commit 2cdf8604 authored by Abhijit Roy's avatar Abhijit Roy Committed by GitHub
Browse files

Add Jina Embeddings v5 model support (fixes #38633) (#39575)


Signed-off-by: default avatarAbhijit <abroy@redhat.com>
Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
Co-authored-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent 78453792
......@@ -45,6 +45,7 @@ You can compute pairwise similarity scores to build a similarity matrix using th
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ |
| `GteModel` | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | |
| `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | |
| `JinaEmbeddingsV5Model`<sup>C</sup> | Qwen3-based with task-specific LoRA adapters | `jinaai/jina-embeddings-v5-text-small` (see note) | ✅︎ | ✅︎ |
| `LlamaBidirectionalModel`<sup>C</sup> | Llama-based with bidirectional attention | `nvidia/llama-nemotron-embed-1b-v2`, etc. | ✅︎ | ✅︎ |
| `LlamaModel`<sup>C</sup>, `LlamaForCausalLM`<sup>C</sup>, `MistralModel`<sup>C</sup>, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ |
| `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | |
......@@ -73,6 +74,12 @@ You can compute pairwise similarity scores to build a similarity matrix using th
!!! note
`jinaai/jina-embeddings-v3` supports multiple tasks through LoRA, while vllm temporarily only supports text-matching tasks by merging LoRA weights.
!!! note
`jinaai/jina-embeddings-v5-text-small` ships with four task-specific LoRA adapters
(`retrieval`, `text-matching`, `classification`, `clustering`). vLLM merges the
selected adapter into the base weights at load time. Choose the task with
`--hf-overrides '{"jina_task": "<task>"}'`; the default is `retrieval`.
### Multimodal Models
!!! note
......
......@@ -364,6 +364,7 @@ class HfRunner:
model_name: str,
dtype: str = "auto",
*,
revision: str | None = None,
model_kwargs: dict[str, Any] | None = None,
trust_remote_code: bool = True,
is_sentence_transformer: bool = False,
......@@ -383,6 +384,7 @@ class HfRunner:
self._init(
model_name=model_name,
dtype=dtype,
revision=revision,
model_kwargs=model_kwargs,
trust_remote_code=trust_remote_code,
is_sentence_transformer=is_sentence_transformer,
......@@ -396,6 +398,7 @@ class HfRunner:
model_name: str,
dtype: str = "auto",
*,
revision: str | None = None,
model_kwargs: dict[str, Any] | None = None,
trust_remote_code: bool = True,
is_sentence_transformer: bool = False,
......@@ -437,6 +440,7 @@ class HfRunner:
self.model = SentenceTransformer(
model_name,
revision=revision,
device=self.device,
model_kwargs=model_kwargs,
trust_remote_code=trust_remote_code,
......@@ -447,6 +451,7 @@ class HfRunner:
self.model = CrossEncoder(
model_name,
revision=revision,
device=self.device,
automodel_args=model_kwargs,
trust_remote_code=trust_remote_code,
......@@ -456,6 +461,7 @@ class HfRunner:
nn.Module,
auto_cls.from_pretrained(
model_name,
revision=revision,
trust_remote_code=trust_remote_code,
**model_kwargs,
),
......
......@@ -74,10 +74,25 @@ class MtebEmbedMixin(mteb.EncoderProtocol):
return sim
class HfMtebEncoder(MtebEmbedMixin):
def __init__(self, model):
self.model = model
def encode(
self,
inputs: DataLoader[mteb.types.BatchedInput],
*args,
**kwargs,
) -> np.ndarray:
sentences = [text for batch in inputs for text in batch["text"]]
return self.model.encode(sentences)
class VllmMtebEncoder(MtebEmbedMixin):
def __init__(self, vllm_model):
def __init__(self, vllm_model, prompt_prefix: str | None = None):
self.llm = vllm_model
self.rng = np.random.default_rng(seed=42)
self.prompt_prefix = prompt_prefix
def encode(
self,
......@@ -87,7 +102,11 @@ class VllmMtebEncoder(MtebEmbedMixin):
) -> np.ndarray:
# Hoping to discover potential scheduling
# issues by randomizing the order.
sentences = [text for batch in inputs for text in batch["text"]]
sentences = [
self.prompt_prefix + text if self.prompt_prefix else text
for batch in inputs
for text in batch["text"]
]
r = self.rng.permutation(len(sentences))
sentences = [sentences[i] for i in r]
outputs = self.llm.embed(sentences, use_tqdm=False)
......@@ -143,6 +162,7 @@ def mteb_test_embed_models(
vllm_extra_kwargs=None,
hf_model_callback=None,
atol=MTEB_EMBED_TOL,
prompt_prefix: str | None = None,
):
vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs)
......@@ -182,7 +202,7 @@ def mteb_test_embed_models(
)
vllm_main_score = run_mteb_embed_task(
VllmMtebEncoder(vllm_model), MTEB_EMBED_TASKS
VllmMtebEncoder(vllm_model, prompt_prefix=prompt_prefix), MTEB_EMBED_TASKS
)
vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype
head_dtype = model_config.head_dtype
......@@ -210,7 +230,9 @@ def mteb_test_embed_models(
if hf_model_callback is not None:
hf_model_callback(hf_model)
st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS)
st_main_score = run_mteb_embed_task(
HfMtebEncoder(hf_model), MTEB_EMBED_TASKS
)
st_dtype = next(hf_model.model.parameters()).dtype
# Check embeddings close to hf outputs
......
......@@ -28,7 +28,16 @@ EMBEDDING_MODELS = [
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
)
),
EmbedModelInfo(
"jinaai/jina-embeddings-v5-text-small",
mteb_score=0.794535707854956,
architecture="JinaEmbeddingsV5Model",
seq_pooling_type="LAST",
attn_type="decoder",
is_prefix_caching_supported=True,
is_chunked_prefill_supported=True,
),
]
RERANK_MODELS = [
......@@ -46,11 +55,18 @@ RERANK_MODELS = [
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None:
task = "retrieval" if "v5" in model_info.name else "text-matching"
prompt_prefix: str | None = "Document: " if "v5" in model_info.name else None
def hf_model_callback(model):
model.encode = partial(model.encode, task="text-matching")
model.encode = partial(model.encode, task=task)
mteb_test_embed_models(
hf_runner, vllm_runner, model_info, hf_model_callback=hf_model_callback
hf_runner,
vllm_runner,
model_info,
hf_model_callback=hf_model_callback,
prompt_prefix=prompt_prefix,
)
......@@ -58,8 +74,10 @@ def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -
def test_embed_models_correctness(
hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts
) -> None:
task = "retrieval" if "v5" in model_info.name else "text-matching"
def hf_model_callback(model):
model.encode = partial(model.encode, task="text-matching")
model.encode = partial(model.encode, task=task)
correctness_test_embed_models(
hf_runner,
......@@ -97,12 +115,14 @@ def test_matryoshka(
# ST will strip the input texts, see test_embedding.py
example_prompts = [str(s).strip() for s in example_prompts]
task = "retrieval" if "v5" in model_info.name else "text-matching"
with hf_runner(
model_info.name,
dtype=dtype,
is_sentence_transformer=True,
) as hf_model:
hf_outputs = hf_model.encode(example_prompts, task="text-matching")
hf_outputs = hf_model.encode(example_prompts, task=task)
hf_outputs = matryoshka_fy(hf_outputs, dimensions)
with vllm_runner(
......
......@@ -609,6 +609,10 @@ _EMBEDDING_EXAMPLE_MODELS = {
trust_remote_code=True,
hf_overrides={"architectures": ["GteNewModel"]},
),
"JinaEmbeddingsV5Model": _HfExamplesInfo(
"jinaai/jina-embeddings-v5-text-small",
trust_remote_code=True,
),
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"LlamaBidirectionalModel": _HfExamplesInfo(
"nvidia/llama-nemotron-embed-1b-v2", trust_remote_code=True
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://huggingface.co/jinaai/jina-reranker-v3/blob/main/modeling.py
import json
import logging
from collections import defaultdict
from collections.abc import Iterable
import torch
from safetensors.torch import load as safetensors_load
from torch import nn
from vllm.config import VllmConfig
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from vllm.transformers_utils.repo_utils import get_hf_file_bytes
from vllm.v1.pool.metadata import PoolingMetadata
from ..layers.pooler import DispatchPooler
......@@ -18,9 +23,12 @@ from ..layers.pooler.tokwise import (
TokenPoolingMethodOutputItem,
)
from .interfaces import SupportsLateInteraction
from .qwen3 import Qwen3Model
from .interfaces_base import VllmModelForPooling
from .qwen3 import Qwen3ForCausalLM, Qwen3Model
from .utils import AutoWeightsLoader, maybe_prefix
logger = logging.getLogger(__name__)
class JinaForRanking(nn.Module, SupportsLateInteraction):
is_pooling_model = True
......@@ -108,3 +116,143 @@ class JinaForRankingPool(StepPool):
embeds_list.append(embeds)
return embeds_list
# jina-embeddings-v5-text-small wraps Qwen3-0.6B-Base with four task-specific
# LoRA adapters. This implementation merges the selected adapter into the base
# weights at load time to avoid any runtime dependency on peft.
#
# Task selection:
# Pass --hf-overrides '{"jina_task": "retrieval"}' to select one of:
# retrieval (default), text-matching, classification, clustering.
_DEFAULT_TASK = "retrieval"
_SUPPORTED_TASKS = {"retrieval", "text-matching", "classification", "clustering"}
def _load_adapter(
model: str,
task: str,
revision: str | None,
) -> tuple[dict, dict[str, torch.Tensor]] | None:
"""Load adapter config and weights from a local path or HF repo.
Returns (adapter_config, adapter_weights) or None if not found.
"""
config_bytes = get_hf_file_bytes(
f"adapters/{task}/adapter_config.json",
model,
revision,
)
if config_bytes is None:
return None
adapter_config = json.loads(config_bytes)
weights_bytes = get_hf_file_bytes(
f"adapters/{task}/adapter_model.safetensors",
model,
revision,
)
if weights_bytes is None:
return None
adapter_weights = safetensors_load(weights_bytes)
return adapter_config, adapter_weights
def _build_lora_pairs(adapter_weights: dict) -> dict:
"""Group raw adapter tensors into {base_key: {"A": tensor, "B": tensor}} pairs.
Transforms adapter keys like:
base_model.model.layers.0.self_attn.q_proj.lora_A.weight
Into base keys like:
layers.0.self_attn.q_proj.weight
"""
lora_pairs = defaultdict(dict)
for key, tensor in adapter_weights.items():
clean_key = key
if clean_key.startswith("base_model.model."):
clean_key = clean_key[len("base_model.model.") :]
if ".lora_A." in clean_key:
base_key = clean_key.split(".lora_A.")[0] + ".weight"
lora_pairs[base_key]["A"] = tensor
elif ".lora_B." in clean_key:
base_key = clean_key.split(".lora_B.")[0] + ".weight"
lora_pairs[base_key]["B"] = tensor
return dict(lora_pairs)
class JinaEmbeddingsV5Model(Qwen3ForCausalLM, VllmModelForPooling):
"""Jina Embeddings V5 with task-specific LoRA adapters merged at load time.
Extends Qwen3ForCausalLM (the underlying architecture) and declares itself
as a pooling model so that as_embedding_model() does not wrap it.
"""
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
self._model_name = vllm_config.model_config.model
self._revision = vllm_config.model_config.revision
self._task = getattr(
vllm_config.model_config.hf_config, "jina_task", _DEFAULT_TASK
)
if self._task not in _SUPPORTED_TASKS:
logger.warning(
"Unknown jina_task=%r. Falling back to %r.",
self._task,
_DEFAULT_TASK,
)
self._task = _DEFAULT_TASK
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler.for_embedding(pooler_config)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
lora_pairs: dict = {}
scaling = 1.0
result = _load_adapter(self._model_name, self._task, self._revision)
if result is None:
logger.warning(
"No adapter found for task %r in %r. Loading raw base weights.",
self._task,
self._model_name,
)
else:
adapter_config, adapter_weights = result
scaling = adapter_config["lora_alpha"] / adapter_config["r"]
lora_pairs = _build_lora_pairs(adapter_weights)
logger.info(
"Loaded %d adapter tensors for task %r (scaling=%.4f, %d LoRA pairs)",
len(adapter_weights),
self._task,
scaling,
len(lora_pairs),
)
def _merge_weights(
weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[tuple[str, torch.Tensor]]:
for name, tensor in weights:
clean_name = name
if clean_name.startswith("model."):
clean_name = clean_name[len("model.") :]
if clean_name in lora_pairs:
pair = lora_pairs[clean_name]
if "A" in pair and "B" in pair:
lora_A = pair["A"].to(device=tensor.device, dtype=tensor.dtype)
lora_B = pair["B"].to(device=tensor.device, dtype=tensor.dtype)
tensor = tensor + (lora_B @ lora_A) * scaling
yield name, tensor
loaded = self.model.load_weights(_merge_weights(weights))
return {f"model.{name}" for name in loaded}
......@@ -227,6 +227,7 @@ _EMBEDDING_MODELS = {
"GritLM": ("gritlm", "GritLM"),
"GteModel": ("bert_with_rope", "SnowflakeGteNewModel"),
"GteNewModel": ("bert_with_rope", "GteNewModel"),
"JinaEmbeddingsV5Model": ("jina", "JinaEmbeddingsV5Model"),
"LlamaBidirectionalModel": ("llama", "LlamaBidirectionalModel"),
"LlamaModel": ("llama", "LlamaForCausalLM"),
**{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment