Unverified Commit 542a4059 authored by YunzhuLu's avatar YunzhuLu Committed by GitHub
Browse files

[Model] Use mm_position to compute mrope positions for Qwen2-VL/2.5-VL (#32126)


Signed-off-by: default avatarYunzhuLu <lucia.yunzhu@gmail.com>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent df7e1271
...@@ -26,11 +26,12 @@ ...@@ -26,11 +26,12 @@
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import Annotated, Any, Literal, TypeAlias from typing import Annotated, Any, Literal, TypeAlias
import einops import einops
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -1044,121 +1045,82 @@ class Qwen2_5_VLForConditionalGeneration( ...@@ -1044,121 +1045,82 @@ class Qwen2_5_VLForConditionalGeneration(
supports_encoder_tp_data = True supports_encoder_tp_data = True
def iter_mm_grid_thw(
self, mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, int, int, int, float]]:
"""
Iterate over multimodal features and yield grid information.
Args:
mm_features: List of multimodal feature specifications
Yields:
Tuple of (offset, grid_t, grid_h, grid_w, t_factor) for each frame/image
"""
spatial_merge_size = self.config.vision_config.spatial_merge_size
tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0)
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
offset = mm_feature.mm_position.offset
if mm_feature.modality == "image":
t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
assert t == 1, f"Image must have 1 frame, got {t}"
yield offset, 1, h // spatial_merge_size, w // spatial_merge_size, 1.0
elif mm_feature.modality == "video":
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
second_per_grid_ts = 1.0
if mm_feature.data.get("second_per_grid_ts", None):
second_per_grid_ts = mm_feature.data[
"second_per_grid_ts"
].data.item()
t_factor = second_per_grid_ts * tokens_per_second
yield (
offset,
t,
h // spatial_merge_size,
w // spatial_merge_size,
t_factor,
)
else:
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
def get_mrope_input_positions( def get_mrope_input_positions(
self, self,
input_tokens: list[int], input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec], mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]: ) -> tuple[torch.Tensor, int]:
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{"image_grid_thw", "video_grid_thw", "second_per_grid_ts"},
)
image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
hf_config = self.config
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id
).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = [] llm_pos_ids_list: list = []
st = 0 st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
video_second_per_grid_t = 0.0
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else:
ed_image = len(input_tokens) + 1
if remain_videos > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = image_grid_thw[image_index]
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = video_grid_thw[video_index]
video_second_per_grid_t = 1.0
if second_per_grid_ts:
video_second_per_grid_t = second_per_grid_ts[video_index]
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
text_len = ed - st
for (
offset,
llm_grid_t,
llm_grid_h,
llm_grid_w,
t_factor,
) in self.iter_mm_grid_thw(mm_features):
text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append( llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
t_index = (
(
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
* video_second_per_grid_t
* tokens_per_second
)
.long()
.flatten()
) )
h_index = ( grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w))
torch.arange(llm_grid_h) if t_factor != 1.0:
.view(1, -1, 1) grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
.expand(llm_grid_t, -1, llm_grid_w) llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx)
.flatten() st = offset + llm_grid_t * llm_grid_h * llm_grid_w
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens): if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st text_len = len(input_tokens) - st
llm_pos_ids_list.append( llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
) )
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta return torch.from_numpy(llm_positions), mrope_position_delta
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None: def get_placeholder_str(cls, modality: str, i: int) -> str | None:
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
import math import math
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from functools import partial from functools import partial
from typing import Annotated, Any, Literal, TypeAlias from typing import Annotated, Any, Literal, TypeAlias
...@@ -1137,121 +1137,82 @@ class Qwen2VLForConditionalGeneration( ...@@ -1137,121 +1137,82 @@ class Qwen2VLForConditionalGeneration(
supports_encoder_tp_data = True supports_encoder_tp_data = True
def iter_mm_grid_thw(
self, mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, int, int, int, float]]:
"""
Iterate over multimodal features and yield grid information.
Args:
mm_features: List of multimodal feature specifications
Yields:
Tuple of (offset, grid_t, grid_h, grid_w, t_factor) for each frame/image
"""
spatial_merge_size = self.config.vision_config.spatial_merge_size
tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0)
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
offset = mm_feature.mm_position.offset
if mm_feature.modality == "image":
t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
assert t == 1, f"Image must have 1 frame, got {t}"
yield offset, 1, h // spatial_merge_size, w // spatial_merge_size, 1.0
elif mm_feature.modality == "video":
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
second_per_grid_ts = 1.0
if mm_feature.data.get("second_per_grid_ts", None):
second_per_grid_ts = mm_feature.data[
"second_per_grid_ts"
].data.item()
t_factor = second_per_grid_ts * tokens_per_second
yield (
offset,
t,
h // spatial_merge_size,
w // spatial_merge_size,
t_factor,
)
else:
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
def get_mrope_input_positions( def get_mrope_input_positions(
self, self,
input_tokens: list[int], input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec], mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]: ) -> tuple[torch.Tensor, int]:
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{"image_grid_thw", "video_grid_thw", "second_per_grid_ts"},
)
image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
hf_config = self.config
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id
).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = [] llm_pos_ids_list: list = []
st = 0 st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
video_second_per_grid_t = 0.0
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else:
ed_image = len(input_tokens) + 1
if remain_videos > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = image_grid_thw[image_index]
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = video_grid_thw[video_index]
video_second_per_grid_t = 1.0
if second_per_grid_ts:
video_second_per_grid_t = second_per_grid_ts[video_index]
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
text_len = ed - st
for (
offset,
llm_grid_t,
llm_grid_h,
llm_grid_w,
t_factor,
) in self.iter_mm_grid_thw(mm_features):
text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append( llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
t_index = (
(
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
* video_second_per_grid_t
* tokens_per_second
)
.long()
.flatten()
) )
h_index = ( grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w))
torch.arange(llm_grid_h) if t_factor != 1.0:
.view(1, -1, 1) grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
.expand(llm_grid_t, -1, llm_grid_w) llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx)
.flatten() st = offset + llm_grid_t * llm_grid_h * llm_grid_w
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens): if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st text_len = len(input_tokens) - st
llm_pos_ids_list.append( llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
) )
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta return torch.from_numpy(llm_positions), mrope_position_delta
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None: def get_placeholder_str(cls, modality: str, i: int) -> str | None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment