Unverified Commit 3a911b85 authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

Refactor mm processors and Enable mixed modality processing (#7629)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
parent 886d3449
...@@ -110,9 +110,6 @@ class LlavaImageProcessor(BaseMultimodalProcessor): ...@@ -110,9 +110,6 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
*args, *args,
**kwargs, **kwargs,
): ):
if not image_data:
return None
modalities = request_obj.modalities or ["image"] modalities = request_obj.modalities or ["image"]
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = ( grid_pinpoints = (
...@@ -122,9 +119,6 @@ class LlavaImageProcessor(BaseMultimodalProcessor): ...@@ -122,9 +119,6 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
else None else None
) )
if isinstance(image_data, str):
image_data = [image_data]
if isinstance(image_data, list) and len(image_data) > 0: if isinstance(image_data, list) and len(image_data) > 0:
if "multi-images" in modalities or "video" in modalities: if "multi-images" in modalities or "video" in modalities:
# Multiple images # Multiple images
......
...@@ -23,19 +23,12 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -23,19 +23,12 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
audio_data: List[Union[str, bytes]],
input_text, input_text,
request_obj, request_obj,
max_req_input_len, max_req_input_len,
**kwargs, **kwargs,
): ):
audio_data = request_obj.audio_data
if not image_data and not audio_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
if not isinstance(audio_data, list):
audio_data = [audio_data]
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
......
...@@ -15,21 +15,11 @@ class MllamaImageProcessor(BaseMultimodalProcessor): ...@@ -15,21 +15,11 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
async def process_mm_data_async( async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
): ):
if not image_data:
return None
if isinstance(input_text, list): if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int) assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text) input_text = self._processor.tokenizer.decode(input_text)
if not isinstance(image_data, list):
image_data = [image_data]
if len(image_data) > 0:
images = [load_image(image)[0] for image in image_data] images = [load_image(image)[0] for image in image_data]
else:
images = load_image(image_data[0])[0]
image_inputs = self.process_mm_data(input_text=input_text, images=images) image_inputs = self.process_mm_data(input_text=input_text, images=images)
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0] image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
image_inputs["mm_items"] = [ image_inputs["mm_items"] = [
......
...@@ -37,9 +37,6 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor): ...@@ -37,9 +37,6 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
*args, *args,
**kwargs, **kwargs,
): ):
if not image_data:
return None
if isinstance(input_text, list): if isinstance(input_text, list):
assert len(input_text) and isinstance(input_text[0], int) assert len(input_text) and isinstance(input_text[0], int)
input_text = self._processor.tokenizer.decode(input_text) input_text = self._processor.tokenizer.decode(input_text)
......
...@@ -26,22 +26,12 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor): ...@@ -26,22 +26,12 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor):
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
audio_data,
input_text, input_text,
request_obj, request_obj,
max_req_input_len, max_req_input_len,
**kwargs, **kwargs,
): ):
audio_data = request_obj.audio_data
if not image_data and not audio_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
if not isinstance(audio_data, list):
audio_data = [audio_data]
if audio_data: if audio_data:
logger.warning( logger.warning(
"Currently SGLang does not support audio data for Phi4MM. We are working on it. You can file an issue to help us prioritize." "Currently SGLang does not support audio data for Phi4MM. We are working on it. You can file an issue to help us prioritize."
......
...@@ -78,12 +78,6 @@ class PixtralProcessor(BaseMultimodalProcessor): ...@@ -78,12 +78,6 @@ class PixtralProcessor(BaseMultimodalProcessor):
*args, *args,
**kwargs, **kwargs,
): ):
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
mm_data = self.load_mm_data( mm_data = self.load_mm_data(
prompt=input_text, prompt=input_text,
multimodal_tokens=self.multimodal_tokens, multimodal_tokens=self.multimodal_tokens,
......
...@@ -49,9 +49,6 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -49,9 +49,6 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
*args, *args,
**kwargs, **kwargs,
): ):
if isinstance(image_data, str):
image_data = [image_data]
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
...@@ -130,12 +127,13 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -130,12 +127,13 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
video_grid_thw = None # TODO video_grid_thw = None # TODO
combined_mm_item, input_ids = self.process_and_combine_mm_data(base_output) mm_items, input_ids = self.process_and_combine_mm_data(base_output)
if combined_mm_item is None: if not mm_items:
# Note(Xinyuan): This is the case where image loading fails. # Note(Xinyuan): This is the case where image loading fails.
return None return None
combined_mm_item = mm_items[0] # only image is supported for now
video_grid_thw = None # TODO video_grid_thw = None # TODO
second_per_grid_ts = getattr(combined_mm_item, "second_per_grid_ts", None) second_per_grid_ts = getattr(combined_mm_item, "second_per_grid_ts", None)
...@@ -157,7 +155,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -157,7 +155,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
return { return {
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
"mm_items": [combined_mm_item], "mm_items": mm_items,
"im_start_id": self.IM_START_TOKEN_ID, "im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID,
"im_token_id": self.IM_TOKEN_ID, "im_token_id": self.IM_TOKEN_ID,
......
...@@ -37,6 +37,8 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor): ...@@ -37,6 +37,8 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
_processor: VILAProcessor, _processor: VILAProcessor,
) -> None: ) -> None:
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.IM_TOKEN_ID = hf_config.image_token_id
self.VIDEO_TOKEN_ID = hf_config.video_token_id
async def process_mm_data_async( async def process_mm_data_async(
self, self,
...@@ -46,13 +48,7 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor): ...@@ -46,13 +48,7 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
max_req_input_len: int, max_req_input_len: int,
**kwargs, **kwargs,
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
if not image_data: base_output = self.load_mm_data(
return None
if not isinstance(image_data, list):
image_data = [image_data]
mm_data = self.load_mm_data(
prompt=input_text, prompt=input_text,
multimodal_tokens=MultimodalSpecialTokens( multimodal_tokens=MultimodalSpecialTokens(
image_token=self._processor.tokenizer.image_token image_token=self._processor.tokenizer.image_token
...@@ -61,25 +57,11 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor): ...@@ -61,25 +57,11 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
image_data=image_data, image_data=image_data,
) )
inputs = self.process_mm_data( mm_items, input_ids = self.process_and_combine_mm_data(base_output)
input_text=mm_data.input_text,
images=mm_data.images,
)
image_offsets = self.get_mm_items_offset(
input_ids=inputs.input_ids[0],
mm_token_id=cast(int, self._processor.tokenizer.image_token_id),
)
mm_items: List[MultimodalDataItem] = [
MultimodalDataItem(
modality=Modality.IMAGE,
image_offsets=image_offsets,
pixel_values=inputs.pixel_values,
)
]
return dict( return {
input_ids=inputs.input_ids[0].tolist(), "input_ids": input_ids.tolist(),
mm_items=mm_items, "mm_items": mm_items,
) "im_token_id": self.IM_TOKEN_ID,
"video_token_id": self.VIDEO_TOKEN_ID,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment