Unverified Commit b5e3d603 authored by Mick's avatar Mick Committed by GitHub
Browse files

vlm: support video as an input modality (#5888)

parent 4ed57807
...@@ -56,7 +56,6 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInp ...@@ -56,7 +56,6 @@ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInp
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -507,11 +506,15 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -507,11 +506,15 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds return image_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor: def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype) # in qwen-vl, last dim is the same
video_embeds = self.visual( pixel_values = torch.cat(
pixel_values_videos, grid_thw=video_input["video_grid_thw"] [getattr(item, "pixel_values_videos") for item in items], dim=0
) ).type(self.visual.dtype)
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
return video_embeds return video_embeds
def get_input_embeddings(self): def get_input_embeddings(self):
...@@ -553,7 +556,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -553,7 +556,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
input_ids=input_ids, input_ids=input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
language_model=self.model, language_model=self.model,
image_data_embedding_func=self.get_image_feature, multimodal_model=self,
positions=positions, positions=positions,
) )
......
...@@ -493,6 +493,17 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -493,6 +493,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds return image_embeds
def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# in qwen-vl, last dim is the same
pixel_values = torch.cat(
[item.pixel_values_videos for item in items], dim=0
).type(self.visual.dtype)
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
return video_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor: def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype) pixel_values_videos = video_input["pixel_values_videos"].type(self.visual.dtype)
video_embeds = self.visual( video_embeds = self.visual(
...@@ -538,7 +549,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -538,7 +549,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
input_ids=input_ids, input_ids=input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
language_model=self.model, language_model=self.model,
image_data_embedding_func=self.get_image_feature, multimodal_model=self,
positions=positions, positions=positions,
) )
......
...@@ -17,7 +17,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO ...@@ -17,7 +17,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM
...@@ -223,7 +227,9 @@ class VILAForConditionalGeneration(nn.Module): ...@@ -223,7 +227,9 @@ class VILAForConditionalGeneration(nn.Module):
input_ids=input_ids, input_ids=input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
language_model=self.llm, language_model=self.llm,
image_data_embedding_func=self.get_image_feature, data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
},
get_embedding=get_embedding, get_embedding=get_embedding,
positions=positions, positions=positions,
) )
......
...@@ -5,7 +5,7 @@ import multiprocessing as mp ...@@ -5,7 +5,7 @@ import multiprocessing as mp
import os import os
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from functools import lru_cache
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -14,7 +14,7 @@ from PIL import Image ...@@ -14,7 +14,7 @@ from PIL import Image
from transformers import BaseImageProcessorFast from transformers import BaseImageProcessorFast
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.utils import encode_video, load_audio, load_image from sglang.srt.utils import load_audio, load_image, load_video, logger
@dataclasses.dataclass @dataclasses.dataclass
...@@ -25,14 +25,22 @@ class BaseMultiModalProcessorOutput: ...@@ -25,14 +25,22 @@ class BaseMultiModalProcessorOutput:
# frames loaded from image and video, in given order # frames loaded from image and video, in given order
images: Optional[list[Union[Image.Image, dict]]] = None images: Optional[list[Union[Image.Image, dict]]] = None
# videos
videos: Optional[list[Union[torch.Tensor, dict]]] = None
# audios # audios
audios: Optional[list[Union[np.ndarray, dict]]] = None audios: Optional[list[Union[np.ndarray, dict]]] = None
def normalize(self): def organize_results(self) -> List[Tuple[Modality, Any]]:
for field_name in ["images", "audios"]: """
field = getattr(self, field_name, None)
if field is not None and isinstance(field, list) and len(field) == 0: :return: a list of results, with their corresponding modalities
setattr(self, field_name, None) """
return (
[(Modality.IMAGE, data) for data in self.images]
+ [(Modality.VIDEO, data) for data in self.videos]
+ [(Modality.AUDIO, data) for data in self.audios]
)
@dataclasses.dataclass @dataclasses.dataclass
...@@ -41,6 +49,10 @@ class MultimodalSpecialTokens: ...@@ -41,6 +49,10 @@ class MultimodalSpecialTokens:
video_token: Optional[Union[int, str, List[str]]] = None video_token: Optional[Union[int, str, List[str]]] = None
audio_token: Optional[Union[int, str, List[str]]] = None audio_token: Optional[Union[int, str, List[str]]] = None
image_token_regex: Optional[re.Pattern] = None
video_token_regex: Optional[re.Pattern] = None
audio_token_regex: Optional[re.Pattern] = None
def convert_to_str(self, token: Union[str, int], processor) -> str: def convert_to_str(self, token: Union[str, int], processor) -> str:
if token is None: if token is None:
return token return token
...@@ -53,11 +65,29 @@ class MultimodalSpecialTokens: ...@@ -53,11 +65,29 @@ class MultimodalSpecialTokens:
self.video_token = self.convert_to_str(self.video_token, processor) self.video_token = self.convert_to_str(self.video_token, processor)
self.audio_token = self.convert_to_str(self.audio_token, processor) self.audio_token = self.convert_to_str(self.audio_token, processor)
image_token_regex: Optional[re.Pattern] = None def get_modality_of_token(self, token) -> Optional[Modality]:
video_token_regex: Optional[re.Pattern] = None """
audio_token_regex: Optional[re.Pattern] = None :return: the modality associated with the given token, if the token is a special_token or matches with the multimodal token regex
"""
def __post_init__(self): modality = {
self.image_token: Modality.IMAGE,
self.video_token: Modality.VIDEO,
self.audio_token: Modality.AUDIO,
}.get(token)
if modality:
return modality
for regex, modality in [
(self.image_token_regex, Modality.IMAGE),
(self.video_token_regex, Modality.VIDEO),
(self.audio_token_regex, Modality.AUDIO),
]:
if regex and regex.match(token):
return modality
return None
def parse_regex(self):
if self.image_token_regex is None and self.image_token is not None: if self.image_token_regex is None and self.image_token is not None:
self.image_token_regex = re.compile(re.escape(self.image_token)) self.image_token_regex = re.compile(re.escape(self.image_token))
if self.video_token_regex is None and self.video_token is not None: if self.video_token_regex is None and self.video_token is not None:
...@@ -65,7 +95,7 @@ class MultimodalSpecialTokens: ...@@ -65,7 +95,7 @@ class MultimodalSpecialTokens:
if self.audio_token_regex is None and self.audio_token is not None: if self.audio_token_regex is None and self.audio_token is not None:
self.audio_token_regex = re.compile(re.escape(self.audio_token)) self.audio_token_regex = re.compile(re.escape(self.audio_token))
def collect(self) -> re.Pattern: def combine_regex(self) -> re.Pattern:
tokens = [ tokens = [
self.image_token_regex, self.image_token_regex,
self.video_token_regex, self.video_token_regex,
...@@ -105,6 +135,7 @@ class BaseMultimodalProcessor(ABC): ...@@ -105,6 +135,7 @@ class BaseMultimodalProcessor(ABC):
self.ATTR_NAME_TO_MODALITY = { self.ATTR_NAME_TO_MODALITY = {
# Image-related attributes # Image-related attributes
"pixel_values": Modality.IMAGE, "pixel_values": Modality.IMAGE,
"pixel_values_videos": Modality.VIDEO,
"image_sizes": Modality.IMAGE, "image_sizes": Modality.IMAGE,
"image_grid_thw": Modality.IMAGE, "image_grid_thw": Modality.IMAGE,
"image_emb_mask": Modality.IMAGE, "image_emb_mask": Modality.IMAGE,
...@@ -120,7 +151,7 @@ class BaseMultimodalProcessor(ABC): ...@@ -120,7 +151,7 @@ class BaseMultimodalProcessor(ABC):
"input_features": Modality.AUDIO, "input_features": Modality.AUDIO,
"input_features_mask": Modality.AUDIO, "input_features_mask": Modality.AUDIO,
# Video-related attributes # Video-related attributes
"video_grid_thws": Modality.VIDEO, "video_grid_thw": Modality.VIDEO,
# Generic attributes that could apply to multiple modalities # Generic attributes that could apply to multiple modalities
# "precomputed_features" - handled specially as it can be any modality # "precomputed_features" - handled specially as it can be any modality
} }
...@@ -196,20 +227,25 @@ class BaseMultimodalProcessor(ABC): ...@@ -196,20 +227,25 @@ class BaseMultimodalProcessor(ABC):
@staticmethod @staticmethod
def _load_single_item( def _load_single_item(
data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True data, modality: Modality, frame_count_limit=None, discard_alpha_channel=True
): ):
"""Static method that can be pickled for multiprocessing""" """
Load a single multimodal data.
If data is precomputed, returns directly.
Static method that can be pickled for multiprocessing"""
if isinstance(data, dict): if isinstance(data, dict):
return data return data
try: try:
if is_audio: if modality == Modality.IMAGE:
return load_audio(data)
elif is_video:
path = data[len("video:") :]
return encode_video(path, frame_count_limit)
else:
img, _ = load_image(data) img, _ = load_image(data)
return img.convert("RGB") if discard_alpha_channel else img return img.convert("RGB") if discard_alpha_channel else img
elif modality == Modality.VIDEO:
return load_video(data, frame_count_limit)
elif modality == Modality.AUDIO:
return load_audio(data)
except Exception as e: except Exception as e:
raise RuntimeError(f"Error while loading data {data}: {e}") raise RuntimeError(f"Error while loading data {data}: {e}")
...@@ -217,75 +253,78 @@ class BaseMultimodalProcessor(ABC): ...@@ -217,75 +253,78 @@ class BaseMultimodalProcessor(ABC):
self, self,
text_parts: List[str], text_parts: List[str],
multimodal_tokens: MultimodalSpecialTokens, multimodal_tokens: MultimodalSpecialTokens,
image_data: Optional[list] = None, data_iterators: dict,
audio_data: Optional[list] = None,
discard_alpha_channel: bool = True, discard_alpha_channel: bool = True,
): image_estimated_frames_iter: Optional[iter] = None,
image_scaling_factor: float = 1.0,
max_image_frames: int = 30,
) -> Tuple[List, List]:
""" """
load multimodal data parallelly load multimodal data parallelly using iterators.
""" """
# TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES = 30
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
total_frame_count = sum(estimated_frames_list)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
assert len(image_data) == len(estimated_frames_list)
# Submit all tasks
futures = [] futures = []
task_info = [] task_info = []
image_index, audio_index = 0, 0
for text_part in text_parts: for text_part in text_parts:
if ( modality = multimodal_tokens.get_modality_of_token(text_part)
multimodal_tokens.image_token_regex if modality is not None:
and multimodal_tokens.image_token_regex.match(text_part) data_iterator = data_iterators.get(modality)
): if data_iterator is None:
data = image_data[image_index] raise ValueError(f"No data iterator found for token: {text_part}")
is_video = isinstance(data, str) and data.startswith("video:")
estimated_frames = estimated_frames_list[image_index] try:
frame_count_limit = max(1, int(estimated_frames * scaling_factor)) data = next(data_iterator)
except StopIteration:
raise ValueError(
f"Mismatch: More '{text_part}' tokens found than corresponding data items provided."
)
frame_count_limit = None
if modality == Modality.IMAGE and image_estimated_frames_iter:
try:
estimated_frames = next(image_estimated_frames_iter)
# Use the pre-calculated scaling factor and max frames
frame_count_limit = max(
1, int(estimated_frames * image_scaling_factor)
)
# Ensure we don't exceed the absolute max (redundant if scaling_factor handles it)
# frame_count_limit = min(frame_count_limit, max_image_frames)
except StopIteration:
raise ValueError(
"Mismatch between image tokens and estimated frame counts."
)
futures.append( futures.append(
self.io_executor.submit( self.io_executor.submit(
BaseMultimodalProcessor._load_single_item, BaseMultimodalProcessor._load_single_item,
data, data,
is_video, modality,
False,
frame_count_limit, frame_count_limit,
discard_alpha_channel, discard_alpha_channel,
) )
) )
task_info.append((Modality.IMAGE, data, frame_count_limit)) task_info.append((modality, data, frame_count_limit))
image_index += 1
elif ( for modality, iterator in data_iterators.items():
multimodal_tokens.audio_token_regex try:
and multimodal_tokens.audio_token_regex.match(text_part) next(iterator)
): logger.warning(
data = audio_data[audio_index] f"Warning: More {modality.name.lower()} data items provided than corresponding tokens found in the prompt."
futures.append(
self.io_executor.submit(
BaseMultimodalProcessor._load_single_item,
data,
False,
True,
None,
discard_alpha_channel,
)
) )
task_info.append((Modality.AUDIO, data, None)) except StopIteration:
audio_index += 1 pass
except Exception:
pass
return futures, task_info return futures, task_info
def load_mm_data( def load_mm_data(
self, self,
prompt: str | List[int], prompt: str,
multimodal_tokens: MultimodalSpecialTokens, multimodal_tokens: MultimodalSpecialTokens,
max_req_input_len: int, max_req_input_len: int,
image_data: Optional[list] = None, image_data: Optional[list] = None,
video_data: Optional[list] = None,
audio_data: Optional[list] = None, audio_data: Optional[list] = None,
return_text: Optional[bool] = True, return_text: Optional[bool] = True,
discard_alpha_channel: bool = True, discard_alpha_channel: bool = True,
...@@ -299,14 +338,9 @@ class BaseMultimodalProcessor(ABC): ...@@ -299,14 +338,9 @@ class BaseMultimodalProcessor(ABC):
discard_alpha_channel: if True, discards the alpha channel in the returned images discard_alpha_channel: if True, discards the alpha channel in the returned images
""" """
if not return_text:
raise NotImplementedError()
if image_data is None:
image_data = []
multimodal_tokens.convert_to_strs(self._processor) multimodal_tokens.convert_to_strs(self._processor)
multimodal_tokens_pattern = multimodal_tokens.collect() multimodal_tokens.parse_regex()
multimodal_tokens_pattern = multimodal_tokens.combine_regex()
if isinstance(prompt, list) and return_text: if isinstance(prompt, list) and return_text:
assert len(prompt) and isinstance(prompt[0], int) assert len(prompt) and isinstance(prompt[0], int)
prompt = self._processor.tokenizer.decode(prompt) prompt = self._processor.tokenizer.decode(prompt)
...@@ -317,59 +351,84 @@ class BaseMultimodalProcessor(ABC): ...@@ -317,59 +351,84 @@ class BaseMultimodalProcessor(ABC):
# split text into list of normal text and special tokens # split text into list of normal text and special tokens
text_parts = re.split(multimodal_tokens_pattern, prompt) text_parts = re.split(multimodal_tokens_pattern, prompt)
# collect all data
data_iterators = {}
if multimodal_tokens.image_token and image_data:
data_iterators[Modality.IMAGE] = iter(image_data)
if multimodal_tokens.video_token and video_data:
data_iterators[Modality.VIDEO] = iter(video_data)
if multimodal_tokens.audio_token and audio_data:
data_iterators[Modality.AUDIO] = iter(audio_data)
# futures: the futures of loaded data
# task_info: modality, raw_data, and other metadata of each data
futures, task_info = self.submit_data_loading_tasks( futures, task_info = self.submit_data_loading_tasks(
text_parts=text_parts, text_parts=text_parts,
multimodal_tokens=multimodal_tokens, multimodal_tokens=multimodal_tokens,
image_data=image_data, data_iterators=data_iterators,
audio_data=audio_data,
discard_alpha_channel=discard_alpha_channel, discard_alpha_channel=discard_alpha_channel,
) )
# Process results task_info_iter = iter(task_info)
images, audios = [], [] futures_iter = iter(futures)
new_text = ""
task_ptr = 0
# Process results
images, videos, audios = [], [], []
new_text_parts = []
for text_part in text_parts: for text_part in text_parts:
if multimodal_tokens_pattern.match(text_part): try:
task_type, data, frame_limit = task_info[task_ptr] if multimodal_tokens_pattern.match(text_part):
result = futures[task_ptr].result() modality, raw_data, frame_limit = next(task_info_iter)
task_ptr += 1 is_precomputed = isinstance(raw_data, dict)
result = next(futures_iter).result()
if task_type == Modality.IMAGE:
# If data is already processed it will be a if modality == Modality.IMAGE:
# dictionary. In this case we want to keep the # If data is already processed it will be a
# expanded tokens in text_part. Otherwise, we will # dictionary(precomputed). In this case we want to keep the
# call the processor code, so keep only a single image # expanded tokens in text_part. Otherwise, we will
# token. # call the processor code, so keep only a single image
mm_tokens = ( # token.
text_part mm_tokens = (
if isinstance(data, dict) text_part
else multimodal_tokens.image_token if is_precomputed
) else multimodal_tokens.image_token
frames = [result] if not isinstance(result, list) else result )
if frames: frames = [result] if not isinstance(result, list) else result
images += frames if frames:
new_text += mm_tokens * len(frames) # only for minicpmv
elif task_type == Modality.AUDIO: images += frames
# audio new_text_parts += mm_tokens * len(frames)
mm_tokens = ( elif modality == Modality.VIDEO:
text_part # load as video
if isinstance(data, dict) mm_tokens = (
else multimodal_tokens.audio_token text_part
) if is_precomputed
audios.append(result) else multimodal_tokens.video_token
new_text += mm_tokens )
# TODO: handle video videos += [result]
else: new_text_parts += mm_tokens
new_text += text_part elif modality == Modality.AUDIO:
# audio
out = BaseMultiModalProcessorOutput( mm_tokens = (
input_text=new_text, text_part
if is_precomputed
else multimodal_tokens.audio_token
)
audios += [result]
new_text_parts += mm_tokens
else:
# normal text
new_text_parts += [text_part]
except Exception as e:
raise RuntimeError(
f"An exception occurred while loading multimodal data: {e}"
)
return BaseMultiModalProcessorOutput(
images=images, images=images,
audios=audios, audios=audios,
videos=videos,
input_text="".join(new_text_parts),
) )
out.normalize()
return out
@staticmethod @staticmethod
def get_mm_items_offset( def get_mm_items_offset(
...@@ -460,21 +519,19 @@ class BaseMultimodalProcessor(ABC): ...@@ -460,21 +519,19 @@ class BaseMultimodalProcessor(ABC):
) )
except ValueError: except ValueError:
modality = Modality.IMAGE modality = Modality.IMAGE
if modality: if modality:
# Create item if needed # Create item if needed
if modality not in items: if modality not in items:
items[modality] = MultimodalDataItem(modality=modality) items[modality] = MultimodalDataItem(modality=modality)
# Set attribute # Set attribute
if hasattr(items[modality], attr_name): setattr(items[modality], attr_name, value)
setattr(items[modality], attr_name, value)
return list(items.values()) return list(items.values())
def _process_and_collect_mm_items( def _process_and_collect_mm_items(
self, input_text: str, images=None, audios=None, videos=None, **kwargs self, input_text: str, images=None, audios=None, videos=None, **kwargs
) -> Tuple[List[MultimodalDataItem], torch.Tensor]: ) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]:
""" """
Helper method to process multimodal data and create mm_items in one step. Helper method to process multimodal data and create mm_items in one step.
...@@ -488,11 +545,11 @@ class BaseMultimodalProcessor(ABC): ...@@ -488,11 +545,11 @@ class BaseMultimodalProcessor(ABC):
input_ids = ret["input_ids"].flatten() input_ids = ret["input_ids"].flatten()
collected_items = self.collect_mm_items_from_processor_output(ret) collected_items = self.collect_mm_items_from_processor_output(ret)
return collected_items, input_ids return collected_items, input_ids, ret
def process_and_combine_mm_data( def process_and_combine_mm_data(
self, base_output: BaseMultiModalProcessorOutput self, base_output: BaseMultiModalProcessorOutput
) -> Tuple[List[MultimodalDataItem], torch.Tensor]: ) -> Tuple[List[MultimodalDataItem], torch.Tensor, dict]:
""" """
Process multimodal data and return the combined multimodal items and input_ids. Process multimodal data and return the combined multimodal items and input_ids.
Supports mixed modalities (images and audio in the same request). Supports mixed modalities (images and audio in the same request).
...@@ -501,8 +558,7 @@ class BaseMultimodalProcessor(ABC): ...@@ -501,8 +558,7 @@ class BaseMultimodalProcessor(ABC):
Tuple of (list of mm_items, input_ids) Tuple of (list of mm_items, input_ids)
""" """
# Collect all items and categorize them # Collect all items and categorize them
all_items = (base_output.images or []) + (base_output.audios or []) all_items = base_output.organize_results()
# Handle text-only case # Handle text-only case
if not all_items: if not all_items:
input_ids = self._processor.tokenizer( input_ids = self._processor.tokenizer(
...@@ -510,19 +566,20 @@ class BaseMultimodalProcessor(ABC): ...@@ -510,19 +566,20 @@ class BaseMultimodalProcessor(ABC):
return_tensors="pt", return_tensors="pt",
add_special_tokens=True, add_special_tokens=True,
).input_ids.flatten() ).input_ids.flatten()
return [], input_ids return [], input_ids, {}
dict_items, raw_images, raw_audios = [], [], [] dict_items, raw_images, raw_audios, raw_videos = [], [], [], []
for item in all_items: for modality, item in all_items:
if isinstance(item, dict): if isinstance(item, dict):
dict_items.append(item) dict_items.append(item)
elif isinstance(item, Image.Image): elif modality == Modality.IMAGE:
raw_images.append(item) raw_images.append(item)
elif isinstance(item, np.ndarray): elif modality == Modality.AUDIO:
raw_audios.append(item) raw_audios.append(item)
elif modality == Modality.VIDEO:
raw_videos.append(item)
else: else:
raise ValueError(f"Unknown multimodal item type: {type(item)}") raise ValueError(f"Unknown multimodal item type: {type(item)}")
# Process items and get input_ids # Process items and get input_ids
all_collected_items = [] all_collected_items = []
input_ids = None input_ids = None
...@@ -534,13 +591,16 @@ class BaseMultimodalProcessor(ABC): ...@@ -534,13 +591,16 @@ class BaseMultimodalProcessor(ABC):
) )
# Handle raw items (need processing) # Handle raw items (need processing)
if raw_images or raw_audios: if raw_images or raw_audios or raw_videos:
collected_items, input_ids = self._process_and_collect_mm_items( collected_items, input_ids, ret = self._process_and_collect_mm_items(
input_text=base_output.input_text, input_text=base_output.input_text,
images=raw_images, images=raw_images,
audios=raw_audios, audios=raw_audios,
videos=raw_videos,
) )
all_collected_items.extend(collected_items) all_collected_items.extend(collected_items)
else:
ret = None
# Fallback tokenization if no raw items were processed # Fallback tokenization if no raw items were processed
if input_ids is None: if input_ids is None:
...@@ -553,21 +613,21 @@ class BaseMultimodalProcessor(ABC): ...@@ -553,21 +613,21 @@ class BaseMultimodalProcessor(ABC):
# Add offsets to all items # Add offsets to all items
for mm_item in all_collected_items: for mm_item in all_collected_items:
if mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]: if mm_item.modality in [Modality.IMAGE, Modality.MULTI_IMAGES]:
mm_item.image_offsets = self.get_mm_items_offset( mm_item.offsets = self.get_mm_items_offset(
input_ids=input_ids, input_ids=input_ids,
mm_token_id=self.IM_TOKEN_ID, mm_token_id=self.IM_TOKEN_ID,
) )
elif mm_item.modality == Modality.AUDIO: elif mm_item.modality == Modality.AUDIO:
mm_item.audio_offsets = self.get_mm_items_offset( mm_item.offsets = self.get_mm_items_offset(
input_ids=input_ids, input_ids=input_ids,
mm_token_id=self.AUDIO_TOKEN_ID, mm_token_id=self.AUDIO_TOKEN_ID,
) )
elif mm_item.modality == Modality.VIDEO: elif mm_item.modality == Modality.VIDEO:
mm_item.video_offsets = self.get_mm_items_offset( mm_item.offsets = self.get_mm_items_offset(
input_ids=input_ids, input_ids=input_ids,
mm_token_id=self.VIDEO_TOKEN_ID, mm_token_id=self.VIDEO_TOKEN_ID,
) )
else: else:
raise ValueError(f"Unknown modality: {mm_item.modality}") raise ValueError(f"Unknown modality: {mm_item.modality}")
return all_collected_items, input_ids return all_collected_items, input_ids, ret
...@@ -69,7 +69,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor): ...@@ -69,7 +69,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
) )
item = MultimodalDataItem( item = MultimodalDataItem(
pixel_values=res["images"], pixel_values=res["images"],
image_offsets=image_offsets, offsets=image_offsets,
modality=Modality.IMAGE, modality=Modality.IMAGE,
image_emb_mask=images_seq_mask, image_emb_mask=images_seq_mask,
image_spatial_crop=batched_images_spatial_crop, image_spatial_crop=batched_images_spatial_crop,
......
...@@ -36,6 +36,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): ...@@ -36,6 +36,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
*args, *args,
**kwargs, **kwargs,
): ):
print(f"{image_data=}")
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
...@@ -46,8 +47,9 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): ...@@ -46,8 +47,9 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
discard_alpha_channel=True, discard_alpha_channel=True,
) )
mm_items, input_ids = self.process_and_combine_mm_data(base_output) mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output)
print(f"{base_output=}")
print(f"{mm_items=}")
return { return {
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
"mm_items": mm_items, "mm_items": mm_items,
......
...@@ -72,7 +72,7 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor): ...@@ -72,7 +72,7 @@ class Gemma3nSGLangProcessor(SGLangBaseProcessor):
), ),
) )
mm_items, input_ids = self.process_and_combine_mm_data(base_output) mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output)
return { return {
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
......
...@@ -225,7 +225,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor): ...@@ -225,7 +225,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
MultimodalDataItem( MultimodalDataItem(
pixel_values=pixel_values, pixel_values=pixel_values,
modality=Modality.IMAGE, modality=Modality.IMAGE,
image_offsets=image_offsets, offsets=image_offsets,
) )
] ]
......
...@@ -49,7 +49,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor): ...@@ -49,7 +49,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
MultimodalDataItem( MultimodalDataItem(
pixel_values=res["pixel_values"], pixel_values=res["pixel_values"],
image_emb_mask=res["images_emb_mask"], image_emb_mask=res["images_emb_mask"],
image_offsets=image_offsets, offsets=image_offsets,
modality=Modality.IMAGE, modality=Modality.IMAGE,
) )
], ],
......
...@@ -39,7 +39,7 @@ class KimiVLImageProcessor(SGLangBaseProcessor): ...@@ -39,7 +39,7 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
) )
mm_items, input_ids = self.process_and_combine_mm_data(base_output) mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output)
return { return {
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
......
...@@ -19,6 +19,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -19,6 +19,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.image_token = "(<image>./</image>)" self.image_token = "(<image>./</image>)"
self.audio_token = "(<audio>./</audio>)" self.audio_token = "(<audio>./</audio>)"
self.video_token = "(<video>./</video>)"
async def process_mm_data_async( async def process_mm_data_async(
self, self,
...@@ -36,6 +37,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -36,6 +37,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens( multimodal_tokens=MultimodalSpecialTokens(
image_token=self.image_token, image_token=self.image_token,
video_token=self.video_token,
audio_token=self.audio_token, audio_token=self.audio_token,
), ),
) )
...@@ -113,7 +115,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -113,7 +115,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
if len(pixel_values) != 0: if len(pixel_values) != 0:
item = MultimodalDataItem( item = MultimodalDataItem(
pixel_values=pixel_values, pixel_values=pixel_values,
image_offsets=image_offsets, offsets=image_offsets,
tgt_size=tgt_sizes_flat, tgt_size=tgt_sizes_flat,
modality=Modality.IMAGE, modality=Modality.IMAGE,
) )
...@@ -135,11 +137,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -135,11 +137,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
item = MultimodalDataItem( item = MultimodalDataItem(
audio_features=[res["audio_features"]], audio_features=[res["audio_features"]],
audio_feature_lens=res["audio_feature_lens"], audio_feature_lens=res["audio_feature_lens"],
audio_offsets=audio_offsets, offsets=audio_offsets,
modality=Modality.AUDIO, modality=Modality.AUDIO,
) )
items += [item] items += [item]
return { return {
"mm_items": items, "mm_items": items,
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
......
...@@ -144,7 +144,7 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor): ...@@ -144,7 +144,7 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
MultimodalDataItem( MultimodalDataItem(
pixel_values=processor_output["pixel_values"], pixel_values=processor_output["pixel_values"],
modality=Modality.IMAGE, modality=Modality.IMAGE,
image_offsets=image_offsets, offsets=image_offsets,
) )
] ]
......
...@@ -65,7 +65,7 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor): ...@@ -65,7 +65,7 @@ class Phi4MMImageProcessor(BaseMultimodalProcessor):
pixel_values=res["input_image_embeds"], pixel_values=res["input_image_embeds"],
image_sizes=res["image_sizes"], image_sizes=res["image_sizes"],
image_emb_mask=res["image_attention_mask"], image_emb_mask=res["image_attention_mask"],
image_offsets=image_offsets, offsets=image_offsets,
modality=Modality.IMAGE, modality=Modality.IMAGE,
) )
] ]
......
...@@ -106,7 +106,7 @@ class PixtralProcessor(BaseMultimodalProcessor): ...@@ -106,7 +106,7 @@ class PixtralProcessor(BaseMultimodalProcessor):
pixel_values=processor_output["pixel_values"], pixel_values=processor_output["pixel_values"],
image_sizes=processor_output["image_sizes"], image_sizes=processor_output["image_sizes"],
modality=Modality.IMAGE, modality=Modality.IMAGE,
image_offsets=image_offsets, offsets=image_offsets,
) )
] ]
......
import asyncio import asyncio
import math import math
import os
import re import re
from typing import Dict, List, Union from typing import List, Union
import torch
import torchvision
from PIL import Image from PIL import Image
from torchvision.transforms import InterpolationMode
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
...@@ -12,6 +16,185 @@ from sglang.srt.multimodal.processors.base_processor import ( ...@@ -12,6 +16,185 @@ from sglang.srt.multimodal.processors.base_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor, BaseMultimodalProcessor as SGLangBaseProcessor,
) )
from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens from sglang.srt.multimodal.processors.base_processor import MultimodalSpecialTokens
from sglang.utils import logger
IMAGE_FACTOR = 28
MIN_PIXELS = 4 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
VIDEO_TOTAL_PIXELS = int(
float(os.environ.get("VIDEO_MAX_PIXELS", 128000 * 28 * 28 * 0.9))
)
VIDEO_MIN_PIXELS = 128 * 28 * 28
VIDEO_MAX_PIXELS = 768 * 28 * 28
FRAME_FACTOR = 2
FPS = 2.0
FPS_MIN_FRAMES = 4
FPS_MAX_FRAMES = 768
def smart_resize(
height: int,
width: int,
factor: int = IMAGE_FACTOR,
min_pixels: int = MIN_PIXELS,
max_pixels: int = MAX_PIXELS,
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def resize_image(image, size_factor: int = IMAGE_FACTOR) -> Image.Image:
width, height = image.size
min_pixels = MIN_PIXELS
max_pixels = MAX_PIXELS
resized_height, resized_width = smart_resize(
height,
width,
factor=size_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
image = image.resize((resized_width, resized_height))
return image
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
async def resize_image_async(image):
return resize_image(image)
def smart_nframes(
ele: dict,
total_frames: int,
video_fps: int | float,
) -> int:
"""calculate the number of frames for video used for model inputs.
Args:
ele (dict): a dict contains the configuration of video.
support either `fps` or `nframes`:
- nframes: the number of frames to extract for model inputs.
- fps: the fps to extract frames for model inputs.
- min_frames: the minimum number of frames of the video, only used when fps is provided.
- max_frames: the maximum number of frames of the video, only used when fps is provided.
total_frames (int): the original total number of frames of the video.
video_fps (int | float): the original fps of the video.
Raises:
ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
Returns:
int: the number of frames for video used for model inputs.
"""
assert not (
"fps" in ele and "nframes" in ele
), "Only accept either `fps` or `nframes`"
if "nframes" in ele:
nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
else:
fps = ele.get("fps", FPS)
min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
max_frames = floor_by_factor(
ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR
)
nframes = total_frames / video_fps * fps
if nframes > total_frames:
logger.warning(
f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]"
)
nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
nframes = floor_by_factor(nframes, FRAME_FACTOR)
if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
raise ValueError(
f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}."
)
return nframes
# process video, qwen-specific
async def preprocess_video(
vr,
image_factor: int = IMAGE_FACTOR,
# vr: VideoReader, image_factor: int = IMAGE_FACTOR
) -> torch.Tensor:
ele = {}
total_frames, video_fps = len(vr), vr.get_avg_fps()
nframes = smart_nframes({}, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
video = vr.get_batch(idx).asnumpy()
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
nframes, _, height, width = video.shape
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
max_pixels = max(
min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR),
int(min_pixels * 1.05),
)
max_pixels_supposed = ele.get("max_pixels", max_pixels)
if max_pixels_supposed > max_pixels:
logger.warning(
f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}]."
)
max_pixels = min(max_pixels_supposed, max_pixels)
if "resized_height" in ele and "resized_width" in ele:
resized_height, resized_width = smart_resize(
ele["resized_height"],
ele["resized_width"],
factor=image_factor,
)
else:
resized_height, resized_width = smart_resize(
height,
width,
factor=image_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
video = torchvision.transforms.functional.resize(
video,
[resized_height, resized_width],
interpolation=InterpolationMode.BICUBIC,
antialias=True,
).float()
return video
# Compatible with Qwen2VL and Qwen2_5VL # Compatible with Qwen2VL and Qwen2_5VL
...@@ -37,104 +220,44 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -37,104 +220,44 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self.MIN_PIXELS = 4 * 28 * 28 self.MIN_PIXELS = 4 * 28 * 28
self.MAX_PIXELS = 16384 * 28 * 28 self.MAX_PIXELS = 16384 * 28 * 28
self.MAX_RATIO = 200 self.MAX_RATIO = 200
# TODO(mick): move all MultimodalSpecialTokens initializations into processor init
self.mm_special_tokens = MultimodalSpecialTokens(
image_token=self.IMAGE_TOKEN,
image_token_regex=self.IMAGE_TOKEN_REGEX,
video_token=self.VIDEO_TOKEN_ID,
)
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes, Dict]], image_data: List[Union[str, bytes]],
input_text, input_text,
request_obj, request_obj,
max_req_input_len, max_req_input_len,
*args, *args,
**kwargs, **kwargs,
): ):
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens( video_data=request_obj.video_data,
image_token=self.IMAGE_TOKEN, multimodal_tokens=self.mm_special_tokens,
image_token_regex=self.IMAGE_TOKEN_REGEX,
),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
) )
def smart_resize(
height: int,
width: int,
factor: int = self.IMAGE_FACTOR,
min_pixels: int = self.MIN_PIXELS,
max_pixels: int = self.MAX_PIXELS,
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > self.MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {self.MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def resize_image(image, size_factor: int = self.IMAGE_FACTOR) -> Image.Image:
width, height = image.size
min_pixels = self.MIN_PIXELS
max_pixels = self.MAX_PIXELS
resized_height, resized_width = smart_resize(
height,
width,
factor=size_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
image = image.resize((resized_width, resized_height))
return image
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
async def resize_image_async(image):
return resize_image(image)
# Qwen-specific: resize images if they are raw Image objects # Qwen-specific: resize images if they are raw Image objects
if base_output.images and isinstance(base_output.images[0], Image.Image): if base_output.images and isinstance(base_output.images[0], Image.Image):
resize_tasks = [resize_image_async(image) for image in base_output.images] resize_tasks = [resize_image_async(image) for image in base_output.images]
base_output.images = await asyncio.gather(*resize_tasks) base_output.images = await asyncio.gather(*resize_tasks)
video_grid_thw = None # TODO if base_output.videos:
base_output.videos = [
mm_items, input_ids = self.process_and_combine_mm_data(base_output) await preprocess_video(video) for video in base_output.videos
]
if not mm_items:
# Note(Xinyuan): This is the case where image loading fails.
return None
combined_mm_item = mm_items[0] # only image is supported for now mm_items, input_ids, ret = self.process_and_combine_mm_data(base_output)
video_grid_thw = None # TODO
second_per_grid_ts = getattr(combined_mm_item, "second_per_grid_ts", None)
input_ids = input_ids.flatten()
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size, spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
image_token_id=self.IM_TOKEN_ID, image_token_id=self.IM_TOKEN_ID,
...@@ -145,9 +268,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -145,9 +268,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self.hf_config.vision_config, "tokens_per_second", None self.hf_config.vision_config, "tokens_per_second", None
), ),
input_ids=input_ids.unsqueeze(0), input_ids=input_ids.unsqueeze(0),
image_grid_thw=combined_mm_item.image_grid_thw, image_grid_thw=getattr(ret, "image_grid_thw", None),
video_grid_thw=video_grid_thw, video_grid_thw=getattr(ret, "video_grid_thw", None),
second_per_grid_ts=second_per_grid_ts, second_per_grid_ts=getattr(ret, "second_per_grid_ts", None),
) )
mrope_positions = mrope_positions.squeeze(1) mrope_positions = mrope_positions.squeeze(1)
......
...@@ -57,7 +57,7 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor): ...@@ -57,7 +57,7 @@ class VILAMultimodalProcessor(BaseMultimodalProcessor):
image_data=image_data, image_data=image_data,
) )
mm_items, input_ids = self.process_and_combine_mm_data(base_output) mm_items, input_ids, _ = self.process_and_combine_mm_data(base_output)
return { return {
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
......
...@@ -728,33 +728,6 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra ...@@ -728,33 +728,6 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
return audio return audio
def encode_video(video_path, frame_count_limit=None):
# Lazy import because decord is not available on some arm platforms.
from decord import VideoReader, cpu
if not os.path.exists(video_path):
logger.error(f"Video {video_path} does not exist")
return []
if frame_count_limit == 0:
return []
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_indices = [i for i in range(0, len(vr), sample_fps)]
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
frame_indices = uniform_sample(frame_indices, frame_count_limit)
frames = vr.get_batch(frame_indices).asnumpy()
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
return frames
def load_image( def load_image(
image_file: Union[Image.Image, str, bytes], image_file: Union[Image.Image, str, bytes],
) -> tuple[Image.Image, tuple[int, int]]: ) -> tuple[Image.Image, tuple[int, int]]:
...@@ -774,9 +747,6 @@ def load_image( ...@@ -774,9 +747,6 @@ def load_image(
elif image_file.startswith("data:"): elif image_file.startswith("data:"):
image_file = image_file.split(",")[1] image_file = image_file.split(",")[1]
image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True))) image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
elif image_file.startswith("video:"):
image_file = image_file.replace("video:", "")
image, image_size = decode_video_base64(image_file)
elif isinstance(image_file, str): elif isinstance(image_file, str):
image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True))) image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
else: else:
...@@ -785,6 +755,61 @@ def load_image( ...@@ -785,6 +755,61 @@ def load_image(
return image, image_size return image, image_size
def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
# We import decord here to avoid a strange Segmentation fault (core dumped) issue.
from decord import VideoReader, cpu, gpu
try:
from decord.bridge import decord_bridge
ctx = gpu(0)
_ = decord_bridge.get_ctx_device(ctx)
except Exception:
ctx = cpu(0)
tmp_file = None
vr = None
try:
if isinstance(video_file, bytes):
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tmp_file.write(video_file)
tmp_file.close()
vr = VideoReader(tmp_file.name, ctx=ctx)
elif isinstance(video_file, str):
if video_file.startswith(("http://", "https://")):
timeout = int(os.getenv("REQUEST_TIMEOUT", "10"))
response = requests.get(video_file, stream=True, timeout=timeout)
response.raise_for_status()
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
for chunk in response.iter_content(chunk_size=8192):
tmp_file.write(chunk)
tmp_file.close()
vr = VideoReader(tmp_file.name, ctx=ctx)
elif video_file.startswith("data:"):
_, encoded = video_file.split(",", 1)
video_bytes = base64.b64decode(encoded)
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tmp_file.write(video_bytes)
tmp_file.close()
vr = VideoReader(tmp_file.name, ctx=ctx)
elif os.path.isfile(video_file):
vr = VideoReader(video_file, ctx=ctx)
else:
video_bytes = base64.b64decode(video_file)
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tmp_file.write(video_bytes)
tmp_file.close()
vr = VideoReader(tmp_file.name, ctx=ctx)
else:
raise ValueError(f"Unsupported video input type: {type(video_file)}")
return vr
finally:
if tmp_file and os.path.exists(tmp_file.name):
os.unlink(tmp_file.name)
def suppress_other_loggers(): def suppress_other_loggers():
warnings.filterwarnings( warnings.filterwarnings(
"ignore", category=UserWarning, message="The given NumPy array is not writable" "ignore", category=UserWarning, message="The given NumPy array is not writable"
......
...@@ -3,7 +3,6 @@ Unit tests for Jinja chat template utils. ...@@ -3,7 +3,6 @@ Unit tests for Jinja chat template utils.
""" """
import unittest import unittest
from unittest.mock import patch
from sglang.srt.jinja_template_utils import ( from sglang.srt.jinja_template_utils import (
detect_jinja_template_content_format, detect_jinja_template_content_format,
...@@ -76,11 +75,12 @@ class TestTemplateContentFormatDetection(CustomTestCase): ...@@ -76,11 +75,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
} }
image_data = [] image_data = []
video_data = []
audio_data = [] audio_data = []
modalities = [] modalities = []
result = process_content_for_template_format( result = process_content_for_template_format(
msg_dict, "openai", image_data, audio_data, modalities msg_dict, "openai", image_data, video_data, audio_data, modalities
) )
# Check that image_data was extracted # Check that image_data was extracted
...@@ -111,11 +111,12 @@ class TestTemplateContentFormatDetection(CustomTestCase): ...@@ -111,11 +111,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
} }
image_data = [] image_data = []
video_data = []
audio_data = [] audio_data = []
modalities = [] modalities = []
result = process_content_for_template_format( result = process_content_for_template_format(
msg_dict, "string", image_data, audio_data, modalities msg_dict, "string", image_data, video_data, audio_data, modalities
) )
# For string format, should flatten to text only # For string format, should flatten to text only
...@@ -139,11 +140,12 @@ class TestTemplateContentFormatDetection(CustomTestCase): ...@@ -139,11 +140,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
} }
image_data = [] image_data = []
video_data = []
audio_data = [] audio_data = []
modalities = [] modalities = []
result = process_content_for_template_format( result = process_content_for_template_format(
msg_dict, "openai", image_data, audio_data, modalities msg_dict, "openai", image_data, video_data, audio_data, modalities
) )
# Check that audio_data was extracted # Check that audio_data was extracted
...@@ -162,11 +164,12 @@ class TestTemplateContentFormatDetection(CustomTestCase): ...@@ -162,11 +164,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
msg_dict = {"role": "user", "content": "Hello world"} msg_dict = {"role": "user", "content": "Hello world"}
image_data = [] image_data = []
video_data = []
audio_data = [] audio_data = []
modalities = [] modalities = []
result = process_content_for_template_format( result = process_content_for_template_format(
msg_dict, "openai", image_data, audio_data, modalities msg_dict, "openai", image_data, video_data, audio_data, modalities
) )
# Should pass through unchanged # Should pass through unchanged
...@@ -188,11 +191,12 @@ class TestTemplateContentFormatDetection(CustomTestCase): ...@@ -188,11 +191,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
} }
image_data = [] image_data = []
video_data = []
audio_data = [] audio_data = []
modalities = [] modalities = []
result = process_content_for_template_format( result = process_content_for_template_format(
msg_dict, "openai", image_data, audio_data, modalities msg_dict, "openai", image_data, video_data, audio_data, modalities
) )
# Check that modalities was extracted # Check that modalities was extracted
...@@ -209,11 +213,12 @@ class TestTemplateContentFormatDetection(CustomTestCase): ...@@ -209,11 +213,12 @@ class TestTemplateContentFormatDetection(CustomTestCase):
} }
image_data = [] image_data = []
video_data = []
audio_data = [] audio_data = []
modalities = [] modalities = []
result = process_content_for_template_format( result = process_content_for_template_format(
msg_dict, "string", image_data, audio_data, modalities msg_dict, "string", image_data, video_data, audio_data, modalities
) )
# None values should be filtered out # None values should be filtered out
......
...@@ -35,6 +35,9 @@ class TestQwen2VLServer(TestOpenAIVisionServer): ...@@ -35,6 +35,9 @@ class TestQwen2VLServer(TestOpenAIVisionServer):
) )
cls.base_url += "/v1" cls.base_url += "/v1"
def test_video_chat_completion(self):
self._test_video_chat_completion()
class TestQwen2_5_VLServer(TestOpenAIVisionServer): class TestQwen2_5_VLServer(TestOpenAIVisionServer):
@classmethod @classmethod
...@@ -54,6 +57,9 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer): ...@@ -54,6 +57,9 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer):
) )
cls.base_url += "/v1" cls.base_url += "/v1"
def test_video_chat_completion(self):
self._test_video_chat_completion()
class TestVLMContextLengthIssue(CustomTestCase): class TestVLMContextLengthIssue(CustomTestCase):
@classmethod @classmethod
......
...@@ -93,7 +93,7 @@ class TestJanusProServer(TestOpenAIVisionServer): ...@@ -93,7 +93,7 @@ class TestJanusProServer(TestOpenAIVisionServer):
) )
cls.base_url += "/v1" cls.base_url += "/v1"
def test_video_chat_completion(self): def test_video_images_chat_completion(self):
pass pass
def test_single_image_chat_completion(self): def test_single_image_chat_completion(self):
...@@ -170,7 +170,7 @@ class TestKimiVLServer(TestOpenAIVisionServer): ...@@ -170,7 +170,7 @@ class TestKimiVLServer(TestOpenAIVisionServer):
) )
cls.base_url += "/v1" cls.base_url += "/v1"
def test_video_chat_completion(self): def test_video_images_chat_completion(self):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment