Unverified Commit 17d87168 authored by lalit10's avatar lalit10 Committed by GitHub
Browse files

[Model] Use mm_features for Keye-VL and Keye-1.5-VL M-RoPE (#39869)


Signed-off-by: default avatarLalit Laxminarayan Bangad <lalitbangad@gmail.com>
parent 98700c61
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
import pytest
import torch
from vllm.model_executor.models.keye import KeyeForConditionalGeneration
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalFieldElem,
MultiModalKwargsItem,
PlaceholderRange,
)
pytestmark = pytest.mark.skip_global_cleanup
@pytest.fixture(autouse=True, scope="module")
def _force_cpu_default_device():
original = torch.get_default_device()
torch.set_default_device("cpu")
yield
torch.set_default_device(original)
@dataclass
class DummyVisionConfig:
spatial_merge_size: int = 2
@dataclass
class DummyConfig:
vision_config: DummyVisionConfig = field(default_factory=DummyVisionConfig)
def make_model(config: DummyConfig) -> KeyeForConditionalGeneration:
model = object.__new__(KeyeForConditionalGeneration)
model.config = config
return model
def make_mm_feature(
*,
modality: str,
offset: int,
length: int,
grid_thw: tuple[int, int, int] | list[tuple[int, int, int]],
) -> MultiModalFeatureSpec:
field_name = "image_grid_thw" if modality == "image" else "video_grid_thw"
return MultiModalFeatureSpec(
data=MultiModalKwargsItem(
{
field_name: MultiModalFieldElem(
data=torch.tensor(grid_thw),
field=None, # HACK.
),
}
),
modality=modality,
identifier="DUMMY",
mm_position=PlaceholderRange(offset=offset, length=length),
)
def test_get_mrope_input_positions_text_only():
model = make_model(DummyConfig())
positions, delta = model.get_mrope_input_positions(
input_tokens=[11, 12, 13, 14, 15],
mm_features=[],
)
expected = torch.tensor(
[
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
]
)
assert torch.equal(positions, expected)
assert delta == 0
def test_get_mrope_input_positions_single_image():
model = make_model(DummyConfig())
mm_features = [
make_mm_feature(
modality="image",
offset=1,
length=4,
grid_thw=(1, 4, 4),
)
]
positions, delta = model.get_mrope_input_positions(
input_tokens=[10, 20, 21, 22, 23, 30, 31],
mm_features=mm_features,
)
expected = torch.tensor(
[
[0, 1, 1, 1, 1, 3, 4],
[0, 1, 1, 2, 2, 3, 4],
[0, 1, 2, 1, 2, 3, 4],
]
)
assert torch.equal(positions, expected)
assert delta == -2
def test_get_mrope_input_positions_interleaved_image_and_video():
model = make_model(DummyConfig())
mm_features = [
make_mm_feature(
modality="image",
offset=1,
length=4,
grid_thw=(1, 4, 4),
),
make_mm_feature(
modality="video",
offset=7,
length=4,
grid_thw=[(2, 4, 2)],
),
]
positions, delta = model.get_mrope_input_positions(
input_tokens=[10, 20, 21, 22, 23, 30, 31, 40, 41, 42, 43, 50, 51],
mm_features=mm_features,
)
expected = torch.tensor(
[
[0, 1, 1, 1, 1, 3, 4, 5, 5, 7, 7, 9, 10],
[0, 1, 1, 2, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[0, 1, 2, 1, 2, 3, 4, 5, 5, 7, 7, 9, 10],
]
)
assert torch.equal(positions, expected)
assert delta == -2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
import pytest
import torch
from vllm.model_executor.models.keye_vl1_5 import KeyeVL1_5ForConditionalGeneration
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalFieldElem,
MultiModalKwargsItem,
PlaceholderRange,
)
pytestmark = pytest.mark.skip_global_cleanup
@pytest.fixture(autouse=True, scope="module")
def _force_cpu_default_device():
original = torch.get_default_device()
torch.set_default_device("cpu")
yield
torch.set_default_device(original)
@dataclass
class DummyVisionConfig:
spatial_merge_size: int = 2
@dataclass
class DummyConfig:
vision_config: DummyVisionConfig = field(default_factory=DummyVisionConfig)
def make_model(config: DummyConfig) -> KeyeVL1_5ForConditionalGeneration:
model = object.__new__(KeyeVL1_5ForConditionalGeneration)
model.config = config
return model
def make_mm_feature(
*,
modality: str,
offset: int,
length: int,
grid_thw: tuple[int, int, int] | list[tuple[int, int, int]],
is_embed: list[bool] | None = None,
) -> MultiModalFeatureSpec:
field_name = "image_grid_thw" if modality == "image" else "video_grid_thw"
return MultiModalFeatureSpec(
data=MultiModalKwargsItem(
{
field_name: MultiModalFieldElem(
data=torch.tensor(grid_thw),
field=None, # HACK.
),
}
),
modality=modality,
identifier="DUMMY",
mm_position=PlaceholderRange(
offset=offset,
length=length,
is_embed=None if is_embed is None else torch.tensor(is_embed),
),
)
def test_get_mrope_input_positions_text_only():
model = make_model(DummyConfig())
positions, delta = model.get_mrope_input_positions(
input_tokens=[11, 12, 13, 14, 15],
mm_features=[],
)
expected = torch.tensor(
[
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
]
)
assert torch.equal(positions, expected)
assert delta == 0
def test_get_mrope_input_positions_single_image():
model = make_model(DummyConfig())
mm_features = [
make_mm_feature(
modality="image",
offset=1,
length=4,
grid_thw=(1, 4, 4),
)
]
positions, delta = model.get_mrope_input_positions(
input_tokens=[10, 20, 21, 22, 23, 30, 31],
mm_features=mm_features,
)
expected = torch.tensor(
[
[0, 1, 1, 1, 1, 3, 4],
[0, 1, 1, 2, 2, 3, 4],
[0, 1, 2, 1, 2, 3, 4],
]
)
assert torch.equal(positions, expected)
assert delta == -2
def test_get_mrope_input_positions_video_uses_embed_ranges():
model = make_model(DummyConfig())
mm_features = [
make_mm_feature(
modality="video",
offset=1,
length=8,
grid_thw=[(2, 4, 2)],
is_embed=[False, False, True, True, False, False, True, True],
)
]
positions, delta = model.get_mrope_input_positions(
input_tokens=[10, 101, 102, 20, 21, 103, 104, 30, 31, 40, 41],
mm_features=mm_features,
)
expected = torch.tensor(
[
[0, 1, 2, 3, 3, 5, 6, 7, 7, 9, 10],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[0, 1, 2, 3, 3, 5, 6, 7, 7, 9, 10],
]
)
assert torch.equal(positions, expected)
assert delta == 0
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Iterable, Iterator, Mapping, Sequence
from functools import partial
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
......@@ -1595,91 +1595,92 @@ class KeyeForConditionalGeneration(
self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos)
)
@staticmethod
def _split_video_grid_thw(
grid_thw: torch.Tensor | list[list[int]] | list[int],
) -> list[list[int]]:
"""
Split video grid_thw along the t dimension into per-frame rows.
This preserves Keye's current M-RoPE behavior, where a video is emitted
as consecutive frame-level multimodal blocks rather than a single block
spanning the whole video.
"""
if isinstance(grid_thw, list):
if len(grid_thw) == 0:
return []
if isinstance(grid_thw[0], int):
grid_thw = torch.tensor([grid_thw], dtype=torch.long)
else:
grid_thw = torch.tensor(grid_thw, dtype=torch.long)
elif grid_thw.ndim == 1:
grid_thw = grid_thw.unsqueeze(0)
if grid_thw.numel() == 0:
return []
t, hw = grid_thw[:, 0], grid_thw[:, 1:]
ones = torch.ones_like(hw[:, :1])
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
return out.tolist()
def iter_mm_grid_thw(
self, mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, int, int, int]]:
spatial_merge_size = self.config.vision_config.spatial_merge_size
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
if mm_feature.data is None:
raise ValueError("M-RoPE calculation requires multimodal feature data")
if mm_feature.modality == "image":
grid_thw = mm_feature.data["image_grid_thw"].data
if isinstance(grid_thw, torch.Tensor):
if grid_thw.ndim == 2:
assert grid_thw.shape[0] == 1
t, h, w = grid_thw[0].tolist()
else:
t, h, w = grid_thw.tolist()
else:
if isinstance(grid_thw[0], list):
assert len(grid_thw) == 1
t, h, w = grid_thw[0]
else:
t, h, w = grid_thw
yield (
mm_feature.mm_position.offset,
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
elif mm_feature.modality == "video":
current_offset = mm_feature.mm_position.offset
for t, h, w in self._split_video_grid_thw(
mm_feature.data["video_grid_thw"].data
):
llm_grid_h = h // spatial_merge_size
llm_grid_w = w // spatial_merge_size
yield (current_offset, t, llm_grid_h, llm_grid_w)
current_offset += t * llm_grid_h * llm_grid_w
else:
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{"image_grid_thw", "video_grid_thw"},
)
image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
video_grid_thw = video_grid_thw[0]
def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]:
"""
Split grid_thw along the t dimension.
Args:
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
Returns:
List of [1, h, w] rows, repeated t times for each original row.
"""
if isinstance(grid_thw, list):
grid_thw = torch.tensor(grid_thw, dtype=torch.long)
if grid_thw.numel() == 0:
return []
t, hw = grid_thw[:, 0], grid_thw[:, 1:]
ones = torch.ones_like(hw[:, :1]) # [N,1]
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
return out.tolist()
video_grid_thw = split_thw(video_grid_thw)
hf_config = self.config
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
image_nums = len(image_grid_thw)
frame_nums = len(video_grid_thw)
llm_pos_ids_list: list = []
st = 0
remain_images, remain_frames = image_nums, frame_nums
image_index, video_index = 0, 0
for _ in range(image_nums + frame_nums):
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else:
ed_image = len(input_tokens) + 1
if remain_frames > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = image_grid_thw[image_index]
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = video_grid_thw[video_index]
video_index += 1
remain_frames -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
text_len = ed - st
for (
offset,
llm_grid_t,
llm_grid_h,
llm_grid_w,
) in self.iter_mm_grid_thw(mm_features):
text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
......@@ -1711,7 +1712,7 @@ class KeyeForConditionalGeneration(
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
st = offset + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from collections.abc import Mapping, Sequence
from collections.abc import Iterator, Mapping, Sequence
from functools import partial
from typing import Annotated, Any, Literal, TypeAlias
......@@ -608,91 +608,67 @@ class KeyeVL1_5ForConditionalGeneration(
new_video_embeds.append(video_embeds[start:end])
return tuple(new_video_embeds)
def iter_mm_grid_thw(
self, mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, int, int, int]]:
spatial_merge_size = self.config.vision_config.spatial_merge_size
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
if mm_feature.data is None:
raise ValueError("M-RoPE calculation requires multimodal feature data")
embed_ranges = mm_feature.mm_position.extract_embeds_range()
if mm_feature.modality == "image":
assert len(embed_ranges) == 1
grid_thw = mm_feature.data["image_grid_thw"].data
if isinstance(grid_thw, torch.Tensor):
if grid_thw.ndim == 2:
assert grid_thw.shape[0] == 1
t, h, w = grid_thw[0].tolist()
else:
t, h, w = grid_thw.tolist()
else:
if isinstance(grid_thw[0], list):
assert len(grid_thw) == 1
t, h, w = grid_thw[0]
else:
t, h, w = grid_thw
yield (
embed_ranges[0][0],
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
elif mm_feature.modality == "video":
split_video_grids = split_thw(mm_feature.data["video_grid_thw"].data)
assert len(embed_ranges) == split_video_grids.shape[0]
for (start_idx, end_idx), (t, h, w) in zip(
embed_ranges, split_video_grids.tolist()
):
llm_grid_h = h // spatial_merge_size
llm_grid_w = w // spatial_merge_size
num_mm_tokens = t * llm_grid_h * llm_grid_w
assert end_idx - start_idx + 1 == num_mm_tokens
yield (start_idx, t, llm_grid_h, llm_grid_w)
else:
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{"image_grid_thw", "video_grid_thw"},
)
image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
video_grid_thw = video_grid_thw[0]
def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]:
"""
Split grid_thw along the t dimension.
Args:
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
Returns:
List of [1, h, w] rows, repeated t times for each original row.
"""
if isinstance(grid_thw, list):
grid_thw = torch.tensor(grid_thw, dtype=torch.long)
if grid_thw.numel() == 0:
return []
t, hw = grid_thw[:, 0], grid_thw[:, 1:]
ones = torch.ones_like(hw[:, :1]) # [N,1]
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
return out.tolist()
video_grid_thw = split_thw(video_grid_thw)
hf_config = self.config
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
image_nums = len(image_grid_thw)
frame_nums = len(video_grid_thw)
llm_pos_ids_list: list = []
st = 0
remain_images, remain_frames = image_nums, frame_nums
image_index, video_index = 0, 0
for _ in range(image_nums + frame_nums):
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else:
ed_image = len(input_tokens) + 1
if remain_frames > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = image_grid_thw[image_index]
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = video_grid_thw[video_index]
video_index += 1
remain_frames -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
text_len = ed - st
for (
offset,
llm_grid_t,
llm_grid_h,
llm_grid_w,
) in self.iter_mm_grid_thw(mm_features):
text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
......@@ -724,7 +700,7 @@ class KeyeVL1_5ForConditionalGeneration(
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
st = offset + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment