Unverified Commit 5cb552b1 authored by Mick's avatar Mick Committed by GitHub
Browse files

refactor: multimodal data (#4754)

parent c7457191
......@@ -1308,6 +1308,9 @@ class DeepseekV2ForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config)
self.dp_size = get_attention_dp_size()
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
@torch.no_grad()
def forward(
self,
......
......@@ -11,7 +11,11 @@ from sglang.srt.configs.deepseekvl2 import (
)
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternImageTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
......@@ -150,7 +154,6 @@ class DeepseekVL2MlpProjector(nn.Module):
return x
# todo
class DeepseekVL2ForCausalLM(nn.Module):
def __init__(
......@@ -215,32 +218,15 @@ class DeepseekVL2ForCausalLM(nn.Module):
forward_batch: ForwardBatch,
**kwargs: object,
):
input_embeds = self.language_model.model.embed_tokens(input_ids)
if (
forward_batch.forward_mode.is_extend()
and forward_batch.contains_image_inputs()
):
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
for idx, image in enumerate(forward_batch.mm_inputs):
if image is None:
continue
start_idx = extend_start_loc_cpu[idx]
end_idx = start_idx + extend_seq_lens_cpu[idx]
images_emb_mask = image.images_emb_mask.to(device="cuda")
image_features = self.get_image_feature(image)
input_embeds[start_idx:end_idx] = input_embeds[
start_idx:end_idx
].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
outputs = self.language_model.forward(
hs = general_mm_embed_routine(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
input_embeds=input_embeds,
image_data_embedding_func=self.get_image_feature,
language_model=self.language_model,
)
return outputs
return hs
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......@@ -263,94 +249,109 @@ class DeepseekVL2ForCausalLM(nn.Module):
weights_loader(param, loaded_weight)
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
return input_ids
def get_image_feature(self, image_input: MultimodalInputs):
pixel_values = image_input.pixel_values.type(
next(self.vision.parameters()).dtype
).to(device=next(self.vision.parameters()).device)
image_feature = self.vision.forward_features(pixel_values)
images_embeds = self.projector(image_feature)
_, hw, n_dim = images_embeds.shape
h = w = int(hw**0.5)
tile_index = 0
helper = MultiModalityDataPaddingPatternImageTokens(
image_token_id=image_inputs.im_token_id
)
return helper.pad_input_tokens(input_ids, image_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]):
images_spatial_crop = torch.cat(
[item.image_spatial_crop for item in items], dim=0
)
assert images_spatial_crop.dim() == 3
# TODO: can it be batched ?
images_in_this_batch = []
images_spatial_crop = image_input.image_spatial_crop
for jdx in range(images_spatial_crop.shape[1]):
num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx]
if num_width_tiles == 0 or num_height_tiles == 0:
break
num_tiles_in_image = num_width_tiles * num_height_tiles
# [hw, D]
global_features = images_embeds[tile_index]
# [num_height_tiles * num_width_tiles, hw, D]
local_features = images_embeds[
tile_index + 1 : tile_index + 1 + num_tiles_in_image
]
tile_index += num_tiles_in_image + 1
# format global and local features
# ----------------- global view add newline -----------------
# [hw, D] -> [h, w, D]
global_features = global_features.view(h, w, n_dim)
# [D] -> [h, 1, D]
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
global_features = torch.cat([global_features, new_lines_in_global], dim=1)
# [h, w + 1, D] -> [h * (w + 1), D]
global_features = global_features.view(-1, n_dim)
# ----------------- local view add newline -----------------
# [num_height_tiles * num_width_tiles, h * w, D] ->
# [num_height_tiles * h, num_width_tiles * w, D]
local_features = rearrange(
local_features,
"(th tw) (h w) d -> (th h) (tw w) d",
th=num_height_tiles,
tw=num_width_tiles,
h=h,
w=w,
for item in items:
assert item.pixel_values.dim() == 4
image_feature = self.vision.forward_features(
item.pixel_values.type(next(self.vision.parameters()).dtype).to(
device=next(self.vision.parameters()).device
)
)
images_embeds = self.projector(image_feature)
_, hw, n_dim = images_embeds.shape
h = w = int(hw**0.5)
tile_index = 0
for jdx in range(item.image_spatial_crop.shape[1]):
num_width_tiles, num_height_tiles = item.image_spatial_crop[0, jdx]
if num_width_tiles == 0 or num_height_tiles == 0:
break
num_tiles_in_image = num_width_tiles * num_height_tiles
# [hw, D]
global_features = images_embeds[tile_index]
# [num_height_tiles * num_width_tiles, hw, D]
local_features = images_embeds[
tile_index + 1 : tile_index + 1 + num_tiles_in_image
]
tile_index += num_tiles_in_image + 1
# [D] -> [num_height_tiles * h, 1, D]
new_lines_in_local = repeat(
self.image_newline,
"d -> (th h) 1 d",
th=num_height_tiles,
h=h,
)
# format global and local features
# ----------------- global view add newline -----------------
# [hw, D] -> [h, w, D]
global_features = global_features.view(h, w, n_dim)
# [D] -> [h, 1, D]
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
local_features = torch.cat([local_features, new_lines_in_local], dim=1)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
local_features = local_features.view(-1, n_dim)
# merge global and local tiles
if self.global_view_pos == "head":
global_local_features = torch.cat(
[
global_features,
self.view_seperator[None, :],
local_features,
]
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
global_features = torch.cat(
[global_features, new_lines_in_global], dim=1
)
else:
global_local_features = torch.cat(
[
local_features,
self.view_seperator[None, :],
global_features,
]
# [h, w + 1, D] -> [h * (w + 1), D]
global_features = global_features.view(-1, n_dim)
# ----------------- local view add newline -----------------
# [num_height_tiles * num_width_tiles, h * w, D] ->
# [num_height_tiles * h, num_width_tiles * w, D]
local_features = rearrange(
local_features,
"(th tw) (h w) d -> (th h) (tw w) d",
th=num_height_tiles,
tw=num_width_tiles,
h=h,
w=w,
)
images_in_this_batch.append(global_local_features)
# [D] -> [num_height_tiles * h, 1, D]
new_lines_in_local = repeat(
self.image_newline,
"d -> (th h) 1 d",
th=num_height_tiles,
h=h,
)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
local_features = torch.cat([local_features, new_lines_in_local], dim=1)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
local_features = local_features.view(-1, n_dim)
# merge global and local tiles
if self.global_view_pos == "head":
global_local_features = torch.cat(
[
global_features,
self.view_seperator[None, :],
local_features,
]
)
else:
global_local_features = torch.cat(
[
local_features,
self.view_seperator[None, :],
global_features,
]
)
images_in_this_batch.append(global_local_features)
return torch.cat(images_in_this_batch, dim=0)
......
......@@ -21,14 +21,7 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
import torch
from torch import nn
from transformers import (
AutoModel,
BatchFeature,
Gemma3Config,
Gemma3Processor,
PreTrainedModel,
)
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
from transformers import AutoModel, Gemma3Config, PreTrainedModel
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.layernorm import Gemma3RMSNorm
......@@ -38,7 +31,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.managers.schedule_batch import (
MultimodalDataItem,
MultimodalInputs,
flatten_nested_list,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
......@@ -274,17 +271,16 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
"""
return self.language_model.get_attention_sliding_window_size()
def get_image_feature(self, image_input: MultimodalInputs):
def get_image_feature(self, items: List[MultimodalDataItem]):
"""
Projects the last hidden state from the vision model into language model space.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
pixel_values = image_input.pixel_values
pixel_values = torch.stack(
flatten_nested_list([item.pixel_values for item in items]), dim=0
)
pixel_values = pixel_values.to("cuda")
pixel_values = pixel_values.to(dtype=self.language_model.dtype())
......@@ -292,61 +288,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
image_features = self.multi_modal_projector(vision_outputs)
return image_features
def embed_mm_inputs(
self,
input_ids: torch.Tensor,
forward_batch: ForwardBatch,
image_input: MultimodalInputs,
) -> torch.Tensor:
if input_ids is None:
raise ValueError("Unimplemented")
# boolean-masking image tokens
special_image_mask = torch.isin(
input_ids,
torch.tensor(image_input.pad_values, device=input_ids.device),
).unsqueeze(-1)
num_image_tokens_in_input_ids = special_image_mask.sum()
inputs_embeds = None
if num_image_tokens_in_input_ids == 0:
inputs_embeds = self.get_input_embeddings()(input_ids)
return inputs_embeds
else:
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
image_features = self.get_image_feature(image_input.pixel_values)
# print(f"image tokens from image embeddings: {image_features.numel()}")
num_image_tokens_in_embedding = (
image_features.shape[0] * image_features.shape[1]
)
if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
num_image = num_image_tokens_in_input_ids // image_features.shape[1]
image_features = image_features[:num_image, :]
logger.warning(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
"tokens from image embeddings."
)
# Important: clamp after extracting original image boundaries
input_ids.clamp_(min=0, max=self.vocab_size - 1)
inputs_embeds = self.get_input_embeddings()(input_ids)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
image_features = image_features.to(
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(
special_image_mask, image_features
)
return inputs_embeds
@torch.no_grad()
def forward(
self,
......@@ -405,22 +346,15 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
else:
llm_input_ids = input_ids
inputs_embeds = general_mm_embed_routine(
hs = general_mm_embed_routine(
input_ids=llm_input_ids,
forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(),
mm_data_embedding_func=self.get_image_feature,
)
outputs = self.language_model(
input_ids=None,
language_model=self.language_model,
image_data_embedding_func=self.get_image_feature,
positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
**kwargs,
)
return outputs
return hs
def tie_weights(self):
return self.language_model.tie_weights()
......
......@@ -428,6 +428,9 @@ class LlamaForCausalLM(nn.Module):
else:
return self.pooler(hidden_states, forward_batch)
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
def get_hidden_dim(self, module_name):
# return input_dim, output_dim
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
......
......@@ -31,7 +31,7 @@ from transformers import (
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.managers.schedule_batch import Modality, MultimodalInputs
from sglang.srt.mm_utils import (
get_anyres_image_grid_shape,
unpad_image,
......@@ -42,17 +42,21 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.utils import add_prefix
from sglang.srt.utils import add_prefix, flatten_nested_list
class LlavaBaseForCausalLM(nn.Module):
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values
image_sizes = flatten_nested_list(
[item.image_sizes for item in image_inputs.mm_items]
)
pad_values = [item.pad_value for item in image_inputs.mm_items]
# hardcode for spatial_unpad + anyres
if image_inputs.modalities is not None and (
"multi-images" in image_inputs.modalities
or "video" in image_inputs.modalities
if any(
item.modality == Modality.MULTI_IMAGES or item.modality == Modality.VIDEO
for item in image_inputs.mm_items
):
image_aspect_ratio = "pad"
else:
......@@ -66,7 +70,7 @@ class LlavaBaseForCausalLM(nn.Module):
math.ceil(self.image_size / self.patch_size / 2) ** 2
)
else:
new_image_feature_len = self.image_feature_len # multiimage
new_image_feature_len = self.image_feature_len # multi-image
height = width = self.num_patches_per_side
if "anyres" in image_aspect_ratio:
......@@ -101,7 +105,7 @@ class LlavaBaseForCausalLM(nn.Module):
# old_len + pad_len - 1, because we need to remove image_token_id
input_ids = (
input_ids[:offset]
+ [pad_values[image_idx]] * new_image_feature_len
+ [pad_values[image_idx % len(pad_values)]] * new_image_feature_len
+ input_ids[offset + 1 :]
)
offset_list.append(offset)
......@@ -150,8 +154,8 @@ class LlavaBaseForCausalLM(nn.Module):
modalities_list = []
max_image_offset = []
for im in image_inputs:
if im and im.modalities is not None:
modalities_list.extend(im.modalities)
if im:
modalities_list.extend([item.modality for item in im.mm_items])
if im and im.image_offsets:
max_image_offset.append(
np.max(np.array(im.image_offsets) + np.array(im.image_pad_len))
......@@ -164,11 +168,19 @@ class LlavaBaseForCausalLM(nn.Module):
if need_vision.any():
bs = forward_batch.batch_size
pixel_values = [
image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
]
pixel_values = flatten_nested_list(
[
[item.pixel_values for item in image_inputs[i].mm_items]
for i in range(bs)
if need_vision[i]
]
)
image_sizes = [
image_inputs[i].image_sizes for i in range(bs) if need_vision[i]
flatten_nested_list(
[item.image_sizes for item in image_inputs[i].mm_items]
)
for i in range(bs)
if need_vision[i]
]
########## Encode Image ########
......@@ -197,13 +209,13 @@ class LlavaBaseForCausalLM(nn.Module):
new_image_features = []
height = width = self.num_patches_per_side
for image_idx, image_feature in enumerate(image_features):
if modalities_list[image_idx] == "image":
if modalities_list[image_idx] == Modality.IMAGE:
image_aspect_ratio = (
self.config.image_aspect_ratio
) # single image
elif (
modalities_list[image_idx] == "multi-images"
or modalities_list[image_idx] == "video"
modalities_list[image_idx] == Modality.MULTI_IMAGES
or modalities_list[image_idx] == Modality.VIDEO
):
image_aspect_ratio = "pad" # multi image
# image_aspect_ratio = (
......@@ -212,7 +224,7 @@ class LlavaBaseForCausalLM(nn.Module):
if (
image_feature.shape[0] > 1
and "anyres" in image_aspect_ratio
and modalities_list[image_idx] == "image"
and modalities_list[image_idx] == Modality.IMAGE
):
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
......@@ -312,7 +324,7 @@ class LlavaBaseForCausalLM(nn.Module):
)
image_feature = image_feature.unsqueeze(0)
else:
if modalities_list[image_idx] == "video": # video
if modalities_list[image_idx] == Modality.VIDEO: # video
# 2x2 pooling
num_of_frames = image_feature.shape[0]
image_feature = image_feature.view(
......
......@@ -22,7 +22,7 @@ from transformers import CLIPVisionModel, LlavaConfig
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.managers.schedule_batch import MultimodalInputs, flatten_nested_list
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM
......@@ -58,7 +58,7 @@ class LlavaVidForCausalLM(nn.Module):
)
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
pad_values = image_inputs.pad_values
pad_values = [item.pad_value for item in image_inputs.mm_items]
new_image_feature_len = self.image_feature_len
pad_ids = pad_values * (
......@@ -133,11 +133,19 @@ class LlavaVidForCausalLM(nn.Module):
need_vision = start_positions <= np.array(max_image_offset)
if need_vision.any():
pixel_values = [
image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
]
pixel_values = flatten_nested_list(
[
[item.pixel_values for item in image_inputs[i].mm_items]
for i in range(bs)
if need_vision[i]
]
)
image_offsets = [
image_inputs[i].image_offsets for i in range(bs) if need_vision[i]
flatten_nested_list(
[item.image_offsets for item in image_inputs[i].mm_items]
)
for i in range(bs)
if need_vision[i]
]
########## Encode Image ########
......@@ -246,7 +254,8 @@ class LlavaVidForCausalLM(nn.Module):
"model.mm_projector.2": "multi_modal_projector.linear_2",
"model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
"model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
"model.vision_tower.vision_tower": "vision_tower",
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
"model.image_newline": "language_model.model.image_newline",
}
params_dict = dict(self.named_parameters())
......
......@@ -40,16 +40,19 @@ from transformers.models.whisper.modeling_whisper import (
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
embed_mm_inputs,
get_multimodal_data_bounds,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.managers.schedule_batch import (
MultimodalDataItem,
MultimodalInputs,
flatten_nested_list,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.minicpmv import (
Idefics2VisionTransformer,
MiniCPMVBaseModel,
MiniCPMBaseModel,
Resampler2_5,
)
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
......@@ -1409,7 +1412,7 @@ class MultiModalProjector(nn.Module):
return hidden_states
class MiniCPMO(MiniCPMVBaseModel):
class MiniCPMO(MiniCPMBaseModel):
def __init__(
self,
config: PretrainedConfig,
......@@ -1537,7 +1540,7 @@ class MiniCPMO(MiniCPMVBaseModel):
return input_lengths_after_cnn, input_lengths_after_pooling
def get_audio_embedding_streaming(self, multimodal_input: MultimodalInputs):
def get_audio_embedding_streaming(self, items: List[MultimodalDataItem]):
r"""
Extract audio embeddings in a streaming manner using cached key-value pairs.
......@@ -1545,26 +1548,15 @@ class MiniCPMO(MiniCPMVBaseModel):
for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended
for streaming scenarios.
Args:
multimodal_input (dict):
- **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
- **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
Returns:
List[List[torch.Tensor]]: audio embeddings
"""
# print("audio embedding")
wavforms = (
[]
if multimodal_input.audio_features is None
else multimodal_input.audio_features
wavforms = flatten_nested_list(
[item.audio_features for item in items if item.audio_features]
)
# list, [[x1, x2], [y1], [z1]]
audio_feature_lens_raw = (
[]
if multimodal_input.audio_feature_lens is None
else multimodal_input.audio_feature_lens
audio_feature_lens_raw = flatten_nested_list(
[item.audio_feature_lens for item in items if item.audio_feature_lens]
)
# exist audio
......@@ -1650,7 +1642,7 @@ class MiniCPMO(MiniCPMVBaseModel):
ret[i, start:ending] = True
return ret
def get_audio_embedding(self, multimodal_input: MultimodalInputs, chunk_length=-1):
def get_audio_embedding(self, items: List[MultimodalDataItem], chunk_length=-1):
r"""
Extract full audio embeddings with optional chunk-based attention.
......@@ -1659,31 +1651,25 @@ class MiniCPMO(MiniCPMVBaseModel):
not use key-value caching and is suitable for non-streaming inference.
Args:
multimodal_input (dict):
- **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
- **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
chunk_length (int, optional): Determines whether to use full attention (-1) or chunk-based
attention (>0) during embedding computation.
Returns:
List[List[torch.Tensor]]: audio embeddings
"""
# print("audio embedding")
# (bs, 80, frames) or [], multi audios need filled in advance
wavforms = (
[]
if multimodal_input.audio_features is None
else multimodal_input.audio_features
wavforms = flatten_nested_list(
[item.audio_features for item in items if item.audio_features]
)
# list, [[x1, x2], [y1], [z1]]
audio_feature_lens_raw = (
[]
if multimodal_input.audio_feature_lens is None
else multimodal_input.audio_feature_lens
audio_feature_lens_raw = flatten_nested_list(
[item.audio_feature_lens for item in items if item.audio_feature_lens]
)
final_audio_embeds = []
assert isinstance(wavforms, list)
assert isinstance(wavforms[0], torch.Tensor)
# exist audio
for wavform in wavforms:
if len(wavform) > 0:
......@@ -1757,86 +1743,46 @@ class MiniCPMO(MiniCPMVBaseModel):
final_audio_embeds.append(target_audio_embeds)
return final_audio_embeds
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
embedding = self.get_omni_embedding(
items=items,
chunk_length=self.config.audio_chunk_length,
stream_input=False,
)
return embedding
def get_omni_embedding(
self,
input_ids,
multimodal_input: MultimodalInputs,
input_embeds: torch.Tensor,
forward_mode: ForwardMode,
items: List[MultimodalDataItem],
chunk_length=-1,
stream_input=False,
):
"""
Args:
multimodal_input:
input_embeds:
chunk_length: whisper use full attention or chunk attention
stream_input: use streaming audio embedding
Returns:
final embeddings with audio feature
"""
input_embeds = input_embeds.unsqueeze(0)
if not forward_mode.is_decode() and multimodal_input.contains_audio_inputs():
audio_bounds = get_multimodal_data_bounds(
input_ids=input_ids,
pad_values=multimodal_input.pad_values,
token_pairs=[
(multimodal_input.audio_start_id, multimodal_input.audio_end_id)
],
)
if audio_bounds.numel() == 0:
input_embeds = input_embeds.squeeze(0)
# TODO
logger.warn("Unimplemented logic. Please try disabling chunked prefill")
return input_embeds
audio_bounds = audio_bounds.unsqueeze(0)
bs = len(input_embeds)
if stream_input:
audio_embeddings = self.get_audio_embedding_streaming(multimodal_input)
else:
audio_embeddings = self.get_audio_embedding(
multimodal_input, chunk_length
)
# batch size
assert len(audio_embeddings) == len(input_embeds)
if len(audio_embeddings) > 0:
if self.config.chunk_input:
for i in range(bs):
audio_embs = torch.cat(audio_embeddings[i], dim=0).to(
device=input_embeds.device, dtype=input_embeds.dtype
)
audio_start_pos = 0
for bound in audio_bounds[i]:
audio_len = bound[1] - bound[0] + 1
input_embeds[0, bound[0] : bound[1] + 1] = audio_embs[
audio_start_pos : audio_start_pos + audio_len, :
]
audio_start_pos += audio_len
else:
for i in range(bs):
audio_embs = audio_embeddings[i]
bounds = audio_bounds[i]
for embs, bound in zip(audio_embs, bounds):
audio_indices = torch.arange(
bound[0], bound[1], dtype=torch.long
).to(input_embeds.device)
if embs.shape[0] != len(audio_indices):
raise ValueError(
f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} "
f"to input indices of length {len(audio_indices)}"
)
input_embeds[i, audio_indices] = embs.to(input_embeds.dtype)
input_embeds = input_embeds.squeeze(0)
return input_embeds
def get_image_features(
self,
image_inputs: MultimodalInputs,
) -> torch.Tensor:
pixel_values = image_inputs.pixel_values
tgt_sizes = image_inputs.tgt_sizes
if stream_input:
audio_embeddings = self.get_audio_embedding_streaming(items)
else:
audio_embeddings = self.get_audio_embedding(items, chunk_length)
bs = len(audio_embeddings)
# batch size
audio_embs = torch.cat(flatten_nested_list(audio_embeddings), dim=0)
return audio_embs
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# list of tensors
pixel_values = flatten_nested_list([item.pixel_values for item in items])
tgt_sizes = torch.stack(
flatten_nested_list([item.tgt_size for item in items]), dim=0
)
assert len(pixel_values) == tgt_sizes.shape[0]
device = self.vpm.embeddings.position_embedding.weight.device
dtype = self.vpm.embeddings.position_embedding.weight.dtype
all_pixel_values_lst = [
......@@ -1845,10 +1791,10 @@ class MiniCPMO(MiniCPMVBaseModel):
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
assert isinstance(max_patches, int)
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
all_pixel_values_lst, batch_first=True, padding_value=0.0
)
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros(
......@@ -1875,53 +1821,23 @@ class MiniCPMO(MiniCPMVBaseModel):
forward_batch: ForwardBatch,
**kwargs: Any,
) -> torch.Tensor:
inputs_embeds = None
# TODO(mick): optimize the logic here: clamp, merge and embedding should happens at most once
if (
not forward_batch.forward_mode.is_decode()
and forward_batch.contains_image_inputs()
):
mm_inputs = forward_batch.merge_mm_inputs()
inputs_embeds = embed_mm_inputs(
mm_input=mm_inputs,
input_ids=input_ids,
input_embedding=self.get_input_embeddings(),
mm_data_embedding_func=self.get_image_features,
placeholder_token_ids=[mm_inputs.im_token_id] + mm_inputs.pad_values,
)
input_ids = input_ids.clamp(
min=0, max=self.get_input_embeddings().num_embeddings - 1
mm_input = forward_batch.merge_mm_inputs()
placeholder_token_ids = (
([mm_input.im_token_id] + [item.pad_value for item in mm_input.mm_items])
if forward_batch.contains_mm_inputs()
else []
)
if inputs_embeds is None:
inputs_embeds = self.llm.get_input_embeddings(input_ids)
if (
not forward_batch.forward_mode.is_decode()
and self.config.init_audio
and forward_batch.contains_audio_inputs()
):
mm_input = forward_batch.merge_mm_inputs()
inputs_embeds = self.get_omni_embedding(
input_ids=input_ids,
multimodal_input=mm_input,
input_embeds=inputs_embeds,
forward_mode=forward_batch.forward_mode,
chunk_length=self.config.audio_chunk_length,
stream_input=False,
)
forward_batch.mm_inputs = None
hidden_states = self.llm.model(
input_ids=None,
positions=positions,
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
)
return self.logits_processor(
input_ids, hidden_states, self.llm.lm_head, forward_batch
language_model=self.llm,
image_data_embedding_func=self.get_image_feature,
audio_data_embedding_func=self.get_audio_feature,
placeholder_token_ids=placeholder_token_ids,
positions=positions,
)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -54,12 +54,12 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
from sglang.srt.utils import add_prefix
from sglang.srt.utils import add_prefix, flatten_nested_list
RawImageType = Union[Image.Image, torch.Tensor]
......@@ -661,7 +661,7 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
return tuple(int(x) for x in version_str.split("."))
class MiniCPMVBaseModel(nn.Module):
class MiniCPMBaseModel(nn.Module):
"""
The abstract class of MiniCPMV can only be inherited, but cannot be
instantiated.
......@@ -853,7 +853,7 @@ class MiniCPMVBaseModel(nn.Module):
return vlm_embedding, vision_hidden_states
def get_input_embeddings(self) -> nn.Embedding:
return self.llm.get_input_embedding()
return self.llm.get_input_embeddings()
def forward(
self,
......@@ -862,23 +862,14 @@ class MiniCPMVBaseModel(nn.Module):
forward_batch: ForwardBatch,
**kwargs: Any,
) -> torch.Tensor:
inputs_embeds = general_mm_embed_routine(
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(),
mm_data_embedding_func=self.get_image_features,
)
hidden_states = self.llm.model(
input_ids=None,
image_data_embedding_func=self.get_image_feature,
language_model=self.llm,
positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
)
return self.logits_processor(
input_ids, hidden_states, self.llm.lm_head, forward_batch
)
return hidden_states
def init_llm(
self,
......@@ -913,11 +904,11 @@ class MiniCPMVBaseModel(nn.Module):
) -> torch.Tensor:
raise NotImplementedError
def get_image_features(self, image_inputs: MultimodalInputs) -> torch.Tensor:
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
raise NotImplementedError
class MiniCPMV2_6(MiniCPMVBaseModel):
class MiniCPMV2_6(MiniCPMBaseModel):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......@@ -1023,14 +1014,13 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
)
return vision_embedding
def get_image_features(
self,
image_inputs: MultimodalInputs,
) -> torch.Tensor:
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# list of tensors
pixel_values = image_inputs.pixel_values
tgt_sizes = image_inputs.tgt_sizes
pixel_values = flatten_nested_list([item.pixel_values for item in items])
tgt_sizes = torch.stack(
flatten_nested_list([item.tgt_size for item in items]), dim=0
)
assert len(pixel_values) == tgt_sizes.shape[0]
device = self.vpm.embeddings.position_embedding.weight.device
dtype = self.vpm.embeddings.position_embedding.weight.dtype
......@@ -1040,10 +1030,10 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
assert isinstance(max_patches, int)
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
all_pixel_values_lst, batch_first=True, padding_value=0.0
)
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros(
......
......@@ -796,14 +796,16 @@ class MllamaForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config.text_config)
self.capture_mode = False
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
pixel_values = image_inputs.pixel_values
pad_values = image_inputs.pad_values
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
pixel_values = torch.cat(
[item.pixel_values for item in mm_inputs.mm_items], dim=0
)
pad_values = [item.pad_value for item in mm_inputs.mm_items]
num_concurrent_media, num_tiles = pixel_values.shape[1:3]
num_patches = self.vision_model.num_patches
image_len = num_concurrent_media * num_tiles * num_patches
image_inputs.num_image_tokens = image_len
mm_inputs.num_image_tokens = image_len
pad_ids = pad_values * ((image_len + len(pad_values)) // len(pad_values))
......@@ -815,10 +817,16 @@ class MllamaForConditionalGeneration(nn.Module):
# pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res)
max_num_images = max_num_tiles = bs = 0
for i, im in enumerate(forward_batch.mm_inputs):
if not forward_batch.encoder_cached[i] and im is not None:
max_num_images = max(max_num_images, im.pixel_values.shape[1])
max_num_tiles = max(max_num_tiles, im.pixel_values.shape[2])
for i, mm_input in enumerate(forward_batch.mm_inputs):
if not forward_batch.encoder_cached[i] and mm_input is not None:
pixel_values = torch.cat(
[item.pixel_values for item in mm_input.mm_items], dim=0
)
# max_num_images = max(max_num_images, sum(1 if item.is_image() else 0 for item in mm_input.items))
max_num_images = max(max_num_images, pixel_values.shape[1])
max_num_tiles = max(max_num_tiles, pixel_values.shape[2])
bs += 1
if max_num_images * max_num_tiles * bs == 0:
......@@ -842,17 +850,24 @@ class MllamaForConditionalGeneration(nn.Module):
)
i = 0
encoder_lens_need = []
for k, im in enumerate(forward_batch.mm_inputs):
if forward_batch.encoder_cached[k] or im is None:
for k, mm_input in enumerate(forward_batch.mm_inputs):
if forward_batch.encoder_cached[k] or mm_input is None:
continue
encoder_lens_need.append(forward_batch.encoder_lens[k])
for j in range(im.pixel_values.shape[1]):
img = im.pixel_values[0, j]
pixel_values = torch.cat(
[item.pixel_values for item in mm_input.mm_items], dim=0
)
for j in range(pixel_values.shape[1]):
img = pixel_values[0, j]
num_tiles = img.shape[0]
batched_images[i, j, :num_tiles] = img
batched_ar_ids[i, j] = im.aspect_ratio_ids[0, j]
batched_ar_mask[i, j, :num_tiles] = im.aspect_ratio_mask[0, j]
batched_ar_ids[i, j] = mm_input.mm_items[0].aspect_ratio_id[0, j]
batched_ar_mask[i, j, :num_tiles] = mm_input.mm_items[
0
].aspect_ratio_mask[0, j]
i += 1
return batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need
......
......@@ -261,11 +261,14 @@ class Qwen2Model(nn.Module):
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
if hasattr(self.config, "scale_emb"):
return self.embed_tokens(input_ids) * self.config.scale_emb
return self.get_input_embeddings()(input_ids) * self.config.scale_emb
else:
return self.embed_tokens(input_ids)
return self.get_input_embeddings()(input_ids)
def get_input_embeddings(self) -> nn.Embedding:
return self.embed_tokens
def forward(
self,
......@@ -358,10 +361,10 @@ class Qwen2ForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embedding(input_ids)
def get_input_embedding(self) -> nn.Embedding:
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
@torch.no_grad()
......
......@@ -30,22 +30,13 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import AutoModel, Qwen2VLConfig
from transformers import Qwen2VLConfig
from transformers.activations import ACT2FN
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig,
Qwen2_5_VLVisionConfig,
)
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
)
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
......@@ -57,7 +48,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model
......@@ -513,19 +504,24 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs
im_start_id: int = image_inputs.im_start_id
im_end_id: int = image_inputs.im_end_id
im_start_id: int = mm_inputs.im_start_id
im_end_id: int = mm_inputs.im_end_id
media_token_pairs = [(im_start_id, im_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, mm_inputs)
return pattern.pad_input_tokens(input_ids, image_inputs)
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
pixel_values = image_input.pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# in qwen-vl, last dim is the same
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
self.visual.dtype
)
image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert image_grid_thws.dim() == 2, image_grid_thws.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws)
return image_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
......@@ -570,18 +566,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
f"(3, seq_len) positions, but got {positions.size()}"
)
inputs_embeds = general_mm_embed_routine(
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(),
mm_data_embedding_func=self.get_image_feature,
)
hidden_states = self.model(
input_ids=None,
language_model=self.model,
image_data_embedding_func=self.get_image_feature,
positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
)
if not get_embedding:
......@@ -594,9 +584,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
("gate_up_proj", "up_proj", 1),
("gate_up_proj", "gate_proj", 0),
]
......
......@@ -45,7 +45,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model
......@@ -472,18 +472,24 @@ class Qwen2VLForConditionalGeneration(nn.Module):
# Use grid_t * grid_w * grid_h to pad tokens for each image
# add replaced padding by unique image hash
def pad_input_ids(self, input_ids: List[int], multi_modal_inputs: MultimodalInputs):
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs
im_start_id: int = multi_modal_inputs.im_start_id
im_end_id: int = multi_modal_inputs.im_end_id
im_start_id: int = mm_inputs.im_start_id
im_end_id: int = mm_inputs.im_end_id
media_token_pairs = [(im_start_id, im_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, multi_modal_inputs)
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
pixel_values = image_input.pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# in qwen-vl, last dim is the same
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
self.visual.dtype
)
image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert image_grid_thws.dim() == 2, image_grid_thws.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws)
return image_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
......@@ -527,27 +533,20 @@ class Qwen2VLForConditionalGeneration(nn.Module):
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
)
inputs_embeds = general_mm_embed_routine(
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(),
mm_data_embedding_func=self.get_image_feature,
)
hidden_states = self.model(
input_ids=None,
language_model=self.model,
image_data_embedding_func=self.get_image_feature,
positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
)
if not get_embedding:
if get_embedding:
return self.pooler(hidden_states, forward_batch)
else:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
else:
return self.pooler(hidden_states, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -897,6 +897,7 @@ def v1_chat_generate_request(
request_ids: List[str] = None,
):
input_ids = []
prompts = []
sampling_params_list = []
image_data_list = []
audio_data_list = []
......@@ -916,6 +917,7 @@ def v1_chat_generate_request(
# - audio_data: None or a list of audio strings (URLs).
# None skips any image processing in GenerateReqInput.
strict_tag = None
prompt = ""
if not isinstance(request.messages, str):
# Apply chat template and its stop strings.
tools = None
......@@ -1005,11 +1007,13 @@ def v1_chat_generate_request(
image_data = None
audio_data = None
modalities = []
prompt = request.messages
input_ids.append(prompt_ids)
return_logprobs.append(request.logprobs)
logprob_start_lens.append(-1)
top_logprobs_nums.append(request.top_logprobs or 0)
lora_paths.append(request.lora_path)
prompts.append(prompt)
sampling_params = {
"temperature": request.temperature,
......@@ -1063,10 +1067,14 @@ def v1_chat_generate_request(
audio_data_list.append(audio_data)
modalities_list.append(modalities)
if len(all_requests) == 1:
if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids[0]}
if tokenizer_manager.model_config.is_multimodal:
# processor will need text input
prompt_kwargs = {"text": prompts[0]}
else:
prompt_kwargs = {"input_ids": input_ids[0]}
if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids[0]}
else:
prompt_kwargs = {"input_ids": input_ids[0]}
sampling_params_list = sampling_params_list[0]
image_data_list = image_data_list[0]
audio_data_list = audio_data_list[0]
......@@ -1076,10 +1084,14 @@ def v1_chat_generate_request(
modalities_list = modalities_list[0]
lora_paths = lora_paths[0]
else:
if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids}
if tokenizer_manager.model_config.is_multimodal:
# processor will need text input
prompt_kwargs = {"text": prompts}
else:
prompt_kwargs = {"input_ids": input_ids}
if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids}
else:
prompt_kwargs = {"input_ids": input_ids}
adapted_request = GenerateReqInput(
**prompt_kwargs,
......
......@@ -12,7 +12,6 @@
# limitations under the License.
# ==============================================================================
"""Common utilities."""
import base64
import builtins
import ctypes
......@@ -54,6 +53,7 @@ import torch.distributed
import torch.distributed as dist
import triton
import zmq
from decord import VideoReader, cpu
from fastapi.responses import ORJSONResponse
from packaging import version as pkg_version
from PIL import Image
......@@ -513,13 +513,18 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
import soundfile as sf
from scipy.signal import resample
# print(f"loading {audio_file}")
# Load audio data
if isinstance(audio_file, bytes):
audio, original_sr = sf.read(BytesIO(audio_file))
elif audio_file.startswith("data:"):
audio_file = audio_file.split(",")[1]
audio, original_sr = sf.read(BytesIO(base64.b64decode(audio_file)))
elif audio_file.startswith("http://") or audio_file.startswith("https://"):
timeout = int(os.getenv("REQUEST_TIMEOUT", "5"))
response = requests.get(audio_file, stream=True, timeout=timeout)
audio_file = BytesIO(response.content)
response.close()
audio, original_sr = sf.read(audio_file)
elif isinstance(audio_file, str):
audio, original_sr = sf.read(audio_file)
else:
......@@ -537,6 +542,30 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
return audio
def encode_video(video_path, frame_count_limit=None):
if not os.path.exists(video_path):
logger.error(f"Video {video_path} does not exist")
return []
if frame_count_limit == 0:
return []
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_indices = [i for i in range(0, len(vr), sample_fps)]
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
frame_indices = uniform_sample(frame_indices, frame_count_limit)
frames = vr.get_batch(frame_indices).asnumpy()
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
return frames
def load_image(image_file: Union[str, bytes]) -> tuple[Image, tuple[int, int]]:
image = image_size = None
......@@ -1796,3 +1825,12 @@ def retry(
traceback.print_exc()
time.sleep(delay)
def flatten_nested_list(nested_list):
if isinstance(nested_list, list):
return [
item for sublist in nested_list for item in flatten_nested_list(sublist)
]
else:
return [nested_list]
......@@ -155,9 +155,7 @@ class TestOpenAIVisionServer(CustomTestCase):
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
},
"image_url": {"url": IMAGE_MAN_IRONING_URL},
"modalities": "multi-images",
},
{
......@@ -399,14 +397,14 @@ class TestOpenAIVisionServer(CustomTestCase):
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt,
},
{
"type": "audio_url",
"audio_url": {"url": f"{audio_file_name}"},
},
{
"type": "text",
"text": prompt,
},
],
}
]
......
......@@ -3,6 +3,7 @@
import unittest
from io import BytesIO
from typing import List
import numpy as np
import requests
......@@ -14,7 +15,11 @@ from transformers import AutoModel, AutoProcessor, AutoTokenizer
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.managers.mm_utils import embed_mm_inputs
from sglang.srt.managers.schedule_batch import MultimodalInputs
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.server_args import ServerArgs
......@@ -195,14 +200,35 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
# sglang
model = self.get_sglang_model()
input_ids = inputs["input_ids"].to(self.device).flatten()
pixel_values = inputs["pixel_values"]
tgt_sizes = inputs["tgt_sizes"]
pixel_values_flat: List[torch.Tensor] = []
tgt_sizes_flat: List[torch.Tensor] = []
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
# per image
if len(pixel_b) != len(tgt_b):
raise ValueError(
"Inconsistent N lengths, found: "
f"{len(pixel_b)} vs {len(tgt_b)}"
)
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
pixel_values_flat += [pixel_n]
tgt_sizes_flat += [tgt_n]
sglang_output = embed_mm_inputs(
mm_input=MultimodalInputs(
pixel_values=inputs["pixel_values"][0],
tgt_sizes=inputs["tgt_sizes"][0],
mm_inputs=MultimodalInputs(
mm_items=[
MultimodalDataItem(
pixel_values=pixel_values_flat,
tgt_size=tgt_sizes_flat,
modality=Modality.IMAGE,
pad_value=self.processor.tokenizer.unk_token_id,
)
]
),
input_ids=input_ids,
input_embedding=model.get_input_embeddings(),
mm_data_embedding_func=model.get_image_features,
image_data_embedding_func=model.get_image_feature,
placeholder_token_ids=[
self.processor.tokenizer.unk_token_id,
],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment