"tutorials/README.txt" did not exist on "97a7dd11b2f88f3376d607f7b5211e05c95b6c5d"
Unverified Commit e53a0b3d authored by Mick's avatar Mick Committed by GitHub
Browse files

[fix] fix mrope positions not picked up (#5265)

parent 038bc5d5
...@@ -94,7 +94,7 @@ class VisionAttention(nn.Module): ...@@ -94,7 +94,7 @@ class VisionAttention(nn.Module):
input_size=embed_dim, input_size=embed_dim,
output_size=embed_dim, output_size=embed_dim,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("out_proj", prefix), prefix=add_prefix("proj", prefix),
) )
def forward( def forward(
......
...@@ -268,6 +268,9 @@ class MultimodalDataItem: ...@@ -268,6 +268,9 @@ class MultimodalDataItem:
self.modality == Modality.VIDEO self.modality == Modality.VIDEO
) and not MultimodalDataItem.is_empty_list(self.pixel_values) ) and not MultimodalDataItem.is_empty_list(self.pixel_values)
def is_valid(self) -> bool:
return self.is_image() or self.is_video() or self.is_audio()
def validate(self): def validate(self):
... ...
# TODO # TODO
...@@ -306,11 +309,7 @@ class MultimodalInputs: ...@@ -306,11 +309,7 @@ class MultimodalInputs:
) )
assert isinstance(ret.mm_items, list) assert isinstance(ret.mm_items, list)
ret.mm_items = [ ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
item
for item in ret.mm_items
if item.is_audio() or item.is_image() or item.is_video()
]
assert len(ret.mm_items) != 0 assert len(ret.mm_items) != 0
...@@ -345,8 +344,8 @@ class MultimodalInputs: ...@@ -345,8 +344,8 @@ class MultimodalInputs:
""" """ """ """
return any(item.is_audio() for item in self.mm_items) return any(item.is_audio() for item in self.mm_items)
def collect_image_inputs(self) -> List[torch.Tensor]: def contains_mm_input(self) -> bool:
return [item.pixel_values for item in self.mm_items if item.is_image()] return any(True for item in self.mm_items if item.is_valid())
def merge(self, other: MultimodalInputs): def merge(self, other: MultimodalInputs):
""" """
......
...@@ -33,7 +33,6 @@ from dataclasses import dataclass ...@@ -33,7 +33,6 @@ from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional, Union from typing import TYPE_CHECKING, List, Optional, Union
import numpy as np
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -399,13 +398,13 @@ class ForwardBatch: ...@@ -399,13 +398,13 @@ class ForwardBatch:
) )
elif self.forward_mode.is_extend(): elif self.forward_mode.is_extend():
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy() extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
for i, multimodal_inputs in enumerate(batch.multimodal_inputs): for i, mm_input in enumerate(batch.multimodal_inputs):
extend_start_loc, extend_seq_len, extend_prefix_len = ( extend_start_loc, extend_seq_len, extend_prefix_len = (
extend_start_loc_cpu[i], extend_start_loc_cpu[i],
batch.extend_seq_lens[i], batch.extend_seq_lens[i],
batch.extend_prefix_lens[i], batch.extend_prefix_lens[i],
) )
if multimodal_inputs is None: if mm_input is None:
# text only # text only
mrope_positions = [ mrope_positions = [
[ [
...@@ -416,23 +415,58 @@ class ForwardBatch: ...@@ -416,23 +415,58 @@ class ForwardBatch:
] ]
] * 3 ] * 3
else: else:
image_grid_thws_list = [
item.image_grid_thws
for item in mm_input.mm_items
if item.image_grid_thws is not None
]
image_grid_thw = (
None
if len(image_grid_thws_list) == 0
else torch.cat(image_grid_thws_list, dim=0)
)
video_grid_thws_list = [
item.video_grid_thws
for item in mm_input.mm_items
if item.video_grid_thws is not None
]
video_grid_thw = (
None
if len(video_grid_thws_list) == 0
else torch.cat(video_grid_thws_list, dim=0)
)
second_per_grid_ts_list = [
item.second_per_grid_ts
for item in mm_input.mm_items
if item.second_per_grid_ts is not None
]
second_per_grid_ts = (
None
if len(second_per_grid_ts_list) == 0
else torch.cat(second_per_grid_ts_list, dim=0)
)
# TODO: current qwen2-vl do not support radix cache since mrope position calculation # TODO: current qwen2-vl do not support radix cache since mrope position calculation
mrope_positions, mrope_position_delta = ( mrope_positions, mrope_position_delta = (
MRotaryEmbedding.get_input_positions( MRotaryEmbedding.get_input_positions(
input_tokens=self.input_ids[ input_tokens=self.input_ids[
extend_start_loc : extend_start_loc + extend_seq_len extend_start_loc : extend_start_loc + extend_seq_len
], ].tolist(),
image_grid_thw=multimodal_inputs.image_grid_thws, image_grid_thw=image_grid_thw,
video_grid_thw=multimodal_inputs.video_grid_thws, video_grid_thw=video_grid_thw,
image_token_id=multimodal_inputs.im_token_id, image_token_id=hf_config.image_token_id,
video_token_id=multimodal_inputs.video_token_id, video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id, vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id, vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.spatial_merge_size, spatial_merge_size=hf_config.vision_config.spatial_merge_size,
context_len=0, context_len=0,
seq_len=len(self.input_ids), seq_len=len(self.input_ids),
second_per_grid_ts=multimodal_inputs.second_per_grid_ts, second_per_grid_ts=second_per_grid_ts,
tokens_per_second=hf_config.vision_config.tokens_per_second, tokens_per_second=getattr(
hf_config.vision_config, "tokens_per_second", None
),
) )
) )
batch.multimodal_inputs[i].mrope_position_delta = ( batch.multimodal_inputs[i].mrope_position_delta = (
......
...@@ -1070,7 +1070,8 @@ class ModelRunner: ...@@ -1070,7 +1070,8 @@ class ModelRunner:
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {}) rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
if rope_scaling is None: if rope_scaling is None:
return False return False
return rope_scaling.get("type", None) == "mrope" is_mrope_enabled = "mrope_section" in rope_scaling
return is_mrope_enabled
def save_remote_model(self, url: str): def save_remote_model(self, url: str):
from sglang.srt.model_loader.loader import RemoteModelLoader from sglang.srt.model_loader.loader import RemoteModelLoader
......
...@@ -30,12 +30,16 @@ import torch ...@@ -30,12 +30,16 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from transformers import Qwen2VLConfig
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig,
Qwen2_5_VLVisionConfig, Qwen2_5_VLVisionConfig,
) )
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionPatchEmbed,
Qwen2_5_VisionRotaryEmbedding,
)
from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.attention.vision import VisionAttention
...@@ -173,33 +177,6 @@ class Qwen2_5_VisionBlock(nn.Module): ...@@ -173,33 +177,6 @@ class Qwen2_5_VisionBlock(nn.Module):
return x return x
class Qwen2_5_VisionPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_chans: int = 3,
embed_dim: int = 1152,
) -> None:
super().__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.embed_dim = embed_dim
kernel_size = [temporal_patch_size, patch_size, patch_size]
self.proj = nn.Conv3d(
in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
target_dtype = self.proj.weight.dtype
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x.to(dtype=target_dtype)).view(L, self.embed_dim)
return x
class Qwen2_5_VisionPatchMerger(nn.Module): class Qwen2_5_VisionPatchMerger(nn.Module):
def __init__( def __init__(
...@@ -244,21 +221,6 @@ class Qwen2_5_VisionPatchMerger(nn.Module): ...@@ -244,21 +221,6 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
return out return out
class Qwen2_5_VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen: int) -> torch.Tensor:
seq = torch.arange(
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
freqs = torch.outer(seq, self.inv_freq)
return freqs
class Qwen2_5_VisionTransformer(nn.Module): class Qwen2_5_VisionTransformer(nn.Module):
def __init__( def __init__(
...@@ -275,7 +237,7 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -275,7 +237,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
spatial_merge_size: int = vision_config.spatial_merge_size spatial_merge_size: int = vision_config.spatial_merge_size
self.spatial_merge_size = spatial_merge_size self.spatial_merge_size = spatial_merge_size
self.spatial_merge_unit: int = spatial_merge_size * spatial_merge_size self.spatial_merge_unit: int = spatial_merge_size * spatial_merge_size
in_chans: int = vision_config.in_channels in_channels: int = vision_config.in_channels
hidden_size: int = vision_config.hidden_size hidden_size: int = vision_config.hidden_size
depth: int = vision_config.depth depth: int = vision_config.depth
num_heads: int = vision_config.num_heads num_heads: int = vision_config.num_heads
...@@ -286,7 +248,7 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -286,7 +248,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
self.patch_embed = Qwen2_5_VisionPatchEmbed( self.patch_embed = Qwen2_5_VisionPatchEmbed(
patch_size=patch_size, patch_size=patch_size,
temporal_patch_size=temporal_patch_size, temporal_patch_size=temporal_patch_size,
in_chans=in_chans, in_channels=in_channels,
embed_dim=hidden_size, embed_dim=hidden_size,
) )
...@@ -469,7 +431,7 @@ cached_get_processor = lru_cache(get_processor) ...@@ -469,7 +431,7 @@ cached_get_processor = lru_cache(get_processor)
class Qwen2_5_VLForConditionalGeneration(nn.Module): class Qwen2_5_VLForConditionalGeneration(nn.Module):
def __init__( def __init__(
self, self,
config: Qwen2VLConfig, config: Qwen2_5_VLConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
...@@ -553,14 +515,15 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -553,14 +515,15 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
otherwise it will be `(seq_len,). otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it) (Use input_metadata.mrope_positions to replace it)
""" """
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": is_mrope_enabled = "mrope_section" in self.config.rope_scaling
if is_mrope_enabled:
positions = forward_batch.mrope_positions positions = forward_batch.mrope_positions
if not ( if not (
forward_batch.forward_mode.is_decode() forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs() or not forward_batch.contains_image_inputs()
): ):
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": if is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, ( assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires " "multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}" f"(3, seq_len) positions, but got {positions.size()}"
......
...@@ -521,14 +521,15 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -521,14 +521,15 @@ class Qwen2VLForConditionalGeneration(nn.Module):
otherwise it will be `(seq_len,). otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it) (Use input_metadata.mrope_positions to replace it)
""" """
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": is_mrope_enabled = "mrope_section" in self.config.rope_scaling
if is_mrope_enabled:
positions = forward_batch.mrope_positions positions = forward_batch.mrope_positions
if not ( if not (
forward_batch.forward_mode.is_decode() forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs() or not forward_batch.contains_image_inputs()
): ):
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": if is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, ( assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires " "multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}" f"(3, seq_len) positions, but got {positions.size()}"
......
...@@ -983,6 +983,8 @@ def v1_chat_generate_request( ...@@ -983,6 +983,8 @@ def v1_chat_generate_request(
): ):
encoded = encoded[1:] encoded = encoded[1:]
prompt_ids += encoded prompt_ids += encoded
if tokenizer_manager.model_config.is_multimodal:
prompt = tokenizer_manager.tokenizer.decode(prompt_ids)
stop = request.stop stop = request.stop
image_data = None image_data = None
audio_data = None audio_data = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment