Unverified Commit 199e3cb4 authored by Yang Liu's avatar Yang Liu Committed by GitHub
Browse files

[Model] Use mm_position to compute mrope positions for GLM-4.xV (#33039)


Signed-off-by: default avatarYang <lymailforjob@gmail.com>
parent 9f8cb81b
...@@ -1283,6 +1283,42 @@ def load_tarsier2(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -1283,6 +1283,42 @@ def load_tarsier2(question: str, image_urls: list[str]) -> ModelRequestData:
) )
# GLM-4.1V
def load_glm4_1v(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "zai-org/GLM-4.1V-9B-Thinking"
engine_args = EngineArgs(
model=model_name,
max_model_len=45082,
max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)},
enforce_eager=True,
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
}
]
processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_data = [fetch_image(url) for url in image_urls]
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=image_data,
)
# GLM-4.5V # GLM-4.5V
def load_glm4_5v(question: str, image_urls: list[str]) -> ModelRequestData: def load_glm4_5v(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "zai-org/GLM-4.5V" model_name = "zai-org/GLM-4.5V"
...@@ -1430,6 +1466,7 @@ model_example_map = { ...@@ -1430,6 +1466,7 @@ model_example_map = {
"stepvl": load_step_vl, "stepvl": load_step_vl,
"tarsier": load_tarsier, "tarsier": load_tarsier,
"tarsier2": load_tarsier2, "tarsier2": load_tarsier2,
"glm4_1v": load_glm4_1v,
"glm4_5v": load_glm4_5v, "glm4_5v": load_glm4_5v,
"glm4_5v_fp8": load_glm4_5v_fp8, "glm4_5v_fp8": load_glm4_5v_fp8,
} }
......
...@@ -27,9 +27,8 @@ ...@@ -27,9 +27,8 @@
"""Inference-only GLM-4.1V & GLM-4.6V-Flash, AutoGLM-Phone-9B model """Inference-only GLM-4.1V & GLM-4.6V-Flash, AutoGLM-Phone-9B model
compatible with HuggingFace weights.""" compatible with HuggingFace weights."""
import itertools
import math import math
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from functools import partial from functools import partial
from typing import Annotated, Any, Literal, TypeAlias from typing import Annotated, Any, Literal, TypeAlias
...@@ -1580,138 +1579,62 @@ class Glm4vForConditionalGeneration( ...@@ -1580,138 +1579,62 @@ class Glm4vForConditionalGeneration(
multimodal_embeddings += tuple(video_embeddings) multimodal_embeddings += tuple(video_embeddings)
return multimodal_embeddings return multimodal_embeddings
def get_mrope_input_positions( def iter_mm_grid_thw(
self, self, mm_features: list[MultiModalFeatureSpec]
input_tokens: list[int], ) -> Iterator[tuple[int, int, int, int]]:
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{"image_grid_thw", "video_grid_thw"},
)
image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
hf_config = self.config hf_config = self.config
image_token_id = hf_config.image_token_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size spatial_merge_size = hf_config.vision_config.spatial_merge_size
llm_pos_ids_list: list = [] for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
offset = mm_feature.mm_position.offset
if image_grid_thw or video_grid_thw: if mm_feature.modality == "image":
input_token_type: list[str] = [] t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
video_check_flg = False assert t == 1, f"Image must have 1 frame, got {t}"
for token in input_tokens: yield offset, t, h // spatial_merge_size, w // spatial_merge_size
if token == video_start_token_id: elif mm_feature.modality == "video":
video_check_flg = True t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
elif token == video_end_token_id: yield (
video_check_flg = False offset,
if (token == image_token_id) and (video_check_flg is False):
input_token_type.append("image")
elif (token == image_token_id) and (video_check_flg is True):
input_token_type.append("video")
else:
input_token_type.append("text")
input_type_group: list[tuple[str, int, int]] = []
for key, group_iter in itertools.groupby(
enumerate(input_token_type), lambda x: x[1]
):
group_list = list(group_iter)
start_index = group_list[0][0]
end_index = group_list[-1][0] + 1
input_type_group.append((key, start_index, end_index))
video_frame_num = 1
mm_data_idx = 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
if modality_type == "image":
t, h, w = image_grid_thw[mm_data_idx]
llm_grid_t, llm_grid_h, llm_grid_w = (
t, t,
h // spatial_merge_size, h // spatial_merge_size,
w // spatial_merge_size, w // spatial_merge_size,
) )
else:
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
t_index = ( def get_mrope_input_positions(
torch.arange(llm_grid_t) self,
.view(-1, 1) input_tokens: list[int],
.expand(-1, llm_grid_h * llm_grid_w) mm_features: list[MultiModalFeatureSpec],
.flatten() ) -> tuple[torch.Tensor, int]:
) llm_pos_ids_list: list = []
h_index = ( st = 0
torch.arange(llm_grid_h) for (
.view(1, -1, 1) offset,
.expand(llm_grid_t, -1, llm_grid_w) llm_grid_t,
.flatten() llm_grid_h,
) llm_grid_w,
w_index = ( ) in self.iter_mm_grid_thw(mm_features):
torch.arange(llm_grid_w) text_len = offset - st
.view(1, 1, -1) st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append( llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
) )
mm_data_idx += 1 grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w)).reshape(
3, -1
elif modality_type == "video":
t, h, w = (
video_frame_num,
*image_grid_thw[mm_data_idx][1:],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
) )
llm_pos_ids_list.append(grid_indices + text_len + st_idx)
st = offset + llm_grid_t * llm_grid_h * llm_grid_w
for t_idx in range(llm_grid_t): if st < len(input_tokens):
t_index = ( text_len = len(input_tokens) - st
torch.tensor(t_idx) st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(1, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(1, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx
)
mm_data_idx += 1
video_frame_num += 1
else:
text_len = end_idx - start_idx
llm_pos_ids_list.append( llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
) )
video_frame_num = 1
else:
text_len = len(input_tokens)
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta return torch.from_numpy(llm_positions), mrope_position_delta
def forward( def forward(
self, self,
......
...@@ -5,11 +5,11 @@ ...@@ -5,11 +5,11 @@
# https://github.com/zai-org/CogAgent # https://github.com/zai-org/CogAgent
"""Inference-only CogAgent model compatible with THUDM weights.""" """Inference-only CogAgent model compatible with THUDM weights."""
import itertools
from argparse import Namespace from argparse import Namespace
from collections.abc import Mapping, Sequence from collections.abc import Iterator, Mapping, Sequence
from typing import Annotated, Literal from typing import Annotated, Literal
import numpy as np
import torch import torch
from torch import nn from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
...@@ -624,138 +624,56 @@ class GLM4VForCausalLM( ...@@ -624,138 +624,56 @@ class GLM4VForCausalLM(
return self.transformer.vision(pixel_values) return self.transformer.vision(pixel_values)
def iter_mm_grid_thw(
self, mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, int, int, int]]:
hf_config = self.config
spatial_merge_size = hf_config.vision_config.spatial_merge_size
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
offset = mm_feature.mm_position.offset
if mm_feature.modality == "image":
t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
assert t == 1, f"Image must have 1 frame, got {t}"
yield offset, t, h // spatial_merge_size, w // spatial_merge_size
else:
# glm4v only supports image modality
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
def get_mrope_input_positions( def get_mrope_input_positions(
self, self,
input_tokens: list[int], input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec], mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]: ) -> tuple[torch.Tensor, int]:
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{"image_grid_thw", "video_grid_thw"},
)
image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
hf_config = self.config
image_token_id = hf_config.image_token_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
llm_pos_ids_list: list = [] llm_pos_ids_list: list = []
st = 0
if image_grid_thw or video_grid_thw: for (
input_token_type: list[str] = [] offset,
video_check_flg = False llm_grid_t,
for token in input_tokens: llm_grid_h,
if token == video_start_token_id: llm_grid_w,
video_check_flg = True ) in self.iter_mm_grid_thw(mm_features):
elif token == video_end_token_id: text_len = offset - st
video_check_flg = False st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
if (token == image_token_id) and (video_check_flg is False):
input_token_type.append("image")
elif (token == image_token_id) and (video_check_flg is True):
input_token_type.append("video")
else:
input_token_type.append("text")
input_type_group: list[tuple[str, int, int]] = []
for key, group_iter in itertools.groupby(
enumerate(input_token_type), lambda x: x[1]
):
group_list = list(group_iter)
start_index = group_list[0][0]
end_index = group_list[-1][0] + 1
input_type_group.append((key, start_index, end_index))
video_frame_num = 1
mm_data_idx = 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
if modality_type == "image":
t, h, w = image_grid_thw[mm_data_idx]
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append( llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
) )
mm_data_idx += 1 grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w)).reshape(
3, -1
elif modality_type == "video":
t, h, w = (
video_frame_num,
*image_grid_thw[mm_data_idx][1:],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
for t_idx in range(llm_grid_t):
t_index = (
torch.tensor(t_idx)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
) )
h_index = ( llm_pos_ids_list.append(grid_indices + text_len + st_idx)
torch.arange(llm_grid_h) # EVA2CLIPModel has embeddings for boi and eoi tokens as well
.view(1, -1, 1) st = offset + 1 + llm_grid_t * llm_grid_h * llm_grid_w + 1
.expand(1, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(1, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx
)
mm_data_idx += 1
video_frame_num += 1
else: if st < len(input_tokens):
text_len = end_idx - start_idx text_len = len(input_tokens) - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append( llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
) )
video_frame_num = 1
else:
text_len = len(input_tokens)
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta return torch.from_numpy(llm_positions), mrope_position_delta
embed_input_ids = SupportsMultiModal.embed_input_ids embed_input_ids = SupportsMultiModal.embed_input_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment