Unverified Commit 4269b794 authored by grYe99's avatar grYe99 Committed by GitHub
Browse files

[Model] Use mm_features to compute mrope positions for PaddleOCR-VL (#39888)


Signed-off-by: default avatargrYe99 <guorongye99@gmail.com>
Co-authored-by: default avatargrYe99 <guorongye99@gmail.com>
parent edc36489
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
import pytest
import torch
from vllm.model_executor.models.paddleocr_vl import (
PaddleOCRVLForConditionalGeneration,
)
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalFieldElem,
MultiModalKwargsItem,
PlaceholderRange,
)
pytestmark = pytest.mark.skip_global_cleanup
@pytest.fixture(autouse=True, scope="module")
def _force_cpu_default_device():
original = torch.get_default_device()
torch.set_default_device("cpu")
yield
torch.set_default_device(original)
@dataclass
class DummyVisionConfig:
spatial_merge_size: int = 2
patch_size: int = 14
@dataclass
class DummyConfig:
image_token_id: int = 151655
video_token_id: int = 151654
vision_start_token_id: int = 151652
vision_end_token_id: int = 151653
vision_config: DummyVisionConfig = field(default_factory=DummyVisionConfig)
def make_model(config: DummyConfig) -> PaddleOCRVLForConditionalGeneration:
model = object.__new__(PaddleOCRVLForConditionalGeneration)
model.config = config
return model
def make_mm_feature(
*,
offset: int,
length: int,
image_grid_thw: tuple[int, int, int],
) -> MultiModalFeatureSpec:
return MultiModalFeatureSpec(
data=MultiModalKwargsItem(
{
"image_grid_thw": MultiModalFieldElem(
data=torch.tensor(image_grid_thw),
field=None,
),
}
),
modality="image",
identifier="DUMMY",
mm_position=PlaceholderRange(offset=offset, length=length),
)
def test_get_mrope_input_positions_text_only():
model = make_model(DummyConfig())
input_tokens = [11, 12, 13, 14, 15]
positions, delta = model.get_mrope_input_positions(
input_tokens=input_tokens,
mm_features=[],
)
expected = torch.tensor(
[
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
]
)
assert torch.equal(positions, expected)
assert delta == 0
def test_get_mrope_input_positions_single_image():
model = make_model(DummyConfig())
spatial_merge_size = model.config.vision_config.spatial_merge_size
t, h, w = 1, 2, 2
num_image_tokens = t * h * w
input_tokens = (
[10]
+ [model.config.vision_start_token_id]
+ [model.config.image_token_id] * num_image_tokens
+ [model.config.vision_end_token_id]
+ [30, 31]
)
mm_features = [
make_mm_feature(
offset=2, # 1 (text) + 1 (vision_start)
length=num_image_tokens,
image_grid_thw=(t, h * spatial_merge_size, w * spatial_merge_size),
)
]
positions, delta = model.get_mrope_input_positions(
input_tokens=input_tokens,
mm_features=mm_features,
)
expected = torch.tensor(
[
[0, 1, 2, 2, 2, 2, 4, 5, 6],
[0, 1, 2, 2, 3, 3, 4, 5, 6],
[0, 1, 2, 3, 2, 3, 4, 5, 6],
]
)
assert torch.equal(positions, expected)
expected_delta = (positions.max().item() + 1) - len(input_tokens)
assert delta == expected_delta
def test_get_mrope_input_positions_multiple_images():
model = make_model(DummyConfig())
spatial_merge_size = model.config.vision_config.spatial_merge_size
t1, h1, w1 = 1, 2, 2
num1 = t1 * h1 * w1
t2, h2, w2 = 1, 1, 3
num2 = t2 * h2 * w2
input_tokens = (
[10]
+ [model.config.vision_start_token_id]
+ [model.config.image_token_id] * num1
+ [model.config.vision_end_token_id]
+ [20, 21]
+ [model.config.vision_start_token_id]
+ [model.config.image_token_id] * num2
+ [model.config.vision_end_token_id]
+ [30]
)
mm_features = [
make_mm_feature(
offset=2,
length=num1,
image_grid_thw=(t1, h1 * spatial_merge_size, w1 * spatial_merge_size),
),
make_mm_feature(
offset=2 + num1 + 1 + 2 + 1,
length=num2,
image_grid_thw=(t2, h2 * spatial_merge_size, w2 * spatial_merge_size),
),
]
positions, delta = model.get_mrope_input_positions(
input_tokens=input_tokens,
mm_features=mm_features,
)
assert positions.shape == (3, 15)
assert not torch.equal(positions[:, 2:6], torch.arange(4).expand(3, 4) + 2)
assert not torch.equal(positions[:, 10:13], torch.arange(3).expand(3, 3) + 10)
def test_get_mrope_input_positions_image_at_start():
model = make_model(DummyConfig())
spatial_merge_size = model.config.vision_config.spatial_merge_size
t, h, w = 1, 2, 2
num_tokens = t * h * w
input_tokens = (
[model.config.vision_start_token_id]
+ [model.config.image_token_id] * num_tokens
+ [model.config.vision_end_token_id]
+ [10, 11]
)
mm_features = [
make_mm_feature(
offset=1, # start token at index 0
length=num_tokens,
image_grid_thw=(t, h * spatial_merge_size, w * spatial_merge_size),
)
]
positions, delta = model.get_mrope_input_positions(
input_tokens=input_tokens,
mm_features=mm_features,
)
expected = torch.tensor(
[
[0, 1, 1, 1, 1, 3, 4, 5],
[0, 1, 1, 2, 2, 3, 4, 5],
[0, 1, 2, 1, 2, 3, 4, 5],
]
)
assert torch.equal(positions, expected)
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Iterator, Mapping, Sequence
from functools import partial from functools import partial
from typing import Annotated, Literal from typing import Annotated, Literal
...@@ -1056,121 +1056,83 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support ...@@ -1056,121 +1056,83 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
) -> torch.Tensor | None: ) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states) return self.language_model.compute_logits(hidden_states)
def iter_mm_grid_thw(
self, mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, int, int, int, float]]:
"""
Iterate over multimodal features and yield grid information.
Args:
mm_features: List of multimodal feature specifications
Yields:
Tuple of (offset, grid_t, grid_h, grid_w, t_factor) for each frame/image
"""
spatial_merge_size = self.config.vision_config.spatial_merge_size
tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0)
for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
offset = mm_feature.mm_position.offset
if mm_feature.modality == "image":
t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
assert t == 1, f"Image must have 1 frame, got {t}"
yield offset, 1, h // spatial_merge_size, w // spatial_merge_size, 1.0
elif mm_feature.modality == "video":
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
second_per_grid_ts = 1.0
if mm_feature.data.get("second_per_grid_ts", None):
second_per_grid_ts = mm_feature.data[
"second_per_grid_ts"
].data.item()
t_factor = second_per_grid_ts * tokens_per_second
yield (
offset,
t,
h // spatial_merge_size,
w // spatial_merge_size,
t_factor,
)
else:
raise ValueError(f"Unsupported modality: {mm_feature.modality}")
def get_mrope_input_positions( def get_mrope_input_positions(
self, self,
input_tokens: list[int], input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec], mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]: ) -> tuple[torch.Tensor, int]:
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{"image_grid_thw", "video_grid_thw", "second_per_grid_ts"},
)
image_grid_thw = [item.tolist() for item in kwargs.get("image_grid_thw", [])]
video_grid_thw = [item.tolist() for item in kwargs.get("video_grid_thw", [])]
second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
hf_config = self.config
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
vision_start_token_id = hf_config.vision_start_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
tokens_per_second = getattr(hf_config.vision_config, "tokens_per_second", 1.0)
input_tokens_tensor = torch.tensor(input_tokens)
vision_start_indices = torch.argwhere(
input_tokens_tensor == vision_start_token_id
).squeeze(1)
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
llm_pos_ids_list: list = [] llm_pos_ids_list: list = []
st = 0 st = 0
remain_images, remain_videos = image_nums, video_nums
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
video_second_per_grid_t = 0.0
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else:
ed_image = len(input_tokens) + 1
if remain_videos > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = image_grid_thw[image_index]
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = video_grid_thw[video_index]
video_second_per_grid_t = 1.0
if second_per_grid_ts:
video_second_per_grid_t = second_per_grid_ts[video_index]
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
text_len = ed - st
for (
offset,
llm_grid_t,
llm_grid_h,
llm_grid_w,
t_factor,
) in self.iter_mm_grid_thw(mm_features):
text_len = offset - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append( llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
) )
t_index = ( grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w))
( if t_factor != 1.0:
torch.arange(llm_grid_t) grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
* video_second_per_grid_t
* tokens_per_second
)
.long()
.flatten()
)
h_index = ( llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx)
torch.arange(llm_grid_h) st = offset + llm_grid_t * llm_grid_h * llm_grid_w
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens): if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st text_len = len(input_tokens) - st
llm_pos_ids_list.append( llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
) )
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta return torch.from_numpy(llm_positions), mrope_position_delta
def _parse_and_validate_image_input( def _parse_and_validate_image_input(
self, **kwargs: object self, **kwargs: object
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment