Unverified Commit 3a911b85 authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

Refactor mm processors and Enable mixed modality processing (#7629)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
parent 886d3449
...@@ -125,74 +125,38 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa ...@@ -125,74 +125,38 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
e.g. <image><image>....<image>, or <audio><audio>...<audio> e.g. <image><image>....<image>, or <audio><audio>...<audio>
""" """
def __init__(self, token_ids: List[int]) -> None:
self.token_ids = token_ids
def pad_input_tokens( def pad_input_tokens(
self, input_ids: List[int], mm_inputs: MultimodalInputs self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]: ) -> List[int]:
""" """
Finds contiguous regions of tokens matching `self.token_ids` in `input_ids` Replaces multimodal tokens in input_ids with corresponding pad_values from mm_items.
and replaces each region with the corresponding `pad_value` from `mm_inputs.mm_items`. Each modality (image, audio, video) is handled separately based on its token_id.
""" """
pad_values = [item.pad_value for item in mm_inputs.mm_items] if not input_ids or not mm_inputs.mm_items:
if not pad_values:
# No multimodal items, return original input_ids
return input_ids return input_ids
if not input_ids:
return []
input_ids_tensor = torch.tensor(input_ids) input_ids_tensor = torch.tensor(input_ids)
device = input_ids_tensor.device
token_ids_tensor = torch.tensor(self.token_ids, device=device)
mask = torch.isin(input_ids_tensor, token_ids_tensor)
if not mask.any(): # Create mapping of token_ids to pad_values for each modality
# No tokens match token_ids, return original input_ids token_to_pad_mapping = {}
return input_ids
# Find contiguous regions for item in mm_inputs.mm_items:
padded_mask = torch.cat( if item.is_image() and mm_inputs.im_token_id is not None:
( token_to_pad_mapping[mm_inputs.im_token_id] = item.pad_value
torch.tensor([False], device=device), elif item.is_audio() and mm_inputs.audio_token_id is not None:
mask, token_to_pad_mapping[mm_inputs.audio_token_id] = item.pad_value
torch.tensor([False], device=device), elif item.is_video() and mm_inputs.video_token_id is not None:
) token_to_pad_mapping[mm_inputs.video_token_id] = item.pad_value
)
# Find indices where the mask value changes
diff_indices = torch.where(padded_mask[1:] != padded_mask[:-1])[0]
# Start indices are where False changes to True
starts = diff_indices[::2]
# End indices are where True changes to False (exclusive index)
ends = diff_indices[1::2]
# Check if the number of regions matches the number of pad values
if len(starts) != len(pad_values):
# Maybe log a warning here?
num_regions = len(starts)
num_pad_values = len(pad_values)
if num_regions > 0 and num_pad_values > 0:
pad_values = (pad_values * (num_regions // num_pad_values + 1))[
:num_regions
]
else: # If no regions or no pad_values, this loop won't run anyway.
pad_values = [] # Ensure pad_values is empty if starts is empty
# Create a copy to modify
output_ids_tensor = input_ids_tensor.clone()
# Replace tokens in each region with the corresponding pad value
# Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
for i in range(min(len(starts), len(pad_values))):
start_idx = starts[i]
end_idx = ends[i]
pad_value = pad_values[i]
if pad_value is not None: # Ensure pad_value is not None before assignment
output_ids_tensor[start_idx:end_idx] = pad_value
else: else:
logger.warning(f"Skipping region {i} due to None pad_value.") raise ValueError(f"No multimodal token id provided for {item.modality}")
return output_ids_tensor.tolist()
# Apply replacements for all tokens at once
for token_id, pad_value in token_to_pad_mapping.items():
input_ids_tensor[input_ids_tensor == token_id] = pad_value
ret_input_ids = input_ids_tensor.tolist()
return ret_input_ids
embedding_cache = None embedding_cache = None
......
...@@ -174,6 +174,15 @@ class Modality(Enum): ...@@ -174,6 +174,15 @@ class Modality(Enum):
VIDEO = auto() VIDEO = auto()
AUDIO = auto() AUDIO = auto()
@staticmethod
def from_str(modality_str: str):
try:
return Modality[modality_str.upper()]
except KeyError:
raise ValueError(
f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
)
@dataclasses.dataclass @dataclasses.dataclass
class MultimodalDataItem: class MultimodalDataItem:
......
...@@ -482,20 +482,25 @@ class TokenizerManager: ...@@ -482,20 +482,25 @@ class TokenizerManager:
token_type_ids = encoded.get("token_type_ids", [None])[0] token_type_ids = encoded.get("token_type_ids", [None])[0]
if self.mm_processor and obj.contains_mm_input(): if self.mm_processor and obj.contains_mm_input():
image_inputs: Dict = await self.mm_processor.process_mm_data_async( if not isinstance(obj.image_data, list):
obj.image_data = [obj.image_data]
if not isinstance(obj.audio_data, list):
obj.audio_data = [obj.audio_data]
mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
image_data=obj.image_data, image_data=obj.image_data,
audio_data=obj.audio_data,
input_text=input_text or input_ids, input_text=input_text or input_ids,
request_obj=obj, request_obj=obj,
max_req_input_len=self.max_req_input_len, max_req_input_len=self.max_req_input_len,
) )
if image_inputs and "input_ids" in image_inputs: if mm_inputs and "input_ids" in mm_inputs:
input_ids = image_inputs["input_ids"] input_ids = mm_inputs["input_ids"]
else: else:
image_inputs: Optional[Dict] = None mm_inputs = None
self._validate_one_request(obj, input_ids) self._validate_one_request(obj, input_ids)
return self._create_tokenized_object( return self._create_tokenized_object(
obj, input_text, input_ids, input_embeds, image_inputs, token_type_ids obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
) )
def _validate_one_request( def _validate_one_request(
...@@ -559,7 +564,7 @@ class TokenizerManager: ...@@ -559,7 +564,7 @@ class TokenizerManager:
input_text: str, input_text: str,
input_ids: List[int], input_ids: List[int],
input_embeds: Optional[Union[List[float], None]] = None, input_embeds: Optional[Union[List[float], None]] = None,
image_inputs: Optional[Dict] = None, mm_inputs: Optional[Dict] = None,
token_type_ids: Optional[List[int]] = None, token_type_ids: Optional[List[int]] = None,
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]: ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
"""Create a tokenized request object from common parameters.""" """Create a tokenized request object from common parameters."""
...@@ -584,7 +589,7 @@ class TokenizerManager: ...@@ -584,7 +589,7 @@ class TokenizerManager:
obj.rid, obj.rid,
input_text, input_text,
input_ids, input_ids,
image_inputs, mm_inputs,
sampling_params, sampling_params,
obj.return_logprob, obj.return_logprob,
obj.logprob_start_len, obj.logprob_start_len,
...@@ -606,7 +611,7 @@ class TokenizerManager: ...@@ -606,7 +611,7 @@ class TokenizerManager:
obj.rid, obj.rid,
input_text, input_text,
input_ids, input_ids,
image_inputs, mm_inputs,
token_type_ids, token_type_ids,
sampling_params, sampling_params,
) )
...@@ -644,9 +649,9 @@ class TokenizerManager: ...@@ -644,9 +649,9 @@ class TokenizerManager:
) -> None: ) -> None:
"""Validate constraints for batch tokenization processing.""" """Validate constraints for batch tokenization processing."""
for i in range(batch_size): for i in range(batch_size):
if self.is_generation and obj[i].image_data: if self.is_generation and obj[i].contains_mm_input():
raise ValueError( raise ValueError(
"For image input processing do not set `enable_tokenizer_batch_encode`." "For multimodal input processing do not set `enable_tokenizer_batch_encode`."
) )
if obj[i].input_ids is not None: if obj[i].input_ids is not None:
raise ValueError( raise ValueError(
......
...@@ -253,11 +253,9 @@ class DeepseekVL2ForCausalLM(nn.Module): ...@@ -253,11 +253,9 @@ class DeepseekVL2ForCausalLM(nn.Module):
weights_loader = getattr(param, "weight_loader", default_weight_loader) weights_loader = getattr(param, "weight_loader", default_weight_loader)
weights_loader(param, loaded_weight) weights_loader(param, loaded_weight)
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
helper = MultiModalityDataPaddingPatternMultimodalTokens( pattern = MultiModalityDataPaddingPatternMultimodalTokens()
[image_inputs.im_token_id] return pattern.pad_input_tokens(input_ids, mm_inputs)
)
return helper.pad_input_tokens(input_ids, image_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]): def get_image_feature(self, items: List[MultimodalDataItem]):
......
...@@ -21,7 +21,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -21,7 +21,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.mm_utils import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
...@@ -244,26 +244,11 @@ class Gemma3nForConditionalGeneration(PreTrainedModel): ...@@ -244,26 +244,11 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
def pad_input_ids( def pad_input_ids(
self, self,
input_ids: List[int], input_ids: List[int],
mm_inputs: Optional[MultimodalInputs] = None, mm_inputs: MultimodalInputs,
) -> List[int]: ) -> List[int]:
"""Pad input IDs with image and audio tokens.""" """Pad input IDs with image and audio tokens."""
if mm_inputs is None: pattern = MultiModalityDataPaddingPatternMultimodalTokens()
return input_ids return pattern.pad_input_tokens(input_ids, mm_inputs)
# Collect available media token pairs
media_token_pairs = []
for attr_name in ["im_start_id", "audio_start_id"]:
if hasattr(mm_inputs, attr_name):
start_id = getattr(mm_inputs, attr_name)
end_id = getattr(mm_inputs, attr_name.replace("start", "end"))
media_token_pairs.append((start_id, end_id))
# Apply padding pattern if we have media tokens
if media_token_pairs:
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, mm_inputs)
return input_ids
def get_input_embeddings(self) -> nn.Embedding: def get_input_embeddings(self) -> nn.Embedding:
return self.language_model.get_input_embeddings() return self.language_model.get_input_embeddings()
...@@ -431,7 +416,6 @@ class Gemma3nForConditionalGeneration(PreTrainedModel): ...@@ -431,7 +416,6 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
) )
positions += 1 positions += 1
if input_ids is not None: if input_ids is not None:
# Prepare per-layer inputs from inputs_ids # Prepare per-layer inputs from inputs_ids
per_layer_inputs_mask = torch.logical_and( per_layer_inputs_mask = torch.logical_and(
......
...@@ -154,8 +154,7 @@ class KimiVLForConditionalGeneration(nn.Module): ...@@ -154,8 +154,7 @@ class KimiVLForConditionalGeneration(nn.Module):
return res return res
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs pattern = MultiModalityDataPaddingPatternMultimodalTokens()
pattern = MultiModalityDataPaddingPatternMultimodalTokens(mm_inputs.im_token_id)
return pattern.pad_input_tokens(input_ids, mm_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
def forward( def forward(
......
...@@ -50,10 +50,7 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -50,10 +50,7 @@ class Llama4ForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config.text_config) self.logits_processor = LogitsProcessor(config.text_config)
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs pattern = MultiModalityDataPaddingPatternMultimodalTokens()
im_token_id: int = mm_inputs.im_token_id
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
return pattern.pad_input_tokens(input_ids, mm_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature( def get_image_feature(
......
...@@ -446,9 +446,7 @@ class Phi4MMForCausalLM(nn.Module): ...@@ -446,9 +446,7 @@ class Phi4MMForCausalLM(nn.Module):
return hidden_states return hidden_states
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs pattern = MultiModalityDataPaddingPatternMultimodalTokens()
im_token_id: int = mm_inputs.im_token_id
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
return pattern.pad_input_tokens(input_ids, mm_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
def should_apply_lora(self, module_name: str) -> bool: def should_apply_lora(self, module_name: str) -> bool:
......
...@@ -268,15 +268,14 @@ class PixtralHFVisionModel(nn.Module): ...@@ -268,15 +268,14 @@ class PixtralHFVisionModel(nn.Module):
DEFAULT_IMAGE_TOKEN_ID = 10 DEFAULT_IMAGE_TOKEN_ID = 10
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
return self.input_padder.pad_input_tokens(input_ids, image_inputs) return self.input_padder.pad_input_tokens(input_ids, mm_inputs)
def __init__( def __init__(
self, self,
config: PixtralVisionConfig, config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
*, *,
image_token_id: int = DEFAULT_IMAGE_TOKEN_ID,
num_hidden_layers_override: Optional[int] = None, num_hidden_layers_override: Optional[int] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
...@@ -314,11 +313,8 @@ class PixtralHFVisionModel(nn.Module): ...@@ -314,11 +313,8 @@ class PixtralHFVisionModel(nn.Module):
) )
# Initialize patch position embedding # Initialize patch position embedding
self.image_token_id = image_token_id
self.patch_positional_embedding = PixtralRotaryEmbedding(config) self.patch_positional_embedding = PixtralRotaryEmbedding(config)
self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens( self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens()
[self.image_token_id]
)
@property @property
def dtype(self): def dtype(self):
......
...@@ -493,9 +493,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -493,9 +493,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs pattern = MultiModalityDataPaddingPatternMultimodalTokens()
im_token_id: int = mm_inputs.im_token_id
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
return pattern.pad_input_tokens(input_ids, mm_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
......
...@@ -479,10 +479,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -479,10 +479,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs pattern = MultiModalityDataPaddingPatternMultimodalTokens()
im_token_id: int = mm_inputs.im_token_id
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
return pattern.pad_input_tokens(input_ids, mm_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
......
...@@ -270,15 +270,10 @@ class VILAForConditionalGeneration(nn.Module): ...@@ -270,15 +270,10 @@ class VILAForConditionalGeneration(nn.Module):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
def pad_input_ids( def pad_input_ids(
self, self, input_ids: List[int], mm_inputs: MultimodalInputs
input_ids: List[int],
image_inputs: MultimodalInputs,
) -> List[int]: ) -> List[int]:
pattern = MultiModalityDataPaddingPatternMultimodalTokens( pattern = MultiModalityDataPaddingPatternMultimodalTokens()
token_ids=[self.config.image_token_id], return pattern.pad_input_tokens(input_ids, mm_inputs)
)
return pattern.pad_input_tokens(input_ids, image_inputs)
##### BEGIN COPY modeling_vila.py ##### ##### BEGIN COPY modeling_vila.py #####
......
...@@ -17,15 +17,6 @@ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem ...@@ -17,15 +17,6 @@ from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.utils import encode_video, load_audio, load_image from sglang.srt.utils import encode_video, load_audio, load_image
class MultimodalInputFormat(Enum):
"""Enum for different multimodal input formats."""
RAW_IMAGES = "raw_images"
PRECOMPUTED_FEATURES = "precomputed_features"
PIXEL_VALUES = "pixel_values"
AUDIO = "audio"
@dataclasses.dataclass @dataclasses.dataclass
class BaseMultiModalProcessorOutput: class BaseMultiModalProcessorOutput:
# input_text, with each frame of video/image represented with a image_token # input_text, with each frame of video/image represented with a image_token
...@@ -110,18 +101,45 @@ class BaseMultimodalProcessor(ABC): ...@@ -110,18 +101,45 @@ class BaseMultimodalProcessor(ABC):
max_workers=int(os.environ.get("SGLANG_CPU_WORKERS", os.cpu_count())), max_workers=int(os.environ.get("SGLANG_CPU_WORKERS", os.cpu_count())),
) )
# Mapping from attribute names to modality types
self.ATTR_NAME_TO_MODALITY = {
# Image-related attributes
"pixel_values": Modality.IMAGE,
"image_sizes": Modality.IMAGE,
"image_grid_thw": Modality.IMAGE,
"image_emb_mask": Modality.IMAGE,
"image_spatial_crop": Modality.IMAGE,
"tgt_size": Modality.IMAGE,
"image_grid_hws": Modality.IMAGE,
"aspect_ratio_id": Modality.IMAGE,
"aspect_ratio_mask": Modality.IMAGE,
"second_per_grid_ts": Modality.IMAGE,
# Audio-related attributes
"audio_features": Modality.AUDIO,
"audio_feature_lens": Modality.AUDIO,
"input_features": Modality.AUDIO,
"input_features_mask": Modality.AUDIO,
# Video-related attributes
"video_grid_thws": Modality.VIDEO,
# Generic attributes that could apply to multiple modalities
# "precomputed_features" - handled specially as it can be any modality
}
def process_mm_data( def process_mm_data(
self, input_text, images=None, videos=None, audios=None, **kwargs self, input_text, images=None, videos=None, audios=None, **kwargs
): ):
""" """
process multimodal data with transformers AutoProcessor process multimodal data with transformers AutoProcessor
""" """
if images is not None: if images:
kwargs["images"] = images kwargs["images"] = images
if videos is not None: if videos:
kwargs["videos"] = videos kwargs["videos"] = videos
if audios is not None: if audios:
kwargs["audios"] = audios kwargs["audios"] = audios
if self.__class__.__name__ == "Gemma3nSGLangProcessor":
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
kwargs["audio"] = audios
processor = self._processor processor = self._processor
if hasattr(processor, "image_processor") and isinstance( if hasattr(processor, "image_processor") and isinstance(
...@@ -144,6 +162,7 @@ class BaseMultimodalProcessor(ABC): ...@@ -144,6 +162,7 @@ class BaseMultimodalProcessor(ABC):
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data, image_data,
audio_data,
input_text, input_text,
request_obj, request_obj,
max_req_input_len, max_req_input_len,
...@@ -418,175 +437,137 @@ class BaseMultimodalProcessor(ABC): ...@@ -418,175 +437,137 @@ class BaseMultimodalProcessor(ABC):
values[k] = v values[k] = v
return values return values
def collect_mm_items_from_processor_output(
self, data_dict: dict
) -> List[MultimodalDataItem]:
"""Create mm_items directly from processor output."""
items = {} # modality -> MultimodalDataItem
for attr_name, value in data_dict.items():
if attr_name == "input_ids":
continue
# Get modality for this attribute
modality = self.ATTR_NAME_TO_MODALITY.get(attr_name)
if not modality and attr_name == "precomputed_features":
modality_str = data_dict.get("modality")
try:
modality = (
Modality.from_str(modality_str)
if modality_str
else Modality.IMAGE
)
except ValueError:
modality = Modality.IMAGE
if modality:
# Create item if needed
if modality not in items:
items[modality] = MultimodalDataItem(modality=modality)
# Set attribute
if hasattr(items[modality], attr_name):
setattr(items[modality], attr_name, value)
return list(items.values())
def _process_and_collect_mm_items(
self, input_text: str, images=None, audios=None, videos=None, **kwargs
) -> Tuple[List[MultimodalDataItem], torch.Tensor]:
"""
Helper method to process multimodal data and create mm_items in one step.
Returns:
Tuple of (created mm_items, input_ids)
"""
ret = self.process_mm_data(
input_text=input_text, images=images, audios=audios, videos=videos, **kwargs
)
input_ids = ret["input_ids"].flatten()
collected_items = self.collect_mm_items_from_processor_output(ret)
return collected_items, input_ids
def process_and_combine_mm_data( def process_and_combine_mm_data(
self, base_output: BaseMultiModalProcessorOutput self, base_output: BaseMultiModalProcessorOutput
) -> Tuple[Optional[MultimodalDataItem], torch.Tensor]: ) -> Tuple[List[MultimodalDataItem], torch.Tensor]:
""" """
Process multimodal data and return the combined multimodal item and input_ids. Process multimodal data and return the combined multimodal items and input_ids.
Handles all three input formats at the same abstraction level. Supports mixed modalities (images and audio in the same request).
Returns: Returns:
Tuple of (combined_mm_item, input_ids) Tuple of (list of mm_items, input_ids)
""" """
# Collect all items and categorize them
all_items = (base_output.images or []) + (base_output.audios or [])
def tokenize_text(input_text: str) -> torch.Tensor: # Handle text-only case
"""Tokenize input text.""" if not all_items:
return self._processor.tokenizer( input_ids = self._processor.tokenizer(
input_text, base_output.input_text,
return_tensors="pt", return_tensors="pt",
add_special_tokens=True, add_special_tokens=True,
).input_ids.flatten() ).input_ids.flatten()
return [], input_ids
dict_items, raw_images, raw_audios = [], [], []
for item in all_items:
if isinstance(item, dict):
dict_items.append(item)
elif isinstance(item, Image.Image):
raw_images.append(item)
elif isinstance(item, np.ndarray):
raw_audios.append(item)
else:
raise ValueError(f"Unknown multimodal item type: {type(item)}")
def categorize_mm_inputs(mm_inputs: List) -> MultimodalInputFormat: # Process items and get input_ids
"""Categorize multimodal inputs and validate consistency.""" all_collected_items = []
try: input_ids = None
has_image = False
has_pixel_values = False
has_precomputed_features = False
has_audio = False
for mm_input in mm_inputs:
if isinstance(mm_input, Image.Image):
has_image = True
elif isinstance(mm_input, np.ndarray):
has_audio = True
elif isinstance(mm_input, dict):
if mm_input.get("precomputed_features", None) is not None:
has_precomputed_features = True
elif mm_input.get("pixel_values", None) is not None:
has_pixel_values = True
else:
raise ValueError(
f"Invalid multimodal input: {mm_input}, expected dict with pixel_values or precomputed_features"
)
else:
raise ValueError(
f"Invalid multimodal input: {mm_input}, expected Image.Image or dict"
)
# Validate format consistency
format_count = sum(
[has_image, has_pixel_values, has_precomputed_features, has_audio]
)
if format_count > 1:
raise ValueError(
"Unsupported: mixture of multimodal input formats. "
f"Found formats: image={has_image}, pixel_values={has_pixel_values}, "
f"precomputed_features={has_precomputed_features}, audio={has_audio}"
)
if has_image: # Handle dict items (already processed)
return MultimodalInputFormat.RAW_IMAGES for dict_item in dict_items:
elif has_precomputed_features: all_collected_items.extend(
return MultimodalInputFormat.PRECOMPUTED_FEATURES self.collect_mm_items_from_processor_output(dict_item)
elif has_pixel_values:
return MultimodalInputFormat.PIXEL_VALUES
elif has_audio:
return MultimodalInputFormat.AUDIO
else:
raise ValueError("No valid multimodal input format found")
except Exception as e:
raise ValueError(f"Failed to categorize inputs: {e}")
def process_raw_images(
base_output: BaseMultiModalProcessorOutput,
) -> Tuple[MultimodalDataItem, torch.Tensor]:
"""Process raw Image.Image objects using transformers processor."""
ret = self.process_mm_data(
input_text=base_output.input_text,
images=base_output.images,
)
combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
# Copy all fields from processor output except input_ids
for key, value in ret.items():
if key != "input_ids" and hasattr(combined_mm_item, key):
setattr(combined_mm_item, key, value)
input_ids = ret["input_ids"].flatten()
return combined_mm_item, input_ids
def process_precomputed_features(
base_output: BaseMultiModalProcessorOutput,
) -> Tuple[MultimodalDataItem, torch.Tensor]:
"""Process inputs with precomputed features."""
combined_mm_item = MultimodalDataItem(modality=Modality.IMAGE)
combined_mm_item.precomputed_features = self._extract_processor_features(
base_output.images, "precomputed_features"
) )
input_ids = tokenize_text(base_output.input_text)
return combined_mm_item, input_ids # Handle raw items (need processing)
if raw_images or raw_audios:
def process_pixel_values( collected_items, input_ids = self._process_and_collect_mm_items(
base_output: BaseMultiModalProcessorOutput,
) -> Tuple[MultimodalDataItem, torch.Tensor]:
"""Process inputs with pixel values."""
values = self._extract_processor_features_from_all_attributes(
base_output.images
)
combined_mm_item = MultimodalDataItem.from_dict(values)
input_ids = tokenize_text(base_output.input_text)
return combined_mm_item, input_ids
def process_audio(
base_output: BaseMultiModalProcessorOutput,
) -> Tuple[MultimodalDataItem, torch.Tensor]:
"""Process inputs with audio."""
ret = self.process_mm_data(
input_text=base_output.input_text, input_text=base_output.input_text,
audio=base_output.audios, # Note: "audio" is for gemma3n only images=raw_images,
audios=raw_audios,
) )
combined_mm_item = MultimodalDataItem(modality=Modality.AUDIO) all_collected_items.extend(collected_items)
for key, value in ret.items():
if key != "input_ids" and hasattr(combined_mm_item, key): # Fallback tokenization if no raw items were processed
setattr(combined_mm_item, key, value) if input_ids is None:
input_ids = ret["input_ids"].flatten() input_ids = self._processor.tokenizer(
return combined_mm_item, input_ids base_output.input_text,
return_tensors="pt",
def finalize_mm_item( add_special_tokens=True,
combined_mm_item: MultimodalDataItem, input_ids: torch.Tensor ).input_ids.flatten()
) -> MultimodalDataItem:
"""Apply common post-processing to the multimodal item.""" # Add offsets to all items
if combined_mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]: for mm_item in all_collected_items:
combined_mm_item.image_offsets = self.get_mm_items_offset( if mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
mm_item.image_offsets = self.get_mm_items_offset(
input_ids=input_ids, input_ids=input_ids,
mm_token_id=self.IM_TOKEN_ID, mm_token_id=self.IM_TOKEN_ID,
) )
elif combined_mm_item.modality == Modality.AUDIO: elif mm_item.modality == Modality.AUDIO:
combined_mm_item.audio_offsets = self.get_mm_items_offset( mm_item.audio_offsets = self.get_mm_items_offset(
input_ids=input_ids, input_ids=input_ids,
mm_token_id=self.AUDIO_TOKEN_ID, mm_token_id=self.AUDIO_TOKEN_ID,
) )
elif combined_mm_item.modality == Modality.VIDEO: elif mm_item.modality == Modality.VIDEO:
combined_mm_item.video_offsets = self.get_mm_items_offset( mm_item.video_offsets = self.get_mm_items_offset(
input_ids=input_ids, input_ids=input_ids,
mm_token_id=self.VIDEO_TOKEN_ID, mm_token_id=self.VIDEO_TOKEN_ID,
) )
else: else:
raise ValueError(f"Unknown modality: {combined_mm_item.modality}") raise ValueError(f"Unknown modality: {mm_item.modality}")
return combined_mm_item
# Main logic - determine input type and handle text-only case
mm_inputs = base_output.images or base_output.audios
if not mm_inputs:
input_ids = tokenize_text(base_output.input_text)
return None, input_ids
# Categorize input formats
input_format = categorize_mm_inputs(mm_inputs)
# Process based on format
if input_format == MultimodalInputFormat.RAW_IMAGES:
combined_mm_item, input_ids = process_raw_images(base_output)
elif input_format == MultimodalInputFormat.PRECOMPUTED_FEATURES:
combined_mm_item, input_ids = process_precomputed_features(base_output)
elif input_format == MultimodalInputFormat.PIXEL_VALUES:
combined_mm_item, input_ids = process_pixel_values(base_output)
elif input_format == MultimodalInputFormat.AUDIO:
combined_mm_item, input_ids = process_audio(base_output)
else:
raise ValueError(f"Unknown input format: {input_format}")
# Finalize with common processing return all_collected_items, input_ids
combined_mm_item = finalize_mm_item(combined_mm_item, input_ids)
return combined_mm_item, input_ids
...@@ -15,20 +15,11 @@ class ClipImageProcessor(BaseMultimodalProcessor): ...@@ -15,20 +15,11 @@ class ClipImageProcessor(BaseMultimodalProcessor):
async def process_mm_data_async( async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
): ):
if not image_data:
return None
if isinstance(input_text, list): if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int) assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text) input_text = self._processor.tokenizer.decode(input_text)
if not isinstance(image_data, list): images = [load_image(image)[0] for image in image_data]
image_data = [image_data]
if len(image_data) > 0:
images = [load_image(image)[0] for image in image_data]
else:
images = load_image(image_data[0])[0]
image_inputs = self.process_mm_data(input_text=input_text, images=images) image_inputs = self.process_mm_data(input_text=input_text, images=images)
image_inputs["data_hashes"] = [hash(str(image_data))] image_inputs["data_hashes"] = [hash(str(image_data))]
......
...@@ -44,17 +44,10 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor): ...@@ -44,17 +44,10 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
*args, *args,
**kwargs **kwargs
): ):
if not image_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data( base_output = self.load_mm_data(
input_text, input_text,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token), multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
) )
res = self.process_mm_data( res = self.process_mm_data(
......
...@@ -36,11 +36,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): ...@@ -36,11 +36,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
*args, *args,
**kwargs, **kwargs,
): ):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
...@@ -51,11 +46,11 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): ...@@ -51,11 +46,11 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
discard_alpha_channel=True, discard_alpha_channel=True,
) )
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output) mm_items, input_ids = self.process_and_combine_mm_data(base_output)
return { return {
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
"mm_items": [combined_mm_item] if combined_mm_item is not None else [], "mm_items": mm_items,
"im_start_id": self.IM_START_TOKEN_ID, "im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID,
} }
...@@ -59,17 +59,6 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor): ...@@ -59,17 +59,6 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
**kwargs, **kwargs,
): ):
"""Process multimodal data including images and audio.""" """Process multimodal data including images and audio."""
audio_data = request_obj.audio_data
if not image_data and not audio_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
if isinstance(audio_data, str):
audio_data = [audio_data]
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
...@@ -83,13 +72,11 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor): ...@@ -83,13 +72,11 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
), ),
) )
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output) mm_items, input_ids = self.process_and_combine_mm_data(base_output)
return { return {
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
"mm_items": [combined_mm_item] if combined_mm_item is not None else [], "mm_items": mm_items,
"im_start_id": self.IM_START_TOKEN_ID, "im_token_id": self.IM_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID, "audio_token_id": self.AUDIO_TOKEN_ID,
"audio_start_id": self.AUDIO_START_TOKEN_ID,
"audio_end_id": self.AUDIO_END_TOKEN_ID,
} }
...@@ -172,13 +172,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor): ...@@ -172,13 +172,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
async def process_mm_data_async( async def process_mm_data_async(
self, image_data, input_text, request_obj, max_req_input_len, **kwargs self, image_data, input_text, request_obj, max_req_input_len, **kwargs
): ):
if not image_data:
return None
# Ensure image_data is a list
if isinstance(image_data, str):
image_data = [image_data]
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
......
...@@ -22,12 +22,6 @@ class JanusProImageProcessor(BaseMultimodalProcessor): ...@@ -22,12 +22,6 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
max_req_input_len, max_req_input_len,
**kwargs, **kwargs,
): ):
if not image_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
processor = self._processor processor = self._processor
base_out = self.load_mm_data( base_out = self.load_mm_data(
......
...@@ -30,11 +30,6 @@ class KimiVLImageProcessor(SGLangBaseProcessor): ...@@ -30,11 +30,6 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
*args, *args,
**kwargs, **kwargs,
): ):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
...@@ -44,10 +39,10 @@ class KimiVLImageProcessor(SGLangBaseProcessor): ...@@ -44,10 +39,10 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
) )
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output) mm_items, input_ids = self.process_and_combine_mm_data(base_output)
return { return {
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
"mm_items": [combined_mm_item] if combined_mm_item is not None else [], "mm_items": mm_items,
"im_token_id": self.IM_TOKEN_ID, "im_token_id": self.IM_TOKEN_ID,
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment