Unverified Commit 17d87168 authored by lalit10's avatar lalit10 Committed by GitHub
Browse files

[Model] Use mm_features for Keye-VL and Keye-1.5-VL M-RoPE (#39869)


Signed-off-by: default avatarLalit Laxminarayan Bangad <lalitbangad@gmail.com>
parent 98700c61
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
import pytest
import torch
from vllm.model_executor.models.keye import KeyeForConditionalGeneration
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalFieldElem,
MultiModalKwargsItem,
PlaceholderRange,
)
pytestmark = pytest.mark.skip_global_cleanup
@pytest.fixture(autouse=True, scope="module")
def _force_cpu_default_device():
original = torch.get_default_device()
torch.set_default_device("cpu")
yield
torch.set_default_device(original)
@dataclass
class DummyVisionConfig:
spatial_merge_size: int = 2
@dataclass
class DummyConfig:
vision_config: DummyVisionConfig = field(default_factory=DummyVisionConfig)
def make_model(config: DummyConfig) -> KeyeForConditionalGeneration:
model = object.__new__(KeyeForConditionalGeneration)
model.config = config
return model
def make_mm_feature(
*,
modality: str,
offset: int,
length: int,
grid_thw: tuple[int, int, int] | list[tuple[int, int, int]],
) -> MultiModalFeatureSpec:
field_name = "image_grid_thw" if modality == "image" else "video_grid_thw"
return MultiModalFeatureSpec(
data=MultiModalKwargsItem(
{
field_name: MultiModalFieldElem(
data=torch.tensor(grid_thw),
field=None, # HACK.
),
}
),
modality=modality,
identifier="DUMMY",
mm_position=PlaceholderRange(offset=offset, length=length),
)
def test_get_mrope_input_positions_text_only():
model = make_model(DummyConfig())
positions, delta = model.get_mrope_input_positions(
input_tokens=[11, 12, 13, 14, 15],
mm_features=[],
)
expected = torch.tensor(
[
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
]
)
assert torch.equal(positions, expected)
assert delta == 0
def test_get_mrope_input_positions_single_image():
model = make_model(DummyConfig())
mm_features = [
make_mm_feature(
modality="image",
offset=1,
length=4,
grid_thw=(1, 4, 4),
)
]
positions, delta = model.get_mrope_input_positions(
input_tokens=[10, 20, 21, 22, 23, 30, 31],
mm_features=mm_features,
)
expected = torch.tensor(
[
[0, 1, 1, 1, 1, 3, 4],
[0, 1, 1, 2, 2, 3, 4],
[0, 1, 2, 1, 2, 3, 4],
]
)
assert torch.equal(positions, expected)
assert delta == -2
def test_get_mrope_input_positions_interleaved_image_and_video():
model = make_model(DummyConfig())
mm_features = [
make_mm_feature(
modality="image",
offset=1,
length=4,
grid_thw=(1, 4, 4),
),
make_mm_feature(
modality="video",
offset=7,
length=4,
grid_thw=[(2, 4, 2)],
),
]
positions, delta = model.get_mrope_input_positions(
input_tokens=[10, 20, 21, 22, 23, 30, 31, 40, 41, 42, 43, 50, 51],
mm_features=mm_features,
)
expected = torch.tensor(
[
[0, 1, 1, 1, 1, 3, 4, 5, 5, 7, 7, 9, 10],
[0, 1, 1, 2, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[0, 1, 2, 1, 2, 3, 4, 5, 5, 7, 7, 9, 10],
]
)
assert torch.equal(positions, expected)
assert delta == -2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
import pytest
import torch
from vllm.model_executor.models.keye_vl1_5 import KeyeVL1_5ForConditionalGeneration
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalFieldElem,
MultiModalKwargsItem,
PlaceholderRange,
)
pytestmark = pytest.mark.skip_global_cleanup
@pytest.fixture(autouse=True, scope="module")
def _force_cpu_default_device():
original = torch.get_default_device()
torch.set_default_device("cpu")
yield
torch.set_default_device(original)
@dataclass
class DummyVisionConfig:
spatial_merge_size: int = 2
@dataclass
class DummyConfig:
vision_config: DummyVisionConfig = field(default_factory=DummyVisionConfig)
def make_model(config: DummyConfig) -> KeyeVL1_5ForConditionalGeneration:
model = object.__new__(KeyeVL1_5ForConditionalGeneration)
model.config = config
return model
def make_mm_feature(
*,
modality: str,
offset: int,
length: int,
grid_thw: tuple[int, int, int] | list[tuple[int, int, int]],
is_embed: list[bool] | None = None,
) -> MultiModalFeatureSpec:
field_name = "image_grid_thw" if modality == "image" else "video_grid_thw"
return MultiModalFeatureSpec(
data=MultiModalKwargsItem(
{
field_name: MultiModalFieldElem(
data=torch.tensor(grid_thw),
field=None, # HACK.
),
}
),
modality=modality,
identifier="DUMMY",
mm_position=PlaceholderRange(
offset=offset,
length=length,
is_embed=None if is_embed is None else torch.tensor(is_embed),
),
)
def test_get_mrope_input_positions_text_only():
model = make_model(DummyConfig())
positions, delta = model.get_mrope_input_positions(
input_tokens=[11, 12, 13, 14, 15],
mm_features=[],
)
expected = torch.tensor(
[
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
]
)
assert torch.equal(positions, expected)
assert delta == 0
def test_get_mrope_input_positions_single_image():
model = make_model(DummyConfig())
mm_features = [
make_mm_feature(
modality="image",
offset=1,
length=4,
grid_thw=(1, 4, 4),
)
]
positions, delta = model.get_mrope_input_positions(
input_tokens=[10, 20, 21, 22, 23, 30, 31],
mm_features=mm_features,
)
expected = torch.tensor(
[
[0, 1, 1, 1, 1, 3, 4],
[0, 1, 1, 2, 2, 3, 4],
[0, 1, 2, 1, 2, 3, 4],
]
)
assert torch.equal(positions, expected)
assert delta == -2
def test_get_mrope_input_positions_video_uses_embed_ranges():
model = make_model(DummyConfig())
mm_features = [
make_mm_feature(
modality="video",
offset=1,
length=8,
grid_thw=[(2, 4, 2)],
is_embed=[False, False, True, True, False, False, True, True],
)
]
positions, delta = model.get_mrope_input_positions(
input_tokens=[10, 101, 102, 20, 21, 103, 104, 30, 31, 40, 41],
mm_features=mm_features,
)
expected = torch.tensor(
[
[0, 1, 2, 3, 3, 5, 6, 7, 7, 9, 10],
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[0, 1, 2, 3, 3, 5, 6, 7, 7, 9, 10],
]
)
assert torch.equal(positions, expected)
assert delta == 0
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math import math
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Iterator, Mapping, Sequence
from functools import partial from functools import partial
from typing import Annotated, Any, Literal, TypeAlias, TypeVar from typing import Annotated, Any, Literal, TypeAlias, TypeVar
...@@ -1595,91 +1595,92 @@ class KeyeForConditionalGeneration( ...@@ -1595,91 +1595,92 @@ class KeyeForConditionalGeneration(
self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos) self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos)
) )
def get_mrope_input_positions( @staticmethod
self, def _split_video_grid_thw(
input_tokens: list[int], grid_thw: torch.Tensor | list[list[int]] | list[int],
mm_features: list[MultiModalFeatureSpec], ) -> list[list[int]]:
) -> tuple[torch.Tensor, int]:
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{"image_grid_thw", "video_grid_thw"},
)
image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
video_grid_thw = video_grid_thw[0]
def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]:
""" """
Split grid_thw along the t dimension. Split video grid_thw along the t dimension into per-frame rows.
Args:
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
Returns: This preserves Keye's current M-RoPE behavior, where a video is emitted
List of [1, h, w] rows, repeated t times for each original row. as consecutive frame-level multimodal blocks rather than a single block
spanning the whole video.
""" """
if isinstance(grid_thw, list): if isinstance(grid_thw, list):
if len(grid_thw) == 0:
return []
if isinstance(grid_thw[0], int):
grid_thw = torch.tensor([grid_thw], dtype=torch.long)
else:
grid_thw = torch.tensor(grid_thw, dtype=torch.long) grid_thw = torch.tensor(grid_thw, dtype=torch.long)
elif grid_thw.ndim == 1:
grid_thw = grid_thw.unsqueeze(0)
if grid_thw.numel() == 0: if grid_thw.numel() == 0:
return [] return []
t, hw = grid_thw[:, 0], grid_thw[:, 1:] t, hw = grid_thw[:, 0], grid_thw[:, 1:]
ones = torch.ones_like(hw[:, :1]) # [N,1] ones = torch.ones_like(hw[:, :1])
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0) out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
return out.tolist() return out.tolist()
video_grid_thw = split_thw(video_grid_thw) def iter_mm_grid_thw(
self, mm_features: list[MultiModalFeatureSpec]
hf_config = self.config ) -> Iterator[tuple[int, int, int, int]]:
image_token_id = hf_config.image_token_id spatial_merge_size = self.config.vision_config.spatial_merge_size
video_token_id = hf_config.video_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
if mm_feature.data is None:
image_nums = len(image_grid_thw) raise ValueError("M-RoPE calculation requires multimodal feature data")
frame_nums = len(video_grid_thw)
llm_pos_ids_list: list = [] if mm_feature.modality == "image":
grid_thw = mm_feature.data["image_grid_thw"].data
st = 0 if isinstance(grid_thw, torch.Tensor):
remain_images, remain_frames = image_nums, frame_nums if grid_thw.ndim == 2:
assert grid_thw.shape[0] == 1
image_index, video_index = 0, 0 t, h, w = grid_thw[0].tolist()
for _ in range(image_nums + frame_nums):
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else: else:
ed_image = len(input_tokens) + 1 t, h, w = grid_thw.tolist()
if remain_frames > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else: else:
ed_video = len(input_tokens) + 1 if isinstance(grid_thw[0], list):
assert len(grid_thw) == 1
if ed_image < ed_video: t, h, w = grid_thw[0]
t, h, w = image_grid_thw[image_index]
image_index += 1
remain_images -= 1
ed = ed_image
else: else:
t, h, w = video_grid_thw[video_index] t, h, w = grid_thw
video_index += 1
remain_frames -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = ( yield (
mm_feature.mm_position.offset,
t, t,
h // spatial_merge_size, h // spatial_merge_size,
w // spatial_merge_size, w // spatial_merge_size,
) )
text_len = ed - st elif mm_feature.modality == "video":
current_offset = mm_feature.mm_position.offset
for t, h, w in self._split_video_grid_thw(
mm_feature.data["video_grid_thw"].data
):
llm_grid_h = h // spatial_merge_size
llm_grid_w = w // spatial_merge_size
yield (current_offset, t, llm_grid_h, llm_grid_w)
current_offset += t * llm_grid_h * llm_grid_w
else:
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
llm_pos_ids_list: list = []
st = 0
for (
offset,
llm_grid_t,
llm_grid_h,
llm_grid_w,
) in self.iter_mm_grid_thw(mm_features):
text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append( llm_pos_ids_list.append(
...@@ -1711,7 +1712,7 @@ class KeyeForConditionalGeneration( ...@@ -1711,7 +1712,7 @@ class KeyeForConditionalGeneration(
llm_pos_ids_list.append( llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx torch.stack([t_index, h_index, w_index]) + text_len + st_idx
) )
st = ed + llm_grid_t * llm_grid_h * llm_grid_w st = offset + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens): if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools import itertools
from collections.abc import Mapping, Sequence from collections.abc import Iterator, Mapping, Sequence
from functools import partial from functools import partial
from typing import Annotated, Any, Literal, TypeAlias from typing import Annotated, Any, Literal, TypeAlias
...@@ -608,91 +608,67 @@ class KeyeVL1_5ForConditionalGeneration( ...@@ -608,91 +608,67 @@ class KeyeVL1_5ForConditionalGeneration(
new_video_embeds.append(video_embeds[start:end]) new_video_embeds.append(video_embeds[start:end])
return tuple(new_video_embeds) return tuple(new_video_embeds)
def get_mrope_input_positions( def iter_mm_grid_thw(
self, self, mm_features: list[MultiModalFeatureSpec]
input_tokens: list[int], ) -> Iterator[tuple[int, int, int, int]]:
mm_features: list[MultiModalFeatureSpec], spatial_merge_size = self.config.vision_config.spatial_merge_size
) -> tuple[torch.Tensor, int]:
kwargs = MultiModalFeatureSpec.gather_kwargs( for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
mm_features, if mm_feature.data is None:
{"image_grid_thw", "video_grid_thw"}, raise ValueError("M-RoPE calculation requires multimodal feature data")
)
image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])] embed_ranges = mm_feature.mm_position.extract_embeds_range()
video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])] if mm_feature.modality == "image":
assert len(embed_ranges) == 1
if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0: grid_thw = mm_feature.data["image_grid_thw"].data
video_grid_thw = video_grid_thw[0] if isinstance(grid_thw, torch.Tensor):
if grid_thw.ndim == 2:
def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]: assert grid_thw.shape[0] == 1
""" t, h, w = grid_thw[0].tolist()
Split grid_thw along the t dimension.
Args:
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
Returns:
List of [1, h, w] rows, repeated t times for each original row.
"""
if isinstance(grid_thw, list):
grid_thw = torch.tensor(grid_thw, dtype=torch.long)
if grid_thw.numel() == 0:
return []
t, hw = grid_thw[:, 0], grid_thw[:, 1:]
ones = torch.ones_like(hw[:, :1]) # [N,1]
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
return out.tolist()
video_grid_thw = split_thw(video_grid_thw)
hf_config = self.config
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
image_nums = len(image_grid_thw)
frame_nums = len(video_grid_thw)
llm_pos_ids_list: list = []
st = 0
remain_images, remain_frames = image_nums, frame_nums
image_index, video_index = 0, 0
for _ in range(image_nums + frame_nums):
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else: else:
ed_image = len(input_tokens) + 1 t, h, w = grid_thw.tolist()
if remain_frames > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else: else:
ed_video = len(input_tokens) + 1 if isinstance(grid_thw[0], list):
assert len(grid_thw) == 1
if ed_image < ed_video: t, h, w = grid_thw[0]
t, h, w = image_grid_thw[image_index]
image_index += 1
remain_images -= 1
ed = ed_image
else: else:
t, h, w = video_grid_thw[video_index] t, h, w = grid_thw
video_index += 1
remain_frames -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = ( yield (
embed_ranges[0][0],
t, t,
h // spatial_merge_size, h // spatial_merge_size,
w // spatial_merge_size, w // spatial_merge_size,
) )
text_len = ed - st elif mm_feature.modality == "video":
split_video_grids = split_thw(mm_feature.data["video_grid_thw"].data)
assert len(embed_ranges) == split_video_grids.shape[0]
for (start_idx, end_idx), (t, h, w) in zip(
embed_ranges, split_video_grids.tolist()
):
llm_grid_h = h // spatial_merge_size
llm_grid_w = w // spatial_merge_size
num_mm_tokens = t * llm_grid_h * llm_grid_w
assert end_idx - start_idx + 1 == num_mm_tokens
yield (start_idx, t, llm_grid_h, llm_grid_w)
else:
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
llm_pos_ids_list: list = []
st = 0
for (
offset,
llm_grid_t,
llm_grid_h,
llm_grid_w,
) in self.iter_mm_grid_thw(mm_features):
text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append( llm_pos_ids_list.append(
...@@ -724,7 +700,7 @@ class KeyeVL1_5ForConditionalGeneration( ...@@ -724,7 +700,7 @@ class KeyeVL1_5ForConditionalGeneration(
llm_pos_ids_list.append( llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx torch.stack([t_index, h_index, w_index]) + text_len + st_idx
) )
st = ed + llm_grid_t * llm_grid_h * llm_grid_w st = offset + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens): if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment