Unverified Commit 65736dc5 authored by Zhihao Liu's avatar Zhihao Liu Committed by GitHub
Browse files

[Model] Support Qwen3ForSequenceClassification for Qwen3-Embed Model (#7957)

parent 7b56e494
...@@ -3,20 +3,53 @@ ...@@ -3,20 +3,53 @@
SGLang provides robust support for embedding models by integrating efficient serving mechanisms with its flexible programming interface. This integration allows for streamlined handling of embedding tasks, facilitating faster and more accurate retrieval and semantic search operations. SGLang's architecture enables better resource utilization and reduced latency in embedding model deployment. SGLang provides robust support for embedding models by integrating efficient serving mechanisms with its flexible programming interface. This integration allows for streamlined handling of embedding tasks, facilitating faster and more accurate retrieval and semantic search operations. SGLang's architecture enables better resource utilization and reduced latency in embedding model deployment.
```{important} ```{important}
They are executed with `--is-embedding` and some may require `--trust-remote-code` Embedding models are executed with `--is-embedding` flag and some may require `--trust-remote-code`
``` ```
## Example Launch Command ## Quick Start
### Launch Server
```shell ```shell
python3 -m sglang.launch_server \ python3 -m sglang.launch_server \
--model-path Alibaba-NLP/gme-Qwen2-VL-2B-Instruct \ --model-path Qwen/Qwen3-Embedding-4B \
--is-embedding \ --is-embedding \
--host 0.0.0.0 \ --host 0.0.0.0 \
--port 30000
```
### Client Request
```python
import requests
url = "http://127.0.0.1:30000"
payload = {
"model": "Qwen/Qwen3-Embedding-4B",
"input": "What is the capital of France?",
"encoding_format": "float"
}
response = requests.post(url + "/v1/embeddings", json=payload).json()
print("Embedding:", response["data"][0]["embedding"])
```
## Multimodal Embedding Example
For multimodal models like GME that support both text and images:
```shell
python3 -m sglang.launch_server \
--model-path Alibaba-NLP/gme-Qwen2-VL-2B-Instruct \
--is-embedding \
--chat-template gme-qwen2-vl \ --chat-template gme-qwen2-vl \
--host 0.0.0.0 \
--port 30000 --port 30000
``` ```
## Example Client Request
```python ```python
import requests import requests
...@@ -42,13 +75,13 @@ response = requests.post(url + "/v1/embeddings", json=payload).json() ...@@ -42,13 +75,13 @@ response = requests.post(url + "/v1/embeddings", json=payload).json()
print("Embeddings:", [x.get("embedding") for x in response.get("data", [])]) print("Embeddings:", [x.get("embedding") for x in response.get("data", [])])
``` ```
## Supported Models
## Supported models | Model Family | Example Model | Chat Template | Description |
| ------------------------------------------ | -------------------------------------- | ------------- | --------------------------------------------------------------------------- |
| Model Family (Embedding) | Example HuggingFace Identifier | Chat Template | Description | | **E5 (Llama/Mistral based)** | `intfloat/e5-mistral-7b-instruct` | N/A | High-quality text embeddings based on Mistral/Llama architectures |
|-------------------------------------------------|-----------------------------------------------|---------------|--------------------------------------------------------------------------------------------------------------------------------------| | **GTE-Qwen2** | `Alibaba-NLP/gte-Qwen2-7B-instruct` | N/A | Alibaba's text embedding model with multilingual support |
| **Llama/Mistral based (E5EmbeddingModel)** | `intfloat/e5-mistral-7b-instruct` | N/A | Mistral/Llama-based embedding model fine‑tuned for high‑quality text embeddings (top‑ranked on the MTEB benchmark). | | **Qwen3-Embedding** | `Qwen/Qwen3-Embedding-4B` | N/A | Latest Qwen3-based text embedding model for semantic representation |
| **GTE (QwenEmbeddingModel)** | `Alibaba-NLP/gte-Qwen2-7B-instruct` | N/A | Alibaba’s general text embedding model (7B), achieving state‑of‑the‑art multilingual performance in English and Chinese. | | **BGE** | `BAAI/bge-large-en-v1.5` | N/A | BAAI's text embeddings (requires `attention-backend` triton/torch_native) |
| **GME (MultimodalEmbedModel)** | `Alibaba-NLP/gme-Qwen2-VL-2B-Instruct` | `gme-qwen2-vl` | Multimodal embedding model (2B) based on Qwen2‑VL, encoding image + text into a unified vector space for cross‑modal retrieval. | | **GME (Multimodal)** | `Alibaba-NLP/gme-Qwen2-VL-2B-Instruct`| `gme-qwen2-vl`| Multimodal embedding for text and image cross-modal tasks |
| **CLIP (CLIPEmbeddingModel)** | `openai/clip-vit-large-patch14-336` | N/A | OpenAI’s CLIP model (ViT‑L/14) for embedding images (and text) into a joint latent space; widely used for image similarity search. | | **CLIP** | `openai/clip-vit-large-patch14-336` | N/A | OpenAI's CLIP for image and text embeddings |
| **BGE (BgeEmbeddingModel)** | `BAAI/bge-large-en-v1.5` | N/A | Currently only support `attention-backend` `triton` and `torch_native`. BAAI's BGE embedding models optimized for retrieval and reranking tasks. |
...@@ -642,6 +642,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal ...@@ -642,6 +642,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
or "InternLM2ForRewardModel" in model_architectures or "InternLM2ForRewardModel" in model_architectures
or "Qwen2ForRewardModel" in model_architectures or "Qwen2ForRewardModel" in model_architectures
or "Qwen2ForSequenceClassification" in model_architectures or "Qwen2ForSequenceClassification" in model_architectures
or "Qwen3ForSequenceClassification" in model_architectures
or "CLIPModel" in model_architectures or "CLIPModel" in model_architectures
or "BertModel" in model_architectures or "BertModel" in model_architectures
or "Contriever" in model_architectures or "Contriever" in model_architectures
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import Qwen2Config # Qwen3 uses Qwen2Config
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.qwen3 import Qwen3ForCausalLM, Qwen3Model
from sglang.srt.utils import add_prefix
class Qwen3ForSequenceClassification(nn.Module):
def __init__(
self,
config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = Qwen3Model(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
self.score = nn.Linear(config.hidden_size, config.num_labels)
# Use normalize=True for qwen3 embedding based on official implementation
# Reference: https://github.com/QwenLM/Qwen3-Embedding/blob/main/examples/qwen3_embedding_transformers.py#L55
# Official code: output = F.normalize(output, p=2, dim=1)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.eos_token_id = config.eos_token_id
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: Optional[torch.Tensor] = None,
get_embedding: bool = True,
) -> EmbeddingPoolerOutput:
assert (
get_embedding
), "Qwen3ForSequenceClassification is only used for embedding"
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
logits = self.score(hidden_states)
pooled_logits = self.pooler(logits, forward_batch).embeddings
return EmbeddingPoolerOutput(pooled_logits)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Filter out lm_head weights of Qwen3ForCausalLM
filtered_weights = [
(name, w) for name, w in weights if not name.startswith("lm_head")
]
return Qwen3ForCausalLM.load_weights(self, filtered_weights)
EntryClass = [
Qwen3ForSequenceClassification,
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment