You need to sign in or sign up before continuing.
Unverified Commit 681fdc26 authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

Refactor vlm embedding routine to use precomputed feature (#6543)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
parent 0d477880
...@@ -252,40 +252,36 @@ def get_embedding_chunk( ...@@ -252,40 +252,36 @@ def get_embedding_chunk(
return embedding_chunk, start_index, end_index return embedding_chunk, start_index, end_index
def get_embedding_and_mask( def _get_precomputed_embedding(
items: List[MultimodalDataItem],
) -> Optional[torch.Tensor]:
"""
If all items have precomputed_features, return their concatenation.
If some but not all have precomputed_features, raise NotImplementedError.
If none have precomputed_features, return None.
"""
precomputed_features = [item.precomputed_features for item in items]
if any(feature is not None for feature in precomputed_features):
if not all(feature is not None for feature in precomputed_features):
raise NotImplementedError(
"MM inputs where only some items are precomputed."
)
result = torch.concat(precomputed_features)
# some models embedding is 3-dim, reshape it to 2-dim (similar to get_embedding_chunk)
result = result.reshape(-1, result.shape[-1])
return result
return None
def _get_chunked_prefill_embedding(
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor], data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
embedding_items: List[MultimodalDataItem], embedding_items: List[MultimodalDataItem],
placeholder_tensor: torch.Tensor,
input_ids: torch.Tensor,
items_size: List[int], items_size: List[int],
prefix_length: List[int], prefix_length: List[int],
extend_length: List[int], extend_length: List[int],
items_offset_list: List[List[Tuple[int, int]]], items_offset_list: List[List[Tuple[int, int]]],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Optional[torch.Tensor]:
""" # Calculate embedding for each request, try to get it from cache to avoid repeated calculation
Generate multimodal embeddings and create a mask for identifying their positions in the input sequence.
Args:
data_embedding_func: Function that generates embeddings for multimodal items
embedding_items: List of multimodal items to embed
placeholder_tensor: Tensor containing token IDs that serve as placeholders for multimodal content
input_ids: The input token IDs tensor
items_size: Cumulative sizes of multimodal items per request
prefix_length: Prefix lengths for each request
extend_length: Sequence lengths for each request
items_offset_list: List of offset ranges for multimodal items in each request
Returns:
A tuple containing:
- The generated embeddings tensor
- A boolean mask tensor indicating where these embeddings should be placed
Raises:
AssertionError: If the number of multimodal tokens in input_ids doesn't match
the number of tokens in the generated embeddings
"""
# 1. Get the embedding
# Calculate embedding for each request, try to get it from cache to avoid repeated calculation
embedding_list = [] embedding_list = []
for i in range(len(items_size) - 1): for i in range(len(items_size) - 1):
if items_size[i] == items_size[i + 1]: if items_size[i] == items_size[i + 1]:
...@@ -321,21 +317,28 @@ def get_embedding_and_mask( ...@@ -321,21 +317,28 @@ def get_embedding_and_mask(
embedding_cache.free(embedding_items_hash) embedding_cache.free(embedding_items_hash)
embedding_list.append(embedding_per_req_chunk) embedding_list.append(embedding_per_req_chunk)
if len(embedding_list) == 0: if len(embedding_list) == 0:
return None, None return None
embedding = torch.concat(embedding_list, dim=0) return torch.concat(embedding_list, dim=0)
# 2. Check the embedding
num_mm_tokens_in_embedding = embedding.shape[0]
special_multimodal_mask = torch.isin( def _get_multimodal_mask(
input_ids, input_ids: torch.Tensor, placeholder_tensor: torch.Tensor
placeholder_tensor, ) -> torch.Tensor:
).unsqueeze(-1) return torch.isin(input_ids, placeholder_tensor).unsqueeze(-1)
num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
def _adjust_embedding_length(
embedding: torch.Tensor,
mask: torch.Tensor,
logger,
) -> torch.Tensor:
num_mm_tokens_in_embedding = embedding.shape[0]
num_mm_tokens_in_input_ids = mask.sum().item()
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding: if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
logger.warning( logger.warning(
f"Number of tokens in multimodal embedding does not match those in the input text. " f"Number of tokens in multimodal embedding does not match those in the input text. "
f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} " f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} "
"tokens from multimodal embeddings." f"tokens from multimodal embeddings."
) )
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding: if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"] chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
...@@ -353,7 +356,54 @@ def get_embedding_and_mask( ...@@ -353,7 +356,54 @@ def get_embedding_and_mask(
raise RuntimeError( raise RuntimeError(
f"Insufficient multimodal embedding length: {num_mm_tokens_in_input_ids=} vs {num_mm_tokens_in_embedding=}. This is an internal error" f"Insufficient multimodal embedding length: {num_mm_tokens_in_input_ids=} vs {num_mm_tokens_in_embedding=}. This is an internal error"
) )
return embedding
def get_embedding_and_mask(
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
embedding_items: List[MultimodalDataItem],
placeholder_tensor: torch.Tensor,
input_ids: torch.Tensor,
items_size: List[int],
prefix_length: List[int],
extend_length: List[int],
items_offset_list: List[List[Tuple[int, int]]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate multimodal embeddings and create a mask for identifying their positions in the input sequence.
Args:
data_embedding_func: Function that generates embeddings for multimodal items
embedding_items: List of multimodal items to embed
placeholder_tensor: Tensor containing token IDs that serve as placeholders for multimodal content
input_ids: The input token IDs tensor
items_size: Cumulative sizes of multimodal items per request
prefix_length: Prefix lengths for each request
extend_length: Sequence lengths for each request
items_offset_list: List of offset ranges for multimodal items in each request
Returns:
A tuple containing:
- The generated embeddings tensor
- A boolean mask tensor indicating where these embeddings should be placed
"""
# 1. Get embedding
embedding = _get_precomputed_embedding(embedding_items)
if embedding is None:
embedding = _get_chunked_prefill_embedding(
data_embedding_func,
embedding_items,
items_size,
prefix_length,
extend_length,
items_offset_list,
)
if embedding is None:
return None, None
# 2. Get mask
special_multimodal_mask = _get_multimodal_mask(input_ids, placeholder_tensor)
# 3. Adjust embedding length if needed
embedding = _adjust_embedding_length(embedding, special_multimodal_mask, logger)
return embedding, special_multimodal_mask return embedding, special_multimodal_mask
......
...@@ -144,12 +144,11 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -144,12 +144,11 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
if base_output.images: if base_output.images:
if images_are_preprocessed: if images_are_preprocessed:
image_grid_thw = torch.concat( all_image_grid_thws = [
[ item.image_grid_thws
torch.as_tensor(item.image_grid_thws) for item in base_output.images
for item in base_output.images if item.image_grid_thws is not None
] ]
)
all_pixel_values = [ all_pixel_values = [
item.pixel_values item.pixel_values
for item in base_output.images for item in base_output.images
...@@ -160,6 +159,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -160,6 +159,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
for item in base_output.images for item in base_output.images
if item.precomputed_features is not None if item.precomputed_features is not None
] ]
image_grid_thw = (
torch.concat(all_image_grid_thws) if all_image_grid_thws else None
)
pixel_values = ( pixel_values = (
torch.concat(all_pixel_values) if all_pixel_values else None torch.concat(all_pixel_values) if all_pixel_values else None
) )
......
...@@ -282,13 +282,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -282,13 +282,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
Returns: Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
""" """
if any(item.precomputed_features is not None for item in items):
if not all(item.precomputed_features is not None for item in items):
raise NotImplementedError(
"MM inputs where only some items are precomputed."
)
return torch.concat([item.precomputed_features for item in items])
# Process images one by one to handle flatten_batch=True constraint in vision_tower # Process images one by one to handle flatten_batch=True constraint in vision_tower
all_pixel_values = flatten_nested_list([item.pixel_values for item in items]) all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
vision_outputs_list = [] vision_outputs_list = []
......
...@@ -499,12 +499,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -499,12 +499,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
return pattern.pad_input_tokens(input_ids, mm_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
if any(item.precomputed_features is not None for item in items):
if not all(item.precomputed_features is not None for item in items):
raise NotImplementedError(
"MM inputs where only some items are precomputed."
)
return torch.concat([item.precomputed_features for item in items])
# in qwen-vl, last dim is the same # in qwen-vl, last dim is the same
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type( pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
self.visual.dtype self.visual.dtype
......
...@@ -486,12 +486,6 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -486,12 +486,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
return pattern.pad_input_tokens(input_ids, mm_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
if any(item.precomputed_features is not None for item in items):
if not all(item.precomputed_features is not None for item in items):
raise NotImplementedError(
"MM inputs where only some items are precomputed."
)
return torch.concat([item.precomputed_features for item in items])
# in qwen-vl, last dim is the same # in qwen-vl, last dim is the same
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type( pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
self.visual.dtype self.visual.dtype
......
...@@ -81,7 +81,7 @@ suites = { ...@@ -81,7 +81,7 @@ suites = {
TestFile("test_update_weights_from_tensor.py", 48), TestFile("test_update_weights_from_tensor.py", 48),
TestFile("test_vertex_endpoint.py", 31), TestFile("test_vertex_endpoint.py", 31),
TestFile("test_vision_chunked_prefill.py", 175), TestFile("test_vision_chunked_prefill.py", 175),
TestFile("test_vlm_accuracy.py", 60), TestFile("test_vlm_input_format.py", 300),
TestFile("test_vision_openai_server_a.py", 700), TestFile("test_vision_openai_server_a.py", 700),
TestFile("test_vision_openai_server_b.py", 700), TestFile("test_vision_openai_server_b.py", 700),
TestFile("test_w8a8_quantization.py", 46), TestFile("test_w8a8_quantization.py", 46),
......
...@@ -10,15 +10,8 @@ import requests ...@@ -10,15 +10,8 @@ import requests
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from PIL import Image from PIL import Image
from transformers import ( from transformers import AutoModel, AutoProcessor, AutoTokenizer
AutoModel,
AutoProcessor,
AutoTokenizer,
Gemma3ForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
)
from sglang import Engine
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.conversation import generate_chat_conv from sglang.srt.conversation import generate_chat_conv
from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache
...@@ -41,9 +34,6 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase): ...@@ -41,9 +34,6 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
def setUpClass(cls): def setUpClass(cls):
cls.image_url = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" cls.image_url = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cls.model_path = ""
cls.chat_template = ""
cls.processor = ""
response = requests.get(cls.image_url) response = requests.get(cls.image_url)
cls.main_image = Image.open(BytesIO(response.content)) cls.main_image = Image.open(BytesIO(response.content))
...@@ -274,131 +264,3 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase): ...@@ -274,131 +264,3 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
) )
self.compare_outputs(sglang_output, hf_output) self.compare_outputs(sglang_output, hf_output)
class TestQwenVLUnderstandsImage(VisionLLMLogitsBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
cls.chat_template = "qwen2-vl"
cls.processor = AutoProcessor.from_pretrained(
cls.model_path, trust_remote_code=True, use_fast=True
)
cls.visual = (
Qwen2_5_VLForConditionalGeneration.from_pretrained(
cls.model_path, torch_dtype=torch.bfloat16
)
.eval()
.visual.to(cls.device)
)
def setUp(self):
self.engine = Engine(
model_path=self.model_path,
chat_template=self.chat_template,
device=self.device.type,
mem_fraction_static=0.8,
)
def tearDown(self):
self.engine.shutdown()
async def test_qwen_vl_understands_image(self):
req = self.get_completion_request()
conv = generate_chat_conv(req, template_name=self.chat_template)
text = conv.get_prompt()
output = await self.engine.async_generate(
prompt=text,
image_data=[self.main_image],
sampling_params=dict(temperature=0.0),
)
self.assertIn("taxi", output["text"].lower())
async def test_qwen_vl_understands_precomputed_features(self):
req = self.get_completion_request()
processor_output = self.get_processor_output(req=req)
with torch.inference_mode():
precomputed_features = self.visual(
processor_output["pixel_values"], processor_output["image_grid_thw"]
)
output = await self.engine.async_generate(
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
image_data=[
dict(
modality="IMAGE",
image_grid_thws=processor_output["image_grid_thw"],
precomputed_features=precomputed_features,
)
],
sampling_params=dict(temperature=0.0),
)
self.assertIn("taxi", output["text"].lower())
class TestGemmaUnderstandsImage(VisionLLMLogitsBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model_path = "google/gemma-3-4b-it"
cls.chat_template = "gemma-it"
cls.processor = AutoProcessor.from_pretrained(
cls.model_path, trust_remote_code=True, use_fast=True
)
model = Gemma3ForConditionalGeneration.from_pretrained(
cls.model_path, torch_dtype=torch.bfloat16
)
cls.vision_tower = model.vision_tower.eval().to(cls.device)
cls.mm_projector = model.multi_modal_projector.eval().to(cls.device)
@classmethod
def visual(cls, pixel_values):
vision_outputs = cls.vision_tower(pixel_values=pixel_values).last_hidden_state
image_features = cls.mm_projector(vision_outputs)
return image_features
def setUp(self):
self.engine = Engine(
model_path=self.model_path,
chat_template=self.chat_template,
device=self.device.type,
mem_fraction_static=0.5,
enable_multimodal=True,
)
def tearDown(self):
self.engine.shutdown()
async def test_gemma_understands_image(self):
req = self.get_completion_request()
conv = generate_chat_conv(req, template_name=self.chat_template)
text = conv.get_prompt()
output = await self.engine.async_generate(
prompt=text,
image_data=[self.main_image],
sampling_params=dict(temperature=0.0),
)
self.assertIn("taxi", output["text"].lower())
async def test_gemma_understands_precomputed_features(self):
req = self.get_completion_request()
processor_output = self.get_processor_output(req=req)
with torch.inference_mode():
precomputed_features = self.visual(processor_output["pixel_values"])
output = await self.engine.async_generate(
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
image_data=[
dict(
modality="IMAGE",
precomputed_features=precomputed_features,
)
],
sampling_params=dict(temperature=0.0),
)
self.assertIn("taxi", output["text"].lower())
if __name__ == "__main__":
unittest.main()
import json
import unittest
from io import BytesIO
from typing import Optional
import requests
import torch
from PIL import Image
from transformers import (
AutoProcessor,
Gemma3ForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
)
from sglang import Engine
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.openai_api.protocol import ChatCompletionRequest
TEST_IMAGE_URL = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
class VLMInputTestBase:
model_path = None
chat_template = None
processor = None
visual = None # Should be a callable for precomputed features
@classmethod
def setUpClass(cls):
assert cls.model_path is not None, "Set model_path in subclass"
assert cls.chat_template is not None, "Set chat_template in subclass"
cls.image_url = TEST_IMAGE_URL
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
response = requests.get(cls.image_url)
cls.main_image = Image.open(BytesIO(response.content))
cls.processor = AutoProcessor.from_pretrained(
cls.model_path, trust_remote_code=True, use_fast=True
)
cls._init_visual()
@classmethod
def _init_visual(cls):
"""Override in subclass to set up cls.visual as a callable for precomputed features."""
raise NotImplementedError
def setUp(self):
self.engine = Engine(
model_path=self.model_path,
chat_template=self.chat_template,
device=self.device.type,
mem_fraction_static=0.8,
enable_multimodal=True,
disable_cuda_graph=True,
)
def tearDown(self):
self.engine.shutdown()
def get_completion_request(self) -> ChatCompletionRequest:
json_structure = {
"model": self.model_path,
"messages": [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": self.image_url}},
{"type": "text", "text": "What's in this picture?"},
],
}
],
}
json_str = json.dumps(json_structure)
return ChatCompletionRequest.model_validate_json(json_str)
def get_processor_output(self, req: Optional[ChatCompletionRequest] = None):
if req is None:
req = self.get_completion_request()
conv = generate_chat_conv(req, template_name=self.chat_template)
text = conv.get_prompt()
# Process inputs using processor
inputs = self.processor(
text=[text],
images=[self.main_image],
return_tensors="pt",
).to(self.device)
return inputs
async def test_understands_image(self):
req = self.get_completion_request()
conv = generate_chat_conv(req, template_name=self.chat_template)
text = conv.get_prompt()
output = await self.engine.async_generate(
prompt=text,
image_data=[self.main_image],
sampling_params=dict(temperature=0.0),
)
self.assertIn("taxi", output["text"].lower())
async def test_understands_precomputed_features(self):
req = self.get_completion_request()
processor_output = self.get_processor_output(req=req)
with torch.inference_mode():
precomputed_features = self.__class__.visual(processor_output)
output = await self.engine.async_generate(
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
image_data=[
self._precomputed_image_data(processor_output, precomputed_features)
],
sampling_params=dict(temperature=0.0),
)
self.assertIn("taxi", output["text"].lower())
async def test_understands_pixel_values(self):
req = self.get_completion_request()
processor_output = self.get_processor_output(req=req)
output = await self.engine.async_generate(
input_ids=processor_output["input_ids"][0].detach().cpu().tolist(),
image_data=[self._pixel_values_image_data(processor_output)],
sampling_params=dict(temperature=0.0),
)
self.assertIn("taxi", output["text"].lower())
def _precomputed_image_data(self, processor_output, precomputed_features):
"""This should not be overridden."""
return dict(
modality="IMAGE",
precomputed_features=precomputed_features,
)
def _pixel_values_image_data(self, processor_output):
"""Override in subclass to pass the correct set of arguments."""
raise NotImplementedError
class TestQwenVLUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestCase):
model_path = "Qwen/Qwen2.5-VL-3B-Instruct"
chat_template = "qwen2-vl"
@classmethod
def _init_visual(cls):
cls.visual_model = (
Qwen2_5_VLForConditionalGeneration.from_pretrained(
cls.model_path, torch_dtype=torch.bfloat16
)
.eval()
.visual.to(cls.device)
)
cls.visual = lambda processor_output: cls.visual_model(
processor_output["pixel_values"], processor_output["image_grid_thw"]
)
def _pixel_values_image_data(self, processor_output):
return dict(
modality="IMAGE",
image_grid_thws=processor_output["image_grid_thw"],
pixel_values=processor_output["pixel_values"],
)
class TestGemmaUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestCase):
model_path = "google/gemma-3-4b-it"
chat_template = "gemma-it"
@classmethod
def _init_visual(cls):
model = Gemma3ForConditionalGeneration.from_pretrained(
cls.model_path, torch_dtype=torch.bfloat16
)
cls.vision_tower = model.vision_tower.eval().to(cls.device)
cls.mm_projector = model.multi_modal_projector.eval().to(cls.device)
cls.visual = lambda processor_output: cls.mm_projector(
cls.vision_tower(
pixel_values=processor_output["pixel_values"]
).last_hidden_state
)
def _pixel_values_image_data(self, processor_output):
return dict(
modality="IMAGE",
pixel_values=processor_output["pixel_values"][0],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment