Unverified Commit 911d38ed authored by wang.yuqi's avatar wang.yuqi Committed by GitHub
Browse files

[Model] Let more models to support the score template. (#31335)


Signed-off-by: default avatarwang.yuqi <yuqi.wang@daocloud.io>
Signed-off-by: default avatarwang.yuqi <noooop@126.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent caaa482a
...@@ -540,21 +540,28 @@ If your model is not in the above list, we will try to automatically convert the ...@@ -540,21 +540,28 @@ If your model is not in the above list, we will try to automatically convert the
Cross-encoder and reranker models are a subset of classification models that accept two prompts as input. Cross-encoder and reranker models are a subset of classification models that accept two prompts as input.
These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) API. These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) API.
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | | Architecture | Models | Example HF Models | Score template (see note) | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|--------------|--------|-------------------|----------------------|---------------------------| |--------------|--------|-------------------|---------------------------|-----------------------------|-----------------------------------------|
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | | | `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | N/A | | |
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | ✅︎ | ✅︎ | | `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma`(see note), etc. | [bge-reranker-v2-gemma.jinja](../../examples/pooling/score/template/bge-reranker-v2-gemma.jinja) | ✅︎ | ✅︎ |
| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | | | | `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | N/A | | |
| `LlamaBidirectionalForSequenceClassification`<sup>C</sup> | Llama-based with bidirectional attention | `nvidia/llama-nemotron-rerank-1b-v2` (see note), etc. | ✅︎ | ✅︎ | | `LlamaBidirectionalForSequenceClassification`<sup>C</sup> | Llama-based with bidirectional attention | `nvidia/llama-nemotron-rerank-1b-v2`, etc. | [nemotron-rerank.jinja](../../examples/pooling/score/template/nemotron-rerank.jinja) | ✅︎ | ✅︎ |
| `Qwen2ForSequenceClassification`<sup>C</sup> | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | ✅︎ | | `Qwen2ForSequenceClassification`<sup>C</sup> | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2`(see note), etc. | [mxbai_rerank_v2.jinja](../../examples/pooling/score/template/mxbai_rerank_v2.jinja) | ✅︎ | ✅︎ |
| `Qwen3ForSequenceClassification`<sup>C</sup> | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | ✅︎ | | `Qwen3ForSequenceClassification`<sup>C</sup> | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B`(see note), etc. | [qwen3_reranker.jinja](../../examples/pooling/score/template/qwen3_reranker.jinja) | ✅︎ | ✅︎ |
| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | | | `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | N/A | | |
| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | | | `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | N/A | | |
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | | `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | N/A | \* | \* |
<sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion)) <sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion))
\* Feature support is the same as that of the original model. \* Feature support is the same as that of the original model.
!!! note
Some models require a specific prompt format to work correctly.
You can find Example HF Models's corresponding score template in [examples/pooling/score/template/](../../examples/pooling/score/template)
Examples : [examples/pooling/score/using_template_offline.py](../../examples/pooling/score/using_template_offline.py) [examples/pooling/score/using_template_online.py](../../examples/pooling/score/using_template_online.py)
!!! note !!! note
Load the official original `BAAI/bge-reranker-v2-gemma` by using the following command. Load the official original `BAAI/bge-reranker-v2-gemma` by using the following command.
...@@ -565,11 +572,6 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A ...@@ -565,11 +572,6 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
!!! note !!! note
The second-generation GTE model (mGTE-TRM) is named `NewForSequenceClassification`. The name `NewForSequenceClassification` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture. The second-generation GTE model (mGTE-TRM) is named `NewForSequenceClassification`. The name `NewForSequenceClassification` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewForSequenceClassification"]}'` to specify the use of the `GteNewForSequenceClassification` architecture.
!!! note
`nvidia/llama-nemotron-rerank-1b-v2` require a specific prompt format to work correctly.
Examples : [offline_using_template.py](../../examples/pooling/score/offline_using_template.py) [online_using_template.py](../../examples/pooling/score/online_using_template.py)
!!! note !!! note
Load the official original `mxbai-rerank-v2` by using the following command. Load the official original `mxbai-rerank-v2` by using the following command.
...@@ -578,7 +580,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A ...@@ -578,7 +580,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
``` ```
!!! note !!! note
Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: [examples/pooling/score/offline_reranker.py](../../examples/pooling/score/offline_reranker.py). Load the official original `Qwen3 Reranker` by using the following command. More information can be found at: [examples/pooling/score/qwen3_reranker_offline.py](../../examples/pooling/score/qwen3_reranker_offline.py) [examples/pooling/score/qwen3_reranker_online.py](../../examples/pooling/score/qwen3_reranker_online.py).
```bash ```bash
vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' vllm serve Qwen/Qwen3-Reranker-0.6B --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}'
......
...@@ -2,35 +2,70 @@ ...@@ -2,35 +2,70 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501 # ruff: noqa: E501
"""
Script to convert Large Language Models (LLMs) to Sequence Classification models.
This is particularly useful for converting reranker models that use next-token
prediction to a sequence classification format for compatibility with standard
classification and rerank pipelines.
Usage examples:
- For BAAI/bge-reranker-v2-gemma:
python convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma \
--classifier_from_tokens '["Yes"]' --method no_post_processing \
--path ./bge-reranker-v2-gemma-seq-cls
- For mxbai-rerank-v2:
python convert_model_to_seq_cls.py --model_name mixedbread-ai/mxbai-rerank-base-v2 \
--classifier_from_tokens '["0", "1"]' --method from_2_way_softmax \
--path ./mxbai-rerank-base-v2-seq-cls
- For Qwen3-Reranker:
python convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B \
--classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax \
--path ./Qwen3-Reranker-0.6B-seq-cls
Note: For BAAI/bge-reranker-v2-gemma, "Yes" and "yes" are different tokens.
"""
import argparse import argparse
import json import json
import torch import torch
import transformers import transformers
# Usage:
# for BAAI/bge-reranker-v2-gemma
# Caution: "Yes" and "yes" are two different tokens
# python convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["Yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls
# for mxbai-rerank-v2
# python convert_model_to_seq_cls.py --model_name mixedbread-ai/mxbai-rerank-base-v2 --classifier_from_tokens '["0", "1"]' --method from_2_way_softmax --path ./mxbai-rerank-base-v2-seq-cls
# for Qwen3-Reranker
# python convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls
def from_2_way_softmax(causal_lm, seq_cls_model, tokenizer, tokens, device): def from_2_way_softmax(causal_lm, seq_cls_model, tokenizer, tokens, device):
# refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3 """
assert len(tokens) == 2 This method extracts the difference between weights for 'true' and 'false' tokens
from the language model head to create a single classification weight vector.
Args:
causal_lm: The original causal language model
seq_cls_model: The target sequence classification model
tokenizer: Model tokenizer
tokens: List of two tokens representing [false_token, true_token]
device: Target device (cpu/cuda)
Reference: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
"""
assert len(tokens) == 2, (
"Method requires exactly two tokens for binary classification"
)
# Get the language model head weights (vocabulary_size x hidden_size)
lm_head_weights = causal_lm.lm_head.weight lm_head_weights = causal_lm.lm_head.weight
# Convert token strings to their corresponding token IDs
false_id = tokenizer.convert_tokens_to_ids(tokens[0]) false_id = tokenizer.convert_tokens_to_ids(tokens[0])
true_id = tokenizer.convert_tokens_to_ids(tokens[1]) true_id = tokenizer.convert_tokens_to_ids(tokens[1])
# Compute the classification weight as the difference between true and false token weights
# This follows the approach in: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
score_weight = lm_head_weights[true_id].to(device).to( score_weight = lm_head_weights[true_id].to(device).to(
torch.float32 torch.float32
) - lm_head_weights[false_id].to(device).to(torch.float32) ) - lm_head_weights[false_id].to(device).to(torch.float32)
# Copy the computed weights to the sequence classification model
with torch.no_grad(): with torch.no_grad():
seq_cls_model.score.weight.copy_(score_weight.unsqueeze(0)) seq_cls_model.score.weight.copy_(score_weight.unsqueeze(0))
if seq_cls_model.score.bias is not None: if seq_cls_model.score.bias is not None:
...@@ -38,12 +73,29 @@ def from_2_way_softmax(causal_lm, seq_cls_model, tokenizer, tokens, device): ...@@ -38,12 +73,29 @@ def from_2_way_softmax(causal_lm, seq_cls_model, tokenizer, tokens, device):
def no_post_processing(causal_lm, seq_cls_model, tokenizer, tokens, device): def no_post_processing(causal_lm, seq_cls_model, tokenizer, tokens, device):
"""
Directly use token weights from the language model head for classification.
This method maps each classification label directly to a corresponding token
in the vocabulary without additional transformation.
Args:
causal_lm: The original causal language model
seq_cls_model: The target sequence classification model
tokenizer: Model tokenizer
tokens: List of tokens representing class labels
device: Target device (cpu/cuda)
"""
# Get the language model head weights (vocabulary_size x hidden_size)
lm_head_weights = causal_lm.lm_head.weight lm_head_weights = causal_lm.lm_head.weight
# Convert all tokens to their corresponding token IDs
token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens] token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens]
# Extract weights for the specific tokens (num_tokens x hidden_size)
score_weight = lm_head_weights[token_ids].to(device) score_weight = lm_head_weights[token_ids].to(device)
# Copy the weights to the sequence classification model
with torch.no_grad(): with torch.no_grad():
seq_cls_model.score.weight.copy_(score_weight) seq_cls_model.score.weight.copy_(score_weight)
if seq_cls_model.score.bias is not None: if seq_cls_model.score.bias is not None:
...@@ -58,19 +110,33 @@ method_map = { ...@@ -58,19 +110,33 @@ method_map = {
def converting( def converting(
model_name, classifier_from_tokens, path, method, use_pad_token=False, device="cpu" model_name, classifier_from_tokens, path, method, use_pad_token=False, device="cpu"
): ):
assert method in method_map """
Main conversion function to transform a CausalLM model to SequenceClassification.
Args:
model_name: Name or path of the pretrained model
classifier_from_tokens: List of tokens used for classification
path: Output path to save the converted model
method: Conversion method ('from_2_way_softmax' or 'no_post_processing')
use_pad_token: Whether to use padding token in the sequence classification model
device: Device to load the model on ('cpu' or 'cuda')
"""
assert method in method_map, f"Unknown method: {method}"
# Determine number of labels based on conversion method
if method == "from_2_way_softmax": if method == "from_2_way_softmax":
assert len(classifier_from_tokens) == 2 assert len(classifier_from_tokens) == 2
num_labels = 1 num_labels = 1
else: else:
num_labels = len(classifier_from_tokens) num_labels = len(classifier_from_tokens)
# Load tokenizer and original causal language model
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
causal_lm = transformers.AutoModelForCausalLM.from_pretrained( causal_lm = transformers.AutoModelForCausalLM.from_pretrained(
model_name, device_map=device model_name, device_map=device
) )
# Load an empty sequence classification model with the same architecture
seq_cls_model = transformers.AutoModelForSequenceClassification.from_pretrained( seq_cls_model = transformers.AutoModelForSequenceClassification.from_pretrained(
model_name, model_name,
num_labels=num_labels, num_labels=num_labels,
...@@ -78,14 +144,17 @@ def converting( ...@@ -78,14 +144,17 @@ def converting(
device_map=device, device_map=device,
) )
# Apply the selected conversion method to transfer weights
method_map[method]( method_map[method](
causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device
) )
# `llm as reranker` defaults to not using pad_token # Configure padding token settings
# Note: Reranker models typically don't use padding tokens by default
seq_cls_model.config.use_pad_token = use_pad_token seq_cls_model.config.use_pad_token = use_pad_token
seq_cls_model.config.pad_token_id = tokenizer.pad_token_id seq_cls_model.config.pad_token_id = tokenizer.pad_token_id
# Save the converted model and tokenizer
seq_cls_model.save_pretrained(path) seq_cls_model.save_pretrained(path)
tokenizer.save_pretrained(path) tokenizer.save_pretrained(path)
...@@ -99,25 +168,30 @@ def parse_args(): ...@@ -99,25 +168,30 @@ def parse_args():
"--model_name", "--model_name",
type=str, type=str,
default="BAAI/bge-reranker-v2-gemma", default="BAAI/bge-reranker-v2-gemma",
help="Model name", help="HuggingFace model name or local path",
) )
parser.add_argument( parser.add_argument(
"--classifier_from_tokens", "--classifier_from_tokens",
type=str, type=str,
default='["Yes"]', default='["Yes"]',
help="classifier from tokens", help="JSON string of tokens used for classification labels",
) )
parser.add_argument( parser.add_argument(
"--method", type=str, default="no_post_processing", help="Converting converting" "--method",
type=str,
default="no_post_processing",
help="Conversion method to use",
) )
parser.add_argument( parser.add_argument(
"--use-pad-token", action="store_true", help="Whether to use pad_token" "--use-pad-token",
action="store_true",
help="Enable padding token in the sequence classification model",
) )
parser.add_argument( parser.add_argument(
"--path", "--path",
type=str, type=str,
default="./bge-reranker-v2-gemma-seq-cls", default="./bge-reranker-v2-gemma-seq-cls",
help="Path to save converted model", help="Output directory to save the converted model",
) )
return parser.parse_args() return parser.parse_args()
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
from vllm import LLM
model_name = "Qwen/Qwen3-Reranker-0.6B"
# What is the difference between the official original version and one
# that has been converted into a sequence classification model?
# Qwen3-Reranker is a language model that doing reranker by using the
# logits of "no" and "yes" tokens.
# It needs to computing 151669 tokens logits, making this method extremely
# inefficient, not to mention incompatible with the vllm score API.
# A method for converting the original model into a sequence classification
# model was proposed. See:https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
# Models converted offline using this method can not only be more efficient
# and support the vllm score API, but also make the init parameters more
# concise, for example.
# llm = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", runner="pooling")
# If you want to load the official original version, the init parameters are
# as follows.
def get_llm() -> LLM:
"""Initializes and returns the LLM model for Qwen3-Reranker."""
return LLM(
model=model_name,
runner="pooling",
hf_overrides={
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
},
)
# Why do we need hf_overrides for the official original version:
# vllm converts it to Qwen3ForSequenceClassification when loaded for
# better performance.
# - Firstly, we need using `"architectures": ["Qwen3ForSequenceClassification"],`
# to manually route to Qwen3ForSequenceClassification.
# - Then, we will extract the vector corresponding to classifier_from_token
# from lm_head using `"classifier_from_token": ["no", "yes"]`.
# - Third, we will convert these two vectors into one vector. The use of
# conversion logic is controlled by `using "is_original_qwen3_reranker": True`.
# Please use the query_template and document_template to format the query and
# document for better reranker results.
prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
query_template = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
document_template = "<Document>: {doc}{suffix}"
def main() -> None:
instruction = (
"Given a web search query, retrieve relevant passages that answer the query"
)
queries = [
"What is the capital of China?",
"Explain gravity",
]
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
]
queries = [
query_template.format(prefix=prefix, instruction=instruction, query=query)
for query in queries
]
documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents]
llm = get_llm()
outputs = llm.score(queries, documents)
print("-" * 30)
print([output.outputs.score for output in outputs])
print("-" * 30)
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
from pathlib import Path
from vllm import LLM
model_name = "nvidia/llama-nemotron-rerank-1b-v2"
# Path to template file
template_path = Path(__file__).parent / "template" / "nemotron-rerank.jinja"
chat_template = template_path.read_text()
llm = LLM(model=model_name, runner="pooling", trust_remote_code=True)
query = "how much protein should a female eat?"
documents = [
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.",
"Calorie intake should not fall below 1,200 a day in women or 1,500 a day in men, except under the supervision of a health professional.",
]
outputs = llm.score(query, documents, chat_template=chat_template)
print("-" * 30)
print([output.outputs.score for output in outputs])
print("-" * 30)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
"""
What is the difference between the official original version and one
that has been converted into a sequence classification model?
Qwen3-Reranker is a language model that doing reranker by using the
logits of "no" and "yes" tokens.
This requires computing logits for all 151,669 tokens in the vocabulary,
making it inefficient and incompatible with vLLM's score() API.
A conversion method has been proposed to transform the original model into a
sequence classification model. This converted model:
1. Is significantly more efficient
2. Fully supports vLLM's score() API
3. Simplifies initialization parameters
Reference: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
Reference: https://github.com/vllm-project/vllm/blob/main/examples/pooling/score/convert_model_to_seq_cls.py
For the converted model, initialization would simply be:
llm = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", runner="pooling")
This example demonstrates loading the ORIGINAL model with special overrides
to make it compatible with vLLM's score API.
"""
from pathlib import Path
from vllm import LLM
model_name = "Qwen/Qwen3-Reranker-0.6B"
def get_llm() -> LLM:
"""
Initializes and returns the LLM model for Qwen3-Reranker.
Returns:
LLM: Configured vLLM instance for reranking tasks.
Note:
This function loads the ORIGINAL Qwen3-Reranker model with specific
overrides to make it compatible with vLLM's score API.
"""
return LLM(
# Specify the original model from HuggingFace
model=model_name,
# Use pooling runner for score task
runner="pooling",
# HuggingFace model configuration overrides required for compatibility
hf_overrides={
# Manually route to sequence classification architecture
# This tells vLLM to use Qwen3ForSequenceClassification instead of
# the default Qwen3ForCausalLM
"architectures": ["Qwen3ForSequenceClassification"],
# Specify which token logits to extract from the language model head
# The original reranker uses "no" and "yes" token logits for scoring
"classifier_from_token": ["no", "yes"],
# Enable special handling for original Qwen3-Reranker models
# This flag triggers conversion logic that transforms the two token
# vectors into a single classification vector
"is_original_qwen3_reranker": True,
},
)
def main() -> None:
# Load the Jinja template for formatting query-document pairs
# The template ensures proper formatting for the reranker model
template_home = Path(__file__).parent / "template"
template_path = "qwen3_reranker.jinja"
chat_template = (template_home / template_path).read_text()
# Sample queries for testing the reranker
queries = [
"What is the capital of China?",
"Explain gravity",
]
# Corresponding documents to be scored against each query
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
]
# Initialize the LLM model with the original Qwen3-Reranker configuration
llm = get_llm()
# Compute relevance scores for each query-document pair
# The score() method returns a relevance score for each pair
# Higher scores indicate better relevance
outputs = llm.score(queries, documents, chat_template=chat_template)
# Extract and print the relevance scores from the outputs
# Each output contains a score representing query-document relevance
print("-" * 30)
print("Relevance scores:", [output.outputs.score for output in outputs])
print("-" * 30)
if __name__ == "__main__":
main()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
"""
What is the difference between the official original version and one
that has been converted into a sequence classification model?
Qwen3-Reranker is a language model that doing reranker by using the
logits of "no" and "yes" tokens.
This requires computing logits for all 151,669 tokens in the vocabulary,
making it inefficient and incompatible with vLLM's score() API.
A conversion method has been proposed to transform the original model into a
sequence classification model. This converted model:
1. Is significantly more efficient
2. Fully supports vLLM's score() API
3. Simplifies initialization parameters
Reference: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
Reference: https://github.com/vllm-project/vllm/blob/main/examples/pooling/score/convert_model_to_seq_cls.py
For the converted model, initialization would simply be:
vllm serve tomaarsen/Qwen3-Reranker-0.6B-seq-cls --runner pooling --chat-template examples/pooling/score/template/qwen3_reranker.jinja
This example demonstrates loading the ORIGINAL model with special overrides
to make it compatible with vLLM's score API.
vllm serve Qwen/Qwen3-Reranker-0.6B --runner pooling --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' --chat-template examples/pooling/score/template/qwen3_reranker.jinja
"""
import json
import requests
# URL of the vLLM server's score endpoint
# Default vLLM server runs on localhost port 8000
url = "http://127.0.0.1:8000/score"
# HTTP headers for the request
headers = {"accept": "application/json", "Content-Type": "application/json"}
# Example queries & documents
queries = [
"What is the capital of China?",
"Explain gravity",
]
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
]
# Request payload for the score API
data = {
"model": "Qwen/Qwen3-Reranker-0.6B",
"text_1": queries,
"text_2": documents,
}
def main():
"""Main function to send a score request to the vLLM server.
This function sends a POST request to the /score endpoint with
the query and documents, then prints the relevance scores.
"""
# Send POST request to the vLLM server's score endpoint
response = requests.post(url, headers=headers, json=data)
# Check if the request was successful
if response.status_code == 200:
print("Request successful!")
# Pretty print the JSON response containing relevance scores
# The response includes scores for each document's relevance to the query
print(json.dumps(response.json(), indent=2))
else:
# Handle request failure
print(f"Request failed with status code: {response.status_code}")
print(response.text)
if __name__ == "__main__":
main()
A: {{ (messages | selectattr("role", "eq", "query") | first).content }}
B: {{ (messages | selectattr("role", "eq", "document") | first).content }}
Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'.
\ No newline at end of file
<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
query: {{ (messages | selectattr("role", "eq", "query") | first).content }}
document: {{ (messages | selectattr("role", "eq", "document") | first).content }}
You are a search relevance expert who evaluates how well documents match search queries. For each query-document pair, carefully analyze the semantic relationship between them, then provide your binary relevance judgment (0 for not relevant, 1 for relevant).
Relevance:<|im_end|>
<|im_start|>assistant
<|im_start|>system
Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>
<|im_start|>user
<Instruct>: {{ messages | selectattr("role", "eq", "system") | map(attribute="content") | first | default("Given a web search query, retrieve relevant passages that answer the query") }}
<Query>: {{ messages | selectattr("role", "eq", "query") | map(attribute="content") | first }}
<Document>: {{ messages | selectattr("role", "eq", "document") | map(attribute="content") | first }}<|im_end|>
<|im_start|>assistant
<think>
</think>
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
from argparse import Namespace
from pathlib import Path
from typing import Any
from vllm import LLM, EngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser
def parse_args():
"""Parse command line arguments for the reranking example.
This function sets up the argument parser with default values
specific to reranking models, including the model name and
runner type.
"""
parser = FlexibleArgumentParser()
# Add all EngineArgs command line arguments to the parser
parser = EngineArgs.add_cli_args(parser)
# Set default values specific to this reranking example
# These defaults ensure the script works out-of-the-box for reranking tasks
parser.set_defaults(
model="nvidia/llama-nemotron-rerank-1b-v2", # Default reranking model
runner="pooling", # Required for cross-encoder/reranking models
trust_remote_code=True, # Allow loading models with custom code
)
return parser.parse_args()
def get_chat_template(model: str) -> str:
"""Load the appropriate chat template for the specified model.
Reranking models require specific prompt templates to format
query-document pairs correctly. This function maps model names
to their corresponding template files.
"""
# Directory containing all chat template files
template_home = Path(__file__).parent / "template"
# Mapping from model names to their corresponding template files
# Each reranking model has its own specific prompt format
model_name_to_template_path_map = {
"BAAI/bge-reranker-v2-gemma": "bge-reranker-v2-gemma.jinja",
"Qwen/Qwen3-Reranker-0.6B": "qwen3_reranker.jinja",
"Qwen/Qwen3-Reranker-4B": "qwen3_reranker.jinja",
"Qwen/Qwen3-Reranker-8B": "qwen3_reranker.jinja",
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls": "qwen3_reranker.jinja",
"tomaarsen/Qwen3-Reranker-4B-seq-cls": "qwen3_reranker.jinja",
"tomaarsen/Qwen3-Reranker-8B-seq-cls": "qwen3_reranker.jinja",
"mixedbread-ai/mxbai-rerank-base-v2": "mxbai_rerank_v2.jinja",
"mixedbread-ai/mxbai-rerank-large-v2": "mxbai_rerank_v2.jinja",
"nvidia/llama-nemotron-rerank-1b-v2": "nemotron-rerank.jinja",
}
# Get the template filename for the specified model
template_path = model_name_to_template_path_map.get(model)
if template_path is None:
raise ValueError(f"This demo does not support model name: {model}.")
# Read and return the template content
return (template_home / template_path).read_text()
def get_hf_overrides(model: str) -> dict[str, Any]:
"""Convert Large Language Models (LLMs) to Sequence Classification models.
note:
Some reranking models require special configuration overrides to work
correctly with vLLM's score API.
Reference: https://github.com/vllm-project/vllm/blob/main/examples/pooling/score/qwen3_reranker_offline.py
Reference: https://github.com/vllm-project/vllm/blob/main/examples/pooling/score/convert_model_to_seq_cls.py
"""
model_name_to_hf_overrides_map = {
"BAAI/bge-reranker-v2-gemma": {
"architectures": ["GemmaForSequenceClassification"],
"classifier_from_token": ["Yes"],
"method": "no_post_processing",
},
"Qwen/Qwen3-Reranker-0.6B": {
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
},
"Qwen/Qwen3-Reranker-4B": {
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
},
"Qwen/Qwen3-Reranker-8B": {
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
},
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls": {},
"tomaarsen/Qwen3-Reranker-4B-seq-cls": {},
"tomaarsen/Qwen3-Reranker-8B-seq-cls": {},
"mixedbread-ai/mxbai-rerank-base-v2": {
"architectures": ["Qwen2ForSequenceClassification"],
"classifier_from_token": ["0", "1"],
"method": "from_2_way_softmax",
},
"mixedbread-ai/mxbai-rerank-large-v2": {
"architectures": ["Qwen2ForSequenceClassification"],
"classifier_from_token": ["0", "1"],
"method": "from_2_way_softmax",
},
"nvidia/llama-nemotron-rerank-1b-v2": {},
}
hf_overrides = model_name_to_hf_overrides_map.get(model)
if hf_overrides is None:
raise ValueError(f"This demo does not support model name: {model}.")
return hf_overrides
def main(args: Namespace):
"""Main execution function for the reranking example."""
# Get the overrides for the specified model
args.hf_overrides = get_hf_overrides(args.model)
# Initialize the LLM with all provided arguments
llm = LLM(**vars(args))
# Example query for demonstration
query = "how much protein should a female eat?"
# Example documents to be reranked based on relevance to the query
documents = [
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.",
"Calorie intake should not fall below 1,200 a day in women or 1,500 a day in men, except under the supervision of a health professional.",
]
# Load the appropriate chat template for the selected model
# The template formats query-document pairs for the reranking model
chat_template = get_chat_template(args.model)
# Score documents based on relevance to the query
# The score method returns relevance scores for each document
outputs = llm.score(query, documents, chat_template=chat_template)
# Display the relevance scores
# Higher scores indicate more relevant documents
print("-" * 30)
print([output.outputs.score for output in outputs])
print("-" * 30)
if __name__ == "__main__":
args = parse_args()
main(args)
...@@ -4,18 +4,37 @@ ...@@ -4,18 +4,37 @@
""" """
Example of using the rerank API with template. Example of using the rerank API with template.
This script demonstrates how to interact with a vLLM server running
a reranking model via the REST API.
Before running this script, start the vLLM server with one of the
supported reranking models using the commands below.
note:
Some reranking models require special configuration overrides to work correctly
with vLLM's score API.
Reference: https://github.com/vllm-project/vllm/blob/main/examples/pooling/score/qwen3_reranker_online.py
Reference: https://github.com/vllm-project/vllm/blob/main/examples/pooling/score/convert_model_to_seq_cls.py
run: run:
vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}' --chat-template examples/pooling/score/template/bge-reranker-v2-gemma.jinja
vllm serve tomaarsen/Qwen3-Reranker-0.6B-seq-cls --chat-template examples/pooling/score/template/qwen3_reranker.jinja
vllm serve mixedbread-ai/mxbai-rerank-base-v2 --hf_overrides '{"architectures": ["Qwen2ForSequenceClassification"],"classifier_from_token": ["0", "1"], "method": "from_2_way_softmax"}' --chat-template examples/pooling/score/template/mxbai_rerank_v2.jinja
vllm serve nvidia/llama-nemotron-rerank-1b-v2 --runner pooling --trust-remote-code --chat-template examples/pooling/score/template/nemotron-rerank.jinja vllm serve nvidia/llama-nemotron-rerank-1b-v2 --runner pooling --trust-remote-code --chat-template examples/pooling/score/template/nemotron-rerank.jinja
vllm serve Qwen/Qwen3-Reranker-0.6B --runner pooling --hf_overrides '{"architectures": ["Qwen3ForSequenceClassification"],"classifier_from_token": ["no", "yes"],"is_original_qwen3_reranker": true}' --chat-template examples/pooling/score/template/qwen3_reranker.jinja
""" """
import json import json
import requests import requests
# URL of the vLLM server's rerank endpoint
# Default vLLM server runs on localhost port 8000
url = "http://127.0.0.1:8000/rerank" url = "http://127.0.0.1:8000/rerank"
# HTTP headers for the request
headers = {"accept": "application/json", "Content-Type": "application/json"} headers = {"accept": "application/json", "Content-Type": "application/json"}
# Example query & documents
query = "how much protein should a female eat?" query = "how much protein should a female eat?"
documents = [ documents = [
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.", "As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
...@@ -23,21 +42,31 @@ documents = [ ...@@ -23,21 +42,31 @@ documents = [
"Calorie intake should not fall below 1,200 a day in women or 1,500 a day in men, except under the supervision of a health professional.", "Calorie intake should not fall below 1,200 a day in women or 1,500 a day in men, except under the supervision of a health professional.",
] ]
# Request payload for the rerank API
data = { data = {
"model": "nvidia/llama-nemotron-rerank-1b-v2", "model": "nvidia/llama-nemotron-rerank-1b-v2", # Model to use for reranking
"query": query, "query": query, # The query to score documents against
"documents": documents, "documents": documents, # List of documents to be scored
} }
def main(): def main():
"""Main function to send a rerank request to the vLLM server.
This function sends a POST request to the /rerank endpoint with
the query and documents, then prints the relevance scores.
"""
# Send POST request to the vLLM server's rerank endpoint
response = requests.post(url, headers=headers, json=data) response = requests.post(url, headers=headers, json=data)
# Check the response # Check if the request was successful
if response.status_code == 200: if response.status_code == 200:
print("Request successful!") print("Request successful!")
# Pretty print the JSON response containing relevance scores
# The response includes scores for each document's relevance to the query
print(json.dumps(response.json(), indent=2)) print(json.dumps(response.json(), indent=2))
else: else:
# Handle request failure
print(f"Request failed with status code: {response.status_code}") print(f"Request failed with status code: {response.status_code}")
print(response.text) print(response.text)
......
...@@ -45,7 +45,11 @@ from transformers import ( ...@@ -45,7 +45,11 @@ from transformers import (
) )
from transformers.models.auto.auto_factory import _BaseAutoModelClass from transformers.models.auto.auto_factory import _BaseAutoModelClass
from tests.models.utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs from tests.models.utils import (
TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
softmax,
)
from vllm import LLM, SamplingParams, envs from vllm import LLM, SamplingParams, envs
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
...@@ -513,7 +517,7 @@ class HfRunner: ...@@ -513,7 +517,7 @@ class HfRunner:
elif problem_type == "multi_label_classification": elif problem_type == "multi_label_classification":
logits = output.logits.sigmoid()[0].tolist() logits = output.logits.sigmoid()[0].tolist()
else: else:
logits = output.logits.softmax(dim=-1)[0].tolist() logits = softmax(output.logits)[0].tolist()
outputs.append(logits) outputs.append(logits)
return outputs return outputs
......
...@@ -3,13 +3,16 @@ ...@@ -3,13 +3,16 @@
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import Any
import mteb import mteb
import numpy as np import numpy as np
import requests import requests
import torch
from mteb.models import ModelMeta from mteb.models import ModelMeta
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tests.conftest import HfRunner
from tests.models.utils import ( from tests.models.utils import (
RerankModelInfo, RerankModelInfo,
get_vllm_extra_kwargs, get_vllm_extra_kwargs,
...@@ -67,6 +70,12 @@ class VllmMtebCrossEncoder(MtebCrossEncoderMixin): ...@@ -67,6 +70,12 @@ class VllmMtebCrossEncoder(MtebCrossEncoderMixin):
queries = [text for batch in inputs1 for text in batch["text"]] queries = [text for batch in inputs1 for text in batch["text"]]
corpus = [text for batch in inputs2 for text in batch["text"]] corpus = [text for batch in inputs2 for text in batch["text"]]
# Hoping to discover potential scheduling
# issues by randomizing the order.
r = self.rng.permutation(len(queries))
queries = [queries[i] for i in r]
corpus = [corpus[i] for i in r]
outputs = self.llm.score( outputs = self.llm.score(
queries, queries,
corpus, corpus,
...@@ -75,6 +84,7 @@ class VllmMtebCrossEncoder(MtebCrossEncoderMixin): ...@@ -75,6 +84,7 @@ class VllmMtebCrossEncoder(MtebCrossEncoderMixin):
chat_template=self.chat_template, chat_template=self.chat_template,
) )
scores = np.array(outputs) scores = np.array(outputs)
scores = scores[np.argsort(r)]
return scores return scores
...@@ -84,7 +94,6 @@ class ScoreClientMtebEncoder(MtebCrossEncoderMixin): ...@@ -84,7 +94,6 @@ class ScoreClientMtebEncoder(MtebCrossEncoderMixin):
def __init__(self, model_name: str, url): def __init__(self, model_name: str, url):
self.model_name = model_name self.model_name = model_name
self.url = url self.url = url
self.rng = np.random.default_rng(seed=42)
def predict( def predict(
self, self,
...@@ -130,6 +139,50 @@ class RerankClientMtebEncoder(ScoreClientMtebEncoder): ...@@ -130,6 +139,50 @@ class RerankClientMtebEncoder(ScoreClientMtebEncoder):
return response["results"][0]["relevance_score"] return response["results"][0]["relevance_score"]
class HFMtebCrossEncoder(MtebCrossEncoderMixin, HfRunner):
chat_template: str | None = None
def __init__(self, model_name: str, dtype: str = "auto", **kwargs: Any) -> None:
HfRunner.__init__(
self, model_name=model_name, is_cross_encoder=True, dtype=dtype, **kwargs
)
@torch.no_grad
def predict(
self,
inputs1: DataLoader[mteb.types.BatchedInput],
inputs2: DataLoader[mteb.types.BatchedInput],
*args,
**kwargs,
) -> np.ndarray:
queries = [text for batch in inputs1 for text in batch["text"]]
corpus = [text for batch in inputs2 for text in batch["text"]]
if self.chat_template is not None:
tokenizer = self.model.tokenizer
prompts = []
for query, document in zip(queries, corpus):
conversation = [
{"role": "query", "content": query},
{"role": "document", "content": document},
]
prompt = tokenizer.apply_chat_template(
conversation=conversation,
tools=None,
chat_template=self.chat_template,
tokenize=False,
)
prompts.append(prompt)
outputs_list = HfRunner.classify(self, prompts)
scores = np.array(outputs_list).squeeze(-1)
return scores
else:
prompts = list(zip(queries, corpus))
outputs_tensor = HfRunner.predict(self, prompts, show_progress_bar=False)
return outputs_tensor.cpu().numpy()
def run_mteb_rerank(cross_encoder: mteb.CrossEncoderProtocol, tasks, languages): def run_mteb_rerank(cross_encoder: mteb.CrossEncoderProtocol, tasks, languages):
with tempfile.TemporaryDirectory() as prediction_folder: with tempfile.TemporaryDirectory() as prediction_folder:
bm25s = mteb.get_model("bm25s") bm25s = mteb.get_model("bm25s")
...@@ -168,31 +221,21 @@ def run_mteb_rerank(cross_encoder: mteb.CrossEncoderProtocol, tasks, languages): ...@@ -168,31 +221,21 @@ def run_mteb_rerank(cross_encoder: mteb.CrossEncoderProtocol, tasks, languages):
return main_score return main_score
def mteb_test_rerank_models_hf(
hf_runner, model_name, hf_dtype="float32", hf_model_callback=None
):
with hf_runner(model_name, is_cross_encoder=True, dtype=hf_dtype) as hf_model:
if hf_model_callback is not None:
hf_model_callback(hf_model)
st_main_score = run_mteb_rerank(
hf_model, tasks=MTEB_RERANK_TASKS, languages=MTEB_RERANK_LANGS
)
st_dtype = next(hf_model.model.model.parameters()).dtype
return st_main_score, st_dtype
def mteb_test_rerank_models( def mteb_test_rerank_models(
hf_runner,
vllm_runner, vllm_runner,
model_info: RerankModelInfo, model_info: RerankModelInfo,
hf_runner=HFMtebCrossEncoder,
vllm_extra_kwargs=None, vllm_extra_kwargs=None,
hf_model_callback=None,
vllm_mteb_encoder=VllmMtebCrossEncoder, vllm_mteb_encoder=VllmMtebCrossEncoder,
atol=MTEB_RERANK_TOL, atol=MTEB_RERANK_TOL,
): ):
vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs) vllm_extra_kwargs = get_vllm_extra_kwargs(model_info, vllm_extra_kwargs)
# Maybe load chat_template.
chat_template: str | None = None
if model_info.chat_template_name is not None:
chat_template = (template_home / model_info.chat_template_name).read_text()
with vllm_runner( with vllm_runner(
model_info.name, model_info.name,
runner="pooling", runner="pooling",
...@@ -201,6 +244,7 @@ def mteb_test_rerank_models( ...@@ -201,6 +244,7 @@ def mteb_test_rerank_models(
**vllm_extra_kwargs, **vllm_extra_kwargs,
) as vllm_model: ) as vllm_model:
model_config = vllm_model.llm.llm_engine.model_config model_config = vllm_model.llm.llm_engine.model_config
vllm_model.chat_template = chat_template
# Confirm whether vllm is using the correct architecture # Confirm whether vllm is using the correct architecture
if model_info.architecture: if model_info.architecture:
...@@ -209,12 +253,6 @@ def mteb_test_rerank_models( ...@@ -209,12 +253,6 @@ def mteb_test_rerank_models(
# Score API is only enabled for num_labels == 1 # Score API is only enabled for num_labels == 1
assert model_config.hf_config.num_labels == 1 assert model_config.hf_config.num_labels == 1
# Maybe load chat_template.
chat_template: str | None = None
if model_info.chat_template_name is not None:
chat_template = (template_home / model_info.chat_template_name).read_text()
vllm_model.chat_template = chat_template
# Confirm whether the important configs in model_config are correct. # Confirm whether the important configs in model_config are correct.
if model_info.pooling_type is not None: if model_info.pooling_type is not None:
assert model_config.pooler_config.pooling_type == model_info.pooling_type assert model_config.pooler_config.pooling_type == model_info.pooling_type
...@@ -242,9 +280,14 @@ def mteb_test_rerank_models( ...@@ -242,9 +280,14 @@ def mteb_test_rerank_models(
# Accelerate mteb test by setting # Accelerate mteb test by setting
# SentenceTransformers mteb score to a constant # SentenceTransformers mteb score to a constant
if model_info.mteb_score is None: if model_info.mteb_score is None:
st_main_score, st_dtype = mteb_test_rerank_models_hf( with hf_runner(model_info.name, dtype=model_info.hf_dtype) as hf_model:
hf_runner, model_info.name, model_info.hf_dtype, hf_model_callback hf_model.chat_template = chat_template
) st_main_score = run_mteb_rerank(
hf_model,
tasks=MTEB_RERANK_TASKS,
languages=MTEB_RERANK_LANGS,
)
st_dtype = next(hf_model.model.model.parameters()).dtype
else: else:
st_main_score = model_info.mteb_score st_main_score = model_info.mteb_score
st_dtype = "Constant" st_dtype = "Constant"
......
...@@ -112,7 +112,5 @@ def test_embed_models_correctness( ...@@ -112,7 +112,5 @@ def test_embed_models_correctness(
@pytest.mark.parametrize("model_info", RERANK_MODELS) @pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb( def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
hf_runner, vllm_runner, model_info: RerankModelInfo mteb_test_rerank_models(vllm_runner, model_info)
) -> None:
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
...@@ -11,40 +11,60 @@ from torch.utils.data import DataLoader ...@@ -11,40 +11,60 @@ from torch.utils.data import DataLoader
from tests.conftest import HfRunner from tests.conftest import HfRunner
from tests.models.utils import RerankModelInfo from tests.models.utils import RerankModelInfo
from .mteb_score_utils import VllmMtebCrossEncoder, mteb_test_rerank_models from .mteb_score_utils import (
MtebCrossEncoderMixin,
mteb_test_rerank_models,
)
RERANK_MODELS = [ RERANK_MODELS = [
RerankModelInfo( RerankModelInfo(
"BAAI/bge-reranker-v2-gemma", "BAAI/bge-reranker-v2-gemma",
architecture="GemmaForSequenceClassification", architecture="GemmaForSequenceClassification",
mteb_score=0.33757,
hf_overrides={ hf_overrides={
"architectures": ["GemmaForSequenceClassification"], "architectures": ["GemmaForSequenceClassification"],
"classifier_from_token": ["Yes"], "classifier_from_token": ["Yes"],
"method": "no_post_processing", "method": "no_post_processing",
}, },
mteb_score=0.33757,
pooling_type="LAST", pooling_type="LAST",
attn_type="decoder", attn_type="decoder",
is_prefix_caching_supported=True, is_prefix_caching_supported=True,
is_chunked_prefill_supported=True, is_chunked_prefill_supported=True,
chat_template_name="bge-reranker-v2-gemma.jinja",
), ),
] ]
PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501 PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501
class GemmaRerankerHfRunner(HfRunner): class GemmaRerankerHfRunner(MtebCrossEncoderMixin, HfRunner):
def __init__( def __init__(
self, model_name: str, dtype: str = "auto", *args: Any, **kwargs: Any self, model_name: str, dtype: str = "auto", *args: Any, **kwargs: Any
) -> None: ) -> None:
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM) HfRunner.__init__(
self,
model_name=model_name,
auto_cls=AutoModelForCausalLM,
dtype=dtype,
**kwargs,
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
self.yes_loc = self.tokenizer.convert_tokens_to_ids("Yes") self.yes_loc = self.tokenizer.convert_tokens_to_ids("Yes")
@torch.no_grad() @torch.no_grad
def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor: def predict(
self,
inputs1: DataLoader[mteb.types.BatchedInput],
inputs2: DataLoader[mteb.types.BatchedInput],
*args,
**kwargs,
) -> np.ndarray:
queries = [text for batch in inputs1 for text in batch["text"]]
corpus = [text for batch in inputs2 for text in batch["text"]]
def get_inputs(pairs, tokenizer, prompt=None): def get_inputs(pairs, tokenizer, prompt=None):
if prompt is None: if prompt is None:
prompt = PROMPT prompt = PROMPT
...@@ -89,8 +109,8 @@ class GemmaRerankerHfRunner(HfRunner): ...@@ -89,8 +109,8 @@ class GemmaRerankerHfRunner(HfRunner):
) )
scores = [] scores = []
for query, doc, *_ in prompts: for query, document in zip(queries, corpus):
pairs = [(query, doc)] pairs = [(query, document)]
inputs = get_inputs(pairs, self.tokenizer) inputs = get_inputs(pairs, self.tokenizer)
inputs = inputs.to(self.model.device) inputs = inputs.to(self.model.device)
_n_tokens = inputs["input_ids"].shape[1] _n_tokens = inputs["input_ids"].shape[1]
...@@ -107,41 +127,10 @@ class GemmaRerankerHfRunner(HfRunner): ...@@ -107,41 +127,10 @@ class GemmaRerankerHfRunner(HfRunner):
return torch.Tensor(scores) return torch.Tensor(scores)
class GemmaMtebEncoder(VllmMtebCrossEncoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.query_template = "A: {query}\n"
self.document_template = "B: {doc}\n{prompt}"
def predict(
self,
inputs1: DataLoader[mteb.types.BatchedInput],
inputs2: DataLoader[mteb.types.BatchedInput],
*args,
**kwargs,
) -> np.ndarray:
queries = [
self.query_template.format(query=text)
for batch in inputs1
for text in batch["text"]
]
corpus = [
self.document_template.format(doc=text, prompt=PROMPT)
for batch in inputs2
for text in batch["text"]
]
outputs = self.llm.score(
queries, corpus, truncate_prompt_tokens=-1, use_tqdm=False
)
scores = np.array(outputs)
return scores
@pytest.mark.parametrize("model_info", RERANK_MODELS) @pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
mteb_test_rerank_models( mteb_test_rerank_models(
GemmaRerankerHfRunner,
vllm_runner, vllm_runner,
model_info, model_info,
vllm_mteb_encoder=GemmaMtebEncoder, hf_runner=GemmaRerankerHfRunner,
) )
...@@ -11,27 +11,26 @@ from .mteb_score_utils import mteb_test_rerank_models ...@@ -11,27 +11,26 @@ from .mteb_score_utils import mteb_test_rerank_models
RERANK_MODELS = [ RERANK_MODELS = [
RerankModelInfo( RerankModelInfo(
"cross-encoder/ms-marco-TinyBERT-L-2-v2", "cross-encoder/ms-marco-TinyBERT-L-2-v2",
mteb_score=0.32898,
architecture="BertForSequenceClassification", architecture="BertForSequenceClassification",
pooling_type="CLS", pooling_type="CLS",
attn_type="encoder_only", attn_type="encoder_only",
is_prefix_caching_supported=False, is_prefix_caching_supported=False,
is_chunked_prefill_supported=False, is_chunked_prefill_supported=False,
mteb_score=0.32898,
), ),
RerankModelInfo( RerankModelInfo(
"tomaarsen/Qwen3-Reranker-0.6B-seq-cls", "tomaarsen/Qwen3-Reranker-0.6B-seq-cls",
mteb_score=0.25736,
architecture="Qwen3ForSequenceClassification", architecture="Qwen3ForSequenceClassification",
pooling_type="LAST", pooling_type="LAST",
attn_type="decoder", attn_type="decoder",
is_prefix_caching_supported=True, is_prefix_caching_supported=True,
is_chunked_prefill_supported=True, is_chunked_prefill_supported=True,
chat_template_name="qwen3_reranker.jinja",
mteb_score=0.33459,
), ),
] ]
@pytest.mark.parametrize("model_info", RERANK_MODELS) @pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb( def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
hf_runner, vllm_runner, model_info: RerankModelInfo mteb_test_rerank_models(vllm_runner, model_info)
) -> None:
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
...@@ -143,7 +143,5 @@ def test_embed_models_correctness( ...@@ -143,7 +143,5 @@ def test_embed_models_correctness(
@pytest.mark.parametrize("model_info", RERANK_MODELS) @pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb( def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
hf_runner, vllm_runner, model_info: RerankModelInfo mteb_test_rerank_models(vllm_runner, model_info)
) -> None:
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
...@@ -72,10 +72,8 @@ def test_embed_models_correctness( ...@@ -72,10 +72,8 @@ def test_embed_models_correctness(
@pytest.mark.parametrize("model_info", RERANK_MODELS) @pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb( def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
hf_runner, vllm_runner, model_info: RerankModelInfo mteb_test_rerank_models(vllm_runner, model_info)
) -> None:
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS) @pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
......
...@@ -2,13 +2,16 @@ ...@@ -2,13 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any from typing import Any
import mteb
import numpy as np
import pytest import pytest
import torch import torch
from torch.utils.data import DataLoader
from tests.conftest import HfRunner from tests.conftest import HfRunner
from tests.models.utils import RerankModelInfo from tests.models.utils import RerankModelInfo
from .mteb_score_utils import mteb_test_rerank_models from .mteb_score_utils import MtebCrossEncoderMixin, mteb_test_rerank_models
mxbai_rerank_hf_overrides = { mxbai_rerank_hf_overrides = {
"architectures": ["Qwen2ForSequenceClassification"], "architectures": ["Qwen2ForSequenceClassification"],
...@@ -21,50 +24,69 @@ RERANK_MODELS = [ ...@@ -21,50 +24,69 @@ RERANK_MODELS = [
"mixedbread-ai/mxbai-rerank-base-v2", "mixedbread-ai/mxbai-rerank-base-v2",
architecture="Qwen2ForSequenceClassification", architecture="Qwen2ForSequenceClassification",
hf_overrides=mxbai_rerank_hf_overrides, hf_overrides=mxbai_rerank_hf_overrides,
mteb_score=0.273,
pooling_type="LAST", pooling_type="LAST",
attn_type="decoder", attn_type="decoder",
is_prefix_caching_supported=True, is_prefix_caching_supported=True,
is_chunked_prefill_supported=True, is_chunked_prefill_supported=True,
chat_template_name="mxbai_rerank_v2.jinja",
mteb_score=0.33651,
enable_test=True, enable_test=True,
), ),
RerankModelInfo( RerankModelInfo(
"mixedbread-ai/mxbai-rerank-large-v2", "mixedbread-ai/mxbai-rerank-large-v2",
architecture="Qwen2ForSequenceClassification", architecture="Qwen2ForSequenceClassification",
hf_overrides=mxbai_rerank_hf_overrides, hf_overrides=mxbai_rerank_hf_overrides,
chat_template_name="mxbai_rerank_v2.jinja",
enable_test=False, enable_test=False,
), ),
] ]
class MxbaiRerankerHfRunner(HfRunner): class MxbaiRerankerHfRunner(MtebCrossEncoderMixin, HfRunner):
def __init__( def __init__(
self, model_name: str, dtype: str = "auto", *args: Any, **kwargs: Any self, model_name: str, dtype: str = "auto", *args: Any, **kwargs: Any
) -> None: ) -> None:
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM) HfRunner.__init__(
self,
model_name=model_name,
auto_cls=AutoModelForCausalLM,
dtype=dtype,
**kwargs,
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
self.yes_loc = self.tokenizer.convert_tokens_to_ids("1") self.yes_loc = self.tokenizer.convert_tokens_to_ids("1")
self.no_loc = self.tokenizer.convert_tokens_to_ids("0") self.no_loc = self.tokenizer.convert_tokens_to_ids("0")
def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor: @torch.no_grad
def process_inputs(pairs): def predict(
inputs = self.tokenizer( self,
pairs, inputs1: DataLoader[mteb.types.BatchedInput],
padding=False, inputs2: DataLoader[mteb.types.BatchedInput],
truncation="longest_first", *args,
return_attention_mask=False, **kwargs,
) -> np.ndarray:
queries = [text for batch in inputs1 for text in batch["text"]]
corpus = [text for batch in inputs2 for text in batch["text"]]
tokenizer = self.tokenizer
prompts = []
for query, document in zip(queries, corpus):
conversation = [
{"role": "query", "content": query},
{"role": "document", "content": document},
]
prompt = tokenizer.apply_chat_template(
conversation=conversation,
tools=None,
chat_template=self.chat_template,
tokenize=False,
) )
for i, ele in enumerate(inputs["input_ids"]): prompts.append(prompt)
inputs["input_ids"][i] = ele
inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt")
for key in inputs:
inputs[key] = inputs[key].to(self.model.device)
return inputs
@torch.no_grad()
def compute_logits(inputs): def compute_logits(inputs):
logits = self.model(**inputs).logits[:, -1, :] logits = self.model(**inputs).logits[:, -1, :]
yes_logits = logits[:, self.yes_loc] yes_logits = logits[:, self.yes_loc]
...@@ -74,9 +96,9 @@ class MxbaiRerankerHfRunner(HfRunner): ...@@ -74,9 +96,9 @@ class MxbaiRerankerHfRunner(HfRunner):
return scores return scores
scores = [] scores = []
for query, doc, *_ in prompts: for prompt in prompts:
pairs = [(query, doc)] inputs = tokenizer([prompt], return_tensors="pt")
inputs = process_inputs(pairs) inputs = self.wrap_device(inputs)
score = compute_logits(inputs) score = compute_logits(inputs)
scores.append(score[0].item()) scores.append(score[0].item())
return torch.Tensor(scores) return torch.Tensor(scores)
...@@ -84,4 +106,4 @@ class MxbaiRerankerHfRunner(HfRunner): ...@@ -84,4 +106,4 @@ class MxbaiRerankerHfRunner(HfRunner):
@pytest.mark.parametrize("model_info", RERANK_MODELS) @pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
mteb_test_rerank_models(MxbaiRerankerHfRunner, vllm_runner, model_info) mteb_test_rerank_models(vllm_runner, model_info, hf_runner=MxbaiRerankerHfRunner)
...@@ -46,7 +46,5 @@ def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) - ...@@ -46,7 +46,5 @@ def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -
@pytest.mark.parametrize("model_info", RERANK_MODELS) @pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb( def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None:
hf_runner, vllm_runner, model_info: RerankModelInfo mteb_test_rerank_models(vllm_runner, model_info)
) -> None:
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment