Unverified Commit c998d04b authored by Mick's avatar Mick Committed by GitHub
Browse files

vlm: enable radix cache for qwen-vl models (#5349)


Co-authored-by: default avatarXinyuan Tong <justinning0323@outlook.com>
parent 7d0edf3c
...@@ -89,7 +89,7 @@ def set_seed(seed_value): ...@@ -89,7 +89,7 @@ def set_seed(seed_value):
def prepare_samples(eval_args: EvalArgs): def prepare_samples(eval_args: EvalArgs):
print("preparing samples...") print("Preparing samples...")
# Build prompts # Build prompts
set_seed(eval_args.seed) set_seed(eval_args.seed)
...@@ -105,15 +105,40 @@ def prepare_samples(eval_args: EvalArgs): ...@@ -105,15 +105,40 @@ def prepare_samples(eval_args: EvalArgs):
assert len(value) == 1, "key {} has more than one value".format(key) assert len(value) == 1, "key {} has more than one value".format(key)
eval_args.config[key] = value[0] eval_args.config[key] = value[0]
# run for each subject # run for each subject in parallel
sub_dataset_list = [] sub_dataset_list = []
subjects = list(CAT_SHORT2LONG.values()) # Get a fixed list of subjects
for subject in tqdm(CAT_SHORT2LONG.values()): print(f"Loading datasets for {len(subjects)} subjects...")
sub_dataset = load_dataset( with ThreadPoolExecutor() as executor:
eval_args.dataset_path, subject, split=eval_args.split # Submit all load_dataset tasks
) future_to_subject = {
sub_dataset_list.append(sub_dataset) executor.submit(
# break load_dataset, eval_args.dataset_path, subject, split=eval_args.split
): subject
for subject in subjects
}
# Collect results as they complete
results = {}
for future in tqdm(
as_completed(future_to_subject),
total=len(subjects),
desc="Loading datasets",
):
subject = future_to_subject[future]
try:
results[subject] = future.result()
except Exception as exc:
print(f"{subject} generated an exception: {exc}")
# Ensure datasets are added in the original order for consistency
for subject in subjects:
if subject in results:
sub_dataset_list.append(results[subject])
else:
# Handle cases where a dataset failed to load (optional, depends on desired behavior)
print(f"Warning: Dataset for subject '{subject}' could not be loaded.")
# merge all dataset # merge all dataset
dataset = concatenate_datasets(sub_dataset_list) dataset = concatenate_datasets(sub_dataset_list)
...@@ -133,18 +158,25 @@ def prepare_samples(eval_args: EvalArgs): ...@@ -133,18 +158,25 @@ def prepare_samples(eval_args: EvalArgs):
width, height = image.size width, height = image.size
if width * height >= eval_args.image_pixels_limit: if width * height >= eval_args.image_pixels_limit:
return None, True return None, True
image_path = f"{images_path}/image_{i}.png" # Use a unique identifier for the image path to avoid potential collisions if indices reset
image_path = f"{images_path}/image_{sample['id']}.png"
if not os.path.exists(image_path): if not os.path.exists(image_path):
image.save(image_path) image.save(image_path)
sample["image_path"] = image_path sample["image_path"] = image_path
return sample, False return sample, False
print("Processing samples...")
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor() as executor:
# Pass the sample itself to process_sample, index is less reliable now
futures = [ futures = [
executor.submit(process_sample, i, sample) executor.submit(
process_sample, i, sample
) # Keep index i for tqdm maybe? Or remove it. Let's keep it for now.
for i, sample in enumerate(dataset) for i, sample in enumerate(dataset)
] ]
for future in tqdm(as_completed(futures), total=len(futures)): for future in tqdm(
as_completed(futures), total=len(dataset), desc="Processing samples"
):
sample, skipped = future.result() sample, skipped = future.result()
if skipped: if skipped:
skip_count += 1 skip_count += 1
...@@ -152,9 +184,9 @@ def prepare_samples(eval_args: EvalArgs): ...@@ -152,9 +184,9 @@ def prepare_samples(eval_args: EvalArgs):
samples.append(sample) samples.append(sample)
print( print(
f"skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset" f"Skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset"
) )
print("samples have been prepared") print("Samples have been prepared")
return samples return samples
......
...@@ -73,15 +73,14 @@ class ModelConfig: ...@@ -73,15 +73,14 @@ class ModelConfig:
) )
if enable_multimodal is None: if enable_multimodal is None:
if self.hf_config.architectures[0] == "Llama4ForConditionalGeneration": mm_disabled_models = [
"Gemma3ForConditionalGeneration",
"Llama4ForConditionalGeneration",
]
if self.hf_config.architectures[0] in mm_disabled_models:
enable_multimodal = False enable_multimodal = False
logger.info( logger.info(
"Multimodal is disabled for Llama4. To enable it, set --enable-llama4-multimodal." f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
)
elif self.hf_config.architectures[0] == "Gemma3ForConditionalGeneration":
enable_multimodal = False
logger.info(
"Multimodal is disabled for Gemma3. To enable it, set --enable-gemma3-multimodal."
) )
else: else:
enable_multimodal = True enable_multimodal = True
......
...@@ -877,127 +877,163 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -877,127 +877,163 @@ class MRotaryEmbedding(RotaryEmbedding):
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key return query, key
# Copied from https://github.com/huggingface/transformers/blob/c8e0e603de9b3d49161a15fe6e8ea84badfb5d02/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1439
@staticmethod @staticmethod
def get_input_positions( def get_rope_index(
input_tokens: List[int], spatial_merge_size: int,
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
image_token_id: int, image_token_id: int,
video_token_id: int, video_token_id: int,
vision_start_token_id: int, vision_start_token_id: int,
vision_end_token_id: int, model_type: str,
spatial_merge_size: int,
context_len: int = 0,
seq_len: Optional[int] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
tokens_per_second: Optional[int] = None, tokens_per_second: Optional[int] = None,
) -> Tuple[List[List[int]], int]: input_ids: Optional[torch.LongTensor] = None,
""" image_grid_thw: Optional[torch.LongTensor] = None,
Get mrope input positions and delta value. video_grid_thw: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
:arg **kwargs,
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): ) -> Tuple[torch.Tensor, torch.Tensor]:
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. mrope_position_deltas = []
if input_ids is not None and (
""" image_grid_thw is not None or video_grid_thw is not None
):
if isinstance(image_grid_thw, torch.Tensor): total_input_ids = input_ids
image_grid_thw = image_grid_thw.tolist() position_ids = torch.ones(
if isinstance(video_grid_thw, torch.Tensor): 3,
video_grid_thw = video_grid_thw.tolist() input_ids.shape[0],
input_ids.shape[1],
input_tokens_tensor = torch.tensor(input_tokens) dtype=input_ids.dtype,
vision_start_indices = torch.argwhere( device=input_ids.device,
input_tokens_tensor == vision_start_token_id
).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
second_per_grid_t = 0
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
if second_per_grid_ts is not None:
second_per_grid_t = second_per_grid_ts[video_index]
else:
second_per_grid_t = 1.0
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
t_index = (
torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
* second_per_grid_t
* tokens_per_second
).flatten()
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
) )
st = ed + llm_grid_t * llm_grid_h * llm_grid_w image_index, video_index = 0, 0
for i, input_ids in enumerate(total_input_ids):
if st < len(input_tokens): image_nums, video_nums = 0, 0
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 vision_start_indices = torch.argwhere(
text_len = len(input_tokens) - st input_ids == vision_start_token_id
llm_pos_ids_list.append( ).squeeze(1)
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
second_per_grid_t = 0
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
if second_per_grid_ts is not None:
second_per_grid_t = second_per_grid_ts[video_index]
else:
second_per_grid_t = 1.0
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
text_len = ed - st
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
if model_type == "qwen2_5_vl":
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
expanded_range = range_tensor.expand(
-1, llm_grid_h * llm_grid_w
)
time_tensor = (
expanded_range * second_per_grid_t * tokens_per_second
)
time_tensor_long = time_tensor.long()
t_index = time_tensor_long.flatten()
elif model_type == "qwen2_vl":
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
else:
raise RuntimeError("Unimplemented")
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = (
llm_pos_ids_list[-1].max() + 1
if len(llm_pos_ids_list) > 0
else 0
)
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, :] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(
llm_positions.max() + 1 - len(total_input_ids[i])
)
mrope_position_deltas = torch.tensor(
mrope_position_deltas, device=input_ids.device
).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
s = input_ids.shape[1]
position_ids = torch.arange(s)
position_ids = (
position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
) )
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) -1, keepdim=True
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() )[0]
llm_positions = llm_positions[:, context_len:seq_len] mrope_position_deltas = max_position_ids + 1 - s
return position_ids, mrope_position_deltas
return llm_positions.tolist(), mrope_position_delta
@staticmethod @staticmethod
def get_next_input_positions( def get_next_input_positions(
......
...@@ -463,6 +463,8 @@ class EmbeddingReqInput: ...@@ -463,6 +463,8 @@ class EmbeddingReqInput:
image_data: Optional[ image_data: Optional[
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]] Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
] = None ] = None
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
audio_data: Optional[Union[List[str], str]] = None
# The token ids for text; one can either specify text or input_ids. # The token ids for text; one can either specify text or input_ids.
input_ids: Optional[Union[List[List[int]], List[int]]] = None input_ids: Optional[Union[List[List[int]], List[int]]] = None
# The request id. # The request id.
......
...@@ -10,12 +10,13 @@ import torch ...@@ -10,12 +10,13 @@ import torch
from torch import nn from torch import nn
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem, MultimodalDataItem,
MultimodalInputs, MultimodalInputs,
global_server_args_dict, global_server_args_dict,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import print_warning_once from sglang.srt.utils import flatten_nested_list, print_warning_once
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -97,31 +98,80 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern) ...@@ -97,31 +98,80 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
return padded_ids return padded_ids
class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern): class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPattern):
"""In this pattern, data tokens should be represented as repetitions of a single token """In this pattern, data tokens should be represented as repetitions of a single token
e.g. <image><image>....<image>, or <audio><audio>...<audio> e.g. <image><image>....<image>, or <audio><audio>...<audio>
""" """
def __init__(self, image_token_id: torch.Tensor) -> None: def __init__(self, token_ids: List[int]) -> None:
self.image_token_id = image_token_id self.token_ids = token_ids
def pad_input_tokens(self, input_ids: List[int], mm_inputs) -> List[int]: def pad_input_tokens(
self, input_ids: List[int], mm_inputs: MultimodalInputs
) -> List[int]:
""" """
This function will replace the data-tokens in between with pad_values accordingly Finds contiguous regions of tokens matching `self.token_ids` in `input_ids`
and replaces each region with the corresponding `pad_value` from `mm_inputs.mm_items`.
""" """
pad_values = [item.pad_value for item in mm_inputs.mm_items] pad_values = [item.pad_value for item in mm_inputs.mm_items]
assert len(pad_values) != 0 if not pad_values:
# No multimodal items, return original input_ids
return input_ids
if not input_ids:
return []
input_ids_tensor = torch.tensor(input_ids) input_ids_tensor = torch.tensor(input_ids)
mask = torch.isin(input_ids_tensor, self.image_token_id) device = input_ids_tensor.device
token_ids_tensor = torch.tensor(self.token_ids, device=device)
mask = torch.isin(input_ids_tensor, token_ids_tensor)
num_image_tokens = mask.sum().item() if not mask.any():
repeated_pad_values = torch.tensor(pad_values).repeat( # No tokens match token_ids, return original input_ids
num_image_tokens // len(pad_values) + 1 return input_ids
)[:num_image_tokens]
# Find contiguous regions
padded_mask = torch.cat(
(
torch.tensor([False], device=device),
mask,
torch.tensor([False], device=device),
)
)
# Find indices where the mask value changes
diff_indices = torch.where(padded_mask[1:] != padded_mask[:-1])[0]
# Start indices are where False changes to True
starts = diff_indices[::2]
# End indices are where True changes to False (exclusive index)
ends = diff_indices[1::2]
# Check if the number of regions matches the number of pad values
if len(starts) != len(pad_values):
# Maybe log a warning here?
num_regions = len(starts)
num_pad_values = len(pad_values)
if num_regions > 0 and num_pad_values > 0:
pad_values = (pad_values * (num_regions // num_pad_values + 1))[
:num_regions
]
else: # If no regions or no pad_values, this loop won't run anyway.
pad_values = [] # Ensure pad_values is empty if starts is empty
# Create a copy to modify
output_ids_tensor = input_ids_tensor.clone()
# Replace tokens in each region with the corresponding pad value
# Ensure we don't iterate if pad_values became empty due to mismatch and num_regions=0
for i in range(min(len(starts), len(pad_values))):
start_idx = starts[i]
end_idx = ends[i]
pad_value = pad_values[i]
if pad_value is not None: # Ensure pad_value is not None before assignment
output_ids_tensor[start_idx:end_idx] = pad_value
else:
logger.warning(f"Skipping region {i} due to None pad_value.")
input_ids_tensor[mask] = repeated_pad_values return output_ids_tensor.tolist()
return input_ids_tensor.tolist()
def get_embedding_and_mask( def get_embedding_and_mask(
...@@ -150,7 +200,6 @@ def get_embedding_and_mask( ...@@ -150,7 +200,6 @@ def get_embedding_and_mask(
).unsqueeze(-1) ).unsqueeze(-1)
num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item() num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding: if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
logger.warning( logger.warning(
f"Number of tokens in multimodal embedding does not match those in the input text." f"Number of tokens in multimodal embedding does not match those in the input text."
...@@ -190,13 +239,13 @@ def embed_mm_inputs( ...@@ -190,13 +239,13 @@ def embed_mm_inputs(
audio_data_embedding_func: Callable[ audio_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor [List[MultimodalDataItem]], torch.Tensor
] = None, ] = None,
placeholder_token_ids: List[int] = None, placeholder_tokens: dict[Modality, List[int]] = None,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
""" """
Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
Args: Args:
placeholder_token_ids: denoting the token of multimodal data in input_ids. placeholder_tokens: denoting the token of multimodal data in input_ids.
If none, the pad_values of multimodal items are used If none, the pad_values of multimodal items are used
Returns: Returns:
...@@ -208,9 +257,17 @@ def embed_mm_inputs( ...@@ -208,9 +257,17 @@ def embed_mm_inputs(
# 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values # 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
# we assume that multimodal data are represented with its pad_values in input_ids # we assume that multimodal data are represented with its pad_values in input_ids
placeholder_token_ids = placeholder_token_ids or [ # See `pad_input_ids` for more detail
item.pad_value for item in mm_inputs.mm_items
] # if placeholder_tokens is specified
if placeholder_tokens is not None:
placeholder_token_ids = flatten_nested_list(
[placeholder_token for placeholder_token in placeholder_tokens.values()]
)
else:
placeholder_token_ids = [item.pad_value for item in mm_inputs.mm_items]
assert isinstance(placeholder_token_ids[0], int)
placeholder_tensor = torch.tensor(placeholder_token_ids, device=input_ids.device) placeholder_tensor = torch.tensor(placeholder_token_ids, device=input_ids.device)
...@@ -233,7 +290,7 @@ def embed_mm_inputs( ...@@ -233,7 +290,7 @@ def embed_mm_inputs(
using_all_items = False using_all_items = False
if len(appearing_items) == 0: if len(appearing_items) == 0:
# This happens mostly when arg placeholder_token_ids is passed # This happens mostly when arg placeholder_token_ids is passed
logger.warning_once( logger.warning(
"No multimodal data item's pad value exist in placeholder ids. Using all items" "No multimodal data item's pad value exist in placeholder ids. Using all items"
) )
using_all_items = True using_all_items = True
...@@ -253,7 +310,8 @@ def embed_mm_inputs( ...@@ -253,7 +310,8 @@ def embed_mm_inputs(
data_embedding_func=image_data_embedding_func, data_embedding_func=image_data_embedding_func,
embedding_items=items, embedding_items=items,
placeholder_tensor=( placeholder_tensor=(
placeholder_tensor # use the specified modality token to identify the location to embed
placeholder_tokens[Modality.IMAGE]
if using_all_items if using_all_items
else torch.tensor( else torch.tensor(
[item.pad_value for item in items], [item.pad_value for item in items],
...@@ -275,7 +333,7 @@ def embed_mm_inputs( ...@@ -275,7 +333,7 @@ def embed_mm_inputs(
data_embedding_func=audio_data_embedding_func, data_embedding_func=audio_data_embedding_func,
embedding_items=items, embedding_items=items,
placeholder_tensor=( placeholder_tensor=(
placeholder_tensor placeholder_tokens[Modality.AUDIO]
if using_all_items if using_all_items
else torch.tensor( else torch.tensor(
[item.pad_value for item in items], [item.pad_value for item in items],
...@@ -296,7 +354,7 @@ def embed_mm_inputs( ...@@ -296,7 +354,7 @@ def embed_mm_inputs(
input_ids.clamp_(min=0, max=vocab_size - 1) input_ids.clamp_(min=0, max=vocab_size - 1)
inputs_embeds = input_embedding(input_ids) inputs_embeds = input_embedding(input_ids)
# 4. scatter embeddings into input embedding # 4. Scatter embeddings into input embedding
for embedding, mask in zip(embeddings, masks): for embedding, mask in zip(embeddings, masks):
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device) mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds = inputs_embeds.masked_scatter( inputs_embeds = inputs_embeds.masked_scatter(
...@@ -316,7 +374,7 @@ def general_mm_embed_routine( ...@@ -316,7 +374,7 @@ def general_mm_embed_routine(
audio_data_embedding_func: Callable[ audio_data_embedding_func: Callable[
[List[MultimodalDataItem]], torch.Tensor [List[MultimodalDataItem]], torch.Tensor
] = None, ] = None,
placeholder_token_ids: List[int] = None, placeholder_tokens: dict[Modality, List[int]] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -328,7 +386,6 @@ def general_mm_embed_routine( ...@@ -328,7 +386,6 @@ def general_mm_embed_routine(
audio_data_embedding_func : the function returning the image embedding audio_data_embedding_func : the function returning the image embedding
Returns: Returns:
inputs_embedding
forwarded hidden states forwarded hidden states
""" """
...@@ -346,9 +403,9 @@ def general_mm_embed_routine( ...@@ -346,9 +403,9 @@ def general_mm_embed_routine(
input_embedding=embed_tokens, input_embedding=embed_tokens,
image_data_embedding_func=image_data_embedding_func, image_data_embedding_func=image_data_embedding_func,
audio_data_embedding_func=audio_data_embedding_func, audio_data_embedding_func=audio_data_embedding_func,
placeholder_token_ids=placeholder_token_ids, placeholder_tokens=placeholder_tokens,
) )
# once used, mm_inputs is useless # once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
# just being defensive here # just being defensive here
forward_batch.mm_inputs = None forward_batch.mm_inputs = None
else: else:
......
...@@ -8,6 +8,7 @@ from typing import List, Optional ...@@ -8,6 +8,7 @@ from typing import List, Optional
import numpy as np import numpy as np
import PIL import PIL
from PIL import Image
from transformers import BaseImageProcessorFast from transformers import BaseImageProcessorFast
from sglang.srt.managers.schedule_batch import Modality from sglang.srt.managers.schedule_batch import Modality
...@@ -92,7 +93,12 @@ class BaseMultimodalProcessor(ABC): ...@@ -92,7 +93,12 @@ class BaseMultimodalProcessor(ABC):
@abstractmethod @abstractmethod
async def process_mm_data_async( async def process_mm_data_async(
self, image_data, input_text, max_req_input_len, **kwargs self,
image_data,
input_text,
request_obj,
max_req_input_len,
**kwargs,
): ):
pass pass
...@@ -104,6 +110,8 @@ class BaseMultimodalProcessor(ABC): ...@@ -104,6 +110,8 @@ class BaseMultimodalProcessor(ABC):
from decord import VideoReader, cpu from decord import VideoReader, cpu
# Before processing inputs # Before processing inputs
if not image_data or len(image_data) == 0:
return []
estimated_frames_list = [] estimated_frames_list = []
for image in image_data: for image in image_data:
if isinstance(image, str) and image.startswith("video:"): if isinstance(image, str) and image.startswith("video:"):
...@@ -215,6 +223,9 @@ class BaseMultimodalProcessor(ABC): ...@@ -215,6 +223,9 @@ class BaseMultimodalProcessor(ABC):
discard_alpha_channel: if True, discards the alpha channel in the returned images discard_alpha_channel: if True, discards the alpha channel in the returned images
""" """
if image_data is None:
image_data = []
if isinstance(multimodal_tokens.image_token, int): if isinstance(multimodal_tokens.image_token, int):
multimodal_tokens.image_token = ( multimodal_tokens.image_token = (
self._processor.tokenizer.convert_ids_to_tokens( self._processor.tokenizer.convert_ids_to_tokens(
...@@ -229,6 +240,8 @@ class BaseMultimodalProcessor(ABC): ...@@ -229,6 +240,8 @@ class BaseMultimodalProcessor(ABC):
prompt = self._processor.tokenizer.decode(prompt) prompt = self._processor.tokenizer.decode(prompt)
else: else:
prompt = prompt prompt = prompt
assert isinstance(prompt, str)
if return_text: if return_text:
import re import re
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
from typing import List, Union
import torch import torch
...@@ -35,7 +36,13 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor): ...@@ -35,7 +36,13 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
self.IMAGE_TOKEN = "<image>" self.IMAGE_TOKEN = "<image>"
async def process_mm_data_async( async def process_mm_data_async(
self, image_data, input_ids, request_obj, max_req_input_len, *args, **kwargs self,
image_data: List[Union[str, bytes]],
input_text,
request_obj,
max_req_input_len,
*args,
**kwargs
): ):
if not image_data: if not image_data:
return None return None
...@@ -45,7 +52,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor): ...@@ -45,7 +52,7 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
image_token = self.IMAGE_TOKEN image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data( base_output = self.load_mm_data(
input_ids, input_text,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token), multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
......
from typing import List, Union from typing import List, Union
from transformers.utils import logging
from sglang.srt.managers.multimodal_processor import ( from sglang.srt.managers.multimodal_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor, BaseMultimodalProcessor as SGLangBaseProcessor,
) )
...@@ -13,7 +11,6 @@ from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration ...@@ -13,7 +11,6 @@ from sglang.srt.models.gemma3_mm import Gemma3ForConditionalGeneration
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py # Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
# will be removed in the future # will be removed in the future
logger = logging.get_logger(__name__)
class Gemma3SGLangImageProcessor(SGLangBaseProcessor): class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
...@@ -28,7 +25,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): ...@@ -28,7 +25,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
input_ids, input_text,
request_obj, request_obj,
max_req_input_len, max_req_input_len,
*args, *args,
...@@ -41,7 +38,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor): ...@@ -41,7 +38,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
image_token = self.IMAGE_TOKEN image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_ids, prompt=input_text,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token), multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
......
...@@ -17,7 +17,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor): ...@@ -17,7 +17,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
input_ids, input_text,
request_obj, request_obj,
max_req_input_len, max_req_input_len,
**kwargs, **kwargs,
...@@ -31,7 +31,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor): ...@@ -31,7 +31,7 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
processor = self._processor processor = self._processor
base_out = self.load_mm_data( base_out = self.load_mm_data(
prompt=input_ids, prompt=input_text,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens( multimodal_tokens=MultimodalSpecialTokens(
image_token=processor.image_token image_token=processor.image_token
......
...@@ -51,9 +51,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -51,9 +51,10 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
input_ids, input_text,
request_obj, request_obj,
max_req_input_len, max_req_input_len,
**kwargs,
): ):
audio_data = request_obj.audio_data audio_data = request_obj.audio_data
if not image_data and not audio_data: if not image_data and not audio_data:
...@@ -64,7 +65,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor): ...@@ -64,7 +65,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audio_data = [audio_data] audio_data = [audio_data]
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_ids, prompt=input_text,
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
audio_data=audio_data, audio_data=audio_data,
image_data=image_data, image_data=image_data,
......
...@@ -5,6 +5,7 @@ from typing import List, Union ...@@ -5,6 +5,7 @@ from typing import List, Union
import torch import torch
from PIL import Image from PIL import Image
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.managers.multimodal_processors.base_processor import ( from sglang.srt.managers.multimodal_processors.base_processor import (
BaseMultimodalProcessor as SGLangBaseProcessor, BaseMultimodalProcessor as SGLangBaseProcessor,
) )
...@@ -27,6 +28,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -27,6 +28,8 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
self.image_token_id = hf_config.image_token_id self.image_token_id = hf_config.image_token_id
self.video_token_id = hf_config.video_token_id self.video_token_id = hf_config.video_token_id
self.vision_start_token_id = hf_config.vision_start_token_id
self.vision_end_token_id = hf_config.vision_end_token_id
self.NUM_TOKEN_PER_FRAME = 770 self.NUM_TOKEN_PER_FRAME = 770
self.IMAGE_FACTOR = 28 self.IMAGE_FACTOR = 28
self.MIN_PIXELS = 4 * 28 * 28 self.MIN_PIXELS = 4 * 28 * 28
...@@ -36,20 +39,18 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -36,20 +39,18 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
async def process_mm_data_async( async def process_mm_data_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
prompt, input_text,
request_obj, request_obj,
max_req_input_len, max_req_input_len,
*args, *args,
**kwargs, **kwargs,
): ):
if not image_data:
return None
if isinstance(image_data, str): if isinstance(image_data, str):
image_data = [image_data] image_data = [image_data]
image_token = self.IMAGE_TOKEN image_token = self.IMAGE_TOKEN
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=prompt, prompt=input_text,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=image_token), multimodal_tokens=MultimodalSpecialTokens(image_token=image_token),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
...@@ -116,29 +117,53 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor): ...@@ -116,29 +117,53 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
async def resize_image_async(image): async def resize_image_async(image):
return resize_image(image) return resize_image(image)
resize_tasks = [resize_image_async(image) for image in base_output.images] if base_output.images:
resized_images = await asyncio.gather(*resize_tasks) resize_tasks = [resize_image_async(image) for image in base_output.images]
base_output.images = await asyncio.gather(*resize_tasks)
ret = self.process_mm_data( ret = self.process_mm_data(
input_text=base_output.input_text, input_text=base_output.input_text,
images=resized_images, images=base_output.images,
) )
image_grid_thws = torch.concat([ret["image_grid_thw"]]) items = []
return {
"input_ids": ret["input_ids"].flatten().tolist(), input_ids = ret["input_ids"].flatten().tolist()
"mm_items": [ if "pixel_values" in ret:
items += [
MultimodalDataItem( MultimodalDataItem(
pixel_values=ret["pixel_values"], pixel_values=ret["pixel_values"],
image_grid_thws=image_grid_thws, image_grid_thws=torch.concat([ret["image_grid_thw"]]),
# TODO # TODO
video_grid_thws=None, video_grid_thws=None,
second_per_grid_ts=ret.get("second_per_grid_ts", None), second_per_grid_ts=ret.get("second_per_grid_ts", None),
modality=Modality.IMAGE, modality=Modality.IMAGE,
) )
], ]
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
spatial_merge_size=self.hf_config.vision_config.spatial_merge_size,
image_token_id=self.image_token_id,
video_token_id=self.video_token_id,
vision_start_token_id=self.vision_start_token_id,
model_type=self.hf_config.model_type,
tokens_per_second=getattr(
self.hf_config.vision_config, "tokens_per_second", None
),
input_ids=torch.tensor(input_ids).unsqueeze(0),
image_grid_thw=ret.get("image_grid_thw", None),
video_grid_thw=ret.get("video_grid_thw", None),
second_per_grid_ts=ret.get("second_per_grid_ts", None),
)
mrope_positions = mrope_positions.squeeze(1)
return {
"input_ids": input_ids,
"mm_items": items,
"im_start_id": self.IM_START_TOKEN_ID, "im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID, "im_end_id": self.IM_END_TOKEN_ID,
"im_token_id": self.image_token_id, "im_token_id": self.image_token_id,
"video_token_id": self.video_token_id, "video_token_id": self.video_token_id,
"mrope_positions": mrope_positions,
"mrope_position_delta": mrope_position_delta,
} }
...@@ -285,6 +285,7 @@ class MultimodalInputs: ...@@ -285,6 +285,7 @@ class MultimodalInputs:
num_image_tokens: Optional[int] = None num_image_tokens: Optional[int] = None
# QWen2-VL related # QWen2-VL related
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[torch.Tensor] = None mrope_position_delta: Optional[torch.Tensor] = None
# image # image
...@@ -310,16 +311,12 @@ class MultimodalInputs: ...@@ -310,16 +311,12 @@ class MultimodalInputs:
assert isinstance(ret.mm_items, list) assert isinstance(ret.mm_items, list)
ret.mm_items = [item for item in ret.mm_items if item.is_valid()] ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
assert len(ret.mm_items) != 0
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example.
for item in ret.mm_items: for item in ret.mm_items:
item.set_pad_value() item.set_pad_value()
optional_args = [ optional_args = [
"mrope_positions",
"mrope_position_delta",
"im_token_id", "im_token_id",
"im_start_id", "im_start_id",
"im_end_id", "im_end_id",
...@@ -350,20 +347,26 @@ class MultimodalInputs: ...@@ -350,20 +347,26 @@ class MultimodalInputs:
merge image inputs when requests are being merged merge image inputs when requests are being merged
""" """
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example.
# args needed to be merged # args needed to be merged
optional_args = [ optional_args = [
"mm_items", "mm_items",
"image_pad_len", "image_pad_len",
"mrope_position_delta",
] ]
for arg in optional_args: for arg in optional_args:
self_arg = getattr(self, arg, None) self_arg = getattr(self, arg, None)
if self_arg is not None: if self_arg is not None:
setattr(self, arg, self_arg + getattr(other, arg)) setattr(self, arg, self_arg + getattr(other, arg))
mrope_positions = self.mrope_positions
if mrope_positions is not None:
if other.mrope_positions is None:
self.mrope_positions = mrope_positions
else:
self.mrope_positions = torch.cat(
[self.mrope_positions, other.mrope_positions], dim=1
)
# other args would be kept intact # other args would be kept intact
......
...@@ -419,7 +419,10 @@ class TokenizerManager: ...@@ -419,7 +419,10 @@ class TokenizerManager:
input_ids = self.tokenizer.encode(input_text) input_ids = self.tokenizer.encode(input_text)
image_inputs: Dict = await self.mm_processor.process_mm_data_async( image_inputs: Dict = await self.mm_processor.process_mm_data_async(
obj.image_data, input_text or input_ids, obj, self.max_req_input_len image_data=obj.image_data,
input_text=input_text or input_ids,
request_obj=obj,
max_req_input_len=self.max_req_input_len,
) )
if image_inputs and "input_ids" in image_inputs: if image_inputs and "input_ids" in image_inputs:
input_ids = image_inputs["input_ids"] input_ids = image_inputs["input_ids"]
......
...@@ -407,8 +407,6 @@ class ForwardBatch: ...@@ -407,8 +407,6 @@ class ForwardBatch:
def _compute_mrope_positions( def _compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch self, model_runner: ModelRunner, batch: ModelWorkerBatch
): ):
device = model_runner.device
hf_config = model_runner.model_config.hf_config
mrope_positions_list = [None] * self.seq_lens.shape[0] mrope_positions_list = [None] * self.seq_lens.shape[0]
if self.forward_mode.is_decode(): if self.forward_mode.is_decode():
for i, _ in enumerate(mrope_positions_list): for i, _ in enumerate(mrope_positions_list):
...@@ -417,93 +415,44 @@ class ForwardBatch: ...@@ -417,93 +415,44 @@ class ForwardBatch:
if batch.multimodal_inputs[i] is None if batch.multimodal_inputs[i] is None
else batch.multimodal_inputs[i].mrope_position_delta else batch.multimodal_inputs[i].mrope_position_delta
) )
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions( mrope_positions_list[i] = torch.tensor(
mrope_position_delta, MRotaryEmbedding.get_next_input_positions(
int(self.seq_lens[i]) - 1, mrope_position_delta,
int(self.seq_lens[i]), int(self.seq_lens[i]) - 1,
int(self.seq_lens[i]),
)
) )
elif self.forward_mode.is_extend(): elif self.forward_mode.is_extend():
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
for i, mm_input in enumerate(batch.multimodal_inputs): for i, mm_input in enumerate(batch.multimodal_inputs):
extend_start_loc, extend_seq_len, extend_prefix_len = ( extend_seq_len, extend_prefix_len = (
extend_start_loc_cpu[i],
batch.extend_seq_lens[i], batch.extend_seq_lens[i],
batch.extend_prefix_lens[i], batch.extend_prefix_lens[i],
) )
if mm_input is None: if mm_input is None:
# text only # text only
mrope_positions = [ mrope_positions = torch.tensor(
[ [
pos [
for pos in range( pos
extend_prefix_len, extend_prefix_len + extend_seq_len for pos in range(
) extend_prefix_len,
extend_prefix_len + extend_seq_len,
)
]
] ]
] * 3 * 3
else:
image_grid_thws_list = [
item.image_grid_thws
for item in mm_input.mm_items
if item.image_grid_thws is not None
]
image_grid_thw = (
None
if len(image_grid_thws_list) == 0
else torch.cat(image_grid_thws_list, dim=0)
)
video_grid_thws_list = [
item.video_grid_thws
for item in mm_input.mm_items
if item.video_grid_thws is not None
]
video_grid_thw = (
None
if len(video_grid_thws_list) == 0
else torch.cat(video_grid_thws_list, dim=0)
) )
else:
second_per_grid_ts_list = [ mrope_positions = mm_input.mrope_positions[
item.second_per_grid_ts :,
for item in mm_input.mm_items extend_prefix_len : extend_prefix_len + extend_seq_len,
if item.second_per_grid_ts is not None
] ]
second_per_grid_ts = (
None
if len(second_per_grid_ts_list) == 0
else torch.cat(second_per_grid_ts_list, dim=0)
)
# TODO: current qwen2-vl do not support radix cache since mrope position calculation
mrope_positions, mrope_position_delta = (
MRotaryEmbedding.get_input_positions(
input_tokens=self.input_ids[
extend_start_loc : extend_start_loc + extend_seq_len
].tolist(),
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
context_len=0,
seq_len=len(self.input_ids),
second_per_grid_ts=second_per_grid_ts,
tokens_per_second=getattr(
hf_config.vision_config, "tokens_per_second", None
),
)
)
batch.multimodal_inputs[i].mrope_position_delta = (
mrope_position_delta
)
mrope_positions_list[i] = mrope_positions mrope_positions_list[i] = mrope_positions
self.mrope_positions = torch.cat( self.mrope_positions = torch.cat(
[torch.tensor(pos, device=device) for pos in mrope_positions_list], [pos.to(device=model_runner.device) for pos in mrope_positions_list],
axis=1, dim=1,
) ).to(device=model_runner.device)
self.mrope_positions = self.mrope_positions.to(torch.int64) self.mrope_positions = self.mrope_positions.to(torch.int64)
def get_max_chunk_capacity(self): def get_max_chunk_capacity(self):
......
...@@ -310,15 +310,6 @@ class ModelRunner: ...@@ -310,15 +310,6 @@ class ModelRunner:
) )
server_args.chunked_prefill_size = -1 server_args.chunked_prefill_size = -1
if self.model_config.hf_config.architectures == [
"Qwen2VLForConditionalGeneration"
] or self.model_config.hf_config.architectures == [
"Qwen2_5_VLForConditionalGeneration"
]:
# TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
logger.info("Automatically disable radix cache for qwen-vl series.")
server_args.disable_radix_cache = True
if server_args.enable_deepep_moe: if server_args.enable_deepep_moe:
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}") logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
......
...@@ -12,7 +12,7 @@ from sglang.srt.configs.deepseekvl2 import ( ...@@ -12,7 +12,7 @@ from sglang.srt.configs.deepseekvl2 import (
from sglang.srt.layers.linear import ReplicatedLinear from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternImageTokens, MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
...@@ -249,8 +249,8 @@ class DeepseekVL2ForCausalLM(nn.Module): ...@@ -249,8 +249,8 @@ class DeepseekVL2ForCausalLM(nn.Module):
weights_loader(param, loaded_weight) weights_loader(param, loaded_weight)
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
helper = MultiModalityDataPaddingPatternImageTokens( helper = MultiModalityDataPaddingPatternMultimodalTokens(
image_token_id=image_inputs.im_token_id [image_inputs.im_token_id]
) )
return helper.pad_input_tokens(input_ids, image_inputs) return helper.pad_input_tokens(input_ids, image_inputs)
......
...@@ -43,6 +43,7 @@ from sglang.srt.managers.mm_utils import ( ...@@ -43,6 +43,7 @@ from sglang.srt.managers.mm_utils import (
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
Modality,
MultimodalDataItem, MultimodalDataItem,
MultimodalInputs, MultimodalInputs,
flatten_nested_list, flatten_nested_list,
...@@ -1834,7 +1835,10 @@ class MiniCPMO(MiniCPMBaseModel): ...@@ -1834,7 +1835,10 @@ class MiniCPMO(MiniCPMBaseModel):
language_model=self.llm, language_model=self.llm,
image_data_embedding_func=self.get_image_feature, image_data_embedding_func=self.get_image_feature,
audio_data_embedding_func=self.get_audio_feature, audio_data_embedding_func=self.get_audio_feature,
placeholder_token_ids=placeholder_token_ids, placeholder_tokens={
Modality.IMAGE: placeholder_token_ids,
Modality.AUDIO: placeholder_token_ids,
},
positions=positions, positions=positions,
) )
return hidden_states return hidden_states
......
...@@ -10,7 +10,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -10,7 +10,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.managers.mm_utils import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternImageTokens, MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
...@@ -53,7 +53,7 @@ class Llama4ForConditionalGeneration(nn.Module): ...@@ -53,7 +53,7 @@ class Llama4ForConditionalGeneration(nn.Module):
# Get all special token IDs # Get all special token IDs
im_token_id: int = mm_inputs.im_token_id im_token_id: int = mm_inputs.im_token_id
pattern = MultiModalityDataPaddingPatternImageTokens(torch.tensor(im_token_id)) pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
return pattern.pad_input_tokens(input_ids, mm_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature( def get_image_feature(
......
...@@ -49,7 +49,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType ...@@ -49,7 +49,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
...@@ -488,11 +488,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -488,11 +488,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs # Get all special token IDs
im_start_id: int = mm_inputs.im_start_id im_token_id: int = mm_inputs.im_token_id
im_end_id: int = mm_inputs.im_end_id pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
media_token_pairs = [(im_start_id, im_end_id)]
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, mm_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
......
...@@ -42,7 +42,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType ...@@ -42,7 +42,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.mm_utils import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternMultimodalTokens,
general_mm_embed_routine, general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
...@@ -490,15 +490,11 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -490,15 +490,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
# Use grid_t * grid_w * grid_h to pad tokens for each image
# add replaced padding by unique image hash
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs # Get all special token IDs
im_start_id: int = mm_inputs.im_start_id im_token_id: int = mm_inputs.im_token_id
im_end_id: int = mm_inputs.im_end_id
media_token_pairs = [(im_start_id, im_end_id)] pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, mm_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment