Unverified Commit 9f635ea5 authored by Mick's avatar Mick Committed by GitHub
Browse files

[Fix] Address remaining issues of supporting MiniCPMV (#2977)

parent 76285fde
...@@ -78,6 +78,7 @@ Another valuable resource is the [vLLM Models Directory](https://github.com/vllm ...@@ -78,6 +78,7 @@ Another valuable resource is the [vLLM Models Directory](https://github.com/vllm
To port a model from vLLM to SGLang, you can compare these two files [SGLang Llama Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py) and [vLLM Llama Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py). This comparison will help you understand how to convert a model implementation from vLLM to SGLang. The major difference is the replacement of Attention with RadixAttention. The other parts are almost identical. Specifically, To port a model from vLLM to SGLang, you can compare these two files [SGLang Llama Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py) and [vLLM Llama Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py). This comparison will help you understand how to convert a model implementation from vLLM to SGLang. The major difference is the replacement of Attention with RadixAttention. The other parts are almost identical. Specifically,
- Replace vllm's `Attention` with `RadixAttention`. Note that you need to pass `layer_id` all the way to `RadixAttention`. - Replace vllm's `Attention` with `RadixAttention`. Note that you need to pass `layer_id` all the way to `RadixAttention`.
- Replace vllm's `LogitsProcessor` with SGLang's `LogitsProcessor`. - Replace vllm's `LogitsProcessor` with SGLang's `LogitsProcessor`.
- Replace Multi-headed `Attention` of ViT with SGLang's `VisionAttention`.
- Replace other vLLM layers with SGLang layers (e.g., `RMSNorm`, `SiluAndMul`). - Replace other vLLM layers with SGLang layers (e.g., `RMSNorm`, `SiluAndMul`).
- Remove `Sample`. - Remove `Sample`.
- Change `forward()` functions, and add `forward_batch`. - Change `forward()` functions, and add `forward_batch`.
......
...@@ -166,6 +166,12 @@ def _fwd_kernel( ...@@ -166,6 +166,12 @@ def _fwd_kernel(
def context_attention_fwd( def context_attention_fwd(
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
): ):
"""
q, k, v: [b * s, head, head_dim]
b_start_loc: [b]
b_seq_len: [b]
out: [b * s, head, head_dim]
"""
if is_cuda_available and CUDA_CAPABILITY[0] > 8: if is_cuda_available and CUDA_CAPABILITY[0] > 8:
BLOCK = 128 BLOCK = 128
else: else:
......
...@@ -4,6 +4,7 @@ from typing import Optional ...@@ -4,6 +4,7 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from sglang.srt.distributed import parallel_state from sglang.srt.distributed import parallel_state
...@@ -63,7 +64,20 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.T ...@@ -63,7 +64,20 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.T
class VisionAttention(nn.Module): class VisionAttention(nn.Module):
"""Multi-headed attention without any cache, mostly used for ViT.""" r"""
Multi-headed attention without any cache, mostly used for ViT.
Args:
use_qkv_parallel (bool, optional): If True, use QKV-parallel attention.
use_context_forward (bool, default to True):
if ``True``, a flash_attn style attention will be applied
Otherwise, a full-sequence attention will be applied.
use_full_precision_softmax (bool, default to False):
if ``True``, the softmax will be performed in full-precision
Otherwise, it will be performed in half-precision
"""
def __init__( def __init__(
self, self,
...@@ -72,25 +86,39 @@ class VisionAttention(nn.Module): ...@@ -72,25 +86,39 @@ class VisionAttention(nn.Module):
projection_size: int, projection_size: int,
use_qkv_parallel: bool, use_qkv_parallel: bool,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
dropout: float = 0.0,
use_context_forward: bool = True,
use_full_precision_softmax: bool = False,
flatten_batch: bool = False,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
self.use_context_forward = use_context_forward
world_size = parallel_state.get_tensor_model_parallel_world_size() world_size = parallel_state.get_tensor_model_parallel_world_size()
self.dropout = dropout
self.head_size = embed_dim // num_heads
self.hidden_size_per_attention_head = dist_utils.divide( self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads projection_size, num_heads
) )
self.num_attention_heads_per_partition = dist_utils.divide( self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, world_size num_heads, world_size
) )
# self.tp_size = get_tensor_model_parallel_world_size()
# num_heads = self.num_heads_per_partition if self.use_context_forward:
self.qkv_backend = VisionTritonAttention()
else:
self.qkv_backend = VisionSdpaAttention(
head_size=self.head_size,
dropout=dropout,
flatten_batch=flatten_batch,
use_full_precision_softmax=use_full_precision_softmax,
)
self.use_qkv_parallel = use_qkv_parallel self.use_qkv_parallel = use_qkv_parallel
if use_qkv_parallel: if use_qkv_parallel:
self.head_dim = embed_dim // num_heads
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size=embed_dim, hidden_size=embed_dim,
head_size=self.head_dim, head_size=self.head_size,
total_num_heads=num_heads, total_num_heads=num_heads,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=f"{prefix}.qkv_proj",
...@@ -114,12 +142,15 @@ class VisionAttention(nn.Module): ...@@ -114,12 +142,15 @@ class VisionAttention(nn.Module):
x: torch.Tensor, x: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None,
rotary_pos_emb: torch.Tensor = None, rotary_pos_emb: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
r"""
Args:
x: [b, s, embed_dim]
cu_seqlens: [b]
Returns:
[s, b, num_heads * head]
""" """
Input shape: [b, s, embed_dim]
Output shape: [s, b, num_heads * head_size]
"""
bsz, s, _ = x.shape bsz, s, _ = x.shape
if self.use_qkv_parallel: if self.use_qkv_parallel:
# [b, s, embed_dim] --> [b, s, embed_dim] # [b, s, embed_dim] --> [b, s, embed_dim]
...@@ -136,19 +167,19 @@ class VisionAttention(nn.Module): ...@@ -136,19 +167,19 @@ class VisionAttention(nn.Module):
else: else:
# [b, s, embed_dim] --> [s, b, embed_dim] # [b, s, embed_dim] --> [s, b, embed_dim]
x = rearrange(x, "b s ... -> s b ...") x = rearrange(x, "b s ... -> s b ...")
# [s, b, embed_dim] --> [s, b, head * 3 * head_dim] # [s, b, embed_dim] --> [s, b, head * 3 * head_size]
qkv, _ = self.qkv_proj(x) qkv, _ = self.qkv_proj(x)
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim] # [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
new_x_shape = qkv.size()[:-1] + ( new_x_shape = qkv.size()[:-1] + (
self.num_attention_heads_per_partition, self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head, 3 * self.hidden_size_per_attention_head,
) )
qkv = qkv.view(*new_x_shape) qkv = qkv.view(*new_x_shape)
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim] # [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3) q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
# [s, b, head, head_dim] --> [b, s, head, head_dim] # [s, b, head, head_size] --> [b, s, head, head_size]
q, k, v = [ q, k, v = [
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v) rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
] ]
...@@ -160,45 +191,217 @@ class VisionAttention(nn.Module): ...@@ -160,45 +191,217 @@ class VisionAttention(nn.Module):
if self.use_qkv_parallel: if self.use_qkv_parallel:
pass pass
else: else:
# [b, s, head, head_dim] --> [b * s, head, head_dim] # [b, s, head, head_size] --> [b * s, head, head_size]
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
# [b * s, num_heads, head_size] output = self.qkv_backend.forward(q, k, v, bsz, cu_seqlens, attention_mask)
output = torch.empty_like(q)
seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda()
max_seqlen = seq_lens.max().item()
context_attention_fwd(
q,
k,
v,
output,
cu_seqlens.cuda(),
seq_lens,
max_seqlen,
is_causal=False,
)
if self.use_qkv_parallel: if self.use_qkv_parallel:
# [b * s, h, head_size] --> [b, s, h * head_size]
# [b * s, head, head_dim] --> [b, s, head * head_dim]
output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz) output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
# [b, s, head, head_dim] --> [b, s, head, head_dim] # [b, s, h * head_size] --> [b, s, h * head_size]
output, _ = self.proj(output) output, _ = self.proj(output)
else: else:
# [b * s, head, head_dim] --> [b, s, head, head_dim] # [b * s, h, head_size] --> [s, b, h * head_size]
context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz)
# [s, b, num_heads * head_size]
context_layer = rearrange( context_layer = rearrange(
context_layer, "b s h d -> s b (h d)" output, "(b s) h d -> s b (h d)", b=bsz, s=s
).contiguous() ).contiguous()
# [s, b, num_heads * head_size] --> [s, b, num_heads * head_size] # [s, b, h * head_size] --> [s, b, h * head_size]
output, _ = self.proj(context_layer) output, _ = self.proj(context_layer)
# [s, b, h * head_size] --> [b, s, h * head_size]
output = output.view(bsz, s, -1) output = output.view(bsz, s, -1)
return output return output
class VisionSdpaAttention(nn.Module):
r"""
Scaled Dot Product Attention inner product
"""
# TODO: Should it be released after used?
_mask_cache = {}
def __init__(
self,
head_size: int,
dropout: float = 0.0,
flatten_batch: bool = False,
use_full_precision_softmax: bool = False,
):
super().__init__()
self.head_size = head_size
self.flatten_batch = flatten_batch
self.use_full_precision_softmax = use_full_precision_softmax
self.dropout = dropout
def generate_patch_attention_mask(
self,
s: int,
bsz: int,
device,
cu_seqlens: Optional[torch.Tensor],
flatten_batch: bool = False,
dtype=torch.bfloat16,
) -> torch.Tensor:
r"""
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
When `flatten_batch` is True:
- All sequences in the batch are flattened into a single dimension
- `s` represents the total number of tokens across all sequences in the batch
- Returns a unified mask of shape `(1, 1, s, s)`
When `flatten_batch` is False:
- Each sequence has its own attention mask
- `s` represents the maximum sequence length in the batch
- Returns separate masks of shape `(b, 1, s, s)`
Args:
flatten_batch: (bool):
If True, treats all sequences in the batch as a single flattened sequence
If False, generates separate masks for each sequence
Returns:
Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
"""
cache_key = (s, bsz, flatten_batch, tuple(cu_seqlens.cpu().tolist()))
if cache_key in VisionSdpaAttention._mask_cache:
cached_mask = VisionSdpaAttention._mask_cache[cache_key]
# print(f"cache hit for key: {cache_key}")
return cached_mask.to(device=device, dtype=dtype)
if cu_seqlens is None:
raise ValueError("Internal Error: cu_seqlens cannot be None")
if flatten_batch:
mask = torch.zeros([1, s, s], device=device, dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
start = cu_seqlens[i - 1]
end = cu_seqlens[i]
mask[
...,
start:end,
start:end,
] = True
else:
# [1, 1, 1, s]
row_indices = torch.arange(s, device=device).view(1, 1, 1, s)
# [1, 1, s, 1]
col_indices = torch.arange(s, device=device).view(1, 1, s, 1)
# [b, 1, 1, 1]
seq_lens = (
(cu_seqlens[1:] - cu_seqlens[:-1]).to(device=device).view(-1, 1, 1, 1)
)
mask = (row_indices < seq_lens) & (col_indices < seq_lens)
# Convert to attention mask format (False -> 0, True -> -inf)
mask = (~mask).to(dtype) * torch.finfo(dtype).min
VisionSdpaAttention._mask_cache[cache_key] = mask
return mask
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
bsz: int,
cu_seqlens: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Args:
cu_seqlens: [b]
Returns:
[b * s, h, head_size]
"""
s = q.shape[0] // bsz
# [b, 1, s, s]
if attention_mask is None:
attention_mask = self.generate_patch_attention_mask(
s, bsz, q.device, cu_seqlens, self.flatten_batch, q.dtype
)
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
# [b, 1, s]
if self.use_full_precision_softmax:
scale = self.head_size**-0.5
k_transposed = rearrange(k, "b h s d -> b h d s")
attn_weights = torch.matmul(q, k_transposed) * scale
del k, k_transposed
attn_weights = attn_weights + attention_mask
del attention_mask
# full-precision
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(q.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=False
)
output = torch.matmul(attn_weights, v)
del attn_weights, v
else:
# SDPA
# [b, h, s, head_size]
output = F.scaled_dot_product_attention(
q, k, v, attention_mask, dropout_p=self.dropout
)
# [b, h, s, head_size] --> [b * s, h, head_size]
output = rearrange(output, "b h s d -> (b s) h d")
return output
class VisionTritonAttention(nn.Module):
"""
Triton-implemented attention without a causal mask
"""
def __init__(
self,
):
super().__init__()
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
_bsz: int,
cu_seqlens: Optional[torch.Tensor],
**kwargs,
) -> torch.Tensor:
r"""
Args:
cu_seqlens: [b]
Returns:
[b * s, h, head_size]
"""
# [b * s, head, head_size]
output = torch.empty_like(q)
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
max_seqlen = seq_lens.max().item()
context_attention_fwd(
q,
k,
v,
output,
cu_seqlens.cuda(),
seq_lens.cuda(),
max_seqlen,
is_causal=False,
)
return output
...@@ -240,6 +240,7 @@ class MllamaImageProcessor(BaseImageProcessor): ...@@ -240,6 +240,7 @@ class MllamaImageProcessor(BaseImageProcessor):
class MiniCPMVImageProcessor(BaseImageProcessor): class MiniCPMVImageProcessor(BaseImageProcessor):
def __init__(self, hf_config, server_args, _processor): def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor) super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "(<image>./</image>)"
@staticmethod @staticmethod
def _process_images_task(images, input_text): def _process_images_task(images, input_text):
...@@ -271,7 +272,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor): ...@@ -271,7 +272,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
async def process_images_async( async def process_images_async(
self, self,
image_data: List[Union[str, bytes]], image_data: List[Union[str, bytes]],
input_text, input_ids,
request_obj, request_obj,
max_req_input_len, max_req_input_len,
): ):
...@@ -282,28 +283,49 @@ class MiniCPMVImageProcessor(BaseImageProcessor): ...@@ -282,28 +283,49 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
image_data = [image_data] image_data = [image_data]
image_hashes, image_sizes = [], [] image_hashes, image_sizes = [], []
raw_images = [] all_frames = []
IMAGE_TOKEN = "(<image>./</image>)"
# roughly calculate the max number of frames # roughly calculate the max number of frames under the max_req_input_len limit
# TODO: the process should be applied to all the visual inputs
def calculate_max_num_frames() -> int: def calculate_max_num_frames() -> int:
# Model-specific # Model-specific
NUM_TOKEN_PER_FRAME = 330 NUM_TOKEN_PER_FRAME = 330
ret = (max_req_input_len - len(input_text)) // NUM_TOKEN_PER_FRAME ret = (max_req_input_len - len(input_ids)) // NUM_TOKEN_PER_FRAME
return min(ret, 100) return min(ret, 100)
# if cuda OOM set a smaller number
MAX_NUM_FRAMES = calculate_max_num_frames() MAX_NUM_FRAMES = calculate_max_num_frames()
print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
def encode_video(video_path): # print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
def get_estimated_frames_list():
"""
estimate the total frame count from all visual input
"""
# Before processing inputs
estimated_frames_list = []
for image in image_data:
if isinstance(image, str) and image.startswith("video:"):
path = image[len("video:") :]
# Estimate frames for the video
vr = VideoReader(path, ctx=cpu(0))
num_frames = len(vr)
else:
# For images, each contributes one frame
num_frames = 1
estimated_frames_list.append(num_frames)
return estimated_frames_list
estimated_frames_list = get_estimated_frames_list()
total_frame_count = sum(estimated_frames_list)
scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
def encode_video(video_path, frame_count_limit=None):
if not os.path.exists(video_path): if not os.path.exists(video_path):
logger.error(f"Video {video_path} does not exist") logger.error(f"Video {video_path} does not exist")
return [] return []
if MAX_NUM_FRAMES == 0: if frame_count_limit == 0:
return [] return []
def uniform_sample(l, n): def uniform_sample(l, n):
...@@ -314,45 +336,63 @@ class MiniCPMVImageProcessor(BaseImageProcessor): ...@@ -314,45 +336,63 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
vr = VideoReader(video_path, ctx=cpu(0)) vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_idx = [i for i in range(0, len(vr), sample_fps)] frame_idx = [i for i in range(0, len(vr), sample_fps)]
if len(frame_idx) > MAX_NUM_FRAMES: if frame_count_limit is not None and len(frame_idx) > frame_count_limit:
frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES) frame_idx = uniform_sample(frame_idx, frame_count_limit)
frames = vr.get_batch(frame_idx).asnumpy() frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype("uint8")) for v in frames] frames = [Image.fromarray(v.astype("uint8")) for v in frames]
return frames return frames
if isinstance(input_text, list): if isinstance(input_ids, list):
assert len(input_text) and isinstance(input_text[0], int) assert len(input_ids) and isinstance(input_ids[0], int)
input_text = self._processor.tokenizer.decode(input_text) input_text = self._processor.tokenizer.decode(input_ids)
else:
input_text = input_ids
# MiniCPMV requires each frame of video as a single image token # MiniCPMV requires each frame of video as a single image token
text_parts = input_text.split(IMAGE_TOKEN) text_parts = input_text.split(self.IMAGE_TOKEN)
new_text_parts = [] new_text_parts = []
for image_index, image in enumerate(image_data): # Process each input with allocated frames
try: for image_index, (image, estimated_frames) in enumerate(
if isinstance(image, str) and image.startswith("video:"): zip(image_data, estimated_frames_list)
path = image[len("video:") :] ):
frames = encode_video(path) if len(all_frames) >= MAX_NUM_FRAMES:
else: frames_to_process = 0
raw_image, size = load_image(image) else:
frames = [raw_image] frames_to_process = max(1, int(estimated_frames * scaling_factor))
if len(frames) == 0:
continue if frames_to_process == 0:
except FileNotFoundError as e: frames = []
print(e) else:
return None try:
if isinstance(image, str) and image.startswith("video:"):
image_sizes += frames[0].size * len(frames) path = image[len("video:") :]
image_hashes += [hash(image)] * len(frames) frames = encode_video(path, frame_count_limit=frames_to_process)
raw_images += frames else:
raw_image, _size = load_image(image)
frames = [raw_image]
if len(frames) == 0:
continue
except FileNotFoundError as e:
print(e)
return None
image_sizes += frames[0].size * len(frames)
image_hashes += [hash(image)] * len(frames)
all_frames += frames
assert frames_to_process == len(frames)
new_text_parts.append(text_parts[image_index]) new_text_parts.append(text_parts[image_index])
new_text_parts.append(IMAGE_TOKEN * len(frames))
if frames_to_process != 0:
new_text_parts.append(self.IMAGE_TOKEN * len(frames))
new_text_parts.append(text_parts[-1]) new_text_parts.append(text_parts[-1])
input_text = "".join(new_text_parts) input_text = "".join(new_text_parts)
if len(raw_images) == 0:
if len(all_frames) == 0:
return None return None
res = await self._process_images(images=raw_images, input_text=input_text) res = await self._process_images(images=all_frames, input_text=input_text)
pixel_values = res["pixel_values"] pixel_values = res["pixel_values"]
tgt_sizes = res["tgt_sizes"] tgt_sizes = res["tgt_sizes"]
input_ids = res["input_ids"] input_ids = res["input_ids"]
...@@ -364,7 +404,6 @@ class MiniCPMVImageProcessor(BaseImageProcessor): ...@@ -364,7 +404,6 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
if tokenizer.slice_start_id: if tokenizer.slice_start_id:
slice_start_id = [tokenizer.slice_start_id] slice_start_id = [tokenizer.slice_start_id]
slice_end_id = [tokenizer.slice_end_id] slice_end_id = [tokenizer.slice_end_id]
return { return {
"input_ids": input_ids.flatten().tolist(), "input_ids": input_ids.flatten().tolist(),
"pixel_values": pixel_values, "pixel_values": pixel_values,
......
# Adapted from # Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team. # Copyright 2023 The SGLang team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# #
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only MiniCPM-V model compatible with HuggingFace weights.""" """Inference-only MiniCPM-V model compatible with HuggingFace weights."""
from functools import cached_property, partial from functools import partial
from typing import ( from typing import (
Any, Any,
Callable, Callable,
...@@ -33,16 +33,13 @@ from typing import ( ...@@ -33,16 +33,13 @@ from typing import (
Union, Union,
) )
import numpy as np
import torch import torch
import torch.types import torch.types
from PIL import Image from PIL import Image
from torch import nn from torch import nn
from torch.nn.init import trunc_normal_ from torch.nn.init import trunc_normal_
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.activation import get_act_fn
...@@ -63,6 +60,88 @@ from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM ...@@ -63,6 +60,88 @@ from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
RawImageType = Union[Image.Image, torch.Tensor] RawImageType = Union[Image.Image, torch.Tensor]
# sin/cos positional embedding helpers are adapted from:
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_1d_sincos_pos_embed_from_grid(
embed_dim: int, pos: np.ndarray, version: Tuple[int, int] = (2, 0)
) -> torch.Tensor:
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,) / (H, W)
out: (M, D) / (H, W, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
if version == (2, 0):
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
else:
out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
emb_sin = np.sin(out) # (H, W, D/2)
emb_cos = np.cos(out) # (H, W, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
return emb
def get_2d_sincos_pos_embed_from_grid(
embed_dim: int, grid: np.ndarray, version: Tuple[int, int] = (2, 0)
) -> torch.Tensor:
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[0], version
) # (H*W, D/2) or (H, W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[1], version
) # (H*W, D/2) or (H, W, D/2)
if version == (2, 0):
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
else:
emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
return emb
def get_2d_sincos_pos_embed(
embed_dim: int,
grid_size: Union[int, Tuple[int, int]],
cls_token: bool = False,
version: Tuple[int, int] = (2, 0),
) -> torch.Tensor:
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_h_size, grid_w_size = grid_size, grid_size
else:
grid_h_size, grid_w_size = grid_size[0], grid_size[1]
grid_h = np.arange(grid_h_size, dtype=np.float32)
grid_w = np.arange(grid_w_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
assert isinstance(grid, np.ndarray) and grid.shape == (2, grid_h_size, grid_w_size)
if version == (2, 0):
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
else:
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version)
return pos_embed
class Idefics2VisionMLP(nn.Module): class Idefics2VisionMLP(nn.Module):
def __init__( def __init__(
...@@ -116,6 +195,10 @@ class Idefics2EncoderLayer(nn.Module): ...@@ -116,6 +195,10 @@ class Idefics2EncoderLayer(nn.Module):
projection_size=config.intermediate_size, projection_size=config.intermediate_size,
use_qkv_parallel=True, use_qkv_parallel=True,
quant_config=quant_config, quant_config=quant_config,
dropout=config.attention_dropout,
use_context_forward=False,
use_full_precision_softmax=True,
flatten_batch=False,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
) )
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
...@@ -126,7 +209,6 @@ class Idefics2EncoderLayer(nn.Module): ...@@ -126,7 +209,6 @@ class Idefics2EncoderLayer(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -136,11 +218,8 @@ class Idefics2EncoderLayer(nn.Module): ...@@ -136,11 +218,8 @@ class Idefics2EncoderLayer(nn.Module):
""" """
residual = hidden_states residual = hidden_states
hidden_states = self.layer_norm1(hidden_states) hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn( hidden_states = self.self_attn(hidden_states, cu_seqlens=cu_seqlens)
hidden_states,
cu_seqlens=cu_seqlens,
# , forward_batch=forward_batch
)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
residual = hidden_states residual = hidden_states
hidden_states = self.layer_norm2(hidden_states) hidden_states = self.layer_norm2(hidden_states)
...@@ -181,7 +260,6 @@ class Idefics2Encoder(nn.Module): ...@@ -181,7 +260,6 @@ class Idefics2Encoder(nn.Module):
self, self,
inputs_embeds: torch.Tensor, inputs_embeds: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
Args: Args:
...@@ -195,7 +273,8 @@ class Idefics2Encoder(nn.Module): ...@@ -195,7 +273,8 @@ class Idefics2Encoder(nn.Module):
hidden_states = inputs_embeds hidden_states = inputs_embeds
for encoder_layer in self.layers: for encoder_layer in self.layers:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch hidden_states,
cu_seqlens=cu_seqlens,
) )
hidden_states = layer_outputs hidden_states = layer_outputs
return hidden_states return hidden_states
...@@ -232,19 +311,14 @@ class Idefics2VisionEmbeddings(nn.Module): ...@@ -232,19 +311,14 @@ class Idefics2VisionEmbeddings(nn.Module):
self.num_positions = self.num_patches self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
def forward( def get_position_ids(
self, self,
pixel_values: torch.FloatTensor, pixel_values: torch.FloatTensor,
patch_attention_mask: torch.BoolTensor, patch_attention_mask: torch.BoolTensor,
tgt_sizes: Optional[torch.IntTensor] = None, tgt_sizes: Optional[torch.IntTensor] = None,
) -> torch.Tensor: ):
batch_size, _, max_im_h, max_im_w = pixel_values.shape batch_size, _, max_im_h, max_im_w = pixel_values.shape
target_dtype = self.patch_embedding.weight.dtype
pixel_values = pixel_values.to(
device=self.patch_embedding.weight.device, dtype=target_dtype
)
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
max_nb_patches_h, max_nb_patches_w = ( max_nb_patches_h, max_nb_patches_w = (
max_im_h // self.patch_size, max_im_h // self.patch_size,
max_im_w // self.patch_size, max_im_w // self.patch_size,
...@@ -277,6 +351,24 @@ class Idefics2VisionEmbeddings(nn.Module): ...@@ -277,6 +351,24 @@ class Idefics2VisionEmbeddings(nn.Module):
).flatten() ).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
position_ids = position_ids.to(self.position_embedding.weight.device) position_ids = position_ids.to(self.position_embedding.weight.device)
return position_ids
def forward(
self,
pixel_values: torch.FloatTensor,
patch_attention_mask: torch.BoolTensor,
tgt_sizes: Optional[torch.IntTensor] = None,
) -> torch.Tensor:
target_dtype = self.patch_embedding.weight.dtype
pixel_values = pixel_values.to(
device=self.patch_embedding.weight.device, dtype=target_dtype
)
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
position_ids = self.get_position_ids(
pixel_values, patch_attention_mask, tgt_sizes
)
embeddings = embeddings + self.position_embedding(position_ids) embeddings = embeddings + self.position_embedding(position_ids)
return embeddings return embeddings
...@@ -287,7 +379,6 @@ class Idefics2VisionTransformer(nn.Module): ...@@ -287,7 +379,6 @@ class Idefics2VisionTransformer(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -302,8 +393,6 @@ class Idefics2VisionTransformer(nn.Module): ...@@ -302,8 +393,6 @@ class Idefics2VisionTransformer(nn.Module):
def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor: def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor:
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] # shape: (batch_size,) patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] # shape: (batch_size,)
# 做 prefix sum 来得到 cu_seqlens,注意在最前面插一个 0 作为 offset
cu_seqlens = torch.cat( cu_seqlens = torch.cat(
[ [
torch.tensor([0], device=patch_len.device, dtype=torch.int32), torch.tensor([0], device=patch_len.device, dtype=torch.int32),
...@@ -316,19 +405,18 @@ class Idefics2VisionTransformer(nn.Module): ...@@ -316,19 +405,18 @@ class Idefics2VisionTransformer(nn.Module):
def forward( def forward(
self, self,
pixel_values, pixel_values,
forward_batch: ForwardBatch,
patch_attention_mask: Optional[torch.BoolTensor] = None, patch_attention_mask: Optional[torch.BoolTensor] = None,
tgt_sizes: Optional[torch.IntTensor] = None, tgt_sizes: Optional[torch.IntTensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embeddings( hidden_states = self.embeddings(
pixel_values=pixel_values, pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask, patch_attention_mask=patch_attention_mask,
# forward_batch=forward_batch,
tgt_sizes=tgt_sizes, tgt_sizes=tgt_sizes,
) )
cu_seqlens = self.compute_cu_seqlens(tgt_sizes) cu_seqlens = self.compute_cu_seqlens(tgt_sizes)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch hidden_states,
cu_seqlens=cu_seqlens,
) )
last_hidden_state = self.post_layernorm(encoder_outputs) last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state return last_hidden_state
...@@ -573,14 +661,12 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -573,14 +661,12 @@ class MiniCPMVBaseModel(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
# multimodal_config = config.model_config.multimodal_config
super().__init__() super().__init__()
# All MiniCPM-V models disable `tie_word_embeddings` but # All MiniCPM-V models disable `tie_word_embeddings` but
# `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
# check `tie_word_embeddings` until vLLM integrate MiniCPM-V model # check `tie_word_embeddings` until SGLang integrate MiniCPM-V model
# and config class # and config class
self.config = config self.config = config
# self.multimodal_config = multimodal_config
self.version = get_version_by_config(self.config) self.version = get_version_by_config(self.config)
self.llm = self.init_llm(config=config, quant_config=quant_config) self.llm = self.init_llm(config=config, quant_config=quant_config)
...@@ -598,13 +684,6 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -598,13 +684,6 @@ class MiniCPMVBaseModel(nn.Module):
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@cached_property
def sampler(self):
if hasattr(self.llm, "sampler"):
return self.llm.sampler
return get_sampler()
def _get_image_bounds( def _get_image_bounds(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -666,7 +745,6 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -666,7 +745,6 @@ class MiniCPMVBaseModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImageInputs], image_inputs: Optional[MiniCPMVImageInputs],
forward_batch: ForwardBatch,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids) vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
...@@ -680,10 +758,7 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -680,10 +758,7 @@ class MiniCPMVBaseModel(nn.Module):
.to(vlm_embedding.device) .to(vlm_embedding.device)
) )
else: else:
vision_hidden_states = self.get_vision_hidden_states( vision_hidden_states = self.get_vision_hidden_states(image_inputs)
forward_batch, image_inputs
)
# See NOTE in _parse_and_validate_inputs # See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"] image_bounds = image_inputs["image_bounds"]
if len(image_bounds) > 0: if len(image_bounds) > 0:
...@@ -693,6 +768,7 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -693,6 +768,7 @@ class MiniCPMVBaseModel(nn.Module):
for start, end in image_bounds.tolist() for start, end in image_bounds.tolist()
] ]
).to(vlm_embedding.device) ).to(vlm_embedding.device)
vlm_embedding.scatter_( vlm_embedding.scatter_(
0, 0,
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]), image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
...@@ -839,7 +915,7 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -839,7 +915,7 @@ class MiniCPMVBaseModel(nn.Module):
# There values are useless because their embeddings will be replaced by vision embeddings anyway. # There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1) input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs, forward_batch) vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
# always pass the input via `inputs_embeds` # always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent # to make sure the computation graph is consistent
...@@ -857,29 +933,6 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -857,29 +933,6 @@ class MiniCPMVBaseModel(nn.Module):
input_ids, hidden_states, self.llm.lm_head, forward_batch input_ids, hidden_states, self.llm.lm_head, forward_batch
) )
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.llm.compute_logits(hidden_states, sampling_metadata)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="llm", connector="resampler", tower_model="vpm"
)
def init_llm( def init_llm(
self, self,
config: Qwen2Config, config: Qwen2Config,
...@@ -910,9 +963,7 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -910,9 +963,7 @@ class MiniCPMVBaseModel(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def get_vision_hidden_states( def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor:
self, forward_batch: ForwardBatch, data: MiniCPMVImageInputs
) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -1019,7 +1070,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel): ...@@ -1019,7 +1070,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
def get_vision_hidden_states( def get_vision_hidden_states(
self, self,
forward_batch: ForwardBatch,
data: MiniCPMVImageInputs, data: MiniCPMVImageInputs,
) -> torch.Tensor: ) -> torch.Tensor:
pixel_values = data["data"] pixel_values = data["data"]
...@@ -1042,15 +1092,18 @@ class MiniCPMV2_6(MiniCPMVBaseModel): ...@@ -1042,15 +1092,18 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
patch_attn_mask = torch.zeros( patch_attn_mask = torch.zeros(
(B, 1, max_patches), dtype=torch.bool, device=device (B, 1, max_patches), dtype=torch.bool, device=device
) )
for i in range(B):
patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True tgt_sizes_tensor = tgt_sizes.clone().to(device=patch_attn_mask.device)
mask_shapes = tgt_sizes_tensor[:, 0] * tgt_sizes_tensor[:, 1]
patch_attn_mask[:, 0, :] = torch.arange(
patch_attn_mask.size(2), device=patch_attn_mask.device
).unsqueeze(0) < mask_shapes.unsqueeze(1)
vision_embedding = self.vpm( vision_embedding = self.vpm(
all_pixel_values.type(dtype), all_pixel_values.type(dtype),
forward_batch=forward_batch,
patch_attention_mask=patch_attn_mask, patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes, tgt_sizes=tgt_sizes,
) )
return self.resampler(vision_embedding, tgt_sizes) return self.resampler(vision_embedding, tgt_sizes)
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
...@@ -1138,7 +1191,7 @@ class MiniCPMV: ...@@ -1138,7 +1191,7 @@ class MiniCPMV:
""" """
Different versions of MiniCPMV use different visual encoders and LLMs, Different versions of MiniCPMV use different visual encoders and LLMs,
which is not conducive to the current integration logic of LoRA and which is not conducive to the current integration logic of LoRA and
bitsandbytes in vLLM. Therefore, it is necessary to separate them. bitsandbytes in SGLang. Therefore, it is necessary to separate them.
""" """
# Ensure that the LoRA support check passes when the class is not # Ensure that the LoRA support check passes when the class is not
......
...@@ -17,6 +17,7 @@ from transformers.models.mllama.modeling_mllama import ( ...@@ -17,6 +17,7 @@ from transformers.models.mllama.modeling_mllama import (
import sglang.srt.distributed.parallel_state as ps import sglang.srt.distributed.parallel_state as ps
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -145,61 +146,6 @@ class MllamaPrecomputedPositionEmbedding(nn.Module): ...@@ -145,61 +146,6 @@ class MllamaPrecomputedPositionEmbedding(nn.Module):
return hidden_state return hidden_state
class MllamaVisionSdpaAttention(nn.Module):
def __init__(self, config: config_mllama.MllamaVisionConfig):
super().__init__()
model_parallel_size = get_tensor_model_parallel_world_size()
self.embed_dim = config.hidden_size
self.num_heads = config.attention_heads
self.head_dim = config.hidden_size // config.attention_heads
self.num_local_heads = self.num_heads // model_parallel_size
self.q_size = self.num_local_heads * self.head_dim
self.kv_size = self.num_local_heads * self.head_dim
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
bias=False,
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.embed_dim,
bias=False,
input_is_parallel=True,
)
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_state)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(
q.shape[0], q.shape[1], self.num_local_heads, self.head_dim
).transpose(1, 2)
k = k.view(
k.shape[0], k.shape[1], self.num_local_heads, self.head_dim
).transpose(1, 2)
v = v.view(
v.shape[0], v.shape[1], self.num_local_heads, self.head_dim
).transpose(1, 2)
# TODO: remove padding in image encoder
attn_output = F.scaled_dot_product_attention(
q, k, v, attn_mask=attention_mask, dropout_p=0.0
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(
attn_output.shape[0], attn_output.shape[1], -1
)
output, _ = self.o_proj(attn_output)
return output
class MllamaVisionMLP(nn.Module): class MllamaVisionMLP(nn.Module):
def __init__(self, config, quant_config: Optional[QuantizationConfig] = None): def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
...@@ -237,7 +183,17 @@ class MllamaVisionEncoderLayer(nn.Module): ...@@ -237,7 +183,17 @@ class MllamaVisionEncoderLayer(nn.Module):
self.is_gated = is_gated self.is_gated = is_gated
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.self_attn = MllamaVisionSdpaAttention(config) self.self_attn = VisionAttention(
self.hidden_size,
self.num_attention_heads,
self.hidden_size,
use_qkv_parallel=True,
quant_config=None,
dropout=0.0,
use_context_forward=False,
use_full_precision_softmax=False,
flatten_batch=False,
)
self.mlp = MllamaVisionMLP(config) self.mlp = MllamaVisionMLP(config)
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
...@@ -992,6 +948,10 @@ class MllamaForConditionalGeneration(nn.Module): ...@@ -992,6 +948,10 @@ class MllamaForConditionalGeneration(nn.Module):
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
if "vision_model" in name:
# adapt to VisionAttention
name = name.replace("self_attn.o_proj", "self_attn.proj")
param = params_dict.pop(name) param = params_dict.pop(name)
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
......
...@@ -249,7 +249,10 @@ class Qwen2Model(nn.Module): ...@@ -249,7 +249,10 @@ class Qwen2Model(nn.Module):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) if hasattr(self.config, "scale_emb"):
return self.embed_tokens(input_ids) * self.config.scale_emb
else:
return self.embed_tokens(input_ids)
def forward( def forward(
self, self,
......
...@@ -30,12 +30,10 @@ import numpy as np ...@@ -30,12 +30,10 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange
from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.activation import QuickGELU
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
from sglang.srt.distributed import parallel_state
from sglang.srt.distributed import utils as dist_utils
from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
...@@ -118,6 +116,7 @@ class Qwen2VisionBlock(nn.Module): ...@@ -118,6 +116,7 @@ class Qwen2VisionBlock(nn.Module):
mlp_ratio: float, mlp_ratio: float,
act_layer: Type[nn.Module] = QuickGELU, act_layer: Type[nn.Module] = QuickGELU,
norm_layer: Type[nn.Module] = None, norm_layer: Type[nn.Module] = None,
attn_implementation: Optional[str] = "sdpa",
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -126,12 +125,24 @@ class Qwen2VisionBlock(nn.Module): ...@@ -126,12 +125,24 @@ class Qwen2VisionBlock(nn.Module):
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
if attn_implementation == "sdpa":
use_context_forward = False
use_full_precision_softmax = False
elif attn_implementation == "flash_attention_2":
use_full_precision_softmax = False
use_context_forward = True
elif attn_implementation == "eager":
use_full_precision_softmax = True
use_context_forward = False
self.attn = VisionAttention( self.attn = VisionAttention(
embed_dim=dim, embed_dim=dim,
num_heads=num_heads, num_heads=num_heads,
projection_size=dim, projection_size=dim,
use_qkv_parallel=False, use_qkv_parallel=False,
use_context_forward=use_context_forward,
use_full_precision_softmax=use_full_precision_softmax,
flatten_batch=True,
quant_config=quant_config, quant_config=quant_config,
) )
self.mlp = Qwen2VisionMLP( self.mlp = Qwen2VisionMLP(
...@@ -286,7 +297,6 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -286,7 +297,6 @@ class Qwen2VisionTransformer(nn.Module):
norm_layer = partial(nn.LayerNorm, eps=norm_eps) norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = embed_dim // num_heads head_dim = embed_dim // num_heads
self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2) self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [
Qwen2VisionBlock( Qwen2VisionBlock(
...@@ -294,6 +304,7 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -294,6 +304,7 @@ class Qwen2VisionTransformer(nn.Module):
num_heads=num_heads, num_heads=num_heads,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
norm_layer=norm_layer, norm_layer=norm_layer,
attn_implementation="sdpa",
quant_config=quant_config, quant_config=quant_config,
) )
for _ in range(depth) for _ in range(depth)
...@@ -482,10 +493,6 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -482,10 +493,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
opensource models), the shape will be `(3, seq_len)`, opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,). otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it) (Use input_metadata.mrope_positions to replace it)
pixel_values: Pixel values to be fed to a model.
`None` if no images are passed.
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
`None` if no images are passed.
""" """
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
positions = forward_batch.mrope_positions positions = forward_batch.mrope_positions
...@@ -540,15 +547,18 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -540,15 +547,18 @@ class Qwen2VLForConditionalGeneration(nn.Module):
num_image_tokens = self.calculate_num_image_tokens( num_image_tokens = self.calculate_num_image_tokens(
image_grid_thws[idx] image_grid_thws[idx]
) )
left_idx = start_idx + (image_offset - prefix_len) left_idx = start_idx + (image_offset - prefix_len)
right_idx = ( right_idx = (
start_idx + (image_offset - prefix_len) + num_image_tokens start_idx + (image_offset - prefix_len) + num_image_tokens
) )
inputs_embeds[left_idx:right_idx] = image_embeds[ inputs_embeds[left_idx:right_idx] = image_embeds[
image_embeds_offset : image_embeds_offset + num_image_tokens image_embeds_offset : image_embeds_offset + num_image_tokens
] ]
image_embeds_offset += num_image_tokens image_embeds_offset += num_image_tokens
input_ids = None
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
......
...@@ -444,8 +444,6 @@ def load_image(image_file: Union[str, bytes]): ...@@ -444,8 +444,6 @@ def load_image(image_file: Union[str, bytes]):
else: else:
raise ValueError(f"Invalid image: {image}") raise ValueError(f"Invalid image: {image}")
# if image_size is None:
# image_size = image.size
return image, image_size return image, image_size
......
...@@ -48,6 +48,7 @@ suites = { ...@@ -48,6 +48,7 @@ suites = {
"test_update_weights_from_disk.py", "test_update_weights_from_disk.py",
"test_update_weights_from_tensor.py", "test_update_weights_from_tensor.py",
"test_vision_chunked_prefill.py", "test_vision_chunked_prefill.py",
"test_vision_llm.py",
"test_vision_openai_server.py", "test_vision_openai_server.py",
"test_w8a8_quantization.py", "test_w8a8_quantization.py",
"test_fp8_kvcache.py", "test_fp8_kvcache.py",
...@@ -72,7 +73,6 @@ for target_suite_name, target_tests in suites.items(): ...@@ -72,7 +73,6 @@ for target_suite_name, target_tests in suites.items():
tests.remove(target_suite_name) tests.remove(target_suite_name)
tests.extend(target_tests) tests.extend(target_tests)
if __name__ == "__main__": if __name__ == "__main__":
arg_parser = argparse.ArgumentParser() arg_parser = argparse.ArgumentParser()
arg_parser.add_argument( arg_parser.add_argument(
......
"""
"""
import unittest
from io import BytesIO
import numpy as np
import requests
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import AutoModel, AutoProcessor, AutoTokenizer
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.conversation import generate_chat_conv
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.openai_api.protocol import ChatCompletionRequest
from sglang.srt.server_args import ServerArgs
MiniCPMV = "openbmb/MiniCPM-V-2_6"
# Test the logits output between HF and SGLang
class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase):
@classmethod
def setUpClass(cls):
cls.image_url = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true"
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cls.model_path = ""
cls.chat_template = ""
cls.processor = ""
response = requests.get(cls.image_url)
cls.main_image = Image.open(BytesIO(response.content))
def compare_outputs(self, sglang_output: torch.Tensor, hf_output: torch.Tensor):
# Convert to float32 for numerical stability if needed
hf = hf_output.float()
sg = sglang_output.float()
# Basic shape and dtype comparison
print("\n=== Basic Properties ===")
print(f"Shapes match: {hf.shape == sg.shape}")
print(f"HF shape: {hf.shape}, SGLang shape: {sg.shape}")
print(f"HF dtype: {hf.dtype}, SGLang dtype: {sg.dtype}")
# Move tensors to CPU for numpy operations
hf_np = hf.cpu().numpy()
sg_np = sg.cpu().numpy()
# Statistical metrics
print("\n=== Statistical Metrics ===")
print(f"Mean absolute difference: {torch.mean(torch.abs(hf - sg)).item():.6f}")
print(f"Max absolute difference: {torch.max(torch.abs(hf - sg)).item():.6f}")
print(f"Mean squared error: {torch.mean((hf - sg) ** 2).item():.6f}")
print(
f"Root mean squared error: {torch.sqrt(torch.mean((hf - sg) ** 2)).item():.6f}"
)
# Cosine similarity (across feature dimension)
cos_sim = F.cosine_similarity(hf, sg)
print(f"Mean cosine similarity: {torch.mean(cos_sim).item():.6f}")
print(f"Min cosine similarity: {torch.min(cos_sim).item():.6f}")
# Find largest absolute differences
print("\n=== Largest Absolute Differences ===")
diffs = torch.abs(hf - sg)
flat_diffs = diffs.flatten()
# Get indices of top 10 differences
top_k = 10
top_values, top_flat_indices = torch.topk(flat_diffs, top_k)
# Convert flat indices to multidimensional indices
top_indices = np.unravel_index(top_flat_indices.cpu().numpy(), diffs.shape)
print(f"\nTop {top_k} largest absolute differences:")
print(
"Index".ljust(30)
+ "Difference".ljust(15)
+ "HF Value".ljust(15)
+ "SGLang Value"
)
print("-" * 75)
for i in range(top_k):
# Get the index tuple for this difference
idx = tuple(dim[i] for dim in top_indices)
diff_val = top_values[i].item()
hf_val = hf[idx].item()
sg_val = sg[idx].item()
# Format the index tuple and values
idx_str = str(idx)
print(f"{idx_str:<30}{diff_val:<15.6f}{hf_val:<15.6f}{sg_val:.6f}")
np.testing.assert_allclose(hf_np, sg_np)
def get_processor_output(self):
json_str = f"""
{{
"model": "{self.model_path}",
"messages": [
{{
"role": "user",
"content": [
{{
"type": "image_url",
"image_url": {{
"url": "{self.image_url}"
}}
}},
{{
"type": "text",
"text": "Whats in this picture?"
}}
]
}}
]
}}
"""
req = ChatCompletionRequest.model_validate_json(json_str)
conv = generate_chat_conv(req, template_name=self.chat_template)
text = conv.get_prompt()
# Process inputs using processor
# FIXME: the formal arguments may differ
inputs = self.processor(
text=[text],
images=[self.main_image],
return_tensors="pt",
).to(self.device)
return inputs
def get_sglang_model(self):
model_runner = ModelRunner(
model_config=ModelConfig(self.model_path, model_override_args="{}"),
mem_fraction_static=0.8,
gpu_id=0,
tp_rank=0,
tp_size=1,
nccl_port=12435,
server_args=ServerArgs(
model_path=self.model_path,
disable_cuda_graph=True,
),
)
return model_runner.model
class TestMiniCPMVLogits(VisionLLMLogitsBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model_path = MiniCPMV
cls.tokenizer = AutoTokenizer.from_pretrained(
cls.model_path, trust_remote_code=True
)
cls.processor = AutoProcessor.from_pretrained(
cls.model_path, trust_remote_code=True
)
cls.chat_template = "minicpmv"
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cls.model = AutoModel.from_pretrained(
cls.model_path, torch_dtype=torch.bfloat16, trust_remote_code=True
).eval()
cls.model.to(cls.device)
async def test_encode_output(self):
inputs = self.get_processor_output()
with torch.no_grad():
model_inputs = {
"input_ids": inputs.input_ids,
"image_bound": inputs.image_bound,
"pixel_values": inputs.pixel_values,
"tgt_sizes": inputs.tgt_sizes,
}
(hf_output, _) = self.model.get_vllm_embedding(
model_inputs,
)
hf_output = hf_output.squeeze(0)
with torch.no_grad():
model = self.get_sglang_model()
input_ids = inputs["input_ids"].to(self.device).flatten()
image_inputs = model._parse_and_validate_inputs(
input_ids=input_ids,
**{
"pixel_values": [inputs["pixel_values"]],
"tgt_sizes": [inputs["tgt_sizes"]],
"im_start_id": [self.tokenizer.im_start_id],
"im_end_id": [self.tokenizer.im_end_id],
"slice_start_id": [self.tokenizer.slice_start_id],
"slice_end_id": [self.tokenizer.slice_end_id],
},
)
(sglang_output, _) = model.get_embedding(
input_ids=input_ids, image_inputs=image_inputs
)
self.compare_outputs(sglang_output, hf_output)
if __name__ == "__main__":
unittest.main()
...@@ -180,7 +180,9 @@ class TestOpenAIVisionServer(unittest.TestCase): ...@@ -180,7 +180,9 @@ class TestOpenAIVisionServer(unittest.TestCase):
assert response.usage.total_tokens > 0 assert response.usage.total_tokens > 0
def prepare_video_messages(self, video_path): def prepare_video_messages(self, video_path):
max_frames_num = 32 # the memory consumed by the Vision Attention varies a lot, e.g. blocked qkv vs full-sequence sdpa
# the size of the video embeds differs from the `modality` argument when preprocessed
max_frames_num = 12
vr = VideoReader(video_path, ctx=cpu(0)) vr = VideoReader(video_path, ctx=cpu(0))
total_frame_num = len(vr) total_frame_num = len(vr)
uniform_sampled_frames = np.linspace( uniform_sampled_frames = np.linspace(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment