Unverified Commit cf9815ba authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

[Refactor] Multimodal data processing for VLM (#6659)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
parent bd75690f
......@@ -132,7 +132,7 @@
"\n",
"mm_item = dict(\n",
" modality=\"IMAGE\",\n",
" image_grid_thws=processed_prompt[\"image_grid_thw\"],\n",
" image_grid_thw=processed_prompt[\"image_grid_thw\"],\n",
" precomputed_features=precomputed_features,\n",
")\n",
"out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n",
......
......@@ -5,7 +5,8 @@ import multiprocessing as mp
import os
import re
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Tuple, Union
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
......@@ -16,16 +17,24 @@ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.utils import encode_video, load_audio, load_image
class MultimodalInputFormat(Enum):
"""Enum for different multimodal input formats."""
RAW_IMAGES = "raw_images"
PRECOMPUTED_FEATURES = "precomputed_features"
PIXEL_VALUES = "pixel_values"
@dataclasses.dataclass
class BaseMultiModalProcessorOutput:
# input_text, with each frame of video/image represented with a image_token
input_text: str
# frames loaded from image and video, in given order
images: Optional[list[Union[Image.Image, MultimodalDataItem]]] = None
images: Optional[list[Union[Image.Image, dict]]] = None
# audios
audios: Optional[list[Union[np.ndarray, MultimodalDataItem]]] = None
audios: Optional[list[Union[np.ndarray, dict]]] = None
def normalize(self):
for field_name in ["images", "audios"]:
......@@ -170,8 +179,6 @@ class BaseMultimodalProcessor(ABC):
):
"""Static method that can be pickled for multiprocessing"""
if isinstance(data, dict):
return MultimodalDataItem.from_dict(data)
if isinstance(data, MultimodalDataItem):
return data
try:
if is_audio:
......@@ -370,29 +377,180 @@ class BaseMultimodalProcessor(ABC):
return list(zip(indices_start.tolist(), indices_end.tolist()))
def mm_inputs_are_preprocessed(self, mm_inputs: Optional[list]):
"""Returns true if all images are preprocessed, false if all are not, and error otherwise."""
if not mm_inputs:
return True
ret = any(isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs)
if ret and not all(
isinstance(mm_input, MultimodalDataItem) for mm_input in mm_inputs
):
raise ValueError(
"Unsupported: mixture of multimodal inputs where some but not all are preprocessed."
)
return ret
@staticmethod
def _extract_processor_features(
items: List[Any], attr_name: str
items: List[dict], attr_name: str
) -> Optional[torch.Tensor]:
"""
Helper function to concat extracted attributes from processor output.
"""
values = [
getattr(item, attr_name)
for item in items
if getattr(item, attr_name) is not None
]
return torch.concat(values) if values else None
values = [value for item in items if (value := item.get(attr_name)) is not None]
return torch.cat(values) if values else None
# When we assume that all the items have the same attributes
def _extract_processor_features_from_all_attributes(
self, items: List[dict]
) -> dict:
values = {}
# Verify all items have the same keys
first_keys = set(items[0].keys())
for item in items[1:]:
if set(item.keys()) != first_keys:
raise ValueError(
f"All items must have the same attributes. "
f"First item has {first_keys}, but found {set(item.keys())}"
)
# Process each attribute
for k, v in items[0].items():
if isinstance(v, list):
values[k] = self._extract_processor_features(items, k)
else:
# Verify all items have the same value for non-list attributes
for item in items[1:]:
if item[k] != v:
raise ValueError(
f"All items must have the same value for attribute {k}. "
f"First item has {v}, but found {item[k]}"
)
values[k] = v
return values
def process_and_combine_mm_data(
self, base_output: BaseMultiModalProcessorOutput
) -> Tuple[Optional[MultimodalDataItem], torch.Tensor]:
"""
Process multimodal data and return the combined multimodal item and input_ids.
Handles all three input formats at the same abstraction level.
Returns:
Tuple of (combined_mm_item, input_ids)
"""
def tokenize_text(input_text: str) -> torch.Tensor:
"""Tokenize input text."""
return self._processor.tokenizer(
input_text,
return_tensors="pt",
add_special_tokens=True,
).input_ids.flatten()
def categorize_mm_inputs(mm_inputs: List) -> MultimodalInputFormat:
"""Categorize multimodal inputs and validate consistency."""
try:
has_image = False
has_pixel_values = False
has_precomputed_features = False
for mm_input in mm_inputs:
if isinstance(mm_input, Image.Image):
has_image = True
elif isinstance(mm_input, dict):
if mm_input.get("precomputed_features", None) is not None:
has_precomputed_features = True
elif mm_input.get("pixel_values", None) is not None:
has_pixel_values = True
else:
raise ValueError(
f"Invalid multimodal input: {mm_input}, expected dict with pixel_values or precomputed_features"
)
else:
raise ValueError(
f"Invalid multimodal input: {mm_input}, expected Image.Image or dict"
)
# Validate format consistency
format_count = sum(
[has_image, has_pixel_values, has_precomputed_features]
)
if format_count > 1:
raise ValueError(
"Unsupported: mixture of multimodal input formats. "
f"Found formats: image={has_image}, pixel_values={has_pixel_values}, "
f"precomputed_features={has_precomputed_features}"
)
if has_image:
return MultimodalInputFormat.RAW_IMAGES
elif has_precomputed_features:
return MultimodalInputFormat.PRECOMPUTED_FEATURES
elif has_pixel_values:
return MultimodalInputFormat.PIXEL_VALUES
else:
raise ValueError("No valid multimodal input format found")
except Exception as e:
raise ValueError(f"Failed to categorize inputs: {e}")
def process_raw_images(
base_output: BaseMultiModalProcessorOutput,
) -> Tuple[MultimodalDataItem, torch.Tensor]:
"""Process raw Image.Image objects using transformers processor."""
ret = self.process_mm_data(
input_text=base_output.input_text,
images=base_output.images,
)
combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
# Copy all fields from processor output except input_ids
for key, value in ret.items():
if key != "input_ids" and hasattr(combined_mm_item, key):
setattr(combined_mm_item, key, value)
input_ids = ret["input_ids"].flatten()
return combined_mm_item, input_ids
def process_precomputed_features(
base_output: BaseMultiModalProcessorOutput,
) -> Tuple[MultimodalDataItem, torch.Tensor]:
"""Process inputs with precomputed features."""
combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
combined_mm_item.precomputed_features = self._extract_processor_features(
base_output.images, "precomputed_features"
)
input_ids = tokenize_text(base_output.input_text)
return combined_mm_item, input_ids
def process_pixel_values(
base_output: BaseMultiModalProcessorOutput,
) -> Tuple[MultimodalDataItem, torch.Tensor]:
"""Process inputs with pixel values."""
values = self._extract_processor_features_from_all_attributes(
base_output.images
)
combined_mm_item = MultimodalDataItem.from_dict(values)
input_ids = tokenize_text(base_output.input_text)
return combined_mm_item, input_ids
def finalize_mm_item(
combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor
) -> MultimodalDataItem:
"""Apply common post-processing to the multimodal item."""
combined_mm_item.image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.IM_TOKEN_ID,
)
return combined_mm_item
# Main logic
mm_inputs = base_output.images
if not mm_inputs:
# Return text-only case
input_ids = tokenize_text(base_output.input_text)
return None, input_ids
# Categorize input formats
input_format = categorize_mm_inputs(mm_inputs)
# Process based on format
if input_format == MultimodalInputFormat.RAW_IMAGES:
combined_mm_item, input_ids = process_raw_images(base_output)
elif input_format == MultimodalInputFormat.PRECOMPUTED_FEATURES:
combined_mm_item, input_ids = process_precomputed_features(base_output)
elif input_format == MultimodalInputFormat.PIXEL_VALUES:
combined_mm_item, input_ids = process_pixel_values(base_output)
else:
raise ValueError(f"Unknown input format: {input_format}")
# Finalize with common processing
combined_mm_item = finalize_mm_item(combined_mm_item, input_ids)
return combined_mm_item, input_ids
......@@ -27,6 +27,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
)
self.IM_START_TOKEN_ID = hf_config.boi_token_index
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
self.IM_TOKEN_ID = hf_config.image_token_index
async def process_mm_data_async(
self,
......@@ -42,49 +43,21 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
if isinstance(image_data, str):
image_data = [image_data]
image_token = self.IMAGE_TOKEN
image_token_regex = self.IMAGE_TOKEN_REGEX
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(
image_token=image_token, image_token_regex=image_token_regex
image_token=self.IMAGE_TOKEN, image_token_regex=self.IMAGE_TOKEN_REGEX
),
max_req_input_len=max_req_input_len,
discard_alpha_channel=True,
)
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
ret = self.process_mm_data(
input_text=base_output.input_text,
images=None if images_are_preprocessed else base_output.images,
)
items = []
input_ids = ret["input_ids"].flatten()
image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.hf_config.image_token_index,
)
for i, image in enumerate(base_output.images):
if images_are_preprocessed:
pixel_values = image.pixel_values
precomputed_features = image.precomputed_features
else:
pixel_values = ret["pixel_values"][i]
precomputed_features = None
item = MultimodalDataItem(
pixel_values=pixel_values,
precomputed_features=precomputed_features,
modality=Modality.IMAGE,
image_offsets=image_offsets[i],
)
items += [item]
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
return {
"mm_items": items,
"input_ids": input_ids.tolist(),
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
}
......@@ -21,7 +21,7 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<|media_pad|>"
self.IMAGE_TOKEN_REGEX = re.compile(r"(?:<\|media_pad\|>)+")
self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
self.IM_TOKEN_ID = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
async def process_mm_data_async(
self,
......@@ -46,48 +46,10 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
max_req_input_len=max_req_input_len,
)
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
if not images_are_preprocessed:
ret = self.process_mm_data(
input_text=base_output.input_text,
images=base_output.images,
)
input_ids = ret["input_ids"].flatten()
image_grid_thws = ret["image_grid_hws"]
pixel_values = ret["pixel_values"]
precomputed_features = None
else:
input_ids = self._processor.tokenizer(
base_output.input_text,
return_tensors="pt",
add_special_tokens=True,
).input_ids.flatten()
image_grid_thws = self._extract_processor_features(
base_output.images, "image_grid_thws"
)
precomputed_features = self._extract_processor_features(
base_output.images, "precomputed_features"
)
pixel_values = self._extract_processor_features(
base_output.images, "pixel_values"
)
image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.im_token_id,
)
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
return {
"input_ids": input_ids.tolist(),
"mm_items": [
MultimodalDataItem(
pixel_values=pixel_values,
image_grid_thws=image_grid_thws,
precomputed_features=precomputed_features,
modality=Modality.IMAGE,
image_offsets=image_offsets,
)
],
"im_token_id": self.im_token_id,
"mm_items": [combined_mm_item] if combined_mm_item is not None else [],
"im_token_id": self.IM_TOKEN_ID,
}
......@@ -32,8 +32,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
)
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
self.image_token_id = hf_config.image_token_id
self.video_token_id = hf_config.video_token_id
self.IM_TOKEN_ID = hf_config.image_token_id
self.VIDEO_TOKEN_ID = hf_config.video_token_id
self.vision_start_token_id = hf_config.vision_start_token_id
self.vision_end_token_id = hf_config.vision_end_token_id
self.NUM_TOKEN_PER_FRAME = 770
......@@ -125,72 +125,45 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
async def resize_image_async(image):
return resize_image(image)
images_are_preprocessed = self.mm_inputs_are_preprocessed(base_output.images)
if base_output.images and not images_are_preprocessed:
# Qwen-specific: resize images if they are raw Image objects
if base_output.images and isinstance(base_output.images[0], Image.Image):
resize_tasks = [resize_image_async(image) for image in base_output.images]
base_output.images = await asyncio.gather(*resize_tasks)
ret = self.process_mm_data(
input_text=base_output.input_text,
images=None if images_are_preprocessed else base_output.images,
)
input_ids = ret["input_ids"].flatten().tolist()
image_offsets = self.get_mm_items_offset(
input_ids=ret["input_ids"].flatten(), mm_token_id=self.image_token_id
)
image_grid_thw = None
video_grid_thw = None # TODO
items = []
if base_output.images:
if images_are_preprocessed:
image_grid_thw = self._extract_processor_features(
base_output.images, "image_grid_thws"
)
precomputed_features = self._extract_processor_features(
base_output.images, "precomputed_features"
)
pixel_values = self._extract_processor_features(
base_output.images, "pixel_values"
)
else:
image_grid_thw = ret["image_grid_thw"]
pixel_values = ret["pixel_values"]
precomputed_features = None
items += [
MultimodalDataItem(
pixel_values=pixel_values,
image_grid_thws=image_grid_thw,
video_grid_thws=video_grid_thw,
precomputed_features=precomputed_features,
image_offsets=image_offsets,
modality=Modality.IMAGE,
)
]
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
if combined_mm_item is None:
# Note(Xinyuan): This is the case where image loading fails.
return None
video_grid_thw = None # TODO
second_per_grid_ts = getattr(combined_mm_item, "second_per_grid_ts", None)
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
image_token_id=self.image_token_id,
video_token_id=self.video_token_id,
image_token_id=self.IM_TOKEN_ID,
video_token_id=self.VIDEO_TOKEN_ID,
vision_start_token_id=self.vision_start_token_id,
model_type=self.hf_config.model_type,
tokens_per_second=getattr(
self.hf_config.vision_config, "tokens_per_second", None
),
input_ids=torch.tensor(input_ids).unsqueeze(0),
image_grid_thw=image_grid_thw,
input_ids=input_ids.unsqueeze(0),
image_grid_thw=combined_mm_item.image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=ret.get("second_per_grid_ts", None),
second_per_grid_ts=second_per_grid_ts,
)
mrope_positions = mrope_positions.squeeze(1)
return {
"input_ids": input_ids,
"mm_items": items,
"input_ids": input_ids.tolist(),
"mm_items": [combined_mm_item],
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
"im_token_id": self.image_token_id,
"video_token_id": self.video_token_id,
"im_token_id": self.IM_TOKEN_ID,
"video_token_id": self.VIDEO_TOKEN_ID,
"mrope_positions": mrope_positions,
"mrope_position_delta": mrope_position_delta,
}
......@@ -188,7 +188,7 @@ class MultimodalDataItem:
# the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.ndarray]]
pixel_values: Union[torch.Tensor, np.ndarray] = None
image_grid_thws: Union[torch.Tensor, np.ndarray] = None
image_grid_thw: Union[torch.Tensor, np.ndarray] = None
video_grid_thws: Union[torch.Tensor, np.ndarray] = None
image_emb_mask: Optional[torch.Tensor] = None
......@@ -198,6 +198,9 @@ class MultimodalDataItem:
# [num_images, (n, w, h)]
tgt_size: Tuple[int, int] = None
# kimi-vl related
image_grid_hws: Optional[List[torch.Tensor]] = None
audio_features: Union[torch.Tensor, np.ndarray] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None
audio_offsets: Optional[List[Tuple[int, int]]] = None
......
......@@ -286,14 +286,26 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
vision_outputs_list = []
for pixel_value in all_pixel_values:
# Add batch dimension for single image processing
pixel_value_batch = pixel_value.unsqueeze(0)
pixel_value_batch = pixel_value_batch.to(device=self.vision_tower.device)
pixel_value_batch = pixel_value_batch.to(dtype=self.language_model.dtype())
vision_output = self.vision_tower(pixel_values=pixel_value_batch)
vision_outputs_list.append(vision_output)
for pixel_values_batch in all_pixel_values:
# Normalize input shape to [batch_size, channels, height, width]
if pixel_values_batch.dim() == 5:
pixel_values_batch = pixel_values_batch.squeeze(0)
elif pixel_values_batch.dim() == 3:
pixel_values_batch = pixel_values_batch.unsqueeze(0)
elif pixel_values_batch.dim() != 4:
raise ValueError(
f"Unexpected pixel_values shape: {pixel_values_batch.shape}"
)
# Process each image in the batch
batch_size = pixel_values_batch.shape[0]
for i in range(batch_size):
pixel_value = pixel_values_batch[i : i + 1] # Keep batch dimension as 1
pixel_value = pixel_value.to(
device=self.vision_tower.device, dtype=self.language_model.dtype()
)
vision_output = self.vision_tower(pixel_values=pixel_value)
vision_outputs_list.append(vision_output)
# Concatenate all vision outputs
vision_outputs = torch.cat(vision_outputs_list, dim=0)
......
......@@ -144,10 +144,10 @@ class KimiVLForConditionalGeneration(nn.Module):
.type(self.vision_tower.dtype)
.to(self.vision_tower.device)
)
image_grid_thws = torch.concat(
[item.image_grid_thws for item in items], dim=0
).to(self.vision_tower.device)
image_features = self.vision_tower(pixel_values, image_grid_thws)
image_grid_hws = torch.cat([item.image_grid_hws for item in items], dim=0).to(
self.vision_tower.device
)
image_features = self.vision_tower(pixel_values, image_grid_hws)
assert isinstance(image_features, list)
# lengths = [x.shape[0] for x in image_features]
res = self.multi_modal_projector(torch.cat(image_features)) # .split(lengths)
......
......@@ -503,10 +503,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
self.visual.dtype
)
image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0)
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert image_grid_thws.dim() == 2, image_grid_thws.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws)
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
......
......@@ -490,10 +490,10 @@ class Qwen2VLForConditionalGeneration(nn.Module):
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
self.visual.dtype
)
image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0)
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert image_grid_thws.dim() == 2, image_grid_thws.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws)
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
......
......@@ -156,7 +156,7 @@ class TestQwenVLUnderstandsImage(VLMInputTestBase, unittest.IsolatedAsyncioTestC
def _pixel_values_image_data(self, processor_output):
return dict(
modality="IMAGE",
image_grid_thws=processor_output["image_grid_thw"],
image_grid_thw=processor_output["image_grid_thw"],
pixel_values=processor_output["pixel_values"],
)
......@@ -207,8 +207,8 @@ class TestKimiVLImageUnderstandsImage(
def _pixel_values_image_data(self, processor_output):
return dict(
modality="IMAGE",
image_grid_thws=processor_output["image_grid_hws"],
pixel_values=processor_output["pixel_values"],
image_grid_hws=processor_output["image_grid_hws"],
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment