Unverified Commit 5cb552b1 authored by Mick's avatar Mick Committed by GitHub
Browse files

refactor: multimodal data (#4754)

parent c7457191
...@@ -1308,6 +1308,9 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1308,6 +1308,9 @@ class DeepseekV2ForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.dp_size = get_attention_dp_size() self.dp_size = get_attention_dp_size()
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
......
...@@ -11,7 +11,11 @@ from sglang.srt.configs.deepseekvl2 import ( ...@@ -11,7 +11,11 @@ from sglang.srt.configs.deepseekvl2 import (
) )
from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternImageTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
...@@ -150,7 +154,6 @@ class DeepseekVL2MlpProjector(nn.Module): ...@@ -150,7 +154,6 @@ class DeepseekVL2MlpProjector(nn.Module):
return x return x
# todo
class DeepseekVL2ForCausalLM(nn.Module): class DeepseekVL2ForCausalLM(nn.Module):
def __init__( def __init__(
...@@ -215,32 +218,15 @@ class DeepseekVL2ForCausalLM(nn.Module): ...@@ -215,32 +218,15 @@ class DeepseekVL2ForCausalLM(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
**kwargs: object, **kwargs: object,
): ):
input_embeds = self.language_model.model.embed_tokens(input_ids) hs = general_mm_embed_routine(
if (
forward_batch.forward_mode.is_extend()
and forward_batch.contains_image_inputs()
):
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
for idx, image in enumerate(forward_batch.mm_inputs):
if image is None:
continue
start_idx = extend_start_loc_cpu[idx]
end_idx = start_idx + extend_seq_lens_cpu[idx]
images_emb_mask = image.images_emb_mask.to(device="cuda")
image_features = self.get_image_feature(image)
input_embeds[start_idx:end_idx] = input_embeds[
start_idx:end_idx
].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
outputs = self.language_model.forward(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
forward_batch=forward_batch, forward_batch=forward_batch,
input_embeds=input_embeds, image_data_embedding_func=self.get_image_feature,
language_model=self.language_model,
) )
return outputs return hs
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
...@@ -263,94 +249,109 @@ class DeepseekVL2ForCausalLM(nn.Module): ...@@ -263,94 +249,109 @@ class DeepseekVL2ForCausalLM(nn.Module):
weights_loader(param, loaded_weight) weights_loader(param, loaded_weight)
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
return input_ids helper = MultiModalityDataPaddingPatternImageTokens(
image_token_id=image_inputs.im_token_id
def get_image_feature(self, image_input: MultimodalInputs): )
pixel_values = image_input.pixel_values.type( return helper.pad_input_tokens(input_ids, image_inputs)
next(self.vision.parameters()).dtype
).to(device=next(self.vision.parameters()).device) def get_image_feature(self, items: List[MultimodalDataItem]):
image_feature = self.vision.forward_features(pixel_values)
images_embeds = self.projector(image_feature) images_spatial_crop = torch.cat(
_, hw, n_dim = images_embeds.shape [item.image_spatial_crop for item in items], dim=0
h = w = int(hw**0.5) )
tile_index = 0
assert images_spatial_crop.dim() == 3
# TODO: can it be batched ?
images_in_this_batch = [] images_in_this_batch = []
images_spatial_crop = image_input.image_spatial_crop for item in items:
for jdx in range(images_spatial_crop.shape[1]): assert item.pixel_values.dim() == 4
num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx] image_feature = self.vision.forward_features(
if num_width_tiles == 0 or num_height_tiles == 0: item.pixel_values.type(next(self.vision.parameters()).dtype).to(
break device=next(self.vision.parameters()).device
num_tiles_in_image = num_width_tiles * num_height_tiles )
# [hw, D]
global_features = images_embeds[tile_index]
# [num_height_tiles * num_width_tiles, hw, D]
local_features = images_embeds[
tile_index + 1 : tile_index + 1 + num_tiles_in_image
]
tile_index += num_tiles_in_image + 1
# format global and local features
# ----------------- global view add newline -----------------
# [hw, D] -> [h, w, D]
global_features = global_features.view(h, w, n_dim)
# [D] -> [h, 1, D]
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
global_features = torch.cat([global_features, new_lines_in_global], dim=1)
# [h, w + 1, D] -> [h * (w + 1), D]
global_features = global_features.view(-1, n_dim)
# ----------------- local view add newline -----------------
# [num_height_tiles * num_width_tiles, h * w, D] ->
# [num_height_tiles * h, num_width_tiles * w, D]
local_features = rearrange(
local_features,
"(th tw) (h w) d -> (th h) (tw w) d",
th=num_height_tiles,
tw=num_width_tiles,
h=h,
w=w,
) )
images_embeds = self.projector(image_feature)
_, hw, n_dim = images_embeds.shape
h = w = int(hw**0.5)
tile_index = 0
for jdx in range(item.image_spatial_crop.shape[1]):
num_width_tiles, num_height_tiles = item.image_spatial_crop[0, jdx]
if num_width_tiles == 0 or num_height_tiles == 0:
break
num_tiles_in_image = num_width_tiles * num_height_tiles
# [hw, D]
global_features = images_embeds[tile_index]
# [num_height_tiles * num_width_tiles, hw, D]
local_features = images_embeds[
tile_index + 1 : tile_index + 1 + num_tiles_in_image
]
tile_index += num_tiles_in_image + 1
# [D] -> [num_height_tiles * h, 1, D] # format global and local features
new_lines_in_local = repeat( # ----------------- global view add newline -----------------
self.image_newline, # [hw, D] -> [h, w, D]
"d -> (th h) 1 d", global_features = global_features.view(h, w, n_dim)
th=num_height_tiles,
h=h, # [D] -> [h, 1, D]
) new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
# [num_height_tiles * h, num_width_tiles * w + 1, D] # cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
local_features = torch.cat([local_features, new_lines_in_local], dim=1) global_features = torch.cat(
[global_features, new_lines_in_global], dim=1
# [num_height_tiles * h, num_width_tiles * w + 1, D]
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
local_features = local_features.view(-1, n_dim)
# merge global and local tiles
if self.global_view_pos == "head":
global_local_features = torch.cat(
[
global_features,
self.view_seperator[None, :],
local_features,
]
) )
else:
global_local_features = torch.cat( # [h, w + 1, D] -> [h * (w + 1), D]
[ global_features = global_features.view(-1, n_dim)
local_features,
self.view_seperator[None, :], # ----------------- local view add newline -----------------
global_features, # [num_height_tiles * num_width_tiles, h * w, D] ->
] # [num_height_tiles * h, num_width_tiles * w, D]
local_features = rearrange(
local_features,
"(th tw) (h w) d -> (th h) (tw w) d",
th=num_height_tiles,
tw=num_width_tiles,
h=h,
w=w,
) )
images_in_this_batch.append(global_local_features) # [D] -> [num_height_tiles * h, 1, D]
new_lines_in_local = repeat(
self.image_newline,
"d -> (th h) 1 d",
th=num_height_tiles,
h=h,
)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
local_features = torch.cat([local_features, new_lines_in_local], dim=1)
# [num_height_tiles * h, num_width_tiles * w + 1, D]
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
local_features = local_features.view(-1, n_dim)
# merge global and local tiles
if self.global_view_pos == "head":
global_local_features = torch.cat(
[
global_features,
self.view_seperator[None, :],
local_features,
]
)
else:
global_local_features = torch.cat(
[
local_features,
self.view_seperator[None, :],
global_features,
]
)
images_in_this_batch.append(global_local_features)
return torch.cat(images_in_this_batch, dim=0) return torch.cat(images_in_this_batch, dim=0)
......
...@@ -21,14 +21,7 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict ...@@ -21,14 +21,7 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
import torch import torch
from torch import nn from torch import nn
from transformers import ( from transformers import AutoModel, Gemma3Config, PreTrainedModel
AutoModel,
BatchFeature,
Gemma3Config,
Gemma3Processor,
PreTrainedModel,
)
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.layernorm import Gemma3RMSNorm from sglang.srt.layers.layernorm import Gemma3RMSNorm
...@@ -38,7 +31,11 @@ from sglang.srt.managers.mm_utils import ( ...@@ -38,7 +31,11 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.managers.schedule_batch import (
MultimodalDataItem,
MultimodalInputs,
flatten_nested_list,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
...@@ -274,17 +271,16 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -274,17 +271,16 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
""" """
return self.language_model.get_attention_sliding_window_size() return self.language_model.get_attention_sliding_window_size()
def get_image_feature(self, image_input: MultimodalInputs): def get_image_feature(self, items: List[MultimodalDataItem]):
""" """
Projects the last hidden state from the vision model into language model space. Projects the last hidden state from the vision model into language model space.
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
Returns: Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
""" """
pixel_values = image_input.pixel_values pixel_values = torch.stack(
flatten_nested_list([item.pixel_values for item in items]), dim=0
)
pixel_values = pixel_values.to("cuda") pixel_values = pixel_values.to("cuda")
pixel_values = pixel_values.to(dtype=self.language_model.dtype()) pixel_values = pixel_values.to(dtype=self.language_model.dtype())
...@@ -292,61 +288,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -292,61 +288,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
image_features = self.multi_modal_projector(vision_outputs) image_features = self.multi_modal_projector(vision_outputs)
return image_features return image_features
def embed_mm_inputs(
self,
input_ids: torch.Tensor,
forward_batch: ForwardBatch,
image_input: MultimodalInputs,
) -> torch.Tensor:
if input_ids is None:
raise ValueError("Unimplemented")
# boolean-masking image tokens
special_image_mask = torch.isin(
input_ids,
torch.tensor(image_input.pad_values, device=input_ids.device),
).unsqueeze(-1)
num_image_tokens_in_input_ids = special_image_mask.sum()
inputs_embeds = None
if num_image_tokens_in_input_ids == 0:
inputs_embeds = self.get_input_embeddings()(input_ids)
return inputs_embeds
else:
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
image_features = self.get_image_feature(image_input.pixel_values)
# print(f"image tokens from image embeddings: {image_features.numel()}")
num_image_tokens_in_embedding = (
image_features.shape[0] * image_features.shape[1]
)
if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
num_image = num_image_tokens_in_input_ids // image_features.shape[1]
image_features = image_features[:num_image, :]
logger.warning(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
"tokens from image embeddings."
)
# Important: clamp after extracting original image boundaries
input_ids.clamp_(min=0, max=self.vocab_size - 1)
inputs_embeds = self.get_input_embeddings()(input_ids)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
image_features = image_features.to(
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(
special_image_mask, image_features
)
return inputs_embeds
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
...@@ -405,22 +346,15 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -405,22 +346,15 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
else: else:
llm_input_ids = input_ids llm_input_ids = input_ids
inputs_embeds = general_mm_embed_routine( hs = general_mm_embed_routine(
input_ids=llm_input_ids, input_ids=llm_input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(), language_model=self.language_model,
mm_data_embedding_func=self.get_image_feature, image_data_embedding_func=self.get_image_feature,
)
outputs = self.language_model(
input_ids=None,
positions=positions, positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
**kwargs,
) )
return outputs return hs
def tie_weights(self): def tie_weights(self):
return self.language_model.tie_weights() return self.language_model.tie_weights()
......
...@@ -428,6 +428,9 @@ class LlamaForCausalLM(nn.Module): ...@@ -428,6 +428,9 @@ class LlamaForCausalLM(nn.Module):
else: else:
return self.pooler(hidden_states, forward_batch) return self.pooler(hidden_states, forward_batch)
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
def get_hidden_dim(self, module_name): def get_hidden_dim(self, module_name):
# return input_dim, output_dim # return input_dim, output_dim
if module_name in ["q_proj", "o_proj", "qkv_proj"]: if module_name in ["q_proj", "o_proj", "qkv_proj"]:
......
...@@ -31,7 +31,7 @@ from transformers import ( ...@@ -31,7 +31,7 @@ from transformers import (
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.managers.schedule_batch import Modality, MultimodalInputs
from sglang.srt.mm_utils import ( from sglang.srt.mm_utils import (
get_anyres_image_grid_shape, get_anyres_image_grid_shape,
unpad_image, unpad_image,
...@@ -42,17 +42,21 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader ...@@ -42,17 +42,21 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.llama import LlamaForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix, flatten_nested_list
class LlavaBaseForCausalLM(nn.Module): class LlavaBaseForCausalLM(nn.Module):
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values image_sizes = flatten_nested_list(
[item.image_sizes for item in image_inputs.mm_items]
)
pad_values = [item.pad_value for item in image_inputs.mm_items]
# hardcode for spatial_unpad + anyres # hardcode for spatial_unpad + anyres
if image_inputs.modalities is not None and ( if any(
"multi-images" in image_inputs.modalities item.modality == Modality.MULTI_IMAGES or item.modality == Modality.VIDEO
or "video" in image_inputs.modalities for item in image_inputs.mm_items
): ):
image_aspect_ratio = "pad" image_aspect_ratio = "pad"
else: else:
...@@ -66,7 +70,7 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -66,7 +70,7 @@ class LlavaBaseForCausalLM(nn.Module):
math.ceil(self.image_size / self.patch_size / 2) ** 2 math.ceil(self.image_size / self.patch_size / 2) ** 2
) )
else: else:
new_image_feature_len = self.image_feature_len # multiimage new_image_feature_len = self.image_feature_len # multi-image
height = width = self.num_patches_per_side height = width = self.num_patches_per_side
if "anyres" in image_aspect_ratio: if "anyres" in image_aspect_ratio:
...@@ -101,7 +105,7 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -101,7 +105,7 @@ class LlavaBaseForCausalLM(nn.Module):
# old_len + pad_len - 1, because we need to remove image_token_id # old_len + pad_len - 1, because we need to remove image_token_id
input_ids = ( input_ids = (
input_ids[:offset] input_ids[:offset]
+ [pad_values[image_idx]] * new_image_feature_len + [pad_values[image_idx % len(pad_values)]] * new_image_feature_len
+ input_ids[offset + 1 :] + input_ids[offset + 1 :]
) )
offset_list.append(offset) offset_list.append(offset)
...@@ -150,8 +154,8 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -150,8 +154,8 @@ class LlavaBaseForCausalLM(nn.Module):
modalities_list = [] modalities_list = []
max_image_offset = [] max_image_offset = []
for im in image_inputs: for im in image_inputs:
if im and im.modalities is not None: if im:
modalities_list.extend(im.modalities) modalities_list.extend([item.modality for item in im.mm_items])
if im and im.image_offsets: if im and im.image_offsets:
max_image_offset.append( max_image_offset.append(
np.max(np.array(im.image_offsets) + np.array(im.image_pad_len)) np.max(np.array(im.image_offsets) + np.array(im.image_pad_len))
...@@ -164,11 +168,19 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -164,11 +168,19 @@ class LlavaBaseForCausalLM(nn.Module):
if need_vision.any(): if need_vision.any():
bs = forward_batch.batch_size bs = forward_batch.batch_size
pixel_values = [ pixel_values = flatten_nested_list(
image_inputs[i].pixel_values for i in range(bs) if need_vision[i] [
] [item.pixel_values for item in image_inputs[i].mm_items]
for i in range(bs)
if need_vision[i]
]
)
image_sizes = [ image_sizes = [
image_inputs[i].image_sizes for i in range(bs) if need_vision[i] flatten_nested_list(
[item.image_sizes for item in image_inputs[i].mm_items]
)
for i in range(bs)
if need_vision[i]
] ]
########## Encode Image ######## ########## Encode Image ########
...@@ -197,13 +209,13 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -197,13 +209,13 @@ class LlavaBaseForCausalLM(nn.Module):
new_image_features = [] new_image_features = []
height = width = self.num_patches_per_side height = width = self.num_patches_per_side
for image_idx, image_feature in enumerate(image_features): for image_idx, image_feature in enumerate(image_features):
if modalities_list[image_idx] == "image": if modalities_list[image_idx] == Modality.IMAGE:
image_aspect_ratio = ( image_aspect_ratio = (
self.config.image_aspect_ratio self.config.image_aspect_ratio
) # single image ) # single image
elif ( elif (
modalities_list[image_idx] == "multi-images" modalities_list[image_idx] == Modality.MULTI_IMAGES
or modalities_list[image_idx] == "video" or modalities_list[image_idx] == Modality.VIDEO
): ):
image_aspect_ratio = "pad" # multi image image_aspect_ratio = "pad" # multi image
# image_aspect_ratio = ( # image_aspect_ratio = (
...@@ -212,7 +224,7 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -212,7 +224,7 @@ class LlavaBaseForCausalLM(nn.Module):
if ( if (
image_feature.shape[0] > 1 image_feature.shape[0] > 1
and "anyres" in image_aspect_ratio and "anyres" in image_aspect_ratio
and modalities_list[image_idx] == "image" and modalities_list[image_idx] == Modality.IMAGE
): ):
base_image_feature = image_feature[0] base_image_feature = image_feature[0]
image_feature = image_feature[1:] image_feature = image_feature[1:]
...@@ -312,7 +324,7 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -312,7 +324,7 @@ class LlavaBaseForCausalLM(nn.Module):
) )
image_feature = image_feature.unsqueeze(0) image_feature = image_feature.unsqueeze(0)
else: else:
if modalities_list[image_idx] == "video": # video if modalities_list[image_idx] == Modality.VIDEO: # video
# 2x2 pooling # 2x2 pooling
num_of_frames = image_feature.shape[0] num_of_frames = image_feature.shape[0]
image_feature = image_feature.view( image_feature = image_feature.view(
......
...@@ -22,7 +22,7 @@ from transformers import CLIPVisionModel, LlavaConfig ...@@ -22,7 +22,7 @@ from transformers import CLIPVisionModel, LlavaConfig
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.managers.schedule_batch import MultimodalInputs, flatten_nested_list
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.llama import LlamaForCausalLM
...@@ -58,7 +58,7 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -58,7 +58,7 @@ class LlavaVidForCausalLM(nn.Module):
) )
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
pad_values = image_inputs.pad_values pad_values = [item.pad_value for item in image_inputs.mm_items]
new_image_feature_len = self.image_feature_len new_image_feature_len = self.image_feature_len
pad_ids = pad_values * ( pad_ids = pad_values * (
...@@ -133,11 +133,19 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -133,11 +133,19 @@ class LlavaVidForCausalLM(nn.Module):
need_vision = start_positions <= np.array(max_image_offset) need_vision = start_positions <= np.array(max_image_offset)
if need_vision.any(): if need_vision.any():
pixel_values = [ pixel_values = flatten_nested_list(
image_inputs[i].pixel_values for i in range(bs) if need_vision[i] [
] [item.pixel_values for item in image_inputs[i].mm_items]
for i in range(bs)
if need_vision[i]
]
)
image_offsets = [ image_offsets = [
image_inputs[i].image_offsets for i in range(bs) if need_vision[i] flatten_nested_list(
[item.image_offsets for item in image_inputs[i].mm_items]
)
for i in range(bs)
if need_vision[i]
] ]
########## Encode Image ######## ########## Encode Image ########
...@@ -246,7 +254,8 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -246,7 +254,8 @@ class LlavaVidForCausalLM(nn.Module):
"model.mm_projector.2": "multi_modal_projector.linear_2", "model.mm_projector.2": "multi_modal_projector.linear_2",
"model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1", "model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
"model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2", "model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). "model.vision_tower.vision_tower": "vision_tower",
# Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
"model.image_newline": "language_model.model.image_newline", "model.image_newline": "language_model.model.image_newline",
} }
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
......
...@@ -40,16 +40,19 @@ from transformers.models.whisper.modeling_whisper import ( ...@@ -40,16 +40,19 @@ from transformers.models.whisper.modeling_whisper import (
from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.managers.mm_utils import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
embed_mm_inputs, general_mm_embed_routine,
get_multimodal_data_bounds,
) )
from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.managers.schedule_batch import (
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode MultimodalDataItem,
MultimodalInputs,
flatten_nested_list,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.minicpmv import ( from sglang.srt.models.minicpmv import (
Idefics2VisionTransformer, Idefics2VisionTransformer,
MiniCPMVBaseModel, MiniCPMBaseModel,
Resampler2_5, Resampler2_5,
) )
from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM
...@@ -1409,7 +1412,7 @@ class MultiModalProjector(nn.Module): ...@@ -1409,7 +1412,7 @@ class MultiModalProjector(nn.Module):
return hidden_states return hidden_states
class MiniCPMO(MiniCPMVBaseModel): class MiniCPMO(MiniCPMBaseModel):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
...@@ -1537,7 +1540,7 @@ class MiniCPMO(MiniCPMVBaseModel): ...@@ -1537,7 +1540,7 @@ class MiniCPMO(MiniCPMVBaseModel):
return input_lengths_after_cnn, input_lengths_after_pooling return input_lengths_after_cnn, input_lengths_after_pooling
def get_audio_embedding_streaming(self, multimodal_input: MultimodalInputs): def get_audio_embedding_streaming(self, items: List[MultimodalDataItem]):
r""" r"""
Extract audio embeddings in a streaming manner using cached key-value pairs. Extract audio embeddings in a streaming manner using cached key-value pairs.
...@@ -1545,26 +1548,15 @@ class MiniCPMO(MiniCPMVBaseModel): ...@@ -1545,26 +1548,15 @@ class MiniCPMO(MiniCPMVBaseModel):
for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended for faster inference on subsequent audio frames. It only supports batch_size=1 and is intended
for streaming scenarios. for streaming scenarios.
Args:
multimodal_input (dict):
- **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
- **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
Returns: Returns:
List[List[torch.Tensor]]: audio embeddings List[List[torch.Tensor]]: audio embeddings
""" """
# print("audio embedding") wavforms = flatten_nested_list(
[item.audio_features for item in items if item.audio_features]
wavforms = (
[]
if multimodal_input.audio_features is None
else multimodal_input.audio_features
) )
# list, [[x1, x2], [y1], [z1]] # list, [[x1, x2], [y1], [z1]]
audio_feature_lens_raw = ( audio_feature_lens_raw = flatten_nested_list(
[] [item.audio_feature_lens for item in items if item.audio_feature_lens]
if multimodal_input.audio_feature_lens is None
else multimodal_input.audio_feature_lens
) )
# exist audio # exist audio
...@@ -1650,7 +1642,7 @@ class MiniCPMO(MiniCPMVBaseModel): ...@@ -1650,7 +1642,7 @@ class MiniCPMO(MiniCPMVBaseModel):
ret[i, start:ending] = True ret[i, start:ending] = True
return ret return ret
def get_audio_embedding(self, multimodal_input: MultimodalInputs, chunk_length=-1): def get_audio_embedding(self, items: List[MultimodalDataItem], chunk_length=-1):
r""" r"""
Extract full audio embeddings with optional chunk-based attention. Extract full audio embeddings with optional chunk-based attention.
...@@ -1659,31 +1651,25 @@ class MiniCPMO(MiniCPMVBaseModel): ...@@ -1659,31 +1651,25 @@ class MiniCPMO(MiniCPMVBaseModel):
not use key-value caching and is suitable for non-streaming inference. not use key-value caching and is suitable for non-streaming inference.
Args: Args:
multimodal_input (dict):
- **"audio_features"** (`torch.FloatTensor`): Input mel-spectrograms of shape `(batch_size, 80, frames)`.
- **"audio_feature_lens"** (List[List[int]]): Lengths of each audio segment for each item in the batch.
chunk_length (int, optional): Determines whether to use full attention (-1) or chunk-based chunk_length (int, optional): Determines whether to use full attention (-1) or chunk-based
attention (>0) during embedding computation. attention (>0) during embedding computation.
Returns: Returns:
List[List[torch.Tensor]]: audio embeddings List[List[torch.Tensor]]: audio embeddings
""" """
# print("audio embedding")
# (bs, 80, frames) or [], multi audios need filled in advance # (bs, 80, frames) or [], multi audios need filled in advance
wavforms = ( wavforms = flatten_nested_list(
[] [item.audio_features for item in items if item.audio_features]
if multimodal_input.audio_features is None
else multimodal_input.audio_features
) )
# list, [[x1, x2], [y1], [z1]] # list, [[x1, x2], [y1], [z1]]
audio_feature_lens_raw = ( audio_feature_lens_raw = flatten_nested_list(
[] [item.audio_feature_lens for item in items if item.audio_feature_lens]
if multimodal_input.audio_feature_lens is None
else multimodal_input.audio_feature_lens
) )
final_audio_embeds = [] final_audio_embeds = []
assert isinstance(wavforms, list)
assert isinstance(wavforms[0], torch.Tensor)
# exist audio # exist audio
for wavform in wavforms: for wavform in wavforms:
if len(wavform) > 0: if len(wavform) > 0:
...@@ -1757,86 +1743,46 @@ class MiniCPMO(MiniCPMVBaseModel): ...@@ -1757,86 +1743,46 @@ class MiniCPMO(MiniCPMVBaseModel):
final_audio_embeds.append(target_audio_embeds) final_audio_embeds.append(target_audio_embeds)
return final_audio_embeds return final_audio_embeds
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
embedding = self.get_omni_embedding(
items=items,
chunk_length=self.config.audio_chunk_length,
stream_input=False,
)
return embedding
def get_omni_embedding( def get_omni_embedding(
self, self,
input_ids, items: List[MultimodalDataItem],
multimodal_input: MultimodalInputs,
input_embeds: torch.Tensor,
forward_mode: ForwardMode,
chunk_length=-1, chunk_length=-1,
stream_input=False, stream_input=False,
): ):
""" """
Args: Args:
multimodal_input:
input_embeds:
chunk_length: whisper use full attention or chunk attention chunk_length: whisper use full attention or chunk attention
stream_input: use streaming audio embedding stream_input: use streaming audio embedding
Returns: Returns:
final embeddings with audio feature final embeddings with audio feature
""" """
input_embeds = input_embeds.unsqueeze(0)
if not forward_mode.is_decode() and multimodal_input.contains_audio_inputs(): if stream_input:
audio_bounds = get_multimodal_data_bounds( audio_embeddings = self.get_audio_embedding_streaming(items)
input_ids=input_ids, else:
pad_values=multimodal_input.pad_values, audio_embeddings = self.get_audio_embedding(items, chunk_length)
token_pairs=[ bs = len(audio_embeddings)
(multimodal_input.audio_start_id, multimodal_input.audio_end_id) # batch size
], audio_embs = torch.cat(flatten_nested_list(audio_embeddings), dim=0)
)
if audio_bounds.numel() == 0: return audio_embs
input_embeds = input_embeds.squeeze(0)
# TODO def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
logger.warn("Unimplemented logic. Please try disabling chunked prefill") # list of tensors
return input_embeds pixel_values = flatten_nested_list([item.pixel_values for item in items])
audio_bounds = audio_bounds.unsqueeze(0) tgt_sizes = torch.stack(
bs = len(input_embeds) flatten_nested_list([item.tgt_size for item in items]), dim=0
)
if stream_input: assert len(pixel_values) == tgt_sizes.shape[0]
audio_embeddings = self.get_audio_embedding_streaming(multimodal_input)
else:
audio_embeddings = self.get_audio_embedding(
multimodal_input, chunk_length
)
# batch size
assert len(audio_embeddings) == len(input_embeds)
if len(audio_embeddings) > 0:
if self.config.chunk_input:
for i in range(bs):
audio_embs = torch.cat(audio_embeddings[i], dim=0).to(
device=input_embeds.device, dtype=input_embeds.dtype
)
audio_start_pos = 0
for bound in audio_bounds[i]:
audio_len = bound[1] - bound[0] + 1
input_embeds[0, bound[0] : bound[1] + 1] = audio_embs[
audio_start_pos : audio_start_pos + audio_len, :
]
audio_start_pos += audio_len
else:
for i in range(bs):
audio_embs = audio_embeddings[i]
bounds = audio_bounds[i]
for embs, bound in zip(audio_embs, bounds):
audio_indices = torch.arange(
bound[0], bound[1], dtype=torch.long
).to(input_embeds.device)
if embs.shape[0] != len(audio_indices):
raise ValueError(
f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} "
f"to input indices of length {len(audio_indices)}"
)
input_embeds[i, audio_indices] = embs.to(input_embeds.dtype)
input_embeds = input_embeds.squeeze(0)
return input_embeds
def get_image_features(
self,
image_inputs: MultimodalInputs,
) -> torch.Tensor:
pixel_values = image_inputs.pixel_values
tgt_sizes = image_inputs.tgt_sizes
device = self.vpm.embeddings.position_embedding.weight.device device = self.vpm.embeddings.position_embedding.weight.device
dtype = self.vpm.embeddings.position_embedding.weight.dtype dtype = self.vpm.embeddings.position_embedding.weight.dtype
all_pixel_values_lst = [ all_pixel_values_lst = [
...@@ -1845,10 +1791,10 @@ class MiniCPMO(MiniCPMVBaseModel): ...@@ -1845,10 +1791,10 @@ class MiniCPMO(MiniCPMVBaseModel):
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item() max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
assert isinstance(max_patches, int) assert isinstance(max_patches, int)
all_pixel_values = torch.nn.utils.rnn.pad_sequence( all_pixel_values = torch.nn.utils.rnn.pad_sequence(
all_pixel_values_lst, batch_first=True, padding_value=0.0 all_pixel_values_lst, batch_first=True, padding_value=0.0
) )
B, L, _ = all_pixel_values.shape B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros( patch_attn_mask = torch.zeros(
...@@ -1875,53 +1821,23 @@ class MiniCPMO(MiniCPMVBaseModel): ...@@ -1875,53 +1821,23 @@ class MiniCPMO(MiniCPMVBaseModel):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = None
# TODO(mick): optimize the logic here: clamp, merge and embedding should happens at most once
if (
not forward_batch.forward_mode.is_decode()
and forward_batch.contains_image_inputs()
):
mm_inputs = forward_batch.merge_mm_inputs()
inputs_embeds = embed_mm_inputs(
mm_input=mm_inputs,
input_ids=input_ids,
input_embedding=self.get_input_embeddings(),
mm_data_embedding_func=self.get_image_features,
placeholder_token_ids=[mm_inputs.im_token_id] + mm_inputs.pad_values,
)
input_ids = input_ids.clamp( mm_input = forward_batch.merge_mm_inputs()
min=0, max=self.get_input_embeddings().num_embeddings - 1 placeholder_token_ids = (
([mm_input.im_token_id] + [item.pad_value for item in mm_input.mm_items])
if forward_batch.contains_mm_inputs()
else []
) )
if inputs_embeds is None: hidden_states = general_mm_embed_routine(
inputs_embeds = self.llm.get_input_embeddings(input_ids) input_ids=input_ids,
if (
not forward_batch.forward_mode.is_decode()
and self.config.init_audio
and forward_batch.contains_audio_inputs()
):
mm_input = forward_batch.merge_mm_inputs()
inputs_embeds = self.get_omni_embedding(
input_ids=input_ids,
multimodal_input=mm_input,
input_embeds=inputs_embeds,
forward_mode=forward_batch.forward_mode,
chunk_length=self.config.audio_chunk_length,
stream_input=False,
)
forward_batch.mm_inputs = None
hidden_states = self.llm.model(
input_ids=None,
positions=positions,
forward_batch=forward_batch, forward_batch=forward_batch,
input_embeds=inputs_embeds, language_model=self.llm,
) image_data_embedding_func=self.get_image_feature,
audio_data_embedding_func=self.get_audio_feature,
return self.logits_processor( placeholder_token_ids=placeholder_token_ids,
input_ids, hidden_states, self.llm.lm_head, forward_batch positions=positions,
) )
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -54,12 +54,12 @@ from sglang.srt.managers.mm_utils import ( ...@@ -54,12 +54,12 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix, flatten_nested_list
RawImageType = Union[Image.Image, torch.Tensor] RawImageType = Union[Image.Image, torch.Tensor]
...@@ -661,7 +661,7 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: ...@@ -661,7 +661,7 @@ def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
return tuple(int(x) for x in version_str.split(".")) return tuple(int(x) for x in version_str.split("."))
class MiniCPMVBaseModel(nn.Module): class MiniCPMBaseModel(nn.Module):
""" """
The abstract class of MiniCPMV can only be inherited, but cannot be The abstract class of MiniCPMV can only be inherited, but cannot be
instantiated. instantiated.
...@@ -853,7 +853,7 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -853,7 +853,7 @@ class MiniCPMVBaseModel(nn.Module):
return vlm_embedding, vision_hidden_states return vlm_embedding, vision_hidden_states
def get_input_embeddings(self) -> nn.Embedding: def get_input_embeddings(self) -> nn.Embedding:
return self.llm.get_input_embedding() return self.llm.get_input_embeddings()
def forward( def forward(
self, self,
...@@ -862,23 +862,14 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -862,23 +862,14 @@ class MiniCPMVBaseModel(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = general_mm_embed_routine( hidden_states = general_mm_embed_routine(
input_ids=input_ids, input_ids=input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(), image_data_embedding_func=self.get_image_feature,
mm_data_embedding_func=self.get_image_features, language_model=self.llm,
)
hidden_states = self.llm.model(
input_ids=None,
positions=positions, positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
)
return self.logits_processor(
input_ids, hidden_states, self.llm.lm_head, forward_batch
) )
return hidden_states
def init_llm( def init_llm(
self, self,
...@@ -913,11 +904,11 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -913,11 +904,11 @@ class MiniCPMVBaseModel(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def get_image_features(self, image_inputs: MultimodalInputs) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
class MiniCPMV2_6(MiniCPMVBaseModel): class MiniCPMV2_6(MiniCPMBaseModel):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -1023,14 +1014,13 @@ class MiniCPMV2_6(MiniCPMVBaseModel): ...@@ -1023,14 +1014,13 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
) )
return vision_embedding return vision_embedding
def get_image_features( def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
self,
image_inputs: MultimodalInputs,
) -> torch.Tensor:
# list of tensors # list of tensors
pixel_values = image_inputs.pixel_values pixel_values = flatten_nested_list([item.pixel_values for item in items])
tgt_sizes = torch.stack(
tgt_sizes = image_inputs.tgt_sizes flatten_nested_list([item.tgt_size for item in items]), dim=0
)
assert len(pixel_values) == tgt_sizes.shape[0]
device = self.vpm.embeddings.position_embedding.weight.device device = self.vpm.embeddings.position_embedding.weight.device
dtype = self.vpm.embeddings.position_embedding.weight.dtype dtype = self.vpm.embeddings.position_embedding.weight.dtype
...@@ -1040,10 +1030,10 @@ class MiniCPMV2_6(MiniCPMVBaseModel): ...@@ -1040,10 +1030,10 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item() max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item()
assert isinstance(max_patches, int) assert isinstance(max_patches, int)
all_pixel_values = torch.nn.utils.rnn.pad_sequence( all_pixel_values = torch.nn.utils.rnn.pad_sequence(
all_pixel_values_lst, batch_first=True, padding_value=0.0 all_pixel_values_lst, batch_first=True, padding_value=0.0
) )
B, L, _ = all_pixel_values.shape B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros( patch_attn_mask = torch.zeros(
......
...@@ -796,14 +796,16 @@ class MllamaForConditionalGeneration(nn.Module): ...@@ -796,14 +796,16 @@ class MllamaForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config.text_config) self.logits_processor = LogitsProcessor(config.text_config)
self.capture_mode = False self.capture_mode = False
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
pixel_values = image_inputs.pixel_values pixel_values = torch.cat(
pad_values = image_inputs.pad_values [item.pixel_values for item in mm_inputs.mm_items], dim=0
)
pad_values = [item.pad_value for item in mm_inputs.mm_items]
num_concurrent_media, num_tiles = pixel_values.shape[1:3] num_concurrent_media, num_tiles = pixel_values.shape[1:3]
num_patches = self.vision_model.num_patches num_patches = self.vision_model.num_patches
image_len = num_concurrent_media * num_tiles * num_patches image_len = num_concurrent_media * num_tiles * num_patches
image_inputs.num_image_tokens = image_len mm_inputs.num_image_tokens = image_len
pad_ids = pad_values * ((image_len + len(pad_values)) // len(pad_values)) pad_ids = pad_values * ((image_len + len(pad_values)) // len(pad_values))
...@@ -815,10 +817,16 @@ class MllamaForConditionalGeneration(nn.Module): ...@@ -815,10 +817,16 @@ class MllamaForConditionalGeneration(nn.Module):
# pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res) # pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res)
max_num_images = max_num_tiles = bs = 0 max_num_images = max_num_tiles = bs = 0
for i, im in enumerate(forward_batch.mm_inputs): for i, mm_input in enumerate(forward_batch.mm_inputs):
if not forward_batch.encoder_cached[i] and im is not None:
max_num_images = max(max_num_images, im.pixel_values.shape[1]) if not forward_batch.encoder_cached[i] and mm_input is not None:
max_num_tiles = max(max_num_tiles, im.pixel_values.shape[2]) pixel_values = torch.cat(
[item.pixel_values for item in mm_input.mm_items], dim=0
)
# max_num_images = max(max_num_images, sum(1 if item.is_image() else 0 for item in mm_input.items))
max_num_images = max(max_num_images, pixel_values.shape[1])
max_num_tiles = max(max_num_tiles, pixel_values.shape[2])
bs += 1 bs += 1
if max_num_images * max_num_tiles * bs == 0: if max_num_images * max_num_tiles * bs == 0:
...@@ -842,17 +850,24 @@ class MllamaForConditionalGeneration(nn.Module): ...@@ -842,17 +850,24 @@ class MllamaForConditionalGeneration(nn.Module):
) )
i = 0 i = 0
encoder_lens_need = [] encoder_lens_need = []
for k, im in enumerate(forward_batch.mm_inputs):
if forward_batch.encoder_cached[k] or im is None: for k, mm_input in enumerate(forward_batch.mm_inputs):
if forward_batch.encoder_cached[k] or mm_input is None:
continue continue
encoder_lens_need.append(forward_batch.encoder_lens[k]) encoder_lens_need.append(forward_batch.encoder_lens[k])
for j in range(im.pixel_values.shape[1]): pixel_values = torch.cat(
img = im.pixel_values[0, j] [item.pixel_values for item in mm_input.mm_items], dim=0
)
for j in range(pixel_values.shape[1]):
img = pixel_values[0, j]
num_tiles = img.shape[0] num_tiles = img.shape[0]
batched_images[i, j, :num_tiles] = img batched_images[i, j, :num_tiles] = img
batched_ar_ids[i, j] = im.aspect_ratio_ids[0, j] batched_ar_ids[i, j] = mm_input.mm_items[0].aspect_ratio_id[0, j]
batched_ar_mask[i, j, :num_tiles] = im.aspect_ratio_mask[0, j]
batched_ar_mask[i, j, :num_tiles] = mm_input.mm_items[
0
].aspect_ratio_mask[0, j]
i += 1 i += 1
return batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need return batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need
......
...@@ -261,11 +261,14 @@ class Qwen2Model(nn.Module): ...@@ -261,11 +261,14 @@ class Qwen2Model(nn.Module):
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
if hasattr(self.config, "scale_emb"): if hasattr(self.config, "scale_emb"):
return self.embed_tokens(input_ids) * self.config.scale_emb return self.get_input_embeddings()(input_ids) * self.config.scale_emb
else: else:
return self.embed_tokens(input_ids) return self.get_input_embeddings()(input_ids)
def get_input_embeddings(self) -> nn.Embedding:
return self.embed_tokens
def forward( def forward(
self, self,
...@@ -358,10 +361,10 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -358,10 +361,10 @@ class Qwen2ForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embedding(input_ids)
def get_input_embedding(self) -> nn.Embedding: def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens return self.model.embed_tokens
@torch.no_grad() @torch.no_grad()
......
...@@ -30,22 +30,13 @@ import torch ...@@ -30,22 +30,13 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from transformers import AutoModel, Qwen2VLConfig from transformers import Qwen2VLConfig
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
from transformers.models.qwen2_5_vl import Qwen2_5_VLProcessor
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig,
Qwen2_5_VLVisionConfig, Qwen2_5_VLVisionConfig,
) )
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
)
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
...@@ -57,7 +48,7 @@ from sglang.srt.managers.mm_utils import ( ...@@ -57,7 +48,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.models.qwen2 import Qwen2Model
...@@ -513,19 +504,24 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -513,19 +504,24 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs # Get all special token IDs
im_start_id: int = image_inputs.im_start_id im_start_id: int = mm_inputs.im_start_id
im_end_id: int = image_inputs.im_end_id im_end_id: int = mm_inputs.im_end_id
media_token_pairs = [(im_start_id, im_end_id)] media_token_pairs = [(im_start_id, im_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, mm_inputs)
return pattern.pad_input_tokens(input_ids, image_inputs) def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
# in qwen-vl, last dim is the same
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor: pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
pixel_values = image_input.pixel_values.type(self.visual.dtype) self.visual.dtype
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws) )
image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert image_grid_thws.dim() == 2, image_grid_thws.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws)
return image_embeds return image_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor: def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
...@@ -570,18 +566,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -570,18 +566,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
f"(3, seq_len) positions, but got {positions.size()}" f"(3, seq_len) positions, but got {positions.size()}"
) )
inputs_embeds = general_mm_embed_routine( hidden_states = general_mm_embed_routine(
input_ids=input_ids, input_ids=input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(), language_model=self.model,
mm_data_embedding_func=self.get_image_feature, image_data_embedding_func=self.get_image_feature,
)
hidden_states = self.model(
input_ids=None,
positions=positions, positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
) )
if not get_embedding: if not get_embedding:
...@@ -594,9 +584,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -594,9 +584,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
("qkv_proj", "k_proj", "k"), (".qkv_proj", ".k_proj", "k"),
("qkv_proj", "v_proj", "v"), (".qkv_proj", ".v_proj", "v"),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
] ]
......
...@@ -45,7 +45,7 @@ from sglang.srt.managers.mm_utils import ( ...@@ -45,7 +45,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.models.qwen2 import Qwen2Model
...@@ -472,18 +472,24 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -472,18 +472,24 @@ class Qwen2VLForConditionalGeneration(nn.Module):
# Use grid_t * grid_w * grid_h to pad tokens for each image # Use grid_t * grid_w * grid_h to pad tokens for each image
# add replaced padding by unique image hash # add replaced padding by unique image hash
def pad_input_ids(self, input_ids: List[int], multi_modal_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs # Get all special token IDs
im_start_id: int = multi_modal_inputs.im_start_id im_start_id: int = mm_inputs.im_start_id
im_end_id: int = multi_modal_inputs.im_end_id im_end_id: int = mm_inputs.im_end_id
media_token_pairs = [(im_start_id, im_end_id)] media_token_pairs = [(im_start_id, im_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, multi_modal_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
pixel_values = image_input.pixel_values.type(self.visual.dtype) # in qwen-vl, last dim is the same
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws) pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
self.visual.dtype
)
image_grid_thws = torch.concat([item.image_grid_thws for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert image_grid_thws.dim() == 2, image_grid_thws.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thws)
return image_embeds return image_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor: def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
...@@ -527,27 +533,20 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -527,27 +533,20 @@ class Qwen2VLForConditionalGeneration(nn.Module):
"multimodal section rotary embedding requires " "multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}" f"(3, seq_len) positions, but got {positions.size()}"
) )
hidden_states = general_mm_embed_routine(
inputs_embeds = general_mm_embed_routine(
input_ids=input_ids, input_ids=input_ids,
forward_batch=forward_batch, forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(), language_model=self.model,
mm_data_embedding_func=self.get_image_feature, image_data_embedding_func=self.get_image_feature,
)
hidden_states = self.model(
input_ids=None,
positions=positions, positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
) )
if not get_embedding: if get_embedding:
return self.pooler(hidden_states, forward_batch)
else:
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
else:
return self.pooler(hidden_states, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -897,6 +897,7 @@ def v1_chat_generate_request( ...@@ -897,6 +897,7 @@ def v1_chat_generate_request(
request_ids: List[str] = None, request_ids: List[str] = None,
): ):
input_ids = [] input_ids = []
prompts = []
sampling_params_list = [] sampling_params_list = []
image_data_list = [] image_data_list = []
audio_data_list = [] audio_data_list = []
...@@ -916,6 +917,7 @@ def v1_chat_generate_request( ...@@ -916,6 +917,7 @@ def v1_chat_generate_request(
# - audio_data: None or a list of audio strings (URLs). # - audio_data: None or a list of audio strings (URLs).
# None skips any image processing in GenerateReqInput. # None skips any image processing in GenerateReqInput.
strict_tag = None strict_tag = None
prompt = ""
if not isinstance(request.messages, str): if not isinstance(request.messages, str):
# Apply chat template and its stop strings. # Apply chat template and its stop strings.
tools = None tools = None
...@@ -1005,11 +1007,13 @@ def v1_chat_generate_request( ...@@ -1005,11 +1007,13 @@ def v1_chat_generate_request(
image_data = None image_data = None
audio_data = None audio_data = None
modalities = [] modalities = []
prompt = request.messages
input_ids.append(prompt_ids) input_ids.append(prompt_ids)
return_logprobs.append(request.logprobs) return_logprobs.append(request.logprobs)
logprob_start_lens.append(-1) logprob_start_lens.append(-1)
top_logprobs_nums.append(request.top_logprobs or 0) top_logprobs_nums.append(request.top_logprobs or 0)
lora_paths.append(request.lora_path) lora_paths.append(request.lora_path)
prompts.append(prompt)
sampling_params = { sampling_params = {
"temperature": request.temperature, "temperature": request.temperature,
...@@ -1063,10 +1067,14 @@ def v1_chat_generate_request( ...@@ -1063,10 +1067,14 @@ def v1_chat_generate_request(
audio_data_list.append(audio_data) audio_data_list.append(audio_data)
modalities_list.append(modalities) modalities_list.append(modalities)
if len(all_requests) == 1: if len(all_requests) == 1:
if isinstance(input_ids[0], str): if tokenizer_manager.model_config.is_multimodal:
prompt_kwargs = {"text": input_ids[0]} # processor will need text input
prompt_kwargs = {"text": prompts[0]}
else: else:
prompt_kwargs = {"input_ids": input_ids[0]} if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids[0]}
else:
prompt_kwargs = {"input_ids": input_ids[0]}
sampling_params_list = sampling_params_list[0] sampling_params_list = sampling_params_list[0]
image_data_list = image_data_list[0] image_data_list = image_data_list[0]
audio_data_list = audio_data_list[0] audio_data_list = audio_data_list[0]
...@@ -1076,10 +1084,14 @@ def v1_chat_generate_request( ...@@ -1076,10 +1084,14 @@ def v1_chat_generate_request(
modalities_list = modalities_list[0] modalities_list = modalities_list[0]
lora_paths = lora_paths[0] lora_paths = lora_paths[0]
else: else:
if isinstance(input_ids[0], str): if tokenizer_manager.model_config.is_multimodal:
prompt_kwargs = {"text": input_ids} # processor will need text input
prompt_kwargs = {"text": prompts}
else: else:
prompt_kwargs = {"input_ids": input_ids} if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids}
else:
prompt_kwargs = {"input_ids": input_ids}
adapted_request = GenerateReqInput( adapted_request = GenerateReqInput(
**prompt_kwargs, **prompt_kwargs,
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Common utilities.""" """Common utilities."""
import base64 import base64
import builtins import builtins
import ctypes import ctypes
...@@ -54,6 +53,7 @@ import torch.distributed ...@@ -54,6 +53,7 @@ import torch.distributed
import torch.distributed as dist import torch.distributed as dist
import triton import triton
import zmq import zmq
from decord import VideoReader, cpu
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from packaging import version as pkg_version from packaging import version as pkg_version
from PIL import Image from PIL import Image
...@@ -513,13 +513,18 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra ...@@ -513,13 +513,18 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
import soundfile as sf import soundfile as sf
from scipy.signal import resample from scipy.signal import resample
# print(f"loading {audio_file}")
# Load audio data # Load audio data
if isinstance(audio_file, bytes): if isinstance(audio_file, bytes):
audio, original_sr = sf.read(BytesIO(audio_file)) audio, original_sr = sf.read(BytesIO(audio_file))
elif audio_file.startswith("data:"): elif audio_file.startswith("data:"):
audio_file = audio_file.split(",")[1] audio_file = audio_file.split(",")[1]
audio, original_sr = sf.read(BytesIO(base64.b64decode(audio_file))) audio, original_sr = sf.read(BytesIO(base64.b64decode(audio_file)))
elif audio_file.startswith("http://") or audio_file.startswith("https://"):
timeout = int(os.getenv("REQUEST_TIMEOUT", "5"))
response = requests.get(audio_file, stream=True, timeout=timeout)
audio_file = BytesIO(response.content)
response.close()
audio, original_sr = sf.read(audio_file)
elif isinstance(audio_file, str): elif isinstance(audio_file, str):
audio, original_sr = sf.read(audio_file) audio, original_sr = sf.read(audio_file)
else: else:
...@@ -537,6 +542,30 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra ...@@ -537,6 +542,30 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
return audio return audio
def encode_video(video_path, frame_count_limit=None):
if not os.path.exists(video_path):
logger.error(f"Video {video_path} does not exist")
return []
if frame_count_limit == 0:
return []
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_indices = [i for i in range(0, len(vr), sample_fps)]
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
frame_indices = uniform_sample(frame_indices, frame_count_limit)
frames = vr.get_batch(frame_indices).asnumpy()
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
return frames
def load_image(image_file: Union[str, bytes]) -> tuple[Image, tuple[int, int]]: def load_image(image_file: Union[str, bytes]) -> tuple[Image, tuple[int, int]]:
image = image_size = None image = image_size = None
...@@ -1796,3 +1825,12 @@ def retry( ...@@ -1796,3 +1825,12 @@ def retry(
traceback.print_exc() traceback.print_exc()
time.sleep(delay) time.sleep(delay)
def flatten_nested_list(nested_list):
if isinstance(nested_list, list):
return [
item for sublist in nested_list for item in flatten_nested_list(sublist)
]
else:
return [nested_list]
...@@ -155,9 +155,7 @@ class TestOpenAIVisionServer(CustomTestCase): ...@@ -155,9 +155,7 @@ class TestOpenAIVisionServer(CustomTestCase):
"content": [ "content": [
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": IMAGE_MAN_IRONING_URL},
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
},
"modalities": "multi-images", "modalities": "multi-images",
}, },
{ {
...@@ -399,14 +397,14 @@ class TestOpenAIVisionServer(CustomTestCase): ...@@ -399,14 +397,14 @@ class TestOpenAIVisionServer(CustomTestCase):
{ {
"role": "user", "role": "user",
"content": [ "content": [
{
"type": "text",
"text": prompt,
},
{ {
"type": "audio_url", "type": "audio_url",
"audio_url": {"url": f"{audio_file_name}"}, "audio_url": {"url": f"{audio_file_name}"},
}, },
{
"type": "text",
"text": prompt,
},
], ],
} }
] ]
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import unittest import unittest
from io import BytesIO from io import BytesIO
from typing import List
import numpy as np import numpy as np
import requests import requests
...@@ -14,7 +15,11 @@ from transformers import AutoModel, AutoProcessor, AutoTokenizer ...@@ -14,7 +15,11 @@ from transformers import AutoModel, AutoProcessor, AutoTokenizer
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.conversation import generate_chat_conv from sglang.srt.conversation import generate_chat_conv
from sglang.srt.managers.mm_utils import embed_mm_inputs from sglang.srt.managers.mm_utils import embed_mm_inputs
from sglang.srt.managers.schedule_batch import MultimodalInputs from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.openai_api.protocol import ChatCompletionRequest from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -195,14 +200,35 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase): ...@@ -195,14 +200,35 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
# sglang # sglang
model = self.get_sglang_model() model = self.get_sglang_model()
input_ids = inputs["input_ids"].to(self.device).flatten() input_ids = inputs["input_ids"].to(self.device).flatten()
pixel_values = inputs["pixel_values"]
tgt_sizes = inputs["tgt_sizes"]
pixel_values_flat: List[torch.Tensor] = []
tgt_sizes_flat: List[torch.Tensor] = []
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
# per image
if len(pixel_b) != len(tgt_b):
raise ValueError(
"Inconsistent N lengths, found: "
f"{len(pixel_b)} vs {len(tgt_b)}"
)
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
pixel_values_flat += [pixel_n]
tgt_sizes_flat += [tgt_n]
sglang_output = embed_mm_inputs( sglang_output = embed_mm_inputs(
mm_input=MultimodalInputs( mm_inputs=MultimodalInputs(
pixel_values=inputs["pixel_values"][0], mm_items=[
tgt_sizes=inputs["tgt_sizes"][0], MultimodalDataItem(
pixel_values=pixel_values_flat,
tgt_size=tgt_sizes_flat,
modality=Modality.IMAGE,
pad_value=self.processor.tokenizer.unk_token_id,
)
]
), ),
input_ids=input_ids, input_ids=input_ids,
input_embedding=model.get_input_embeddings(), input_embedding=model.get_input_embeddings(),
mm_data_embedding_func=model.get_image_features, image_data_embedding_func=model.get_image_feature,
placeholder_token_ids=[ placeholder_token_ids=[
self.processor.tokenizer.unk_token_id, self.processor.tokenizer.unk_token_id,
], ],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment