Unverified Commit b5e3d603 authored by Mick's avatar Mick Committed by GitHub
Browse files

vlm: support video as an input modality (#5888)

parent 4ed57807
......@@ -88,9 +88,11 @@ class Conversation:
stop_str: Union[str, List[str]] = None
# The string that represents an image token in the prompt
image_token: str = "<image>"
video_token: str = "<video>"
audio_token: str = "<audio>"
image_data: Optional[List[str]] = None
video_data: Optional[List[str]] = None
modalities: Optional[List[str]] = None
stop_token_ids: Optional[int] = None
......@@ -380,11 +382,15 @@ class Conversation:
self.messages.append([role, message])
def append_image(self, image: str):
"""Append a new message."""
"""Append a new image."""
self.image_data.append(image)
def append_video(self, video: str):
"""Append a new video."""
self.video_data.append(video)
def append_audio(self, audio: str):
"""Append a new message."""
"""Append a new audio."""
self.audio_data.append(audio)
def update_last_message(self, message: str):
......@@ -433,6 +439,7 @@ class Conversation:
sep2=self.sep2,
stop_str=self.stop_str,
image_token=self.image_token,
video_token=self.video_token,
audio_token=self.audio_token,
)
......@@ -495,8 +502,12 @@ def generate_embedding_convs(
sep2=conv_template.sep2,
stop_str=conv_template.stop_str,
image_data=[],
video_data=[],
audio_data=[],
modalities=[],
image_token=conv_template.image_token,
video_token=conv_template.video_token,
audio_token=conv_template.audio_token,
)
real_content = ""
......@@ -557,10 +568,12 @@ def generate_chat_conv(
sep2=conv.sep2,
stop_str=conv.stop_str,
image_data=[],
video_data=[],
audio_data=[],
modalities=[],
image_token=conv.image_token,
audio_token=conv.audio_token,
video_token=conv.video_token,
)
if isinstance(request.messages, str):
......@@ -602,6 +615,7 @@ def generate_chat_conv(
image_token = ""
audio_token = conv.audio_token
video_token = conv.video_token
for content in message.content:
if content.type == "text":
if num_image_url > 16:
......@@ -614,6 +628,9 @@ def generate_chat_conv(
else:
real_content += image_token
conv.append_image(content.image_url.url)
elif content.type == "video_url":
real_content += video_token
conv.append_video(content.video_url.url)
elif content.type == "audio_url":
real_content += audio_token
conv.append_audio(content.audio_url.url)
......@@ -810,6 +827,7 @@ register_conv_template(
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
stop_str=["<|im_end|>"],
image_token="<|vision_start|><|image_pad|><|vision_end|>",
video_token="<|vision_start|><|video_pad|><|vision_end|>",
)
)
......@@ -870,6 +888,7 @@ register_conv_template(
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
stop_str=("<|im_end|>", "<|endoftext|>"),
image_token="(<image>./</image>)",
video_token="(<video>./</video>)",
)
)
......
......@@ -267,6 +267,10 @@ class ChatCompletionMessageContentImageURL(BaseModel):
detail: Optional[Literal["auto", "low", "high"]] = "auto"
class ChatCompletionMessageContentVideoURL(BaseModel):
url: str
class ChatCompletionMessageContentAudioURL(BaseModel):
url: str
......@@ -277,6 +281,11 @@ class ChatCompletionMessageContentImagePart(BaseModel):
modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
class ChatCompletionMessageContentVideoPart(BaseModel):
type: Literal["video_url"]
video_url: ChatCompletionMessageContentVideoURL
class ChatCompletionMessageContentAudioPart(BaseModel):
type: Literal["audio_url"]
audio_url: ChatCompletionMessageContentAudioURL
......@@ -285,6 +294,7 @@ class ChatCompletionMessageContentAudioPart(BaseModel):
ChatCompletionMessageContentPart = Union[
ChatCompletionMessageContentTextPart,
ChatCompletionMessageContentImagePart,
ChatCompletionMessageContentVideoPart,
ChatCompletionMessageContentAudioPart,
]
......@@ -629,6 +639,7 @@ class MessageProcessingResult:
prompt_ids: Union[str, List[int]]
image_data: Optional[Any]
audio_data: Optional[Any]
video_data: Optional[Any]
modalities: List[str]
stop: List[str]
tool_call_constraint: Optional[Any] = None
......@@ -82,6 +82,7 @@ class OpenAIServingChat(OpenAIServingBase):
adapted_request = GenerateReqInput(
**prompt_kwargs,
image_data=processed_messages.image_data,
video_data=processed_messages.video_data,
audio_data=processed_messages.audio_data,
sampling_params=sampling_params,
return_logprob=request.logprobs,
......@@ -143,6 +144,7 @@ class OpenAIServingChat(OpenAIServingBase):
prompt_ids = []
openai_compatible_messages = []
image_data = []
video_data = []
audio_data = []
modalities = []
......@@ -158,6 +160,7 @@ class OpenAIServingChat(OpenAIServingBase):
msg_dict,
template_content_format,
image_data,
video_data,
audio_data,
modalities,
)
......@@ -214,11 +217,13 @@ class OpenAIServingChat(OpenAIServingBase):
stop = request.stop
image_data = image_data if image_data else None
audio_data = audio_data if audio_data else None
video_data = video_data if video_data else None
modalities = modalities if modalities else []
return MessageProcessingResult(
prompt=prompt,
prompt_ids=prompt_ids,
image_data=image_data,
video_data=video_data,
audio_data=audio_data,
modalities=modalities,
stop=stop,
......@@ -260,6 +265,7 @@ class OpenAIServingChat(OpenAIServingBase):
prompt = conv.get_prompt()
image_data = conv.image_data if conv.image_data else None
video_data = conv.video_data if conv.video_data else None
audio_data = conv.audio_data if conv.audio_data else None
modalities = conv.modalities if conv.modalities else []
stop = copy.copy(conv.stop_str or [] if not request.ignore_eos else [])
......@@ -277,6 +283,7 @@ class OpenAIServingChat(OpenAIServingBase):
prompt=prompt,
prompt_ids=prompt_ids,
image_data=image_data,
video_data=video_data,
audio_data=audio_data,
modalities=modalities,
stop=stop,
......
......@@ -110,6 +110,7 @@ def process_content_for_template_format(
msg_dict: dict,
content_format: str,
image_data: list,
video_data: list,
audio_data: list,
modalities: list,
) -> dict:
......@@ -120,6 +121,7 @@ def process_content_for_template_format(
msg_dict: Message dictionary with content
content_format: 'string' or 'openai' (detected via AST analysis)
image_data: List to append extracted image URLs
video_data: List to append extracted video URLs
audio_data: List to append extracted audio URLs
modalities: List to append modalities
......@@ -143,6 +145,12 @@ def process_content_for_template_format(
modalities.append(chunk.get("modalities"))
# Normalize to simple 'image' type for template compatibility
processed_content_parts.append({"type": "image"})
elif chunk_type == "video_url":
video_data.append(chunk["video_url"]["url"])
if chunk.get("modalities"):
modalities.append(chunk.get("modalities"))
# Normalize to simple 'video' type for template compatibility
processed_content_parts.append({"type": "video"})
elif chunk_type == "audio_url":
audio_data.append(chunk["audio_url"]["url"])
# Normalize to simple 'audio' type
......
......@@ -65,6 +65,8 @@ class GenerateReqInput:
] = None
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
video_data: Optional[Union[List[List[str]], List[str], str]] = None
# The sampling_params. See descriptions below.
sampling_params: Optional[Union[List[Dict], Dict]] = None
# The request id.
......@@ -110,7 +112,11 @@ class GenerateReqInput:
data_parallel_rank: Optional[int] = None
def contains_mm_input(self) -> bool:
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
return (
has_valid_data(self.image_data)
or has_valid_data(self.video_data)
or has_valid_data(self.audio_data)
)
def normalize_batch_and_arguments(self):
"""
......@@ -232,6 +238,7 @@ class GenerateReqInput:
self._normalize_rid(num)
self._normalize_lora_paths(num)
self._normalize_image_data(num)
self._normalize_video_data(num)
self._normalize_audio_data(num)
self._normalize_sampling_params(num)
self._normalize_logprob_params(num)
......@@ -300,6 +307,15 @@ class GenerateReqInput:
self.image_data = wrapped_images * self.parallel_sample_num
self.modalities = ["image"] * num
def _normalize_video_data(self, num):
"""Normalize video data for batch processing."""
if self.video_data is None:
self.video_data = [None] * num
elif not isinstance(self.video_data, list):
self.video_data = [self.video_data] * num
elif isinstance(self.video_data, list):
self.video_data = self.video_data * self.parallel_sample_num
def _normalize_audio_data(self, num):
"""Normalize audio data for batch processing."""
if self.audio_data is None:
......@@ -408,6 +424,7 @@ class GenerateReqInput:
self.input_embeds[i] if self.input_embeds is not None else None
),
image_data=self.image_data[i],
video_data=self.video_data[i],
audio_data=self.audio_data[i],
sampling_params=self.sampling_params[i],
rid=self.rid[i],
......@@ -507,6 +524,8 @@ class EmbeddingReqInput:
image_data: Optional[
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
] = None
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
video_data: Optional[Union[List[str], str]] = None
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
audio_data: Optional[Union[List[str], str]] = None
# The token ids for text; one can either specify text or input_ids.
......@@ -578,7 +597,11 @@ class EmbeddingReqInput:
return self.rid
def contains_mm_input(self) -> bool:
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
return (
has_valid_data(self.image_data)
or has_valid_data(self.video_data)
or has_valid_data(self.audio_data)
)
def __getitem__(self, i):
if self.is_cross_encoder_request:
......
......@@ -4,7 +4,7 @@ Multi-modality utils
import hashlib
from abc import abstractmethod
from typing import Callable, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
......@@ -76,6 +76,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
This function will replace the data-tokens in between with pad_values accordingly
"""
pad_values = [item.pad_value for item in mm_inputs.mm_items]
print(f"{mm_inputs.mm_items=}")
data_token_pairs = self.data_token_id_pairs
mm_inputs.data_offsets = []
if data_token_pairs is None:
......@@ -159,10 +160,10 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
return ret_input_ids
embedding_cache = None
embedding_cache: Optional[MultiModalCache] = None
def init_embedding_cache(max_size: int):
def init_embedding_cache(max_size: int = 0):
global embedding_cache
embedding_cache = MultiModalCache(max_size)
......@@ -255,6 +256,7 @@ def _get_chunked_prefill_embedding(
continue
embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
items_offset = items_offset_list[i]
assert items_offset is not None, items_offset
embedding_items_hash = get_embedding_hash(embedding_items_per_req)
# if all items has been prefixed, we do not need to calculate embedding
if all([offset_end < prefix_length[i] for _, offset_end in items_offset]):
......@@ -380,11 +382,9 @@ def embed_mm_inputs(
extend_seq_lens: List[int],
input_ids: torch.Tensor,
input_embedding: nn.Embedding,
image_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
] = None,
audio_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor
multimodal_model: nn.Module = None,
data_embedding_func_mapping: Dict[
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: dict[Modality, List[int]] = None,
) -> Optional[torch.Tensor]:
......@@ -397,8 +397,6 @@ def embed_mm_inputs(
extend_seq_lens: Sequence lengths for each request
input_ids: Input token IDs tensor
input_embedding: Embedding layer for text tokens
image_data_embedding_func: Function to embed image data
audio_data_embedding_func: Function to embed audio data
placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None)
Returns:
......@@ -415,88 +413,53 @@ def embed_mm_inputs(
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
embeddings, masks = [], []
# 2. Get multimodal embedding separately
# TODO: make this more generic
# Try get image embedding if any
if (
any(True for item in item_flatten_list if item.is_image())
and image_data_embedding_func
):
items = [item for item in item_flatten_list if item.is_image()]
placeholder_tensor = torch.tensor(
[item.pad_value for item in items],
device=input_ids.device,
# Try get mm embedding if any
for modality in Modality.all():
items = [
item for item in item_flatten_list if item.is_modality(modality=modality)
]
embedder = (
None
if data_embedding_func_mapping is None
else data_embedding_func_mapping.get(modality, None)
)
# calculate per request items length offset
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
items_offsets = []
for i, mm_inputs in enumerate(mm_inputs_list):
image_items = [item for item in mm_inputs.mm_items if item.is_image()]
items_size[i + 1] = len(image_items)
items_offsets.append(
flatten_nested_list(
[
item.image_offsets
for item in mm_inputs.mm_items
if item.is_image()
]
)
if embedder is None:
# "image", "video", etc
modality_id = modality.name.lower()
embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
if len(items) != 0 and embedder is not None:
placeholder_tensor = torch.tensor(
[item.pad_value for item in items],
device=input_ids.device,
)
items_size = torch.cumsum(items_size, dim=0).tolist()
embedding, mask = get_embedding_and_mask(
data_embedding_func=image_data_embedding_func,
embedding_items=items,
placeholder_tensor=placeholder_tensor,
input_ids=input_ids,
items_size=items_size,
prefix_length=extend_prefix_lens,
extend_length=extend_seq_lens,
items_offset_list=items_offsets,
)
embeddings += [embedding]
masks += [mask]
# Try get audio embedding if any
if (
any(True for item in item_flatten_list if item.is_audio())
and audio_data_embedding_func
):
items = [item for item in item_flatten_list if item.is_audio()]
placeholder_tensor = torch.tensor(
[item.pad_value for item in items],
device=input_ids.device,
)
items_offsets = []
# calculate per request items length offset
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
for i, mm_inputs in enumerate(mm_inputs_list):
audio_items = [item for item in mm_inputs.mm_items if item.is_audio()]
items_size[i + 1] = len(audio_items)
items_offsets.append(
flatten_nested_list(
[
item.audio_offsets
for item in mm_inputs.mm_items
if item.is_audio()
]
# calculate per request items length offset
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
items_offsets = []
for i, mm_inputs in enumerate(mm_inputs_list):
mm_items = [
item
for item in mm_inputs.mm_items
if item.is_modality(modality=modality)
]
items_size[i + 1] = len(mm_items)
items_offsets.append(
flatten_nested_list([item.offsets for item in mm_inputs.mm_items])
)
items_size = torch.cumsum(items_size, dim=0).tolist()
embedding, mask = get_embedding_and_mask(
data_embedding_func=embedder,
embedding_items=items,
placeholder_tensor=placeholder_tensor,
input_ids=input_ids,
items_size=items_size,
prefix_length=extend_prefix_lens,
extend_length=extend_seq_lens,
items_offset_list=items_offsets,
)
items_size = torch.cumsum(items_size, dim=0)
embedding, mask = get_embedding_and_mask(
data_embedding_func=audio_data_embedding_func,
embedding_items=items,
placeholder_tensor=placeholder_tensor,
input_ids=input_ids,
items_size=items_size,
prefix_length=extend_prefix_lens,
extend_length=extend_seq_lens,
items_offset_list=items_offsets,
)
embeddings += [embedding]
masks += [mask]
embeddings += [embedding]
masks += [mask]
# 3. Get input embeddings
vocab_size = input_embedding.num_embeddings
......@@ -523,11 +486,9 @@ def general_mm_embed_routine(
input_ids: torch.Tensor,
forward_batch: ForwardBatch,
language_model: nn.Module,
image_data_embedding_func: Optional[
Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
audio_data_embedding_func: Optional[
Callable[[List[MultimodalDataItem]], torch.Tensor]
multimodal_model: Optional[nn.Module] = None,
data_embedding_funcs: Dict[
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
**kwargs,
......@@ -572,8 +533,8 @@ def general_mm_embed_routine(
extend_seq_lens=extend_seq_lens,
input_ids=input_ids,
input_embedding=embed_tokens,
image_data_embedding_func=image_data_embedding_func,
audio_data_embedding_func=audio_data_embedding_func,
multimodal_model=multimodal_model,
data_embedding_func_mapping=data_embedding_funcs,
placeholder_tokens=placeholder_tokens,
)
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
......
......@@ -185,6 +185,10 @@ class Modality(Enum):
f"Invalid modality string: {modality_str}. Valid modalities are: {[m.name for m in Modality]}"
)
@staticmethod
def all():
return [Modality.IMAGE, Modality.VIDEO, Modality.AUDIO]
@dataclasses.dataclass
class MultimodalDataItem:
......@@ -200,7 +204,7 @@ class MultimodalDataItem:
hash: int = None
pad_value: int = None
image_sizes: Tuple[int, int] = None
image_offsets: Optional[list] = None
offsets: Optional[list] = None
# the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.ndarray]]
......@@ -253,12 +257,17 @@ class MultimodalDataItem:
self.hash = hash_feature(self.audio_features)
elif self.input_features is not None:
self.hash = hash_feature(self.input_features)
elif self.is_video():
self.hash = hash_feature(self.pixel_values_videos)
else:
self.hash = hash_feature(self.pixel_values)
assert self.hash is not None
self.pad_value = self.hash % (1 << 30)
def is_modality(self, modality: Modality) -> bool:
return self.modality == modality
def is_audio(self):
return (self.modality == Modality.AUDIO) and (
self.precomputed_features is not None
......@@ -268,7 +277,7 @@ class MultimodalDataItem:
def is_image(self):
return (
self.modality == Modality.IMAGE or self.modality == Modality.MULTI_IMAGES
self.is_modality(Modality.IMAGE) or self.is_modality(Modality.MULTI_IMAGES)
) and (
self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.pixel_values)
......@@ -277,7 +286,7 @@ class MultimodalDataItem:
def is_video(self):
return (self.modality == Modality.VIDEO) and (
self.precomputed_features is not None
or not MultimodalDataItem.is_empty_list(self.pixel_values)
or not MultimodalDataItem.is_empty_list(self.pixel_values_videos)
)
def is_valid(self) -> bool:
......@@ -351,6 +360,7 @@ class MultimodalInputs:
"im_token_id",
"im_start_id",
"im_end_id",
"video_token_id",
"slice_start_id",
"slice_end_id",
"audio_start_id",
......@@ -364,11 +374,12 @@ class MultimodalInputs:
return ret
def contains_image_inputs(self) -> bool:
""" """
return any(item.is_image() for item in self.mm_items)
def contains_video_inputs(self) -> bool:
return any(item.is_video() for item in self.mm_items)
def contains_audio_inputs(self) -> bool:
""" """
return any(item.is_audio() for item in self.mm_items)
def contains_mm_input(self) -> bool:
......
......@@ -453,8 +453,20 @@ class ForwardBatch:
for mm_input in self.mm_inputs
)
def contains_video_inputs(self) -> bool:
if self.mm_inputs is None:
return False
return any(
mm_input is not None and mm_input.contains_video_inputs()
for mm_input in self.mm_inputs
)
def contains_mm_inputs(self) -> bool:
return self.contains_audio_inputs() or self.contains_image_inputs()
return (
self.contains_audio_inputs()
or self.contains_video_inputs()
or self.contains_image_inputs()
)
def _compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
......
......@@ -1989,7 +1989,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
image_data_embedding_func=self.get_image_feature,
multimodal_model=self,
language_model=self.language_model,
positions=positions,
)
......
......@@ -227,7 +227,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
image_data_embedding_func=self.get_image_feature,
multimodal_model=self,
language_model=self.language_model,
)
......
......@@ -374,7 +374,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
input_ids=llm_input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
image_data_embedding_func=self.get_image_feature,
multimodal_model=self,
positions=positions,
)
......
import logging
import re
from functools import lru_cache
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict, Union
from typing import Iterable, List, Optional, Set, Tuple, TypedDict, Union
import torch
from torch import nn
......@@ -25,6 +25,7 @@ from sglang.srt.managers.mm_utils import (
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
flatten_nested_list,
......@@ -434,8 +435,10 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
image_data_embedding_func=self.get_image_feature,
audio_data_embedding_func=self.get_audio_feature,
data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
Modality.AUDIO: self.get_audio_feature,
},
positions=positions,
per_layer_inputs=per_layer_inputs,
)
......
......@@ -29,7 +29,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_janus_pro import DropPath
......@@ -523,7 +527,9 @@ class InternVLChatModel(nn.Module):
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
image_data_embedding_func=self.get_image_feature,
data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
},
positions=positions,
)
......
......@@ -67,7 +67,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
......@@ -168,7 +172,9 @@ class KimiVLForConditionalGeneration(nn.Module):
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
image_data_embedding_func=self.get_image_feature,
data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
},
positions=positions,
)
......
......@@ -787,7 +787,9 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
forward_batch=forward_batch,
get_embedding=get_embedding,
language_model=self.language_model,
image_data_embedding_func=self.get_image_feature,
data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
},
placeholder_tokens=None, # using mm_item.pad_value
positions=positions,
)
......
......@@ -142,7 +142,7 @@ class LlavaVidForCausalLM(nn.Module):
)
image_offsets = [
flatten_nested_list(
[item.image_offsets for item in image_inputs[i].mm_items]
[item.offsets for item in image_inputs[i].mm_items]
)
for i in range(bs)
if need_vision[i]
......
......@@ -1827,8 +1827,7 @@ class MiniCPMO(MiniCPMBaseModel):
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.llm,
image_data_embedding_func=self.get_image_feature,
audio_data_embedding_func=self.get_audio_feature,
multimodal_model=self,
positions=positions,
)
return hidden_states
......
......@@ -573,7 +573,7 @@ class MiniCPMBaseModel(nn.Module):
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
image_data_embedding_func=self.get_image_feature,
multimodal_model=self,
language_model=self.llm,
positions=positions,
)
......
......@@ -6,8 +6,11 @@ from typing import List, Optional, Set, Tuple
import torch
from torch import nn
from transformers import Llama4Config, Llama4VisionModel
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
from transformers import Llama4Config
from transformers.models.llama4.modeling_llama4 import (
Llama4MultiModalProjector,
Llama4VisionModel,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
......@@ -16,7 +19,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cpu
......@@ -166,7 +173,9 @@ class Llama4ForConditionalGeneration(nn.Module):
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
image_data_embedding_func=image_embedding_func,
data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
},
positions=positions,
)
......
......@@ -31,7 +31,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.idefics2 import Idefics2VisionTransformer
......@@ -439,7 +443,9 @@ class Phi4MMForCausalLM(nn.Module):
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
image_data_embedding_func=self.get_image_feature,
data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
},
positions=positions,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment