Unverified Commit 3a911b85 authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

Refactor mm processors and Enable mixed modality processing (#7629)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
parent 886d3449
......@@ -110,9 +110,6 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
*args,
**kwargs,
):
if not image_data:
return None
modalities = request_obj.modalities or ["image"]
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = (
......@@ -122,9 +119,6 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
else None
)
if isinstance(image_data, str):
image_data = [image_data]
if isinstance(image_data, list) and len(image_data) > 0:
if "multi-images" in modalities or "video" in modalities:
# Multiple images
......
......@@ -23,19 +23,12 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
audio_data: List[Union[str, bytes]],
input_text,
request_obj,
max_req_input_len,
**kwargs,
):
audio_data = request_obj.audio_data
if not image_data and not audio_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
if not isinstance(audio_data, list):
audio_data = [audio_data]
base_output = self.load_mm_data(
prompt=input_text,
max_req_input_len=max_req_input_len,
......
......@@ -15,21 +15,11 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
if not image_data:
return None
if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)
if not isinstance(image_data, list):
image_data = [image_data]
if len(image_data) > 0:
images = [load_image(image)[0] for image in image_data]
else:
images = load_image(image_data[0])[0]
images = [load_image(image)[0] for image in image_data]
image_inputs = self.process_mm_data(input_text=input_text, images=images)
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
image_inputs["mm_items"] = [
......
......@@ -37,9 +37,6 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
*args,
**kwargs,
):
if not image_data:
return None
if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text)
......
......@@ -26,22 +26,12 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor):
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
audio_data,
input_text,
request_obj,
max_req_input_len,
**kwargs,
):
audio_data = request_obj.audio_data
if not image_data and not audio_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
if not isinstance(audio_data, list):
audio_data = [audio_data]
if audio_data:
logger.warning(
"Currently SGLang does not support audio data for Phi4MM. We are working on it. You can file an issue to help us prioritize."
......
......@@ -78,12 +78,6 @@ class PixtralProcessor(BaseMultimodalProcessor):
*args,
**kwargs,
):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
mm_data = self.load_mm_data(
prompt=input_text,
multimodal_tokens=self.multimodal_tokens,
......
......@@ -49,9 +49,6 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
*args,
**kwargs,
):
if isinstance(image_data, str):
image_data = [image_data]
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
......@@ -130,12 +127,13 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
video_grid_thw = None # TODO
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output)
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
if combined_mm_item is None:
if not mm_items:
# Note(Xinyuan): This is the case where image loading fails.
return None
combined_mm_item = mm_items[0] # only image is supported for now
video_grid_thw = None # TODO
second_per_grid_ts = getattr(combined_mm_item, "second_per_grid_ts", None)
......@@ -157,7 +155,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
return {
"input_ids": input_ids.tolist(),
"mm_items": [combined_mm_item],
"mm_items": mm_items,
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
"im_token_id": self.IM_TOKEN_ID,
......
......@@ -37,6 +37,8 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
_processor: VILAProcessor,
) -> None:
super().__init__(hf_config, server_args, _processor)
self.IM_TOKEN_ID = hf_config.image_token_id
self.VIDEO_TOKEN_ID = hf_config.video_token_id
async def process_mm_data_async(
self,
......@@ -46,13 +48,7 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
max_req_input_len: int,
**kwargs,
) -> Optional[Dict[str, Any]]:
if not image_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
mm_data = self.load_mm_data(
base_output = self.load_mm_data(
prompt=input_text,
multimodal_tokens=MultimodalSpecialTokens(
image_token=self._processor.tokenizer.image_token
......@@ -61,25 +57,11 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
image_data=image_data,
)
inputs = self.process_mm_data(
input_text=mm_data.input_text,
images=mm_data.images,
)
mm_items, input_ids = self.process_and_combine_mm_data(base_output)
image_offsets = self.get_mm_items_offset(
input_ids=inputs.input_ids[0],
mm_token_id=cast(int, self._processor.tokenizer.image_token_id),
)
mm_items: List[MultimodalDataItem] = [
MultimodalDataItem(
modality=Modality.IMAGE,
image_offsets=image_offsets,
pixel_values=inputs.pixel_values,
)
]
return dict(
input_ids=inputs.input_ids[0].tolist(),
mm_items=mm_items,
)
return {
"input_ids": input_ids.tolist(),
"mm_items": mm_items,
"im_token_id": self.IM_TOKEN_ID,
"video_token_id": self.VIDEO_TOKEN_ID,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment