Unverified Commit 8430bfe3 authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

[Refactor] simplify multimodal data processing (#8107)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
parent c9e8613c
...@@ -126,14 +126,14 @@ ...@@ -126,14 +126,14 @@
" images=[image], text=conv.get_prompt(), return_tensors=\"pt\"\n", " images=[image], text=conv.get_prompt(), return_tensors=\"pt\"\n",
")\n", ")\n",
"input_ids = processed_prompt[\"input_ids\"][0].detach().cpu().tolist()\n", "input_ids = processed_prompt[\"input_ids\"][0].detach().cpu().tolist()\n",
"precomputed_features = vision(\n", "precomputed_embeddings = vision(\n",
" processed_prompt[\"pixel_values\"].cuda(), processed_prompt[\"image_grid_thw\"].cuda()\n", " processed_prompt[\"pixel_values\"].cuda(), processed_prompt[\"image_grid_thw\"].cuda()\n",
")\n", ")\n",
"\n", "\n",
"mm_item = dict(\n", "mm_item = dict(\n",
" modality=\"IMAGE\",\n", " modality=\"IMAGE\",\n",
" image_grid_thw=processed_prompt[\"image_grid_thw\"],\n", " image_grid_thw=processed_prompt[\"image_grid_thw\"],\n",
" precomputed_features=precomputed_features,\n", " precomputed_embeddings=precomputed_embeddings,\n",
")\n", ")\n",
"out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n", "out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n",
"print(out[\"text\"])" "print(out[\"text\"])"
......
...@@ -42,6 +42,9 @@ def select_best_resolution(image_size, candidate_resolutions): ...@@ -42,6 +42,9 @@ def select_best_resolution(image_size, candidate_resolutions):
class DictOutput(object): class DictOutput(object):
def items(self):
return self.__dict__.items()
def keys(self): def keys(self):
return self.__dict__.keys() return self.__dict__.keys()
...@@ -59,7 +62,9 @@ class DictOutput(object): ...@@ -59,7 +62,9 @@ class DictOutput(object):
class VLChatProcessorOutput(DictOutput): class VLChatProcessorOutput(DictOutput):
input_ids: torch.LongTensor input_ids: torch.LongTensor
target_ids: torch.LongTensor target_ids: torch.LongTensor
images: torch.Tensor pixel_values: (
torch.Tensor
) # rename from "images" to "pixel_values" for compatibility
images_seq_mask: torch.BoolTensor images_seq_mask: torch.BoolTensor
images_spatial_crop: torch.LongTensor images_spatial_crop: torch.LongTensor
...@@ -312,10 +317,14 @@ class DeepseekVLV2Processor(ProcessorMixin): ...@@ -312,10 +317,14 @@ class DeepseekVLV2Processor(ProcessorMixin):
images = torch.stack(images_list, dim=0) images = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
images_spatial_crop = torch.stack(
[images_spatial_crop], dim=0
) # stack the tensor to make it a batch of 1
prepare = VLChatProcessorOutput( prepare = VLChatProcessorOutput(
input_ids=input_ids, input_ids=input_ids,
target_ids=target_ids, target_ids=target_ids,
images=images, pixel_values=images,
images_seq_mask=images_seq_mask, images_seq_mask=images_seq_mask,
images_spatial_crop=images_spatial_crop, images_spatial_crop=images_spatial_crop,
) )
......
...@@ -284,6 +284,9 @@ class VLMImageProcessor(BaseImageProcessor): ...@@ -284,6 +284,9 @@ class VLMImageProcessor(BaseImageProcessor):
class DictOutput(object): class DictOutput(object):
def items(self):
return self.__dict__.items()
def keys(self): def keys(self):
return self.__dict__.keys() return self.__dict__.keys()
......
...@@ -221,17 +221,17 @@ def _get_precomputed_embedding( ...@@ -221,17 +221,17 @@ def _get_precomputed_embedding(
items: List[MultimodalDataItem], items: List[MultimodalDataItem],
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
""" """
If all items have precomputed_features, return their concatenation. If all items have precomputed_embeddings, return their concatenation.
If some but not all have precomputed_features, raise NotImplementedError. If some but not all have precomputed_embeddings, raise NotImplementedError.
If none have precomputed_features, return None. If none have precomputed_embeddings, return None.
""" """
precomputed_features = [item.precomputed_features for item in items] precomputed_embeddings = [item.precomputed_embeddings for item in items]
if any(feature is not None for feature in precomputed_features): if any(feature is not None for feature in precomputed_embeddings):
if not all(feature is not None for feature in precomputed_features): if not all(feature is not None for feature in precomputed_embeddings):
raise NotImplementedError( raise NotImplementedError(
"MM inputs where only some items are precomputed." "MM inputs where only some items are precomputed."
) )
result = torch.concat(precomputed_features) result = torch.concat(precomputed_embeddings)
# some models embedding is 3-dim, reshape it to 2-dim (similar to get_embedding_chunk) # some models embedding is 3-dim, reshape it to 2-dim (similar to get_embedding_chunk)
result = result.reshape(-1, result.shape[-1]) result = result.reshape(-1, result.shape[-1])
return result return result
......
...@@ -201,7 +201,7 @@ class MultimodalDataItem: ...@@ -201,7 +201,7 @@ class MultimodalDataItem:
For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem. For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
One for images and one for audio. One for images and one for audio.
We put the common fields first and the model-specific fields last. We put the common fields first and the model-specific fields in model_specific_data.
""" """
modality: Modality modality: Modality
...@@ -211,37 +211,31 @@ class MultimodalDataItem: ...@@ -211,37 +211,31 @@ class MultimodalDataItem:
# the raw features returned by processor, e.g. pixel_values or audio_features # the raw features returned by processor, e.g. pixel_values or audio_features
feature: Union[torch.Tensor, np.ndarray] = None feature: Union[torch.Tensor, np.ndarray] = None
image_sizes: Tuple[int, int] = None # the precomputed embeddings for the modality, e.g. image_emb for image, audio_emb for audio
precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None # Model-specific data stored in a dictionary
audio_offsets: Optional[List[Tuple[int, int]]] = None model_specific_data: dict[str, Any] = dataclasses.field(default_factory=dict)
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
# For qwen-vl def __getattr__(self, name: str):
image_grid_thw: Union[torch.Tensor, np.ndarray] = None if (
second_per_grid_ts: Optional[List[torch.Tensor]] = None "model_specific_data" in self.__dict__
and name in self.__dict__["model_specific_data"]
# For deepseek-vl ):
image_emb_mask: Optional[torch.Tensor] = None return self.__dict__["model_specific_data"][name]
image_spatial_crop: Optional[torch.Tensor] = None else:
raise AttributeError(
# For minicpmv f"'{self.__class__.__name__}' object has no attribute '{name}'"
# [num_images, (n, w, h)] )
tgt_size: Tuple[int, int] = None
# For mllama
aspect_ratio_id: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
# For kimi-vl
image_grid_hws: Optional[List[torch.Tensor]] = None
# For gemma3n def __setitem__(self, key: str, value: Any):
input_features_mask: Optional[torch.Tensor] = None if key in self.__dict__:
self.__dict__[key] = value
else:
self.model_specific_data[key] = value
# For phi4-mm def set(self, key: str, value: Any):
image_attention_mask: Optional[torch.Tensor] = None self.__setitem__(key, value)
audio_attention_mask: Optional[torch.Tensor] = None
@staticmethod @staticmethod
def is_empty_list(l): def is_empty_list(l):
...@@ -259,7 +253,7 @@ class MultimodalDataItem: ...@@ -259,7 +253,7 @@ class MultimodalDataItem:
if self.feature is not None: if self.feature is not None:
hashed_feature = self.feature hashed_feature = self.feature
else: else:
hashed_feature = self.precomputed_features hashed_feature = self.precomputed_embeddings
self.hash = hash_feature(hashed_feature) self.hash = hash_feature(hashed_feature)
assert self.hash is not None assert self.hash is not None
self.pad_value = self.hash % (1 << 30) self.pad_value = self.hash % (1 << 30)
...@@ -268,24 +262,13 @@ class MultimodalDataItem: ...@@ -268,24 +262,13 @@ class MultimodalDataItem:
return self.modality == modality return self.modality == modality
def is_audio(self): def is_audio(self):
return (self.modality == Modality.AUDIO) and ( return self.modality == Modality.AUDIO
self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.feature)
)
def is_image(self): def is_image(self):
return ( return self.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]
self.is_modality(Modality.IMAGE) or self.is_modality(Modality.MULTI_IMAGES)
) and (
self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.feature)
)
def is_video(self): def is_video(self):
return (self.modality == Modality.VIDEO) and ( return self.modality == Modality.VIDEO
self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.feature)
)
def is_valid(self) -> bool: def is_valid(self) -> bool:
return self.is_image() or self.is_video() or self.is_audio() return self.is_image() or self.is_video() or self.is_audio()
...@@ -306,8 +289,7 @@ class MultimodalDataItem: ...@@ -306,8 +289,7 @@ class MultimodalDataItem:
def merge(self, other): def merge(self, other):
self.feature += other.feature self.feature += other.feature
self.image_sizes += other.image_sizes self.offsets += other.offsets
self.image_offsets += other.image_offsets
self.hash = hash((self.hash, other.hash)) self.hash = hash((self.hash, other.hash))
self.set_pad_value() self.set_pad_value()
......
...@@ -260,7 +260,7 @@ class DeepseekVL2ForCausalLM(nn.Module): ...@@ -260,7 +260,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
def get_image_feature(self, items: List[MultimodalDataItem]): def get_image_feature(self, items: List[MultimodalDataItem]):
images_spatial_crop = torch.cat( images_spatial_crop = torch.cat(
[item.image_spatial_crop for item in items], dim=0 [item.images_spatial_crop for item in items], dim=0
) )
assert images_spatial_crop.dim() == 3 assert images_spatial_crop.dim() == 3
...@@ -278,8 +278,8 @@ class DeepseekVL2ForCausalLM(nn.Module): ...@@ -278,8 +278,8 @@ class DeepseekVL2ForCausalLM(nn.Module):
_, hw, n_dim = images_embeds.shape _, hw, n_dim = images_embeds.shape
h = w = int(hw**0.5) h = w = int(hw**0.5)
tile_index = 0 tile_index = 0
for jdx in range(item.image_spatial_crop.shape[1]): for jdx in range(item.images_spatial_crop.shape[1]):
num_width_tiles, num_height_tiles = item.image_spatial_crop[0, jdx] num_width_tiles, num_height_tiles = item.images_spatial_crop[0, jdx]
if num_width_tiles == 0 or num_height_tiles == 0: if num_width_tiles == 0 or num_height_tiles == 0:
break break
num_tiles_in_image = num_width_tiles * num_height_tiles num_tiles_in_image = num_width_tiles * num_height_tiles
......
...@@ -81,6 +81,7 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -81,6 +81,7 @@ class Llama4ForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(
config.text_config if hasattr(config, "text_config") else config config.text_config if hasattr(config, "text_config") else config
) )
self.padding_pattern = MultiModalityDataPaddingPatternMultimodalTokens()
def _has_vision_weights(self, config) -> bool: def _has_vision_weights(self, config) -> bool:
"""Check if the model has vision components by examining the checkpoint.""" """Check if the model has vision components by examining the checkpoint."""
...@@ -135,8 +136,7 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -135,8 +136,7 @@ class Llama4ForConditionalGeneration(nn.Module):
return False return False
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
pattern = MultiModalityDataPaddingPatternMultimodalTokens() return self.padding_pattern.pad_input_tokens(input_ids, mm_inputs)
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature( def get_image_feature(
self, self,
......
...@@ -435,7 +435,12 @@ class Phi4MMForCausalLM(nn.Module): ...@@ -435,7 +435,12 @@ class Phi4MMForCausalLM(nn.Module):
dtype = next(self.vision_encoder.parameters()).dtype dtype = next(self.vision_encoder.parameters()).dtype
pixel_values = torch.cat([item.feature for item in items], dim=0).type(dtype) pixel_values = torch.cat([item.feature for item in items], dim=0).type(dtype)
image_attention_mask = torch.cat( image_attention_mask = torch.cat(
[item.image_attention_mask for item in items], dim=0 [
item.image_attention_mask
for item in items
if hasattr(item, "image_attention_mask")
],
dim=0,
) )
image_sizes = torch.cat([item.image_sizes for item in items], dim=0) image_sizes = torch.cat([item.image_sizes for item in items], dim=0)
image_embeds = self.vision_encoder( image_embeds = self.vision_encoder(
...@@ -456,7 +461,7 @@ class Phi4MMForCausalLM(nn.Module): ...@@ -456,7 +461,7 @@ class Phi4MMForCausalLM(nn.Module):
audio_features=item.feature.to(device).type(dtype), audio_features=item.feature.to(device).type(dtype),
audio_attention_mask=( audio_attention_mask=(
item.audio_attention_mask.to(device) item.audio_attention_mask.to(device)
if item.audio_attention_mask is not None if hasattr(item, "audio_attention_mask")
else None else None
), ),
) )
......
...@@ -5,7 +5,7 @@ import multiprocessing as mp ...@@ -5,7 +5,7 @@ import multiprocessing as mp
import os import os
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -155,17 +155,15 @@ class BaseMultimodalProcessor(ABC): ...@@ -155,17 +155,15 @@ class BaseMultimodalProcessor(ABC):
self.ATTR_NAME_TO_MODALITY = { self.ATTR_NAME_TO_MODALITY = {
# Image-related attributes # Image-related attributes
"pixel_values": Modality.IMAGE, "pixel_values": Modality.IMAGE,
"pixel_values_videos": Modality.VIDEO,
"image_sizes": Modality.IMAGE, "image_sizes": Modality.IMAGE,
"image_grid_thw": Modality.IMAGE, "image_grid_thw": Modality.IMAGE,
"image_attention_mask": Modality.IMAGE, "image_attention_mask": Modality.IMAGE,
"image_emb_mask": Modality.IMAGE, "image_emb_mask": Modality.IMAGE,
"image_spatial_crop": Modality.IMAGE, "images_spatial_crop": Modality.IMAGE,
"tgt_size": Modality.IMAGE, "tgt_size": Modality.IMAGE,
"image_grid_hws": Modality.IMAGE, "image_grid_hws": Modality.IMAGE,
"aspect_ratio_id": Modality.IMAGE, "aspect_ratio_ids": Modality.IMAGE,
"aspect_ratio_mask": Modality.IMAGE, "aspect_ratio_mask": Modality.IMAGE,
"second_per_grid_ts": Modality.IMAGE,
# Audio-related attributes # Audio-related attributes
"audio_features": Modality.AUDIO, "audio_features": Modality.AUDIO,
"audio_feature_lens": Modality.AUDIO, "audio_feature_lens": Modality.AUDIO,
...@@ -173,9 +171,11 @@ class BaseMultimodalProcessor(ABC): ...@@ -173,9 +171,11 @@ class BaseMultimodalProcessor(ABC):
"input_features_mask": Modality.AUDIO, "input_features_mask": Modality.AUDIO,
"audio_attention_mask": Modality.AUDIO, "audio_attention_mask": Modality.AUDIO,
# Video-related attributes # Video-related attributes
"pixel_values_videos": Modality.VIDEO,
"second_per_grid_ts": Modality.VIDEO,
"video_grid_thw": Modality.VIDEO, "video_grid_thw": Modality.VIDEO,
# Generic attributes that could apply to multiple modalities # Generic attributes that could apply to multiple modalities
# "precomputed_features" - handled specially as it can be any modality # "precomputed_embeddings" - handled specially as it can be any modality
} }
# name of the feature filed # name of the feature filed
...@@ -222,7 +222,6 @@ class BaseMultimodalProcessor(ABC): ...@@ -222,7 +222,6 @@ class BaseMultimodalProcessor(ABC):
audio_data, audio_data,
input_text, input_text,
request_obj, request_obj,
max_req_input_len,
**kwargs, **kwargs,
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
pass pass
...@@ -283,7 +282,7 @@ class BaseMultimodalProcessor(ABC): ...@@ -283,7 +282,7 @@ class BaseMultimodalProcessor(ABC):
self, self,
text_parts: List[str], text_parts: List[str],
multimodal_tokens: MultimodalSpecialTokens, multimodal_tokens: MultimodalSpecialTokens,
data_iterators: dict, data_iterators: dict[Modality, Iterator[Any]],
discard_alpha_channel: bool = True, discard_alpha_channel: bool = True,
image_estimated_frames_iter: Optional[iter] = None, image_estimated_frames_iter: Optional[iter] = None,
image_scaling_factor: float = 1.0, image_scaling_factor: float = 1.0,
...@@ -354,7 +353,6 @@ class BaseMultimodalProcessor(ABC): ...@@ -354,7 +353,6 @@ class BaseMultimodalProcessor(ABC):
self, self,
prompt: str, prompt: str,
multimodal_tokens: MultimodalSpecialTokens, multimodal_tokens: MultimodalSpecialTokens,
max_req_input_len: int,
image_data: Optional[list] = None, image_data: Optional[list] = None,
video_data: Optional[list] = None, video_data: Optional[list] = None,
audio_data: Optional[list] = None, audio_data: Optional[list] = None,
...@@ -489,50 +487,11 @@ class BaseMultimodalProcessor(ABC): ...@@ -489,50 +487,11 @@ class BaseMultimodalProcessor(ABC):
return list(zip(indices_start.tolist(), indices_end.tolist())) return list(zip(indices_start.tolist(), indices_end.tolist()))
@staticmethod
def _extract_processor_features(
items: List[dict], attr_name: str
) -> Optional[torch.Tensor]:
"""
Helper function to concat extracted attributes from processor output.
"""
values = [value for item in items if (value := item.get(attr_name)) is not None]
return torch.cat(values) if values else None
# When we assume that all the items have the same attributes
def _extract_processor_features_from_all_attributes(
self, items: List[dict]
) -> dict:
values = {}
# Verify all items have the same keys
first_keys = set(items[0].keys())
for item in items[1:]:
if set(item.keys()) != first_keys:
raise ValueError(
f"All items must have the same attributes. "
f"First item has {first_keys}, but found {set(item.keys())}"
)
# Process each attribute
for k, v in items[0].items():
if isinstance(v, list):
values[k] = self._extract_processor_features(items, k)
else:
# Verify all items have the same value for non-list attributes
for item in items[1:]:
if item[k] != v:
raise ValueError(
f"All items must have the same value for attribute {k}. "
f"First item has {v}, but found {item[k]}"
)
values[k] = v
return values
def collect_mm_items_from_processor_output( def collect_mm_items_from_processor_output(
self, data_dict: dict self, data_dict: dict
) -> List[MultimodalDataItem]: ) -> List[MultimodalDataItem]:
"""Create mm_items directly from processor output.""" """Create mm_items directly from processor output."""
items = {} # modality -> MultimodalDataItem items: dict[Modality, MultimodalDataItem] = {}
for attr_name, value in data_dict.items(): for attr_name, value in data_dict.items():
if attr_name == "input_ids": if attr_name == "input_ids":
...@@ -541,16 +500,15 @@ class BaseMultimodalProcessor(ABC): ...@@ -541,16 +500,15 @@ class BaseMultimodalProcessor(ABC):
# Get modality for this attribute # Get modality for this attribute
modality = self.ATTR_NAME_TO_MODALITY.get(attr_name) modality = self.ATTR_NAME_TO_MODALITY.get(attr_name)
if not modality and attr_name == "precomputed_features": if attr_name == "precomputed_embeddings":
modality_str = data_dict.get("modality") modality_str = data_dict.get("modality")
modality = Modality.IMAGE
if modality_str:
try: try:
modality = ( modality = Modality.from_str(modality_str)
Modality.from_str(modality_str)
if modality_str
else Modality.IMAGE
)
except ValueError: except ValueError:
modality = Modality.IMAGE pass
if modality: if modality:
# Create item if needed # Create item if needed
if modality not in items: if modality not in items:
...@@ -559,8 +517,7 @@ class BaseMultimodalProcessor(ABC): ...@@ -559,8 +517,7 @@ class BaseMultimodalProcessor(ABC):
if attr_name in self.FEATURE_NAMES: if attr_name in self.FEATURE_NAMES:
attr_name = "feature" attr_name = "feature"
# Set attribute items[modality].set(attr_name, value)
setattr(items[modality], attr_name, value)
return list(items.values()) return list(items.values())
...@@ -586,6 +543,7 @@ class BaseMultimodalProcessor(ABC): ...@@ -586,6 +543,7 @@ class BaseMultimodalProcessor(ABC):
self, self,
base_output: BaseMultiModalProcessorOutput, base_output: BaseMultiModalProcessorOutput,
mm_tokens: MultimodalSpecialTokens, mm_tokens: MultimodalSpecialTokens,
**kwargs,
) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]: ) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]:
""" """
Process multimodal data and return the combined multimodal items and input_ids. Process multimodal data and return the combined multimodal items and input_ids.
...@@ -618,7 +576,7 @@ class BaseMultimodalProcessor(ABC): ...@@ -618,7 +576,7 @@ class BaseMultimodalProcessor(ABC):
else: else:
raise ValueError(f"Unknown multimodal item type: {type(item)}") raise ValueError(f"Unknown multimodal item type: {type(item)}")
# Process items and get input_ids # Process items and get input_ids
all_collected_items = [] all_collected_items: list[MultimodalDataItem] = []
input_ids = None input_ids = None
# Handle dict items (already processed) # Handle dict items (already processed)
...@@ -634,6 +592,7 @@ class BaseMultimodalProcessor(ABC): ...@@ -634,6 +592,7 @@ class BaseMultimodalProcessor(ABC):
images=raw_images, images=raw_images,
audios=raw_audios, audios=raw_audios,
videos=raw_videos, videos=raw_videos,
**kwargs,
) )
all_collected_items.extend(collected_items) all_collected_items.extend(collected_items)
else: else:
......
from typing import List, Union from typing import List, Union
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.clip import CLIPModel from sglang.srt.models.clip import CLIPModel
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor from sglang.srt.multimodal.processors.base_processor import (
from sglang.srt.utils import load_image BaseMultimodalProcessor,
MultimodalSpecialTokens,
)
class ClipImageProcessor(BaseMultimodalProcessor): class ClipImageProcessor(BaseMultimodalProcessor):
...@@ -11,23 +12,24 @@ class ClipImageProcessor(BaseMultimodalProcessor): ...@@ -11,23 +12,24 @@ class ClipImageProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
_processor
)
async def process_mm_data_async( async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
): ):
if isinstance(input_text, list): base_output = self.load_mm_data(
assert len(input_text) and isinstance(input_text[0], int) prompt=input_text,
input_text = self._processor.tokenizer.decode(input_text) multimodal_tokens=self.mm_tokens,
image_data=image_data,
images = [load_image(image)[0] for image in image_data] )
image_inputs = self.process_mm_data(input_text=input_text, images=images) mm_items, input_ids, _ = self.process_and_combine_mm_data(
image_inputs["data_hashes"] = [hash(str(image_data))] base_output, self.mm_tokens
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
image_inputs["mm_items"] = [
MultimodalDataItem(
feature=image_inputs["pixel_values"], modality=Modality.IMAGE
) )
]
return image_inputs return {
"input_ids": input_ids.tolist(),
"mm_items": mm_items,
}
...@@ -33,9 +33,9 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor): ...@@ -33,9 +33,9 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build( self.mm_tokens = MultimodalSpecialTokens(
_processor image_token="<image>", image_token_id=self._processor.image_token_id
) ).build(_processor)
async def process_mm_data_async( async def process_mm_data_async(
self, self,
...@@ -50,36 +50,16 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor): ...@@ -50,36 +50,16 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
input_text, input_text,
image_data=image_data, image_data=image_data,
multimodal_tokens=self.mm_tokens, multimodal_tokens=self.mm_tokens,
max_req_input_len=max_req_input_len,
) )
res = self.process_mm_data( mm_items, input_ids, _ = self.process_and_combine_mm_data(
input_text=base_output.input_text, base_output,
images=base_output.images, self.mm_tokens,
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
conversations=base_output.input_text, conversations=base_output.input_text,
) )
images_seq_mask = res["images_seq_mask"]
images_spatial_crop = res["images_spatial_crop"]
batched_images_spatial_crop = []
batched_images_spatial_crop.append(images_spatial_crop)
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
items = []
input_ids = res["input_ids"]
image_offsets = self.get_mm_items_offset(
input_ids=input_ids, mm_token_id=self._processor.image_token_id
)
item = MultimodalDataItem(
feature=res["images"],
offsets=image_offsets,
modality=Modality.IMAGE,
image_emb_mask=images_seq_mask,
image_spatial_crop=batched_images_spatial_crop,
)
items += [item]
return { return {
"mm_items": items, "mm_items": mm_items,
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
"im_token_id": self._processor.image_token_id, "im_token_id": self._processor.image_token_id,
} }
...@@ -33,7 +33,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): ...@@ -33,7 +33,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
image_data: List[Union[str, bytes, Dict]], image_data: List[Union[str, bytes, Dict]],
input_text, input_text,
request_obj, request_obj,
max_req_input_len,
*args, *args,
**kwargs, **kwargs,
): ):
...@@ -41,7 +40,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): ...@@ -41,7 +40,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
multimodal_tokens=self.mm_tokens, multimodal_tokens=self.mm_tokens,
max_req_input_len=max_req_input_len,
discard_alpha_channel=True, discard_alpha_channel=True,
) )
......
...@@ -54,7 +54,6 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor): ...@@ -54,7 +54,6 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
audio_data: Optional[List[Union[str, bytes, Dict]]] = None, audio_data: Optional[List[Union[str, bytes, Dict]]] = None,
input_text: str = "", input_text: str = "",
request_obj=None, request_obj=None,
max_req_input_len: int = 0,
*args, *args,
**kwargs, **kwargs,
): ):
...@@ -63,7 +62,6 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor): ...@@ -63,7 +62,6 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
audio_data=audio_data, audio_data=audio_data,
max_req_input_len=max_req_input_len,
multimodal_tokens=self.mm_tokens, multimodal_tokens=self.mm_tokens,
) )
......
...@@ -170,13 +170,12 @@ class InternVLImageProcessor(BaseMultimodalProcessor): ...@@ -170,13 +170,12 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
return pixel_values, num_patches_list return pixel_values, num_patches_list
async def process_mm_data_async( async def process_mm_data_async(
self, image_data, input_text, request_obj, max_req_input_len, **kwargs self, image_data, input_text, request_obj, **kwargs
): ):
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
multimodal_tokens=self.mm_tokens, multimodal_tokens=self.mm_tokens,
max_req_input_len=max_req_input_len,
discard_alpha_channel=True, discard_alpha_channel=True,
) )
......
...@@ -11,52 +11,35 @@ from sglang.srt.multimodal.processors.base_processor import ( ...@@ -11,52 +11,35 @@ from sglang.srt.multimodal.processors.base_processor import (
class JanusProImageProcessor(BaseMultimodalProcessor): class JanusProImageProcessor(BaseMultimodalProcessor):
models = [MultiModalityCausalLM] models = [MultiModalityCausalLM]
def __init__(self, hf_config, server_args, processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, processor) super().__init__(hf_config, server_args, _processor)
self.mm_tokens = MultimodalSpecialTokens( self.mm_tokens = MultimodalSpecialTokens(
image_token=processor.image_token image_token=_processor.image_token,
).build(processor) image_token_id=_processor.image_id,
).build(_processor)
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
input_text, input_text,
request_obj, request_obj,
max_req_input_len,
**kwargs, **kwargs,
): ):
processor = self._processor
base_out = self.load_mm_data( base_out = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
multimodal_tokens=self.mm_tokens, multimodal_tokens=self.mm_tokens,
max_req_input_len=max_req_input_len,
) )
images = base_out.images mm_items, input_ids, _ = self.process_and_combine_mm_data(
res = self.process_mm_data( base_out, self.mm_tokens, prompt=base_out.input_text
input_text=base_out.input_text,
prompt=base_out.input_text,
images=images,
) )
input_ids = res["input_ids"].flatten()
image_offsets = self.get_mm_items_offset(
input_ids=input_ids, mm_token_id=processor.image_id
)
return { return {
"mm_items": [ "mm_items": mm_items,
MultimodalDataItem(
feature=res["pixel_values"],
image_emb_mask=res["images_emb_mask"],
offsets=image_offsets,
modality=Modality.IMAGE,
)
],
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
"im_start_id": processor.image_start_id, "im_start_id": self._processor.image_start_id,
"im_end_id": processor.image_end_id, "im_end_id": self._processor.image_end_id,
"im_token_id": processor.image_id, "im_token_id": self.mm_tokens.image_token_id,
} }
...@@ -26,7 +26,6 @@ class KimiVLImageProcessor(SGLangBaseProcessor): ...@@ -26,7 +26,6 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
image_data: List[Union[str, bytes, Dict]], image_data: List[Union[str, bytes, Dict]],
input_text, input_text,
request_obj, request_obj,
max_req_input_len,
*args, *args,
**kwargs, **kwargs,
): ):
...@@ -34,7 +33,6 @@ class KimiVLImageProcessor(SGLangBaseProcessor): ...@@ -34,7 +33,6 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
multimodal_tokens=self.mm_tokens, multimodal_tokens=self.mm_tokens,
max_req_input_len=max_req_input_len,
) )
mm_items, input_ids, _ = self.process_and_combine_mm_data( mm_items, input_ids, _ = self.process_and_combine_mm_data(
......
...@@ -159,7 +159,9 @@ class LlavaImageProcessor(BaseMultimodalProcessor): ...@@ -159,7 +159,9 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
"mm_items": [ "mm_items": [
MultimodalDataItem( MultimodalDataItem(
feature=pixel_values, feature=pixel_values,
image_sizes=image_sizes, model_specific_data={
"image_sizes": image_sizes,
},
modality=modality, modality=modality,
) )
], ],
......
...@@ -17,10 +17,21 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -17,10 +17,21 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
# Collect special token ids
tokenizer = self._processor.tokenizer
self.slice_start_id = getattr(tokenizer, "slice_start_id", None)
self.slice_end_id = getattr(tokenizer, "slice_end_id", None)
self.audio_start_id = getattr(tokenizer, "audio_start_id", None)
self.audio_end_id = getattr(tokenizer, "audio_end_id", None)
self.im_start_id = getattr(tokenizer, "im_start_id", None)
self.im_end_id = getattr(tokenizer, "im_end_id", None)
self.im_token_id = getattr(tokenizer, "unk_id", None)
self.mm_tokens = MultimodalSpecialTokens( self.mm_tokens = MultimodalSpecialTokens(
image_token="(<image>./</image>)", image_token="(<image>./</image>)",
audio_token="(<audio>./</audio>)", audio_token="(<audio>./</audio>)",
video_token="(<video>./</video>)", video_token="(<video>./</video>)",
image_token_id=self.im_token_id,
).build(_processor) ).build(_processor)
async def process_mm_data_async( async def process_mm_data_async(
...@@ -29,12 +40,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -29,12 +40,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audio_data: List[Union[str, bytes]], audio_data: List[Union[str, bytes]],
input_text, input_text,
request_obj, request_obj,
max_req_input_len,
**kwargs, **kwargs,
): ):
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
max_req_input_len=max_req_input_len,
audio_data=audio_data, audio_data=audio_data,
image_data=image_data, image_data=image_data,
multimodal_tokens=self.mm_tokens, multimodal_tokens=self.mm_tokens,
...@@ -48,24 +57,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -48,24 +57,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audios=base_output.audios, audios=base_output.audios,
) )
# Collect special token ids
tokenizer = self._processor.tokenizer
slice_start_id, slice_end_id, audio_start_id, audio_end_id = (
None,
None,
None,
None,
)
if tokenizer.slice_start_id:
slice_start_id = tokenizer.slice_start_id
slice_end_id = tokenizer.slice_end_id
if hasattr(tokenizer, "audio_start_id"):
audio_start_id = tokenizer.audio_start_id
audio_end_id = tokenizer.audio_end_id
im_start_id = tokenizer.im_start_id
im_end_id = tokenizer.im_end_id
im_token_id = tokenizer.unk_id
pixel_values = res["pixel_values"] pixel_values = res["pixel_values"]
tgt_sizes = res["tgt_sizes"] tgt_sizes = res["tgt_sizes"]
...@@ -102,10 +93,12 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -102,10 +93,12 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
items = [] items = []
input_ids = res["input_ids"].flatten() input_ids = res["input_ids"].flatten()
image_offsets = self.get_mm_items_offset_by_pair( image_offsets = self.get_mm_items_offset_by_pair(
input_ids=input_ids, mm_start_id=im_start_id, mm_end_id=im_end_id input_ids=input_ids, mm_start_id=self.im_start_id, mm_end_id=self.im_end_id
) )
slice_offsets = self.get_mm_items_offset_by_pair( slice_offsets = self.get_mm_items_offset_by_pair(
input_ids=input_ids, mm_start_id=slice_start_id, mm_end_id=slice_end_id input_ids=input_ids,
mm_start_id=self.slice_start_id,
mm_end_id=self.slice_end_id,
) )
image_offsets.extend(slice_offsets) image_offsets.extend(slice_offsets)
image_offsets = sorted(image_offsets) image_offsets = sorted(image_offsets)
...@@ -114,7 +107,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -114,7 +107,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
item = MultimodalDataItem( item = MultimodalDataItem(
feature=pixel_values, feature=pixel_values,
offsets=image_offsets, offsets=image_offsets,
tgt_size=tgt_sizes_flat, model_specific_data={"tgt_size": tgt_sizes_flat},
modality=Modality.IMAGE, modality=Modality.IMAGE,
) )
items += [item] items += [item]
...@@ -124,17 +117,17 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -124,17 +117,17 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
and res["audio_features"] is not None and res["audio_features"] is not None
and len(res["audio_features"]) != 0 and len(res["audio_features"]) != 0
): ):
if audio_start_id is not None and audio_end_id is not None: if self.audio_start_id is not None and self.audio_end_id is not None:
audio_offsets = self.get_mm_items_offset_by_pair( audio_offsets = self.get_mm_items_offset_by_pair(
input_ids=input_ids, input_ids=input_ids,
mm_start_id=audio_start_id, mm_start_id=self.audio_start_id,
mm_end_id=audio_end_id, mm_end_id=self.audio_end_id,
) )
else: else:
audio_offsets = None audio_offsets = None
item = MultimodalDataItem( item = MultimodalDataItem(
feature=[res["audio_features"]], feature=[res["audio_features"]],
audio_feature_lens=res["audio_feature_lens"], model_specific_data={"audio_feature_lens": res["audio_feature_lens"]},
offsets=audio_offsets, offsets=audio_offsets,
modality=Modality.AUDIO, modality=Modality.AUDIO,
) )
...@@ -142,11 +135,11 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -142,11 +135,11 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
return { return {
"mm_items": items, "mm_items": items,
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
"audio_start_id": audio_start_id, "audio_start_id": self.audio_start_id,
"audio_end_id": audio_end_id, "audio_end_id": self.audio_end_id,
"im_token_id": im_token_id, "im_token_id": self.im_token_id,
"im_start_id": im_start_id, "im_start_id": self.im_start_id,
"im_end_id": im_end_id, "im_end_id": self.im_end_id,
"slice_start_id": slice_start_id, "slice_start_id": self.slice_start_id,
"slice_end_id": slice_end_id, "slice_end_id": self.slice_end_id,
} }
from typing import List, Union from typing import List, Union
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.mllama import MllamaForConditionalGeneration from sglang.srt.models.mllama import MllamaForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor from sglang.srt.multimodal.processors.base_processor import (
from sglang.srt.utils import load_image BaseMultimodalProcessor,
MultimodalSpecialTokens,
)
class MllamaImageProcessor(BaseMultimodalProcessor): class MllamaImageProcessor(BaseMultimodalProcessor):
...@@ -11,24 +12,26 @@ class MllamaImageProcessor(BaseMultimodalProcessor): ...@@ -11,24 +12,26 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.mm_tokens = MultimodalSpecialTokens(
image_token=self._processor.image_token,
image_token_id=self._processor.image_token_id,
).build(_processor)
async def process_mm_data_async( async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
): ):
if isinstance(input_text, list): base_out = self.load_mm_data(
assert len(input_text) and isinstance(input_text[0], int) prompt=input_text,
input_text = self._processor.tokenizer.decode(input_text) image_data=image_data,
multimodal_tokens=self.mm_tokens,
)
images = [load_image(image)[0] for image in image_data] mm_items, input_ids, _ = self.process_and_combine_mm_data(
image_inputs = self.process_mm_data(input_text=input_text, images=images) base_out, self.mm_tokens
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
image_inputs["mm_items"] = [
MultimodalDataItem(
feature=image_inputs["pixel_values"],
aspect_ratio_id=image_inputs["aspect_ratio_ids"],
aspect_ratio_mask=image_inputs["aspect_ratio_mask"],
modality=Modality.IMAGE,
) )
]
return image_inputs return {
"mm_items": mm_items,
"input_ids": input_ids.tolist(),
"im_token_id": self.mm_tokens.image_token_id,
}
...@@ -27,13 +27,13 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor): ...@@ -27,13 +27,13 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
self.image_token_index = hf_config.image_token_index self.image_token_index = hf_config.image_token_index
self.multimodal_tokens = MultimodalSpecialTokens( self.multimodal_tokens = MultimodalSpecialTokens(
image_token=_processor.image_token, image_token=_processor.image_token,
image_token_id=self.image_token_index,
).build(_processor) ).build(_processor)
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
input_text, input_text,
max_req_input_len=None,
*args, *args,
**kwargs, **kwargs,
): ):
...@@ -45,7 +45,6 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor): ...@@ -45,7 +45,6 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
processed_data = self.load_mm_data( processed_data = self.load_mm_data(
prompt=input_text, prompt=input_text,
multimodal_tokens=self.multimodal_tokens, multimodal_tokens=self.multimodal_tokens,
max_req_input_len=max_req_input_len or 4096,
image_data=image_data, image_data=image_data,
return_text=True, return_text=True,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment