Unverified Commit 0008729a authored by lalit10's avatar lalit10 Committed by GitHub
Browse files

[Model] Use mm_features for Ernie-4.5 VL M-RoPE (#39753)


Signed-off-by: default avatarLalit Laxminarayan Bangad <lalitbangad@gmail.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent d3af8c18
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import pytest
import torch
from vllm.model_executor.models.ernie45_vl import (
Ernie4_5_VLMoeForConditionalGeneration,
)
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalFieldElem,
MultiModalKwargsItem,
PlaceholderRange,
)
pytestmark = pytest.mark.skip_global_cleanup
@pytest.fixture(autouse=True, scope="module")
def _force_cpu_default_device():
original = torch.get_default_device()
torch.set_default_device("cpu")
yield
torch.set_default_device(original)
@dataclass
class DummyConfig:
spatial_conv_size: int = 2
temporal_conv_size: int = 2
def make_model(config: DummyConfig) -> Ernie4_5_VLMoeForConditionalGeneration:
model = object.__new__(Ernie4_5_VLMoeForConditionalGeneration)
model.config = config
return model
def make_mm_feature(
*,
modality: str,
offset: int,
length: int,
grid_thw: tuple[int, int, int],
) -> MultiModalFeatureSpec:
field_name = "image_grid_thw" if modality == "image" else "video_grid_thw"
return MultiModalFeatureSpec(
data=MultiModalKwargsItem(
{
field_name: MultiModalFieldElem(
data=torch.tensor(grid_thw),
field=None, # HACK.
),
}
),
modality=modality,
identifier="DUMMY",
mm_position=PlaceholderRange(offset=offset, length=length),
)
def test_get_mrope_input_positions_text_only():
model = make_model(DummyConfig())
positions, delta = model.get_mrope_input_positions(
input_tokens=[11, 12, 13, 14, 15],
mm_features=[],
)
expected = torch.tensor(
[
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
]
)
assert torch.equal(positions, expected)
assert delta == 0
def test_get_mrope_input_positions_single_image():
model = make_model(DummyConfig())
mm_features = [
make_mm_feature(
modality="image",
offset=1,
length=4,
grid_thw=(1, 4, 4),
)
]
positions, delta = model.get_mrope_input_positions(
input_tokens=[10, 20, 21, 22, 23, 30, 31],
mm_features=mm_features,
)
expected = torch.tensor(
[
[0, 1, 1, 1, 1, 3, 4],
[0, 1, 1, 2, 2, 3, 4],
[0, 1, 2, 1, 2, 3, 4],
]
)
assert torch.equal(positions, expected)
assert delta == -2
def test_get_mrope_input_positions_interleaved_image_and_video():
model = make_model(DummyConfig())
mm_features = [
make_mm_feature(
modality="image",
offset=1,
length=4,
grid_thw=(1, 4, 4),
),
make_mm_feature(
modality="video",
offset=7,
length=2,
grid_thw=(2, 4, 2),
),
]
positions, delta = model.get_mrope_input_positions(
input_tokens=[10, 20, 21, 22, 23, 30, 31, 40, 41, 50, 51],
mm_features=mm_features,
)
expected = torch.tensor(
[
[0, 1, 1, 1, 1, 3, 4, 5, 5, 7, 8],
[0, 1, 1, 2, 2, 3, 4, 5, 6, 7, 8],
[0, 1, 2, 1, 2, 3, 4, 5, 5, 7, 8],
]
)
assert torch.equal(positions, expected)
assert delta == -2
...@@ -23,9 +23,8 @@ ...@@ -23,9 +23,8 @@
# limitations under the License. # limitations under the License.
"""Inference-only Ernie VL model compatible with HuggingFace weights.""" """Inference-only Ernie VL model compatible with HuggingFace weights."""
import itertools
import math import math
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from functools import partial from functools import partial
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal
...@@ -1401,131 +1400,62 @@ class Ernie4_5_VLMoeForConditionalGeneration( ...@@ -1401,131 +1400,62 @@ class Ernie4_5_VLMoeForConditionalGeneration(
input_tokens: list[int], input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec], mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]: ) -> tuple[torch.Tensor, int]:
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{"image_grid_thw", "video_grid_thw"},
)
image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
hf_config = self.config
image_token_id = hf_config.im_patch_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_conv_size = hf_config.spatial_conv_size
temporal_conv_size = hf_config.temporal_conv_size
llm_pos_ids_list: list = [] llm_pos_ids_list: list = []
st = 0
for (
offset,
llm_grid_t,
llm_grid_h,
llm_grid_w,
) in self.iter_mm_grid_thw(mm_features):
text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
if image_grid_thw or video_grid_thw: grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w)).reshape(
input_token_type: list[str] = [] 3, -1
video_check_flg = False )
for token in input_tokens: llm_pos_ids_list.append(grid_indices + text_len + st_idx)
if token == video_start_token_id: st = offset + llm_grid_t * llm_grid_h * llm_grid_w
video_check_flg = True
elif token == video_end_token_id: if st < len(input_tokens):
video_check_flg = False text_len = len(input_tokens) - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
if (token == image_token_id) and (video_check_flg is False): llm_pos_ids_list.append(
input_token_type.append("image") np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
elif (token == image_token_id) and (video_check_flg is True): )
input_token_type.append("video")
else:
input_token_type.append("text")
input_type_group: list[tuple[str, int, int]] = []
for key, group_iter in itertools.groupby(
enumerate(input_token_type), lambda x: x[1]
):
group_list = list(group_iter)
start_index = group_list[0][0]
end_index = group_list[-1][0] + 1
input_type_group.append((key, start_index, end_index))
video_frame_num = 1
mm_data_idx = 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
if modality_type == "image":
t, h, w = image_grid_thw[mm_data_idx]
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_conv_size,
w // spatial_conv_size,
)
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx
)
mm_data_idx += 1
elif modality_type == "video":
t, h, w = video_grid_thw[mm_data_idx]
llm_grid_t, llm_grid_h, llm_grid_w = (
t // temporal_conv_size,
h // spatial_conv_size,
w // spatial_conv_size,
)
for t_idx in range(llm_grid_t):
t_index = (
torch.tensor(t_idx)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(1, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(1, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + st_idx
)
mm_data_idx += 1
video_frame_num += 1
else:
text_len = end_idx - start_idx
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
video_frame_num = 1
else:
text_len = len(input_tokens)
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta return torch.from_numpy(llm_positions), mrope_position_delta
def iter_mm_grid_thw(
self, mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, int, int, int]]:
spatial_conv_size = self.config.spatial_conv_size
temporal_conv_size = self.config.temporal_conv_size
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
if mm_feature.data is None:
raise ValueError("M-RoPE calculation requires multimodal feature data")
offset = mm_feature.mm_position.offset
if mm_feature.modality == "image":
t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
yield offset, t, h // spatial_conv_size, w // spatial_conv_size
elif mm_feature.modality == "video":
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
yield (
offset,
t // temporal_conv_size,
h // spatial_conv_size,
w // spatial_conv_size,
)
else:
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object self, **kwargs: object
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment