Commit 9bcbaafc authored by zhuwenwen's avatar zhuwenwen
Browse files

[New Model]: Support Qwen3 Embedding & Reranker

parent 2f6f5bb3
......@@ -15,7 +15,7 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
| Llama4ForConditionalGeneration | Llama 4 | No/Yes | - | - | v0.8.5.post1 | No |
| QWenLMHeadModel | QWen,Qwen-VL | Yes | Yes | Yes | v0.5.0,Qwen-VL>=v0.6.2 | Yes |
| Qwen2ForCausalLM | QWen2,QWen1.5,CodeQwen1.5,DeepSeek-R1-Distill-Qwen,gte_Qwen2-1.5B-instruct | Yes | Yes | Yes | v0.5.0,gte>=v0.7.2 | Yes |
| Qwen3ForCausalLM | QWen3,Qwen3-Embedding | Yes | - | - | v0.8.4 | Yes |
| Qwen3ForCausalLM | QWen3,Qwen3-Embedding,Qwen3-Reranker | Yes | - | - | v0.8.4 | Yes |
| Qwen3MoeForCausalLM | QWen3MoE | Yes | - | - | v0.8.4 | Yes |
| ChatGLMModel | glm-4v-9b,chatglm3,chatglm2 | Yes | No | Yes | v0.5.0 | Yes |
| Glm4ForCausalLM | GLM-4-0414 | No/Yes | - | - | v0.8.5.post1 | Yes |
......@@ -36,6 +36,7 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
| LlavaForConditionalGeneration | LLaMA,LLaMA-2,LLaMA-3 | Yes | No | - | v0.6.2 | No |
| Qwen2VLForConditionalGeneration | Qwen2-VL | Yes | No | Yes | v0.6.2 | No |
| Qwen2_5_VLForConditionalGeneration | Qwen.5-VL | Yes | No | Yes | v0.7.2 | No |
| Mistral3ForConditionalGeneration | Mistral3 | Yes | No | - | v0.8.5.post1 | No |
| Gemma3ForConditionalGeneration | Gemma 3 | Yes | - | - | v0.8.5.post1 | No |
| MiniCPMV | MiniCPM-V | Yes | No | - | v0.6.2 | No |
| Phi3VForCausalLM | Phi-3.5-vision | Yes | No | - | v0.6.2 | No |
......
......@@ -636,6 +636,11 @@ you should explicitly specify the task type to ensure that the model is used in
* `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc.
* ✅︎
* ✅︎
- * `Qwen3Model`, `Qwen3ForCausalLM`
* Qwen3-based
* `Qwen/Qwen3-Embedding-0.6B`, etc.
* ✅︎
* ✅︎
- * `RobertaModel`, `RobertaForMaskedLM`
* RoBERTa-based
* `sentence-transformers/all-roberta-large-v1`, `sentence-transformers/all-roberta-large-v1`, etc.
......@@ -1282,4 +1287,4 @@ We have the following levels of testing for models:
1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to [models tests](https://github.com/vllm-project/vllm/blob/main/tests/models) for the models that have passed this test.
2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test.
3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to [functionality tests](gh-dir:tests) and [examples](gh-dir:examples) for the models that have passed this test.
4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category.
4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category.
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
from vllm import LLM
model_name = "Qwen3-Reranker-0.6B"
# What is the difference between the official original version and one
# that has been converted into a sequence classification model?
# Qwen3-Reranker is a language model that doing reranker by using the
# logits of "no" and "yes" tokens.
# It needs to computing 151669 tokens logits, making this method extremely
# inefficient, not to mention incompatible with the vllm score API.
# A method for converting the original model into a sequence classification
# model was proposed. See:https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3
# Models converted offline using this method can not only be more efficient
# and support the vllm score API, but also make the init parameters more
# concise, for example.
# model = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", task="score")
# If you want to load the official original version, the init parameters are
# as follows.
model = LLM(
model=model_name,
task="score",
hf_overrides={
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
},
)
# Why do we need hf_overrides for the official original version:
# vllm converts it to Qwen3ForSequenceClassification when loaded for
# better performance.
# - Firstly, we need using `"architectures": ["Qwen3ForSequenceClassification"],`
# to manually route to Qwen3ForSequenceClassification.
# - Then, we will extract the vector corresponding to classifier_from_token
# from lm_head using `"classifier_from_token": ["no", "yes"]`.
# - Third, we will convert these two vectors into one vector. The use of
# conversion logic is controlled by `using "is_original_qwen3_reranker": True`.
# Please use the query_template and document_template to format the query and
# document for better reranker results.
prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
query_template = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
document_template = "<Document>: {doc}{suffix}"
if __name__ == "__main__":
instruction = (
"Given a web search query, retrieve relevant passages that answer the query"
)
queries = [
"What is the capital of China?",
"Explain gravity",
]
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",
]
queries = [
query_template.format(prefix=prefix, instruction=instruction, query=query)
for query in queries
]
documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents]
outputs = model.score(queries, documents)
print([output.outputs.score for output in outputs])
\ No newline at end of file
......@@ -36,13 +36,15 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsCrossEncoding, SupportsLoRA, SupportsPP
from .qwen2 import Qwen2MLP as Qwen3MLP
from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix
......@@ -317,3 +319,122 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
class Qwen3ForSequenceClassification(nn.Module, SupportsLoRA,
SupportsCrossEncoding):
def __init__(
self,
vllm_config: "VllmConfig",
prefix: str = "",
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
self.vllm_config = vllm_config
self.config = config
self.quant_config = quant_config
self.prefix = prefix
self.model = Qwen3Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.score = RowParallelLinear(config.hidden_size,
config.num_labels,
quant_config=quant_config,
input_is_parallel=False,
bias=False,
prefix=maybe_prefix(prefix, "score"))
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=False,
softmax=True)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
hidden_states = self._pooler.extract_states(hidden_states,
pooling_metadata)
logits, _ = self.score(hidden_states)
pooled_data = self._pooler.head(logits, pooling_metadata)
pooled_outputs = [
self._pooler.build_output(data.squeeze(-1)) for data in pooled_data
]
return PoolerOutput(outputs=pooled_outputs)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
is_original_qwen3_reranker = getattr(self.config,
"is_original_qwen3_reranker",
False)
if not is_original_qwen3_reranker:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
return self.load_weights_from_original_qwen3_reranker(weights)
def load_weights_from_original_qwen3_reranker(
self, weights: Iterable[tuple[str, torch.Tensor]]):
tokens = getattr(self.config, "classifier_from_token", None)
assert tokens is not None and len(tokens) == 2, \
("Try loading the original Qwen3 Reranker?, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
self.config.num_labels = 1
model_config = self.vllm_config.model_config
device = self.score.weight.device
self.score = RowParallelLinear(self.config.hidden_size,
self.config.num_labels,
quant_config=self.quant_config,
input_is_parallel=False,
bias=False,
prefix=maybe_prefix(
self.prefix, "score")).to(device)
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(
self.prefix, "lm_head"))
loader = AutoWeightsLoader(self)
loaded_weights = loader.load_weights(weights)
from vllm.transformers_utils.tokenizer import get_tokenizer
tokenizer = get_tokenizer(
model_config.tokenizer,
revision=model_config.tokenizer_revision,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code)
a = tokenizer.convert_tokens_to_ids(tokens[0])
b = tokenizer.convert_tokens_to_ids(tokens[1])
weight = self.lm_head.weight.data[b].to(
device) - self.lm_head.weight.data[a].to(device)
self.score.weight.data.copy_(weight)
del self.lm_head
loaded_weights.add("classifier.weight")
loaded_weights.discard("lm_head.weight")
\ No newline at end of file
......@@ -168,6 +168,7 @@ _CROSS_ENCODER_MODELS = {
"RobertaForSequenceClassification"),
"ModernBertForSequenceClassification": ("modernbert",
"ModernBertForSequenceClassification"),
"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
}
_MULTIMODAL_MODELS = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment