Unverified Commit 4685fbb8 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[VLM] Support chunk prefill for VLM (#6355)


Co-authored-by: default avataryizhang2077 <1109276519@qq.com>
parent 0a4fc73b
......@@ -116,6 +116,10 @@ class ModelConfig:
self.is_audio_model = enable_multimodal and is_audio_model(
self.hf_config.architectures
)
self.is_multimodal_chunked_prefill_supported = (
enable_multimodal
and is_multimodal_chunked_prefill_supported(self.hf_config.architectures)
)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
......@@ -574,6 +578,21 @@ def is_encoder_decoder_model(model_architectures: List[str]):
return "MllamaForConditionalGeneration" in model_architectures
def is_multimodal_chunked_prefill_supported(model_architectures: List[str]):
"""Check if chunked prefill is supported for a MultiModal model."""
unsupported = [
"Grok1VForCausalLM",
"Grok1AForCausalLM",
"LlavaLlamaForCausalLM",
"MllamaForConditionalGeneration",
"CLIPModel",
]
if any(multi_model_arch in unsupported for multi_model_arch in model_architectures):
return False
else:
return True
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
......
......@@ -16,10 +16,15 @@ from sglang.srt.managers.schedule_batch import (
MultimodalInputs,
global_server_args_dict,
)
from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import flatten_nested_list, print_warning_once
from sglang.utils import logger
logger = logging.getLogger(__name__)
# NOTE: Using the shared logger from sglang.utils instead of creating a module-specific logger
# to ensure consistent logging behavior across the codebase. This prevents issues with log
# propagation that can cause some log messages (like 'server is fired up') to not appear
# in the console when multimodal support is enabled.
class MultiModalityDataPaddingPattern:
......@@ -189,26 +194,137 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
return output_ids_tensor.tolist()
embedding_cache = None
def init_embedding_cache(max_size: int):
global embedding_cache
embedding_cache = MultiModalCache(max_size)
def get_embedding_hash(embedding_items: List[MultimodalDataItem]) -> int:
hash_list = [item.hash for item in embedding_items]
return hash(tuple(hash_list))
def get_embedding_chunk(
embedding: torch.Tensor,
extend_prefix_len: int,
extend_seq_len: int,
items_offset: List[Tuple[int, int]],
) -> Tuple[torch.Tensor, int, int]:
"""
Extract a chunk of embeddings based on the specified prefix length, sequence length, and offset ranges.
Args:
embedding: The full embedding tensor to extract a chunk from
extend_prefix_len: The starting position (prefix length) for extraction
extend_seq_len: The number of tokens to extract
items_offset: List of [start, end] offset ranges for multimodal items in the input sequence
Returns:
A tuple containing:
- The extracted embedding chunk as a tensor
- The start index used for extraction
- The end index used for extraction
Note:
If there's no overlap between the requested range and the offset ranges,
an empty tensor is returned with zeros for start and end indices.
"""
start_index, end_index = 0, 0
extend_start_index = extend_prefix_len
extend_end_index = extend_prefix_len + extend_seq_len - 1
for start, end in items_offset:
if extend_start_index >= start and extend_start_index <= end:
start_index += extend_start_index - start
elif extend_start_index > end:
start_index += end - start + 1
if extend_end_index >= start and extend_end_index <= end:
end_index += extend_end_index - start + 1
elif extend_end_index > end:
end_index += end - start + 1
# some models embedding is 3-dim, reshape it to 2-dim
embedding = embedding.reshape(-1, embedding.shape[-1])
embedding_chunk = embedding[start_index:end_index]
return embedding_chunk, start_index, end_index
def get_embedding_and_mask(
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
embedding_items: List[MultimodalDataItem],
placeholder_tensor: torch.Tensor,
input_ids: torch.Tensor,
):
items_size: List[int],
prefix_length: List[int],
extend_length: List[int],
items_offset_list: List[List[Tuple[int, int]]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get the multimodal embedding and its mask from input_ids
Generate multimodal embeddings and create a mask for identifying their positions in the input sequence.
Args:
data_embedding_func: Function that generates embeddings for multimodal items
embedding_items: List of multimodal items to embed
placeholder_tensor: Tensor containing token IDs that serve as placeholders for multimodal content
input_ids: The input token IDs tensor
items_size: Cumulative sizes of multimodal items per request
prefix_length: Prefix lengths for each request
extend_length: Sequence lengths for each request
items_offset_list: List of offset ranges for multimodal items in each request
Returns:
A tuple containing:
- The generated embeddings tensor
- A boolean mask tensor indicating where these embeddings should be placed
Raises:
AssertionError: If the number of multimodal tokens in input_ids doesn't match
the number of tokens in the generated embeddings
"""
# 1. Get the embedding
embedding = data_embedding_func(embedding_items)
# Calculate embedding for each request, try to get it from cache to avoid repeated calculation
embedding_list = []
for i in range(len(items_size) - 1):
if items_size[i] == items_size[i + 1]:
continue
embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
items_offset = items_offset_list[i]
embedding_items_hash = get_embedding_hash(embedding_items_per_req)
# if all items has been prefixed, we do not need to calculate embedding
if all([offset_end < prefix_length[i] for _, offset_end in items_offset]):
continue
embedding_per_req = embedding_cache.get(embedding_items_hash)
if embedding_per_req is None:
embedding_per_req = data_embedding_func(embedding_items_per_req)
if not embedding_cache.put(embedding_items_hash, embedding_per_req):
print_warning_once(
"Multimodal embedding cache is full. Consider increasing the "
"`SGLANG_VLM_CACHE_SIZE_MB` environment variable."
)
embedding_per_req_chunk, _, end_index = get_embedding_chunk(
embedding=embedding_per_req,
extend_prefix_len=prefix_length[i],
extend_seq_len=extend_length[i],
items_offset=items_offset,
)
# remove this item from cache if chunk reaches to the end
embedding_per_req_length = (
embedding_per_req.shape[0]
if embedding_per_req.dim() == 2
else embedding_per_req.shape[0] * embedding_per_req.shape[1]
)
if end_index == embedding_per_req_length:
embedding_cache.free(embedding_items_hash)
embedding_list.append(embedding_per_req_chunk)
if len(embedding_list) == 0:
return None, None
embedding = torch.concat(embedding_list, dim=0)
# 2. Check the embedding
if embedding.dim() == 2:
num_mm_tokens_in_embedding = embedding.shape[0]
else:
num_mm_tokens_in_embedding = embedding.shape[0] * embedding.shape[1]
# the mask of multimodal tokens from input_ids
num_mm_tokens_in_embedding = embedding.shape[0]
special_multimodal_mask = torch.isin(
input_ids,
placeholder_tensor,
......@@ -222,9 +338,6 @@ def get_embedding_and_mask(
"tokens from multimodal embeddings."
)
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
# a fix may be cache the unfinished multimodal embedding for future reuse, determine the tokens to embed with
# extend_start_loc and extend_seq_lens
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
if chunked_prefill_size != -1:
logger.warning(
......@@ -245,7 +358,9 @@ def get_embedding_and_mask(
def embed_mm_inputs(
mm_inputs: MultimodalInputs,
mm_inputs_list: List[MultimodalInputs],
extend_prefix_lens: List[int],
extend_seq_lens: List[int],
input_ids: torch.Tensor,
input_embedding: nn.Embedding,
image_data_embedding_func: Callable[
......@@ -257,125 +372,133 @@ def embed_mm_inputs(
placeholder_tokens: dict[Modality, List[int]] = None,
) -> Optional[torch.Tensor]:
"""
Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
Args:
placeholder_tokens: denoting the token of multimodal data in input_ids.
If none, the pad_values of multimodal items are used
Embed multimodal inputs and integrate them with text token embeddings.
Args:
mm_inputs_list: List of multimodal inputs to process
extend_prefix_lens: Prefix lengths for each request
extend_seq_lens: Sequence lengths for each request
input_ids: Input token IDs tensor
input_embedding: Embedding layer for text tokens
image_data_embedding_func: Function to embed image data
audio_data_embedding_func: Function to embed audio data
placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None)
Returns:
final embedding: Optional[torch.Tensor]
Returns:
Combined embedding tensor with multimodal content integrated
"""
if mm_inputs is None:
if mm_inputs_list is None:
return None
# 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
# we assume that multimodal data are represented with its pad_values in input_ids
# See `pad_input_ids` for more detail
item_flatten_list = []
for mm_inputs in mm_inputs_list:
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
# if placeholder_tokens is specified
if placeholder_tokens is not None:
placeholder_token_ids = flatten_nested_list(
[placeholder_token for placeholder_token in placeholder_tokens.values()]
)
else:
placeholder_token_ids = [item.pad_value for item in mm_inputs.mm_items]
assert isinstance(placeholder_token_ids[0], int)
placeholder_tensor = torch.tensor(placeholder_token_ids, device=input_ids.device)
placeholder_masks = torch.isin(input_ids, placeholder_tensor)
appearing_pad_values = torch.unique(
input_ids[placeholder_masks], return_counts=False
)
embeddings, masks = [], []
if appearing_pad_values.numel() == 0:
# all been prefixed
inputs_embeds = input_embedding(input_ids)
else:
appearing_items = [
item
for item in mm_inputs.mm_items
if item.pad_value is not None and item.pad_value in appearing_pad_values
]
using_all_items = False
if len(appearing_items) == 0:
# This happens mostly when arg placeholder_token_ids is passed
logger.warning(
"No multimodal data item's pad value exist in placeholder ids. Using all items"
# 2. Get multimodal embedding separately
# TODO: make this more generic
# Try get image embedding if any
if (
any(True for item in item_flatten_list if item.is_image())
and image_data_embedding_func
):
items = [item for item in item_flatten_list if item.is_image()]
placeholder_tensor = torch.tensor(
[item.pad_value for item in items],
device=input_ids.device,
)
# calculate per request items length offset
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
items_offsets = []
for i, mm_inputs in enumerate(mm_inputs_list):
image_items = [item for item in mm_inputs.mm_items if item.is_image()]
items_size[i + 1] = len(image_items)
items_offsets.append(
flatten_nested_list(
[
item.image_offsets
for item in mm_inputs.mm_items
if item.is_image()
]
)
)
using_all_items = True
appearing_items = mm_inputs.mm_items
items_size = torch.cumsum(items_size, dim=0).tolist()
embeddings, masks = [], []
embedding, mask = get_embedding_and_mask(
data_embedding_func=image_data_embedding_func,
embedding_items=items,
placeholder_tensor=placeholder_tensor,
input_ids=input_ids,
items_size=items_size,
prefix_length=extend_prefix_lens,
extend_length=extend_seq_lens,
items_offset_list=items_offsets,
)
embeddings += [embedding]
masks += [mask]
# 2. Get multimodal embedding separately
# TODO: make this more generic
# Try get image embedding if any
if (
any(True for item in appearing_items if item.is_image())
and image_data_embedding_func
):
items = [item for item in appearing_items if item.is_image()]
embedding, mask = get_embedding_and_mask(
data_embedding_func=image_data_embedding_func,
embedding_items=items,
placeholder_tensor=(
# use the specified modality token to identify the location to embed
placeholder_tokens[Modality.IMAGE]
if using_all_items
else torch.tensor(
[item.pad_value for item in items],
device=input_ids.device,
)
),
input_ids=input_ids,
# Try get audio embedding if any
if (
any(True for item in item_flatten_list if item.is_audio())
and audio_data_embedding_func
):
items = [item for item in item_flatten_list if item.is_audio()]
placeholder_tensor = torch.tensor(
[item.pad_value for item in items],
device=input_ids.device,
)
items_offsets = []
# calculate per request items length offset
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
for i, mm_inputs in enumerate(mm_inputs_list):
audio_items = [item for item in mm_inputs.mm_items if item.is_audio()]
items_size[i + 1] = len(audio_items)
items_offsets.append(
flatten_nested_list(
[
item.audio_offsets
for item in mm_inputs.mm_items
if item.is_audio()
]
)
)
embeddings += [embedding]
masks += [mask]
items_size = torch.cumsum(items_size, dim=0)
# Try get audio embedding if any
if (
any(True for item in appearing_items if item.is_audio())
and audio_data_embedding_func
):
items = [item for item in appearing_items if item.is_audio()]
embedding, mask = get_embedding_and_mask(
data_embedding_func=audio_data_embedding_func,
embedding_items=items,
placeholder_tensor=(
placeholder_tokens[Modality.AUDIO]
if using_all_items
else torch.tensor(
[item.pad_value for item in items],
device=input_ids.device,
)
),
input_ids=input_ids,
)
embeddings += [embedding]
masks += [mask]
# 3. Get input embeddings
vocab_size = input_embedding.num_embeddings
# Important: clamp after getting original multimodal regions
# Clamp input ids. This is because the input_ids for the multimodal tokens are
# filled with the hash values of the multimodal for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=vocab_size - 1)
inputs_embeds = input_embedding(input_ids)
# 4. Scatter embeddings into input embedding
for embedding, mask in zip(embeddings, masks):
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds = inputs_embeds.masked_scatter(
mask,
embedding.to(inputs_embeds.device, inputs_embeds.dtype),
)
embedding, mask = get_embedding_and_mask(
data_embedding_func=audio_data_embedding_func,
embedding_items=items,
placeholder_tensor=placeholder_tensor,
input_ids=input_ids,
items_size=items_size,
prefix_length=extend_prefix_lens,
extend_length=extend_seq_lens,
items_offset_list=items_offsets,
)
embeddings += [embedding]
masks += [mask]
# 3. Get input embeddings
vocab_size = input_embedding.num_embeddings
# Important: clamp after getting original multimodal regions
# Clamp input ids. This is because the input_ids for the multimodal tokens are
# filled with the hash values of the multimodal for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=vocab_size - 1)
inputs_embeds = input_embedding(input_ids)
# 4. scatter embeddings into input embedding
for embedding, mask in zip(embeddings, masks):
if embedding is None or mask is None:
continue
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds = inputs_embeds.masked_scatter(
mask,
embedding.to(inputs_embeds.device, inputs_embeds.dtype),
)
return inputs_embeds
......@@ -393,16 +516,19 @@ def general_mm_embed_routine(
**kwargs,
) -> torch.Tensor:
"""
A general wrapper function to get final input embeds from multimodal models with a language model as causal model
Args:
placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
image_data_embedding_func : the function returning the image embedding
audio_data_embedding_func : the function returning the image embedding
Process multimodal inputs and forward through language model.
Returns:
forwarded hidden states
Args:
input_ids: Input token IDs tensor
forward_batch: Batch information for model forward pass
language_model: Base language model to use
image_data_embedding_func: Function to embed image data
audio_data_embedding_func: Function to embed audio data
placeholder_tokens: Token IDs for multimodal placeholders
**kwargs: Additional arguments passed to language model
Returns:
Hidden states from language model forward pass
"""
assert hasattr(language_model, "get_input_embeddings")
embed_tokens = language_model.get_input_embeddings()
......@@ -410,9 +536,23 @@ def general_mm_embed_routine(
not forward_batch.forward_mode.is_decode()
and forward_batch.contains_mm_inputs()
):
mm_input = forward_batch.merge_mm_inputs()
mm_inputs_list = [
mm_input for mm_input in forward_batch.mm_inputs if mm_input is not None
]
extend_prefix_lens = [
prefix_len
for i, prefix_len in enumerate(forward_batch.extend_prefix_lens_cpu)
if forward_batch.mm_inputs[i] is not None
]
extend_seq_lens = [
seq_len
for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu)
if forward_batch.mm_inputs[i] is not None
]
inputs_embeds = embed_mm_inputs(
mm_inputs=mm_input,
mm_inputs_list=mm_inputs_list,
extend_prefix_lens=extend_prefix_lens,
extend_seq_lens=extend_seq_lens,
input_ids=input_ids,
input_embedding=embed_tokens,
image_data_embedding_func=image_data_embedding_func,
......
......@@ -5,7 +5,7 @@ import multiprocessing as mp
import os
import re
from abc import ABC, abstractmethod
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
......@@ -343,6 +343,33 @@ class BaseMultimodalProcessor(ABC):
out.normalize()
return out
@staticmethod
def get_mm_items_offset(
input_ids: torch.Tensor, mm_token_id: int
) -> List[Tuple[int, int]]:
"""
Get a set of range for mm_items from input_ids
Example:
input_ids = [1, 2, 3, 3, 3, 4, 3, 3]
mm_token_id = 3
return result = [(2,4),(6,7)]
"""
mask = input_ids == mm_token_id
start_positions = (mask & ~torch.roll(mask, 1)).nonzero(as_tuple=True)[0]
end_positions = (mask & ~torch.roll(mask, -1)).nonzero(as_tuple=True)[0]
return list(zip(start_positions.tolist(), end_positions.tolist()))
@staticmethod
def get_mm_items_offset_by_pair(
input_ids: torch.Tensor, mm_start_id: int, mm_end_id: int
) -> List[Tuple[int, int]]:
indices_start = (input_ids == mm_start_id).nonzero(as_tuple=True)[0] + 1
indices_end = (input_ids == mm_end_id).nonzero(as_tuple=True)[0] - 1
return list(zip(indices_start.tolist(), indices_end.tolist()))
def mm_inputs_are_preprocessed(self, mm_inputs: Optional[list]):
"""Returns true if all images are preprocessed, false if all are not, and error otherwise."""
if not mm_inputs:
......
......@@ -70,8 +70,13 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
items = []
input_ids = res["input_ids"]
image_offsets = self.get_mm_items_offset(
input_ids=input_ids, mm_token_id=self._processor.image_token_id
)
item = MultimodalDataItem(
pixel_values=res["images"],
image_offsets=image_offsets,
modality=Modality.IMAGE,
image_emb_mask=images_seq_mask,
image_spatial_crop=batched_images_spatial_crop,
......@@ -80,6 +85,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
return {
"mm_items": items,
"input_ids": res["input_ids"].tolist(),
"input_ids": input_ids.tolist(),
"im_token_id": self._processor.image_token_id,
}
......@@ -61,6 +61,11 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
)
items = []
input_ids = ret["input_ids"].flatten()
image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.hf_config.image_token_index,
)
for i, image in enumerate(base_output.images):
if images_are_preprocessed:
pixel_values = image.pixel_values
......@@ -73,12 +78,13 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
pixel_values=pixel_values,
precomputed_features=precomputed_features,
modality=Modality.IMAGE,
image_offsets=image_offsets[i],
)
items += [item]
return {
"mm_items": items,
"input_ids": ret["input_ids"].flatten().tolist(),
"input_ids": input_ids.tolist(),
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
}
......@@ -209,7 +209,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
return None
pixel_values = torch.cat(pixel_values, dim=0)
items = [MultimodalDataItem(pixel_values=pixel_values, modality=Modality.IMAGE)]
for idx, num_patches in enumerate(num_patches_list):
image_tokens = (
......@@ -220,10 +219,21 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
input_text = input_text.replace("<image>", image_tokens, 1)
tokenizer = self._processor
input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].flatten()
image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.img_context_token_id,
)
items = [
MultimodalDataItem(
pixel_values=pixel_values,
modality=Modality.IMAGE,
image_offsets=image_offsets,
)
]
return {
"input_ids": tokenizer(input_text, return_tensors="pt")["input_ids"]
.flatten()
.tolist(),
"input_ids": input_ids.tolist(),
"mm_items": items,
"im_start_id": self.img_start_token_id,
"im_end_id": self.img_end_token_id,
......
......@@ -45,15 +45,21 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
prompt=base_out.input_text,
images=images,
)
input_ids = res["input_ids"].flatten()
image_offsets = self.get_mm_items_offset(
input_ids=input_ids, mm_token_id=processor.image_id
)
return {
"mm_items": [
MultimodalDataItem(
pixel_values=res["pixel_values"],
image_emb_mask=res["images_emb_mask"],
image_offsets=image_offsets,
modality=Modality.IMAGE,
)
],
"input_ids": res["input_ids"].flatten().tolist(),
"input_ids": input_ids.tolist(),
"im_start_id": processor.image_start_id,
"im_end_id": processor.image_end_id,
"im_token_id": processor.image_id,
......
import asyncio
import math
from typing import List, Union
import torch
from PIL import Image
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor,
)
......@@ -57,13 +52,19 @@ class KimiVLImageProcessor(SGLangBaseProcessor):
input_text=base_output.input_text,
images=base_output.images,
)
input_ids = ret["input_ids"].flatten()
image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.im_token_id,
)
return {
"input_ids": ret["input_ids"].flatten().tolist(),
"input_ids": input_ids.tolist(),
"mm_items": [
MultimodalDataItem(
pixel_values=ret["pixel_values"],
image_grid_thws=ret["image_grid_hws"],
modality=Modality.IMAGE,
image_offsets=image_offsets,
)
],
"im_token_id": self.im_token_id,
......
import asyncio
import importlib
from typing import List, Optional, Union
import numpy as np
......
......@@ -69,6 +69,8 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audio_start_id = tokenizer.audio_start_id
audio_end_id = tokenizer.audio_end_id
im_start_id = tokenizer.im_start_id
im_end_id = tokenizer.im_end_id
im_token_id = tokenizer.unk_id
pixel_values = res["pixel_values"]
tgt_sizes = res["tgt_sizes"]
......@@ -104,9 +106,20 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
pixel_values = pixel_values_flat
items = []
input_ids = res["input_ids"].flatten()
image_offsets = self.get_mm_items_offset_by_pair(
input_ids=input_ids, mm_start_id=im_start_id, mm_end_id=im_end_id
)
slice_offsets = self.get_mm_items_offset_by_pair(
input_ids=input_ids, mm_start_id=slice_start_id, mm_end_id=slice_end_id
)
image_offsets.extend(slice_offsets)
image_offsets = sorted(image_offsets)
if len(pixel_values) != 0:
item = MultimodalDataItem(
pixel_values=pixel_values,
image_offsets=image_offsets,
tgt_size=tgt_sizes_flat,
modality=Modality.IMAGE,
)
......@@ -117,21 +130,30 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
and res["audio_features"] is not None
and len(res["audio_features"]) != 0
):
if audio_start_id is not None and audio_end_id is not None:
audio_offsets = self.get_mm_items_offset_by_pair(
input_ids=input_ids,
mm_start_id=audio_start_id,
mm_end_id=audio_end_id,
)
else:
audio_offsets = None
item = MultimodalDataItem(
audio_features=[res["audio_features"]],
audio_feature_lens=res["audio_feature_lens"],
audio_offsets=audio_offsets,
modality=Modality.AUDIO,
)
items += [item]
return {
"mm_items": items,
"input_ids": res["input_ids"].flatten().tolist(),
"input_ids": input_ids.tolist(),
"audio_start_id": audio_start_id,
"audio_end_id": audio_end_id,
"im_token_id": im_token_id,
"im_start_id": tokenizer.im_start_id,
"im_end_id": tokenizer.im_end_id,
"im_start_id": im_start_id,
"im_end_id": im_end_id,
"slice_start_id": slice_start_id,
"slice_end_id": slice_end_id,
}
......@@ -135,11 +135,17 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
processor_output["im_end_id"] = self.eoi_token_index
processor_output["im_token_id"] = self.image_token_index
image_offsets = self.get_mm_items_offset(
input_ids=torch.tensor(processor_output["input_ids"]),
mm_token_id=self.image_token_index,
)
# Add metadata for image processing
processor_output["mm_items"] = [
MultimodalDataItem(
pixel_values=processor_output["pixel_values"],
modality=Modality.IMAGE,
image_offsets=image_offsets,
)
]
......
import asyncio
import math
from typing import List, Optional, Union
from typing import List, Union
import numpy as np
from transformers import PretrainedConfig
from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
)
......@@ -12,11 +10,7 @@ from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
MultimodalSpecialTokens,
)
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.pixtral import PixtralVisionModel
......@@ -108,15 +102,21 @@ class PixtralProcessor(BaseMultimodalProcessor):
)
if "pixel_values" in processor_output:
input_ids = processor_output["input_ids"].view(-1)
image_offsets = self.get_mm_items_offset(
input_ids=input_ids,
mm_token_id=self.image_token_id,
)
mm_items = [
MultimodalDataItem(
pixel_values=processor_output["pixel_values"],
image_sizes=processor_output["image_sizes"],
modality=Modality.IMAGE,
image_offsets=image_offsets,
)
]
input_ids = processor_output["input_ids"].view(-1).tolist()
input_ids = input_ids.tolist()
processor_output.update(
input_ids=input_ids,
mm_items=mm_items,
......
......@@ -135,6 +135,9 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
images=None if images_are_preprocessed else base_output.images,
)
input_ids = ret["input_ids"].flatten().tolist()
image_offsets = self.get_mm_items_offset(
input_ids=ret["input_ids"].flatten(), mm_token_id=self.image_token_id
)
image_grid_thw = None
video_grid_thw = None # TODO
items = []
......@@ -175,6 +178,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
image_grid_thws=image_grid_thw,
video_grid_thws=video_grid_thw,
precomputed_features=precomputed_features,
image_offsets=image_offsets,
modality=Modality.IMAGE,
)
]
......
......@@ -197,6 +197,7 @@ class MultimodalDataItem:
audio_features: Union[torch.Tensor, np.ndarray] = None
audio_feature_lens: Optional[List[torch.Tensor]] = None
audio_offsets: Optional[List[Tuple[int, int]]] = None
precomputed_features: Optional[Union[torch.Tensor, np.ndarray]] = None
......@@ -1097,7 +1098,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
else:
self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)
assert len(self.out_cache_loc) == self.extend_num_tokens
assert (
len(self.out_cache_loc) == self.extend_num_tokens
), f"Expected {len(self.out_cache_loc)}, got {self.extend_num_tokens}"
def prepare_for_extend(self):
self.forward_mode = ForwardMode.EXTEND
......
......@@ -102,6 +102,7 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
from sglang.srt.managers.mm_utils import init_embedding_cache
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
MultimodalInputs,
......@@ -2282,6 +2283,10 @@ def run_scheduler_process(
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
embedding_cache_size = 100
if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
init_embedding_cache(embedding_cache_size * 1024 * 1024)
# Create a scheduler and run the event loop
try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
......
from typing import Dict
import torch
class MultiModalCache:
"""MultiModalCache is used to store vlm encoder results"""
def __init__(
self,
max_size: int,
):
self.max_size = max_size
self.mm_cache: Dict[int, torch.Tensor] = {}
self.current_size = 0
def put(self, mm_hash: int, embedding: torch.Tensor) -> bool:
if mm_hash in self.mm_cache:
return True
data_size = self._get_tensor_size(embedding)
if self.current_size + data_size > self.max_size:
return False
self.mm_cache[mm_hash] = embedding
self.current_size += data_size
return True
def get(self, mm_hash: int) -> torch.Tensor:
return self.mm_cache.get(mm_hash)
def free(self, mm_hash: int) -> bool:
if mm_hash not in self.mm_cache:
return False
old_embedding = self.mm_cache.pop(mm_hash)
self.current_size -= self._get_tensor_size(old_embedding)
return True
def clear(self):
self.mm_cache.clear()
self.current_size = 0
def _get_tensor_size(self, embedding: torch.Tensor):
return embedding.element_size() * embedding.numel()
def __len__(self):
return len(self.mm_cache)
......@@ -166,6 +166,9 @@ class ModelRunner:
self.is_draft_worker = is_draft_worker
self.is_generation = model_config.is_generation
self.is_multimodal = model_config.is_multimodal
self.is_multimodal_chunked_prefill_supported = (
model_config.is_multimodal_chunked_prefill_supported
)
self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
)
......@@ -389,12 +392,15 @@ class ModelRunner:
if self.is_multimodal:
self.mem_fraction_static *= 0.90
logger.info(
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} because this is a multimodal model."
)
server_args.chunked_prefill_size = -1
logger.info(
"Automatically turn off --chunked-prefill-size for multimodal model."
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
f"because this is a multimodal model."
)
if not self.is_multimodal_chunked_prefill_supported:
server_args.chunked_prefill_size = -1
logger.info(
f"Automatically turn of --chunked-prefill-size as it is not supported for "
f"{self.model_config.hf_config.model_type}"
)
if not self.use_mla_backend:
server_args.disable_chunked_prefix_cache = True
......
......@@ -1826,22 +1826,12 @@ class MiniCPMO(MiniCPMBaseModel):
**kwargs: Any,
) -> torch.Tensor:
mm_input = forward_batch.merge_mm_inputs()
placeholder_token_ids = (
([mm_input.im_token_id] + [item.pad_value for item in mm_input.mm_items])
if forward_batch.contains_mm_inputs()
else []
)
hidden_states = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.llm,
image_data_embedding_func=self.get_image_feature,
audio_data_embedding_func=self.get_audio_feature,
placeholder_tokens={
Modality.IMAGE: placeholder_token_ids,
Modality.AUDIO: placeholder_token_ids,
},
positions=positions,
)
return hidden_states
......
......@@ -294,20 +294,24 @@ class TestOpenAIVisionServer(CustomTestCase):
print("-" * 30)
# Add assertions to validate the video response
assert "iPod" in video_response or "device" in video_response, video_response
assert (
"iPod" in video_response or "device" in video_response
), f"video_response: {video_response}, should contain 'iPod' or 'device'"
assert (
"man" in video_response
or "person" in video_response
or "individual" in video_response
or "speaker" in video_response
), video_response
), f"video_response: {video_response}, should either have 'man' in video_response, or 'person' in video_response, or 'individual' in video_response or 'speaker' in video_response"
assert (
"present" in video_response
or "examine" in video_response
or "display" in video_response
or "hold" in video_response
)
assert "black" in video_response or "dark" in video_response
), f"video_response: {video_response}, should contain 'present', 'examine', 'display', or 'hold'"
assert (
"black" in video_response or "dark" in video_response
), f"video_response: {video_response}, should contain 'black' or 'dark'"
self.assertIsNotNone(video_response)
self.assertGreater(len(video_response), 0)
......
......@@ -21,7 +21,10 @@ from transformers import (
from sglang import Engine
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.managers.mm_utils import embed_mm_inputs
from sglang.srt.managers.mm_utils import embed_mm_inputs, init_embedding_cache
from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor,
)
from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem,
......@@ -188,6 +191,7 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
.eval()
.to(cls.device)
)
init_embedding_cache(0)
async def test_vlm_embedding_output(self):
"""
......@@ -226,17 +230,41 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
pixel_values_flat += [pixel_n]
tgt_sizes_flat += [tgt_n]
im_start_id, im_end_id = (
self.tokenizer.im_start_id,
self.tokenizer.im_end_id,
)
slice_start_id, slice_end_id = (
self.tokenizer.slice_start_id,
self.tokenizer.slice_end_id,
)
image_offsets = BaseMultimodalProcessor.get_mm_items_offset_by_pair(
input_ids=input_ids, mm_start_id=im_start_id, mm_end_id=im_end_id
)
slice_offsets = BaseMultimodalProcessor.get_mm_items_offset_by_pair(
input_ids=input_ids, mm_start_id=slice_start_id, mm_end_id=slice_end_id
)
image_offsets.extend(slice_offsets)
image_offsets = sorted(image_offsets)
sglang_output = embed_mm_inputs(
mm_inputs=MultimodalInputs(
mm_items=[
MultimodalDataItem(
pixel_values=pixel_values_flat,
tgt_size=tgt_sizes_flat,
modality=Modality.IMAGE,
pad_value=self.processor.tokenizer.unk_token_id,
)
]
),
mm_inputs_list=[
MultimodalInputs(
mm_items=[
MultimodalDataItem(
pixel_values=pixel_values_flat,
image_offsets=image_offsets,
tgt_size=tgt_sizes_flat,
modality=Modality.IMAGE,
pad_value=self.processor.tokenizer.unk_token_id,
)
]
),
],
extend_prefix_lens=[0],
extend_seq_lens=[input_ids.shape[0]],
input_ids=input_ids,
input_embedding=model.get_input_embeddings(),
image_data_embedding_func=model.get_image_feature,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment