Unverified Commit 6e923dbd authored by Xinyuan Tong's avatar Xinyuan Tong Committed by GitHub
Browse files

feat: update multimodal data handling in engine entrypoint (#8002)


Signed-off-by: default avatarXinyuan Tong <justinning0323@outlook.com>
parent c268c11c
...@@ -46,9 +46,9 @@ from sglang.srt.managers.io_struct import ( ...@@ -46,9 +46,9 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput, EmbeddingReqInput,
GenerateReqInput, GenerateReqInput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
ImageDataItem,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput, LoadLoRAAdapterReqInput,
MultimodalDataInputFormat,
ReleaseMemoryOccupationReqInput, ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqInput,
RpcReqInput, RpcReqInput,
...@@ -148,13 +148,9 @@ class Engine(EngineBase): ...@@ -148,13 +148,9 @@ class Engine(EngineBase):
# - List of images (one per request in a batch) # - List of images (one per request in a batch)
# - List of lists of images (multiple images per request) # - List of lists of images (multiple images per request)
# See also python/sglang/srt/utils.py:load_image for more details. # See also python/sglang/srt/utils.py:load_image for more details.
image_data: Optional[ image_data: Optional[MultimodalDataInputFormat] = None,
Union[ audio_data: Optional[MultimodalDataInputFormat] = None,
List[List[ImageDataItem]], video_data: Optional[MultimodalDataInputFormat] = None,
List[ImageDataItem],
ImageDataItem,
]
] = None,
return_logprob: Optional[Union[List[bool], bool]] = False, return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None, logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None,
...@@ -187,6 +183,8 @@ class Engine(EngineBase): ...@@ -187,6 +183,8 @@ class Engine(EngineBase):
input_ids=input_ids, input_ids=input_ids,
sampling_params=sampling_params, sampling_params=sampling_params,
image_data=image_data, image_data=image_data,
audio_data=audio_data,
video_data=video_data,
return_logprob=return_logprob, return_logprob=return_logprob,
logprob_start_len=logprob_start_len, logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num, top_logprobs_num=top_logprobs_num,
...@@ -231,13 +229,9 @@ class Engine(EngineBase): ...@@ -231,13 +229,9 @@ class Engine(EngineBase):
# - List of images (one per request in a batch) # - List of images (one per request in a batch)
# - List of lists of images (multiple images per request) # - List of lists of images (multiple images per request)
# See also python/sglang/srt/utils.py:load_image for more details. # See also python/sglang/srt/utils.py:load_image for more details.
image_data: Optional[ image_data: Optional[MultimodalDataInputFormat] = None,
Union[ audio_data: Optional[MultimodalDataInputFormat] = None,
List[List[ImageDataItem]], video_data: Optional[MultimodalDataInputFormat] = None,
List[ImageDataItem],
ImageDataItem,
]
] = None,
return_logprob: Optional[Union[List[bool], bool]] = False, return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None, logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None, top_logprobs_num: Optional[Union[List[int], int]] = None,
...@@ -272,6 +266,8 @@ class Engine(EngineBase): ...@@ -272,6 +266,8 @@ class Engine(EngineBase):
input_ids=input_ids, input_ids=input_ids,
sampling_params=sampling_params, sampling_params=sampling_params,
image_data=image_data, image_data=image_data,
audio_data=audio_data,
video_data=video_data,
return_logprob=return_logprob, return_logprob=return_logprob,
logprob_start_len=logprob_start_len, logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num, top_logprobs_num=top_logprobs_num,
...@@ -295,19 +291,20 @@ class Engine(EngineBase): ...@@ -295,19 +291,20 @@ class Engine(EngineBase):
def encode( def encode(
self, self,
prompt: Union[str, List[str], List[Dict], List[List[Dict]]], prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
image_data: Optional[ image_data: Optional[MultimodalDataInputFormat] = None,
Union[ audio_data: Optional[MultimodalDataInputFormat] = None,
List[List[Union[Image, str]]], video_data: Optional[MultimodalDataInputFormat] = None,
List[Union[Image, str]],
Union[Image, str],
]
] = None,
) -> Dict: ) -> Dict:
""" """
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`. The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
Please refer to `EmbeddingReqInput` for the documentation. Please refer to `EmbeddingReqInput` for the documentation.
""" """
obj = EmbeddingReqInput(text=prompt, image_data=image_data) obj = EmbeddingReqInput(
text=prompt,
image_data=image_data,
audio_data=audio_data,
video_data=video_data,
)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
generator = self.tokenizer_manager.generate_request(obj, None) generator = self.tokenizer_manager.generate_request(obj, None)
ret = loop.run_until_complete(generator.__anext__()) ret = loop.run_until_complete(generator.__anext__())
...@@ -316,7 +313,9 @@ class Engine(EngineBase): ...@@ -316,7 +313,9 @@ class Engine(EngineBase):
async def async_encode( async def async_encode(
self, self,
prompt: Union[str, List[str], List[Dict], List[List[Dict]]], prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
image_data: Optional[Union[List[str], str]] = None, image_data: Optional[MultimodalDataInputFormat] = None,
audio_data: Optional[MultimodalDataInputFormat] = None,
video_data: Optional[MultimodalDataInputFormat] = None,
) -> Dict: ) -> Dict:
""" """
Asynchronous version of encode method. Asynchronous version of encode method.
...@@ -324,7 +323,12 @@ class Engine(EngineBase): ...@@ -324,7 +323,12 @@ class Engine(EngineBase):
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`. The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
Please refer to `EmbeddingReqInput` for the documentation. Please refer to `EmbeddingReqInput` for the documentation.
""" """
obj = EmbeddingReqInput(text=prompt, image_data=image_data) obj = EmbeddingReqInput(
text=prompt,
image_data=image_data,
audio_data=audio_data,
video_data=video_data,
)
generator = self.tokenizer_manager.generate_request(obj, None) generator = self.tokenizer_manager.generate_request(obj, None)
return await generator.__anext__() return await generator.__anext__()
......
...@@ -42,8 +42,21 @@ class SessionParams: ...@@ -42,8 +42,21 @@ class SessionParams:
drop_previous_output: Optional[bool] = None drop_previous_output: Optional[bool] = None
AudioDataItem = Union[str, Dict] # Type definitions for multimodal input data
ImageDataItem = Union[Image, str, Dict] # Individual data item types for each modality
ImageDataInputItem = Union[Image, str, Dict]
AudioDataInputItem = Union[str, Dict]
VideoDataInputItem = Union[str, Dict]
# Union type for any multimodal data item
MultimodalDataInputItem = Union[
ImageDataInputItem, VideoDataInputItem, AudioDataInputItem
]
# Format types supporting single items, lists, or nested lists for batch processing
MultimodalDataInputFormat = Union[
List[List[MultimodalDataInputItem]],
List[MultimodalDataInputItem],
MultimodalDataInputItem,
]
@dataclass @dataclass
...@@ -60,13 +73,11 @@ class GenerateReqInput: ...@@ -60,13 +73,11 @@ class GenerateReqInput:
# - List of images (one per request in a batch) # - List of images (one per request in a batch)
# - List of lists of images (multiple images per request) # - List of lists of images (multiple images per request)
# See also python/sglang/srt/utils.py:load_image for more details. # See also python/sglang/srt/utils.py:load_image for more details.
image_data: Optional[ image_data: Optional[MultimodalDataInputFormat] = None
Union[List[List[ImageDataItem]], List[ImageDataItem], ImageDataItem]
] = None
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
# The video input. Like image data, it can be a file name, a url, or base64 encoded string. # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
video_data: Optional[Union[List[List[str]], List[str], str]] = None video_data: Optional[MultimodalDataInputFormat] = None
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
audio_data: Optional[MultimodalDataInputFormat] = None
# The sampling_params. See descriptions below. # The sampling_params. See descriptions below.
sampling_params: Optional[Union[List[Dict], Dict]] = None sampling_params: Optional[Union[List[Dict], Dict]] = None
# The request id. # The request id.
...@@ -524,13 +535,11 @@ class EmbeddingReqInput: ...@@ -524,13 +535,11 @@ class EmbeddingReqInput:
# - List of images (one per request in a batch) # - List of images (one per request in a batch)
# - List of lists of images (multiple images per request) # - List of lists of images (multiple images per request)
# See also python/sglang/srt/utils.py:load_image for more details. # See also python/sglang/srt/utils.py:load_image for more details.
image_data: Optional[ image_data: Optional[MultimodalDataInputFormat] = None
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
] = None
# The video input. Like image data, it can be a file name, a url, or base64 encoded string. # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
video_data: Optional[Union[List[str], str]] = None video_data: Optional[MultimodalDataInputFormat] = None
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string. # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
audio_data: Optional[Union[List[str], str]] = None audio_data: Optional[MultimodalDataInputFormat] = None
# The token ids for text; one can either specify text or input_ids. # The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The request id. # The request id.
...@@ -610,8 +619,6 @@ class EmbeddingReqInput: ...@@ -610,8 +619,6 @@ class EmbeddingReqInput:
if self.is_cross_encoder_request: if self.is_cross_encoder_request:
return EmbeddingReqInput( return EmbeddingReqInput(
text=[self.text[i]] if self.text is not None else None, text=[self.text[i]] if self.text is not None else None,
input_ids=None,
image_data=None,
sampling_params=self.sampling_params[i], sampling_params=self.sampling_params[i],
rid=self.rid[i], rid=self.rid[i],
is_cross_encoder_request=True, is_cross_encoder_request=True,
...@@ -621,6 +628,8 @@ class EmbeddingReqInput: ...@@ -621,6 +628,8 @@ class EmbeddingReqInput:
text=self.text[i] if self.text is not None else None, text=self.text[i] if self.text is not None else None,
input_ids=self.input_ids[i] if self.input_ids is not None else None, input_ids=self.input_ids[i] if self.input_ids is not None else None,
image_data=self.image_data[i] if self.image_data is not None else None, image_data=self.image_data[i] if self.image_data is not None else None,
audio_data=self.audio_data[i] if self.audio_data is not None else None,
video_data=self.video_data[i] if self.video_data is not None else None,
sampling_params=self.sampling_params[i], sampling_params=self.sampling_params[i],
rid=self.rid[i], rid=self.rid[i],
) )
......
...@@ -8,7 +8,7 @@ from transformers.tokenization_utils_base import PreTrainedTokenizerBase ...@@ -8,7 +8,7 @@ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
EmbeddingReqInput, EmbeddingReqInput,
GenerateReqInput, GenerateReqInput,
ImageDataItem, ImageDataInputItem,
) )
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.vila import VILAForConditionalGeneration from sglang.srt.models.vila import VILAForConditionalGeneration
...@@ -42,7 +42,7 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor): ...@@ -42,7 +42,7 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: Optional[ImageDataItem | List[ImageDataItem]], image_data: Optional[ImageDataInputItem | List[ImageDataInputItem]],
input_text: str | List[int], input_text: str | List[int],
request_obj: GenerateReqInput | EmbeddingReqInput, request_obj: GenerateReqInput | EmbeddingReqInput,
max_req_input_len: int, max_req_input_len: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment