Unverified Commit 1e86457c authored by Mick's avatar Mick Committed by GitHub
Browse files

model: Minicpmo (#3023)

parent 64129fa6
import argparse import argparse
import PIL.Image
import torch import torch
from data_utils import save_json from data_utils import save_json
from eval_utils import ( from eval_utils import (
...@@ -10,22 +11,38 @@ from eval_utils import ( ...@@ -10,22 +11,38 @@ from eval_utils import (
process_result, process_result,
) )
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoModelForImageTextToText, AutoProcessor, GenerationConfig from transformers import AutoModel, AutoProcessor, GenerationConfig
@torch.no_grad() @torch.no_grad()
def eval_mmmu(args): def eval_mmmu(args):
eval_args = EvalArgs.from_cli_args(args) eval_args = EvalArgs.from_cli_args(args)
try:
from transformers import AutoModelForImageTextToText
model = AutoModelForImageTextToText.from_pretrained( model = AutoModelForImageTextToText.from_pretrained(
args.model_path, args.model_path,
torch_dtype="auto", torch_dtype="auto",
trust_remote_code=True, trust_remote_code=True,
) )
except Exception as first_exception:
try:
model = AutoModel.from_pretrained(
args.model_path,
torch_dtype="auto",
trust_remote_code=True,
init_tts=False,
)
except Exception as second_exception:
raise RuntimeError(
f"Failed to load model: First attempt failed with {first_exception}, "
f"second attempt failed with {second_exception}"
) from second_exception
model = model.eval().cuda() model = model.eval().cuda()
processor = AutoProcessor.from_pretrained( processor = AutoProcessor.from_pretrained(
args.model_path, torch_dtype="auto", device_map="auto" args.model_path, torch_dtype="auto", device_map="auto", trust_remote_code=True
) )
samples = prepare_samples(eval_args) samples = prepare_samples(eval_args)
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
- InternLM 2 - InternLM 2
- Exaone 3 - Exaone 3
- BaiChuan2 - BaiChuan2
- MiniCPM / MiniCPM 3 / MiniCPMV - MiniCPM / MiniCPM 3 / MiniCPM-v / MiniCPM-o
- XVERSE / XVERSE MoE - XVERSE / XVERSE MoE
- SmolLM - SmolLM
- GLM-4 - GLM-4
...@@ -70,9 +70,9 @@ LLM. ...@@ -70,9 +70,9 @@ LLM.
1. **Register your new model as multimodal**: Extend `is_multimodal_model` in [ 1. **Register your new model as multimodal**: Extend `is_multimodal_model` in [
`model_config.py`](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/configs/model_config.py) to `model_config.py`](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/configs/model_config.py) to
return True for your model. return True for your model.
2. **Process Images**: Create a new `ImageProcessor` class that inherits from `BaseImageProcessor` and register this 2. **Process Images**: Define a new `Processor` class that inherits from `BaseProcessor` and register this
processor as your model's dedicated processor. See [ processor as your model's dedicated processor. See [
`image_processor.py`](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/image_processor.py) `multimodal_processor.py`](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/multimodal_processor.py)
for more details. for more details.
3. **Handle Image Tokens**: Implement a `pad_input_ids` function for your new model, in which image tokens in the prompt 3. **Handle Image Tokens**: Implement a `pad_input_ids` function for your new model, in which image tokens in the prompt
should be expanded and replaced with image-hashes, so that SGLang can recognize different images for should be expanded and replaced with image-hashes, so that SGLang can recognize different images for
...@@ -80,7 +80,7 @@ LLM. ...@@ -80,7 +80,7 @@ LLM.
4. Replace Multi-headed `Attention` of ViT with SGLang's `VisionAttention`. 4. Replace Multi-headed `Attention` of ViT with SGLang's `VisionAttention`.
You can refer [Qwen2VL](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/qwen2_vl.py) or other You can refer [Qwen2VL](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/qwen2_vl.py) or other
vLMs. These models demonstrate how to properly handle both visual and textual inputs. vLMs. These models demonstrate how to properly handle both multimodal and textual inputs.
You should test the new vLM locally against hf models. See [`mmmu`](https://github.com/sgl-project/sglang/tree/main/benchmark/mmmu) for an example. You should test the new vLM locally against hf models. See [`mmmu`](https://github.com/sgl-project/sglang/tree/main/benchmark/mmmu) for an example.
......
...@@ -34,6 +34,7 @@ runtime_common = [ ...@@ -34,6 +34,7 @@ runtime_common = [
"pydantic", "pydantic",
"python-multipart", "python-multipart",
"pyzmq>=25.1.2", "pyzmq>=25.1.2",
"soundfile==0.13.1",
"torchao>=0.7.0", "torchao>=0.7.0",
"transformers==4.50.0", "transformers==4.50.0",
"uvicorn", "uvicorn",
......
...@@ -15,6 +15,7 @@ class ChatTemplate: ...@@ -15,6 +15,7 @@ class ChatTemplate:
role_prefix_and_suffix: Dict[str, Tuple[str, str]] role_prefix_and_suffix: Dict[str, Tuple[str, str]]
stop_str: List[str] = () stop_str: List[str] = ()
image_token: str = "<image>" image_token: str = "<image>"
audio_token: str = "<audio>"
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
def get_prefix_and_suffix( def get_prefix_and_suffix(
...@@ -253,6 +254,22 @@ register_chat_template( ...@@ -253,6 +254,22 @@ register_chat_template(
) )
) )
# https://huggingface.co/openbmb/MiniCPM-o-2_6
register_chat_template(
ChatTemplate(
name="minicpmo",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", " "),
"user": ("user:", " "),
"assistant": ("assistant:", "</s>"),
},
stop_str=("<|im_end|>", "<|endoftext|>"),
image_token="(<image>./</image>)",
audio_token="(<audio>./</audio>)",
)
)
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token. # The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
register_chat_template( register_chat_template(
ChatTemplate( ChatTemplate(
...@@ -474,12 +491,6 @@ def match_chat_ml(model_path: str): ...@@ -474,12 +491,6 @@ def match_chat_ml(model_path: str):
return get_chat_template("chatml-llava") return get_chat_template("chatml-llava")
@register_chat_template_matching_function
def match_chat_minicpm(model_path: str):
if "minicpm" in model_path:
return get_chat_template("minicpmv")
@register_chat_template_matching_function @register_chat_template_matching_function
def match_chat_yi(model_path: str): def match_chat_yi(model_path: str):
model_path = model_path.lower() model_path = model_path.lower()
...@@ -499,8 +510,10 @@ def match_gemma_it(model_path: str): ...@@ -499,8 +510,10 @@ def match_gemma_it(model_path: str):
@register_chat_template_matching_function @register_chat_template_matching_function
def match_openbmb_minicpm(model_path: str): def match_openbmb_minicpm(model_path: str):
model_path = model_path.lower() model_path = model_path.lower()
if "minicpm" in model_path: if "minicpm-v" in model_path:
return get_chat_template("minicpmv") return get_chat_template("minicpmv")
elif "minicpm-o" in model_path:
return get_chat_template("minicpmo")
@register_chat_template_matching_function @register_chat_template_matching_function
......
...@@ -462,18 +462,19 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal ...@@ -462,18 +462,19 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
multimodal_model_archs = [ multimodal_model_archs = [
"DeepseekVL2ForCausalLM", "DeepseekVL2ForCausalLM",
"LlavaLlamaForCausalLM",
"LlavaQwenForCausalLM",
"LlavaMistralForCausalLM",
"LlavaVidForCausalLM",
"Gemma3ForConditionalGeneration", "Gemma3ForConditionalGeneration",
"Grok1VForCausalLM", "Grok1VForCausalLM",
"Grok1AForCausalLM", "Grok1AForCausalLM",
"LlavaLlamaForCausalLM",
"LlavaMistralForCausalLM",
"LlavaQwenForCausalLM",
"LlavaVidForCausalLM",
"MiniCPMO",
"MiniCPMV",
"MultiModalityCausalLM",
"MllamaForConditionalGeneration", "MllamaForConditionalGeneration",
"Qwen2VLForConditionalGeneration", "Qwen2VLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration",
"MiniCPMV",
"MultiModalityCausalLM",
] ]
......
...@@ -73,11 +73,14 @@ class Conversation: ...@@ -73,11 +73,14 @@ class Conversation:
stop_str: Union[str, List[str]] = None stop_str: Union[str, List[str]] = None
# The string that represents an image token in the prompt # The string that represents an image token in the prompt
image_token: str = "<image>" image_token: str = "<image>"
audio_token: str = "<audio>"
image_data: Optional[List[str]] = None image_data: Optional[List[str]] = None
modalities: Optional[List[str]] = None modalities: Optional[List[str]] = None
stop_token_ids: Optional[int] = None stop_token_ids: Optional[int] = None
audio_data: Optional[List[str]] = None
def get_prompt(self) -> str: def get_prompt(self) -> str:
"""Get the prompt for generation.""" """Get the prompt for generation."""
system_prompt = self.system_template.format(system_message=self.system_message) system_prompt = self.system_template.format(system_message=self.system_message)
...@@ -327,6 +330,10 @@ class Conversation: ...@@ -327,6 +330,10 @@ class Conversation:
"""Append a new message.""" """Append a new message."""
self.image_data.append(image) self.image_data.append(image)
def append_audio(self, audio: str):
"""Append a new message."""
self.audio_data.append(audio)
def update_last_message(self, message: str): def update_last_message(self, message: str):
"""Update the last output. """Update the last output.
...@@ -373,6 +380,7 @@ class Conversation: ...@@ -373,6 +380,7 @@ class Conversation:
sep2=self.sep2, sep2=self.sep2,
stop_str=self.stop_str, stop_str=self.stop_str,
image_token=self.image_token, image_token=self.image_token,
audio_token=self.audio_token,
) )
def dict(self): def dict(self):
...@@ -459,8 +467,10 @@ def generate_chat_conv( ...@@ -459,8 +467,10 @@ def generate_chat_conv(
sep2=conv.sep2, sep2=conv.sep2,
stop_str=conv.stop_str, stop_str=conv.stop_str,
image_data=[], image_data=[],
audio_data=[],
modalities=[], modalities=[],
image_token=conv.image_token, image_token=conv.image_token,
audio_token=conv.audio_token,
) )
if isinstance(request.messages, str): if isinstance(request.messages, str):
...@@ -498,6 +508,7 @@ def generate_chat_conv( ...@@ -498,6 +508,7 @@ def generate_chat_conv(
if conv.name != "qwen2-vl" if conv.name != "qwen2-vl"
else conv.image_token else conv.image_token
) )
audio_token = conv.audio_token
for content in message.content: for content in message.content:
if content.type == "text": if content.type == "text":
if num_image_url > 16: if num_image_url > 16:
...@@ -507,6 +518,10 @@ def generate_chat_conv( ...@@ -507,6 +518,10 @@ def generate_chat_conv(
# NOTE: Only works for llava # NOTE: Only works for llava
real_content += image_token real_content += image_token
conv.append_image(content.image_url.url) conv.append_image(content.image_url.url)
elif content.type == "audio_url":
real_content += audio_token
conv.append_audio(content.audio_url.url)
conv.append_message(conv.roles[0], real_content) conv.append_message(conv.roles[0], real_content)
elif msg_role == "assistant": elif msg_role == "assistant":
parsed_content = "" parsed_content = ""
...@@ -704,3 +719,18 @@ register_conv_template( ...@@ -704,3 +719,18 @@ register_conv_template(
image_token="<image_placeholder>", image_token="<image_placeholder>",
) )
) )
# Reference: https://huggingface.co/openbmb/MiniCPM-o-2_6#usage
register_conv_template(
Conversation(
name="minicpmo",
system_message="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
system_template="<|im_start|>system\n{system_message}",
roles=("<|im_start|>user", "<|im_start|>assistant"),
sep="<|im_end|>\n",
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
stop_str=("<|im_end|>", "<|endoftext|>"),
image_token="(<image>./</image>)",
audio_token="(<audio>./</audio>)",
)
)
...@@ -45,6 +45,8 @@ class GenerateReqInput: ...@@ -45,6 +45,8 @@ class GenerateReqInput:
# The image input. It can be a file name, a url, or base64 encoded string. # The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image. # See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None image_data: Optional[Union[List[str], str]] = None
# The audio input. Like image data, tt can be a file name, a url, or base64 encoded string.
audio_data: Optional[Union[List[str], str]] = None
# The sampling_params. See descriptions below. # The sampling_params. See descriptions below.
sampling_params: Optional[Union[List[Dict], Dict]] = None sampling_params: Optional[Union[List[Dict], Dict]] = None
# The request id. # The request id.
...@@ -167,6 +169,13 @@ class GenerateReqInput: ...@@ -167,6 +169,13 @@ class GenerateReqInput:
elif isinstance(self.image_data, list): elif isinstance(self.image_data, list):
pass pass
if self.audio_data is None:
self.audio_data = [None] * num
elif not isinstance(self.audio_data, list):
self.audio_data = [self.audio_data] * num
elif isinstance(self.audio_data, list):
pass
if self.sampling_params is None: if self.sampling_params is None:
self.sampling_params = [{}] * num self.sampling_params = [{}] * num
elif not isinstance(self.sampling_params, list): elif not isinstance(self.sampling_params, list):
...@@ -231,6 +240,7 @@ class GenerateReqInput: ...@@ -231,6 +240,7 @@ class GenerateReqInput:
text=self.text[i] if self.text is not None else None, text=self.text[i] if self.text is not None else None,
input_ids=self.input_ids[i] if self.input_ids is not None else None, input_ids=self.input_ids[i] if self.input_ids is not None else None,
image_data=self.image_data[i], image_data=self.image_data[i],
audio_data=self.audio_data[i],
sampling_params=self.sampling_params[i], sampling_params=self.sampling_params[i],
rid=self.rid[i], rid=self.rid[i],
return_logprob=self.return_logprob[i], return_logprob=self.return_logprob[i],
...@@ -259,8 +269,8 @@ class TokenizedGenerateReqInput: ...@@ -259,8 +269,8 @@ class TokenizedGenerateReqInput:
input_text: str input_text: str
# The input token ids # The input token ids
input_ids: List[int] input_ids: List[int]
# The image inputs # The multimodal inputs
image_inputs: dict mm_inputs: dict
# The sampling parameters # The sampling parameters
sampling_params: SamplingParams sampling_params: SamplingParams
# Whether to return the logprobs # Whether to return the logprobs
......
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
from torch import nn from torch import nn
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
ImageInputs, MultimodalInputs,
global_server_args_dict, global_server_args_dict,
logger, logger,
) )
...@@ -26,7 +26,7 @@ class MultiModalityDataPaddingPattern: ...@@ -26,7 +26,7 @@ class MultiModalityDataPaddingPattern:
@abstractmethod @abstractmethod
def pad_input_tokens( def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs self, input_ids: List[int], image_inputs: MultimodalInputs
) -> List[int]: ) -> List[int]:
""" """
Pad the input ids sequence containing data tokens, and replace them with pad_values Pad the input ids sequence containing data tokens, and replace them with pad_values
...@@ -44,16 +44,16 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern) ...@@ -44,16 +44,16 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
self.data_token_id_pairs = data_token_pairs self.data_token_id_pairs = data_token_pairs
def pad_input_tokens( def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]: ) -> List[int]:
""" """
This function will replace the data-tokens inbetween with pad_values accordingly This function will replace the data-tokens inbetween with pad_values accordingly
""" """
pad_values = image_inputs.pad_values pad_values = mm_inputs.pad_values
data_token_pairs = self.data_token_id_pairs data_token_pairs = self.data_token_id_pairs
image_inputs.image_offsets = [] mm_inputs.image_offsets = []
if data_token_pairs is None: if data_token_pairs is None:
data_token_pairs = [image_inputs.im_start_id, image_inputs.im_end_id] data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
if data_token_pairs is None: if data_token_pairs is None:
logger.warning( logger.warning(
"No data_token_pairs provided, RadixAttention might be influenced." "No data_token_pairs provided, RadixAttention might be influenced."
...@@ -61,8 +61,6 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern) ...@@ -61,8 +61,6 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
return input_ids return input_ids
start_token_ids = [s for s, _e in data_token_pairs] start_token_ids = [s for s, _e in data_token_pairs]
end_tokens_ids = [e for _s, e in data_token_pairs] end_tokens_ids = [e for _s, e in data_token_pairs]
# First start token marks new data
data_start_token = start_token_ids[0]
padded_ids = [] padded_ids = []
last_idx = 0 last_idx = 0
...@@ -77,9 +75,12 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern) ...@@ -77,9 +75,12 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
for start_idx, end_idx in zip(start_indices, end_indices): for start_idx, end_idx in zip(start_indices, end_indices):
padded_ids.extend(input_ids[last_idx : start_idx + 1]) padded_ids.extend(input_ids[last_idx : start_idx + 1])
if input_ids[start_idx] == data_start_token: if input_ids[start_idx] in start_token_ids:
data_idx += 1 data_idx += 1
image_inputs.image_offsets += [start_idx] mm_inputs.image_offsets += [start_idx]
if data_idx >= len(mm_inputs.pad_values):
data_idx = len(mm_inputs.pad_values) - 1
num_tokens = end_idx - start_idx - 1 num_tokens = end_idx - start_idx - 1
pad_value = pad_values[data_idx] pad_value = pad_values[data_idx]
...@@ -89,7 +90,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern) ...@@ -89,7 +90,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
padded_ids.extend(input_ids[last_idx:]) padded_ids.extend(input_ids[last_idx:])
assert len(input_ids) == len(padded_ids) assert len(input_ids) == len(padded_ids), "Length validation fails"
return padded_ids return padded_ids
...@@ -107,26 +108,25 @@ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern) ...@@ -107,26 +108,25 @@ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern)
self.num_data_token_calc_func = num_data_token_calc_func self.num_data_token_calc_func = num_data_token_calc_func
def pad_input_tokens( def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]: ) -> List[int]:
""" """
This function will follow the procedure of: This function will follow the procedure of:
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func` 1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
2. the padded data tokens will be replaced with their pad_values 2. the padded data tokens will be replaced with their pad_values
""" """
image_grid_thws = image_inputs.image_grid_thws image_grid_thws = mm_inputs.image_grid_thws
pad_values = image_inputs.pad_values pad_values = mm_inputs.pad_values
image_indices = [ image_indices = [
idx idx for idx, token in enumerate(input_ids) if token == mm_inputs.im_token_id
for idx, token in enumerate(input_ids)
if token == image_inputs.im_token_id
] ]
image_inputs.image_offsets = [] mm_inputs.image_offsets = []
input_ids_with_image = [] input_ids_with_image = []
for image_cnt, _ in enumerate(image_grid_thws): for image_cnt, _ in enumerate(image_grid_thws):
# print(f"image_cnt {image_cnt}")
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt]) num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
if image_cnt == 0: if image_cnt == 0:
non_image_tokens = input_ids[: image_indices[image_cnt]] non_image_tokens = input_ids[: image_indices[image_cnt]]
...@@ -135,7 +135,7 @@ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern) ...@@ -135,7 +135,7 @@ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern)
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt] image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
] ]
input_ids_with_image.extend(non_image_tokens) input_ids_with_image.extend(non_image_tokens)
image_inputs.image_offsets.append(len(input_ids_with_image)) mm_inputs.image_offsets.append(len(input_ids_with_image))
pad_ids = pad_values * ( pad_ids = pad_values * (
(num_image_tokens + len(pad_values)) // len(pad_values) (num_image_tokens + len(pad_values)) // len(pad_values)
) )
...@@ -170,11 +170,11 @@ class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern ...@@ -170,11 +170,11 @@ class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern
return input_ids_tensor.tolist() return input_ids_tensor.tolist()
def embed_image_inputs( def embed_mm_inputs(
image_input: ImageInputs, mm_input: MultimodalInputs,
input_ids: torch.Tensor, input_ids: torch.Tensor,
input_embedding: nn.Embedding, input_embedding: nn.Embedding,
image_embedding_func, mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
placeholder_token_ids: List[int] = None, placeholder_token_ids: List[int] = None,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
""" """
...@@ -184,10 +184,10 @@ def embed_image_inputs( ...@@ -184,10 +184,10 @@ def embed_image_inputs(
Returns: Returns:
final embedding: Optional[torch.Tensor] final embedding: Optional[torch.Tensor]
""" """
if image_input is None: if mm_input is None:
return None return None
placeholder_token_ids = placeholder_token_ids or image_input.pad_values placeholder_token_ids = placeholder_token_ids or mm_input.pad_values
# boolean masking the special tokens # boolean masking the special tokens
special_image_mask = torch.isin( special_image_mask = torch.isin(
...@@ -196,12 +196,18 @@ def embed_image_inputs( ...@@ -196,12 +196,18 @@ def embed_image_inputs(
).unsqueeze(-1) ).unsqueeze(-1)
num_image_tokens_in_input_ids = special_image_mask.sum() num_image_tokens_in_input_ids = special_image_mask.sum()
# print(f"{num_image_tokens_in_input_ids}")
# print(f"{input_ids}")
# return
if num_image_tokens_in_input_ids == 0: if num_image_tokens_in_input_ids == 0:
# unexpected # unexpected
inputs_embeds = input_embedding(input_ids) inputs_embeds = input_embedding(input_ids)
else: else:
image_embedding = image_embedding_func(image_input) # print(f"Getting image feature")
image_embedding = mm_data_embedding_func(mm_input)
# print(f"image_embedding: {image_embedding.shape}")
if image_embedding.dim() == 2: if image_embedding.dim() == 2:
num_image_tokens_in_embedding = image_embedding.shape[0] num_image_tokens_in_embedding = image_embedding.shape[0]
...@@ -273,31 +279,95 @@ def embed_image_embedding( ...@@ -273,31 +279,95 @@ def embed_image_embedding(
def general_mm_embed_routine( def general_mm_embed_routine(
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
embed_tokens: nn.Embedding, embed_tokens: nn.Embedding,
image_embedding_func: Callable[[ImageInputs], torch.Tensor], mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
placeholder_token_ids: List[int] = None, placeholder_token_ids: List[int] = None,
): ):
""" """
a general wrapper function to get final input embeds from multimodal models a general wrapper function to get final input embeds from multimodal models
with a language model as causal model with a language model as causal model
Args:
placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
""" """
if ( if (
forward_batch.forward_mode.is_decode() not forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs() and forward_batch.contains_mm_inputs()
): ):
inputs_embeds = embed_tokens(input_ids) image = forward_batch.merge_mm_inputs()
else: inputs_embeds = embed_mm_inputs(
image = forward_batch.merge_image_inputs() mm_input=image,
inputs_embeds = embed_image_inputs(
image_input=image,
input_ids=input_ids, input_ids=input_ids,
input_embedding=embed_tokens, input_embedding=embed_tokens,
image_embedding_func=image_embedding_func, mm_data_embedding_func=mm_data_embedding_func,
placeholder_token_ids=placeholder_token_ids, placeholder_token_ids=placeholder_token_ids,
) )
# once used, image_inputs is useless # once used, mm_inputs is useless
# just being defensive here # just being defensive here
forward_batch.image_inputs = None forward_batch.mm_inputs = None
else:
inputs_embeds = embed_tokens(input_ids)
return inputs_embeds return inputs_embeds
def get_multimodal_data_bounds(
input_ids: torch.Tensor, pad_values: List[int], token_pairs: List[Tuple[int, int]]
) -> torch.Tensor:
"""
Returns a tensor indicating the bounds of multimodal data (images, video, audio, etc.)
Returns:
[bounds_count, 2]
"""
# All the images in the batch should share the same special image
# bound token ids.
start_tokens = [s for s, _e in token_pairs]
end_tokens = [e for _s, e in token_pairs]
assert all(isinstance(t, int) for t in start_tokens)
assert all(isinstance(t, int) for t in end_tokens)
# print(input_ids)
start_cond = torch.isin(
input_ids, torch.tensor(start_tokens, device=input_ids.device)
)
end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))
(data_start_tokens,) = torch.where(start_cond)
(data_end_tokens,) = torch.where(end_cond)
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the images
if len(data_start_tokens) != len(data_end_tokens):
if (
len(data_start_tokens) + 1 == len(data_end_tokens)
and input_ids[0] in pad_values
and data_end_tokens[0] < data_start_tokens[0]
):
data_start_tokens = torch.cat(
[
torch.tensor([0], device=data_start_tokens.device),
data_start_tokens,
]
)
valid_image_nums = min(len(data_start_tokens), len(data_end_tokens))
if valid_image_nums == 0:
return torch.zeros((0, 2), device=input_ids.device)
# Filter out pairs where start_token >= end_token
valid_pairs = []
for i in range(valid_image_nums):
start_token = data_start_tokens[i]
end_token = data_end_tokens[i]
if start_token < end_token:
valid_pairs.append((start_token + 1, end_token - 1))
if not valid_pairs:
return torch.zeros((0, 2), device=input_ids.device)
# Convert valid pairs to tensor
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
return valid_pairs_tensor
...@@ -4,46 +4,41 @@ import inspect ...@@ -4,46 +4,41 @@ import inspect
import logging import logging
import pkgutil import pkgutil
from functools import lru_cache from functools import lru_cache
from typing import Union
from torch import Tensor from transformers import PROCESSOR_MAPPING
from transformers import IMAGE_PROCESSOR_MAPPING
from sglang.srt.managers.image_processors.base_image_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
BaseImageProcessor, BaseMultimodalProcessor,
DummyImageProcessor,
) )
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PROCESSOR_MAPPING = {}
IMAGE_PROCESSOR_MAPPING = {}
class DummyMultimodalProcessor(BaseMultimodalProcessor):
def __init__(self):
pass
def get_image_processor(hf_config, server_args, processor) -> BaseImageProcessor: async def process_mm_data_async(self, *args, **kwargs):
for model_cls, processor_cls in IMAGE_PROCESSOR_MAPPING.items(): return None
if model_cls.__name__ in hf_config.architectures:
return processor_cls(hf_config, server_args, processor)
raise ValueError(
f"No image processor found for architecture: {hf_config.architectures}"
)
def get_dummy_image_processor(): def get_dummy_processor():
return DummyImageProcessor() return DummyMultimodalProcessor()
@lru_cache() @lru_cache()
def import_image_processors(): def import_processors():
package_name = "sglang.srt.managers.image_processors" package_name = "sglang.srt.managers.multimodal_processors"
package = importlib.import_module(package_name) package = importlib.import_module(package_name)
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."): for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg: if not ispkg:
try: try:
module = importlib.import_module(name) module = importlib.import_module(name)
except Exception as e: except Exception as e:
logger.warning(f" Ignore import error when loading {name}: " f"{e}") logger.warning(f"Ignore import error when loading {name}: " f"{e}")
continue continue
all_members = inspect.getmembers(module, inspect.isclass) all_members = inspect.getmembers(module, inspect.isclass)
classes = [ classes = [
...@@ -51,11 +46,23 @@ def import_image_processors(): ...@@ -51,11 +46,23 @@ def import_image_processors():
for name, member in all_members for name, member in all_members
if member.__module__ == module.__name__ if member.__module__ == module.__name__
] ]
for cls in classes: for cls in (
if issubclass(cls, BaseImageProcessor): cls for cls in classes if issubclass(cls, BaseMultimodalProcessor)
):
assert hasattr(cls, "models")
for arch in getattr(cls, "models"): for arch in getattr(cls, "models"):
IMAGE_PROCESSOR_MAPPING[arch] = cls PROCESSOR_MAPPING[arch] = cls
def get_mm_processor(
hf_config, server_args: ServerArgs, processor
) -> BaseMultimodalProcessor:
for model_cls, processor_cls in PROCESSOR_MAPPING.items():
if model_cls.__name__ in hf_config.architectures:
return processor_cls(hf_config, server_args, processor)
raise ValueError(
f"No processor registered for architecture: {hf_config.architectures}.\n"
f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
)
# also register processors self.image_proce
import_image_processors()
...@@ -4,16 +4,16 @@ import dataclasses ...@@ -4,16 +4,16 @@ import dataclasses
import multiprocessing as mp import multiprocessing as mp
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional, Union from typing import Optional
import numpy as np
import PIL import PIL
import transformers import transformers
from decord import VideoReader, cpu from decord import VideoReader, cpu
from openai import BadRequestError from openai import BadRequestError
from PIL import Image from PIL import Image
from sglang.srt.utils import load_image from sglang.srt.utils import load_audio, load_image, logger
from sglang.utils import logger
global global_processor global global_processor
...@@ -24,21 +24,41 @@ def get_global_processor(): ...@@ -24,21 +24,41 @@ def get_global_processor():
@dataclasses.dataclass @dataclasses.dataclass
class BaseImageProcessorOutput: class BaseMultiModalProcessorOutput:
image_hashes: list[int] # input_text, with each frame of video/image represented with a image_token
image_sizes: list[tuple[int, int]]
all_frames: [PIL.Image]
# input_text, with each frame of video/image represented as an image_token
input_text: str input_text: str
mm_data_hashes: Optional[list[int]]
# images
image_sizes: Optional[list[int]]
# frames loaded from image and video, in given order
images: Optional[list[PIL.Image]] = None
# audios
audios: Optional[list[np.ndarray]] = None
def normalize(self): def normalize(self):
for field_name in ["data_hashes", "image_sizes", "all_frames"]: for field_name in ["data_hashes", "image_sizes", "images", "audios"]:
field = getattr(self, field_name, None) field = getattr(self, field_name, None)
if field is not None and isinstance(field, list) and len(field) == 0: if field is not None and isinstance(field, list) and len(field) == 0:
setattr(self, field_name, None) setattr(self, field_name, None)
class BaseImageProcessor(ABC): @dataclasses.dataclass
class MultimodalSpecialTokens:
image_token: Optional[str] = None
video_token: Optional[str] = None
audio_token: Optional[str] = None
def collect(self) -> list[str]:
return [
token
for token in [self.image_token, self.video_token, self.audio_token]
if token
]
class BaseMultimodalProcessor(ABC):
models = [] models = []
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
...@@ -72,7 +92,7 @@ class BaseImageProcessor(ABC): ...@@ -72,7 +92,7 @@ class BaseImageProcessor(ABC):
) )
@abstractmethod @abstractmethod
async def process_images_async( async def process_mm_data_async(
self, image_data, input_text, max_req_input_len, **kwargs self, image_data, input_text, max_req_input_len, **kwargs
): ):
pass pass
...@@ -120,29 +140,33 @@ class BaseImageProcessor(ABC): ...@@ -120,29 +140,33 @@ class BaseImageProcessor(ABC):
frames = [Image.fromarray(v.astype("uint8")) for v in frames] frames = [Image.fromarray(v.astype("uint8")) for v in frames]
return frames return frames
def load_images( def load_mm_data(
self, self,
input_ids: list[int], input_ids: list[int],
image_data, multimodal_tokens: MultimodalSpecialTokens,
image_token: Union[int, str],
max_req_input_len: int, max_req_input_len: int,
image_data: Optional[list] = None,
audio_data: Optional[list] = None,
return_text: Optional[bool] = True, return_text: Optional[bool] = True,
discard_alpha_channel: bool = True, discard_alpha_channel: bool = True,
) -> BaseImageProcessorOutput: ) -> BaseMultiModalProcessorOutput:
""" """
Each frame of video/image will be replaced by a single image token Each frame of video/image will be replaced by a single image token
Args: Args:
image_token: The token ID representing the image placeholder. multimodal_tokens (list[str]): list of special token which denoting a single multimodal data
e.g. image token or audio token
discard_alpha_channel: if True, discards the alpha channel in the returned images discard_alpha_channel: if True, discards the alpha channel in the returned images
""" """
if isinstance(image_token, int): if isinstance(multimodal_tokens.image_token, int):
image_token_str = self._processor.tokenizer.convert_ids_to_tokens( multimodal_tokens.image_token = (
image_token self._processor.tokenizer.convert_ids_to_tokens(
multimodal_tokens.image_token
)
) )
else: else:
image_token_str = image_token multimodal_tokens.image_token = multimodal_tokens.image_token
if isinstance(input_ids, list) and return_text: if isinstance(input_ids, list) and return_text:
assert len(input_ids) and isinstance(input_ids[0], int) assert len(input_ids) and isinstance(input_ids[0], int)
...@@ -152,7 +176,11 @@ class BaseImageProcessor(ABC): ...@@ -152,7 +176,11 @@ class BaseImageProcessor(ABC):
if return_text: if return_text:
import re import re
pattern = "(" + "|".join(re.escape(sep) for sep in [image_token]) + ")" pattern = (
"("
+ "|".join(re.escape(sep) for sep in multimodal_tokens.collect())
+ ")"
)
# split text into list of normal text and special tokens # split text into list of normal text and special tokens
text_parts = re.split(pattern, input_text) text_parts = re.split(pattern, input_text)
...@@ -162,7 +190,7 @@ class BaseImageProcessor(ABC): ...@@ -162,7 +190,7 @@ class BaseImageProcessor(ABC):
total_frame_count = sum(estimated_frames_list) total_frame_count = sum(estimated_frames_list)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs. # a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used # e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
_scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count)) scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
assert len(image_data) == len(estimated_frames_list) assert len(image_data) == len(estimated_frames_list)
...@@ -171,9 +199,16 @@ class BaseImageProcessor(ABC): ...@@ -171,9 +199,16 @@ class BaseImageProcessor(ABC):
new_text = "" new_text = ""
for index, text_part in enumerate(text_parts): for index, text_part in enumerate(text_parts):
try: try:
if text_part == image_token: if text_part == multimodal_tokens.image_token:
# load as image # load as image
frames_to_process = estimated_frames_list[image_index] if len(images) >= MAX_NUM_FRAMES:
frames_to_process = 0
else:
estimated_frames = estimated_frames_list[image_index]
frames_to_process = max(
1, int(estimated_frames * scaling_factor)
)
if frames_to_process == 0: if frames_to_process == 0:
frames = [] frames = []
else: else:
...@@ -183,7 +218,7 @@ class BaseImageProcessor(ABC): ...@@ -183,7 +218,7 @@ class BaseImageProcessor(ABC):
): ):
# video # video
path = image_file[len("video:") :] path = image_file[len("video:") :]
frames = self.encode_video( frames = BaseMultimodalProcessor.encode_video(
path, frame_count_limit=frames_to_process path, frame_count_limit=frames_to_process
) )
else: else:
...@@ -200,40 +235,41 @@ class BaseImageProcessor(ABC): ...@@ -200,40 +235,41 @@ class BaseImageProcessor(ABC):
images += frames images += frames
image_index += 1 image_index += 1
if frames_to_process != 0: if frames_to_process != 0:
new_text += image_token * len(frames) new_text += multimodal_tokens.image_token * len(frames)
assert frames_to_process == len(frames) assert frames_to_process == len(frames)
elif text_part == multimodal_tokens.audio_token:
# load as audio
audio_file = audio_data[audio_index]
audio = load_audio(audio_file)
hashes += [hash(audio_file)]
audios += [audio]
audio_index += 1
new_text += multimodal_tokens.audio_token
else: else:
# TODO(mick): handle video # TODO(mick): handle video
# normal text # normal text
new_text += text_part new_text += text_part
except Exception as e: except Exception as e:
logger.error(f"An exception occurred while loading images: {e}") logger.error(f"An exception occurred while loading images: {e}")
raise BadRequestError( raise BadRequestError(
f"An exception occurred while loading images: {e}" f"An exception occurred while loading images: {e}"
) )
return BaseImageProcessorOutput( out = BaseMultiModalProcessorOutput(
image_hashes=hashes, mm_data_hashes=hashes,
image_sizes=image_sizes, image_sizes=image_sizes,
all_frames=images, images=images,
audios=audios,
input_text=new_text, input_text=new_text,
) )
out.normalize() out.normalize()
return out return out
class DummyImageProcessor(BaseImageProcessor): def init_global_processor(sglang_processor: BaseMultimodalProcessor, server_args):
def __init__(self): """
pass Init the global processor for multimodal models."""
async def process_images_async(self, *args, **kwargs):
return None
def init_global_processor(sglang_image_processor: BaseImageProcessor, server_args):
"""Init the global processor for multi-modal models."""
global global_processor global global_processor
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
global_processor = sglang_image_processor._build_processor(server_args=server_args) global_processor = sglang_processor._build_processor(server_args=server_args)
...@@ -20,14 +20,15 @@ import asyncio ...@@ -20,14 +20,15 @@ import asyncio
import torch import torch
from sglang.srt.managers.image_processor import BaseImageProcessor from sglang.srt.managers.multimodal_processors.base_processor import (
from sglang.srt.managers.image_processors.base_image_processor import ( BaseMultimodalProcessor,
MultimodalSpecialTokens,
get_global_processor, get_global_processor,
) )
from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
class DeepseekVL2ImageProcessor(BaseImageProcessor): class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
models = [DeepseekVL2ForCausalLM] models = [DeepseekVL2ForCausalLM]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
...@@ -63,7 +64,23 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor): ...@@ -63,7 +64,23 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
return image_inputs return image_inputs
async def process_images_async( async def _process_images(self, image_data, input_text, max_req_input_len):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
DeepseekVL2ImageProcessor._process_images_task,
image_data,
input_text,
max_req_input_len,
)
else:
image_inputs = self._process_images_task(
image_data, input_text, max_req_input_len
)
return image_inputs
async def process_mm_data_async(
self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs
): ):
if not image_data: if not image_data:
...@@ -75,11 +92,14 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor): ...@@ -75,11 +92,14 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
images, image_sizes = [], [] images, image_sizes = [], []
image_token = self.IMAGE_TOKEN image_token = self.IMAGE_TOKEN
base_output = self.load_images( base_output = self.load_mm_data(
input_ids, image_data, image_token, max_req_input_len input_ids,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len,
) )
res = await self._process_images( res = await self._process_images(
base_output.all_frames, base_output.input_text, max_req_input_len base_output.images, base_output.input_text, max_req_input_len
) )
images_seq_mask = res["images_seq_mask"] images_seq_mask = res["images_seq_mask"]
images_spatial_crop = res["images_spatial_crop"] images_spatial_crop = res["images_spatial_crop"]
...@@ -91,7 +111,7 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor): ...@@ -91,7 +111,7 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
"input_ids": res["input_ids"].tolist(), "input_ids": res["input_ids"].tolist(),
"pixel_values": res["images"], "pixel_values": res["images"],
"im_token_id": res["im_token_id"], "im_token_id": res["im_token_id"],
"image_hashes": base_output.image_hashes, "data_hashes": base_output.mm_data_hashes,
"image_sizes": image_sizes, "image_sizes": image_sizes,
"images_emb_mask": images_seq_mask, "images_emb_mask": images_seq_mask,
"image_spatial_crop": batched_images_spatial_crop, "image_spatial_crop": batched_images_spatial_crop,
......
import asyncio
from typing import List, Union from typing import List, Union
from transformers.utils import logging from transformers.utils import logging
from sglang.srt.managers.image_processor import ( from sglang.srt.managers.multimodal_processor import (
BaseImageProcessor as SGLangBaseImageProcessor, BaseMultimodalProcessor as SGLangBaseProcessor,
) )
from sglang.srt.managers.image_processors.base_image_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
MultimodalSpecialTokens,
get_global_processor, get_global_processor,
) )
from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
...@@ -16,7 +16,7 @@ from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration ...@@ -16,7 +16,7 @@ from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor): class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
models = [Gemma3ForConditionalGeneration] models = [Gemma3ForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
...@@ -47,7 +47,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor): ...@@ -47,7 +47,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
"pixel_values": pixel_values, "pixel_values": pixel_values,
} }
async def process_images_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
input_ids, input_ids,
...@@ -62,22 +62,22 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor): ...@@ -62,22 +62,22 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
image_data = [image_data] image_data = [image_data]
image_token = self.IMAGE_TOKEN image_token = self.IMAGE_TOKEN
base_output = self.load_images( base_output = self.load_mm_data(
input_ids=input_ids, input_ids=input_ids,
image_data=image_data, image_data=image_data,
image_token=image_token, multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
discard_alpha_channel=True, discard_alpha_channel=True,
) )
ret = await self._process_single_image( ret = await self._process_single_image(
input_text=base_output.input_text, images=base_output.all_frames input_text=base_output.input_text, images=base_output.images
) )
return { return {
"input_ids": ret["input_ids"].flatten().tolist(), "input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"], "pixel_values": ret["pixel_values"],
"image_hashes": base_output.image_hashes, "data_hashes": base_output.mm_data_hashes,
"im_start_id": self.IM_START_TOKEN_ID, "im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID,
} }
import asyncio import asyncio
from typing import List, Union from typing import List, Union
from sglang.srt.managers.image_processors.base_image_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
BaseImageProcessor as SGLangBaseImageProcessor, BaseMultimodalProcessor,
) MultimodalSpecialTokens,
from sglang.srt.managers.image_processors.base_image_processor import (
get_global_processor, get_global_processor,
) )
from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM
class JanusProProcessor(SGLangBaseImageProcessor): class JanusProImageProcessor(BaseMultimodalProcessor):
models = [MultiModalityCausalLM] models = [MultiModalityCausalLM]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
...@@ -36,7 +35,7 @@ class JanusProProcessor(SGLangBaseImageProcessor): ...@@ -36,7 +35,7 @@ class JanusProProcessor(SGLangBaseImageProcessor):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor( image_inputs = await loop.run_in_executor(
self.executor, self.executor,
JanusProProcessor._process_images_task, JanusProImageProcessor._process_images_task,
images, images,
input_text, input_text,
) )
...@@ -47,7 +46,7 @@ class JanusProProcessor(SGLangBaseImageProcessor): ...@@ -47,7 +46,7 @@ class JanusProProcessor(SGLangBaseImageProcessor):
return image_inputs return image_inputs
async def process_images_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
input_ids, input_ids,
...@@ -61,20 +60,24 @@ class JanusProProcessor(SGLangBaseImageProcessor): ...@@ -61,20 +60,24 @@ class JanusProProcessor(SGLangBaseImageProcessor):
if not isinstance(image_data, list): if not isinstance(image_data, list):
image_data = [image_data] image_data = [image_data]
base_out = self.load_images( base_out = self.load_mm_data(
input_ids=input_ids, input_ids=input_ids,
image_data=image_data, image_data=image_data,
image_token="<image_placeholder>", multimodal_tokens=MultimodalSpecialTokens(
image_token="<image_placeholder>"
),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
) )
images = base_out.all_frames images = base_out.images
res = await self._process_images(images=images, input_text=base_out.input_text) res = await self._process_images(images=images, input_text=base_out.input_text)
# print(res)
# print(base_out)
# print("", res["images_emb_mask"].shape)
return { return {
"input_ids": res["input_ids"].flatten().tolist(), "input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": res["pixel_values"], "pixel_values": res["pixel_values"],
"images_emb_mask": res["images_emb_mask"], "images_emb_mask": res["images_emb_mask"],
"image_hashes": base_out.image_hashes, "data_hashes": base_out.mm_data_hashes,
"im_start_id": res["im_start_id"], "im_start_id": res["im_start_id"],
"im_end_id": res["im_end_id"], "im_end_id": res["im_end_id"],
"im_token_id": res["im_token_id"], "im_token_id": res["im_token_id"],
......
...@@ -3,8 +3,8 @@ from typing import List, Optional, Union ...@@ -3,8 +3,8 @@ from typing import List, Optional, Union
import numpy as np import numpy as np
from sglang.srt.managers.image_processor import BaseImageProcessor from sglang.srt.managers.multimodal_processors.base_processor import (
from sglang.srt.managers.image_processors.base_image_processor import ( BaseMultimodalProcessor,
get_global_processor, get_global_processor,
) )
from sglang.srt.mm_utils import expand2square, process_anyres_image from sglang.srt.mm_utils import expand2square, process_anyres_image
...@@ -14,7 +14,7 @@ from sglang.srt.utils import load_image, logger ...@@ -14,7 +14,7 @@ from sglang.srt.utils import load_image, logger
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
class LlavaImageProcessor(BaseImageProcessor): class LlavaImageProcessor(BaseMultimodalProcessor):
models = [LlavaVidForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM] models = [LlavaVidForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
...@@ -86,7 +86,7 @@ class LlavaImageProcessor(BaseImageProcessor): ...@@ -86,7 +86,7 @@ class LlavaImageProcessor(BaseImageProcessor):
image_data, aspect_ratio, grid_pinpoints image_data, aspect_ratio, grid_pinpoints
) )
async def process_images_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
input_text, input_text,
...@@ -113,7 +113,7 @@ class LlavaImageProcessor(BaseImageProcessor): ...@@ -113,7 +113,7 @@ class LlavaImageProcessor(BaseImageProcessor):
if "multi-images" in modalities or "video" in modalities: if "multi-images" in modalities or "video" in modalities:
# Multiple images # Multiple images
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values, image_hashes, image_sizes = [], [], [] pixel_values, data_hashes, image_sizes = [], [], []
res = [] res = []
for img_data in image_data: for img_data in image_data:
res.append( res.append(
...@@ -124,7 +124,7 @@ class LlavaImageProcessor(BaseImageProcessor): ...@@ -124,7 +124,7 @@ class LlavaImageProcessor(BaseImageProcessor):
res = await asyncio.gather(*res) res = await asyncio.gather(*res)
for pixel_v, image_h, image_s in res: for pixel_v, image_h, image_s in res:
pixel_values.append(pixel_v) pixel_values.append(pixel_v)
image_hashes.append(image_h) data_hashes.append(image_h)
image_sizes.append(image_s) image_sizes.append(image_s)
if isinstance(pixel_values[0], np.ndarray): if isinstance(pixel_values[0], np.ndarray):
...@@ -134,14 +134,14 @@ class LlavaImageProcessor(BaseImageProcessor): ...@@ -134,14 +134,14 @@ class LlavaImageProcessor(BaseImageProcessor):
pixel_values, image_hash, image_size = await self._process_single_image( pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints image_data[0], aspect_ratio, grid_pinpoints
) )
image_hashes = [image_hash] data_hashes = [image_hash]
image_sizes = [image_size] image_sizes = [image_size]
else: else:
raise ValueError(f"Invalid image data: {image_data}") raise ValueError(f"Invalid image data: {image_data}")
return { return {
"pixel_values": pixel_values, "pixel_values": pixel_values,
"image_hashes": image_hashes, "data_hashes": data_hashes,
"image_sizes": image_sizes, "image_sizes": image_sizes,
"modalities": request_obj.modalities or ["image"], "modalities": request_obj.modalities or ["image"],
} }
...@@ -3,82 +3,113 @@ from typing import List, Union ...@@ -3,82 +3,113 @@ from typing import List, Union
import torch import torch
from sglang.srt.managers.image_processor import BaseImageProcessor from sglang.srt.managers.multimodal_processors.base_processor import (
from sglang.srt.managers.image_processors.base_image_processor import ( BaseMultimodalProcessor,
MultimodalSpecialTokens,
get_global_processor, get_global_processor,
) )
from sglang.srt.models.minicpmo import MiniCPMO
from sglang.srt.models.minicpmv import MiniCPMV from sglang.srt.models.minicpmv import MiniCPMV
class MiniCPMVImageProcessor(BaseImageProcessor): # Compatible with both 'O' and 'V'
models = [MiniCPMV] class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
models = [MiniCPMV, MiniCPMO]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "(<image>./</image>)" self.image_token = "(<image>./</image>)"
self.audio_token = "(<audio>./</audio>)"
@staticmethod @staticmethod
def _process_images_task(images, input_text): def _process_data_task(input_text, images=None, audios=None):
processor = get_global_processor()
result = processor.__call__(text=input_text, images=images, return_tensors="pt") if isinstance(images, list) and len(images) == 0:
images = None
if isinstance(audios, list) and len(audios) == 0:
audios = None
result = get_global_processor().__call__(
text=input_text,
images=images,
audios=audios,
return_tensors="pt",
chunk_input=True,
)
return { return {
"input_ids": result.input_ids, "input_ids": result.input_ids,
"pixel_values": result.pixel_values, "pixel_values": getattr(result, "pixel_values", None),
"tgt_sizes": result.tgt_sizes, "tgt_sizes": getattr(result, "tgt_sizes", None),
"audio_features": getattr(result, "audio_features", None),
"audio_feature_lens": getattr(result, "audio_feature_lens", None),
"audio_bounds": getattr(result, "audio_bounds", None),
} }
async def _process_images(self, images, input_text): async def _process_data(self, images, input_text, audios=None):
if self.executor is not None: if self.executor is not None:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor( multimodal_data_inputs = await loop.run_in_executor(
self.executor, self.executor,
MiniCPMVImageProcessor._process_images_task, MiniCPMMultimodalProcessor._process_data_task,
images,
input_text, input_text,
images,
audios,
) )
else: else:
image_inputs = self._processor( multimodal_data_inputs = self._processor(
images=images, text=input_text, return_tensors="pt" images=images, text=input_text, audios=audios, return_tensors="pt"
) )
return image_inputs return multimodal_data_inputs
async def process_images_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
input_ids, input_ids,
request_obj, request_obj,
max_req_input_len, max_req_input_len,
): ):
if not image_data: audio_data = request_obj.audio_data
if not image_data and not audio_data:
return None return None
if not isinstance(image_data, list): if not isinstance(image_data, list):
image_data = [image_data] image_data = [image_data]
if not isinstance(audio_data, list):
audio_data = [audio_data]
base_output = self.load_images( base_output = self.load_mm_data(
input_ids=input_ids, input_ids=input_ids,
image_data=image_data,
image_token=self.IMAGE_TOKEN,
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
audio_data=audio_data,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(
image_token=self.image_token, audio_token=self.audio_token
),
) )
if base_output is None: if base_output is None:
return None return None
if len(base_output.all_frames) == 0: res = await self._process_data(
return None images=base_output.images,
res = await self._process_images( input_text=base_output.input_text,
images=base_output.all_frames, input_text=base_output.input_text audios=base_output.audios,
) )
# Collect special token ids # Collect special token ids
tokenizer = self._processor.tokenizer tokenizer = self._processor.tokenizer
im_start_id = tokenizer.im_start_id slice_start_id, slice_end_id, audio_start_id, audio_end_id = (
im_token_id = tokenizer.unk_token_id None,
im_end_id = tokenizer.im_end_id None,
None,
None,
)
if tokenizer.slice_start_id: if tokenizer.slice_start_id:
slice_start_id = tokenizer.slice_start_id slice_start_id = tokenizer.slice_start_id
slice_end_id = tokenizer.slice_end_id slice_end_id = tokenizer.slice_end_id
if hasattr(tokenizer, "audio_start_id"):
audio_start_id = tokenizer.audio_start_id
audio_end_id = tokenizer.audio_end_id
im_token_id = tokenizer.unk_token_id
pixel_values = res["pixel_values"] pixel_values = res["pixel_values"]
tgt_sizes = res["tgt_sizes"] tgt_sizes = res["tgt_sizes"]
...@@ -98,8 +129,6 @@ class MiniCPMVImageProcessor(BaseImageProcessor): ...@@ -98,8 +129,6 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
f"{len(pixel_values)} vs. {len(tgt_sizes)}" f"{len(pixel_values)} vs. {len(tgt_sizes)}"
) )
# tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)]
# tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
pixel_values_flat: List[torch.Tensor] = [] pixel_values_flat: List[torch.Tensor] = []
tgt_sizes_flat: List[torch.Tensor] = [] tgt_sizes_flat: List[torch.Tensor] = []
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes): for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
...@@ -109,21 +138,30 @@ class MiniCPMVImageProcessor(BaseImageProcessor): ...@@ -109,21 +138,30 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
"Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}" "Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
) )
for pixel_n, tgt_n in zip(pixel_b, tgt_b): for pixel_n, tgt_n in zip(pixel_b, tgt_b):
# per patch
pixel_values_flat += [pixel_n] pixel_values_flat += [pixel_n]
tgt_sizes_flat += [tgt_n] tgt_sizes_flat += [tgt_n]
pixel_values = pixel_values_flat pixel_values = pixel_values_flat
if len(tgt_sizes_flat) == 0:
tgt_sizes = None
else:
tgt_sizes = torch.stack(tgt_sizes_flat) tgt_sizes = torch.stack(tgt_sizes_flat)
if not isinstance(res["audio_features"], list):
res["audio_features"] = [res["audio_features"]]
return { return {
"input_ids": res["input_ids"].flatten().tolist(), "input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": pixel_values, "pixel_values": pixel_values,
"tgt_sizes": tgt_sizes, "tgt_sizes": tgt_sizes,
"image_hashes": base_output.image_hashes, "data_hashes": base_output.mm_data_hashes,
"modalities": request_obj.modalities or ["image"], "modalities": request_obj.modalities or ["image"],
"im_start_id": im_start_id, "audio_start_id": audio_start_id,
"audio_end_id": audio_end_id,
"audio_features": res["audio_features"],
"audio_bounds": res["audio_bounds"],
"audio_feature_lens": res["audio_feature_lens"],
"im_token_id": im_token_id, "im_token_id": im_token_id,
"im_end_id": im_end_id, "im_start_id": tokenizer.im_start_id,
"im_end_id": tokenizer.im_end_id,
"slice_start_id": slice_start_id, "slice_start_id": slice_start_id,
"slice_end_id": slice_end_id, "slice_end_id": slice_end_id,
} }
import asyncio import asyncio
from typing import List, Union from typing import List, Union
from sglang.srt.managers.image_processor import BaseImageProcessor from sglang.srt.managers.multimodal_processors.base_processor import (
from sglang.srt.managers.image_processors.base_image_processor import ( BaseMultimodalProcessor,
get_global_processor, get_global_processor,
) )
from sglang.srt.models.mllama import MllamaForConditionalGeneration from sglang.srt.models.mllama import MllamaForConditionalGeneration
from sglang.srt.utils import load_image from sglang.srt.utils import load_image
class MllamaImageProcessor(BaseImageProcessor): class MllamaImageProcessor(BaseMultimodalProcessor):
models = [MllamaForConditionalGeneration] models = [MllamaForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
...@@ -34,7 +34,7 @@ class MllamaImageProcessor(BaseImageProcessor): ...@@ -34,7 +34,7 @@ class MllamaImageProcessor(BaseImageProcessor):
return image_inputs return image_inputs
async def process_images_async( async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
): ):
if not image_data: if not image_data:
...@@ -53,7 +53,7 @@ class MllamaImageProcessor(BaseImageProcessor): ...@@ -53,7 +53,7 @@ class MllamaImageProcessor(BaseImageProcessor):
images = load_image(image_data[0])[0] images = load_image(image_data[0])[0]
image_inputs = await self._process_single_image(images, input_text) image_inputs = await self._process_single_image(images, input_text)
image_inputs["image_hashes"] = [hash(str(image_data))] image_inputs["data_hashes"] = [hash(str(image_data))]
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0] image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
return image_inputs return image_inputs
import asyncio import asyncio
import math import math
import time
from typing import List, Union from typing import List, Union
import torch import torch
from PIL import Image from PIL import Image
from sglang.srt.managers.image_processor import BaseImageProcessor from sglang.srt.managers.multimodal_processor import (
from sglang.srt.managers.image_processors.base_image_processor import ( BaseMultimodalProcessor as SGLangBaseProcessor,
)
from sglang.srt.managers.multimodal_processors.base_processor import (
MultimodalSpecialTokens,
get_global_processor, get_global_processor,
) )
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
...@@ -14,7 +18,7 @@ from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration ...@@ -14,7 +18,7 @@ from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
# Compatible with Qwen2VL and Qwen2_5VL # Compatible with Qwen2VL and Qwen2_5VL
class Qwen2_5VLImageProcessor(BaseImageProcessor): class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration] models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
...@@ -59,7 +63,7 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor): ...@@ -59,7 +63,7 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
else: else:
return self._process_images_task(images, input_text, self.hf_config) return self._process_images_task(images, input_text, self.hf_config)
async def process_images_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
input_ids, input_ids,
...@@ -68,16 +72,17 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor): ...@@ -68,16 +72,17 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
*args, *args,
**kwargs, **kwargs,
): ):
start = time.time()
if not image_data: if not image_data:
return None return None
if isinstance(image_data, str): if isinstance(image_data, str):
image_data = [image_data] image_data = [image_data]
image_token = self.IMAGE_TOKEN image_token = self.IMAGE_TOKEN
base_output = self.load_images( base_output = self.load_mm_data(
input_ids=input_ids, input_ids=input_ids,
image_data=image_data, image_data=image_data,
image_token=image_token, multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
) )
...@@ -139,7 +144,7 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor): ...@@ -139,7 +144,7 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor return math.floor(number / factor) * factor
images = [resize_image(image) for image in base_output.all_frames] images = [resize_image(image) for image in base_output.images]
ret = await self._process_single_image( ret = await self._process_single_image(
images=images, input_text=base_output.input_text images=images, input_text=base_output.input_text
...@@ -147,11 +152,10 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor): ...@@ -147,11 +152,10 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
image_grid_thws = torch.concat([ret["image_grid_thw"]]) image_grid_thws = torch.concat([ret["image_grid_thw"]])
video_grid_thws = None video_grid_thws = None
return { return {
"input_ids": ret["input_ids"].flatten().tolist(), "input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"], "pixel_values": ret["pixel_values"],
"image_hashes": base_output.image_hashes, "data_hashes": base_output.mm_data_hashes,
"modalities": request_obj.modalities or ["image"], "modalities": request_obj.modalities or ["image"],
"image_grid_thws": image_grid_thws, "image_grid_thws": image_grid_thws,
"video_grid_thws": video_grid_thws, "video_grid_thws": video_grid_thws,
......
...@@ -144,11 +144,11 @@ class FINISH_ABORT(BaseFinishReason): ...@@ -144,11 +144,11 @@ class FINISH_ABORT(BaseFinishReason):
@dataclasses.dataclass @dataclasses.dataclass
class ImageInputs: class MultimodalInputs:
"""The image related inputs.""" """The image related inputs."""
pixel_values: Union[torch.Tensor, np.array] pixel_values: Union[torch.Tensor, np.array]
image_hashes: Optional[list] = None data_hashes: Optional[list] = None
image_sizes: Optional[list] = None image_sizes: Optional[list] = None
image_offsets: Optional[list] = None image_offsets: Optional[list] = None
image_pad_len: Optional[list] = None image_pad_len: Optional[list] = None
...@@ -182,20 +182,27 @@ class ImageInputs: ...@@ -182,20 +182,27 @@ class ImageInputs:
im_end_id: Optional[int] = None im_end_id: Optional[int] = None
slice_start_id: Optional[int] = None slice_start_id: Optional[int] = None
slice_end_id: Optional[int] = None slice_end_id: Optional[int] = None
# [num_images, 2 (w, h)]
tgt_sizes: Optional[list] = None tgt_sizes: Optional[list] = None
# audio
audio_start_id: Optional[torch.Tensor] = None
audio_end_id: Optional[torch.Tensor] = None
audio_features: Optional[List[torch.Tensor]] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None
@staticmethod @staticmethod
def from_dict(obj: dict): def from_dict(obj: dict):
ret = ImageInputs( ret = MultimodalInputs(
pixel_values=obj["pixel_values"], pixel_values=obj["pixel_values"],
image_hashes=obj["image_hashes"], data_hashes=obj["data_hashes"],
) )
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache. # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward, # Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example. # errors in cuda kernels. See also llava.py for example.
ret.pad_values = [x % (1 << 30) for x in ret.image_hashes] ret.pad_values = [x % (1 << 30) for x in ret.data_hashes]
optional_args = [ optional_args = [
"image_sizes", "image_sizes",
...@@ -211,6 +218,10 @@ class ImageInputs: ...@@ -211,6 +218,10 @@ class ImageInputs:
"slice_start_id", "slice_start_id",
"slice_end_id", "slice_end_id",
"tgt_sizes", "tgt_sizes",
"audio_start_id",
"audio_end_id",
"audio_features",
"audio_feature_lens",
] ]
for arg in optional_args: for arg in optional_args:
if arg in obj: if arg in obj:
...@@ -223,9 +234,19 @@ class ImageInputs: ...@@ -223,9 +234,19 @@ class ImageInputs:
or isinstance(ret.pixel_values, list) or isinstance(ret.pixel_values, list)
) )
assert ret.audio_features is None or isinstance(ret.audio_features, list)
return ret return ret
def merge(self, other: ImageInputs): def contains_image_inputs(self) -> bool:
""" """
return self.pixel_values is not None and self.pixel_values != []
def contains_audio_inputs(self) -> bool:
""" """
return self.audio_features is not None and self.audio_features != []
def merge(self, other: MultimodalInputs):
""" """
merge image inputs when requests are being merged merge image inputs when requests are being merged
""" """
...@@ -268,10 +289,12 @@ class ImageInputs: ...@@ -268,10 +289,12 @@ class ImageInputs:
# Please note that if the `input_ids` is later used in the model forward, # Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example. # errors in cuda kernels. See also llava.py for example.
self.image_hashes += other.image_hashes self.data_hashes += other.data_hashes
self.pad_values = [x % (1 << 30) for x in self.image_hashes] self.pad_values = [x % (1 << 30) for x in self.data_hashes]
# args needed to be merged # args needed to be merged
optional_args = [ optional_args = [
"audio_features",
"image_sizes", "image_sizes",
"image_offsets", "image_offsets",
"image_pad_len", "image_pad_len",
...@@ -362,7 +385,7 @@ class Req: ...@@ -362,7 +385,7 @@ class Req:
self.decoded_text = "" self.decoded_text = ""
# For multimodal inputs # For multimodal inputs
self.image_inputs: Optional[ImageInputs] = None self.multimodal_inputs: Optional[MultimodalInputs] = None
# Prefix info # Prefix info
# The indices to kv cache for the shared prefix. # The indices to kv cache for the shared prefix.
...@@ -458,10 +481,10 @@ class Req: ...@@ -458,10 +481,10 @@ class Req:
return len(self.origin_input_ids) + len(self.output_ids) return len(self.origin_input_ids) + len(self.output_ids)
def extend_image_inputs(self, image_inputs): def extend_image_inputs(self, image_inputs):
if self.image_inputs is None: if self.multimodal_inputs is None:
self.image_inputs = image_inputs self.multimodal_inputs = image_inputs
else: else:
self.image_inputs.merge(image_inputs) self.multimodal_inputs.merge(image_inputs)
def finished(self) -> bool: def finished(self) -> bool:
# Whether request reached finished condition # Whether request reached finished condition
...@@ -802,7 +825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -802,7 +825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.encoder_cached = [] self.encoder_cached = []
for req in self.reqs: for req in self.reqs:
im = req.image_inputs im = req.multimodal_inputs
if im is None or im.num_image_tokens is None: if im is None or im.num_image_tokens is None:
# No image input # No image input
self.encoder_lens_cpu.append(0) self.encoder_lens_cpu.append(0)
...@@ -1391,7 +1414,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1391,7 +1414,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
extend_seq_lens=extend_seq_lens, extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens, extend_prefix_lens=extend_prefix_lens,
extend_logprob_start_lens=extend_logprob_start_lens, extend_logprob_start_lens=extend_logprob_start_lens,
image_inputs=[r.image_inputs for r in self.reqs], multimodal_inputs=[r.multimodal_inputs for r in self.reqs],
encoder_cached=self.encoder_cached, encoder_cached=self.encoder_cached,
encoder_lens=self.encoder_lens, encoder_lens=self.encoder_lens,
encoder_lens_cpu=self.encoder_lens_cpu, encoder_lens_cpu=self.encoder_lens_cpu,
...@@ -1474,7 +1497,7 @@ class ModelWorkerBatch: ...@@ -1474,7 +1497,7 @@ class ModelWorkerBatch:
extend_input_logprob_token_ids: Optional[torch.Tensor] extend_input_logprob_token_ids: Optional[torch.Tensor]
# For multimodal # For multimodal
image_inputs: Optional[List[ImageInputs]] multimodal_inputs: Optional[List[MultimodalInputs]]
# For encoder-decoder # For encoder-decoder
encoder_cached: Optional[List[bool]] encoder_cached: Optional[List[bool]]
......
...@@ -88,7 +88,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -88,7 +88,7 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
FINISH_ABORT, FINISH_ABORT,
ImageInputs, MultimodalInputs,
Req, Req,
ScheduleBatch, ScheduleBatch,
global_server_args_dict, global_server_args_dict,
...@@ -841,8 +841,8 @@ class Scheduler( ...@@ -841,8 +841,8 @@ class Scheduler(
return return
# Handle multimodal inputs # Handle multimodal inputs
if recv_req.image_inputs is not None: if recv_req.mm_inputs is not None:
image_inputs = ImageInputs.from_dict(recv_req.image_inputs) image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
# Expand a single image token into multiple dummy tokens for receiving image embeddings # Expand a single image token into multiple dummy tokens for receiving image embeddings
req.origin_input_ids = self.pad_input_ids_func( req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids, image_inputs req.origin_input_ids, image_inputs
...@@ -856,7 +856,7 @@ class Scheduler( ...@@ -856,7 +856,7 @@ class Scheduler(
) )
logger.error(error_msg) logger.error(error_msg)
req.origin_input_ids = [0] req.origin_input_ids = [0]
req.image_inputs = None req.multimodal_inputs = None
req.sampling_params.max_new_tokens = 0 req.sampling_params.max_new_tokens = 0
req.finished_reason = FINISH_ABORT( req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError" error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
...@@ -960,7 +960,7 @@ class Scheduler( ...@@ -960,7 +960,7 @@ class Scheduler(
# Handle multimodal inputs # Handle multimodal inputs
if recv_req.image_inputs is not None: if recv_req.image_inputs is not None:
image_inputs = ImageInputs.from_dict(recv_req.image_inputs) image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
# Expand a single image token into multiple dummy tokens for receiving image embeddings # Expand a single image token into multiple dummy tokens for receiving image embeddings
req.origin_input_ids = self.pad_input_ids_func( req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids, image_inputs req.origin_input_ids, image_inputs
...@@ -974,7 +974,7 @@ class Scheduler( ...@@ -974,7 +974,7 @@ class Scheduler(
) )
logger.error(error_msg) logger.error(error_msg)
req.origin_input_ids = [0] req.origin_input_ids = [0]
req.image_inputs = None req.multimodal_inputs = None
req.sampling_params.max_new_tokens = 0 req.sampling_params.max_new_tokens = 0
req.finished_reason = FINISH_ABORT( req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError" error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
......
...@@ -138,7 +138,7 @@ class Session: ...@@ -138,7 +138,7 @@ class Session:
token_ids_logprob=req.token_ids_logprob, token_ids_logprob=req.token_ids_logprob,
) )
if last_req is not None: if last_req is not None:
new_req.image_inputs = last_req.image_inputs new_req.multimodal_inputs = last_req.mm_inputs
new_req.tokenizer = tokenizer new_req.tokenizer = tokenizer
if abort: if abort:
new_req.to_abort = True new_req.to_abort = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment