Unverified Commit 8430bfe3 authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

[Refactor] simplify multimodal data processing (#8107)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
parent c9e8613c
......@@ -126,14 +126,14 @@
" images=[image], text=conv.get_prompt(), return_tensors=\"pt\"\n",
")\n",
"input_ids = processed_prompt[\"input_ids\"][0].detach().cpu().tolist()\n",
"precomputed_features = vision(\n",
"precomputed_embeddings = vision(\n",
" processed_prompt[\"pixel_values\"].cuda(), processed_prompt[\"image_grid_thw\"].cuda()\n",
")\n",
"\n",
"mm_item = dict(\n",
" modality=\"IMAGE\",\n",
" image_grid_thw=processed_prompt[\"image_grid_thw\"],\n",
" precomputed_features=precomputed_features,\n",
" precomputed_embeddings=precomputed_embeddings,\n",
")\n",
"out = llm.generate(input_ids=input_ids, image_data=[mm_item])\n",
"print(out[\"text\"])"
......
......@@ -42,6 +42,9 @@ def select_best_resolution(image_size, candidate_resolutions):
class DictOutput(object):
def items(self):
return self.__dict__.items()
def keys(self):
return self.__dict__.keys()
......@@ -59,7 +62,9 @@ class DictOutput(object):
class VLChatProcessorOutput(DictOutput):
input_ids: torch.LongTensor
target_ids: torch.LongTensor
images: torch.Tensor
pixel_values: (
torch.Tensor
) # rename from "images" to "pixel_values" for compatibility
images_seq_mask: torch.BoolTensor
images_spatial_crop: torch.LongTensor
......@@ -312,10 +317,14 @@ class DeepseekVLV2Processor(ProcessorMixin):
images = torch.stack(images_list, dim=0)
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
images_spatial_crop = torch.stack(
[images_spatial_crop], dim=0
) # stack the tensor to make it a batch of 1
prepare = VLChatProcessorOutput(
input_ids=input_ids,
target_ids=target_ids,
images=images,
pixel_values=images,
images_seq_mask=images_seq_mask,
images_spatial_crop=images_spatial_crop,
)
......
......@@ -284,6 +284,9 @@ class VLMImageProcessor(BaseImageProcessor):
class DictOutput(object):
def items(self):
return self.__dict__.items()
def keys(self):
return self.__dict__.keys()
......
......@@ -221,17 +221,17 @@ def _get_precomputed_embedding(
items: List[MultimodalDataItem],
) -> Optional[torch.Tensor]:
"""
If all items have precomputed_features, return their concatenation.
If some but not all have precomputed_features, raise NotImplementedError.
If none have precomputed_features, return None.
If all items have precomputed_embeddings, return their concatenation.
If some but not all have precomputed_embeddings, raise NotImplementedError.
If none have precomputed_embeddings, return None.
"""
precomputed_features = [item.precomputed_features for item in items]
if any(feature is not None for feature in precomputed_features):
if not all(feature is not None for feature in precomputed_features):
precomputed_embeddings = [item.precomputed_embeddings for item in items]
if any(feature is not None for feature in precomputed_embeddings):
if not all(feature is not None for feature in precomputed_embeddings):
raise NotImplementedError(
"MM inputs where only some items are precomputed."
)
result = torch.concat(precomputed_features)
result = torch.concat(precomputed_embeddings)
# some models embedding is 3-dim, reshape it to 2-dim (similar to get_embedding_chunk)
result = result.reshape(-1, result.shape[-1])
return result
......
......@@ -201,7 +201,7 @@ class MultimodalDataItem:
For example, if there are 3 images and 1 audio inputs, there will be 2 MultimodalDataItem.
One for images and one for audio.
We put the common fields first and the model-specific fields last.
We put the common fields first and the model-specific fields in model_specific_data.
"""
modality: Modality
......@@ -211,37 +211,31 @@ class MultimodalDataItem:
# the raw features returned by processor, e.g. pixel_values or audio_features
feature: Union[torch.Tensor, np.ndarray] = None
image_sizes: Tuple[int, int] = None
# the precomputed embeddings for the modality, e.g. image_emb for image, audio_emb for audio
precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None
audio_offsets: Optional[List[Tuple[int, int]]] = None
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
# Model-specific data stored in a dictionary
model_specific_data: dict[str, Any] = dataclasses.field(default_factory=dict)
# For qwen-vl
image_grid_thw: Union[torch.Tensor, np.ndarray] = None
second_per_grid_ts: Optional[List[torch.Tensor]] = None
# For deepseek-vl
image_emb_mask: Optional[torch.Tensor] = None
image_spatial_crop: Optional[torch.Tensor] = None
# For minicpmv
# [num_images, (n, w, h)]
tgt_size: Tuple[int, int] = None
# For mllama
aspect_ratio_id: Optional[List[torch.Tensor]] = None
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
# For kimi-vl
image_grid_hws: Optional[List[torch.Tensor]] = None
def __getattr__(self, name: str):
if (
"model_specific_data" in self.__dict__
and name in self.__dict__["model_specific_data"]
):
return self.__dict__["model_specific_data"][name]
else:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
# For gemma3n
input_features_mask: Optional[torch.Tensor] = None
def __setitem__(self, key: str, value: Any):
if key in self.__dict__:
self.__dict__[key] = value
else:
self.model_specific_data[key] = value
# For phi4-mm
image_attention_mask: Optional[torch.Tensor] = None
audio_attention_mask: Optional[torch.Tensor] = None
def set(self, key: str, value: Any):
self.__setitem__(key, value)
@staticmethod
def is_empty_list(l):
......@@ -259,7 +253,7 @@ class MultimodalDataItem:
if self.feature is not None:
hashed_feature = self.feature
else:
hashed_feature = self.precomputed_features
hashed_feature = self.precomputed_embeddings
self.hash = hash_feature(hashed_feature)
assert self.hash is not None
self.pad_value = self.hash % (1 << 30)
......@@ -268,24 +262,13 @@ class MultimodalDataItem:
return self.modality == modality
def is_audio(self):
return (self.modality == Modality.AUDIO) and (
self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.feature)
)
return self.modality == Modality.AUDIO
def is_image(self):
return (
self.is_modality(Modality.IMAGE) or self.is_modality(Modality.MULTI_IMAGES)
) and (
self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.feature)
)
return self.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]
def is_video(self):
return (self.modality == Modality.VIDEO) and (
self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.feature)
)
return self.modality == Modality.VIDEO
def is_valid(self) -> bool:
return self.is_image() or self.is_video() or self.is_audio()
......@@ -306,8 +289,7 @@ class MultimodalDataItem:
def merge(self, other):
self.feature += other.feature
self.image_sizes += other.image_sizes
self.image_offsets += other.image_offsets
self.offsets += other.offsets
self.hash = hash((self.hash, other.hash))
self.set_pad_value()
......
......@@ -260,7 +260,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
def get_image_feature(self, items: List[MultimodalDataItem]):
images_spatial_crop = torch.cat(
[item.image_spatial_crop for item in items], dim=0
[item.images_spatial_crop for item in items], dim=0
)
assert images_spatial_crop.dim() == 3
......@@ -278,8 +278,8 @@ class DeepseekVL2ForCausalLM(nn.Module):
_, hw, n_dim = images_embeds.shape
h = w = int(hw**0.5)
tile_index = 0
for jdx in range(item.image_spatial_crop.shape[1]):
num_width_tiles, num_height_tiles = item.image_spatial_crop[0, jdx]
for jdx in range(item.images_spatial_crop.shape[1]):
num_width_tiles, num_height_tiles = item.images_spatial_crop[0, jdx]
if num_width_tiles == 0 or num_height_tiles == 0:
break
num_tiles_in_image = num_width_tiles * num_height_tiles
......
......@@ -81,6 +81,7 @@ class Llama4ForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(
config.text_config if hasattr(config, "text_config") else config
)
self.padding_pattern = MultiModalityDataPaddingPatternMultimodalTokens()
def _has_vision_weights(self, config) -> bool:
"""Check if the model has vision components by examining the checkpoint."""
......@@ -135,8 +136,7 @@ class Llama4ForConditionalGeneration(nn.Module):
return False
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
return pattern.pad_input_tokens(input_ids, mm_inputs)
return self.padding_pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(
self,
......
......@@ -435,7 +435,12 @@ class Phi4MMForCausalLM(nn.Module):
dtype = next(self.vision_encoder.parameters()).dtype
pixel_values = torch.cat([item.feature for item in items], dim=0).type(dtype)
image_attention_mask = torch.cat(
[item.image_attention_mask for item in items], dim=0
[
item.image_attention_mask
for item in items
if hasattr(item, "image_attention_mask")
],
dim=0,
)
image_sizes = torch.cat([item.image_sizes for item in items], dim=0)
image_embeds = self.vision_encoder(
......@@ -456,7 +461,7 @@ class Phi4MMForCausalLM(nn.Module):
audio_features=item.feature.to(device).type(dtype),
audio_attention_mask=(
item.audio_attention_mask.to(device)
if item.audio_attention_mask is not None
if hasattr(item, "audio_attention_mask")
else None
),
)
......
......@@ -5,7 +5,7 @@ import multiprocessing as mp
import os
import re
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import numpy as np
import torch
......@@ -155,17 +155,15 @@ class BaseMultimodalProcessor(ABC):
self.ATTR_NAME_TO_MODALITY = {
# Image-related attributes
"pixel_values": Modality.IMAGE,
"pixel_values_videos": Modality.VIDEO,
"image_sizes": Modality.IMAGE,
"image_grid_thw": Modality.IMAGE,
"image_attention_mask": Modality.IMAGE,
"image_emb_mask": Modality.IMAGE,
"image_spatial_crop": Modality.IMAGE,
"images_spatial_crop": Modality.IMAGE,
"tgt_size": Modality.IMAGE,
"image_grid_hws": Modality.IMAGE,
"aspect_ratio_id": Modality.IMAGE,
"aspect_ratio_ids": Modality.IMAGE,
"aspect_ratio_mask": Modality.IMAGE,
"second_per_grid_ts": Modality.IMAGE,
# Audio-related attributes
"audio_features": Modality.AUDIO,
"audio_feature_lens": Modality.AUDIO,
......@@ -173,9 +171,11 @@ class BaseMultimodalProcessor(ABC):
"input_features_mask": Modality.AUDIO,
"audio_attention_mask": Modality.AUDIO,
# Video-related attributes
"pixel_values_videos": Modality.VIDEO,
"second_per_grid_ts": Modality.VIDEO,
"video_grid_thw": Modality.VIDEO,
# Generic attributes that could apply to multiple modalities
# "precomputed_features" - handled specially as it can be any modality
# "precomputed_embeddings" - handled specially as it can be any modality
}
# name of the feature filed
......@@ -222,7 +222,6 @@ class BaseMultimodalProcessor(ABC):
audio_data,
input_text,
request_obj,
max_req_input_len,
**kwargs,
) -> Optional[Dict[str, Any]]:
pass
......@@ -283,7 +282,7 @@ class BaseMultimodalProcessor(ABC):
self,
text_parts: List[str],
multimodal_tokens: MultimodalSpecialTokens,
data_iterators: dict,
data_iterators: dict[Modality, Iterator[Any]],
discard_alpha_channel: bool = True,
image_estimated_frames_iter: Optional[iter] = None,
image_scaling_factor: float = 1.0,
......@@ -354,7 +353,6 @@ class BaseMultimodalProcessor(ABC):
self,
prompt: str,
multimodal_tokens: MultimodalSpecialTokens,
max_req_input_len: int,
image_data: Optional[list] = None,
video_data: Optional[list] = None,
audio_data: Optional[list] = None,
......@@ -489,50 +487,11 @@ class BaseMultimodalProcessor(ABC):
return list(zip(indices_start.tolist(), indices_end.tolist()))
@staticmethod
def _extract_processor_features(
items: List[dict], attr_name: str
) -> Optional[torch.Tensor]:
"""
Helper function to concat extracted attributes from processor output.
"""
values = [value for item in items if (value := item.get(attr_name)) is not None]
return torch.cat(values) if values else None
# When we assume that all the items have the same attributes
def _extract_processor_features_from_all_attributes(
self, items: List[dict]
) -> dict:
values = {}
# Verify all items have the same keys
first_keys = set(items[0].keys())
for item in items[1:]:
if set(item.keys()) != first_keys:
raise ValueError(
f"All items must have the same attributes. "
f"First item has {first_keys}, but found {set(item.keys())}"
)
# Process each attribute
for k, v in items[0].items():
if isinstance(v, list):
values[k] = self._extract_processor_features(items, k)
else:
# Verify all items have the same value for non-list attributes
for item in items[1:]:
if item[k] != v:
raise ValueError(
f"All items must have the same value for attribute {k}. "
f"First item has {v}, but found {item[k]}"
)
values[k] = v
return values
def collect_mm_items_from_processor_output(
self, data_dict: dict
) -> List[MultimodalDataItem]:
"""Create mm_items directly from processor output."""
items = {} # modality -> MultimodalDataItem
items: dict[Modality, MultimodalDataItem] = {}
for attr_name, value in data_dict.items():
if attr_name == "input_ids":
......@@ -541,16 +500,15 @@ class BaseMultimodalProcessor(ABC):
# Get modality for this attribute
modality = self.ATTR_NAME_TO_MODALITY.get(attr_name)
if not modality and attr_name == "precomputed_features":
if attr_name == "precomputed_embeddings":
modality_str = data_dict.get("modality")
try:
modality = (
Modality.from_str(modality_str)
if modality_str
else Modality.IMAGE
)
except ValueError:
modality = Modality.IMAGE
modality = Modality.IMAGE
if modality_str:
try:
modality = Modality.from_str(modality_str)
except ValueError:
pass
if modality:
# Create item if needed
if modality not in items:
......@@ -559,8 +517,7 @@ class BaseMultimodalProcessor(ABC):
if attr_name in self.FEATURE_NAMES:
attr_name = "feature"
# Set attribute
setattr(items[modality], attr_name, value)
items[modality].set(attr_name, value)
return list(items.values())
......@@ -586,6 +543,7 @@ class BaseMultimodalProcessor(ABC):
self,
base_output: BaseMultiModalProcessorOutput,
mm_tokens: MultimodalSpecialTokens,
**kwargs,
) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]:
"""
Process multimodal data and return the combined multimodal items and input_ids.
......@@ -618,7 +576,7 @@ class BaseMultimodalProcessor(ABC):
else:
raise ValueError(f"Unknown multimodal item type: {type(item)}")
# Process items and get input_ids
all_collected_items = []
all_collected_items: list[MultimodalDataItem] = []
input_ids = None
# Handle dict items (already processed)
......@@ -634,6 +592,7 @@ class BaseMultimodalProcessor(ABC):
images=raw_images,
audios=raw_audios,
videos=raw_videos,
**kwargs,
)
all_collected_items.extend(collected_items)
else:
......
from typing import List, Union
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.clip import CLIPModel
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
from sglang.srt.utils import load_image
from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
)
class ClipImageProcessor(BaseMultimodalProcessor):
......@@ -11,23 +12,24 @@ class ClipImageProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
_processor
)
async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)
images = [load_image(image)[0] for image in image_data]
image_inputs = self.process_mm_data(input_text=input_text, images=images)
image_inputs["data_hashes"] = [hash(str(image_data))]
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
image_inputs["mm_items"] = [
MultimodalDataItem(
feature=image_inputs["pixel_values"], modality=Modality.IMAGE
)
]
return image_inputs
base_output = self.load_mm_data(
prompt=input_text,
multimodal_tokens=self.mm_tokens,
image_data=image_data,
)
mm_items, input_ids, _ = self.process_and_combine_mm_data(
base_output, self.mm_tokens
)
return {
"input_ids": input_ids.tolist(),
"mm_items": mm_items,
}
......@@ -33,9 +33,9 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.mm_tokens = MultimodalSpecialTokens(image_token="<image>").build(
_processor
)
self.mm_tokens = MultimodalSpecialTokens(
image_token="<image>", image_token_id=self._processor.image_token_id
).build(_processor)
async def process_mm_data_async(
self,
......@@ -50,36 +50,16 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
input_text,
image_data=image_data,
multimodal_tokens=self.mm_tokens,
max_req_input_len=max_req_input_len,
)
res = self.process_mm_data(
input_text=base_output.input_text,
images=base_output.images,
mm_items, input_ids, _ = self.process_and_combine_mm_data(
base_output,
self.mm_tokens,
max_req_input_len=max_req_input_len,
conversations=base_output.input_text,
)
images_seq_mask = res["images_seq_mask"]
images_spatial_crop = res["images_spatial_crop"]
batched_images_spatial_crop = []
batched_images_spatial_crop.append(images_spatial_crop)
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
items = []
input_ids = res["input_ids"]
image_offsets = self.get_mm_items_offset(
input_ids=input_ids, mm_token_id=self._processor.image_token_id
)
item = MultimodalDataItem(
feature=res["images"],
offsets=image_offsets,
modality=Modality.IMAGE,
image_emb_mask=images_seq_mask,
image_spatial_crop=batched_images_spatial_crop,
)
items += [item]
return {
"mm_items": items,
"mm_items": mm_items,
"input_ids": input_ids.tolist(),
"im_token_id": self._processor.image_token_id,
}
......@@ -33,7 +33,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
image_data: List[Union[str, bytes, Dict]],
input_text,
request_obj,
max_req_input_len,
*args,
**kwargs,
):
......@@ -41,7 +40,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
prompt=input_text,
image_data=image_data,
multimodal_tokens=self.mm_tokens,
max_req_input_len=max_req_input_len,
discard_alpha_channel=True,
)
......
......@@ -54,7 +54,6 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
audio_data: Optional[List[Union[str, bytes, Dict]]] = None,
input_text: str = "",
request_obj=None,
max_req_input_len: int = 0,
*args,
**kwargs,
):
......@@ -63,7 +62,6 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
prompt=input_text,
image_data=image_data,
audio_data=audio_data,
max_req_input_len=max_req_input_len,
multimodal_tokens=self.mm_tokens,
)
......
......@@ -170,13 +170,12 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
return pixel_values, num_patches_list
async def process_mm_data_async(
self, image_data, input_text, request_obj, max_req_input_len, **kwargs
self, image_data, input_text, request_obj, **kwargs
):
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
multimodal_tokens=self.mm_tokens,
max_req_input_len=max_req_input_len,
discard_alpha_channel=True,
)
......
......@@ -11,52 +11,35 @@ from sglang.srt.multimodal.processors.base_processor import (
class JanusProImageProcessor(BaseMultimodalProcessor):
models = [MultiModalityCausalLM]
def __init__(self, hf_config, server_args, processor):
super().__init__(hf_config, server_args, processor)
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.mm_tokens = MultimodalSpecialTokens(
image_token=processor.image_token
).build(processor)
image_token=_processor.image_token,
image_token_id=_processor.image_id,
).build(_processor)
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
max_req_input_len,
**kwargs,
):
processor = self._processor
base_out = self.load_mm_data(
prompt=input_text,
image_data=image_data,
multimodal_tokens=self.mm_tokens,
max_req_input_len=max_req_input_len,
)
images = base_out.images
res = self.process_mm_data(
input_text=base_out.input_text,
prompt=base_out.input_text,
images=images,
mm_items, input_ids, _ = self.process_and_combine_mm_data(
base_out, self.mm_tokens, prompt=base_out.input_text
)
input_ids = res["input_ids"].flatten()
image_offsets = self.get_mm_items_offset(
input_ids=input_ids, mm_token_id=processor.image_id
)
return {
"mm_items": [
MultimodalDataItem(
feature=res["pixel_values"],
image_emb_mask=res["images_emb_mask"],
offsets=image_offsets,
modality=Modality.IMAGE,
)
],
"mm_items": mm_items,
"input_ids": input_ids.tolist(),
"im_start_id": processor.image_start_id,
"im_end_id": processor.image_end_id,
"im_token_id": processor.image_id,
"im_start_id": self._processor.image_start_id,
"im_end_id": self._processor.image_end_id,
"im_token_id": self.mm_tokens.image_token_id,
}
......@@ -26,7 +26,6 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
image_data: List[Union[str, bytes, Dict]],
input_text,
request_obj,
max_req_input_len,
*args,
**kwargs,
):
......@@ -34,7 +33,6 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
prompt=input_text,
image_data=image_data,
multimodal_tokens=self.mm_tokens,
max_req_input_len=max_req_input_len,
)
mm_items, input_ids, _ = self.process_and_combine_mm_data(
......
......@@ -159,7 +159,9 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
"mm_items": [
MultimodalDataItem(
feature=pixel_values,
image_sizes=image_sizes,
model_specific_data={
"image_sizes": image_sizes,
},
modality=modality,
)
],
......
......@@ -17,10 +17,21 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
# Collect special token ids
tokenizer = self._processor.tokenizer
self.slice_start_id = getattr(tokenizer, "slice_start_id", None)
self.slice_end_id = getattr(tokenizer, "slice_end_id", None)
self.audio_start_id = getattr(tokenizer, "audio_start_id", None)
self.audio_end_id = getattr(tokenizer, "audio_end_id", None)
self.im_start_id = getattr(tokenizer, "im_start_id", None)
self.im_end_id = getattr(tokenizer, "im_end_id", None)
self.im_token_id = getattr(tokenizer, "unk_id", None)
self.mm_tokens = MultimodalSpecialTokens(
image_token="(<image>./</image>)",
audio_token="(<audio>./</audio>)",
video_token="(<video>./</video>)",
image_token_id=self.im_token_id,
).build(_processor)
async def process_mm_data_async(
......@@ -29,12 +40,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audio_data: List[Union[str, bytes]],
input_text,
request_obj,
max_req_input_len,
**kwargs,
):
base_output = self.load_mm_data(
prompt=input_text,
max_req_input_len=max_req_input_len,
audio_data=audio_data,
image_data=image_data,
multimodal_tokens=self.mm_tokens,
......@@ -48,24 +57,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audios=base_output.audios,
)
# Collect special token ids
tokenizer = self._processor.tokenizer
slice_start_id, slice_end_id, audio_start_id, audio_end_id = (
None,
None,
None,
None,
)
if tokenizer.slice_start_id:
slice_start_id = tokenizer.slice_start_id
slice_end_id = tokenizer.slice_end_id
if hasattr(tokenizer, "audio_start_id"):
audio_start_id = tokenizer.audio_start_id
audio_end_id = tokenizer.audio_end_id
im_start_id = tokenizer.im_start_id
im_end_id = tokenizer.im_end_id
im_token_id = tokenizer.unk_id
pixel_values = res["pixel_values"]
tgt_sizes = res["tgt_sizes"]
......@@ -102,10 +93,12 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
items = []
input_ids = res["input_ids"].flatten()
image_offsets = self.get_mm_items_offset_by_pair(
input_ids=input_ids, mm_start_id=im_start_id, mm_end_id=im_end_id
input_ids=input_ids, mm_start_id=self.im_start_id, mm_end_id=self.im_end_id
)
slice_offsets = self.get_mm_items_offset_by_pair(
input_ids=input_ids, mm_start_id=slice_start_id, mm_end_id=slice_end_id
input_ids=input_ids,
mm_start_id=self.slice_start_id,
mm_end_id=self.slice_end_id,
)
image_offsets.extend(slice_offsets)
image_offsets = sorted(image_offsets)
......@@ -114,7 +107,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
item = MultimodalDataItem(
feature=pixel_values,
offsets=image_offsets,
tgt_size=tgt_sizes_flat,
model_specific_data={"tgt_size": tgt_sizes_flat},
modality=Modality.IMAGE,
)
items += [item]
......@@ -124,17 +117,17 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
and res["audio_features"] is not None
and len(res["audio_features"]) != 0
):
if audio_start_id is not None and audio_end_id is not None:
if self.audio_start_id is not None and self.audio_end_id is not None:
audio_offsets = self.get_mm_items_offset_by_pair(
input_ids=input_ids,
mm_start_id=audio_start_id,
mm_end_id=audio_end_id,
mm_start_id=self.audio_start_id,
mm_end_id=self.audio_end_id,
)
else:
audio_offsets = None
item = MultimodalDataItem(
feature=[res["audio_features"]],
audio_feature_lens=res["audio_feature_lens"],
model_specific_data={"audio_feature_lens": res["audio_feature_lens"]},
offsets=audio_offsets,
modality=Modality.AUDIO,
)
......@@ -142,11 +135,11 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
return {
"mm_items": items,
"input_ids": input_ids.tolist(),
"audio_start_id": audio_start_id,
"audio_end_id": audio_end_id,
"im_token_id": im_token_id,
"im_start_id": im_start_id,
"im_end_id": im_end_id,
"slice_start_id": slice_start_id,
"slice_end_id": slice_end_id,
"audio_start_id": self.audio_start_id,
"audio_end_id": self.audio_end_id,
"im_token_id": self.im_token_id,
"im_start_id": self.im_start_id,
"im_end_id": self.im_end_id,
"slice_start_id": self.slice_start_id,
"slice_end_id": self.slice_end_id,
}
from typing import List, Union
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.mllama import MllamaForConditionalGeneration
from sglang.srt.multimodal.processors.base_processor import BaseMultimodalProcessor
from sglang.srt.utils import load_image
from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
)
class MllamaImageProcessor(BaseMultimodalProcessor):
......@@ -11,24 +12,26 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.mm_tokens = MultimodalSpecialTokens(
image_token=self._processor.image_token,
image_token_id=self._processor.image_token_id,
).build(_processor)
async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)
base_out = self.load_mm_data(
prompt=input_text,
image_data=image_data,
multimodal_tokens=self.mm_tokens,
)
images = [load_image(image)[0] for image in image_data]
image_inputs = self.process_mm_data(input_text=input_text, images=images)
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
image_inputs["mm_items"] = [
MultimodalDataItem(
feature=image_inputs["pixel_values"],
aspect_ratio_id=image_inputs["aspect_ratio_ids"],
aspect_ratio_mask=image_inputs["aspect_ratio_mask"],
modality=Modality.IMAGE,
)
]
mm_items, input_ids, _ = self.process_and_combine_mm_data(
base_out, self.mm_tokens
)
return image_inputs
return {
"mm_items": mm_items,
"input_ids": input_ids.tolist(),
"im_token_id": self.mm_tokens.image_token_id,
}
......@@ -27,13 +27,13 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
self.image_token_index = hf_config.image_token_index
self.multimodal_tokens = MultimodalSpecialTokens(
image_token=_processor.image_token,
image_token_id=self.image_token_index,
).build(_processor)
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
max_req_input_len=None,
*args,
**kwargs,
):
......@@ -45,7 +45,6 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
processed_data = self.load_mm_data(
prompt=input_text,
multimodal_tokens=self.multimodal_tokens,
max_req_input_len=max_req_input_len or 4096,
image_data=image_data,
return_text=True,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment