Unverified Commit b5e3d603 authored by Mick's avatar Mick Committed by GitHub
Browse files

vlm: support video as an input modality (#5888)

parent 4ed57807
...@@ -88,9 +88,11 @@ class Conversation: ...@@ -88,9 +88,11 @@ class Conversation:
stop_str: Union[str, List[str]] = None stop_str: Union[str, List[str]] = None
# The string that represents an image token in the prompt # The string that represents an image token in the prompt
image_token: str = "<image>" image_token: str = "<image>"
video_token: str = "<video>"
audio_token: str = "<audio>" audio_token: str = "<audio>"
image_data: Optional[List[str]] = None image_data: Optional[List[str]] = None
video_data: Optional[List[str]] = None
modalities: Optional[List[str]] = None modalities: Optional[List[str]] = None
stop_token_ids: Optional[int] = None stop_token_ids: Optional[int] = None
...@@ -380,11 +382,15 @@ class Conversation: ...@@ -380,11 +382,15 @@ class Conversation:
self.messages.append([role, message]) self.messages.append([role, message])
def append_image(self, image: str): def append_image(self, image: str):
"""Append a new message.""" """Append a new image."""
self.image_data.append(image) self.image_data.append(image)
def append_video(self, video: str):
"""Append a new video."""
self.video_data.append(video)
def append_audio(self, audio: str): def append_audio(self, audio: str):
"""Append a new message.""" """Append a new audio."""
self.audio_data.append(audio) self.audio_data.append(audio)
def update_last_message(self, message: str): def update_last_message(self, message: str):
...@@ -433,6 +439,7 @@ class Conversation: ...@@ -433,6 +439,7 @@ class Conversation:
sep2=self.sep2, sep2=self.sep2,
stop_str=self.stop_str, stop_str=self.stop_str,
image_token=self.image_token, image_token=self.image_token,
video_token=self.video_token,
audio_token=self.audio_token, audio_token=self.audio_token,
) )
...@@ -495,8 +502,12 @@ def generate_embedding_convs( ...@@ -495,8 +502,12 @@ def generate_embedding_convs(
sep2=conv_template.sep2, sep2=conv_template.sep2,
stop_str=conv_template.stop_str, stop_str=conv_template.stop_str,
image_data=[], image_data=[],
video_data=[],
audio_data=[],
modalities=[], modalities=[],
image_token=conv_template.image_token, image_token=conv_template.image_token,
video_token=conv_template.video_token,
audio_token=conv_template.audio_token,
) )
real_content = "" real_content = ""
...@@ -557,10 +568,12 @@ def generate_chat_conv( ...@@ -557,10 +568,12 @@ def generate_chat_conv(
sep2=conv.sep2, sep2=conv.sep2,
stop_str=conv.stop_str, stop_str=conv.stop_str,
image_data=[], image_data=[],
video_data=[],
audio_data=[], audio_data=[],
modalities=[], modalities=[],
image_token=conv.image_token, image_token=conv.image_token,
audio_token=conv.audio_token, audio_token=conv.audio_token,
video_token=conv.video_token,
) )
if isinstance(request.messages, str): if isinstance(request.messages, str):
...@@ -602,6 +615,7 @@ def generate_chat_conv( ...@@ -602,6 +615,7 @@ def generate_chat_conv(
image_token = "" image_token = ""
audio_token = conv.audio_token audio_token = conv.audio_token
video_token = conv.video_token
for content in message.content: for content in message.content:
if content.type == "text": if content.type == "text":
if num_image_url > 16: if num_image_url > 16:
...@@ -614,6 +628,9 @@ def generate_chat_conv( ...@@ -614,6 +628,9 @@ def generate_chat_conv(
else: else:
real_content += image_token real_content += image_token
conv.append_image(content.image_url.url) conv.append_image(content.image_url.url)
elif content.type == "video_url":
real_content += video_token
conv.append_video(content.video_url.url)
elif content.type == "audio_url": elif content.type == "audio_url":
real_content += audio_token real_content += audio_token
conv.append_audio(content.audio_url.url) conv.append_audio(content.audio_url.url)
...@@ -810,6 +827,7 @@ register_conv_template( ...@@ -810,6 +827,7 @@ register_conv_template(
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
stop_str=["<|im_end|>"], stop_str=["<|im_end|>"],
image_token="<|vision_start|><|image_pad|><|vision_end|>", image_token="<|vision_start|><|image_pad|><|vision_end|>",
video_token="<|vision_start|><|video_pad|><|vision_end|>",
) )
) )
...@@ -870,6 +888,7 @@ register_conv_template( ...@@ -870,6 +888,7 @@ register_conv_template(
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
stop_str=("<|im_end|>", "<|endoftext|>"), stop_str=("<|im_end|>", "<|endoftext|>"),
image_token="(<image>./</image>)", image_token="(<image>./</image>)",
video_token="(<video>./</video>)",
) )
) )
......
...@@ -267,6 +267,10 @@ class ChatCompletionMessageContentImageURL(BaseModel): ...@@ -267,6 +267,10 @@ class ChatCompletionMessageContentImageURL(BaseModel):
detail: Optional[Literal["auto", "low", "high"]] = "auto" detail: Optional[Literal["auto", "low", "high"]] = "auto"
class ChatCompletionMessageContentVideoURL(BaseModel):
url: str
class ChatCompletionMessageContentAudioURL(BaseModel): class ChatCompletionMessageContentAudioURL(BaseModel):
url: str url: str
...@@ -277,6 +281,11 @@ class ChatCompletionMessageContentImagePart(BaseModel): ...@@ -277,6 +281,11 @@ class ChatCompletionMessageContentImagePart(BaseModel):
modalities: Optional[Literal["image", "multi-images", "video"]] = "image" modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
class ChatCompletionMessageContentVideoPart(BaseModel):
type: Literal["video_url"]
video_url: ChatCompletionMessageContentVideoURL
class ChatCompletionMessageContentAudioPart(BaseModel): class ChatCompletionMessageContentAudioPart(BaseModel):
type: Literal["audio_url"] type: Literal["audio_url"]
audio_url: ChatCompletionMessageContentAudioURL audio_url: ChatCompletionMessageContentAudioURL
...@@ -285,6 +294,7 @@ class ChatCompletionMessageContentAudioPart(BaseModel): ...@@ -285,6 +294,7 @@ class ChatCompletionMessageContentAudioPart(BaseModel):
ChatCompletionMessageContentPart = Union[ ChatCompletionMessageContentPart = Union[
ChatCompletionMessageContentTextPart, ChatCompletionMessageContentTextPart,
ChatCompletionMessageContentImagePart, ChatCompletionMessageContentImagePart,
ChatCompletionMessageContentVideoPart,
ChatCompletionMessageContentAudioPart, ChatCompletionMessageContentAudioPart,
] ]
...@@ -629,6 +639,7 @@ class MessageProcessingResult: ...@@ -629,6 +639,7 @@ class MessageProcessingResult:
prompt_ids: Union[str, List[int]] prompt_ids: Union[str, List[int]]
image_data: Optional[Any] image_data: Optional[Any]
audio_data: Optional[Any] audio_data: Optional[Any]
video_data: Optional[Any]
modalities: List[str] modalities: List[str]
stop: List[str] stop: List[str]
tool_call_constraint: Optional[Any] = None tool_call_constraint: Optional[Any] = None
...@@ -82,6 +82,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -82,6 +82,7 @@ class OpenAIServingChat(OpenAIServingBase):
adapted_request = GenerateReqInput( adapted_request = GenerateReqInput(
**prompt_kwargs, **prompt_kwargs,
image_data=processed_messages.image_data, image_data=processed_messages.image_data,
video_data=processed_messages.video_data,
audio_data=processed_messages.audio_data, audio_data=processed_messages.audio_data,
sampling_params=sampling_params, sampling_params=sampling_params,
return_logprob=request.logprobs, return_logprob=request.logprobs,
...@@ -143,6 +144,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -143,6 +144,7 @@ class OpenAIServingChat(OpenAIServingBase):
prompt_ids = [] prompt_ids = []
openai_compatible_messages = [] openai_compatible_messages = []
image_data = [] image_data = []
video_data = []
audio_data = [] audio_data = []
modalities = [] modalities = []
...@@ -158,6 +160,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -158,6 +160,7 @@ class OpenAIServingChat(OpenAIServingBase):
msg_dict, msg_dict,
template_content_format, template_content_format,
image_data, image_data,
video_data,
audio_data, audio_data,
modalities, modalities,
) )
...@@ -214,11 +217,13 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -214,11 +217,13 @@ class OpenAIServingChat(OpenAIServingBase):
stop = request.stop stop = request.stop
image_data = image_data if image_data else None image_data = image_data if image_data else None
audio_data = audio_data if audio_data else None audio_data = audio_data if audio_data else None
video_data = video_data if video_data else None
modalities = modalities if modalities else [] modalities = modalities if modalities else []
return MessageProcessingResult( return MessageProcessingResult(
prompt=prompt, prompt=prompt,
prompt_ids=prompt_ids, prompt_ids=prompt_ids,
image_data=image_data, image_data=image_data,
video_data=video_data,
audio_data=audio_data, audio_data=audio_data,
modalities=modalities, modalities=modalities,
stop=stop, stop=stop,
...@@ -260,6 +265,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -260,6 +265,7 @@ class OpenAIServingChat(OpenAIServingBase):
prompt = conv.get_prompt() prompt = conv.get_prompt()
image_data = conv.image_data if conv.image_data else None image_data = conv.image_data if conv.image_data else None
video_data = conv.video_data if conv.video_data else None
audio_data = conv.audio_data if conv.audio_data else None audio_data = conv.audio_data if conv.audio_data else None
modalities = conv.modalities if conv.modalities else [] modalities = conv.modalities if conv.modalities else []
stop = copy.copy(conv.stop_str or [] if not request.ignore_eos else []) stop = copy.copy(conv.stop_str or [] if not request.ignore_eos else [])
...@@ -277,6 +283,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -277,6 +283,7 @@ class OpenAIServingChat(OpenAIServingBase):
prompt=prompt, prompt=prompt,
prompt_ids=prompt_ids, prompt_ids=prompt_ids,
image_data=image_data, image_data=image_data,
video_data=video_data,
audio_data=audio_data, audio_data=audio_data,
modalities=modalities, modalities=modalities,
stop=stop, stop=stop,
......
...@@ -110,6 +110,7 @@ def process_content_for_template_format( ...@@ -110,6 +110,7 @@ def process_content_for_template_format(
msg_dict: dict, msg_dict: dict,
content_format: str, content_format: str,
image_data: list, image_data: list,
video_data: list,
audio_data: list, audio_data: list,
modalities: list, modalities: list,
) -> dict: ) -> dict:
...@@ -120,6 +121,7 @@ def process_content_for_template_format( ...@@ -120,6 +121,7 @@ def process_content_for_template_format(
msg_dict: Message dictionary with content msg_dict: Message dictionary with content
content_format: 'string' or 'openai' (detected via AST analysis) content_format: 'string' or 'openai' (detected via AST analysis)
image_data: List to append extracted image URLs image_data: List to append extracted image URLs
video_data: List to append extracted video URLs
audio_data: List to append extracted audio URLs audio_data: List to append extracted audio URLs
modalities: List to append modalities modalities: List to append modalities
...@@ -143,6 +145,12 @@ def process_content_for_template_format( ...@@ -143,6 +145,12 @@ def process_content_for_template_format(
modalities.append(chunk.get("modalities")) modalities.append(chunk.get("modalities"))
# Normalize to simple 'image' type for template compatibility # Normalize to simple 'image' type for template compatibility
processed_content_parts.append({"type": "image"}) processed_content_parts.append({"type": "image"})
elif chunk_type == "video_url":
video_data.append(chunk["video_url"]["url"])
if chunk.get("modalities"):
modalities.append(chunk.get("modalities"))
# Normalize to simple 'video' type for template compatibility
processed_content_parts.append({"type": "video"})
elif chunk_type == "audio_url": elif chunk_type == "audio_url":
audio_data.append(chunk["audio_url"]["url"]) audio_data.append(chunk["audio_url"]["url"])
# Normalize to simple 'audio' type # Normalize to simple 'audio' type
......
...@@ -65,6 +65,8 @@ class GenerateReqInput: ...@@ -65,6 +65,8 @@ class GenerateReqInput:
] = None ] = None
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string. # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
video_data: Optional[Union[List[List[str]], List[str], str]] = None
# The sampling_params. See descriptions below. # The sampling_params. See descriptions below.
sampling_params: Optional[Union[List[Dict], Dict]] = None sampling_params: Optional[Union[List[Dict], Dict]] = None
# The request id. # The request id.
...@@ -110,7 +112,11 @@ class GenerateReqInput: ...@@ -110,7 +112,11 @@ class GenerateReqInput:
data_parallel_rank: Optional[int] = None data_parallel_rank: Optional[int] = None
def contains_mm_input(self) -> bool: def contains_mm_input(self) -> bool:
return has_valid_data(self.image_data) or has_valid_data(self.audio_data) return (
has_valid_data(self.image_data)
or has_valid_data(self.video_data)
or has_valid_data(self.audio_data)
)
def normalize_batch_and_arguments(self): def normalize_batch_and_arguments(self):
""" """
...@@ -232,6 +238,7 @@ class GenerateReqInput: ...@@ -232,6 +238,7 @@ class GenerateReqInput:
self._normalize_rid(num) self._normalize_rid(num)
self._normalize_lora_paths(num) self._normalize_lora_paths(num)
self._normalize_image_data(num) self._normalize_image_data(num)
self._normalize_video_data(num)
self._normalize_audio_data(num) self._normalize_audio_data(num)
self._normalize_sampling_params(num) self._normalize_sampling_params(num)
self._normalize_logprob_params(num) self._normalize_logprob_params(num)
...@@ -300,6 +307,15 @@ class GenerateReqInput: ...@@ -300,6 +307,15 @@ class GenerateReqInput:
self.image_data = wrapped_images * self.parallel_sample_num self.image_data = wrapped_images * self.parallel_sample_num
self.modalities = ["image"] * num self.modalities = ["image"] * num
def _normalize_video_data(self, num):
"""Normalize video data for batch processing."""
if self.video_data is None:
self.video_data = [None] * num
elif not isinstance(self.video_data, list):
self.video_data = [self.video_data] * num
elif isinstance(self.video_data, list):
self.video_data = self.video_data * self.parallel_sample_num
def _normalize_audio_data(self, num): def _normalize_audio_data(self, num):
"""Normalize audio data for batch processing.""" """Normalize audio data for batch processing."""
if self.audio_data is None: if self.audio_data is None:
...@@ -408,6 +424,7 @@ class GenerateReqInput: ...@@ -408,6 +424,7 @@ class GenerateReqInput:
self.input_embeds[i] if self.input_embeds is not None else None self.input_embeds[i] if self.input_embeds is not None else None
), ),
image_data=self.image_data[i], image_data=self.image_data[i],
video_data=self.video_data[i],
audio_data=self.audio_data[i], audio_data=self.audio_data[i],
sampling_params=self.sampling_params[i], sampling_params=self.sampling_params[i],
rid=self.rid[i], rid=self.rid[i],
...@@ -507,6 +524,8 @@ class EmbeddingReqInput: ...@@ -507,6 +524,8 @@ class EmbeddingReqInput:
image_data: Optional[ image_data: Optional[
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]] Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
] = None ] = None
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
video_data: Optional[Union[List[str], str]] = None
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string. # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
audio_data: Optional[Union[List[str], str]] = None audio_data: Optional[Union[List[str], str]] = None
# The token ids for text; one can either specify text or input_ids. # The token ids for text; one can either specify text or input_ids.
...@@ -578,7 +597,11 @@ class EmbeddingReqInput: ...@@ -578,7 +597,11 @@ class EmbeddingReqInput:
return self.rid return self.rid
def contains_mm_input(self) -> bool: def contains_mm_input(self) -> bool:
return has_valid_data(self.image_data) or has_valid_data(self.audio_data) return (
has_valid_data(self.image_data)
or has_valid_data(self.video_data)
or has_valid_data(self.audio_data)
)
def __getitem__(self, i): def __getitem__(self, i):
if self.is_cross_encoder_request: if self.is_cross_encoder_request:
......
...@@ -4,7 +4,7 @@ Multi-modality utils ...@@ -4,7 +4,7 @@ Multi-modality utils
import hashlib import hashlib
from abc import abstractmethod from abc import abstractmethod
from typing import Callable, List, Optional, Tuple from typing import Callable, Dict, List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -76,6 +76,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern) ...@@ -76,6 +76,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
This function will replace the data-tokens in between with pad_values accordingly This function will replace the data-tokens in between with pad_values accordingly
""" """
pad_values = [item.pad_value for item in mm_inputs.mm_items] pad_values = [item.pad_value for item in mm_inputs.mm_items]
print(f"{mm_inputs.mm_items=}")
data_token_pairs = self.data_token_id_pairs data_token_pairs = self.data_token_id_pairs
mm_inputs.data_offsets = [] mm_inputs.data_offsets = []
if data_token_pairs is None: if data_token_pairs is None:
...@@ -159,10 +160,10 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa ...@@ -159,10 +160,10 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
return ret_input_ids return ret_input_ids
embedding_cache = None embedding_cache: Optional[MultiModalCache] = None
def init_embedding_cache(max_size: int): def init_embedding_cache(max_size: int = 0):
global embedding_cache global embedding_cache
embedding_cache = MultiModalCache(max_size) embedding_cache = MultiModalCache(max_size)
...@@ -255,6 +256,7 @@ def _get_chunked_prefill_embedding( ...@@ -255,6 +256,7 @@ def _get_chunked_prefill_embedding(
continue continue
embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]] embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
items_offset = items_offset_list[i] items_offset = items_offset_list[i]
assert items_offset is not None, items_offset
embedding_items_hash = get_embedding_hash(embedding_items_per_req) embedding_items_hash = get_embedding_hash(embedding_items_per_req)
# if all items has been prefixed, we do not need to calculate embedding # if all items has been prefixed, we do not need to calculate embedding
if all([offset_end < prefix_length[i] for _, offset_end in items_offset]): if all([offset_end < prefix_length[i] for _, offset_end in items_offset]):
...@@ -380,11 +382,9 @@ def embed_mm_inputs( ...@@ -380,11 +382,9 @@ def embed_mm_inputs(
extend_seq_lens: List[int], extend_seq_lens: List[int],
input_ids: torch.Tensor, input_ids: torch.Tensor,
input_embedding: nn.Embedding, input_embedding: nn.Embedding,
image_data_embedding_func: Callable[ multimodal_model: nn.Module = None,
[List[MultimodalDataItem]], torch.Tensor data_embedding_func_mapping: Dict[
] = None, Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
audio_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
] = None, ] = None,
placeholder_tokens: dict[Modality, List[int]] = None, placeholder_tokens: dict[Modality, List[int]] = None,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
...@@ -397,8 +397,6 @@ def embed_mm_inputs( ...@@ -397,8 +397,6 @@ def embed_mm_inputs(
extend_seq_lens: Sequence lengths for each request extend_seq_lens: Sequence lengths for each request
input_ids: Input token IDs tensor input_ids: Input token IDs tensor
input_embedding: Embedding layer for text tokens input_embedding: Embedding layer for text tokens
image_data_embedding_func: Function to embed image data
audio_data_embedding_func: Function to embed audio data
placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None) placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None)
Returns: Returns:
...@@ -415,88 +413,53 @@ def embed_mm_inputs( ...@@ -415,88 +413,53 @@ def embed_mm_inputs(
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None] item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
embeddings, masks = [], [] embeddings, masks = [], []
# 2. Get multimodal embedding separately # 2. Get multimodal embedding separately
# TODO: make this more generic # Try get mm embedding if any
# Try get image embedding if any for modality in Modality.all():
if ( items = [
any(True for item in item_flatten_list if item.is_image()) item for item in item_flatten_list if item.is_modality(modality=modality)
and image_data_embedding_func ]
): embedder = (
items = [item for item in item_flatten_list if item.is_image()] None
placeholder_tensor = torch.tensor( if data_embedding_func_mapping is None
[item.pad_value for item in items], else data_embedding_func_mapping.get(modality, None)
device=input_ids.device,
) )
# calculate per request items length offset if embedder is None:
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int) # "image", "video", etc
items_offsets = [] modality_id = modality.name.lower()
for i, mm_inputs in enumerate(mm_inputs_list): embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
image_items = [item for item in mm_inputs.mm_items if item.is_image()] if len(items) != 0 and embedder is not None:
items_size[i + 1] = len(image_items) placeholder_tensor = torch.tensor(
items_offsets.append( [item.pad_value for item in items],
flatten_nested_list( device=input_ids.device,
[
item.image_offsets
for item in mm_inputs.mm_items
if item.is_image()
]
)
) )
items_size = torch.cumsum(items_size, dim=0).tolist() # calculate per request items length offset
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
embedding, mask = get_embedding_and_mask( items_offsets = []
data_embedding_func=image_data_embedding_func, for i, mm_inputs in enumerate(mm_inputs_list):
embedding_items=items, mm_items = [
placeholder_tensor=placeholder_tensor, item
input_ids=input_ids, for item in mm_inputs.mm_items
items_size=items_size, if item.is_modality(modality=modality)
prefix_length=extend_prefix_lens, ]
extend_length=extend_seq_lens, items_size[i + 1] = len(mm_items)
items_offset_list=items_offsets, items_offsets.append(
) flatten_nested_list([item.offsets for item in mm_inputs.mm_items])
embeddings += [embedding]
masks += [mask]
# Try get audio embedding if any
if (
any(True for item in item_flatten_list if item.is_audio())
and audio_data_embedding_func
):
items = [item for item in item_flatten_list if item.is_audio()]
placeholder_tensor = torch.tensor(
[item.pad_value for item in items],
device=input_ids.device,
)
items_offsets = []
# calculate per request items length offset
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
for i, mm_inputs in enumerate(mm_inputs_list):
audio_items = [item for item in mm_inputs.mm_items if item.is_audio()]
items_size[i + 1] = len(audio_items)
items_offsets.append(
flatten_nested_list(
[
item.audio_offsets
for item in mm_inputs.mm_items
if item.is_audio()
]
) )
items_size = torch.cumsum(items_size, dim=0).tolist()
embedding, mask = get_embedding_and_mask(
data_embedding_func=embedder,
embedding_items=items,
placeholder_tensor=placeholder_tensor,
input_ids=input_ids,
items_size=items_size,
prefix_length=extend_prefix_lens,
extend_length=extend_seq_lens,
items_offset_list=items_offsets,
) )
items_size = torch.cumsum(items_size, dim=0) embeddings += [embedding]
masks += [mask]
embedding, mask = get_embedding_and_mask(
data_embedding_func=audio_data_embedding_func,
embedding_items=items,
placeholder_tensor=placeholder_tensor,
input_ids=input_ids,
items_size=items_size,
prefix_length=extend_prefix_lens,
extend_length=extend_seq_lens,
items_offset_list=items_offsets,
)
embeddings += [embedding]
masks += [mask]
# 3. Get input embeddings # 3. Get input embeddings
vocab_size = input_embedding.num_embeddings vocab_size = input_embedding.num_embeddings
...@@ -523,11 +486,9 @@ def general_mm_embed_routine( ...@@ -523,11 +486,9 @@ def general_mm_embed_routine(
input_ids: torch.Tensor, input_ids: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
language_model: nn.Module, language_model: nn.Module,
image_data_embedding_func: Optional[ multimodal_model: Optional[nn.Module] = None,
Callable[[List[MultimodalDataItem]], torch.Tensor] data_embedding_funcs: Dict[
] = None, Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
audio_data_embedding_func: Optional[
Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None, ] = None,
placeholder_tokens: Optional[dict[Modality, List[int]]] = None, placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
**kwargs, **kwargs,
...@@ -572,8 +533,8 @@ def general_mm_embed_routine( ...@@ -572,8 +533,8 @@ def general_mm_embed_routine(
extend_seq_lens=extend_seq_lens, extend_seq_lens=extend_seq_lens,
input_ids=input_ids, input_ids=input_ids,
input_embedding=embed_tokens, input_embedding=embed_tokens,
image_data_embedding_func=image_data_embedding_func, multimodal_model=multimodal_model,
audio_data_embedding_func=audio_data_embedding_func, data_embedding_func_mapping=data_embedding_funcs,
placeholder_tokens=placeholder_tokens, placeholder_tokens=placeholder_tokens,
) )
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models # once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
......
...@@ -185,6 +185,10 @@ class Modality(Enum): ...@@ -185,6 +185,10 @@ class Modality(Enum):
f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}" f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
) )
@staticmethod
def all():
return [Modality.IMAGE, Modality.VIDEO, Modality.AUDIO]
@dataclasses.dataclass @dataclasses.dataclass
class MultimodalDataItem: class MultimodalDataItem:
...@@ -200,7 +204,7 @@ class MultimodalDataItem: ...@@ -200,7 +204,7 @@ class MultimodalDataItem:
hash: int = None hash: int = None
pad_value: int = None pad_value: int = None
image_sizes: Tuple[int, int] = None image_sizes: Tuple[int, int] = None
image_offsets: Optional[list] = None offsets: Optional[list] = None
# the real data, pixel_values or audio_features # the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.ndarray]] # data: Union[List[torch.Tensor], List[np.ndarray]]
...@@ -253,12 +257,17 @@ class MultimodalDataItem: ...@@ -253,12 +257,17 @@ class MultimodalDataItem:
self.hash = hash_feature(self.audio_features) self.hash = hash_feature(self.audio_features)
elif self.input_features is not None: elif self.input_features is not None:
self.hash = hash_feature(self.input_features) self.hash = hash_feature(self.input_features)
elif self.is_video():
self.hash = hash_feature(self.pixel_values_videos)
else: else:
self.hash = hash_feature(self.pixel_values) self.hash = hash_feature(self.pixel_values)
assert self.hash is not None assert self.hash is not None
self.pad_value = self.hash % (1 << 30) self.pad_value = self.hash % (1 << 30)
def is_modality(self, modality: Modality) -> bool:
return self.modality == modality
def is_audio(self): def is_audio(self):
return (self.modality == Modality.AUDIO) and ( return (self.modality == Modality.AUDIO) and (
self.precomputed_features is not None self.precomputed_features is not None
...@@ -268,7 +277,7 @@ class MultimodalDataItem: ...@@ -268,7 +277,7 @@ class MultimodalDataItem:
def is_image(self): def is_image(self):
return ( return (
self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES self.is_modality(Modality.IMAGE) or self.is_modality(Modality.MULTI_IMAGES)
) and ( ) and (
self.precomputed_features is not None self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.pixel_values) or not MultimodalDataItem.is_empty_list(self.pixel_values)
...@@ -277,7 +286,7 @@ class MultimodalDataItem: ...@@ -277,7 +286,7 @@ class MultimodalDataItem:
def is_video(self): def is_video(self):
return (self.modality == Modality.VIDEO) and ( return (self.modality == Modality.VIDEO) and (
self.precomputed_features is not None self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.pixel_values) or not MultimodalDataItem.is_empty_list(self.pixel_values_videos)
) )
def is_valid(self) -> bool: def is_valid(self) -> bool:
...@@ -351,6 +360,7 @@ class MultimodalInputs: ...@@ -351,6 +360,7 @@ class MultimodalInputs:
"im_token_id", "im_token_id",
"im_start_id", "im_start_id",
"im_end_id", "im_end_id",
"video_token_id",
"slice_start_id", "slice_start_id",
"slice_end_id", "slice_end_id",
"audio_start_id", "audio_start_id",
...@@ -364,11 +374,12 @@ class MultimodalInputs: ...@@ -364,11 +374,12 @@ class MultimodalInputs:
return ret return ret
def contains_image_inputs(self) -> bool: def contains_image_inputs(self) -> bool:
""" """
return any(item.is_image() for item in self.mm_items) return any(item.is_image() for item in self.mm_items)
def contains_video_inputs(self) -> bool:
return any(item.is_video() for item in self.mm_items)
def contains_audio_inputs(self) -> bool: def contains_audio_inputs(self) -> bool:
""" """
return any(item.is_audio() for item in self.mm_items) return any(item.is_audio() for item in self.mm_items)
def contains_mm_input(self) -> bool: def contains_mm_input(self) -> bool:
......
...@@ -453,8 +453,20 @@ class ForwardBatch: ...@@ -453,8 +453,20 @@ class ForwardBatch:
for mm_input in self.mm_inputs for mm_input in self.mm_inputs
) )
def contains_video_inputs(self) -> bool:
if self.mm_inputs is None:
return False
return any(
mm_input is not None and mm_input.contains_video_inputs()
for mm_input in self.mm_inputs
)
def contains_mm_inputs(self) -> bool: def contains_mm_inputs(self) -> bool:
return self.contains_audio_inputs() or self.contains_image_inputs() return (
self.contains_audio_inputs()
or self.contains_video_inputs()
or self.contains_image_inputs()
)
def _compute_mrope_positions( def _compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch self, model_runner: ModelRunner, batch: ModelWorkerBatch
......
...@@ -1989,7 +1989,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel): ...@@ -1989,7 +1989,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
hidden_states = general_mm_embed_routine( hidden_states = general_mm_embed_routine(
input_ids=input_ids, input_ids=input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
image_data_embedding_func=self.get_image_feature, multimodal_model=self,
language_model=self.language_model, language_model=self.language_model,
positions=positions, positions=positions,
) )
......
...@@ -227,7 +227,7 @@ class DeepseekVL2ForCausalLM(nn.Module): ...@@ -227,7 +227,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
forward_batch=forward_batch, forward_batch=forward_batch,
image_data_embedding_func=self.get_image_feature, multimodal_model=self,
language_model=self.language_model, language_model=self.language_model,
) )
......
...@@ -374,7 +374,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -374,7 +374,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
input_ids=llm_input_ids, input_ids=llm_input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
language_model=self.language_model, language_model=self.language_model,
image_data_embedding_func=self.get_image_feature, multimodal_model=self,
positions=positions, positions=positions,
) )
......
import logging import logging
import re import re
from functools import lru_cache from functools import lru_cache
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict, Union from typing import Iterable, List, Optional, Set, Tuple, TypedDict, Union
import torch import torch
from torch import nn from torch import nn
...@@ -25,6 +25,7 @@ from sglang.srt.managers.mm_utils import ( ...@@ -25,6 +25,7 @@ from sglang.srt.managers.mm_utils import (
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem, MultimodalDataItem,
MultimodalInputs, MultimodalInputs,
flatten_nested_list, flatten_nested_list,
...@@ -434,8 +435,10 @@ class Gemma3nForConditionalGeneration(PreTrainedModel): ...@@ -434,8 +435,10 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
input_ids=input_ids, input_ids=input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
language_model=self.language_model, language_model=self.language_model,
image_data_embedding_func=self.get_image_feature, data_embedding_funcs={
audio_data_embedding_func=self.get_audio_feature, Modality.IMAGE: self.get_image_feature,
Modality.AUDIO: self.get_audio_feature,
},
positions=positions, positions=positions,
per_layer_inputs=per_layer_inputs, per_layer_inputs=per_layer_inputs,
) )
......
...@@ -29,7 +29,11 @@ from sglang.srt.managers.mm_utils import ( ...@@ -29,7 +29,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_janus_pro import DropPath from sglang.srt.models.deepseek_janus_pro import DropPath
...@@ -523,7 +527,9 @@ class InternVLChatModel(nn.Module): ...@@ -523,7 +527,9 @@ class InternVLChatModel(nn.Module):
input_ids=input_ids, input_ids=input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
language_model=self.language_model, language_model=self.language_model,
image_data_embedding_func=self.get_image_feature, data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
},
positions=positions, positions=positions,
) )
......
...@@ -67,7 +67,11 @@ from sglang.srt.managers.mm_utils import ( ...@@ -67,7 +67,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens, MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
...@@ -168,7 +172,9 @@ class KimiVLForConditionalGeneration(nn.Module): ...@@ -168,7 +172,9 @@ class KimiVLForConditionalGeneration(nn.Module):
input_ids=input_ids, input_ids=input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
language_model=self.language_model, language_model=self.language_model,
image_data_embedding_func=self.get_image_feature, data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
},
positions=positions, positions=positions,
) )
......
...@@ -787,7 +787,9 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM): ...@@ -787,7 +787,9 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
forward_batch=forward_batch, forward_batch=forward_batch,
get_embedding=get_embedding, get_embedding=get_embedding,
language_model=self.language_model, language_model=self.language_model,
image_data_embedding_func=self.get_image_feature, data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
},
placeholder_tokens=None, # using mm_item.pad_value placeholder_tokens=None, # using mm_item.pad_value
positions=positions, positions=positions,
) )
......
...@@ -142,7 +142,7 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -142,7 +142,7 @@ class LlavaVidForCausalLM(nn.Module):
) )
image_offsets = [ image_offsets = [
flatten_nested_list( flatten_nested_list(
[item.image_offsets for item in image_inputs[i].mm_items] [item.offsets for item in image_inputs[i].mm_items]
) )
for i in range(bs) for i in range(bs)
if need_vision[i] if need_vision[i]
......
...@@ -1827,8 +1827,7 @@ class MiniCPMO(MiniCPMBaseModel): ...@@ -1827,8 +1827,7 @@ class MiniCPMO(MiniCPMBaseModel):
input_ids=input_ids, input_ids=input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
language_model=self.llm, language_model=self.llm,
image_data_embedding_func=self.get_image_feature, multimodal_model=self,
audio_data_embedding_func=self.get_audio_feature,
positions=positions, positions=positions,
) )
return hidden_states return hidden_states
......
...@@ -573,7 +573,7 @@ class MiniCPMBaseModel(nn.Module): ...@@ -573,7 +573,7 @@ class MiniCPMBaseModel(nn.Module):
hidden_states = general_mm_embed_routine( hidden_states = general_mm_embed_routine(
input_ids=input_ids, input_ids=input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
image_data_embedding_func=self.get_image_feature, multimodal_model=self,
language_model=self.llm, language_model=self.llm,
positions=positions, positions=positions,
) )
......
...@@ -6,8 +6,11 @@ from typing import List, Optional, Set, Tuple ...@@ -6,8 +6,11 @@ from typing import List, Optional, Set, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import Llama4Config, Llama4VisionModel from transformers import Llama4Config
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector from transformers.models.llama4.modeling_llama4 import (
Llama4MultiModalProjector,
Llama4VisionModel,
)
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
...@@ -16,7 +19,11 @@ from sglang.srt.managers.mm_utils import ( ...@@ -16,7 +19,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens, MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cpu from sglang.srt.utils import add_prefix, is_cpu
...@@ -166,7 +173,9 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -166,7 +173,9 @@ class Llama4ForConditionalGeneration(nn.Module):
input_ids=input_ids, input_ids=input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
language_model=self.language_model, language_model=self.language_model,
image_data_embedding_func=image_embedding_func, data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
},
positions=positions, positions=positions,
) )
......
...@@ -31,7 +31,11 @@ from sglang.srt.managers.mm_utils import ( ...@@ -31,7 +31,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens, MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.idefics2 import Idefics2VisionTransformer from sglang.srt.models.idefics2 import Idefics2VisionTransformer
...@@ -439,7 +443,9 @@ class Phi4MMForCausalLM(nn.Module): ...@@ -439,7 +443,9 @@ class Phi4MMForCausalLM(nn.Module):
input_ids=input_ids, input_ids=input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
language_model=self.language_model, language_model=self.language_model,
image_data_embedding_func=self.get_image_feature, data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
},
positions=positions, positions=positions,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment