Unverified Commit f6f71379 authored by James Xu's avatar James Xu Committed by GitHub
Browse files

Add support for Qwen2-VL-based embedding models (#2055)

parent f35cb46c
...@@ -37,7 +37,7 @@ The core features include: ...@@ -37,7 +37,7 @@ The core features include:
- **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (INT4/FP8/AWQ/GPTQ). - **Fast Backend Runtime**: Provides efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, chunked prefill, and quantization (INT4/FP8/AWQ/GPTQ).
- **Flexible Frontend Language**: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions. - **Flexible Frontend Language**: Offers an intuitive interface for programming LLM applications, including chained generation calls, advanced prompting, control flow, multi-modal inputs, parallelism, and external interactions.
- **Extensive Model Support**: Supports a wide range of generative models (Llama, Gemma, Mistral, QWen, DeepSeek, LLaVA, etc.), embedding models (e5-mistral, gte) and reward models (Skywork), with easy extensibility for integrating new models. - **Extensive Model Support**: Supports a wide range of generative models (Llama, Gemma, Mistral, QWen, DeepSeek, LLaVA, etc.), embedding models (e5-mistral, gte, mcdse) and reward models (Skywork), with easy extensibility for integrating new models.
- **Active Community**: SGLang is open-source and backed by an active community with industry adoption. - **Active Community**: SGLang is open-source and backed by an active community with industry adoption.
## Getting Started ## Getting Started
......
...@@ -44,6 +44,7 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import ( ...@@ -44,6 +44,7 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
) )
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.managers.schedule_batch import ImageInputs
...@@ -559,6 +560,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -559,6 +560,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor: def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype) pixel_values = image_input["pixel_values"].type(self.visual.dtype)
...@@ -577,6 +579,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -577,6 +579,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
get_embedding: bool = False,
): ):
"""Run forward pass for Qwen2-VL. """Run forward pass for Qwen2-VL.
...@@ -599,8 +602,8 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -599,8 +602,8 @@ class Qwen2VLForConditionalGeneration(nn.Module):
image_inputs = [ image_inputs = [
img for img in forward_batch.image_inputs if img is not None img for img in forward_batch.image_inputs if img is not None
] ]
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
positions = forward_batch.mrope_positions positions = forward_batch.mrope_positions
if ( if (
forward_batch.forward_mode.is_decode() forward_batch.forward_mode.is_decode()
or image_inputs is None or image_inputs is None
...@@ -655,9 +658,13 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -655,9 +658,13 @@ class Qwen2VLForConditionalGeneration(nn.Module):
forward_batch=forward_batch, forward_batch=forward_batch,
input_embeds=inputs_embeds, input_embeds=inputs_embeds,
) )
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch if not get_embedding:
) return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
else:
return self.pooler(hidden_states, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -58,6 +58,28 @@ def get_top_logprobs(logits, k): ...@@ -58,6 +58,28 @@ def get_top_logprobs(logits, k):
return logprobs return logprobs
def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import is_sentence_transformer_model
if is_sentence_transformer_model(model_path):
model = SentenceTransformer(
model_path,
model_kwargs={"torch_dtype": torch_dtype},
)
else: # if no pre-trained sentence-transformers model
from sentence_transformers import models
word_embedding_model = models.Transformer(model_path).to(dtype=torch_dtype)
pooling_model = models.Pooling(
word_embedding_model.get_word_embedding_dimension(),
pooling_mode="lasttoken",
)
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
return model.cuda()
@dataclass @dataclass
class ModelOutput: class ModelOutput:
output_strs: List[str] = None output_strs: List[str] = None
...@@ -114,12 +136,9 @@ class HFRunner: ...@@ -114,12 +136,9 @@ class HFRunner:
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
).cuda() ).cuda()
elif self.model_type == "embedding": elif self.model_type == "embedding":
from sentence_transformers import SentenceTransformer self.model = _get_sentence_transformer_embedding_model(
model_path, torch_dtype
self.model = SentenceTransformer( )
model_path,
model_kwargs={"torch_dtype": torch_dtype},
).cuda()
elif self.model_type == "reward": elif self.model_type == "reward":
from transformers import AutoModelForSequenceClassification from transformers import AutoModelForSequenceClassification
......
...@@ -25,6 +25,7 @@ from sglang.test.test_utils import get_similarities ...@@ -25,6 +25,7 @@ from sglang.test.test_utils import get_similarities
MODELS = [ MODELS = [
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5), ("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5),
("intfloat/e5-mistral-7b-instruct", 1, 1e-5), ("intfloat/e5-mistral-7b-instruct", 1, 1e-5),
("marco/mcdse-2b-v1", 1, 1e-5),
] ]
TORCH_DTYPES = [torch.float16] TORCH_DTYPES = [torch.float16]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment