"tests/compile/distributed/test_sequence_parallelism.py" did not exist on "6c04638214d413dfafa2ab1bf3f16069878e60f9"
Unverified Commit 372b2e76 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Standardize getting number of image patches/tokens (#34358)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 6afa587d
......@@ -4,8 +4,6 @@ from typing import NamedTuple
import pytest
import torch
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
......@@ -46,31 +44,13 @@ class MRoPETestInfo(NamedTuple):
marks: list[pytest.MarkDecorator] = []
TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version
MODELS_TO_TEST = [
MRoPETestInfo(model_name="zai-org/GLM-4.1V-9B-Thinking"),
MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"),
MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"),
MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"),
MRoPETestInfo(
model_name="Qwen/Qwen3-VL-4B-Instruct",
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
reason="Qwen3-VL only available after Transformers v4.57",
)
],
),
MRoPETestInfo(
model_name="Qwen/Qwen3-VL-30B-A3B-Instruct",
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
reason="Qwen3-VL only available after Transformers v4.57",
)
],
),
MRoPETestInfo(model_name="Qwen/Qwen3-VL-4B-Instruct"),
MRoPETestInfo(model_name="Qwen/Qwen3-VL-30B-A3B-Instruct"),
]
num_tokens_list = [11, 8192]
......
......@@ -961,12 +961,6 @@ VLM_TEST_SETTINGS = {
limit_mm_per_prompt={"image": 4},
)
],
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_VERSION) == Version("4.57.1"),
reason="This model is broken in Transformers v4.57.1",
)
],
),
# regression test for https://github.com/vllm-project/vllm/issues/15122
"qwen2_5_vl-windows-attention": VLMTestInfo(
......
......@@ -168,6 +168,7 @@ def test_get_image_size_with_most_features(
image_width=max_image_size.width,
image_height=max_image_size.height,
processor=hf_processor,
mm_kwargs=hf_processor_mm_kwargs,
)
prompt = "<start_of_image>"
......
......@@ -3,7 +3,9 @@
"""Tests for Idefics3's multimodal preprocessing kwargs."""
import pytest
from packaging.version import Version
from transformers import Idefics3Config
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.multimodal import MULTIMODAL_REGISTRY
......@@ -11,6 +13,10 @@ from ....conftest import ImageTestAssets
from ...utils import build_model_context
@pytest.mark.skipif(
Version(TRANSFORMERS_VERSION) < Version("5.2.0"),
reason="See https://github.com/huggingface/transformers/pull/43948",
)
@pytest.mark.parametrize("model_id", ["HuggingFaceM4/Idefics3-8B-Llama3"])
@pytest.mark.parametrize(
("mm_processor_kwargs", "expected_toks_per_img"),
......@@ -63,7 +69,11 @@ def test_processor_override(
# Ensure the placeholders format are correct
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
hf_processed_inputs = hf_processor(text=prompt, images=mm_data["image"])
hf_processed_inputs = hf_processor(
text=prompt,
images=mm_data["image"],
**processor.info.ctx.get_merged_mm_kwargs(hf_processor_mm_kwargs),
)
assert processed_inputs["prompt_token_ids"] == hf_processed_inputs["input_ids"][0]
# Ensure we have the right number of placeholders per num_crops size
......
......@@ -82,6 +82,7 @@ def test_get_image_size_with_most_features(
image_width=max_image_size.width,
image_height=max_image_size.height,
image_processor=hf_processor.image_processor,
mm_kwargs=hf_processor_mm_kwargs,
)
prompt = "<|vision_start|><|image_pad|><|vision_end|>"
......
......@@ -3,7 +3,9 @@
"""Tests for smolvlm's multimodal preprocessing kwargs."""
import pytest
from packaging.version import Version
from transformers import SmolVLMConfig
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.multimodal import MULTIMODAL_REGISTRY
......@@ -11,6 +13,10 @@ from ....conftest import ImageTestAssets
from ...utils import build_model_context
@pytest.mark.skipif(
Version(TRANSFORMERS_VERSION) < Version("5.2.0"),
reason="See https://github.com/huggingface/transformers/pull/43948",
)
@pytest.mark.parametrize("model_id", ["HuggingFaceTB/SmolVLM2-2.2B-Instruct"])
@pytest.mark.parametrize(
("mm_processor_kwargs", "expected_toks_per_img"),
......@@ -63,7 +69,11 @@ def test_processor_override(
# Ensure the placeholders format are correct
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
hf_processed_inputs = hf_processor(text=prompt, images=mm_data["image"])
hf_processed_inputs = hf_processor(
text=prompt,
images=mm_data["image"],
**processor.info.ctx.get_merged_mm_kwargs(hf_processor_mm_kwargs),
)
assert processed_inputs["prompt_token_ids"] == hf_processed_inputs["input_ids"][0]
# Ensure we have the right number of placeholders per num_crops size
......
......@@ -11,7 +11,7 @@ from torch import nn
from transformers import BatchFeature, PretrainedConfig
from transformers.models.cohere2_vision import Cohere2VisionConfig
from transformers.models.cohere2_vision.image_processing_cohere2_vision_fast import ( # noqa: E501
get_optimal_tiled_canvas,
Cohere2VisionImageProcessorFast,
)
from transformers.models.cohere2_vision.processing_cohere2_vision import (
Cohere2VisionProcessor,
......@@ -166,43 +166,20 @@ class Cohere2VisionProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: Cohere2VisionProcessor | None,
processor: Cohere2VisionProcessor,
mm_kwargs: Mapping[str, object],
) -> int:
"""
Calculate the number of image patches for a given image.
Uses the HF processor to determine the actual number of patches.
"""
if processor is None:
processor = self.get_hf_processor()
image_processor = processor.image_processor
image_processor: Cohere2VisionImageProcessorFast = processor.image_processor
# The current implementation of get_number_of_image_patches
# is incorrect, so we patch it here.
# TODO: Revert once
# https://github.com/huggingface/transformers/pull/40312 is released.
# return image_processor.get_number_of_image_patches(image_height,
# image_width, {})
min_patches = image_processor.min_patches
max_patches = image_processor.max_patches
patch_size = image_processor.size
crop_to_patches = image_processor.crop_to_patches
if not crop_to_patches:
return 1
num_columns, num_rows = get_optimal_tiled_canvas(
(image_height, image_width),
(patch_size["height"], patch_size["width"]),
min_patches,
max_patches,
return image_processor.get_number_of_image_patches(
image_height,
image_width,
self.ctx.get_merged_mm_kwargs(mm_kwargs),
)
num_patches = num_columns * num_rows
if num_patches > 1:
num_patches += 1 # Thumbnail image
return num_patches
class Cohere2VisionDummyInputsBuilder(
......@@ -271,6 +248,7 @@ class Cohere2VisionMultiModalProcessor(
image_width=parsed_images.get_image_size(i).width,
image_height=parsed_images.get_image_size(i).height,
processor=hf_processor,
mm_kwargs=mm_kwargs,
)
for i in range(len(parsed_images))
]
......@@ -311,6 +289,7 @@ class Cohere2VisionMultiModalProcessor(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
mm_kwargs=hf_processor_mm_kwargs,
)
patch_tokens = image_token * img_tokens_per_tile + img_line_break_token
repl = f"{boi_token}{patch_tokens * num_patches}{eoi_token}"
......
......@@ -34,7 +34,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import BatchFeature
from transformers import BaseImageProcessor, BatchFeature
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
......@@ -818,10 +818,9 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
image_height: int,
num_frames: int = 1,
do_resize: bool = True,
image_processor: Any | None,
image_processor: BaseImageProcessor,
mm_kwargs: Mapping[str, object],
) -> tuple[ImageSize, int]:
if image_processor is None:
image_processor = self.get_image_processor()
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
......@@ -829,13 +828,16 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
spatial_conv_size = hf_config.spatial_conv_size
temporal_conv_size = hf_config.temporal_conv_size
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
size = mm_kwargs.get("size", image_processor.size)
if do_resize:
resized_height, resized_width = smart_resize(
height=image_height,
width=image_width,
factor=patch_size * spatial_conv_size,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
min_pixels=size["min_pixels"],
max_pixels=size["max_pixels"],
)
preprocessed_size = ImageSize(width=resized_width, height=resized_height)
else:
......@@ -855,12 +857,14 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
image_processor: Any | None,
image_processor: BaseImageProcessor,
mm_kwargs: Mapping[str, object],
) -> int:
_, num_image_tokens = self._get_vision_info(
image_width=image_width,
image_height=image_height,
image_processor=image_processor,
mm_kwargs=mm_kwargs,
)
return num_image_tokens
......@@ -870,35 +874,43 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
image_width: int,
image_height: int,
num_frames: int,
image_processor: Any | None,
image_processor: BaseImageProcessor,
mm_kwargs: Mapping[str, object],
) -> int:
_, num_video_tokens = self._get_vision_info(
image_width=image_width,
image_height=image_height,
num_frames=num_frames,
image_processor=image_processor,
mm_kwargs=mm_kwargs,
)
return num_video_tokens
def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_image_processor()
max_image_size, _ = self._get_vision_info(
image_width=9999999,
image_height=9999999,
image_processor=None,
image_processor=image_processor,
mm_kwargs={},
)
return max_image_size
def get_max_image_tokens(self) -> int:
image_processor = self.get_image_processor()
target_width, target_height = self.get_image_size_with_most_features()
num_image_tokens = self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
image_processor=None,
image_processor=image_processor,
mm_kwargs={},
)
return num_image_tokens
def _get_max_video_frames(self, max_tokens: int) -> int:
image_processor = self.get_image_processor()
target_width, target_height = self.get_image_size_with_most_features()
num_frames = 0
......@@ -909,7 +921,8 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
image_width=target_width,
image_height=target_height,
num_frames=next_num_frames,
image_processor=None,
image_processor=image_processor,
mm_kwargs={},
)
if next_max_tokens > max_tokens:
......@@ -942,13 +955,15 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
image_processor = self.get_image_processor()
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
image_processor=None,
image_processor=image_processor,
mm_kwargs={},
)
......
......@@ -7,6 +7,7 @@ from typing import Annotated, Any, Literal
import torch
from torch import nn
from transformers import BatchFeature, Gemma3Config, Gemma3Processor
from transformers.models.gemma3.image_processing_gemma3 import Gemma3ImageProcessor
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
from vllm.config import VllmConfig
......@@ -84,55 +85,36 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def _resolve_image_kwargs(
self,
processor: Gemma3Processor,
keys: set[str],
) -> dict[str, Any]:
image_processor = processor.image_processor
kwargs = processor._merge_kwargs(
Gemma3ProcessorKwargs,
tokenizer_init_kwargs=processor.tokenizer.init_kwargs,
)
images_kwargs = kwargs["images_kwargs"]
def _resolve_kw(key: str):
val = getattr(image_processor, key)
if val is None:
val = images_kwargs[key]
return val
return {k: _resolve_kw(k) for k in keys}
def get_num_crops(
self,
*,
image_width: int,
image_height: int,
processor: Gemma3Processor | None,
processor: Gemma3Processor,
mm_kwargs: Mapping[str, object],
) -> int:
if processor is None:
processor = self.get_hf_processor()
image_processor: Gemma3ImageProcessor = processor.image_processor
images_kwargs = processor._merge_kwargs(
Gemma3ProcessorKwargs,
tokenizer_init_kwargs=processor.tokenizer.init_kwargs,
**self.ctx.get_merged_mm_kwargs(mm_kwargs),
)["images_kwargs"]
images_kwargs = self._resolve_image_kwargs(
processor,
{
"do_pan_and_scan",
"pan_and_scan_min_crop_size",
"pan_and_scan_max_num_crops",
do_pan_and_scan = images_kwargs.get(
"do_pan_and_scan", image_processor.do_pan_and_scan
)
pan_and_scan_min_crop_size = images_kwargs.get(
"pan_and_scan_min_crop_size", image_processor.pan_and_scan_min_crop_size
)
pan_and_scan_max_num_crops = images_kwargs.get(
"pan_and_scan_max_num_crops", image_processor.pan_and_scan_max_num_crops
)
pan_and_scan_min_ratio_to_activate = images_kwargs.get(
"pan_and_scan_min_ratio_to_activate",
},
image_processor.pan_and_scan_min_ratio_to_activate,
)
do_pan_and_scan = images_kwargs["do_pan_and_scan"]
pan_and_scan_min_crop_size = images_kwargs["pan_and_scan_min_crop_size"]
pan_and_scan_max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
pan_and_scan_min_ratio_to_activate = images_kwargs[
"pan_and_scan_min_ratio_to_activate"
]
if not do_pan_and_scan:
return 0
......@@ -180,17 +162,16 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: Gemma3Processor | None,
processor: Gemma3Processor,
mm_kwargs: Mapping[str, object],
) -> PromptUpdateDetails[str]:
if processor is None:
processor = self.get_hf_processor()
boi_token = processor.boi_token
num_crops = self.get_num_crops(
image_width=image_width,
image_height=image_height,
processor=processor,
mm_kwargs=mm_kwargs,
)
if num_crops == 0:
......@@ -215,15 +196,14 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: Gemma3Processor | None,
processor: Gemma3Processor,
mm_kwargs: Mapping[str, object],
) -> int:
if processor is None:
processor = self.get_hf_processor()
num_crops = self.get_num_crops(
image_width=image_width,
image_height=image_height,
processor=processor,
mm_kwargs=mm_kwargs,
)
image_seq_len = processor.image_seq_length
......@@ -231,11 +211,17 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
image_processor: Gemma3ImageProcessor = processor.image_processor
images_kwargs = processor._merge_kwargs(
Gemma3ProcessorKwargs,
tokenizer_init_kwargs=processor.tokenizer.init_kwargs,
**self.ctx.get_merged_mm_kwargs({}),
)["images_kwargs"]
images_kwargs = self._resolve_image_kwargs(
processor, {"pan_and_scan_max_num_crops"}
max_num_crops = images_kwargs.get(
"pan_and_scan_max_num_crops", image_processor.pan_and_scan_max_num_crops
)
max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
vision_config = self.get_hf_config().vision_config
native_size = vision_config.image_size
......@@ -303,6 +289,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
image_width=size.width,
image_height=size.height,
processor=hf_processor,
mm_kwargs=mm_kwargs,
)
for size in image_sizes
]
......@@ -339,6 +326,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
mm_kwargs=hf_processor_mm_kwargs,
)
return [
......
......@@ -131,7 +131,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: Gemma3nProcessor | None,
processor: Gemma3nProcessor,
) -> str:
"""
Get the replacement text for image tokens.
......@@ -139,9 +139,6 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
For Gemma3n, this should return the full_image_sequence which includes
BOI token, repeated image tokens, and EOI token.
"""
if processor is None:
processor = self.get_hf_processor()
return PromptUpdateDetails.select_token_id(
processor.full_image_sequence, processor.image_token_id
)
......@@ -149,7 +146,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
def get_audio_repl(
self,
*,
processor: Gemma3nProcessor | None,
processor: Gemma3nProcessor,
) -> str:
"""
Get the replacement text for audio tokens.
......@@ -157,9 +154,6 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
For Gemma3n, this should return the full_audio_sequence which includes
BOA token, repeated audio tokens, and EOA token.
"""
if processor is None:
processor = self.get_hf_processor()
# Return the full audio sequence as defined by the processor
return PromptUpdateDetails.select_token_id(
processor.full_audio_sequence, processor.audio_token_id
......
......@@ -424,12 +424,9 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
*,
image_width: int,
image_height: int,
processor: H2OVLProcessor | None,
processor: H2OVLProcessor,
use_msac: bool | None = None,
) -> int:
if processor is None:
processor = self.get_hf_processor()
return processor.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
......
......@@ -78,7 +78,10 @@ from vllm.transformers_utils.configs.hunyuan_vl import (
HunYuanVLVisionConfig,
)
from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor
from vllm.transformers_utils.processors.hunyuan_vl_image import smart_resize
from vllm.transformers_utils.processors.hunyuan_vl_image import (
HunYuanVLImageProcessor,
smart_resize,
)
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
......@@ -596,7 +599,7 @@ class HunYuanVLProcessingInfo(BaseProcessingInfo):
def get_image_processor(
self,
**kwargs: object,
) -> HunYuanVLProcessor:
) -> HunYuanVLImageProcessor:
return self.get_hf_processor(**kwargs).image_processor
def get_data_parser(self):
......@@ -624,23 +627,24 @@ class HunYuanVLProcessingInfo(BaseProcessingInfo):
image_height: int,
num_frames: int = 1,
do_resize: bool = True,
image_processor: HunYuanVLProcessor | None,
image_processor: HunYuanVLImageProcessor,
mm_kwargs: Mapping[str, object],
) -> tuple[ImageSize, int]:
if image_processor is None:
image_processor = self.get_image_processor()
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
patch_size = vision_config.patch_size
spatial_merge_size = vision_config.spatial_merge_size
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
size = mm_kwargs.get("size", image_processor.size)
if do_resize:
resized_height, resized_width = smart_resize(
height=image_height,
width=image_width,
factor=patch_size * spatial_merge_size,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
min_pixels=size["shortest_edge"],
max_pixels=size["longest_edge"],
)
preprocessed_size = ImageSize(width=resized_width, height=resized_height)
else:
......@@ -662,29 +666,37 @@ class HunYuanVLProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
image_processor: HunYuanVLProcessor | None,
image_processor: HunYuanVLImageProcessor,
mm_kwargs: Mapping[str, object],
) -> int:
_, num_image_tokens = self._get_vision_info(
image_width=image_width,
image_height=image_height,
image_processor=image_processor,
mm_kwargs=mm_kwargs,
)
return num_image_tokens
def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_image_processor()
max_image_size, _ = self._get_vision_info(
image_width=512,
image_height=8192,
image_processor=None,
image_processor=image_processor,
mm_kwargs={},
)
return max_image_size
def get_max_image_tokens(self) -> int:
image_processor = self.get_image_processor()
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
image_processor=None,
image_processor=image_processor,
mm_kwargs={},
)
......
......@@ -16,7 +16,6 @@
# limitations under the License.
"""Inference-only Idefics3 model compatible with HuggingFace weights."""
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal, TypeAlias
......@@ -168,54 +167,35 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: Idefics3Processor | None,
) -> tuple[int, int]:
if processor is None:
processor = self.get_hf_processor()
processor: Idefics3Processor,
mm_kwargs: Mapping[str, object],
) -> tuple[int, int, int]:
image_processor: Idefics3ImageProcessor = processor.image_processor
max_image_size = image_processor.max_image_size["longest_edge"]
size = image_processor.size["longest_edge"]
assert size % max_image_size == 0, (
"`longest_edge` in image_processor's `size` must be divisible by "
"`longest_edge` in `max_image_size`, this may be caused by "
"incorrect mm_kwargs override."
)
resized_height, resized_width = self._get_resize_output_image_size(
image_width=image_width,
image_height=image_height,
resolution_max_side=size,
return image_processor.get_number_of_image_patches(
image_height,
image_width,
self.ctx.get_merged_mm_kwargs(mm_kwargs),
)
if resized_height > max_image_size or resized_width > max_image_size:
grid_h = math.ceil(resized_height / max_image_size)
grid_w = math.ceil(resized_width / max_image_size)
else:
grid_h = grid_w = 0
return grid_w, grid_h
def get_num_patches(
self,
*,
image_width: int,
image_height: int,
processor: Idefics3Processor | None,
processor: Idefics3Processor,
mm_kwargs: Mapping[str, object],
) -> int:
grid_w, grid_h = self._get_image_feature_grid_size(
num_patches, _, _ = self._get_image_feature_grid_size(
image_width=image_width,
image_height=image_height,
processor=processor,
mm_kwargs=mm_kwargs,
)
return grid_w * grid_h + 1
def _get_image_token(
self, processor: Idefics3Processor | None
) -> tuple[str, str, str]:
if processor is None:
processor = self.get_hf_processor()
return num_patches
def _get_image_token(self, processor: Idefics3Processor) -> tuple[str, str, str]:
image_token = processor.image_token
fake_image_token = processor.fake_image_token
global_image_token = processor.global_image_tag
......@@ -226,11 +206,9 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: Idefics3Processor | None,
processor: Idefics3Processor,
mm_kwargs: Mapping[str, object],
) -> str:
if processor is None:
processor = self.get_hf_processor()
image_token, fake_image_token, global_img_token = self._get_image_token(
processor
)
......@@ -241,10 +219,11 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
global_img_placeholder = fake_image_token + global_img_token + p_img
tile_img_placeholder = fake_image_token + grid_placeholder + p_img
grid_w, grid_h = self._get_image_feature_grid_size(
_, grid_h, grid_w = self._get_image_feature_grid_size(
image_width=image_width,
image_height=image_height,
processor=processor,
mm_kwargs=mm_kwargs,
)
if grid_w == 0 and grid_h == 0:
return global_img_placeholder + fake_image_token
......@@ -272,15 +251,14 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: Idefics3Processor | None,
processor: Idefics3Processor,
mm_kwargs: Mapping[str, object],
) -> int:
if processor is None:
processor = self.get_hf_processor()
num_patches = self.get_num_patches(
image_width=image_width,
image_height=image_height,
processor=processor,
mm_kwargs=mm_kwargs,
)
return num_patches * processor.image_seq_len
......@@ -353,6 +331,7 @@ class Idefics3MultiModalProcessor(BaseMultiModalProcessor[Idefics3ProcessingInfo
image_width=size.width,
image_height=size.height,
processor=hf_processor,
mm_kwargs=mm_kwargs,
)
for size in image_sizes
]
......@@ -398,6 +377,7 @@ class Idefics3MultiModalProcessor(BaseMultiModalProcessor[Idefics3ProcessingInfo
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
mm_kwargs=hf_processor_mm_kwargs,
)
return PromptUpdateDetails.select_text(
......
......@@ -197,20 +197,18 @@ class InternS1ProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: GotOcr2ImageProcessorFast | None = None,
processor: InternVLProcessor,
mm_kwargs: Mapping[str, object],
) -> int:
if processor is None:
processor = self.get_hf_processor().image_processor
image_processor: GotOcr2ImageProcessorFast = processor.image_processor
if not isinstance(processor, GotOcr2ImageProcessorFast):
raise ValueError(
f"GotOcr2ImageProcessorFast is expected but got {type(processor)}"
)
num_image_patches = processor.get_number_of_image_patches(
image_height, image_width, images_kwargs=dict()
num_image_patches = image_processor.get_number_of_image_patches(
image_height,
image_width,
self.ctx.get_merged_mm_kwargs(mm_kwargs),
)
num_image_tokens = self.get_hf_processor().image_seq_length * num_image_patches
return num_image_tokens
return processor.image_seq_length * num_image_patches
def resolve_target_ratios(self, use_thumbnail: bool | None = None):
image_processor = self.get_hf_processor().image_processor
......@@ -243,7 +241,8 @@ class InternS1ProcessingInfo(BaseProcessingInfo):
feat_size = self.get_num_image_tokens(
image_width=width,
image_height=height,
processor=processor.image_processor,
processor=processor,
mm_kwargs={},
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
......@@ -262,7 +261,8 @@ class InternS1ProcessingInfo(BaseProcessingInfo):
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=processor.image_processor,
processor=processor,
mm_kwargs={},
)
def get_num_frames_with_most_features(
......
......@@ -705,11 +705,8 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: BaseInternVLProcessor | None,
processor: BaseInternVLProcessor,
) -> int:
if processor is None:
processor = self.get_hf_processor()
return processor.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
......
......@@ -10,7 +10,7 @@ import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from transformers import PretrainedConfig
from transformers import BaseImageProcessor, PretrainedConfig
from transformers.activations import GELUActivation
from transformers.feature_extraction_utils import BatchFeature
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
......@@ -1011,24 +1011,25 @@ class KeyeProcessingInfo(BaseProcessingInfo):
image_height: int,
num_frames: int = 1,
do_resize: bool = True,
image_processor,
image_processor: BaseImageProcessor,
mm_kwargs: Mapping[str, object],
) -> tuple[ImageSize, int]:
if image_processor is None:
image_processor = self.get_image_processor()
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
temporal_patch_size = 1
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
size = mm_kwargs.get("size", image_processor.size)
if do_resize:
resized_height, resized_width = smart_resize(
height=image_height,
width=image_width,
factor=patch_size * merge_size,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
min_pixels=size["min_pixels"],
max_pixels=size["max_pixels"],
)
preprocessed_size = ImageSize(width=resized_width, height=resized_height)
else:
......@@ -1050,12 +1051,14 @@ class KeyeProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
image_processor,
image_processor: BaseImageProcessor,
mm_kwargs: Mapping[str, object],
) -> int:
_, num_image_tokens = self._get_vision_info(
image_width=image_width,
image_height=image_height,
image_processor=image_processor,
mm_kwargs=mm_kwargs,
)
return num_image_tokens
......@@ -1065,36 +1068,42 @@ class KeyeProcessingInfo(BaseProcessingInfo):
image_width: int,
image_height: int,
num_frames: int,
image_processor,
image_processor: BaseImageProcessor,
mm_kwargs: Mapping[str, object],
) -> int:
_, num_video_tokens = self._get_vision_info(
image_width=image_width,
image_height=image_height,
num_frames=num_frames,
image_processor=image_processor,
mm_kwargs=mm_kwargs,
)
return num_video_tokens
def get_image_size_with_most_features(
self,
) -> ImageSize:
def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_image_processor()
max_image_size, _ = self._get_vision_info(
image_width=self.get_max_image_size(),
image_height=self.get_max_image_size(),
image_processor=None,
image_processor=image_processor,
mm_kwargs={},
)
return max_image_size
def get_max_image_tokens(self) -> int:
image_processor = self.get_image_processor()
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
image_processor=None,
image_processor=image_processor,
mm_kwargs={},
)
def _get_max_video_frames(self, max_tokens: int) -> int:
image_processor = self.get_image_processor()
target_width, target_height = self.get_image_size_with_most_features()
num_frames = 0
......@@ -1105,7 +1114,8 @@ class KeyeProcessingInfo(BaseProcessingInfo):
image_width=target_width,
image_height=target_height,
num_frames=next_num_frames,
image_processor=None,
image_processor=image_processor,
mm_kwargs={},
)
if next_max_tokens > max_tokens:
......@@ -1130,13 +1140,15 @@ class KeyeProcessingInfo(BaseProcessingInfo):
return max(max_frames_per_video, 1)
def get_max_video_tokens(self, seq_len: int) -> int:
image_processor = self.get_image_processor()
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len),
image_processor=None,
image_processor=image_processor,
mm_kwargs={},
)
......
......@@ -176,7 +176,7 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
min_tiles: int,
max_tiles: int,
tile_size: int,
) -> tuple[int, int]:
) -> tuple[int, int, int]:
aspect_ratio = width / height
target_ratios = self._target_ratios(min_tiles, max_tiles)
# find best matching grid configuration
......@@ -190,18 +190,27 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
self,
image_width: int,
image_height: int,
processor: Lfm2VlProcessor | None,
) -> tuple[int, int]:
if processor is None:
processor = self.get_image_processor()
processor: Lfm2VlProcessor,
mm_kwargs: Mapping[str, object],
) -> tuple[int, int, int]:
image_processor: Lfm2VlImageProcessorFast = processor.image_processor
downsample_factor = processor.image_processor.downsample_factor
encoder_patch_size = processor.image_processor.encoder_patch_size
max_pixels_tolerance = processor.image_processor.max_pixels_tolerance
min_tiles = processor.image_processor.min_tiles
max_tiles = processor.image_processor.max_tiles
max_image_tokens = processor.image_processor.max_image_tokens
tile_size = processor.image_processor.tile_size
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
downsample_factor = mm_kwargs.get(
"downsample_factor", image_processor.downsample_factor
)
encoder_patch_size = mm_kwargs.get(
"encoder_patch_size", image_processor.encoder_patch_size
)
max_pixels_tolerance = mm_kwargs.get(
"max_pixels_tolerance", image_processor.max_pixels_tolerance
)
min_tiles = mm_kwargs.get("min_tiles", image_processor.min_tiles)
max_tiles = mm_kwargs.get("max_tiles", image_processor.max_tiles)
max_image_tokens = mm_kwargs.get(
"max_image_tokens", image_processor.max_image_tokens
)
tile_size = mm_kwargs.get("tile_size", image_processor.tile_size)
do_image_splitting = not min_tiles == max_tiles == 1
is_image_large = self._is_image_too_large(
......@@ -235,12 +244,14 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: Lfm2VlProcessor | None,
processor: Lfm2VlProcessor,
mm_kwargs: Mapping[str, object],
) -> int:
_, _, total_patches = self._get_image_feature_grid_size(
image_width=image_width,
image_height=image_height,
processor=processor,
mm_kwargs=mm_kwargs,
)
return total_patches
......@@ -249,11 +260,9 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
image_width: int,
image_height: int,
spatial_shapes: torch.Tensor,
processor: Lfm2VlProcessor | None,
processor: Lfm2VlProcessor,
mm_kwargs: Mapping[str, object],
) -> str:
if processor is None:
processor = self.get_hf_processor()
grid_placeholder = "<|img_row_{n_h}_col_{n_w}|>"
image_token = processor.image_token
image_start_token = processor.image_start_token
......@@ -263,6 +272,7 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
num_thumbnail_tokens, num_tokens_per_tile = self.get_num_image_tokens(
spatial_shapes=spatial_shapes,
processor=processor,
mm_kwargs=mm_kwargs,
)
tile_img_placeholder = grid_placeholder + (image_token * num_tokens_per_tile)
......@@ -270,6 +280,7 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
image_width=image_width,
image_height=image_height,
processor=processor,
mm_kwargs=mm_kwargs,
)
if grid_w > 1 or grid_h > 1:
......@@ -295,15 +306,25 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
self,
*,
spatial_shapes: torch.Tensor,
processor: Lfm2VlProcessor | None,
processor: Lfm2VlProcessor,
mm_kwargs: Mapping[str, object],
) -> tuple[int, int]:
tile_size = processor.image_processor.tile_size
downsample_factor = processor.image_processor.downsample_factor
encoder_patch_size = processor.image_processor.encoder_patch_size
image_processor: Lfm2VlImageProcessorFast = processor.image_processor
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
downsample_factor = mm_kwargs.get(
"downsample_factor", image_processor.downsample_factor
)
encoder_patch_size = mm_kwargs.get(
"encoder_patch_size", image_processor.encoder_patch_size
)
tile_size = mm_kwargs.get("tile_size", image_processor.tile_size)
num_thumbnail_tokens = spatial_shapes[-1].prod() // (downsample_factor**2)
num_patches_tile = tile_size // encoder_patch_size
dwn_num_patches_tile = math.ceil(num_patches_tile / downsample_factor)
num_tiles_tokens = dwn_num_patches_tile * dwn_num_patches_tile
return num_thumbnail_tokens, num_tiles_tokens
......@@ -372,6 +393,7 @@ class Lfm2VLMultiModalProcessor(BaseMultiModalProcessor[Lfm2VLProcessingInfo]):
image_width=size.width,
image_height=size.height,
processor=hf_processor,
mm_kwargs=mm_kwargs,
)
for size in image_sizes
]
......@@ -414,6 +436,7 @@ class Lfm2VLMultiModalProcessor(BaseMultiModalProcessor[Lfm2VLProcessingInfo]):
image_height=image_size.height,
spatial_shapes=spatial_shapes,
processor=hf_processor,
mm_kwargs=hf_processor_mm_kwargs,
)
return PromptUpdateDetails.select_text(
image_repl,
......
......@@ -1224,11 +1224,8 @@ class MolmoProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: MolmoProcessorWrapper | None,
processor: MolmoProcessorWrapper,
) -> int:
if processor is None:
processor = self.get_hf_processor()
ncols, nrows = processor.get_patches_grid_size(
image_width=image_width,
image_height=image_height,
......
......@@ -1869,12 +1869,9 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
*,
image_height: int,
image_width: int,
processor: Molmo2ProcessorWrapper | None = None,
processor: Molmo2ProcessorWrapper,
) -> int:
if processor is None:
processor = self.get_hf_processor()
hf_processor = processor.processor # type: ignore
hf_processor = processor.processor
resize_nrows, resize_cols = processor.get_base_grid_size(is_video=False)
# start/end tokens + image patch token + col tokens
......@@ -1897,11 +1894,8 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
self,
*,
num_frames: int,
processor: Molmo2ProcessorWrapper | None = None,
processor: Molmo2ProcessorWrapper,
) -> int:
if processor is None:
processor = self.get_hf_processor()
resize_nrows, resize_cols = processor.get_base_grid_size(is_video=True)
# start/end tokens
extra = 2 + resize_nrows * (
......@@ -1929,7 +1923,9 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
width = wr * crop_window_size + total_margin_pixels
feat_size = self.get_num_image_tokens(
image_height=height, image_width=width, processor=processor
image_height=height,
image_width=width,
processor=processor,
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
......@@ -1940,8 +1936,15 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
return largest_feature_pinpoint
def _get_max_video_frames(self, max_tokens: int) -> int:
num_tokens_per_frame = self.get_num_video_tokens(num_frames=1)
def _get_max_video_frames(
self,
max_tokens: int,
processor: Molmo2ProcessorWrapper,
) -> int:
num_tokens_per_frame = self.get_num_video_tokens(
num_frames=1,
processor=processor,
)
max_frames = max_tokens // num_tokens_per_frame
return max(max_frames, 1)
......@@ -1950,10 +1953,11 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
video_processor = self.get_hf_processor().processor.video_processor
processor = self.get_hf_processor()
video_processor = processor.processor.video_processor
num_frames = video_processor.num_frames
max_videos = mm_counts.get("video", 0)
max_total_frames = self._get_max_video_frames(seq_len)
max_total_frames = self._get_max_video_frames(seq_len, processor)
max_frames_per_video = min(
max_total_frames // max(max_videos, 1),
num_frames,
......
......@@ -215,7 +215,7 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo):
image_width: int,
image_height: int,
num_frames: int = 1,
) -> tuple[ImageSize, int]:
) -> int:
hf_config = self.get_hf_config()
vit_config = hf_config.vit_config
patch_size = vit_config.patch_size
......@@ -245,7 +245,6 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo):
image_width=target_width,
image_height=target_height,
num_frames=next_num_frames,
image_processor=None,
)
if next_max_tokens > max_tokens:
break
......@@ -270,7 +269,6 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo):
image_width: int,
image_height: int,
num_frames: int,
image_processor: BaseImageProcessor | None,
) -> int:
num_video_tokens = self.get_num_image_tokens(
image_width=image_width, image_height=image_height, num_frames=num_frames
......@@ -287,7 +285,6 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo):
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
image_processor=None,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment