Unverified Commit cb5f7501 authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[New Model]: jinaai/jina-reranker-v3 (#38800)


Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
parent 8d0f908b
...@@ -71,6 +71,14 @@ Models of any architecture can be converted into embedding models using `--conve ...@@ -71,6 +71,14 @@ Models of any architecture can be converted into embedding models using `--conve
If your model is not in the above list, we will try to automatically convert the model using [as_embedding_model][vllm.model_executor.models.adapters.as_embedding_model]. If your model is not in the above list, we will try to automatically convert the model using [as_embedding_model][vllm.model_executor.models.adapters.as_embedding_model].
### Special models
| Architecture | Models | Example HF Models | [LoRA](../../features/lora.md) | [PP](../../serving/parallelism_scaling.md) |
| ------------ | ------ | ----------------- | -------------------- | ------------------------- |
| `JinaForRanking` | Qwen3-based | `jinaai/jina-reranker-v3` | | |
jina-reranker-v3 is a listwise document reranker model with a novel `last but not late interaction` architecture. More information can be found at: [examples/pooling/token_embed/jina_reranker_v3_offline.py](../../../examples/pooling/token_embed/jina_reranker_v3_offline.py)
--8<-- [end:supported-token-embed-models] --8<-- [end:supported-token-embed-models]
## Offline Inference ## Offline Inference
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import torch.nn.functional as F
from vllm import LLM
query = "What are the health benefits of green tea?"
documents = [
"Green tea contains antioxidants called catechins that may help reduce inflammation and protect cells from damage.",
"El precio del café ha aumentado un 20% este año debido a problemas en la cadena de suministro.",
"Studies show that drinking green tea regularly can improve brain function and boost metabolism.",
"Basketball is one of the most popular sports in the United States.",
"绿茶富含儿茶素等抗氧化剂,可以降低心脏病风险,还有助于控制体重。",
"Le thé vert est riche en antioxydants et peut améliorer la fonction cérébrale.",
]
def main():
# Initialize model
llm = LLM(
model="jinaai/jina-reranker-v3",
runner="pooling",
)
# Generate scores.
outputs = llm.score(query, documents)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for document, output in zip(documents, outputs):
score = output.outputs.score
print(f"Pair: {[query, document]!r} \nScore: {score}")
print("-" * 60)
# Generate embeddings.
# The JinaForRanking model concatenates docs first, then query.
# Let's stay consistent with this novel design.
outputs = llm.encode(documents + [query], pooling_task="token_embed")
embeds = outputs[0].outputs.data.float()
doc_embeds = embeds[:-1]
query_embeds = embeds[-1]
scores = F.cosine_similarity(query_embeds, doc_embeds)
# Print the outputs.
print("\nGenerated Outputs:\n" + "-" * 60)
for document, score in zip(documents, scores):
print(f"Pair: {[query, document]!r} \nScore: {score}")
print("-" * 60)
if __name__ == "__main__":
main()
...@@ -120,7 +120,8 @@ python = "./.venv" ...@@ -120,7 +120,8 @@ python = "./.venv"
[tool.typos.files] [tool.typos.files]
# these files may be written in non english words # these files may be written in non english words
extend-exclude = ["tests/models/fixtures/*", "tests/prompts/*", "tests/tokenizers_/*", extend-exclude = ["tests/models/fixtures/*", "tests/prompts/*", "tests/tokenizers_/*",
"benchmarks/sonnet.txt", "tests/lora/data/*", "examples/pooling/token_embed/*", "build/*", "benchmarks/sonnet.txt", "tests/lora/data/*", "build/*",
"examples/pooling/token_embed/*", "tests/models/language/pooling/*",
"vllm/third_party/*", "vllm/entrypoints/serve/instrumentator/static/*", "tests/entrypoints/openai/speech_to_text/test_transcription_validation.py", "vllm/third_party/*", "vllm/entrypoints/serve/instrumentator/static/*", "tests/entrypoints/openai/speech_to_text/test_transcription_validation.py",
"docs/governance/process.md", "tests/v1/engine/test_fast_incdec_prefix_err.py", ".git/*"] "docs/governance/process.md", "tests/v1/engine/test_fast_incdec_prefix_err.py", ".git/*"]
ignore-hidden = false ignore-hidden = false
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import pytest
import requests
import torch
import torch.nn.functional as F
from tests.utils import RemoteOpenAIServer
from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
from vllm.entrypoints.pooling.scoring.protocol import ScoreResponse
model_name = "jinaai/jina-reranker-v3"
query = "What are the health benefits of green tea?"
documents = [
"Green tea contains antioxidants called catechins that may help reduce inflammation and protect cells from damage.",
"El precio del café ha aumentado un 20% este año debido a problemas en la cadena de suministro.",
"Studies show that drinking green tea regularly can improve brain function and boost metabolism.",
"Basketball is one of the most popular sports in the United States.",
"绿茶富含儿茶素等抗氧化剂,可以降低心脏病风险,还有助于控制体重。",
"Le thé vert est riche en antioxydants et peut améliorer la fonction cérébrale.",
]
EMBEDDING_SIZE = 512
REFERENCE_1_VS_1 = [
0.345703125,
-0.10498046,
0.314453125,
-0.1376953125,
0.3398437500,
0.2539062,
]
REFERENCE_1_VS_N = [
0.294921875,
-0.16015625,
0.189453125,
-0.1708984375,
0.2255859375,
0.1640625,
]
TOL = 0.01
def test_offline(vllm_runner):
with vllm_runner(model_name, runner="pooling") as llm_runner:
llm = llm_runner.get_llm()
_test_offline_1_v_1(llm)
_test_offline_1_v_n(llm)
_test_offline_n_v_n(llm)
_test_offline_token_embed_illegal_inputs(llm)
assert llm.model_config.embedding_size == EMBEDDING_SIZE
def test_online():
with RemoteOpenAIServer(model_name, ["--runner", "pooling"]) as server:
_test_online_1_v_1(server)
_test_online_1_v_n(server)
_test_online_n_v_n(server)
_test_online_token_embed_illegal_inputs(server)
def _test_offline_1_v_1(llm):
# test llm.score
outputs = llm.score(query, documents[0])
assert len(outputs) == 1
assert outputs[0].outputs.score == pytest.approx(REFERENCE_1_VS_1[0], abs=TOL)
# test llm.encode
outputs = llm.encode(documents[:1] + [query], pooling_task="token_embed")
embeds = outputs[0].outputs.data.float()
assert embeds.shape[0] == 2
assert embeds.shape[-1] == EMBEDDING_SIZE
doc_embeds = embeds[:-1]
query_embeds = embeds[-1]
scores = F.cosine_similarity(query_embeds, doc_embeds)
assert scores[0] == pytest.approx(REFERENCE_1_VS_1[0], abs=TOL)
def _test_offline_1_v_n(llm):
# test llm.score
outputs = llm.score(query, documents)
assert len(outputs) == len(documents)
for expected, output in zip(REFERENCE_1_VS_N, outputs):
actual = output.outputs.score
assert actual == pytest.approx(expected, abs=TOL)
# test llm.encode
outputs = llm.encode(documents + [query], pooling_task="token_embed")
embeds = outputs[0].outputs.data.float()
assert embeds.shape[0] == len(documents) + 1
doc_embeds = embeds[:-1]
query_embeds = embeds[-1]
scores = F.cosine_similarity(query_embeds, doc_embeds)
assert len(scores) == len(documents)
for expected, actual in zip(REFERENCE_1_VS_N, scores):
assert actual == pytest.approx(expected, abs=TOL)
def _test_offline_n_v_n(llm):
# test llm.score
outputs = llm.score([query] * len(documents), documents)
assert len(outputs) == len(documents)
for expected, output in zip(REFERENCE_1_VS_1, outputs):
actual = output.outputs.score
assert actual == pytest.approx(expected, abs=TOL)
# test llm.encode
for doc, expected in zip(documents, REFERENCE_1_VS_1):
outputs = llm.encode([doc, query], pooling_task="token_embed")
embeds = outputs[0].outputs.data.float()
assert embeds.shape[0] == 2
doc_embeds = embeds[:-1]
query_embeds = embeds[-1]
scores = F.cosine_similarity(query_embeds, doc_embeds)
assert scores[0] == pytest.approx(expected, abs=TOL)
def _test_offline_token_embed_illegal_inputs(llm):
with pytest.raises(
ValueError, match="The JinaForRanking model requires at least 2 inputs."
):
llm.encode([query], pooling_task="token_embed")
with pytest.raises(
ValueError, match="The JinaForRanking model only supports text as input."
):
llm.encode([1, 2, 3], pooling_task="token_embed")
def _get_scores(server, query, document):
score_response = requests.post(
server.url_for("score"),
json={
"model": model_name,
"queries": query,
"documents": document,
},
)
score_response.raise_for_status()
score = ScoreResponse.model_validate(score_response.json())
return [d.score for d in score.data]
def _get_embeds(server, prompts: list[str]):
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"task": "token_embed",
"input": prompts,
"encoding_format": "float",
},
)
response.raise_for_status()
poolings = PoolingResponse.model_validate(response.json())
return torch.as_tensor([d.data for d in poolings.data][0]).float()
def _test_online_1_v_1(server):
# test scoring api
scores = _get_scores(server, query, documents[0])
assert len(scores) == 1
assert scores[0] == pytest.approx(REFERENCE_1_VS_1[0], abs=TOL)
# test pooling api
embeds = _get_embeds(server, [documents[0], query])
assert embeds.shape[0] == 2
assert embeds.shape[-1] == EMBEDDING_SIZE
doc_embeds = embeds[:-1]
query_embeds = embeds[-1]
scores = F.cosine_similarity(query_embeds, doc_embeds)
assert scores[0] == pytest.approx(REFERENCE_1_VS_1[0], abs=TOL)
def _test_online_1_v_n(server):
# test scoring api
scores = _get_scores(server, query, documents)
assert len(scores) == len(documents)
for expected, actual in zip(REFERENCE_1_VS_N, scores):
assert actual == pytest.approx(expected, abs=TOL)
# test pooling api
embeds = _get_embeds(server, documents + [query])
assert embeds.shape[0] == len(documents) + 1
doc_embeds = embeds[:-1]
query_embeds = embeds[-1]
scores = F.cosine_similarity(query_embeds, doc_embeds)
assert len(scores) == len(documents)
for expected, actual in zip(REFERENCE_1_VS_N, scores):
assert actual == pytest.approx(expected, abs=TOL)
def _test_online_n_v_n(server):
# test scoring api
scores = _get_scores(server, [query] * len(documents), documents)
assert len(scores) == len(documents)
for expected, actual in zip(REFERENCE_1_VS_1, scores):
assert actual == pytest.approx(expected, abs=TOL)
# test pooling api
for doc, expected in zip(documents, REFERENCE_1_VS_1):
embeds = _get_embeds(server, [doc, query])
assert embeds.shape[0] == 2
doc_embeds = embeds[:-1]
query_embeds = embeds[-1]
scores = F.cosine_similarity(query_embeds, doc_embeds)
assert len(scores) == 1
assert scores[0] == pytest.approx(expected, abs=TOL)
def _test_online_token_embed_illegal_inputs(server):
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"task": "token_embed",
"input": [query],
"encoding_format": "float",
},
)
assert response.json()["error"]["message"].startswith(
"The JinaForRanking model requires at least 2 inputs."
)
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"task": "token_embed",
"input": [1, 2, 3],
"encoding_format": "float",
},
)
assert response.json()["error"]["message"].startswith(
"The JinaForRanking model only supports text as input."
)
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"task": "token_embed",
"messages": [
{
"role": "user",
"content": "The cat sat on the mat.",
}
],
"encoding_format": "float",
},
)
assert response.json()["error"]["message"].startswith(
"The JinaForRanking does not support chat Request."
)
...@@ -645,6 +645,7 @@ _LATE_INTERACTION_EXAMPLE_MODELS = { ...@@ -645,6 +645,7 @@ _LATE_INTERACTION_EXAMPLE_MODELS = {
trust_remote_code=True, trust_remote_code=True,
hf_overrides={"architectures": ["ColBERTLfm2Model"]}, hf_overrides={"architectures": ["ColBERTLfm2Model"]},
), ),
"JinaForRanking": _HfExamplesInfo("jinaai/jina-reranker-v3"),
# [Multimodal] # [Multimodal]
"ColModernVBertForRetrieval": _HfExamplesInfo( "ColModernVBertForRetrieval": _HfExamplesInfo(
"ModernVBERT/colmodernvbert-merged", "ModernVBERT/colmodernvbert-merged",
......
...@@ -24,7 +24,13 @@ from vllm.entrypoints.pooling.embed.protocol import ( ...@@ -24,7 +24,13 @@ from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingChatRequest, EmbeddingChatRequest,
EmbeddingCompletionRequest, EmbeddingCompletionRequest,
) )
from vllm.entrypoints.pooling.typing import PoolingServeContext from vllm.entrypoints.pooling.scoring.io_processor import JinaRankingIOProcessorMixin
from vllm.entrypoints.pooling.typing import (
OfflineInputsContext,
PoolingChatLikeRequest,
PoolingCompletionLikeRequest,
PoolingServeContext,
)
from vllm.inputs import EngineInput, tokens_input from vllm.inputs import EngineInput, tokens_input
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.outputs import PoolingOutput, PoolingRequestOutput
...@@ -553,3 +559,48 @@ class EmbedIOProcessor(PoolingIOProcessor): ...@@ -553,3 +559,48 @@ class EmbedIOProcessor(PoolingIOProcessor):
class TokenEmbedIOProcessor(PoolingIOProcessor): class TokenEmbedIOProcessor(PoolingIOProcessor):
name = "token_embed" name = "token_embed"
class JinaRankingTokenEmbedIOProcessor(
TokenEmbedIOProcessor, JinaRankingIOProcessorMixin
):
def pre_process_online(self, ctx: PoolingServeContext):
request = ctx.request
if isinstance(request, PoolingCompletionLikeRequest):
prompts = request.input
if not isinstance(prompts, Sequence) or len(prompts) < 2:
raise ValueError("The JinaForRanking model requires at least 2 inputs.")
text_prompts = self.ensure_str(prompts)
# The JinaForRanking model concatenates docs first, then query.
# Let's stay consistent with this novel design.
prompt_input = self.format_docs_prompts_func(
query=text_prompts[-1], docs=text_prompts[:-1]
)
engine_inputs = self._preprocess_completion_online(
request,
prompt_input=prompt_input,
prompt_embeds=None,
)
elif isinstance(request, PoolingChatLikeRequest):
raise ValueError("The JinaForRanking does not support chat Request.")
else:
raise ValueError(f"Invalid {self.name} request type")
ctx.engine_inputs = engine_inputs
def pre_process_offline(self, ctx: OfflineInputsContext) -> Sequence[EngineInput]:
if not isinstance(ctx.prompts, Sequence) or len(ctx.prompts) < 2:
raise ValueError("The JinaForRanking model requires at least 2 inputs.")
text_prompts = self.ensure_str(ctx.prompts)
# The JinaForRanking model concatenates docs first, then query.
# Let's stay consistent with this novel design.
ctx.prompts = self.format_docs_prompts_func(
query=text_prompts[-1], docs=text_prompts[:-1]
)
return super().pre_process_offline(ctx)
...@@ -59,6 +59,13 @@ def init_pooling_io_processors( ...@@ -59,6 +59,13 @@ def init_pooling_io_processors(
if score_type is not None and score_type in ScoringIOProcessors: if score_type is not None and score_type in ScoringIOProcessors:
processors[score_type] = ScoringIOProcessors[score_type] processors[score_type] = ScoringIOProcessors[score_type]
if model_config.architecture == "JinaForRanking":
from .embed.io_processor import JinaRankingTokenEmbedIOProcessor
from .scoring.io_processor import ScoringIOProcessors
processors["token_embed"] = JinaRankingTokenEmbedIOProcessor
processors["late-interaction"] = ScoringIOProcessors["jina-reranking-scoring"]
return { return {
task: processor_cls( task: processor_cls(
vllm_config=vllm_config, vllm_config=vllm_config,
......
...@@ -416,11 +416,137 @@ class CrossEncoderIOProcessor(ScoringIOProcessor): ...@@ -416,11 +416,137 @@ class CrossEncoderIOProcessor(ScoringIOProcessor):
return full_prompt, engine_prompt return full_prompt, engine_prompt
class JinaRankingIOProcessorMixin:
@staticmethod
def sanitize_input(text: str, special_tokens: dict[str, str]) -> str:
for token in special_tokens.values():
text = text.replace(token, "")
return text
@staticmethod
def format_docs_prompts_func(
query: str,
docs: list[str],
special_tokens: dict[str, str] | None = None,
instruction: str | None = None,
no_thinking: bool = True,
) -> str:
# TODO: Try converting the code below into a chat template.
default_special_tokens = {
"query_embed_token": "<|rerank_token|>",
"doc_embed_token": "<|embed_token|>",
}
if special_tokens is None:
special_tokens = default_special_tokens
query = JinaRankingIOProcessorMixin.sanitize_input(query, special_tokens)
docs = [
JinaRankingIOProcessorMixin.sanitize_input(doc, special_tokens)
for doc in docs
]
prefix = (
"<|im_start|>system\n"
"You are a search relevance expert who can determine a ranking of the passages based on how relevant they are to the query. " # noqa: E501
"If the query is a question, how relevant a passage is depends on how well it answers the question. " # noqa: E501
"If not, try to analyze the intent of the query and assess how well each passage satisfies the intent. " # noqa: E501
"If an instruction is provided, you should follow the instruction when determining the ranking." # noqa: E501
"<|im_end|>\n<|im_start|>user\n"
)
suffix = "<|im_end|>\n<|im_start|>assistant\n"
if no_thinking:
suffix += "<think>\n\n</think>\n\n"
doc_emb_token = special_tokens["doc_embed_token"]
query_emb_token = special_tokens["query_embed_token"]
prompt = (
f"I will provide you with {len(docs)} passages, each indicated by a numerical identifier. " # noqa: E501
f"Rank the passages based on their relevance to query: {query}\n"
)
if instruction:
prompt += f"<instruct>\n{instruction}\n</instruct>\n"
doc_prompts = [
f'<passage id="{i}">\n{doc}{doc_emb_token}\n</passage>'
for i, doc in enumerate(docs)
]
prompt += "\n".join(doc_prompts) + "\n"
prompt += f"<query>\n{query}{query_emb_token}\n</query>"
return prefix + prompt + suffix
@staticmethod
def ensure_str(data: Sequence[Any]) -> list[str]:
text: list[str] = []
for prompt in data:
if not isinstance(prompt, str):
raise ValueError(
"The JinaForRanking model only supports text as input."
)
text.append(prompt)
return text
class JinaRankingIOProcessor(LateInteractionIOProcessor, JinaRankingIOProcessorMixin):
name = "jina-reranking-scoring"
pooling_task: PoolingTask = "token_embed"
def _pre_process(
self,
scoring_data: ScoringData,
tok_params: TokenizeParams,
prompt_extras: dict[str, Any] | None = None,
) -> Sequence[EngineInput]:
queries = self.ensure_str(scoring_data.data_1)
docs = self.ensure_str(scoring_data.data_2)
if len(queries) == 1:
prompts = [self.format_docs_prompts_func(query=queries[0], docs=docs)]
else:
prompts = [
self.format_docs_prompts_func(query=q, docs=[d])
for q, d in zip(queries, docs)
]
return self._preprocess_completion_offline(
prompts=prompts, tok_params=tok_params, prompt_extras=prompt_extras
)
def _post_process(self, outputs: list[PoolingRequestOutput], n_queries: int):
final_res_batch: list[PoolingRequestOutput] = []
for i in range(len(outputs)):
embeds = outputs[i].outputs.data.float()
# The JinaForRanking model concatenates docs first, then query.
# Let's stay consistent with this novel design.
query_embeds = embeds[-1]
doc_embeds = embeds[:-1]
scores = F.cosine_similarity(query_embeds, doc_embeds)
for score in scores:
final_res_batch.append(
PoolingRequestOutput(
request_id=outputs[i].request_id,
outputs=score,
prompt_token_ids=outputs[i].prompt_token_ids,
num_cached_tokens=outputs[i].num_cached_tokens,
finished=True,
)
)
return final_res_batch
ScoringIOProcessors: dict[str, type[ScoringIOProcessor]] = { ScoringIOProcessors: dict[str, type[ScoringIOProcessor]] = {
p.name: p p.name: p
for p in [ for p in [
BiEncoderIOProcessor, BiEncoderIOProcessor,
LateInteractionIOProcessor, LateInteractionIOProcessor,
JinaRankingIOProcessor,
FlashLateInteractionIOProcessor, FlashLateInteractionIOProcessor,
CrossEncoderIOProcessor, CrossEncoderIOProcessor,
] ]
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
from fastapi.responses import JSONResponse, Response from fastapi.responses import JSONResponse, Response
from vllm import PoolingParams from vllm import PoolingParams
from vllm.config import VllmConfig
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.engine.protocol import UsageInfo from vllm.entrypoints.openai.engine.protocol import UsageInfo
from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
...@@ -42,25 +41,23 @@ class ServingScores(PoolingServing): ...@@ -42,25 +41,23 @@ class ServingScores(PoolingServing):
enable_flash_late_interaction: bool = True, enable_flash_late_interaction: bool = True,
**kwargs, **kwargs,
): ):
self.score_type = engine_client.model_config.score_type self.io_processor_name: str = engine_client.model_config.score_type
self.enable_flash_late_interaction = ( self.enable_flash_late_interaction = (
self.score_type == "late-interaction" and enable_flash_late_interaction self.io_processor_name == "late-interaction"
and enable_flash_late_interaction
) )
super().__init__(engine_client, *args, **kwargs) if self.enable_flash_late_interaction:
self.io_processor_name = "flash-late-interaction"
def init_io_processor( if engine_client.model_config.architecture == "JinaForRanking":
self, vllm_config: VllmConfig, *args, **kwargs self.io_processor_name = "jina-reranking-scoring"
) -> PoolingIOProcessor: self.enable_flash_late_interaction = False
model_config = vllm_config.model_config
score_type: str = model_config.score_type super().__init__(engine_client, *args, **kwargs)
if self.enable_flash_late_interaction:
score_type = "flash-late-interaction"
assert score_type in ScoringIOProcessors def init_io_processor(self, *args, **kwargs) -> PoolingIOProcessor:
processor_cls = ScoringIOProcessors[score_type] return ScoringIOProcessors[self.io_processor_name](*args, **kwargs)
return processor_cls(vllm_config, *args, **kwargs)
async def __call__(self, *args, **kwargs) -> Response: async def __call__(self, *args, **kwargs) -> Response:
if not self.enable_flash_late_interaction: if not self.enable_flash_late_interaction:
......
...@@ -100,7 +100,7 @@ class StepPool(AllPool): ...@@ -100,7 +100,7 @@ class StepPool(AllPool):
): ):
# for unfinished chunked prefill # for unfinished chunked prefill
if data is None: if data is None:
pass pooled_data.append(None)
else: else:
step_tag_id = pooling_param.step_tag_id step_tag_id = pooling_param.step_tag_id
returned_token_ids = pooling_param.returned_token_ids returned_token_ids = pooling_param.returned_token_ids
......
...@@ -58,7 +58,7 @@ class TokenPooler(Pooler): ...@@ -58,7 +58,7 @@ class TokenPooler(Pooler):
def __init__( def __init__(
self, self,
pooling: TokenPoolingMethod | TokenPoolingFn, pooling: TokenPoolingMethod | TokenPoolingFn,
head: TokenPoolerHead | TokenPoolingHeadFn, head: TokenPoolerHead | TokenPoolingHeadFn | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -89,7 +89,8 @@ class TokenPooler(Pooler): ...@@ -89,7 +89,8 @@ class TokenPooler(Pooler):
pooling_metadata: PoolingMetadata, pooling_metadata: PoolingMetadata,
) -> TokenPoolerOutput: ) -> TokenPoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata) pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata) if self.head is not None:
pooled_data = self.head(pooled_data, pooling_metadata)
return pooled_data return pooled_data
......
...@@ -192,6 +192,12 @@ class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig): ...@@ -192,6 +192,12 @@ class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig):
pooler_config.use_activation = False pooler_config.use_activation = False
class JinaForRankingConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
model_config.hf_config.embedding_size = 512
class JinaRobertaModelConfig(VerifyAndUpdateConfig): class JinaRobertaModelConfig(VerifyAndUpdateConfig):
@staticmethod @staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None: def verify_and_update_model_config(model_config: "ModelConfig") -> None:
...@@ -612,6 +618,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { ...@@ -612,6 +618,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"GteNewForSequenceClassification": GteNewModelConfig, "GteNewForSequenceClassification": GteNewModelConfig,
"GteNewModel": GteNewModelConfig, "GteNewModel": GteNewModelConfig,
"JambaForSequenceClassification": JambaForSequenceClassificationConfig, "JambaForSequenceClassification": JambaForSequenceClassificationConfig,
"JinaForRanking": JinaForRankingConfig,
"JinaVLForRanking": JinaVLForSequenceClassificationConfig, "JinaVLForRanking": JinaVLForSequenceClassificationConfig,
"LlamaBidirectionalForSequenceClassification": LlamaBidirectionalConfig, "LlamaBidirectionalForSequenceClassification": LlamaBidirectionalConfig,
"LlamaBidirectionalModel": LlamaBidirectionalConfig, "LlamaBidirectionalModel": LlamaBidirectionalConfig,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://huggingface.co/jinaai/jina-reranker-v3/blob/main/modeling.py
from collections.abc import Iterable
import torch
from torch import nn
from vllm.config import VllmConfig
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from vllm.v1.pool.metadata import PoolingMetadata
from ..layers.pooler import DispatchPooler
from ..layers.pooler.tokwise import (
StepPool,
TokenPooler,
TokenPoolingMethodOutputItem,
)
from .interfaces import SupportsLateInteraction
from .qwen3 import Qwen3Model
from .utils import AutoWeightsLoader, maybe_prefix
class JinaForRanking(nn.Module, SupportsLateInteraction):
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.projector_dim: int = config.embedding_size
self.vllm_config = vllm_config
self.quant_config = quant_config
self.model = Qwen3Model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.projector = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size // 2, bias=False),
nn.ReLU(),
nn.Linear(config.hidden_size // 2, self.projector_dim, bias=False),
)
self.pooler = DispatchPooler(
{
"token_embed": TokenPooler(
pooling=JinaForRankingPool(self.projector),
)
}
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=(["lm_head."]))
return loader.load_weights(weights)
class JinaForRankingPool(StepPool):
def __init__(self, projector: nn.Sequential):
super().__init__()
self.doc_token_id = 151670
self.query_token_id = 151671
self.projector = projector
def get_supported_tasks(self) -> set[PoolingTask]:
return {"token_embed"}
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> list[TokenPoolingMethodOutputItem]:
pooled_data_lst = super().forward(hidden_states, pooling_metadata)
prompt_token_ids = pooling_metadata.get_prompt_token_ids()
embeds_list = list[torch.Tensor | None]()
for data, token_ids in zip(pooled_data_lst, prompt_token_ids):
# for unfinished chunked prefill
if data is None:
embeds_list.append(None)
else:
docs_indexes = torch.where(torch.eq(token_ids, self.doc_token_id))[0]
query_indexes = torch.where(torch.eq(token_ids, self.query_token_id))[0]
# The JinaForRanking model concatenates docs first, then query.
# Let's stay consistent with this novel design.
indexes = torch.cat([docs_indexes, query_indexes])
embeds = self.projector(data[indexes])
embeds_list.append(embeds)
return embeds_list
...@@ -273,6 +273,7 @@ _LATE_INTERACTION_MODELS = { ...@@ -273,6 +273,7 @@ _LATE_INTERACTION_MODELS = {
"ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"), "ColBERTModernBertModel": ("colbert", "ColBERTModernBertModel"),
"ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"), "ColBERTJinaRobertaModel": ("colbert", "ColBERTJinaRobertaModel"),
"ColBERTLfm2Model": ("colbert", "ColBERTLfm2Model"), "ColBERTLfm2Model": ("colbert", "ColBERTLfm2Model"),
"JinaForRanking": ("jina", "JinaForRanking"),
# [Multimodal] # [Multimodal]
"ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"), "ColModernVBertForRetrieval": ("colmodernvbert", "ColModernVBertForRetrieval"),
"ColPaliForRetrieval": ("colpali", "ColPaliModel"), "ColPaliForRetrieval": ("colpali", "ColPaliModel"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment