"src/vscode:/vscode.git/clone" did not exist on "12358622e5637b7c4e01969b1089b66b92fb3d14"
Unverified Commit 361971b8 authored by uylnap's avatar uylnap Committed by GitHub
Browse files

Add Support for Qwen2-VL Multi-modal Embedding Models (#3694)

parent 13bc39c5
...@@ -38,6 +38,8 @@ ...@@ -38,6 +38,8 @@
- Mistral embedding models - Mistral embedding models
- Qwen embedding models - Qwen embedding models
- `python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct --is-embedding` - `python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct --is-embedding`
- Multi-modal embedding models
- `python -m sglang.launch_server --model-path Alibaba-NLP/gme-Qwen2-VL-2B-Instruct --is-embedding --chat-template gme-qwen2-vl`
## Reward Models ## Reward Models
......
...@@ -44,6 +44,7 @@ class SeparatorStyle(IntEnum): ...@@ -44,6 +44,7 @@ class SeparatorStyle(IntEnum):
CHATGLM3 = auto() CHATGLM3 = auto()
DEEPSEEK_CHAT = auto() DEEPSEEK_CHAT = auto()
METAMATH = auto() METAMATH = auto()
QWEN2_VL_EMBED = auto()
@dataclasses.dataclass @dataclasses.dataclass
...@@ -110,6 +111,15 @@ class Conversation: ...@@ -110,6 +111,15 @@ class Conversation:
else: else:
ret += role + "\n" ret += role + "\n"
return ret return ret
elif self.sep_style == SeparatorStyle.QWEN2_VL_EMBED:
ret = "" if system_prompt == "" else system_prompt + self.sep
for role, message in self.messages:
if message:
ret += role + "\n" + message + self.sep
else:
ret += role + "\n"
ret += self.stop_str
return ret
elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
ret = system_prompt ret = system_prompt
for role, message in self.messages: for role, message in self.messages:
...@@ -366,6 +376,46 @@ def chat_template_exists(template_name: str) -> bool: ...@@ -366,6 +376,46 @@ def chat_template_exists(template_name: str) -> bool:
return template_name in chat_templates return template_name in chat_templates
def generate_embedding_convs(
texts: List[str], images: List[str], template_name: str
) -> List[Conversation]:
conv_template = chat_templates[template_name].copy()
convs = []
for text, image in zip(texts, images):
conv = Conversation(
name=conv_template.name,
system_template=conv_template.system_template,
system_message=conv_template.system_message,
roles=conv_template.roles,
messages=list(conv_template.messages), # prevent in-place modification
offset=conv_template.offset,
sep_style=SeparatorStyle(conv_template.sep_style),
sep=conv_template.sep,
sep2=conv_template.sep2,
stop_str=conv_template.stop_str,
image_data=[],
modalities=[],
image_token=conv_template.image_token,
)
real_content = ""
if image is not None:
image_token = (
conv.image_token + "\n"
if conv.name != "gme-qwen2-vl"
else conv.image_token
)
real_content += image_token
if text is not None:
real_content += text
conv.append_message(conv.roles[0], real_content)
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
convs.append(conv)
return convs
def generate_chat_conv( def generate_chat_conv(
request: ChatCompletionRequest, template_name: str request: ChatCompletionRequest, template_name: str
) -> Conversation: ) -> Conversation:
...@@ -555,6 +605,20 @@ register_conv_template( ...@@ -555,6 +605,20 @@ register_conv_template(
) )
) )
# Reference: https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct#usage
register_conv_template(
Conversation(
name="gme-qwen2-vl",
system_message="You are a helpful assistant.",
system_template="<|im_start|>system\n{system_message}",
roles=("<|im_start|>user", "<|im_start|>assistant"),
sep="<|im_end|>\n",
sep_style=SeparatorStyle.QWEN2_VL_EMBED,
stop_str="<|endoftext|>",
image_token="<|vision_start|><|image_pad|><|vision_end|>",
)
)
# Reference: https://huggingface.co/openbmb/MiniCPM-V-2_6#usage # Reference: https://huggingface.co/openbmb/MiniCPM-V-2_6#usage
register_conv_template( register_conv_template(
Conversation( Conversation(
......
...@@ -214,13 +214,13 @@ class Engine: ...@@ -214,13 +214,13 @@ class Engine:
def encode( def encode(
self, self,
prompt: Union[str, List[str], List[Dict], List[List[Dict]]], prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
image_data: Optional[Union[List[str], str]] = None,
) -> Dict: ) -> Dict:
""" """
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`. The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
Please refer to `EmbeddingReqInput` for the documentation. Please refer to `EmbeddingReqInput` for the documentation.
""" """
obj = EmbeddingReqInput(text=prompt, image_data=image_data)
obj = EmbeddingReqInput(text=prompt)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
generator = self.tokenizer_manager.generate_request(obj, None) generator = self.tokenizer_manager.generate_request(obj, None)
ret = loop.run_until_complete(generator.__anext__()) ret = loop.run_until_complete(generator.__anext__())
......
...@@ -293,6 +293,8 @@ class TokenizedGenerateReqInput: ...@@ -293,6 +293,8 @@ class TokenizedGenerateReqInput:
class EmbeddingReqInput: class EmbeddingReqInput:
# The input prompt. It can be a single prompt or a batch of prompts. # The input prompt. It can be a single prompt or a batch of prompts.
text: Optional[Union[List[str], str]] = None text: Optional[Union[List[str], str]] = None
# The image input. It can be a file name, a url, or base64 encoded string.
image_data: Optional[Union[List[str], str]] = None
# The token ids for text; one can either specify text or input_ids. # The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The request id. # The request id.
...@@ -303,28 +305,40 @@ class EmbeddingReqInput: ...@@ -303,28 +305,40 @@ class EmbeddingReqInput:
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics) # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
log_metrics: bool = True log_metrics: bool = True
# The modalities of the image data [image, multi-images, video]
modalities: Optional[List[str]] = None
def normalize_batch_and_arguments(self): def normalize_batch_and_arguments(self):
if (self.text is None and self.input_ids is None) or ( # at least one of text, input_ids, or image should be provided
self.text is not None and self.input_ids is not None if self.text is None and self.input_ids is None and self.image_data is None:
): raise ValueError(
raise ValueError("Either text or input_ids should be provided.") "At least one of text, input_ids, or image should be provided"
)
# text and input_ids cannot be provided at the same time
if self.text is not None and self.input_ids is not None:
raise ValueError("text and input_ids cannot be provided at the same time")
# Derive the batch size # Derive the batch size
self.batch_size = 0
self.is_single = True
# check the batch size of text
if self.text is not None: if self.text is not None:
if isinstance(self.text, str): if isinstance(self.text, list):
self.is_single = True self.batch_size += len(self.text)
self.batch_size = 1
else: else:
self.is_single = False self.batch_size += 1
self.batch_size = len(self.text)
else: # check the batch size of input_ids
if isinstance(self.input_ids[0], int): if self.input_ids is not None:
self.is_single = True if isinstance(self.input_ids[0], list):
self.batch_size = 1 self.batch_size += len(self.input_ids)
else: else:
self.is_single = False self.batch_size += 1
self.batch_size = len(self.input_ids)
if self.batch_size > 1:
self.is_single = False
# Fill in default arguments # Fill in default arguments
if self.is_single: if self.is_single:
...@@ -352,6 +366,7 @@ class EmbeddingReqInput: ...@@ -352,6 +366,7 @@ class EmbeddingReqInput:
return EmbeddingReqInput( return EmbeddingReqInput(
text=self.text[i] if self.text is not None else None, text=self.text[i] if self.text is not None else None,
input_ids=self.input_ids[i] if self.input_ids is not None else None, input_ids=self.input_ids[i] if self.input_ids is not None else None,
image_data=self.image_data[i] if self.image_data is not None else None,
sampling_params=self.sampling_params[i], sampling_params=self.sampling_params[i],
rid=self.rid[i], rid=self.rid[i],
) )
...@@ -365,6 +380,8 @@ class TokenizedEmbeddingReqInput: ...@@ -365,6 +380,8 @@ class TokenizedEmbeddingReqInput:
input_text: str input_text: str
# The input token ids # The input token ids
input_ids: List[int] input_ids: List[int]
# The image inputs
image_inputs: dict
# Dummy sampling params for compatibility # Dummy sampling params for compatibility
sampling_params: SamplingParams sampling_params: SamplingParams
......
...@@ -767,6 +767,30 @@ class Scheduler: ...@@ -767,6 +767,30 @@ class Scheduler:
) )
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
# Handle multimodal inputs
if recv_req.image_inputs is not None:
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
# Expand a single image token into multiple dummy tokens for receiving image embeddings
req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids, image_inputs
)
req.extend_image_inputs(image_inputs)
if len(req.origin_input_ids) >= self.max_req_input_len:
error_msg = (
"Multimodal prompt is too long after expanding multimodal tokens. "
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
)
logger.error(error_msg)
req.origin_input_ids = [0]
req.image_inputs = None
req.sampling_params.max_new_tokens = 0
req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
)
self.waiting_queue.append(req)
return
# Validate prompts length # Validate prompts length
error_msg = validate_input_length( error_msg = validate_input_length(
req, req,
......
...@@ -372,13 +372,12 @@ class TokenizerManager: ...@@ -372,13 +372,12 @@ class TokenizerManager:
) )
input_ids = self.tokenizer.encode(input_text) input_ids = self.tokenizer.encode(input_text)
image_inputs: Dict = await self.image_processor.process_images_async(
obj.image_data, input_text or input_ids, obj, self.max_req_input_len
)
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
if self.is_generation: if self.is_generation:
# TODO: also support getting embeddings for multimodal models
image_inputs: Dict = await self.image_processor.process_images_async(
obj.image_data, input_text or input_ids, obj, self.max_req_input_len
)
if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"]
return_logprob = obj.return_logprob return_logprob = obj.return_logprob
logprob_start_len = obj.logprob_start_len logprob_start_len = obj.logprob_start_len
top_logprobs_num = obj.top_logprobs_num top_logprobs_num = obj.top_logprobs_num
...@@ -438,6 +437,7 @@ class TokenizerManager: ...@@ -438,6 +437,7 @@ class TokenizerManager:
obj.rid, obj.rid,
input_text, input_text,
input_ids, input_ids,
image_inputs,
sampling_params, sampling_params,
) )
......
...@@ -38,6 +38,7 @@ from sglang.srt.conversation import ( ...@@ -38,6 +38,7 @@ from sglang.srt.conversation import (
SeparatorStyle, SeparatorStyle,
chat_template_exists, chat_template_exists,
generate_chat_conv, generate_chat_conv,
generate_embedding_convs,
register_conv_template, register_conv_template,
) )
from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
...@@ -68,6 +69,7 @@ from sglang.srt.openai_api.protocol import ( ...@@ -68,6 +69,7 @@ from sglang.srt.openai_api.protocol import (
FileResponse, FileResponse,
FunctionResponse, FunctionResponse,
LogProbs, LogProbs,
MultimodalEmbeddingInput,
ToolCall, ToolCall,
TopLogprob, TopLogprob,
UsageInfo, UsageInfo,
...@@ -1556,11 +1558,37 @@ def v1_embedding_request(all_requests, tokenizer_manager): ...@@ -1556,11 +1558,37 @@ def v1_embedding_request(all_requests, tokenizer_manager):
prompt = prompts[0] prompt = prompts[0]
if isinstance(prompt, str) or isinstance(prompt[0], str): if isinstance(prompt, str) or isinstance(prompt[0], str):
prompt_kwargs = {"text": prompt} prompt_kwargs = {"text": prompt}
elif isinstance(prompt, list) and isinstance(
prompt[0], MultimodalEmbeddingInput
):
assert (
chat_template_name is not None
), "chat_template_name is required for multimodal inputs"
texts = []
images = []
for item in prompt:
texts.append(item.text if item.text is not None else None)
images.append(item.image if item.image is not None else None)
convs = generate_embedding_convs(texts, images, chat_template_name)
generate_prompts = []
for conv in convs:
generate_prompts.append(conv.get_prompt())
if len(generate_prompts) == 1:
prompt_kwargs = {"text": generate_prompts[0], "image_data": images[0]}
else:
prompt_kwargs = {"text": generate_prompts, "image_data": images}
else: else:
prompt_kwargs = {"input_ids": prompt} prompt_kwargs = {"input_ids": prompt}
else: else:
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str): if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
prompt_kwargs = {"text": prompts} prompt_kwargs = {"text": prompts}
elif isinstance(prompts[0], list) and isinstance(
prompts[0][0], MultimodalEmbeddingInput
):
# TODO: multiple requests
raise NotImplementedError(
"Multiple requests with multimodal inputs are not supported yet"
)
else: else:
prompt_kwargs = {"input_ids": prompts} prompt_kwargs = {"input_ids": prompts}
......
...@@ -403,10 +403,17 @@ class ChatCompletionStreamResponse(BaseModel): ...@@ -403,10 +403,17 @@ class ChatCompletionStreamResponse(BaseModel):
usage: Optional[UsageInfo] = None usage: Optional[UsageInfo] = None
class MultimodalEmbeddingInput(BaseModel):
text: Optional[str] = None
image: Optional[str] = None
class EmbeddingRequest(BaseModel): class EmbeddingRequest(BaseModel):
# Ordered by official OpenAI API documentation # Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings/create # https://platform.openai.com/docs/api-reference/embeddings/create
input: Union[List[int], List[List[int]], str, List[str]] input: Union[
List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]
]
model: str model: str
encoding_format: str = "float" encoding_format: str = "float"
dimensions: int = None dimensions: int = None
......
...@@ -19,7 +19,7 @@ from typing import List, Optional, Tuple, Union ...@@ -19,7 +19,7 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.server import Engine from sglang.srt.server import Engine
...@@ -135,6 +135,76 @@ class HFRunner: ...@@ -135,6 +135,76 @@ class HFRunner:
return True return True
return False return False
# copy from https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct/blob/main/gme_inference.py
def _get_gme_qwen2_vl_embeddings(
self, prompts, image_data: Optional[List[str]] = None
):
from sglang.srt.utils import load_image
images = None
if image_data is not None:
images = [load_image(image)[0] for image in image_data]
inputs = self.processor(
text=prompts,
images=images,
padding=True,
truncation=True,
max_length=1800,
return_tensors="pt",
)
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
with torch.no_grad():
embeddings = self._forward_gme_qwen2_vl(**inputs)
return embeddings.tolist()
def _forward_gme_qwen2_vl(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
pooling_mask: Optional[torch.LongTensor] = None,
**kwargs,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.model.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.model.visual.get_dtype())
image_embeds = self.model.visual(
pixel_values, grid_thw=image_grid_thw
).to(inputs_embeds.device)
image_mask = input_ids == self.model.config.image_token_id
inputs_embeds[image_mask] = image_embeds
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
outputs = self.model.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
)
pooling_mask = attention_mask if pooling_mask is None else pooling_mask
left_padding = pooling_mask[:, -1].sum() == pooling_mask.shape[0] # TODO
if left_padding:
embeddings = outputs.last_hidden_state[:, -1]
else:
sequence_lengths = pooling_mask.sum(dim=1) - 1
batch_size = outputs.last_hidden_state.shape[0]
embeddings = outputs.last_hidden_state[
torch.arange(batch_size, device=outputs.last_hidden_state.device),
sequence_lengths,
]
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
return embeddings.contiguous()
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype): def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
# Apply model-specific patches # Apply model-specific patches
monkey_patch_gemma2_sdpa() monkey_patch_gemma2_sdpa()
...@@ -148,9 +218,18 @@ class HFRunner: ...@@ -148,9 +218,18 @@ class HFRunner:
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
).cuda() ).cuda()
elif self.model_type == "embedding": elif self.model_type == "embedding":
self.model = _get_sentence_transformer_embedding_model( if "gme-qwen2-vl" in model_path.lower():
model_path, torch_dtype self.model = AutoModelForVision2Seq.from_pretrained(
) model_path,
torch_dtype=torch_dtype,
trust_remote_code=False,
low_cpu_mem_usage=True,
).cuda()
self.processor = AutoProcessor.from_pretrained(model_path)
else:
self.model = _get_sentence_transformer_embedding_model(
model_path, torch_dtype
)
elif self.model_type == "reward": elif self.model_type == "reward":
from transformers import AutoModelForSequenceClassification from transformers import AutoModelForSequenceClassification
...@@ -169,7 +248,9 @@ class HFRunner: ...@@ -169,7 +248,9 @@ class HFRunner:
# Run forward # Run forward
while True: while True:
prompts, max_new_tokens, lora_paths, token_ids_logprob = in_queue.get() prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob = (
in_queue.get()
)
if lora_paths is not None: if lora_paths is not None:
assert len(prompts) == len(lora_paths) assert len(prompts) == len(lora_paths)
...@@ -189,7 +270,10 @@ class HFRunner: ...@@ -189,7 +270,10 @@ class HFRunner:
) )
elif self.model_type == "embedding": elif self.model_type == "embedding":
assert not self.output_str_only assert not self.output_str_only
logits = self.model.encode(prompts).tolist() if "gme-qwen2-vl" in model_path.lower():
logits = self._get_gme_qwen2_vl_embeddings(prompts, image_data)
else:
logits = self.model.encode(prompts).tolist()
out_queue.put(ModelOutput(embed_logits=logits)) out_queue.put(ModelOutput(embed_logits=logits))
elif self.model_type == "reward": elif self.model_type == "reward":
...@@ -211,11 +295,14 @@ class HFRunner: ...@@ -211,11 +295,14 @@ class HFRunner:
def forward( def forward(
self, self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
image_data: Optional[List[str]] = None,
max_new_tokens: int = 8, max_new_tokens: int = 8,
lora_paths: Optional[List[str]] = None, lora_paths: Optional[List[str]] = None,
token_ids_logprob: Optional[int] = None, token_ids_logprob: Optional[int] = None,
): ):
self.in_queue.put((prompts, max_new_tokens, lora_paths, token_ids_logprob)) self.in_queue.put(
(prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob)
)
return self.out_queue.get() return self.out_queue.get()
def terminate(self): def terminate(self):
...@@ -396,6 +483,7 @@ class SRTRunner: ...@@ -396,6 +483,7 @@ class SRTRunner:
def forward( def forward(
self, self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
image_data: Optional[List[str]] = None,
max_new_tokens: int = 8, max_new_tokens: int = 8,
lora_paths: Optional[List[str]] = None, lora_paths: Optional[List[str]] = None,
logprob_start_len: int = 0, logprob_start_len: int = 0,
...@@ -413,17 +501,23 @@ class SRTRunner: ...@@ -413,17 +501,23 @@ class SRTRunner:
token_ids_logprob=token_ids_logprob, token_ids_logprob=token_ids_logprob,
) )
else: else:
response = self.engine.encode(prompts)
if self.model_type == "embedding": if self.model_type == "embedding":
logits = [x["embedding"] for x in response] response = self.engine.encode(prompt=prompts, image_data=image_data)
if isinstance(response, list):
logits = [x["embedding"] for x in response]
else:
logits = [response["embedding"]]
return ModelOutput(embed_logits=logits) return ModelOutput(embed_logits=logits)
# reward model
else: else:
response = self.engine.encode(prompts)
scores = [x["embedding"][0] for x in response] scores = [x["embedding"][0] for x in response]
return ModelOutput(scores=scores) return ModelOutput(scores=scores)
def batch_forward( def batch_forward(
self, self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
image_data: Optional[List[str]] = None,
max_new_tokens=8, max_new_tokens=8,
lora_paths=None, lora_paths=None,
): ):
...@@ -439,7 +533,7 @@ class SRTRunner: ...@@ -439,7 +533,7 @@ class SRTRunner:
lora_paths=lora_paths, lora_paths=lora_paths,
) )
else: else:
response = self.engine.encode(prompts) response = self.engine.encode(prompts, image_data)
if self.model_type == "embedding": if self.model_type == "embedding":
logits = [x["embedding"] for x in response] logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits) return ModelOutput(embed_logits=logits)
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import multiprocessing as mp
import unittest
import torch
from sglang.test.runners import HFRunner, SRTRunner
from sglang.test.test_utils import get_similarities
TEXTS = "two Subway Series sandwiches with meats, cheese, lettuce, tomatoes, and onions on a black background, accompanied by the Subway Series logo, highlighting a new sandwich series."
IMAGES = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/023.jpg"
MODELS = [
("Alibaba-NLP/gme-Qwen2-VL-2B-Instruct", 1e-3),
]
TORCH_DTYPES = [torch.float16]
class TestQmeQwenModels(unittest.TestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
def assert_close_embeddings(self, model, prefill_tolerance, torch_dtype):
prompts_no_image = f"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n{TEXTS}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>"
prompts_with_image = f"<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n<|im_start|>assistant\n<|endoftext|>"
with HFRunner(
model,
torch_dtype=torch_dtype,
model_type="embedding",
) as hf_runner:
hf_text_embeddings = hf_runner.forward(prompts=[prompts_no_image])
hf_image_embeddings = hf_runner.forward(
prompts=[prompts_with_image], image_data=[IMAGES]
)
with SRTRunner(
model,
tp_size=1,
torch_dtype=torch_dtype,
model_type="embedding",
) as srt_runner:
srt_text_embeddings = srt_runner.forward(prompts=prompts_no_image)
srt_image_embeddings = srt_runner.forward(
prompts=prompts_with_image, image_data=IMAGES
)
similarity = get_similarities(
hf_text_embeddings.embed_logits[0], srt_text_embeddings.embed_logits[0]
)
print("texts similarity diff", abs(similarity - 1))
assert torch.all(
abs(similarity - 1) < prefill_tolerance
), "embeddings are not all close"
similarity = get_similarities(
hf_image_embeddings.embed_logits[0], srt_image_embeddings.embed_logits[0]
)
print("images similarity diff", abs(similarity - 1))
assert torch.all(
abs(similarity - 1) < prefill_tolerance
), "embeddings are not all close"
def test_accuracy(self):
for model, prefill_tolerance in MODELS:
for torch_dtype in TORCH_DTYPES:
self.assert_close_embeddings(model, prefill_tolerance, torch_dtype)
if __name__ == "__main__":
unittest.main()
...@@ -13,6 +13,7 @@ suites = { ...@@ -13,6 +13,7 @@ suites = {
"models/test_qwen_models.py", "models/test_qwen_models.py",
"models/test_reward_models.py", "models/test_reward_models.py",
"test_gptqmodel_dynamic.py", "test_gptqmodel_dynamic.py",
"models/test_gme_qwen_models.py",
"test_abort.py", "test_abort.py",
"test_chunked_prefill.py", "test_chunked_prefill.py",
"test_custom_allreduce.py", "test_custom_allreduce.py",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment