Unverified Commit c8b678e5 authored by Jakub Zakrzewski's avatar Jakub Zakrzewski Committed by GitHub
Browse files

[Model] Add support for nvidia/llama-nemotron-rerank-vl-1b-v2 (#35735)


Signed-off-by: default avatarJakub Zakrzewski <jzakrzewski@nvidia.com>
parent 18c29c74
...@@ -498,7 +498,9 @@ curl -s http://localhost:8000/pooling -H "Content-Type: application/json" -d '{ ...@@ -498,7 +498,9 @@ curl -s http://localhost:8000/pooling -H "Content-Type: application/json" -d '{
- Multi-vector retrieval: [examples/pooling/token_embed/colqwen3_token_embed_online.py](../../examples/pooling/token_embed/colqwen3_token_embed_online.py) - Multi-vector retrieval: [examples/pooling/token_embed/colqwen3_token_embed_online.py](../../examples/pooling/token_embed/colqwen3_token_embed_online.py)
- Reranking (text + multi-modal): [examples/pooling/score/colqwen3_rerank_online.py](../../examples/pooling/score/colqwen3_rerank_online.py) - Reranking (text + multi-modal): [examples/pooling/score/colqwen3_rerank_online.py](../../examples/pooling/score/colqwen3_rerank_online.py)
### Llama Nemotron Multimodal Embedding Models ### Llama Nemotron Multimodal
#### Embedding Model
Llama Nemotron VL Embedding models combine the bidirectional Llama embedding backbone Llama Nemotron VL Embedding models combine the bidirectional Llama embedding backbone
(from `nvidia/llama-nemotron-embed-1b-v2`) with SigLIP as the vision encoder to produce (from `nvidia/llama-nemotron-embed-1b-v2`) with SigLIP as the vision encoder to produce
...@@ -559,6 +561,70 @@ curl -s http://localhost:8000/v1/embeddings -H "Content-Type: application/json" ...@@ -559,6 +561,70 @@ curl -s http://localhost:8000/v1/embeddings -H "Content-Type: application/json"
}' }'
``` ```
#### Reranker Model
Llama Nemotron VL reranker models combine the same bidirectional Llama + SigLIP
backbone with a sequence-classification head for cross-encoder scoring and reranking.
| Architecture | Backbone | Example HF Models |
|---|---|---|
| `LlamaNemotronVLForSequenceClassification` | Bidirectional Llama + SigLIP | `nvidia/llama-nemotron-rerank-vl-1b-v2` |
Start the server:
```shell
vllm serve nvidia/llama-nemotron-rerank-vl-1b-v2 \
--runner pooling \
--trust-remote-code \
--chat-template examples/pooling/score/template/nemotron-vl-rerank.jinja
```
!!! note
The chat template bundled with this checkpoint's tokenizer is not suitable
for the Score/Rerank APIs. Use the provided override template when serving:
`examples/pooling/score/template/nemotron-vl-rerank.jinja`.
Score a text query against an image document:
```shell
curl -s http://localhost:8000/score -H "Content-Type: application/json" -d '{
"model": "nvidia/llama-nemotron-rerank-vl-1b-v2",
"data_1": "Find diagrams about autonomous robots",
"data_2": [
{
"content": [
{"type": "image_url", "image_url": {"url": "data:image/png;base64,<BASE64>"}},
{"type": "text", "text": "Robotics workflow diagram."}
]
}
]
}'
```
Rerank image documents by a text query:
```shell
curl -s http://localhost:8000/rerank -H "Content-Type: application/json" -d '{
"model": "nvidia/llama-nemotron-rerank-vl-1b-v2",
"query": "Find diagrams about autonomous robots",
"documents": [
{
"content": [
{"type": "image_url", "image_url": {"url": "data:image/png;base64,<BASE64_1>"}},
{"type": "text", "text": "Robotics workflow diagram."}
]
},
{
"content": [
{"type": "image_url", "image_url": {"url": "data:image/png;base64,<BASE64_2>"}},
{"type": "text", "text": "General skyline photo."}
]
}
],
"top_n": 2
}'
```
### BAAI/bge-m3 ### BAAI/bge-m3
The `BAAI/bge-m3` model comes with extra weights for sparse and colbert embeddings but unfortunately in its `config.json` The `BAAI/bge-m3` model comes with extra weights for sparse and colbert embeddings but unfortunately in its `config.json`
......
...@@ -842,6 +842,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A ...@@ -842,6 +842,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | | Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|--------------|--------|--------|-------------------|----------------------|---------------------------| |--------------|--------|--------|-------------------|----------------------|---------------------------|
| `JinaVLForSequenceClassification` | JinaVL-based | T + I<sup>E+</sup> | `jinaai/jina-reranker-m0`, etc. | ✅︎ | ✅︎ | | `JinaVLForSequenceClassification` | JinaVL-based | T + I<sup>E+</sup> | `jinaai/jina-reranker-m0`, etc. | ✅︎ | ✅︎ |
| `LlamaNemotronVLForSequenceClassification` | Llama Nemotron Reranker + SigLIP | T + I<sup>E+</sup> | `nvidia/llama-nemotron-rerank-vl-1b-v2` | | |
| `Qwen3VLForSequenceClassification` | Qwen3-VL-Reranker | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-Reranker-2B`(see note), etc. | ✅︎ | ✅︎ | | `Qwen3VLForSequenceClassification` | Qwen3-VL-Reranker | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen3-VL-Reranker-2B`(see note), etc. | ✅︎ | ✅︎ |
<sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion)) <sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion))
......
{%- set query_msg = (messages | selectattr('role', 'equalto', 'query') | list | first) -%}
{%- set doc_msg = (messages | selectattr('role', 'equalto', 'document') | list | first) -%}
{%- set q = query_msg['content'] -%}
{%- set d = doc_msg['content'] -%}
{# If the doc contains <image> anywhere, hoist a single <image> to the front #}
{%- set has_image = ("<image>" in d) -%}
{%- set d_clean = d | replace("<image>", "") -%}
{%- set q_clean = q | replace("<image>", "") -%}
{%- if has_image -%}<image>{{ " " }}{%- endif -%}
question:{{ q_clean }}{{ " " }}
{{ " " }}
{{ " " }}passage:{{ d_clean }}
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
""" """
Tests for LlamaNemotronVL embedding model (nvidia/llama-nemotron-embed-vl-1b-v2). Tests for the LlamaNemotronVL model family:
- nvidia/llama-nemotron-embed-vl-1b-v2 (LlamaNemotronVLForCausalLM / embed)
- nvidia/llama-nemotron-rerank-vl-1b-v2
(LlamaNemotronVLForSequenceClassification / rerank)
This model uses SigLIP vision encoder with bidirectional LLaMA for embeddings. Both variants share a SigLIP vision encoder with a bidirectional LLaMA backbone.
""" """
import base64
from io import BytesIO
from pathlib import Path
import pytest import pytest
import torch import torch
from transformers import AutoModel from transformers import AutoModel, AutoModelForSequenceClassification, AutoProcessor
from vllm.entrypoints.chat_utils import (
ChatCompletionContentPartImageParam,
ChatCompletionContentPartTextParam,
)
from vllm.entrypoints.pooling.score.utils import ScoreMultiModalParam
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from ...utils import check_embeddings_close from ...utils import check_embeddings_close
...@@ -99,7 +112,7 @@ def _run_test( ...@@ -99,7 +112,7 @@ def _run_test(
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["half"])
def test_models_text( def test_models_text(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
...@@ -123,7 +136,7 @@ def test_models_text( ...@@ -123,7 +136,7 @@ def test_models_text(
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["half"])
def test_models_image( def test_models_image(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
...@@ -146,3 +159,197 @@ def test_models_image( ...@@ -146,3 +159,197 @@ def test_models_image(
model, model,
dtype=dtype, dtype=dtype,
) )
# ---------------------------------------------------------------------------
# Reranker tests — nvidia/llama-nemotron-rerank-vl-1b-v2
# ---------------------------------------------------------------------------
RERANKER_MODELS = ["nvidia/llama-nemotron-rerank-vl-1b-v2"]
# The tokenizer's built-in chat template is not suitable for the Score/Rerank
# APIs (it's inherited from the base LLM). We must use the provided override.
_RERANKER_SCORE_TEMPLATE = (
Path(__file__).parents[4]
/ "examples/pooling/score/template/nemotron-vl-rerank.jinja"
).read_text()
RERANKER_TEXT_QUERY = "How is AI improving the intelligence and capabilities of robots?"
RERANKER_TEXT_DOCS = [
"AI enables robots to perceive, plan, and act autonomously.",
(
"A biological foundation model designed to analyze DNA, RNA, "
"and protein sequences."
),
]
RERANKER_IMAGE_QUERY = "photo of a red stop sign on a street"
def _pil_to_data_uri(image) -> str:
buf = BytesIO()
image.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode()
return f"data:image/png;base64,{b64}"
def _run_hf_reranker(
hf_runner: type[HfRunner],
model: str,
dtype: str,
query: str,
docs: list,
) -> list[float]:
"""Run HF reranker inference; docs is a list of (doc_text, doc_image|None)."""
with hf_runner(
model,
dtype=dtype,
trust_remote_code=True,
auto_cls=AutoModelForSequenceClassification,
) as hf_model:
processor = AutoProcessor.from_pretrained(
model,
trust_remote_code=True,
max_input_tiles=6,
use_thumbnail=True,
rerank_max_length=2048,
)
examples = [
{
"question": query,
"doc_text": doc_text if doc_text is not None else "",
"doc_image": doc_image if doc_image is not None else "",
}
for doc_text, doc_image in docs
]
batch_dict = processor.process_queries_documents_crossencoder(examples)
batch_dict = {
k: v.to(hf_model.model.device) if isinstance(v, torch.Tensor) else v
for k, v in batch_dict.items()
}
with torch.inference_mode():
logits = hf_model.model(**batch_dict, return_dict=True).logits
# vLLM applies sigmoid activation to the raw logits before returning
# scores; apply the same here so both sides are comparable.
scores = torch.sigmoid(logits.squeeze(-1).float())
return scores.detach().cpu().tolist()
def _run_vllm_reranker(
vllm_runner: type[VllmRunner],
model: str,
dtype: str,
query: str,
docs: list,
) -> list[float]:
"""Run vLLM reranker inference; docs is a list of (doc_text, doc_image|None)."""
with vllm_runner(
model,
runner="pooling",
dtype=dtype,
max_model_len=2048,
enforce_eager=True,
trust_remote_code=True,
) as vllm_model:
has_images = any(img is not None for _, img in docs)
if not has_images:
# Text-only path: use the simple string score API.
queries = [query] * len(docs)
doc_texts = [doc_text for doc_text, _ in docs]
outputs = vllm_model.score(
queries,
doc_texts,
chat_template=_RERANKER_SCORE_TEMPLATE,
)
else:
# Multimodal path: build ScoreMultiModalParam for each pair.
query_params = [
ScoreMultiModalParam(
content=[
ChatCompletionContentPartTextParam(
type="text",
text=query,
)
]
)
] * len(docs)
doc_params = []
for doc_text, doc_image in docs:
content: list = []
if doc_image is not None:
content.append(
ChatCompletionContentPartImageParam(
type="image_url",
image_url={"url": _pil_to_data_uri(doc_image)},
)
)
if doc_text:
content.append(
ChatCompletionContentPartTextParam(
type="text",
text=doc_text,
)
)
doc_params.append(ScoreMultiModalParam(content=content))
raw_outputs = vllm_model.llm.score(
query_params,
doc_params,
chat_template=_RERANKER_SCORE_TEMPLATE,
)
outputs = [o.outputs.score for o in raw_outputs]
return outputs
def _run_reranker_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
model: str,
dtype: str,
query: str,
docs: list,
) -> None:
"""Compare HF and vLLM reranker scores.
NOTE: Run vLLM first to avoid CUDA initialization issues with multiprocessing.
"""
vllm_scores = _run_vllm_reranker(vllm_runner, model, dtype, query, docs)
hf_scores = _run_hf_reranker(hf_runner, model, dtype, query, docs)
assert len(hf_scores) == len(vllm_scores), (
f"Output length mismatch: HF={len(hf_scores)}, vLLM={len(vllm_scores)}"
)
for i, (hf_score, vllm_score) in enumerate(zip(hf_scores, vllm_scores)):
assert hf_score == pytest.approx(vllm_score, rel=0.02), (
f"Score mismatch at index {i}: HF={hf_score:.4f}, vLLM={vllm_score:.4f}"
)
@pytest.mark.parametrize("model", RERANKER_MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_reranker_text(
hf_runner,
vllm_runner,
model: str,
dtype: str,
) -> None:
"""Test reranking with text-only query and text documents."""
docs = [(text, None) for text in RERANKER_TEXT_DOCS]
_run_reranker_test(hf_runner, vllm_runner, model, dtype, RERANKER_TEXT_QUERY, docs)
@pytest.mark.parametrize("model", RERANKER_MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_reranker_image_doc(
hf_runner,
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
"""Test reranking with text query against image documents."""
docs = [(None, asset.pil_image) for asset in image_assets]
_run_reranker_test(hf_runner, vllm_runner, model, dtype, RERANKER_IMAGE_QUERY, docs)
...@@ -653,6 +653,9 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { ...@@ -653,6 +653,9 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
"LlamaBidirectionalForSequenceClassification": _HfExamplesInfo( "LlamaBidirectionalForSequenceClassification": _HfExamplesInfo(
"nvidia/llama-nemotron-rerank-1b-v2", trust_remote_code=True "nvidia/llama-nemotron-rerank-1b-v2", trust_remote_code=True
), ),
"LlamaNemotronVLForSequenceClassification": _HfExamplesInfo(
"nvidia/llama-nemotron-rerank-vl-1b-v2", trust_remote_code=True
),
"ModernBertForSequenceClassification": _HfExamplesInfo( "ModernBertForSequenceClassification": _HfExamplesInfo(
"Alibaba-NLP/gte-reranker-modernbert-base" "Alibaba-NLP/gte-reranker-modernbert-base"
), ),
......
...@@ -664,6 +664,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { ...@@ -664,6 +664,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"LlamaBidirectionalForSequenceClassification": LlamaBidirectionalConfig, "LlamaBidirectionalForSequenceClassification": LlamaBidirectionalConfig,
"LlamaBidirectionalModel": LlamaBidirectionalConfig, "LlamaBidirectionalModel": LlamaBidirectionalConfig,
"LlamaNemotronVLModel": LlamaNemotronVLConfig, "LlamaNemotronVLModel": LlamaNemotronVLConfig,
"LlamaNemotronVLForSequenceClassification": LlamaNemotronVLConfig,
"NomicBertModel": NomicBertModelConfig, "NomicBertModel": NomicBertModelConfig,
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig, "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig, "Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
# Copyright (c) 2023 OpenGVLab # Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
import math
from abc import ABC from abc import ABC
from collections.abc import Iterable from collections.abc import Iterable
...@@ -18,6 +19,7 @@ from transformers import AutoModel, PretrainedConfig ...@@ -18,6 +19,7 @@ from transformers import AutoModel, PretrainedConfig
from transformers.image_processing_utils_fast import BaseImageProcessorFast from transformers.image_processing_utils_fast import BaseImageProcessorFast
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.pooler import DispatchPooler from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
...@@ -42,6 +44,7 @@ from vllm.transformers_utils.repo_utils import get_hf_file_to_dict ...@@ -42,6 +44,7 @@ from vllm.transformers_utils.repo_utils import get_hf_file_to_dict
from .interfaces import ( from .interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
SupportsCrossEncoding,
SupportsLoRA, SupportsLoRA,
SupportsMultiModal, SupportsMultiModal,
SupportsPP, SupportsPP,
...@@ -883,3 +886,57 @@ class LlamaNemotronVLForEmbedding(LlamaNemotronVLChatModel, VllmModelForPooling) ...@@ -883,3 +886,57 @@ class LlamaNemotronVLForEmbedding(LlamaNemotronVLChatModel, VllmModelForPooling)
"""Override to use different weight mapping for SigLIP.""" """Override to use different weight mapping for SigLIP."""
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.weight_mapper) return loader.load_weights(weights, mapper=self.weight_mapper)
class LlamaNemotronVLForSequenceClassification(
LlamaNemotronVLForEmbedding, SupportsCrossEncoding
):
"""LlamaNemotronVL model variant for sequence classification / reranking."""
# Reranker checkpoint places base model weights under `model.*`,
# while `score.*` remains at the top level.
weight_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) | (
LlamaNemotronVLForEmbedding.weight_mapper
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix)
text_config = vllm_config.model_config.hf_config.get_text_config()
model_config = vllm_config.model_config
quant_config = vllm_config.quant_config
self.score = ReplicatedLinear(
model_config.get_hidden_size(),
text_config.num_labels,
bias=False,
params_dtype=model_config.head_dtype,
quant_config=quant_config,
return_bias=False,
prefix=maybe_prefix(prefix, "score"),
)
pooler_config = model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loaded_weights = super().load_weights(weights)
# reranker checkpoint omits the inner LM seq-cls head
# (`language_model.score.*`). It is unused by this outer model, but
# the default loader expects all parameters to be initialized.
for name, param in self.named_parameters():
if not name.startswith("language_model.score.") or name in loaded_weights:
continue
if name.endswith(".weight"):
torch.nn.init.kaiming_uniform_(param, a=math.sqrt(5))
elif name.endswith(".bias"):
torch.nn.init.zeros_(param)
else:
torch.nn.init.normal_(param, mean=0.0, std=0.02)
loaded_weights.add(name)
return loaded_weights
...@@ -284,6 +284,10 @@ _CROSS_ENCODER_MODELS = { ...@@ -284,6 +284,10 @@ _CROSS_ENCODER_MODELS = {
"llama", "llama",
"LlamaBidirectionalForSequenceClassification", "LlamaBidirectionalForSequenceClassification",
), ),
"LlamaNemotronVLForSequenceClassification": (
"nemotron_vl",
"LlamaNemotronVLForSequenceClassification",
),
"ModernBertForSequenceClassification": ( "ModernBertForSequenceClassification": (
"modernbert", "modernbert",
"ModernBertForSequenceClassification", "ModernBertForSequenceClassification",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment