Unverified Commit 11577ced authored by Mick's avatar Mick Committed by GitHub
Browse files

refactor: bug fixes and refactor for vlm (#4661)

parent ca75741e
......@@ -47,8 +47,9 @@ from sglang.srt.configs.janus_pro import *
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.managers.multi_modality_padding import (
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
......@@ -1958,17 +1959,24 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
)
self.logits_processor = LogitsProcessor(config)
def prepare_images_seq_mask(
self, input_ids: torch.Tensor, image_inputs: ImageInputs
) -> Optional[torch.LongTensor]:
images_seq_mask = torch.isin(
input_ids, torch.tensor(image_inputs.pad_values, device=input_ids.device)
def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor:
pixel_values = image_input.pixel_values
bs, n = pixel_values.shape[0:2]
pixel_values = pixel_values.to(
device=self.vision_model.device, dtype=self.vision_model.dtype
)
if images_seq_mask.sum() == 0:
# sometimes image_inputs is not empty, but input_ids contain no image token because of prefix-cache
return None
else:
return images_seq_mask
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
# [b x n, T2, D]
images_embeds = self.aligner(self.vision_model(images))
# [b x n, T2, D] -> [b, n x T2, D]
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
return images_embeds
def get_input_embeddings(self) -> nn.Embedding:
return self.language_model.model.embed_tokens
@torch.no_grad()
def forward(
......@@ -1978,86 +1986,22 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
forward_batch: ForwardBatch,
) -> torch.Tensor:
inputs_embeds = None
if (
forward_batch.image_inputs is not None
and len(forward_batch.image_inputs) != 0
and forward_batch.image_inputs[0] is not None
):
image_inputs = forward_batch.image_inputs[0]
images_seq_mask = self.prepare_images_seq_mask(
input_ids=input_ids, image_inputs=image_inputs
)
if images_seq_mask is not None:
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
inputs_embeds = self.prepare_inputs_embeds(
input_ids=input_ids,
pixel_values=image_inputs.pixel_values,
images_seq_mask=images_seq_mask,
images_emb_mask=image_inputs.images_emb_mask,
)
input_ids = None
if input_ids is not None:
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
inputs_embeds = general_mm_embed_routine(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(),
image_embedding_func=self.get_image_feature,
)
return self.language_model(
input_ids=input_ids,
input_ids=None,
positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
get_embedding=False,
)
def prepare_inputs_embeds(
self,
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
images_seq_mask: torch.LongTensor,
images_emb_mask: torch.BoolTensor,
**_kwargs,
):
"""
Args:
input_ids (torch.LongTensor): [b, T]
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
images_seq_mask (torch.BoolTensor): [b, T]
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
Returns:
input_embeds (torch.Tensor): [b, T, D]
"""
bs, n = pixel_values.shape[0:2]
pixel_values = pixel_values.to(
device=self.vision_model.device, dtype=self.vision_model.dtype
)
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
# [b x n, T2, D]
images_embeds = self.aligner(self.vision_model(images))
# [b x n, T2, D] -> [b, n x T2, D]
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
# [b, n, T2] -> [b, n x T2]
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
# [b, T, D]
# ignore the image embeddings
input_ids[input_ids < 0] = 0
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
# replace with the image embeddings
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
return inputs_embeds
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
return self.gen_aligner(self.gen_embed(image_ids))
......
import collections
import itertools
import math
import warnings
from enum import Enum
from functools import partial
from typing import Callable, Iterable, List, Optional, Tuple, Type, Union
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn
from sglang.srt.configs import DeepseekVL2Config
from sglang.srt.configs.deepseekvl2 import (
DeepseekVL2Config,
DeepseekVL2MlpProjectorConfig,
)
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
LinearBase,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
......@@ -233,11 +215,11 @@ class DeepseekVL2ForCausalLM(nn.Module):
forward_batch: ForwardBatch,
**kwargs: object,
):
input_embeds = self.language_model.model.embed_tokens(input_ids)
if forward_batch.forward_mode.is_extend() and forward_batch.image_inputs != [
None
]:
if (
forward_batch.forward_mode.is_extend()
and forward_batch.contains_image_inputs()
):
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
for idx, image in enumerate(forward_batch.image_inputs):
......@@ -245,17 +227,11 @@ class DeepseekVL2ForCausalLM(nn.Module):
continue
start_idx = extend_start_loc_cpu[idx]
end_idx = start_idx + extend_seq_lens_cpu[idx]
pixel_values = image.pixel_values.to(
device="cuda", dtype=torch.bfloat16
)
image_seq_mask = image.image_seq_mask.to(device="cuda")
image_spatial_crop = image.image_spatial_crop
input_embeds[start_idx:end_idx] = self.prepare_inputs_embeds(
pixel_values,
image_seq_mask,
image_spatial_crop,
input_embeds[start_idx:end_idx],
)
images_emb_mask = image.images_emb_mask.to(device="cuda")
image_features = self.get_image_feature(image)
input_embeds[start_idx:end_idx] = input_embeds[
start_idx:end_idx
].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
outputs = self.language_model.forward(
input_ids=input_ids,
......@@ -289,20 +265,17 @@ class DeepseekVL2ForCausalLM(nn.Module):
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
return input_ids
def prepare_inputs_embeds(
self,
pixel_values,
images_seq_mask,
images_spatial_crop,
input_embeds,
):
def get_image_feature(self, image_input: ImageInputs):
pixel_values = image_input.pixel_values.type(
next(self.vision.parameters()).dtype
).to(device=next(self.vision.parameters()).device)
image_feature = self.vision.forward_features(pixel_values)
images_embeds = self.projector(image_feature)
_, hw, n_dim = images_embeds.shape
h = w = int(hw**0.5)
tile_index = 0
images_in_this_batch = []
images_spatial_crop = image_input.image_spatial_crop
for jdx in range(images_spatial_crop.shape[1]):
num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx]
if num_width_tiles == 0 or num_height_tiles == 0:
......@@ -379,13 +352,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
images_in_this_batch.append(global_local_features)
if len(images_in_this_batch) > 0:
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
input_embeds.masked_scatter_(
images_seq_mask.unsqueeze(-1), images_in_this_batch
)
return input_embeds
return torch.cat(images_in_this_batch, dim=0)
EntryClass = DeepseekVL2ForCausalLM
......@@ -37,11 +37,8 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
......@@ -511,7 +508,7 @@ class Gemma3TextModel(PreTrainedModel):
else:
hidden_states = input_embeds
if len(positions.shape) == 1:
if positions.dim() == 1:
positions = einops.rearrange(positions, "s -> 1 s")
position_embeddings_global = self.rotary_emb(hidden_states, positions)
......@@ -609,11 +606,11 @@ class Gemma3ForCausalLM(PreTrainedModel):
)
self.post_init()
def get_input_embeddings(self):
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
def dtype(self) -> torch.dtype:
return self.model.layers[0].mlp.gate_up_proj.weight.dtype
return next(self.parameters()).dtype
@torch.no_grad()
def forward(
......
......@@ -34,8 +34,9 @@ from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.layernorm import Gemma3RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.multi_modality_padding import (
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
......@@ -264,10 +265,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
kwargs["local_attn_masks"] = local_attn_masks
return kwargs
def get_input_embeddings(self):
def get_input_embeddings(self) -> nn.Embedding:
return self.language_model.get_input_embeddings()
def get_image_features(self, pixel_values: torch.Tensor):
def get_image_feature(self, image_input: ImageInputs):
"""
Projects the last hidden state from the vision model into language model space.
......@@ -277,6 +278,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
pixel_values = image_input.pixel_values
pixel_values = pixel_values.to("cuda")
pixel_values = pixel_values.to(dtype=self.language_model.dtype())
......@@ -305,7 +307,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
return inputs_embeds
else:
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
image_features = self.get_image_features(image_input.pixel_values)
image_features = self.get_image_feature(image_input.pixel_values)
# print(f"image tokens from image embeddings: {image_features.numel()}")
num_image_tokens_in_embedding = (
......@@ -397,20 +399,13 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
else:
llm_input_ids = input_ids
merged_image_input = forward_batch.get_merged_image_inputs()
if (
not forward_batch.forward_mode.is_decode()
and merged_image_input is not None
):
inputs_embeds = self.embed_image_inputs(
input_ids=llm_input_ids,
forward_batch=forward_batch,
image_input=merged_image_input,
)
else:
llm_input_ids.clamp_(min=0, max=self.vocab_size - 1)
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
inputs_embeds = general_mm_embed_routine(
input_ids=llm_input_ids,
positions=positions,
forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(),
image_embedding_func=self.get_image_feature,
)
outputs = self.language_model(
input_ids=None,
......
......@@ -50,8 +50,9 @@ from sglang.srt.layers.linear import (
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.multi_modality_padding import (
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
embed_image_inputs,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
......@@ -399,7 +400,7 @@ class Idefics2VisionTransformer(nn.Module):
)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
def get_input_embeddings(self):
def get_input_embeddings(self) -> nn.Embedding:
return self.embeddings
def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor:
......@@ -762,42 +763,6 @@ class MiniCPMVBaseModel(nn.Module):
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
return valid_pairs_tensor
def get_embedding(
self,
input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImageInputs],
) -> Tuple[torch.Tensor, torch.Tensor]:
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
if image_inputs is None: # No image
vision_hidden_states = torch.tensor([], device=input_ids.device)
else:
if image_inputs["type"] == "image_embeds":
vision_hidden_states = (
image_inputs["data"]
.type(vlm_embedding.dtype)
.to(vlm_embedding.device)
)
else:
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
# See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"]
if len(image_bounds) > 0:
image_indices = torch.stack(
[
torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist()
]
).to(vlm_embedding.device)
vlm_embedding.scatter_(
0,
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
)
return vlm_embedding, vision_hidden_states
def _parse_and_validate_inputs(
self,
input_ids: torch.Tensor,
......@@ -836,46 +801,6 @@ class MiniCPMVBaseModel(nn.Module):
type="image_embeds",
)
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
)
if not isinstance(tgt_sizes, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
)
if len(pixel_values) != len(tgt_sizes):
raise ValueError(
"Inconsistent batch lengths, found: "
f"{len(pixel_values)} vs. {len(tgt_sizes)}"
)
pixel_values_flat: List[torch.Tensor] = []
tgt_sizes_flat: List[torch.Tensor] = []
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
if len(pixel_b) != len(tgt_b):
raise ValueError(
"Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
)
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
pixel_values_flat += pixel_n
tgt_sizes_flat += tgt_n
# NOTE: Input IDs does not contain image tokens during memory profiling,
# so we allow it to be empty
if len(pixel_values_flat) != len(tgt_sizes_flat):
raise ValueError(
"Inconsistent flattened lengths, found: "
f"{len(pixel_values_flat)} vs. "
f"{len(tgt_sizes_flat)}"
)
if len(pixel_values_flat) == 0:
return None
image_bounds = self._get_image_bounds(
input_ids=input_ids,
pad_values=pad_values,
......@@ -886,11 +811,50 @@ class MiniCPMVBaseModel(nn.Module):
)
return MiniCPMVImagePixelInputs(
image_bounds=image_bounds.to(device=input_ids.device),
data=pixel_values_flat,
tgt_sizes=torch.stack(tgt_sizes_flat),
data=pixel_values,
tgt_sizes=tgt_sizes,
type="pixel_values",
)
def get_embedding(
self,
input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImageInputs],
) -> Tuple[torch.Tensor, torch.Tensor]:
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
if image_inputs is None: # No image
vision_hidden_states = torch.tensor([], device=input_ids.device)
else:
if image_inputs["type"] == "image_embeds":
vision_hidden_states = (
image_inputs["data"]
.type(vlm_embedding.dtype)
.to(vlm_embedding.device)
)
else:
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
# See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"]
if len(image_bounds) > 0:
image_indices = torch.stack(
[
torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist()
]
).to(vlm_embedding.device)
vlm_embedding.scatter_(
0,
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
)
return vlm_embedding, vision_hidden_states
def get_input_embeddings(self) -> nn.Embedding:
return self.llm.get_input_embedding()
def forward(
self,
input_ids: torch.Tensor,
......@@ -899,58 +863,29 @@ class MiniCPMVBaseModel(nn.Module):
**kwargs: Any,
) -> torch.Tensor:
if (
forward_batch.image_inputs is not None
and len(forward_batch.image_inputs) > 0
and forward_batch.image_inputs[0] is not None
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
# TODO: bath
kwargs.update(
{
"pixel_values": (
None
if forward_batch.image_inputs is None
else [
i.pixel_values
for i in forward_batch.image_inputs
if i is not None
]
),
"tgt_sizes": (
None
if forward_batch.image_inputs is None
else [
i.tgt_sizes
for i in forward_batch.image_inputs
if i is not None
]
),
"im_start_id": forward_batch.image_inputs[0].im_start_id,
"im_end_id": forward_batch.image_inputs[0].im_end_id,
"slice_start_id": forward_batch.image_inputs[0].slice_start_id,
"slice_end_id": forward_batch.image_inputs[0].slice_end_id,
"pad_values": forward_batch.image_inputs[0].pad_values,
}
inputs_embeds: torch.Tensor = self.llm.get_input_embeddings(input_ids)
else:
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
image_inputs = forward_batch.merge_image_inputs()
inputs_embeds = embed_image_inputs(
image_input=image_inputs,
input_ids=input_ids,
input_embedding=self.get_input_embeddings(),
image_embedding_func=self.get_image_features,
placeholder_token_ids=[image_inputs.im_token_id]
+ image_inputs.pad_values,
)
image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
hidden_states = self.llm.model(
input_ids=input_ids,
input_ids=None,
positions=positions,
forward_batch=forward_batch,
input_embeds=vlm_embeddings,
input_embeds=inputs_embeds,
)
return self.logits_processor(
......@@ -990,7 +925,7 @@ class MiniCPMVBaseModel(nn.Module):
) -> torch.Tensor:
raise NotImplementedError
def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor:
def get_image_features(self, image_inputs: ImageInputs) -> torch.Tensor:
raise NotImplementedError
......@@ -1100,12 +1035,14 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
)
return vision_embedding
def get_vision_hidden_states(
def get_image_features(
self,
data: MiniCPMVImageInputs,
image_inputs: ImageInputs,
) -> torch.Tensor:
pixel_values = data["data"]
tgt_sizes = data["tgt_sizes"]
# list of tensors
pixel_values = image_inputs.pixel_values
tgt_sizes = image_inputs.tgt_sizes
device = self.vpm.embeddings.position_embedding.weight.device
dtype = self.vpm.embeddings.position_embedding.weight.dtype
......
......@@ -361,6 +361,9 @@ class Qwen2ForCausalLM(nn.Module):
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def get_input_embedding(self) -> nn.Embedding:
return self.model.embed_tokens
@torch.no_grad()
def forward(
self,
......
......@@ -26,7 +26,6 @@ import logging
from functools import lru_cache, partial
from typing import Iterable, List, Optional, Tuple, Type
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -54,14 +53,15 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.multi_modality_padding import (
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.models.qwen2_vl import Qwen2VLImageInputs, Qwen2VLVideoInputs
from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
......@@ -326,13 +326,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
)
def get_window_index(self, grid_thw):
window_index: list = []
cu_window_seqlens: list = [0]
window_index_id = 0
vit_merger_window_size = (
self.window_size // self.spatial_merge_size // self.patch_size
)
window_index: list = []
for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h, llm_grid_w = (
grid_h // self.spatial_merge_size,
......@@ -369,7 +368,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = torch.cat(window_index, dim=0)
return window_index, cu_window_seqlens
@property
......@@ -382,8 +380,10 @@ class Qwen2_5_VisionTransformer(nn.Module):
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
pos_ids = []
for t, h, w in grid_thw:
for i in range(grid_thw.size(0)):
t, h, w = grid_thw[i].tolist()
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
......@@ -402,6 +402,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
......@@ -443,9 +444,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
position_embeddings = (emb.cos(), emb.sin())
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = torch.cat(
[
torch.tensor([0], device=grid_thw.device),
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
]
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
# transformers
......@@ -509,18 +513,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
processor = cached_get_processor(self.config._name_or_path)
grid_t, grid_h, grid_w = image_grid_thw
num_image_tokens = (
grid_t
* grid_h
* grid_w
// processor.image_processor.merge_size
// processor.image_processor.merge_size
)
return num_image_tokens
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
# Get all special token IDs
im_start_id: int = image_inputs.im_start_id
......@@ -531,9 +523,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
return pattern.pad_input_tokens(input_ids, image_inputs)
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"])
def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor:
pixel_values = image_input.pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
return image_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
......@@ -543,6 +535,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
)
return video_embeds
def get_input_embeddings(self):
return self.model.embed_tokens
def forward(
self,
input_ids: torch.Tensor,
......@@ -565,86 +560,26 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
positions = forward_batch.mrope_positions
image_inputs = None
if forward_batch.image_inputs is not None:
image_inputs = [
img for img in forward_batch.image_inputs if img is not None
]
if (
if not (
forward_batch.forward_mode.is_decode()
or image_inputs is None
or len(image_inputs) == 0
or not forward_batch.contains_image_inputs()
):
inputs_embeds = self.model.embed_tokens(input_ids)
else:
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
)
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
# [B, s, hidden_size]
inputs_embeds = self.model.embed_tokens(input_ids)
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
for i, image in enumerate(forward_batch.image_inputs):
if image is None or image.pixel_values is None:
continue
start_idx = extend_start_loc_cpu[i]
prefix_len = prefix_lens_cpu[i]
pixel_values = image.pixel_values.to(device="cuda")
image_grid_thws = torch.tensor(
np.array(image.image_grid_thws), device="cuda"
)
image_offsets = image.image_offsets
image_input = Qwen2VLImageInputs(
pixel_values=pixel_values, image_grid_thw=image_grid_thws
)
image_embeds = self._process_image_input(image_input)
image_embeds_offset = 0
for idx, image_offset in enumerate(image_offsets):
if image_offset < prefix_len:
continue
num_image_tokens = self.calculate_num_image_tokens(
image_grid_thws[idx]
)
left_idx = start_idx + (image_offset - prefix_len)
right_idx = left_idx + num_image_tokens
tp_size = get_tensor_model_parallel_world_size()
hidden_size = image_embeds.shape[-1]
if hidden_size % tp_size != 0:
padding_size = tp_size - (hidden_size % tp_size)
image_embeds = F.pad(image_embeds, (0, padding_size))
inputs_embeds = F.pad(inputs_embeds, (0, padding_size))
hidden_chunk_size = image_embeds.shape[-1] // tp_size
rank = get_tensor_model_parallel_rank()
start_dim = rank * hidden_chunk_size
end_dim = (rank + 1) * hidden_chunk_size
inputs_embeds[left_idx:right_idx, ..., start_dim:end_dim] = (
image_embeds[
image_embeds_offset : image_embeds_offset
+ num_image_tokens,
...,
start_dim:end_dim,
]
)
image_embeds_offset += num_image_tokens
inputs_embeds = general_mm_embed_routine(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(),
image_embedding_func=self.get_image_feature,
)
hidden_states = self.model(
input_ids=input_ids,
input_ids=None,
positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
......
......@@ -26,7 +26,6 @@ import logging
from functools import lru_cache, partial
from typing import Iterable, List, Optional, Tuple, Type, TypedDict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -42,8 +41,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.multi_modality_padding import (
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
......@@ -351,7 +351,7 @@ class Qwen2VisionTransformer(nn.Module):
@property
def dtype(self) -> torch.dtype:
return self.blocks[0].mlp.fc2.weight.dtype
return next(self.parameters()).dtype
@property
def device(self) -> torch.device:
......@@ -359,7 +359,8 @@ class Qwen2VisionTransformer(nn.Module):
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
pos_ids = []
for t, h, w in grid_thw:
for i in range(grid_thw.size(0)):
t, h, w = grid_thw[i].tolist()
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
hpos_ids = (
......@@ -480,9 +481,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, image_inputs)
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"])
def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor:
pixel_values = image_input.pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
return image_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
......@@ -492,6 +493,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
)
return video_embeds
def get_input_embeddings(self):
return self.model.embed_tokens
def forward(
self,
input_ids: torch.Tensor,
......@@ -514,67 +518,26 @@ class Qwen2VLForConditionalGeneration(nn.Module):
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
positions = forward_batch.mrope_positions
image_inputs = None
if forward_batch.image_inputs is not None:
image_inputs = [
img for img in forward_batch.image_inputs if img is not None
]
if (
if not (
forward_batch.forward_mode.is_decode()
or image_inputs is None
or len(image_inputs) == 0
or not forward_batch.contains_image_inputs()
):
inputs_embeds = self.model.embed_tokens(input_ids)
else:
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
)
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
inputs_embeds = self.model.embed_tokens(input_ids)
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
for i, image in enumerate(forward_batch.image_inputs):
if image is None or image.pixel_values is None:
continue
start_idx = extend_start_loc_cpu[i]
prefix_len = prefix_lens_cpu[i]
pixel_values = image.pixel_values.clone()
image_grid_thws = torch.tensor(
np.array(image.image_grid_thws), device="cuda"
)
image_offsets = image.image_offsets
image_input = Qwen2VLImageInputs(
pixel_values=pixel_values, image_grid_thw=image_grid_thws
)
image_embeds = self._process_image_input(image_input)
image_embeds_offset = 0
for idx, image_offset in enumerate(image_offsets):
if image_offset < prefix_len:
continue
num_image_tokens = self.calculate_num_image_tokens(
image_grid_thws[idx]
)
left_idx = start_idx + (image_offset - prefix_len + 1)
right_idx = left_idx + num_image_tokens
inputs_embeds[left_idx:right_idx] = image_embeds[
image_embeds_offset : image_embeds_offset + num_image_tokens
]
image_embeds_offset += num_image_tokens
input_ids = None
inputs_embeds = general_mm_embed_routine(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(),
image_embedding_func=self.get_image_feature,
)
hidden_states = self.model(
input_ids=input_ids,
input_ids=None,
positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
......
......@@ -23,6 +23,17 @@ from sglang.test.test_utils import (
popen_launch_server,
)
# image
IMAGE_MAN_IRONING_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/man_ironing_on_back_of_suv.png"
IMAGE_SGL_LOGO_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/images/sgl_logo.png"
# video
VIDEO_JOBS_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/videos/jobs_presenting_ipod.mp4"
# audio
AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/Trump_WEF_2018_10s.mp3"
AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3"
class TestOpenAIVisionServer(unittest.TestCase):
@classmethod
......@@ -58,9 +69,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "text",
......@@ -96,9 +105,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "text",
......@@ -153,9 +160,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
},
"image_url": {"url": IMAGE_SGL_LOGO_URL},
"modalities": "multi-images",
},
{
......@@ -242,10 +247,12 @@ class TestOpenAIVisionServer(unittest.TestCase):
]
return messages
def test_video_chat_completion(self):
url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"
def get_or_download_file(self, url: str) -> str:
cache_dir = os.path.expanduser("~/.cache")
file_path = os.path.join(cache_dir, "jobs.mp4")
if url is None:
raise ValueError()
file_name = url.split("/")[-1]
file_path = os.path.join(cache_dir, file_name)
os.makedirs(cache_dir, exist_ok=True)
if not os.path.exists(file_path):
......@@ -254,6 +261,11 @@ class TestOpenAIVisionServer(unittest.TestCase):
with open(file_path, "wb") as f:
f.write(response.content)
return file_path
def test_video_chat_completion(self):
url = VIDEO_JOBS_URL
file_path = self.get_or_download_file(url)
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
......@@ -289,6 +301,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
"present" in video_response
or "examine" in video_response
or "display" in video_response
or "hold" in video_response
)
assert "black" in video_response or "dark" in video_response
self.assertIsNotNone(video_response)
......@@ -312,9 +325,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "text",
......@@ -344,18 +355,14 @@ class TestOpenAIVisionServer(unittest.TestCase):
content.append(
{
"type": "image_url",
"image_url": {
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
"image_url": {"url": IMAGE_MAN_IRONING_URL},
}
)
elif image_id == 1:
content.append(
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
},
"image_url": {"url": IMAGE_SGL_LOGO_URL},
}
)
else:
......@@ -465,9 +472,7 @@ class TestVLMContextLengthIssue(unittest.TestCase):
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
},
"image_url": {"url": IMAGE_MAN_IRONING_URL},
},
{
"type": "text",
......
......@@ -13,6 +13,8 @@ from transformers import AutoModel, AutoProcessor, AutoTokenizer
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.managers.mm_utils import embed_image_inputs
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.server_args import ServerArgs
......@@ -168,10 +170,14 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
).eval()
cls.model.to(cls.device)
async def test_encode_output(self):
async def test_vlm_embedding_output(self):
"""
Compares the embedding output of vlm
"""
inputs = self.get_processor_output()
with torch.no_grad():
# hf
model_inputs = {
"input_ids": inputs.input_ids,
"image_bound": inputs.image_bound,
......@@ -183,22 +189,20 @@ class TestMiniCPMVLogits(VisionLLMLogitsBase):
)
hf_output = hf_output.squeeze(0)
with torch.no_grad():
# sglang
model = self.get_sglang_model()
input_ids = inputs["input_ids"].to(self.device).flatten()
image_inputs = model._parse_and_validate_inputs(
sglang_output = embed_image_inputs(
image_input=ImageInputs(
pixel_values=inputs["pixel_values"][0],
tgt_sizes=inputs["tgt_sizes"][0],
),
input_ids=input_ids,
**{
"pixel_values": [inputs["pixel_values"]],
"tgt_sizes": [inputs["tgt_sizes"]],
"im_start_id": self.tokenizer.im_start_id,
"im_end_id": self.tokenizer.im_end_id,
"slice_start_id": self.tokenizer.slice_start_id,
"slice_end_id": self.tokenizer.slice_end_id,
},
)
(sglang_output, _) = model.get_embedding(
input_ids=input_ids, image_inputs=image_inputs
input_embedding=model.get_input_embeddings(),
image_embedding_func=model.get_image_features,
placeholder_token_ids=[
self.processor.tokenizer.unk_token_id,
],
)
self.compare_outputs(sglang_output, hf_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment