Unverified Commit 1e86457c authored by Mick's avatar Mick Committed by GitHub
Browse files

model: Minicpmo (#3023)

parent 64129fa6
import argparse
import PIL.Image
import torch
from data_utils import save_json
from eval_utils import (
......@@ -10,22 +11,38 @@ from eval_utils import (
process_result,
)
from tqdm import tqdm
from transformers import AutoModelForImageTextToText, AutoProcessor, GenerationConfig
from transformers import AutoModel, AutoProcessor, GenerationConfig
@torch.no_grad()
def eval_mmmu(args):
eval_args = EvalArgs.from_cli_args(args)
try:
from transformers import AutoModelForImageTextToText
model = AutoModelForImageTextToText.from_pretrained(
args.model_path,
torch_dtype="auto",
trust_remote_code=True,
)
except Exception as first_exception:
try:
model = AutoModel.from_pretrained(
args.model_path,
torch_dtype="auto",
trust_remote_code=True,
init_tts=False,
)
except Exception as second_exception:
raise RuntimeError(
f"Failed to load model: First attempt failed with {first_exception}, "
f"second attempt failed with {second_exception}"
) from second_exception
model = AutoModelForImageTextToText.from_pretrained(
args.model_path,
torch_dtype="auto",
trust_remote_code=True,
)
model = model.eval().cuda()
processor = AutoProcessor.from_pretrained(
args.model_path, torch_dtype="auto", device_map="auto"
args.model_path, torch_dtype="auto", device_map="auto", trust_remote_code=True
)
samples = prepare_samples(eval_args)
......
......@@ -24,7 +24,7 @@
- InternLM 2
- Exaone 3
- BaiChuan2
- MiniCPM / MiniCPM 3 / MiniCPMV
- MiniCPM / MiniCPM 3 / MiniCPM-v / MiniCPM-o
- XVERSE / XVERSE MoE
- SmolLM
- GLM-4
......@@ -70,9 +70,9 @@ LLM.
1. **Register your new model as multimodal**: Extend `is_multimodal_model` in [
`model_config.py`](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/configs/model_config.py) to
return True for your model.
2. **Process Images**: Create a new `ImageProcessor` class that inherits from `BaseImageProcessor` and register this
2. **Process Images**: Define a new `Processor` class that inherits from `BaseProcessor` and register this
processor as your model's dedicated processor. See [
`image_processor.py`](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/image_processor.py)
`multimodal_processor.py`](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/multimodal_processor.py)
for more details.
3. **Handle Image Tokens**: Implement a `pad_input_ids` function for your new model, in which image tokens in the prompt
should be expanded and replaced with image-hashes, so that SGLang can recognize different images for
......@@ -80,7 +80,7 @@ LLM.
4. Replace Multi-headed `Attention` of ViT with SGLang's `VisionAttention`.
You can refer [Qwen2VL](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/qwen2_vl.py) or other
vLMs. These models demonstrate how to properly handle both visual and textual inputs.
vLMs. These models demonstrate how to properly handle both multimodal and textual inputs.
You should test the new vLM locally against hf models. See [`mmmu`](https://github.com/sgl-project/sglang/tree/main/benchmark/mmmu) for an example.
......
......@@ -34,6 +34,7 @@ runtime_common = [
"pydantic",
"python-multipart",
"pyzmq>=25.1.2",
"soundfile==0.13.1",
"torchao>=0.7.0",
"transformers==4.50.0",
"uvicorn",
......
......@@ -15,6 +15,7 @@ class ChatTemplate:
role_prefix_and_suffix: Dict[str, Tuple[str, str]]
stop_str: List[str] = ()
image_token: str = "<image>"
audio_token: str = "<audio>"
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
def get_prefix_and_suffix(
......@@ -253,6 +254,22 @@ register_chat_template(
)
)
# https://huggingface.co/openbmb/MiniCPM-o-2_6
register_chat_template(
ChatTemplate(
name="minicpmo",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", " "),
"user": ("user:", " "),
"assistant": ("assistant:", "</s>"),
},
stop_str=("<|im_end|>", "<|endoftext|>"),
image_token="(<image>./</image>)",
audio_token="(<audio>./</audio>)",
)
)
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
register_chat_template(
ChatTemplate(
......@@ -474,12 +491,6 @@ def match_chat_ml(model_path: str):
return get_chat_template("chatml-llava")
@register_chat_template_matching_function
def match_chat_minicpm(model_path: str):
if "minicpm" in model_path:
return get_chat_template("minicpmv")
@register_chat_template_matching_function
def match_chat_yi(model_path: str):
model_path = model_path.lower()
......@@ -499,8 +510,10 @@ def match_gemma_it(model_path: str):
@register_chat_template_matching_function
def match_openbmb_minicpm(model_path: str):
model_path = model_path.lower()
if "minicpm" in model_path:
if "minicpm-v" in model_path:
return get_chat_template("minicpmv")
elif "minicpm-o" in model_path:
return get_chat_template("minicpmo")
@register_chat_template_matching_function
......
......@@ -462,18 +462,19 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
multimodal_model_archs = [
"DeepseekVL2ForCausalLM",
"LlavaLlamaForCausalLM",
"LlavaQwenForCausalLM",
"LlavaMistralForCausalLM",
"LlavaVidForCausalLM",
"Gemma3ForConditionalGeneration",
"Grok1VForCausalLM",
"Grok1AForCausalLM",
"LlavaLlamaForCausalLM",
"LlavaMistralForCausalLM",
"LlavaQwenForCausalLM",
"LlavaVidForCausalLM",
"MiniCPMO",
"MiniCPMV",
"MultiModalityCausalLM",
"MllamaForConditionalGeneration",
"Qwen2VLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration",
"MiniCPMV",
"MultiModalityCausalLM",
]
......
......@@ -73,11 +73,14 @@ class Conversation:
stop_str: Union[str, List[str]] = None
# The string that represents an image token in the prompt
image_token: str = "<image>"
audio_token: str = "<audio>"
image_data: Optional[List[str]] = None
modalities: Optional[List[str]] = None
stop_token_ids: Optional[int] = None
audio_data: Optional[List[str]] = None
def get_prompt(self) -> str:
"""Get the prompt for generation."""
system_prompt = self.system_template.format(system_message=self.system_message)
......@@ -327,6 +330,10 @@ class Conversation:
"""Append a new message."""
self.image_data.append(image)
def append_audio(self, audio: str):
"""Append a new message."""
self.audio_data.append(audio)
def update_last_message(self, message: str):
"""Update the last output.
......@@ -373,6 +380,7 @@ class Conversation:
sep2=self.sep2,
stop_str=self.stop_str,
image_token=self.image_token,
audio_token=self.audio_token,
)
def dict(self):
......@@ -459,8 +467,10 @@ def generate_chat_conv(
sep2=conv.sep2,
stop_str=conv.stop_str,
image_data=[],
audio_data=[],
modalities=[],
image_token=conv.image_token,
audio_token=conv.audio_token,
)
if isinstance(request.messages, str):
......@@ -498,6 +508,7 @@ def generate_chat_conv(
if conv.name != "qwen2-vl"
else conv.image_token
)
audio_token = conv.audio_token
for content in message.content:
if content.type == "text":
if num_image_url > 16:
......@@ -507,6 +518,10 @@ def generate_chat_conv(
# NOTE: Only works for llava
real_content += image_token
conv.append_image(content.image_url.url)
elif content.type == "audio_url":
real_content += audio_token
conv.append_audio(content.audio_url.url)
conv.append_message(conv.roles[0], real_content)
elif msg_role == "assistant":
parsed_content = ""
......@@ -704,3 +719,18 @@ register_conv_template(
image_token="<image_placeholder>",
)
)
# Reference: https://huggingface.co/openbmb/MiniCPM-o-2_6#usage
register_conv_template(
Conversation(
name="minicpmo",
system_message="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
system_template="<|im_start|>system\n{system_message}",
roles=("<|im_start|>user", "<|im_start|>assistant"),
sep="<|im_end|>\n",
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
stop_str=("<|im_end|>", "<|endoftext|>"),
image_token="(<image>./</image>)",
audio_token="(<audio>./</audio>)",
)
)
......@@ -45,6 +45,8 @@ class GenerateReqInput:
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None
# The audio input. Like image data, tt can be a file name, a url, or base64 encoded string.
audio_data: Optional[Union[List[str], str]] = None
# The sampling_params. See descriptions below.
sampling_params: Optional[Union[List[Dict], Dict]] = None
# The request id.
......@@ -167,6 +169,13 @@ class GenerateReqInput:
elif isinstance(self.image_data, list):
pass
if self.audio_data is None:
self.audio_data = [None] * num
elif not isinstance(self.audio_data, list):
self.audio_data = [self.audio_data] * num
elif isinstance(self.audio_data, list):
pass
if self.sampling_params is None:
self.sampling_params = [{}] * num
elif not isinstance(self.sampling_params, list):
......@@ -231,6 +240,7 @@ class GenerateReqInput:
text=self.text[i] if self.text is not None else None,
input_ids=self.input_ids[i] if self.input_ids is not None else None,
image_data=self.image_data[i],
audio_data=self.audio_data[i],
sampling_params=self.sampling_params[i],
rid=self.rid[i],
return_logprob=self.return_logprob[i],
......@@ -259,8 +269,8 @@ class TokenizedGenerateReqInput:
input_text: str
# The input token ids
input_ids: List[int]
# The image inputs
image_inputs: dict
# The multimodal inputs
mm_inputs: dict
# The sampling parameters
sampling_params: SamplingParams
# Whether to return the logprobs
......
......@@ -9,7 +9,7 @@ import torch
from torch import nn
from sglang.srt.managers.schedule_batch import (
ImageInputs,
MultimodalInputs,
global_server_args_dict,
logger,
)
......@@ -26,7 +26,7 @@ class MultiModalityDataPaddingPattern:
@abstractmethod
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
self, input_ids: List[int], image_inputs: MultimodalInputs
) -> List[int]:
"""
Pad the input ids sequence containing data tokens, and replace them with pad_values
......@@ -44,16 +44,16 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
self.data_token_id_pairs = data_token_pairs
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]:
"""
This function will replace the data-tokens inbetween with pad_values accordingly
"""
pad_values = image_inputs.pad_values
pad_values = mm_inputs.pad_values
data_token_pairs = self.data_token_id_pairs
image_inputs.image_offsets = []
mm_inputs.image_offsets = []
if data_token_pairs is None:
data_token_pairs = [image_inputs.im_start_id, image_inputs.im_end_id]
data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
if data_token_pairs is None:
logger.warning(
"No data_token_pairs provided, RadixAttention might be influenced."
......@@ -61,8 +61,6 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
return input_ids
start_token_ids = [s for s, _e in data_token_pairs]
end_tokens_ids = [e for _s, e in data_token_pairs]
# First start token marks new data
data_start_token = start_token_ids[0]
padded_ids = []
last_idx = 0
......@@ -77,9 +75,12 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
for start_idx, end_idx in zip(start_indices, end_indices):
padded_ids.extend(input_ids[last_idx : start_idx + 1])
if input_ids[start_idx] == data_start_token:
if input_ids[start_idx] in start_token_ids:
data_idx += 1
image_inputs.image_offsets += [start_idx]
mm_inputs.image_offsets += [start_idx]
if data_idx >= len(mm_inputs.pad_values):
data_idx = len(mm_inputs.pad_values) - 1
num_tokens = end_idx - start_idx - 1
pad_value = pad_values[data_idx]
......@@ -89,7 +90,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
padded_ids.extend(input_ids[last_idx:])
assert len(input_ids) == len(padded_ids)
assert len(input_ids) == len(padded_ids), "Length validation fails"
return padded_ids
......@@ -107,26 +108,25 @@ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern)
self.num_data_token_calc_func = num_data_token_calc_func
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]:
"""
This function will follow the procedure of:
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
2. the padded data tokens will be replaced with their pad_values
"""
image_grid_thws = image_inputs.image_grid_thws
pad_values = image_inputs.pad_values
image_grid_thws = mm_inputs.image_grid_thws
pad_values = mm_inputs.pad_values
image_indices = [
idx
for idx, token in enumerate(input_ids)
if token == image_inputs.im_token_id
idx for idx, token in enumerate(input_ids) if token == mm_inputs.im_token_id
]
image_inputs.image_offsets = []
mm_inputs.image_offsets = []
input_ids_with_image = []
for image_cnt, _ in enumerate(image_grid_thws):
# print(f"image_cnt {image_cnt}")
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
if image_cnt == 0:
non_image_tokens = input_ids[: image_indices[image_cnt]]
......@@ -135,7 +135,7 @@ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern)
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
]
input_ids_with_image.extend(non_image_tokens)
image_inputs.image_offsets.append(len(input_ids_with_image))
mm_inputs.image_offsets.append(len(input_ids_with_image))
pad_ids = pad_values * (
(num_image_tokens + len(pad_values)) // len(pad_values)
)
......@@ -170,11 +170,11 @@ class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern
return input_ids_tensor.tolist()
def embed_image_inputs(
image_input: ImageInputs,
def embed_mm_inputs(
mm_input: MultimodalInputs,
input_ids: torch.Tensor,
input_embedding: nn.Embedding,
image_embedding_func,
mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
placeholder_token_ids: List[int] = None,
) -> Optional[torch.Tensor]:
"""
......@@ -184,10 +184,10 @@ def embed_image_inputs(
Returns:
final embedding: Optional[torch.Tensor]
"""
if image_input is None:
if mm_input is None:
return None
placeholder_token_ids = placeholder_token_ids or image_input.pad_values
placeholder_token_ids = placeholder_token_ids or mm_input.pad_values
# boolean masking the special tokens
special_image_mask = torch.isin(
......@@ -196,12 +196,18 @@ def embed_image_inputs(
).unsqueeze(-1)
num_image_tokens_in_input_ids = special_image_mask.sum()
# print(f"{num_image_tokens_in_input_ids}")
# print(f"{input_ids}")
# return
if num_image_tokens_in_input_ids == 0:
# unexpected
inputs_embeds = input_embedding(input_ids)
else:
image_embedding = image_embedding_func(image_input)
# print(f"Getting image feature")
image_embedding = mm_data_embedding_func(mm_input)
# print(f"image_embedding: {image_embedding.shape}")
if image_embedding.dim() == 2:
num_image_tokens_in_embedding = image_embedding.shape[0]
......@@ -273,31 +279,95 @@ def embed_image_embedding(
def general_mm_embed_routine(
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
embed_tokens: nn.Embedding,
image_embedding_func: Callable[[ImageInputs], torch.Tensor],
mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
placeholder_token_ids: List[int] = None,
):
"""
a general wrapper function to get final input embeds from multimodal models
with a language model as causal model
Args:
placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
"""
if (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
not forward_batch.forward_mode.is_decode()
and forward_batch.contains_mm_inputs()
):
inputs_embeds = embed_tokens(input_ids)
else:
image = forward_batch.merge_image_inputs()
inputs_embeds = embed_image_inputs(
image_input=image,
image = forward_batch.merge_mm_inputs()
inputs_embeds = embed_mm_inputs(
mm_input=image,
input_ids=input_ids,
input_embedding=embed_tokens,
image_embedding_func=image_embedding_func,
mm_data_embedding_func=mm_data_embedding_func,
placeholder_token_ids=placeholder_token_ids,
)
# once used, image_inputs is useless
# once used, mm_inputs is useless
# just being defensive here
forward_batch.image_inputs = None
forward_batch.mm_inputs = None
else:
inputs_embeds = embed_tokens(input_ids)
return inputs_embeds
def get_multimodal_data_bounds(
input_ids: torch.Tensor, pad_values: List[int], token_pairs: List[Tuple[int, int]]
) -> torch.Tensor:
"""
Returns a tensor indicating the bounds of multimodal data (images, video, audio, etc.)
Returns:
[bounds_count, 2]
"""
# All the images in the batch should share the same special image
# bound token ids.
start_tokens = [s for s, _e in token_pairs]
end_tokens = [e for _s, e in token_pairs]
assert all(isinstance(t, int) for t in start_tokens)
assert all(isinstance(t, int) for t in end_tokens)
# print(input_ids)
start_cond = torch.isin(
input_ids, torch.tensor(start_tokens, device=input_ids.device)
)
end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))
(data_start_tokens,) = torch.where(start_cond)
(data_end_tokens,) = torch.where(end_cond)
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the images
if len(data_start_tokens) != len(data_end_tokens):
if (
len(data_start_tokens) + 1 == len(data_end_tokens)
and input_ids[0] in pad_values
and data_end_tokens[0] < data_start_tokens[0]
):
data_start_tokens = torch.cat(
[
torch.tensor([0], device=data_start_tokens.device),
data_start_tokens,
]
)
valid_image_nums = min(len(data_start_tokens), len(data_end_tokens))
if valid_image_nums == 0:
return torch.zeros((0, 2), device=input_ids.device)
# Filter out pairs where start_token >= end_token
valid_pairs = []
for i in range(valid_image_nums):
start_token = data_start_tokens[i]
end_token = data_end_tokens[i]
if start_token < end_token:
valid_pairs.append((start_token + 1, end_token - 1))
if not valid_pairs:
return torch.zeros((0, 2), device=input_ids.device)
# Convert valid pairs to tensor
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
return valid_pairs_tensor
......@@ -4,46 +4,41 @@ import inspect
import logging
import pkgutil
from functools import lru_cache
from typing import Union
from torch import Tensor
from transformers import IMAGE_PROCESSOR_MAPPING
from transformers import PROCESSOR_MAPPING
from sglang.srt.managers.image_processors.base_image_processor import (
BaseImageProcessor,
DummyImageProcessor,
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
)
from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
PROCESSOR_MAPPING = {}
IMAGE_PROCESSOR_MAPPING = {}
class DummyMultimodalProcessor(BaseMultimodalProcessor):
def __init__(self):
pass
def get_image_processor(hf_config, server_args, processor) -> BaseImageProcessor:
for model_cls, processor_cls in IMAGE_PROCESSOR_MAPPING.items():
if model_cls.__name__ in hf_config.architectures:
return processor_cls(hf_config, server_args, processor)
raise ValueError(
f"No image processor found for architecture: {hf_config.architectures}"
)
async def process_mm_data_async(self, *args, **kwargs):
return None
def get_dummy_image_processor():
return DummyImageProcessor()
def get_dummy_processor():
return DummyMultimodalProcessor()
@lru_cache()
def import_image_processors():
package_name = "sglang.srt.managers.image_processors"
def import_processors():
package_name = "sglang.srt.managers.multimodal_processors"
package = importlib.import_module(package_name)
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg:
try:
module = importlib.import_module(name)
except Exception as e:
logger.warning(f" Ignore import error when loading {name}: " f"{e}")
logger.warning(f"Ignore import error when loading {name}: " f"{e}")
continue
all_members = inspect.getmembers(module, inspect.isclass)
classes = [
......@@ -51,11 +46,23 @@ def import_image_processors():
for name, member in all_members
if member.__module__ == module.__name__
]
for cls in classes:
if issubclass(cls, BaseImageProcessor):
for arch in getattr(cls, "models"):
IMAGE_PROCESSOR_MAPPING[arch] = cls
for cls in (
cls for cls in classes if issubclass(cls, BaseMultimodalProcessor)
):
assert hasattr(cls, "models")
for arch in getattr(cls, "models"):
PROCESSOR_MAPPING[arch] = cls
def get_mm_processor(
hf_config, server_args: ServerArgs, processor
) -> BaseMultimodalProcessor:
for model_cls, processor_cls in PROCESSOR_MAPPING.items():
if model_cls.__name__ in hf_config.architectures:
return processor_cls(hf_config, server_args, processor)
raise ValueError(
f"No processor registered for architecture: {hf_config.architectures}.\n"
f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
)
# also register processors
import_image_processors()
self.image_proce
......@@ -4,16 +4,16 @@ import dataclasses
import multiprocessing as mp
import os
from abc import ABC, abstractmethod
from typing import Optional, Union
from typing import Optional
import numpy as np
import PIL
import transformers
from decord import VideoReader, cpu
from openai import BadRequestError
from PIL import Image
from sglang.srt.utils import load_image
from sglang.utils import logger
from sglang.srt.utils import load_audio, load_image, logger
global global_processor
......@@ -24,21 +24,41 @@ def get_global_processor():
@dataclasses.dataclass
class BaseImageProcessorOutput:
image_hashes: list[int]
image_sizes: list[tuple[int, int]]
all_frames: [PIL.Image]
# input_text, with each frame of video/image represented as an image_token
class BaseMultiModalProcessorOutput:
# input_text, with each frame of video/image represented with a image_token
input_text: str
mm_data_hashes: Optional[list[int]]
# images
image_sizes: Optional[list[int]]
# frames loaded from image and video, in given order
images: Optional[list[PIL.Image]] = None
# audios
audios: Optional[list[np.ndarray]] = None
def normalize(self):
for field_name in ["data_hashes", "image_sizes", "all_frames"]:
for field_name in ["data_hashes", "image_sizes", "images", "audios"]:
field = getattr(self, field_name, None)
if field is not None and isinstance(field, list) and len(field) == 0:
setattr(self, field_name, None)
class BaseImageProcessor(ABC):
@dataclasses.dataclass
class MultimodalSpecialTokens:
image_token: Optional[str] = None
video_token: Optional[str] = None
audio_token: Optional[str] = None
def collect(self) -> list[str]:
return [
token
for token in [self.image_token, self.video_token, self.audio_token]
if token
]
class BaseMultimodalProcessor(ABC):
models = []
def __init__(self, hf_config, server_args, _processor):
......@@ -72,7 +92,7 @@ class BaseImageProcessor(ABC):
)
@abstractmethod
async def process_images_async(
async def process_mm_data_async(
self, image_data, input_text, max_req_input_len, **kwargs
):
pass
......@@ -120,29 +140,33 @@ class BaseImageProcessor(ABC):
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
return frames
def load_images(
def load_mm_data(
self,
input_ids: list[int],
image_data,
image_token: Union[int, str],
multimodal_tokens: MultimodalSpecialTokens,
max_req_input_len: int,
image_data: Optional[list] = None,
audio_data: Optional[list] = None,
return_text: Optional[bool] = True,
discard_alpha_channel: bool = True,
) -> BaseImageProcessorOutput:
) -> BaseMultiModalProcessorOutput:
"""
Each frame of video/image will be replaced by a single image token
Args:
image_token: The token ID representing the image placeholder.
multimodal_tokens (list[str]): list of special token which denoting a single multimodal data
e.g. image token or audio token
discard_alpha_channel: if True, discards the alpha channel in the returned images
"""
if isinstance(image_token, int):
image_token_str = self._processor.tokenizer.convert_ids_to_tokens(
image_token
if isinstance(multimodal_tokens.image_token, int):
multimodal_tokens.image_token = (
self._processor.tokenizer.convert_ids_to_tokens(
multimodal_tokens.image_token
)
)
else:
image_token_str = image_token
multimodal_tokens.image_token = multimodal_tokens.image_token
if isinstance(input_ids, list) and return_text:
assert len(input_ids) and isinstance(input_ids[0], int)
......@@ -152,7 +176,11 @@ class BaseImageProcessor(ABC):
if return_text:
import re
pattern = "(" + "|".join(re.escape(sep) for sep in [image_token]) + ")"
pattern = (
"("
+ "|".join(re.escape(sep) for sep in multimodal_tokens.collect())
+ ")"
)
# split text into list of normal text and special tokens
text_parts = re.split(pattern, input_text)
......@@ -162,7 +190,7 @@ class BaseImageProcessor(ABC):
total_frame_count = sum(estimated_frames_list)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
_scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
assert len(image_data) == len(estimated_frames_list)
......@@ -171,9 +199,16 @@ class BaseImageProcessor(ABC):
new_text = ""
for index, text_part in enumerate(text_parts):
try:
if text_part == image_token:
if text_part == multimodal_tokens.image_token:
# load as image
frames_to_process = estimated_frames_list[image_index]
if len(images) >= MAX_NUM_FRAMES:
frames_to_process = 0
else:
estimated_frames = estimated_frames_list[image_index]
frames_to_process = max(
1, int(estimated_frames * scaling_factor)
)
if frames_to_process == 0:
frames = []
else:
......@@ -183,7 +218,7 @@ class BaseImageProcessor(ABC):
):
# video
path = image_file[len("video:") :]
frames = self.encode_video(
frames = BaseMultimodalProcessor.encode_video(
path, frame_count_limit=frames_to_process
)
else:
......@@ -200,40 +235,41 @@ class BaseImageProcessor(ABC):
images += frames
image_index += 1
if frames_to_process != 0:
new_text += image_token * len(frames)
new_text += multimodal_tokens.image_token * len(frames)
assert frames_to_process == len(frames)
elif text_part == multimodal_tokens.audio_token:
# load as audio
audio_file = audio_data[audio_index]
audio = load_audio(audio_file)
hashes += [hash(audio_file)]
audios += [audio]
audio_index += 1
new_text += multimodal_tokens.audio_token
else:
# TODO(mick): handle video
# normal text
new_text += text_part
except Exception as e:
logger.error(f"An exception occurred while loading images: {e}")
raise BadRequestError(
f"An exception occurred while loading images: {e}"
)
return BaseImageProcessorOutput(
image_hashes=hashes,
out = BaseMultiModalProcessorOutput(
mm_data_hashes=hashes,
image_sizes=image_sizes,
all_frames=images,
images=images,
audios=audios,
input_text=new_text,
)
out.normalize()
return out
class DummyImageProcessor(BaseImageProcessor):
def __init__(self):
pass
async def process_images_async(self, *args, **kwargs):
return None
def init_global_processor(sglang_image_processor: BaseImageProcessor, server_args):
"""Init the global processor for multi-modal models."""
def init_global_processor(sglang_processor: BaseMultimodalProcessor, server_args):
"""
Init the global processor for multimodal models."""
global global_processor
transformers.logging.set_verbosity_error()
global_processor = sglang_image_processor._build_processor(server_args=server_args)
global_processor = sglang_processor._build_processor(server_args=server_args)
......@@ -20,14 +20,15 @@ import asyncio
import torch
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
class DeepseekVL2ImageProcessor(BaseImageProcessor):
class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
models = [DeepseekVL2ForCausalLM]
def __init__(self, hf_config, server_args, _processor):
......@@ -63,7 +64,23 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
return image_inputs
async def process_images_async(
async def _process_images(self, image_data, input_text, max_req_input_len):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
DeepseekVL2ImageProcessor._process_images_task,
image_data,
input_text,
max_req_input_len,
)
else:
image_inputs = self._process_images_task(
image_data, input_text, max_req_input_len
)
return image_inputs
async def process_mm_data_async(
self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs
):
if not image_data:
......@@ -75,11 +92,14 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
images, image_sizes = [], []
image_token = self.IMAGE_TOKEN
base_output = self.load_images(
input_ids, image_data, image_token, max_req_input_len
base_output = self.load_mm_data(
input_ids,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len,
)
res = await self._process_images(
base_output.all_frames, base_output.input_text, max_req_input_len
base_output.images, base_output.input_text, max_req_input_len
)
images_seq_mask = res["images_seq_mask"]
images_spatial_crop = res["images_spatial_crop"]
......@@ -91,7 +111,7 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
"input_ids": res["input_ids"].tolist(),
"pixel_values": res["images"],
"im_token_id": res["im_token_id"],
"image_hashes": base_output.image_hashes,
"data_hashes": base_output.mm_data_hashes,
"image_sizes": image_sizes,
"images_emb_mask": images_seq_mask,
"image_spatial_crop": batched_images_spatial_crop,
......
import asyncio
from typing import List, Union
from transformers.utils import logging
from sglang.srt.managers.image_processor import (
BaseImageProcessor as SGLangBaseImageProcessor,
from sglang.srt.managers.multimodal_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
from sglang.srt.managers.image_processors.base_image_processor import (
from sglang.srt.managers.multimodal_processors.base_processor import (
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
......@@ -16,7 +16,7 @@ from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
logger = logging.get_logger(__name__)
class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
models = [Gemma3ForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
......@@ -47,7 +47,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
"pixel_values": pixel_values,
}
async def process_images_async(
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
......@@ -62,22 +62,22 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
image_data = [image_data]
image_token = self.IMAGE_TOKEN
base_output = self.load_images(
base_output = self.load_mm_data(
input_ids=input_ids,
image_data=image_data,
image_token=image_token,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len,
discard_alpha_channel=True,
)
ret = await self._process_single_image(
input_text=base_output.input_text, images=base_output.all_frames
input_text=base_output.input_text, images=base_output.images
)
return {
"input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"],
"image_hashes": base_output.image_hashes,
"data_hashes": base_output.mm_data_hashes,
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
}
import asyncio
from typing import List, Union
from sglang.srt.managers.image_processors.base_image_processor import (
BaseImageProcessor as SGLangBaseImageProcessor,
)
from sglang.srt.managers.image_processors.base_image_processor import (
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM
class JanusProProcessor(SGLangBaseImageProcessor):
class JanusProImageProcessor(BaseMultimodalProcessor):
models = [MultiModalityCausalLM]
def __init__(self, hf_config, server_args, _processor):
......@@ -36,7 +35,7 @@ class JanusProProcessor(SGLangBaseImageProcessor):
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
self.executor,
JanusProProcessor._process_images_task,
JanusProImageProcessor._process_images_task,
images,
input_text,
)
......@@ -47,7 +46,7 @@ class JanusProProcessor(SGLangBaseImageProcessor):
return image_inputs
async def process_images_async(
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
......@@ -61,20 +60,24 @@ class JanusProProcessor(SGLangBaseImageProcessor):
if not isinstance(image_data, list):
image_data = [image_data]
base_out = self.load_images(
base_out = self.load_mm_data(
input_ids=input_ids,
image_data=image_data,
image_token="<image_placeholder>",
multimodal_tokens=MultimodalSpecialTokens(
image_token="<image_placeholder>"
),
max_req_input_len=max_req_input_len,
)
images = base_out.all_frames
images = base_out.images
res = await self._process_images(images=images, input_text=base_out.input_text)
# print(res)
# print(base_out)
# print("", res["images_emb_mask"].shape)
return {
"input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": res["pixel_values"],
"images_emb_mask": res["images_emb_mask"],
"image_hashes": base_out.image_hashes,
"data_hashes": base_out.mm_data_hashes,
"im_start_id": res["im_start_id"],
"im_end_id": res["im_end_id"],
"im_token_id": res["im_token_id"],
......
......@@ -3,8 +3,8 @@ from typing import List, Optional, Union
import numpy as np
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
get_global_processor,
)
from sglang.srt.mm_utils import expand2square, process_anyres_image
......@@ -14,7 +14,7 @@ from sglang.srt.utils import load_image, logger
from sglang.utils import get_exception_traceback
class LlavaImageProcessor(BaseImageProcessor):
class LlavaImageProcessor(BaseMultimodalProcessor):
models = [LlavaVidForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
def __init__(self, hf_config, server_args, _processor):
......@@ -86,7 +86,7 @@ class LlavaImageProcessor(BaseImageProcessor):
image_data, aspect_ratio, grid_pinpoints
)
async def process_images_async(
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_text,
......@@ -113,7 +113,7 @@ class LlavaImageProcessor(BaseImageProcessor):
if "multi-images" in modalities or "video" in modalities:
# Multiple images
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values, image_hashes, image_sizes = [], [], []
pixel_values, data_hashes, image_sizes = [], [], []
res = []
for img_data in image_data:
res.append(
......@@ -124,7 +124,7 @@ class LlavaImageProcessor(BaseImageProcessor):
res = await asyncio.gather(*res)
for pixel_v, image_h, image_s in res:
pixel_values.append(pixel_v)
image_hashes.append(image_h)
data_hashes.append(image_h)
image_sizes.append(image_s)
if isinstance(pixel_values[0], np.ndarray):
......@@ -134,14 +134,14 @@ class LlavaImageProcessor(BaseImageProcessor):
pixel_values, image_hash, image_size = await self._process_single_image(
image_data[0], aspect_ratio, grid_pinpoints
)
image_hashes = [image_hash]
data_hashes = [image_hash]
image_sizes = [image_size]
else:
raise ValueError(f"Invalid image data: {image_data}")
return {
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"data_hashes": data_hashes,
"image_sizes": image_sizes,
"modalities": request_obj.modalities or ["image"],
}
......@@ -3,82 +3,113 @@ from typing import List, Union
import torch
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.models.minicpmo import MiniCPMO
from sglang.srt.models.minicpmv import MiniCPMV
class MiniCPMVImageProcessor(BaseImageProcessor):
models = [MiniCPMV]
# Compatible with both 'O' and 'V'
class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
models = [MiniCPMV, MiniCPMO]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "(<image>./</image>)"
self.image_token = "(<image>./</image>)"
self.audio_token = "(<audio>./</audio>)"
@staticmethod
def _process_images_task(images, input_text):
processor = get_global_processor()
result = processor.__call__(text=input_text, images=images, return_tensors="pt")
def _process_data_task(input_text, images=None, audios=None):
if isinstance(images, list) and len(images) == 0:
images = None
if isinstance(audios, list) and len(audios) == 0:
audios = None
result = get_global_processor().__call__(
text=input_text,
images=images,
audios=audios,
return_tensors="pt",
chunk_input=True,
)
return {
"input_ids": result.input_ids,
"pixel_values": result.pixel_values,
"tgt_sizes": result.tgt_sizes,
"pixel_values": getattr(result, "pixel_values", None),
"tgt_sizes": getattr(result, "tgt_sizes", None),
"audio_features": getattr(result, "audio_features", None),
"audio_feature_lens": getattr(result, "audio_feature_lens", None),
"audio_bounds": getattr(result, "audio_bounds", None),
}
async def _process_images(self, images, input_text):
async def _process_data(self, images, input_text, audios=None):
if self.executor is not None:
loop = asyncio.get_event_loop()
image_inputs = await loop.run_in_executor(
multimodal_data_inputs = await loop.run_in_executor(
self.executor,
MiniCPMVImageProcessor._process_images_task,
images,
MiniCPMMultimodalProcessor._process_data_task,
input_text,
images,
audios,
)
else:
image_inputs = self._processor(
images=images, text=input_text, return_tensors="pt"
multimodal_data_inputs = self._processor(
images=images, text=input_text, audios=audios, return_tensors="pt"
)
return image_inputs
return multimodal_data_inputs
async def process_images_async(
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
request_obj,
max_req_input_len,
):
if not image_data:
audio_data = request_obj.audio_data
if not image_data and not audio_data:
return None
if not isinstance(image_data, list):
image_data = [image_data]
if not isinstance(audio_data, list):
audio_data = [audio_data]
base_output = self.load_images(
base_output = self.load_mm_data(
input_ids=input_ids,
image_data=image_data,
image_token=self.IMAGE_TOKEN,
max_req_input_len=max_req_input_len,
audio_data=audio_data,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(
image_token=self.image_token, audio_token=self.audio_token
),
)
if base_output is None:
return None
if len(base_output.all_frames) == 0:
return None
res = await self._process_images(
images=base_output.all_frames, input_text=base_output.input_text
res = await self._process_data(
images=base_output.images,
input_text=base_output.input_text,
audios=base_output.audios,
)
# Collect special token ids
tokenizer = self._processor.tokenizer
im_start_id = tokenizer.im_start_id
im_token_id = tokenizer.unk_token_id
im_end_id = tokenizer.im_end_id
slice_start_id, slice_end_id, audio_start_id, audio_end_id = (
None,
None,
None,
None,
)
if tokenizer.slice_start_id:
slice_start_id = tokenizer.slice_start_id
slice_end_id = tokenizer.slice_end_id
if hasattr(tokenizer, "audio_start_id"):
audio_start_id = tokenizer.audio_start_id
audio_end_id = tokenizer.audio_end_id
im_token_id = tokenizer.unk_token_id
pixel_values = res["pixel_values"]
tgt_sizes = res["tgt_sizes"]
......@@ -98,8 +129,6 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
f"{len(pixel_values)} vs. {len(tgt_sizes)}"
)
# tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)]
# tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
pixel_values_flat: List[torch.Tensor] = []
tgt_sizes_flat: List[torch.Tensor] = []
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
......@@ -109,21 +138,30 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
"Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
)
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
# per patch
pixel_values_flat += [pixel_n]
tgt_sizes_flat += [tgt_n]
pixel_values = pixel_values_flat
tgt_sizes = torch.stack(tgt_sizes_flat)
if len(tgt_sizes_flat) == 0:
tgt_sizes = None
else:
tgt_sizes = torch.stack(tgt_sizes_flat)
if not isinstance(res["audio_features"], list):
res["audio_features"] = [res["audio_features"]]
return {
"input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": pixel_values,
"tgt_sizes": tgt_sizes,
"image_hashes": base_output.image_hashes,
"data_hashes": base_output.mm_data_hashes,
"modalities": request_obj.modalities or ["image"],
"im_start_id": im_start_id,
"audio_start_id": audio_start_id,
"audio_end_id": audio_end_id,
"audio_features": res["audio_features"],
"audio_bounds": res["audio_bounds"],
"audio_feature_lens": res["audio_feature_lens"],
"im_token_id": im_token_id,
"im_end_id": im_end_id,
"im_start_id": tokenizer.im_start_id,
"im_end_id": tokenizer.im_end_id,
"slice_start_id": slice_start_id,
"slice_end_id": slice_end_id,
}
import asyncio
from typing import List, Union
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
get_global_processor,
)
from sglang.srt.models.mllama import MllamaForConditionalGeneration
from sglang.srt.utils import load_image
class MllamaImageProcessor(BaseImageProcessor):
class MllamaImageProcessor(BaseMultimodalProcessor):
models = [MllamaForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
......@@ -34,7 +34,7 @@ class MllamaImageProcessor(BaseImageProcessor):
return image_inputs
async def process_images_async(
async def process_mm_data_async(
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
):
if not image_data:
......@@ -53,7 +53,7 @@ class MllamaImageProcessor(BaseImageProcessor):
images = load_image(image_data[0])[0]
image_inputs = await self._process_single_image(images, input_text)
image_inputs["image_hashes"] = [hash(str(image_data))]
image_inputs["data_hashes"] = [hash(str(image_data))]
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
return image_inputs
import asyncio
import math
import time
from typing import List, Union
import torch
from PIL import Image
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
from sglang.srt.managers.multimodal_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
from sglang.srt.managers.multimodal_processors.base_processor import (
MultimodalSpecialTokens,
get_global_processor,
)
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
......@@ -14,7 +18,7 @@ from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
# Compatible with Qwen2VL and Qwen2_5VL
class Qwen2_5VLImageProcessor(BaseImageProcessor):
class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
......@@ -59,7 +63,7 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
else:
return self._process_images_task(images, input_text, self.hf_config)
async def process_images_async(
async def process_mm_data_async(
self,
image_data: List[Union[str, bytes]],
input_ids,
......@@ -68,16 +72,17 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
*args,
**kwargs,
):
start = time.time()
if not image_data:
return None
if isinstance(image_data, str):
image_data = [image_data]
image_token = self.IMAGE_TOKEN
base_output = self.load_images(
base_output = self.load_mm_data(
input_ids=input_ids,
image_data=image_data,
image_token=image_token,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len,
)
......@@ -139,7 +144,7 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
images = [resize_image(image) for image in base_output.all_frames]
images = [resize_image(image) for image in base_output.images]
ret = await self._process_single_image(
images=images, input_text=base_output.input_text
......@@ -147,11 +152,10 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
image_grid_thws = torch.concat([ret["image_grid_thw"]])
video_grid_thws = None
return {
"input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"],
"image_hashes": base_output.image_hashes,
"data_hashes": base_output.mm_data_hashes,
"modalities": request_obj.modalities or ["image"],
"image_grid_thws": image_grid_thws,
"video_grid_thws": video_grid_thws,
......
......@@ -144,11 +144,11 @@ class FINISH_ABORT(BaseFinishReason):
@dataclasses.dataclass
class ImageInputs:
class MultimodalInputs:
"""The image related inputs."""
pixel_values: Union[torch.Tensor, np.array]
image_hashes: Optional[list] = None
data_hashes: Optional[list] = None
image_sizes: Optional[list] = None
image_offsets: Optional[list] = None
image_pad_len: Optional[list] = None
......@@ -182,20 +182,27 @@ class ImageInputs:
im_end_id: Optional[int] = None
slice_start_id: Optional[int] = None
slice_end_id: Optional[int] = None
# [num_images, 2 (w, h)]
tgt_sizes: Optional[list] = None
# audio
audio_start_id: Optional[torch.Tensor] = None
audio_end_id: Optional[torch.Tensor] = None
audio_features: Optional[List[torch.Tensor]] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None
@staticmethod
def from_dict(obj: dict):
ret = ImageInputs(
ret = MultimodalInputs(
pixel_values=obj["pixel_values"],
image_hashes=obj["image_hashes"],
data_hashes=obj["data_hashes"],
)
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example.
ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
ret.pad_values = [x % (1 << 30) for x in ret.data_hashes]
optional_args = [
"image_sizes",
......@@ -211,6 +218,10 @@ class ImageInputs:
"slice_start_id",
"slice_end_id",
"tgt_sizes",
"audio_start_id",
"audio_end_id",
"audio_features",
"audio_feature_lens",
]
for arg in optional_args:
if arg in obj:
......@@ -223,9 +234,19 @@ class ImageInputs:
or isinstance(ret.pixel_values, list)
)
assert ret.audio_features is None or isinstance(ret.audio_features, list)
return ret
def merge(self, other: ImageInputs):
def contains_image_inputs(self) -> bool:
""" """
return self.pixel_values is not None and self.pixel_values != []
def contains_audio_inputs(self) -> bool:
""" """
return self.audio_features is not None and self.audio_features != []
def merge(self, other: MultimodalInputs):
"""
merge image inputs when requests are being merged
"""
......@@ -268,10 +289,12 @@ class ImageInputs:
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example.
self.image_hashes += other.image_hashes
self.pad_values = [x % (1 << 30) for x in self.image_hashes]
self.data_hashes += other.data_hashes
self.pad_values = [x % (1 << 30) for x in self.data_hashes]
# args needed to be merged
optional_args = [
"audio_features",
"image_sizes",
"image_offsets",
"image_pad_len",
......@@ -362,7 +385,7 @@ class Req:
self.decoded_text = ""
# For multimodal inputs
self.image_inputs: Optional[ImageInputs] = None
self.multimodal_inputs: Optional[MultimodalInputs] = None
# Prefix info
# The indices to kv cache for the shared prefix.
......@@ -458,10 +481,10 @@ class Req:
return len(self.origin_input_ids) + len(self.output_ids)
def extend_image_inputs(self, image_inputs):
if self.image_inputs is None:
self.image_inputs = image_inputs
if self.multimodal_inputs is None:
self.multimodal_inputs = image_inputs
else:
self.image_inputs.merge(image_inputs)
self.multimodal_inputs.merge(image_inputs)
def finished(self) -> bool:
# Whether request reached finished condition
......@@ -802,7 +825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.encoder_cached = []
for req in self.reqs:
im = req.image_inputs
im = req.multimodal_inputs
if im is None or im.num_image_tokens is None:
# No image input
self.encoder_lens_cpu.append(0)
......@@ -1391,7 +1414,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
extend_logprob_start_lens=extend_logprob_start_lens,
image_inputs=[r.image_inputs for r in self.reqs],
multimodal_inputs=[r.multimodal_inputs for r in self.reqs],
encoder_cached=self.encoder_cached,
encoder_lens=self.encoder_lens,
encoder_lens_cpu=self.encoder_lens_cpu,
......@@ -1474,7 +1497,7 @@ class ModelWorkerBatch:
extend_input_logprob_token_ids: Optional[torch.Tensor]
# For multimodal
image_inputs: Optional[List[ImageInputs]]
multimodal_inputs: Optional[List[MultimodalInputs]]
# For encoder-decoder
encoder_cached: Optional[List[bool]]
......
......@@ -88,7 +88,7 @@ from sglang.srt.managers.io_struct import (
)
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
ImageInputs,
MultimodalInputs,
Req,
ScheduleBatch,
global_server_args_dict,
......@@ -841,8 +841,8 @@ class Scheduler(
return
# Handle multimodal inputs
if recv_req.image_inputs is not None:
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
if recv_req.mm_inputs is not None:
image_inputs = MultimodalInputs.from_dict(recv_req.mm_inputs)
# Expand a single image token into multiple dummy tokens for receiving image embeddings
req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids, image_inputs
......@@ -856,7 +856,7 @@ class Scheduler(
)
logger.error(error_msg)
req.origin_input_ids = [0]
req.image_inputs = None
req.multimodal_inputs = None
req.sampling_params.max_new_tokens = 0
req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
......@@ -960,7 +960,7 @@ class Scheduler(
# Handle multimodal inputs
if recv_req.image_inputs is not None:
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
image_inputs = MultimodalInputs.from_dict(recv_req.image_inputs)
# Expand a single image token into multiple dummy tokens for receiving image embeddings
req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids, image_inputs
......@@ -974,7 +974,7 @@ class Scheduler(
)
logger.error(error_msg)
req.origin_input_ids = [0]
req.image_inputs = None
req.multimodal_inputs = None
req.sampling_params.max_new_tokens = 0
req.finished_reason = FINISH_ABORT(
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
......
......@@ -138,7 +138,7 @@ class Session:
token_ids_logprob=req.token_ids_logprob,
)
if last_req is not None:
new_req.image_inputs = last_req.image_inputs
new_req.multimodal_inputs = last_req.mm_inputs
new_req.tokenizer = tokenizer
if abort:
new_req.to_abort = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment