Unverified Commit 4269b794 authored by grYe99's avatar grYe99 Committed by GitHub
Browse files

[Model] Use mm_features to compute mrope positions for PaddleOCR-VL (#39888)


Signed-off-by: default avatargrYe99 <guorongye99@gmail.com>
Co-authored-by: default avatargrYe99 <guorongye99@gmail.com>
parent edc36489
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
import pytest
import torch
from vllm.model_executor.models.paddleocr_vl import (
PaddleOCRVLForConditionalGeneration,
)
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalFieldElem,
MultiModalKwargsItem,
PlaceholderRange,
)
pytestmark = pytest.mark.skip_global_cleanup
@pytest.fixture(autouse=True, scope="module")
def _force_cpu_default_device():
original = torch.get_default_device()
torch.set_default_device("cpu")
yield
torch.set_default_device(original)
@dataclass
class DummyVisionConfig:
spatial_merge_size: int = 2
patch_size: int = 14
@dataclass
class DummyConfig:
image_token_id: int = 151655
video_token_id: int = 151654
vision_start_token_id: int = 151652
vision_end_token_id: int = 151653
vision_config: DummyVisionConfig = field(default_factory=DummyVisionConfig)
def make_model(config: DummyConfig) -> PaddleOCRVLForConditionalGeneration:
model = object.__new__(PaddleOCRVLForConditionalGeneration)
model.config = config
return model
def make_mm_feature(
*,
offset: int,
length: int,
image_grid_thw: tuple[int, int, int],
) -> MultiModalFeatureSpec:
return MultiModalFeatureSpec(
data=MultiModalKwargsItem(
{
"image_grid_thw": MultiModalFieldElem(
data=torch.tensor(image_grid_thw),
field=None,
),
}
),
modality="image",
identifier="DUMMY",
mm_position=PlaceholderRange(offset=offset, length=length),
)
def test_get_mrope_input_positions_text_only():
model = make_model(DummyConfig())
input_tokens = [11, 12, 13, 14, 15]
positions, delta = model.get_mrope_input_positions(
input_tokens=input_tokens,
mm_features=[],
)
expected = torch.tensor(
[
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
]
)
assert torch.equal(positions, expected)
assert delta == 0
def test_get_mrope_input_positions_single_image():
model = make_model(DummyConfig())
spatial_merge_size = model.config.vision_config.spatial_merge_size
t, h, w = 1, 2, 2
num_image_tokens = t * h * w
input_tokens = (
[10]
+ [model.config.vision_start_token_id]
+ [model.config.image_token_id] * num_image_tokens
+ [model.config.vision_end_token_id]
+ [30, 31]
)
mm_features = [
make_mm_feature(
offset=2, # 1 (text) + 1 (vision_start)
length=num_image_tokens,
image_grid_thw=(t, h * spatial_merge_size, w * spatial_merge_size),
)
]
positions, delta = model.get_mrope_input_positions(
input_tokens=input_tokens,
mm_features=mm_features,
)
expected = torch.tensor(
[
[0, 1, 2, 2, 2, 2, 4, 5, 6],
[0, 1, 2, 2, 3, 3, 4, 5, 6],
[0, 1, 2, 3, 2, 3, 4, 5, 6],
]
)
assert torch.equal(positions, expected)
expected_delta = (positions.max().item() + 1) - len(input_tokens)
assert delta == expected_delta
def test_get_mrope_input_positions_multiple_images():
model = make_model(DummyConfig())
spatial_merge_size = model.config.vision_config.spatial_merge_size
t1, h1, w1 = 1, 2, 2
num1 = t1 * h1 * w1
t2, h2, w2 = 1, 1, 3
num2 = t2 * h2 * w2
input_tokens = (
[10]
+ [model.config.vision_start_token_id]
+ [model.config.image_token_id] * num1
+ [model.config.vision_end_token_id]
+ [20, 21]
+ [model.config.vision_start_token_id]
+ [model.config.image_token_id] * num2
+ [model.config.vision_end_token_id]
+ [30]
)
mm_features = [
make_mm_feature(
offset=2,
length=num1,
image_grid_thw=(t1, h1 * spatial_merge_size, w1 * spatial_merge_size),
),
make_mm_feature(
offset=2 + num1 + 1 + 2 + 1,
length=num2,
image_grid_thw=(t2, h2 * spatial_merge_size, w2 * spatial_merge_size),
),
]
positions, delta = model.get_mrope_input_positions(
input_tokens=input_tokens,
mm_features=mm_features,
)
assert positions.shape == (3, 15)
assert not torch.equal(positions[:, 2:6], torch.arange(4).expand(3, 4) + 2)
assert not torch.equal(positions[:, 10:13], torch.arange(3).expand(3, 3) + 10)
def test_get_mrope_input_positions_image_at_start():
model = make_model(DummyConfig())
spatial_merge_size = model.config.vision_config.spatial_merge_size
t, h, w = 1, 2, 2
num_tokens = t * h * w
input_tokens = (
[model.config.vision_start_token_id]
+ [model.config.image_token_id] * num_tokens
+ [model.config.vision_end_token_id]
+ [10, 11]
)
mm_features = [
make_mm_feature(
offset=1, # start token at index 0
length=num_tokens,
image_grid_thw=(t, h * spatial_merge_size, w * spatial_merge_size),
)
]
positions, delta = model.get_mrope_input_positions(
input_tokens=input_tokens,
mm_features=mm_features,
)
expected = torch.tensor(
[
[0, 1, 1, 1, 1, 3, 4, 5],
[0, 1, 1, 2, 2, 3, 4, 5],
[0, 1, 2, 1, 2, 3, 4, 5],
]
)
assert torch.equal(positions, expected)
......@@ -15,7 +15,7 @@
# limitations under the License.
import math
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Iterable, Iterator, Mapping, Sequence
from functools import partial
from typing import Annotated, Literal
......@@ -1056,121 +1056,83 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def iter_mm_grid_thw(
self, mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, int, int, int, float]]:
"""
Iterate over multimodal features and yield grid information.
Args:
mm_features: List of multimodal feature specifications
Yields:
Tuple of (offset, grid_t, grid_h, grid_w, t_factor) for each frame/image
"""
spatial_merge_size = self.config.vision_config.spatial_merge_size
tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0)
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
offset = mm_feature.mm_position.offset
if mm_feature.modality == "image":
t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
assert t == 1, f"Image must have 1 frame, got {t}"
yield offset, 1, h // spatial_merge_size, w // spatial_merge_size, 1.0
elif mm_feature.modality == "video":
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
second_per_grid_ts = 1.0
if mm_feature.data.get("second_per_grid_ts", None):
second_per_grid_ts = mm_feature.data[
"second_per_grid_ts"
].data.item()
t_factor = second_per_grid_ts * tokens_per_second
yield (
offset,
t,
h // spatial_merge_size,
w // spatial_merge_size,
t_factor,
)
else:
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{"image_grid_thw", "video_grid_thw", "second_per_grid_ts"},
)
image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
hf_config = self.config
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id
).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
video_second_per_grid_t = 0.0
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else:
ed_image = len(input_tokens) + 1
if remain_videos > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = image_grid_thw[image_index]
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = video_grid_thw[video_index]
video_second_per_grid_t = 1.0
if second_per_grid_ts:
video_second_per_grid_t = second_per_grid_ts[video_index]
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
text_len = ed - st
for (
offset,
llm_grid_t,
llm_grid_h,
llm_grid_w,
t_factor,
) in self.iter_mm_grid_thw(mm_features):
text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
t_index = (
(
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
* video_second_per_grid_t
* tokens_per_second
)
.long()
.flatten()
)
grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w))
if t_factor != 1.0:
grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx)
st = offset + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta
return torch.from_numpy(llm_positions), mrope_position_delta
def _parse_and_validate_image_input(
self, **kwargs: object
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment