Unverified Commit 11577ced authored by Mick's avatar Mick Committed by GitHub
Browse files

refactor: bug fixes and refactor for vlm (#4661)

parent ca75741e
...@@ -47,8 +47,9 @@ from sglang.srt.configs.janus_pro import * ...@@ -47,8 +47,9 @@ from sglang.srt.configs.janus_pro import *
from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization import QuantizationConfig from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.managers.multi_modality_padding import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...@@ -1958,17 +1959,24 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel): ...@@ -1958,17 +1959,24 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
def prepare_images_seq_mask( def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor:
self, input_ids: torch.Tensor, image_inputs: ImageInputs pixel_values = image_input.pixel_values
) -> Optional[torch.LongTensor]: bs, n = pixel_values.shape[0:2]
images_seq_mask = torch.isin( pixel_values = pixel_values.to(
input_ids, torch.tensor(image_inputs.pad_values, device=input_ids.device) device=self.vision_model.device, dtype=self.vision_model.dtype
) )
if images_seq_mask.sum() == 0: images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
# sometimes image_inputs is not empty, but input_ids contain no image token because of prefix-cache
return None # [b x n, T2, D]
else: images_embeds = self.aligner(self.vision_model(images))
return images_seq_mask
# [b x n, T2, D] -> [b, n x T2, D]
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
return images_embeds
def get_input_embeddings(self) -> nn.Embedding:
return self.language_model.model.embed_tokens
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -1978,86 +1986,22 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel): ...@@ -1978,86 +1986,22 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = None inputs_embeds = general_mm_embed_routine(
if ( input_ids=input_ids,
forward_batch.image_inputs is not None positions=positions,
and len(forward_batch.image_inputs) != 0 forward_batch=forward_batch,
and forward_batch.image_inputs[0] is not None embed_tokens=self.get_input_embeddings(),
): image_embedding_func=self.get_image_feature,
)
image_inputs = forward_batch.image_inputs[0]
images_seq_mask = self.prepare_images_seq_mask(
input_ids=input_ids, image_inputs=image_inputs
)
if images_seq_mask is not None:
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
inputs_embeds = self.prepare_inputs_embeds(
input_ids=input_ids,
pixel_values=image_inputs.pixel_values,
images_seq_mask=images_seq_mask,
images_emb_mask=image_inputs.images_emb_mask,
)
input_ids = None
if input_ids is not None:
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
return self.language_model( return self.language_model(
input_ids=input_ids, input_ids=None,
positions=positions, positions=positions,
forward_batch=forward_batch, forward_batch=forward_batch,
input_embeds=inputs_embeds, input_embeds=inputs_embeds,
get_embedding=False, get_embedding=False,
) )
def prepare_inputs_embeds(
self,
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
images_seq_mask: torch.LongTensor,
images_emb_mask: torch.BoolTensor,
**_kwargs,
):
"""
Args:
input_ids (torch.LongTensor): [b, T]
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
images_seq_mask (torch.BoolTensor): [b, T]
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
Returns:
input_embeds (torch.Tensor): [b, T, D]
"""
bs, n = pixel_values.shape[0:2]
pixel_values = pixel_values.to(
device=self.vision_model.device, dtype=self.vision_model.dtype
)
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
# [b x n, T2, D]
images_embeds = self.aligner(self.vision_model(images))
# [b x n, T2, D] -> [b, n x T2, D]
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
# [b, n, T2] -> [b, n x T2]
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
# [b, T, D]
# ignore the image embeddings
input_ids[input_ids < 0] = 0
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
# replace with the image embeddings
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
return inputs_embeds
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor): def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
return self.gen_aligner(self.gen_embed(image_ids)) return self.gen_aligner(self.gen_embed(image_ids))
......
import collections from typing import Iterable, List, Optional, Tuple
import itertools
import math
import warnings
from enum import Enum
from functools import partial
from typing import Callable, Iterable, List, Optional, Tuple, Type, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from torch import nn from torch import nn
from sglang.srt.configs import DeepseekVL2Config
from sglang.srt.configs.deepseekvl2 import ( from sglang.srt.configs.deepseekvl2 import (
DeepseekVL2Config, DeepseekVL2Config,
DeepseekVL2MlpProjectorConfig, DeepseekVL2MlpProjectorConfig,
) )
from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
LinearBase,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
...@@ -233,11 +215,11 @@ class DeepseekVL2ForCausalLM(nn.Module): ...@@ -233,11 +215,11 @@ class DeepseekVL2ForCausalLM(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
**kwargs: object, **kwargs: object,
): ):
input_embeds = self.language_model.model.embed_tokens(input_ids) input_embeds = self.language_model.model.embed_tokens(input_ids)
if forward_batch.forward_mode.is_extend() and forward_batch.image_inputs != [ if (
None forward_batch.forward_mode.is_extend()
]: and forward_batch.contains_image_inputs()
):
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy() extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
for idx, image in enumerate(forward_batch.image_inputs): for idx, image in enumerate(forward_batch.image_inputs):
...@@ -245,17 +227,11 @@ class DeepseekVL2ForCausalLM(nn.Module): ...@@ -245,17 +227,11 @@ class DeepseekVL2ForCausalLM(nn.Module):
continue continue
start_idx = extend_start_loc_cpu[idx] start_idx = extend_start_loc_cpu[idx]
end_idx = start_idx + extend_seq_lens_cpu[idx] end_idx = start_idx + extend_seq_lens_cpu[idx]
pixel_values = image.pixel_values.to( images_emb_mask = image.images_emb_mask.to(device="cuda")
device="cuda", dtype=torch.bfloat16 image_features = self.get_image_feature(image)
) input_embeds[start_idx:end_idx] = input_embeds[
image_seq_mask = image.image_seq_mask.to(device="cuda") start_idx:end_idx
image_spatial_crop = image.image_spatial_crop ].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
input_embeds[start_idx:end_idx] = self.prepare_inputs_embeds(
pixel_values,
image_seq_mask,
image_spatial_crop,
input_embeds[start_idx:end_idx],
)
outputs = self.language_model.forward( outputs = self.language_model.forward(
input_ids=input_ids, input_ids=input_ids,
...@@ -289,20 +265,17 @@ class DeepseekVL2ForCausalLM(nn.Module): ...@@ -289,20 +265,17 @@ class DeepseekVL2ForCausalLM(nn.Module):
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
return input_ids return input_ids
def prepare_inputs_embeds( def get_image_feature(self, image_input: ImageInputs):
self, pixel_values = image_input.pixel_values.type(
pixel_values, next(self.vision.parameters()).dtype
images_seq_mask, ).to(device=next(self.vision.parameters()).device)
images_spatial_crop,
input_embeds,
):
image_feature = self.vision.forward_features(pixel_values) image_feature = self.vision.forward_features(pixel_values)
images_embeds = self.projector(image_feature) images_embeds = self.projector(image_feature)
_, hw, n_dim = images_embeds.shape _, hw, n_dim = images_embeds.shape
h = w = int(hw**0.5) h = w = int(hw**0.5)
tile_index = 0 tile_index = 0
images_in_this_batch = [] images_in_this_batch = []
images_spatial_crop = image_input.image_spatial_crop
for jdx in range(images_spatial_crop.shape[1]): for jdx in range(images_spatial_crop.shape[1]):
num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx] num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx]
if num_width_tiles == 0 or num_height_tiles == 0: if num_width_tiles == 0 or num_height_tiles == 0:
...@@ -379,13 +352,7 @@ class DeepseekVL2ForCausalLM(nn.Module): ...@@ -379,13 +352,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
images_in_this_batch.append(global_local_features) images_in_this_batch.append(global_local_features)
if len(images_in_this_batch) > 0: return torch.cat(images_in_this_batch, dim=0)
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
input_embeds.masked_scatter_(
images_seq_mask.unsqueeze(-1), images_in_this_batch
)
return input_embeds
EntryClass = DeepseekVL2ForCausalLM EntryClass = DeepseekVL2ForCausalLM
...@@ -37,11 +37,8 @@ from sglang.srt.layers.linear import ( ...@@ -37,11 +37,8 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, get_rope from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
...@@ -511,7 +508,7 @@ class Gemma3TextModel(PreTrainedModel): ...@@ -511,7 +508,7 @@ class Gemma3TextModel(PreTrainedModel):
else: else:
hidden_states = input_embeds hidden_states = input_embeds
if len(positions.shape) == 1: if positions.dim() == 1:
positions = einops.rearrange(positions, "s -> 1 s") positions = einops.rearrange(positions, "s -> 1 s")
position_embeddings_global = self.rotary_emb(hidden_states, positions) position_embeddings_global = self.rotary_emb(hidden_states, positions)
...@@ -609,11 +606,11 @@ class Gemma3ForCausalLM(PreTrainedModel): ...@@ -609,11 +606,11 @@ class Gemma3ForCausalLM(PreTrainedModel):
) )
self.post_init() self.post_init()
def get_input_embeddings(self): def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens return self.model.embed_tokens
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:
return self.model.layers[0].mlp.gate_up_proj.weight.dtype return next(self.parameters()).dtype
@torch.no_grad() @torch.no_grad()
def forward( def forward(
......
...@@ -34,8 +34,9 @@ from sglang.srt.hf_transformers_utils import get_processor ...@@ -34,8 +34,9 @@ from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.layernorm import Gemma3RMSNorm from sglang.srt.layers.layernorm import Gemma3RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.multi_modality_padding import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...@@ -264,10 +265,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -264,10 +265,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
kwargs["local_attn_masks"] = local_attn_masks kwargs["local_attn_masks"] = local_attn_masks
return kwargs return kwargs
def get_input_embeddings(self): def get_input_embeddings(self) -> nn.Embedding:
return self.language_model.get_input_embeddings() return self.language_model.get_input_embeddings()
def get_image_features(self, pixel_values: torch.Tensor): def get_image_feature(self, image_input: ImageInputs):
""" """
Projects the last hidden state from the vision model into language model space. Projects the last hidden state from the vision model into language model space.
...@@ -277,6 +278,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -277,6 +278,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
Returns: Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
""" """
pixel_values = image_input.pixel_values
pixel_values = pixel_values.to("cuda") pixel_values = pixel_values.to("cuda")
pixel_values = pixel_values.to(dtype=self.language_model.dtype()) pixel_values = pixel_values.to(dtype=self.language_model.dtype())
...@@ -305,7 +307,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -305,7 +307,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
return inputs_embeds return inputs_embeds
else: else:
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}") # print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
image_features = self.get_image_features(image_input.pixel_values) image_features = self.get_image_feature(image_input.pixel_values)
# print(f"image tokens from image embeddings: {image_features.numel()}") # print(f"image tokens from image embeddings: {image_features.numel()}")
num_image_tokens_in_embedding = ( num_image_tokens_in_embedding = (
...@@ -397,20 +399,13 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): ...@@ -397,20 +399,13 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
else: else:
llm_input_ids = input_ids llm_input_ids = input_ids
merged_image_input = forward_batch.get_merged_image_inputs() inputs_embeds = general_mm_embed_routine(
input_ids=llm_input_ids,
if ( positions=positions,
not forward_batch.forward_mode.is_decode() forward_batch=forward_batch,
and merged_image_input is not None embed_tokens=self.get_input_embeddings(),
): image_embedding_func=self.get_image_feature,
inputs_embeds = self.embed_image_inputs( )
input_ids=llm_input_ids,
forward_batch=forward_batch,
image_input=merged_image_input,
)
else:
llm_input_ids.clamp_(min=0, max=self.vocab_size - 1)
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
outputs = self.language_model( outputs = self.language_model(
input_ids=None, input_ids=None,
......
...@@ -50,8 +50,9 @@ from sglang.srt.layers.linear import ( ...@@ -50,8 +50,9 @@ from sglang.srt.layers.linear import (
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.multi_modality_padding import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
embed_image_inputs,
) )
from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...@@ -399,7 +400,7 @@ class Idefics2VisionTransformer(nn.Module): ...@@ -399,7 +400,7 @@ class Idefics2VisionTransformer(nn.Module):
) )
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
def get_input_embeddings(self): def get_input_embeddings(self) -> nn.Embedding:
return self.embeddings return self.embeddings
def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor: def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor:
...@@ -762,42 +763,6 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -762,42 +763,6 @@ class MiniCPMVBaseModel(nn.Module):
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device) valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
return valid_pairs_tensor return valid_pairs_tensor
def get_embedding(
self,
input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImageInputs],
) -> Tuple[torch.Tensor, torch.Tensor]:
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
if image_inputs is None: # No image
vision_hidden_states = torch.tensor([], device=input_ids.device)
else:
if image_inputs["type"] == "image_embeds":
vision_hidden_states = (
image_inputs["data"]
.type(vlm_embedding.dtype)
.to(vlm_embedding.device)
)
else:
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
# See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"]
if len(image_bounds) > 0:
image_indices = torch.stack(
[
torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist()
]
).to(vlm_embedding.device)
vlm_embedding.scatter_(
0,
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
)
return vlm_embedding, vision_hidden_states
def _parse_and_validate_inputs( def _parse_and_validate_inputs(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -836,46 +801,6 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -836,46 +801,6 @@ class MiniCPMVBaseModel(nn.Module):
type="image_embeds", type="image_embeds",
) )
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
)
if not isinstance(tgt_sizes, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
)
if len(pixel_values) != len(tgt_sizes):
raise ValueError(
"Inconsistent batch lengths, found: "
f"{len(pixel_values)} vs. {len(tgt_sizes)}"
)
pixel_values_flat: List[torch.Tensor] = []
tgt_sizes_flat: List[torch.Tensor] = []
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
if len(pixel_b) != len(tgt_b):
raise ValueError(
"Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
)
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
pixel_values_flat += pixel_n
tgt_sizes_flat += tgt_n
# NOTE: Input IDs does not contain image tokens during memory profiling,
# so we allow it to be empty
if len(pixel_values_flat) != len(tgt_sizes_flat):
raise ValueError(
"Inconsistent flattened lengths, found: "
f"{len(pixel_values_flat)} vs. "
f"{len(tgt_sizes_flat)}"
)
if len(pixel_values_flat) == 0:
return None
image_bounds = self._get_image_bounds( image_bounds = self._get_image_bounds(
input_ids=input_ids, input_ids=input_ids,
pad_values=pad_values, pad_values=pad_values,
...@@ -886,11 +811,50 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -886,11 +811,50 @@ class MiniCPMVBaseModel(nn.Module):
) )
return MiniCPMVImagePixelInputs( return MiniCPMVImagePixelInputs(
image_bounds=image_bounds.to(device=input_ids.device), image_bounds=image_bounds.to(device=input_ids.device),
data=pixel_values_flat, data=pixel_values,
tgt_sizes=torch.stack(tgt_sizes_flat), tgt_sizes=tgt_sizes,
type="pixel_values", type="pixel_values",
) )
def get_embedding(
self,
input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImageInputs],
) -> Tuple[torch.Tensor, torch.Tensor]:
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
if image_inputs is None: # No image
vision_hidden_states = torch.tensor([], device=input_ids.device)
else:
if image_inputs["type"] == "image_embeds":
vision_hidden_states = (
image_inputs["data"]
.type(vlm_embedding.dtype)
.to(vlm_embedding.device)
)
else:
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
# See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"]
if len(image_bounds) > 0:
image_indices = torch.stack(
[
torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist()
]
).to(vlm_embedding.device)
vlm_embedding.scatter_(
0,
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
)
return vlm_embedding, vision_hidden_states
def get_input_embeddings(self) -> nn.Embedding:
return self.llm.get_input_embedding()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -899,58 +863,29 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -899,58 +863,29 @@ class MiniCPMVBaseModel(nn.Module):
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
if ( if (
forward_batch.image_inputs is not None forward_batch.forward_mode.is_decode()
and len(forward_batch.image_inputs) > 0 or not forward_batch.contains_image_inputs()
and forward_batch.image_inputs[0] is not None
): ):
# TODO: bath inputs_embeds: torch.Tensor = self.llm.get_input_embeddings(input_ids)
kwargs.update( else:
{ # Clamp input ids. This is because the input_ids for the image tokens are
"pixel_values": ( # filled with the hash values of the image for the prefix matching in the radix attention.
None # There values are useless because their embeddings will be replaced by vision embeddings anyway.
if forward_batch.image_inputs is None image_inputs = forward_batch.merge_image_inputs()
else [ inputs_embeds = embed_image_inputs(
i.pixel_values image_input=image_inputs,
for i in forward_batch.image_inputs input_ids=input_ids,
if i is not None input_embedding=self.get_input_embeddings(),
] image_embedding_func=self.get_image_features,
), placeholder_token_ids=[image_inputs.im_token_id]
"tgt_sizes": ( + image_inputs.pad_values,
None
if forward_batch.image_inputs is None
else [
i.tgt_sizes
for i in forward_batch.image_inputs
if i is not None
]
),
"im_start_id": forward_batch.image_inputs[0].im_start_id,
"im_end_id": forward_batch.image_inputs[0].im_end_id,
"slice_start_id": forward_batch.image_inputs[0].slice_start_id,
"slice_end_id": forward_batch.image_inputs[0].slice_end_id,
"pad_values": forward_batch.image_inputs[0].pad_values,
}
) )
image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
hidden_states = self.llm.model( hidden_states = self.llm.model(
input_ids=input_ids, input_ids=None,
positions=positions, positions=positions,
forward_batch=forward_batch, forward_batch=forward_batch,
input_embeds=vlm_embeddings, input_embeds=inputs_embeds,
) )
return self.logits_processor( return self.logits_processor(
...@@ -990,7 +925,7 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -990,7 +925,7 @@ class MiniCPMVBaseModel(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor: def get_image_features(self, image_inputs: ImageInputs) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -1100,12 +1035,14 @@ class MiniCPMV2_6(MiniCPMVBaseModel): ...@@ -1100,12 +1035,14 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
) )
return vision_embedding return vision_embedding
def get_vision_hidden_states( def get_image_features(
self, self,
data: MiniCPMVImageInputs, image_inputs: ImageInputs,
) -> torch.Tensor: ) -> torch.Tensor:
pixel_values = data["data"] # list of tensors
tgt_sizes = data["tgt_sizes"] pixel_values = image_inputs.pixel_values
tgt_sizes = image_inputs.tgt_sizes
device = self.vpm.embeddings.position_embedding.weight.device device = self.vpm.embeddings.position_embedding.weight.device
dtype = self.vpm.embeddings.position_embedding.weight.dtype dtype = self.vpm.embeddings.position_embedding.weight.dtype
......
...@@ -361,6 +361,9 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -361,6 +361,9 @@ class Qwen2ForCausalLM(nn.Module):
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
def get_input_embedding(self) -> nn.Embedding:
return self.model.embed_tokens
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
......
...@@ -26,7 +26,6 @@ import logging ...@@ -26,7 +26,6 @@ import logging
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Iterable, List, Optional, Tuple, Type from typing import Iterable, List, Optional, Tuple, Type
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -54,14 +53,15 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -54,14 +53,15 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.multi_modality_padding import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.models.qwen2_vl import Qwen2VLImageInputs, Qwen2VLVideoInputs from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
from sglang.srt.utils import add_prefix from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -326,13 +326,12 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -326,13 +326,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
) )
def get_window_index(self, grid_thw): def get_window_index(self, grid_thw):
window_index: list = []
cu_window_seqlens: list = [0] cu_window_seqlens: list = [0]
window_index_id = 0 window_index_id = 0
vit_merger_window_size = ( vit_merger_window_size = (
self.window_size // self.spatial_merge_size // self.patch_size self.window_size // self.spatial_merge_size // self.patch_size
) )
window_index: list = []
for grid_t, grid_h, grid_w in grid_thw: for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h, llm_grid_w = ( llm_grid_h, llm_grid_w = (
grid_h // self.spatial_merge_size, grid_h // self.spatial_merge_size,
...@@ -369,7 +368,6 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -369,7 +368,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = torch.cat(window_index, dim=0) window_index = torch.cat(window_index, dim=0)
return window_index, cu_window_seqlens return window_index, cu_window_seqlens
@property @property
...@@ -382,8 +380,10 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -382,8 +380,10 @@ class Qwen2_5_VisionTransformer(nn.Module):
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
pos_ids = [] pos_ids = []
for t, h, w in grid_thw: for i in range(grid_thw.size(0)):
t, h, w = grid_thw[i].tolist()
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape( hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size, h // self.spatial_merge_size,
self.spatial_merge_size, self.spatial_merge_size,
...@@ -402,6 +402,7 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -402,6 +402,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
) )
wpos_ids = wpos_ids.permute(0, 2, 1, 3) wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten() wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0) pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max() max_grid_size = grid_thw[:, 1:].max()
...@@ -443,9 +444,12 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -443,9 +444,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
position_embeddings = (emb.cos(), emb.sin()) position_embeddings = (emb.cos(), emb.sin())
# compute cu_seqlens # compute cu_seqlens
cu_seqlens = torch.repeat_interleave( cu_seqlens = torch.cat(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] [
).cumsum(dim=0, dtype=torch.int32) torch.tensor([0], device=grid_thw.device),
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
]
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
# transformers # transformers
...@@ -509,18 +513,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -509,18 +513,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
processor = cached_get_processor(self.config._name_or_path)
grid_t, grid_h, grid_w = image_grid_thw
num_image_tokens = (
grid_t
* grid_h
* grid_w
// processor.image_processor.merge_size
// processor.image_processor.merge_size
)
return num_image_tokens
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
# Get all special token IDs # Get all special token IDs
im_start_id: int = image_inputs.im_start_id im_start_id: int = image_inputs.im_start_id
...@@ -531,9 +523,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -531,9 +523,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
return pattern.pad_input_tokens(input_ids, image_inputs) return pattern.pad_input_tokens(input_ids, image_inputs)
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor: def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype) pixel_values = image_input.pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"]) image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
return image_embeds return image_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor: def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
...@@ -543,6 +535,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -543,6 +535,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
) )
return video_embeds return video_embeds
def get_input_embeddings(self):
return self.model.embed_tokens
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -565,86 +560,26 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -565,86 +560,26 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
positions = forward_batch.mrope_positions positions = forward_batch.mrope_positions
image_inputs = None if not (
if forward_batch.image_inputs is not None:
image_inputs = [
img for img in forward_batch.image_inputs if img is not None
]
if (
forward_batch.forward_mode.is_decode() forward_batch.forward_mode.is_decode()
or image_inputs is None or not forward_batch.contains_image_inputs()
or len(image_inputs) == 0
): ):
inputs_embeds = self.model.embed_tokens(input_ids)
else:
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
assert positions.ndim == 2 and positions.size(0) == 3, ( assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires " "multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}" f"(3, seq_len) positions, but got {positions.size()}"
) )
# Clamp input ids. This is because the input_ids for the image tokens are inputs_embeds = general_mm_embed_routine(
# filled with the hash values of the image for the prefix matching in the radix attention. input_ids=input_ids,
# There values are useless because their embeddings will be replaced by vision embeddings anyway. positions=positions,
input_ids.clamp_(min=0, max=self.config.vocab_size - 1) forward_batch=forward_batch,
# [B, s, hidden_size] embed_tokens=self.get_input_embeddings(),
inputs_embeds = self.model.embed_tokens(input_ids) image_embedding_func=self.get_image_feature,
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() )
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
for i, image in enumerate(forward_batch.image_inputs):
if image is None or image.pixel_values is None:
continue
start_idx = extend_start_loc_cpu[i]
prefix_len = prefix_lens_cpu[i]
pixel_values = image.pixel_values.to(device="cuda")
image_grid_thws = torch.tensor(
np.array(image.image_grid_thws), device="cuda"
)
image_offsets = image.image_offsets
image_input = Qwen2VLImageInputs(
pixel_values=pixel_values, image_grid_thw=image_grid_thws
)
image_embeds = self._process_image_input(image_input)
image_embeds_offset = 0
for idx, image_offset in enumerate(image_offsets):
if image_offset < prefix_len:
continue
num_image_tokens = self.calculate_num_image_tokens(
image_grid_thws[idx]
)
left_idx = start_idx + (image_offset - prefix_len)
right_idx = left_idx + num_image_tokens
tp_size = get_tensor_model_parallel_world_size()
hidden_size = image_embeds.shape[-1]
if hidden_size % tp_size != 0:
padding_size = tp_size - (hidden_size % tp_size)
image_embeds = F.pad(image_embeds, (0, padding_size))
inputs_embeds = F.pad(inputs_embeds, (0, padding_size))
hidden_chunk_size = image_embeds.shape[-1] // tp_size
rank = get_tensor_model_parallel_rank()
start_dim = rank * hidden_chunk_size
end_dim = (rank + 1) * hidden_chunk_size
inputs_embeds[left_idx:right_idx, ..., start_dim:end_dim] = (
image_embeds[
image_embeds_offset : image_embeds_offset
+ num_image_tokens,
...,
start_dim:end_dim,
]
)
image_embeds_offset += num_image_tokens
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=None,
positions=positions, positions=positions,
forward_batch=forward_batch, forward_batch=forward_batch,
input_embeds=inputs_embeds, input_embeds=inputs_embeds,
......
...@@ -26,7 +26,6 @@ import logging ...@@ -26,7 +26,6 @@ import logging
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Iterable, List, Optional, Tuple, Type, TypedDict from typing import Iterable, List, Optional, Tuple, Type, TypedDict
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -42,8 +41,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -42,8 +41,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.multi_modality_padding import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
) )
from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...@@ -351,7 +351,7 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -351,7 +351,7 @@ class Qwen2VisionTransformer(nn.Module):
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype:
return self.blocks[0].mlp.fc2.weight.dtype return next(self.parameters()).dtype
@property @property
def device(self) -> torch.device: def device(self) -> torch.device:
...@@ -359,7 +359,8 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -359,7 +359,8 @@ class Qwen2VisionTransformer(nn.Module):
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
pos_ids = [] pos_ids = []
for t, h, w in grid_thw: for i in range(grid_thw.size(0)):
t, h, w = grid_thw[i].tolist()
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
hpos_ids = ( hpos_ids = (
...@@ -480,9 +481,9 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -480,9 +481,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, image_inputs) return pattern.pad_input_tokens(input_ids, image_inputs)
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor: def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype) pixel_values = image_input.pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"]) image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
return image_embeds return image_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor: def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
...@@ -492,6 +493,9 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -492,6 +493,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
) )
return video_embeds return video_embeds
def get_input_embeddings(self):
return self.model.embed_tokens
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -514,67 +518,26 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -514,67 +518,26 @@ class Qwen2VLForConditionalGeneration(nn.Module):
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
positions = forward_batch.mrope_positions positions = forward_batch.mrope_positions
image_inputs = None if not (
if forward_batch.image_inputs is not None:
image_inputs = [
img for img in forward_batch.image_inputs if img is not None
]
if (
forward_batch.forward_mode.is_decode() forward_batch.forward_mode.is_decode()
or image_inputs is None or not forward_batch.contains_image_inputs()
or len(image_inputs) == 0
): ):
inputs_embeds = self.model.embed_tokens(input_ids)
else:
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
assert positions.ndim == 2 and positions.size(0) == 3, ( assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires " "multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}" f"(3, seq_len) positions, but got {positions.size()}"
) )
# Clamp input ids. This is because the input_ids for the image tokens are inputs_embeds = general_mm_embed_routine(
# filled with the hash values of the image for the prefix matching in the radix attention. input_ids=input_ids,
# There values are useless because their embeddings will be replaced by vision embeddings anyway. positions=positions,
input_ids.clamp_(min=0, max=self.config.vocab_size - 1) forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(),
inputs_embeds = self.model.embed_tokens(input_ids) image_embedding_func=self.get_image_feature,
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() )
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
for i, image in enumerate(forward_batch.image_inputs):
if image is None or image.pixel_values is None:
continue
start_idx = extend_start_loc_cpu[i]
prefix_len = prefix_lens_cpu[i]
pixel_values = image.pixel_values.clone()
image_grid_thws = torch.tensor(
np.array(image.image_grid_thws), device="cuda"
)
image_offsets = image.image_offsets
image_input = Qwen2VLImageInputs(
pixel_values=pixel_values, image_grid_thw=image_grid_thws
)
image_embeds = self._process_image_input(image_input)
image_embeds_offset = 0
for idx, image_offset in enumerate(image_offsets):
if image_offset < prefix_len:
continue
num_image_tokens = self.calculate_num_image_tokens(
image_grid_thws[idx]
)
left_idx = start_idx + (image_offset - prefix_len + 1)
right_idx = left_idx + num_image_tokens
inputs_embeds[left_idx:right_idx] = image_embeds[
image_embeds_offset : image_embeds_offset + num_image_tokens
]
image_embeds_offset += num_image_tokens
input_ids = None
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=None,
positions=positions, positions=positions,
forward_batch=forward_batch, forward_batch=forward_batch,
input_embeds=inputs_embeds, input_embeds=inputs_embeds,
......
...@@ -23,6 +23,17 @@ from sglang.test.test_utils import ( ...@@ -23,6 +23,17 @@ from sglang.test.test_utils import (
popen_launch_server, popen_launch_server,
) )
# image
IMAGE_MAN_IRONING_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/man_ironing_on_back_of_suv.png"
IMAGE_SGL_LOGO_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/sgl_logo.png"
# video
VIDEO_JOBS_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/videos/jobs_presenting_ipod.mp4"
# audio
AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/Trump_WEF_2018_10s.mp3"
AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3"
class TestOpenAIVisionServer(unittest.TestCase): class TestOpenAIVisionServer(unittest.TestCase):
@classmethod @classmethod
...@@ -58,9 +69,7 @@ class TestOpenAIVisionServer(unittest.TestCase): ...@@ -58,9 +69,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
"content": [ "content": [
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": IMAGE_MAN_IRONING_URL},
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
}, },
{ {
"type": "text", "type": "text",
...@@ -96,9 +105,7 @@ class TestOpenAIVisionServer(unittest.TestCase): ...@@ -96,9 +105,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
"content": [ "content": [
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": IMAGE_MAN_IRONING_URL},
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
}, },
{ {
"type": "text", "type": "text",
...@@ -153,9 +160,7 @@ class TestOpenAIVisionServer(unittest.TestCase): ...@@ -153,9 +160,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
}, },
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": IMAGE_SGL_LOGO_URL},
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
},
"modalities": "multi-images", "modalities": "multi-images",
}, },
{ {
...@@ -242,10 +247,12 @@ class TestOpenAIVisionServer(unittest.TestCase): ...@@ -242,10 +247,12 @@ class TestOpenAIVisionServer(unittest.TestCase):
] ]
return messages return messages
def test_video_chat_completion(self): def get_or_download_file(self, url: str) -> str:
url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"
cache_dir = os.path.expanduser("~/.cache") cache_dir = os.path.expanduser("~/.cache")
file_path = os.path.join(cache_dir, "jobs.mp4") if url is None:
raise ValueError()
file_name = url.split("/")[-1]
file_path = os.path.join(cache_dir, file_name)
os.makedirs(cache_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True)
if not os.path.exists(file_path): if not os.path.exists(file_path):
...@@ -254,6 +261,11 @@ class TestOpenAIVisionServer(unittest.TestCase): ...@@ -254,6 +261,11 @@ class TestOpenAIVisionServer(unittest.TestCase):
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(response.content) f.write(response.content)
return file_path
def test_video_chat_completion(self):
url = VIDEO_JOBS_URL
file_path = self.get_or_download_file(url)
client = openai.Client(api_key=self.api_key, base_url=self.base_url) client = openai.Client(api_key=self.api_key, base_url=self.base_url)
...@@ -289,6 +301,7 @@ class TestOpenAIVisionServer(unittest.TestCase): ...@@ -289,6 +301,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
"present" in video_response "present" in video_response
or "examine" in video_response or "examine" in video_response
or "display" in video_response or "display" in video_response
or "hold" in video_response
) )
assert "black" in video_response or "dark" in video_response assert "black" in video_response or "dark" in video_response
self.assertIsNotNone(video_response) self.assertIsNotNone(video_response)
...@@ -312,9 +325,7 @@ class TestOpenAIVisionServer(unittest.TestCase): ...@@ -312,9 +325,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
"content": [ "content": [
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": IMAGE_MAN_IRONING_URL},
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
}, },
{ {
"type": "text", "type": "text",
...@@ -344,18 +355,14 @@ class TestOpenAIVisionServer(unittest.TestCase): ...@@ -344,18 +355,14 @@ class TestOpenAIVisionServer(unittest.TestCase):
content.append( content.append(
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": IMAGE_MAN_IRONING_URL},
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
} }
) )
elif image_id == 1: elif image_id == 1:
content.append( content.append(
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": IMAGE_SGL_LOGO_URL},
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
},
} }
) )
else: else:
...@@ -465,9 +472,7 @@ class TestVLMContextLengthIssue(unittest.TestCase): ...@@ -465,9 +472,7 @@ class TestVLMContextLengthIssue(unittest.TestCase):
"content": [ "content": [
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": IMAGE_MAN_IRONING_URL},
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
}, },
{ {
"type": "text", "type": "text",
......
...@@ -13,6 +13,8 @@ from transformers import AutoModel, AutoProcessor, AutoTokenizer ...@@ -13,6 +13,8 @@ from transformers import AutoModel, AutoProcessor, AutoTokenizer
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.conversation import generate_chat_conv from sglang.srt.conversation import generate_chat_conv
from sglang.srt.managers.mm_utils import embed_image_inputs
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.openai_api.protocol import ChatCompletionRequest from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -168,10 +170,14 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase): ...@@ -168,10 +170,14 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
).eval() ).eval()
cls.model.to(cls.device) cls.model.to(cls.device)
async def test_encode_output(self): async def test_vlm_embedding_output(self):
"""
Compares the embedding output of vlm
"""
inputs = self.get_processor_output() inputs = self.get_processor_output()
with torch.no_grad(): with torch.no_grad():
# hf
model_inputs = { model_inputs = {
"input_ids": inputs.input_ids, "input_ids": inputs.input_ids,
"image_bound": inputs.image_bound, "image_bound": inputs.image_bound,
...@@ -183,22 +189,20 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase): ...@@ -183,22 +189,20 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
) )
hf_output = hf_output.squeeze(0) hf_output = hf_output.squeeze(0)
with torch.no_grad(): # sglang
model = self.get_sglang_model() model = self.get_sglang_model()
input_ids = inputs["input_ids"].to(self.device).flatten() input_ids = inputs["input_ids"].to(self.device).flatten()
image_inputs = model._parse_and_validate_inputs( sglang_output = embed_image_inputs(
image_input=ImageInputs(
pixel_values=inputs["pixel_values"][0],
tgt_sizes=inputs["tgt_sizes"][0],
),
input_ids=input_ids, input_ids=input_ids,
**{ input_embedding=model.get_input_embeddings(),
"pixel_values": [inputs["pixel_values"]], image_embedding_func=model.get_image_features,
"tgt_sizes": [inputs["tgt_sizes"]], placeholder_token_ids=[
"im_start_id": self.tokenizer.im_start_id, self.processor.tokenizer.unk_token_id,
"im_end_id": self.tokenizer.im_end_id, ],
"slice_start_id": self.tokenizer.slice_start_id,
"slice_end_id": self.tokenizer.slice_end_id,
},
)
(sglang_output, _) = model.get_embedding(
input_ids=input_ids, image_inputs=image_inputs
) )
self.compare_outputs(sglang_output, hf_output) self.compare_outputs(sglang_output, hf_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment