Unverified Commit 199e3cb4 authored by Yang Liu's avatar Yang Liu Committed by GitHub
Browse files

[Model] Use mm_position to compute mrope positions for GLM-4.xV (#33039)


Signed-off-by: default avatarYang <lymailforjob@gmail.com>
parent 9f8cb81b
...@@ -1283,6 +1283,42 @@ def load_tarsier2(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -1283,6 +1283,42 @@ def load_tarsier2(question: str, image_urls: list[str]) -> ModelRequestData:
) )
# GLM-4.1V
def load_glm4_1v(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "zai-org/GLM-4.1V-9B-Thinking"
engine_args = EngineArgs(
model=model_name,
max_model_len=45082,
max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)},
enforce_eager=True,
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
}
]
processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_data = [fetch_image(url) for url in image_urls]
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=image_data,
)
# GLM-4.5V # GLM-4.5V
def load_glm4_5v(question: str, image_urls: list[str]) -> ModelRequestData: def load_glm4_5v(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "zai-org/GLM-4.5V" model_name = "zai-org/GLM-4.5V"
...@@ -1430,6 +1466,7 @@ model_example_map = { ...@@ -1430,6 +1466,7 @@ model_example_map = {
"stepvl": load_step_vl, "stepvl": load_step_vl,
"tarsier": load_tarsier, "tarsier": load_tarsier,
"tarsier2": load_tarsier2, "tarsier2": load_tarsier2,
"glm4_1v": load_glm4_1v,
"glm4_5v": load_glm4_5v, "glm4_5v": load_glm4_5v,
"glm4_5v_fp8": load_glm4_5v_fp8, "glm4_5v_fp8": load_glm4_5v_fp8,
} }
......
...@@ -27,9 +27,8 @@ ...@@ -27,9 +27,8 @@
"""Inference-only GLM-4.1V & GLM-4.6V-Flash, AutoGLM-Phone-9B model """Inference-only GLM-4.1V & GLM-4.6V-Flash, AutoGLM-Phone-9B model
compatible with HuggingFace weights.""" compatible with HuggingFace weights."""
import itertools
import math import math
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from functools import partial from functools import partial
from typing import Annotated, Any, Literal, TypeAlias from typing import Annotated, Any, Literal, TypeAlias
...@@ -1580,138 +1579,62 @@ class Glm4vForConditionalGeneration( ...@@ -1580,138 +1579,62 @@ class Glm4vForConditionalGeneration(
multimodal_embeddings += tuple(video_embeddings) multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings return multimodal_embeddings
def iter_mm_grid_thw(
self, mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, int, int, int]]:
hf_config = self.config
spatial_merge_size = hf_config.vision_config.spatial_merge_size
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
offset = mm_feature.mm_position.offset
if mm_feature.modality == "image":
t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
assert t == 1, f"Image must have 1 frame, got {t}"
yield offset, t, h // spatial_merge_size, w // spatial_merge_size
elif mm_feature.modality == "video":
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
yield (
offset,
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
else:
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
def get_mrope_input_positions( def get_mrope_input_positions(
self, self,
input_tokens: list[int], input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec], mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]: ) -> tuple[torch.Tensor, int]:
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{"image_grid_thw", "video_grid_thw"},
)
image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
hf_config = self.config
image_token_id = hf_config.image_token_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
llm_pos_ids_list: list = [] llm_pos_ids_list: list = []
st = 0
for (
offset,
llm_grid_t,
llm_grid_h,
llm_grid_w,
) in self.iter_mm_grid_thw(mm_features):
text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w)).reshape(
3, -1
)
llm_pos_ids_list.append(grid_indices + text_len + st_idx)
st = offset + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
text_len = len(input_tokens) - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
if image_grid_thw or video_grid_thw: llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
if token == video_start_token_id:
video_check_flg = True
elif token == video_end_token_id:
video_check_flg = False
if (token == image_token_id) and (video_check_flg is False):
input_token_type.append("image")
elif (token == image_token_id) and (video_check_flg is True):
input_token_type.append("video")
else:
input_token_type.append("text")
input_type_group: list[tuple[str, int, int]] = []
for key, group_iter in itertools.groupby(
enumerate(input_token_type), lambda x: x[1]
):
group_list = list(group_iter)
start_index = group_list[0][0]
end_index = group_list[-1][0] + 1
input_type_group.append((key, start_index, end_index))
video_frame_num = 1
mm_data_idx = 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
if modality_type == "image":
t, h, w = image_grid_thw[mm_data_idx]
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx
)
mm_data_idx += 1
elif modality_type == "video":
t, h, w = (
video_frame_num,
*image_grid_thw[mm_data_idx][1:],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
for t_idx in range(llm_grid_t):
t_index = (
torch.tensor(t_idx)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(1, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(1, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx
)
mm_data_idx += 1
video_frame_num += 1
else:
text_len = end_idx - start_idx
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
video_frame_num = 1
else:
text_len = len(input_tokens)
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta return torch.from_numpy(llm_positions), mrope_position_delta
def forward( def forward(
self, self,
......
...@@ -5,11 +5,11 @@ ...@@ -5,11 +5,11 @@
# https://github.com/zai-org/CogAgent # https://github.com/zai-org/CogAgent
"""Inference-only CogAgent model compatible with THUDM weights.""" """Inference-only CogAgent model compatible with THUDM weights."""
import itertools
from argparse import Namespace from argparse import Namespace
from collections.abc import Mapping, Sequence from collections.abc import Iterator, Mapping, Sequence
from typing import Annotated, Literal from typing import Annotated, Literal
import numpy as np
import torch import torch
from torch import nn from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
...@@ -624,138 +624,56 @@ class GLM4VForCausalLM( ...@@ -624,138 +624,56 @@ class GLM4VForCausalLM(
return self.transformer.vision(pixel_values) return self.transformer.vision(pixel_values)
def iter_mm_grid_thw(
self, mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, int, int, int]]:
hf_config = self.config
spatial_merge_size = hf_config.vision_config.spatial_merge_size
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
offset = mm_feature.mm_position.offset
if mm_feature.modality == "image":
t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
assert t == 1, f"Image must have 1 frame, got {t}"
yield offset, t, h // spatial_merge_size, w // spatial_merge_size
else:
# glm4v only supports image modality
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
def get_mrope_input_positions( def get_mrope_input_positions(
self, self,
input_tokens: list[int], input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec], mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]: ) -> tuple[torch.Tensor, int]:
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{"image_grid_thw", "video_grid_thw"},
)
image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
hf_config = self.config
image_token_id = hf_config.image_token_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
llm_pos_ids_list: list = [] llm_pos_ids_list: list = []
st = 0
for (
offset,
llm_grid_t,
llm_grid_h,
llm_grid_w,
) in self.iter_mm_grid_thw(mm_features):
text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w)).reshape(
3, -1
)
llm_pos_ids_list.append(grid_indices + text_len + st_idx)
# EVA2CLIPModel has embeddings for boi and eoi tokens as well
st = offset + 1 + llm_grid_t * llm_grid_h * llm_grid_w + 1
if st < len(input_tokens):
text_len = len(input_tokens) - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
if image_grid_thw or video_grid_thw: llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
if token == video_start_token_id:
video_check_flg = True
elif token == video_end_token_id:
video_check_flg = False
if (token == image_token_id) and (video_check_flg is False):
input_token_type.append("image")
elif (token == image_token_id) and (video_check_flg is True):
input_token_type.append("video")
else:
input_token_type.append("text")
input_type_group: list[tuple[str, int, int]] = []
for key, group_iter in itertools.groupby(
enumerate(input_token_type), lambda x: x[1]
):
group_list = list(group_iter)
start_index = group_list[0][0]
end_index = group_list[-1][0] + 1
input_type_group.append((key, start_index, end_index))
video_frame_num = 1
mm_data_idx = 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
if modality_type == "image":
t, h, w = image_grid_thw[mm_data_idx]
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx
)
mm_data_idx += 1
elif modality_type == "video":
t, h, w = (
video_frame_num,
*image_grid_thw[mm_data_idx][1:],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
for t_idx in range(llm_grid_t):
t_index = (
torch.tensor(t_idx)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(1, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(1, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx
)
mm_data_idx += 1
video_frame_num += 1
else:
text_len = end_idx - start_idx
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
video_frame_num = 1
else:
text_len = len(input_tokens)
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta return torch.from_numpy(llm_positions), mrope_position_delta
embed_input_ids = SupportsMultiModal.embed_input_ids embed_input_ids = SupportsMultiModal.embed_input_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment