Unverified Commit 372b2e76 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Standardize getting number of image patches/tokens (#34358)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 6afa587d
...@@ -4,8 +4,6 @@ from typing import NamedTuple ...@@ -4,8 +4,6 @@ from typing import NamedTuple
import pytest import pytest
import torch import torch
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -46,31 +44,13 @@ class MRoPETestInfo(NamedTuple): ...@@ -46,31 +44,13 @@ class MRoPETestInfo(NamedTuple):
marks: list[pytest.MarkDecorator] = [] marks: list[pytest.MarkDecorator] = []
TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version
MODELS_TO_TEST = [ MODELS_TO_TEST = [
MRoPETestInfo(model_name="zai-org/GLM-4.1V-9B-Thinking"), MRoPETestInfo(model_name="zai-org/GLM-4.1V-9B-Thinking"),
MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"), MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"),
MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"), MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"),
MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"), MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"),
MRoPETestInfo( MRoPETestInfo(model_name="Qwen/Qwen3-VL-4B-Instruct"),
model_name="Qwen/Qwen3-VL-4B-Instruct", MRoPETestInfo(model_name="Qwen/Qwen3-VL-30B-A3B-Instruct"),
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
reason="Qwen3-VL only available after Transformers v4.57",
)
],
),
MRoPETestInfo(
model_name="Qwen/Qwen3-VL-30B-A3B-Instruct",
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
reason="Qwen3-VL only available after Transformers v4.57",
)
],
),
] ]
num_tokens_list = [11, 8192] num_tokens_list = [11, 8192]
......
...@@ -961,12 +961,6 @@ VLM_TEST_SETTINGS = { ...@@ -961,12 +961,6 @@ VLM_TEST_SETTINGS = {
limit_mm_per_prompt={"image": 4}, limit_mm_per_prompt={"image": 4},
) )
], ],
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_VERSION) == Version("4.57.1"),
reason="This model is broken in Transformers v4.57.1",
)
],
), ),
# regression test for https://github.com/vllm-project/vllm/issues/15122 # regression test for https://github.com/vllm-project/vllm/issues/15122
"qwen2_5_vl-windows-attention": VLMTestInfo( "qwen2_5_vl-windows-attention": VLMTestInfo(
......
...@@ -168,6 +168,7 @@ def test_get_image_size_with_most_features( ...@@ -168,6 +168,7 @@ def test_get_image_size_with_most_features(
image_width=max_image_size.width, image_width=max_image_size.width,
image_height=max_image_size.height, image_height=max_image_size.height,
processor=hf_processor, processor=hf_processor,
mm_kwargs=hf_processor_mm_kwargs,
) )
prompt = "<start_of_image>" prompt = "<start_of_image>"
......
...@@ -3,7 +3,9 @@ ...@@ -3,7 +3,9 @@
"""Tests for Idefics3's multimodal preprocessing kwargs.""" """Tests for Idefics3's multimodal preprocessing kwargs."""
import pytest import pytest
from packaging.version import Version
from transformers import Idefics3Config from transformers import Idefics3Config
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -11,6 +13,10 @@ from ....conftest import ImageTestAssets ...@@ -11,6 +13,10 @@ from ....conftest import ImageTestAssets
from ...utils import build_model_context from ...utils import build_model_context
@pytest.mark.skipif(
Version(TRANSFORMERS_VERSION) < Version("5.2.0"),
reason="See https://github.com/huggingface/transformers/pull/43948",
)
@pytest.mark.parametrize("model_id", ["HuggingFaceM4/Idefics3-8B-Llama3"]) @pytest.mark.parametrize("model_id", ["HuggingFaceM4/Idefics3-8B-Llama3"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
("mm_processor_kwargs", "expected_toks_per_img"), ("mm_processor_kwargs", "expected_toks_per_img"),
...@@ -63,7 +69,11 @@ def test_processor_override( ...@@ -63,7 +69,11 @@ def test_processor_override(
# Ensure the placeholders format are correct # Ensure the placeholders format are correct
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
hf_processed_inputs = hf_processor(text=prompt, images=mm_data["image"]) hf_processed_inputs = hf_processor(
text=prompt,
images=mm_data["image"],
**processor.info.ctx.get_merged_mm_kwargs(hf_processor_mm_kwargs),
)
assert processed_inputs["prompt_token_ids"] == hf_processed_inputs["input_ids"][0] assert processed_inputs["prompt_token_ids"] == hf_processed_inputs["input_ids"][0]
# Ensure we have the right number of placeholders per num_crops size # Ensure we have the right number of placeholders per num_crops size
......
...@@ -82,6 +82,7 @@ def test_get_image_size_with_most_features( ...@@ -82,6 +82,7 @@ def test_get_image_size_with_most_features(
image_width=max_image_size.width, image_width=max_image_size.width,
image_height=max_image_size.height, image_height=max_image_size.height,
image_processor=hf_processor.image_processor, image_processor=hf_processor.image_processor,
mm_kwargs=hf_processor_mm_kwargs,
) )
prompt = "<|vision_start|><|image_pad|><|vision_end|>" prompt = "<|vision_start|><|image_pad|><|vision_end|>"
......
...@@ -3,7 +3,9 @@ ...@@ -3,7 +3,9 @@
"""Tests for smolvlm's multimodal preprocessing kwargs.""" """Tests for smolvlm's multimodal preprocessing kwargs."""
import pytest import pytest
from packaging.version import Version
from transformers import SmolVLMConfig from transformers import SmolVLMConfig
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -11,6 +13,10 @@ from ....conftest import ImageTestAssets ...@@ -11,6 +13,10 @@ from ....conftest import ImageTestAssets
from ...utils import build_model_context from ...utils import build_model_context
@pytest.mark.skipif(
Version(TRANSFORMERS_VERSION) < Version("5.2.0"),
reason="See https://github.com/huggingface/transformers/pull/43948",
)
@pytest.mark.parametrize("model_id", ["HuggingFaceTB/SmolVLM2-2.2B-Instruct"]) @pytest.mark.parametrize("model_id", ["HuggingFaceTB/SmolVLM2-2.2B-Instruct"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
("mm_processor_kwargs", "expected_toks_per_img"), ("mm_processor_kwargs", "expected_toks_per_img"),
...@@ -63,7 +69,11 @@ def test_processor_override( ...@@ -63,7 +69,11 @@ def test_processor_override(
# Ensure the placeholders format are correct # Ensure the placeholders format are correct
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
hf_processed_inputs = hf_processor(text=prompt, images=mm_data["image"]) hf_processed_inputs = hf_processor(
text=prompt,
images=mm_data["image"],
**processor.info.ctx.get_merged_mm_kwargs(hf_processor_mm_kwargs),
)
assert processed_inputs["prompt_token_ids"] == hf_processed_inputs["input_ids"][0] assert processed_inputs["prompt_token_ids"] == hf_processed_inputs["input_ids"][0]
# Ensure we have the right number of placeholders per num_crops size # Ensure we have the right number of placeholders per num_crops size
......
...@@ -11,7 +11,7 @@ from torch import nn ...@@ -11,7 +11,7 @@ from torch import nn
from transformers import BatchFeature, PretrainedConfig from transformers import BatchFeature, PretrainedConfig
from transformers.models.cohere2_vision import Cohere2VisionConfig from transformers.models.cohere2_vision import Cohere2VisionConfig
from transformers.models.cohere2_vision.image_processing_cohere2_vision_fast import ( # noqa: E501 from transformers.models.cohere2_vision.image_processing_cohere2_vision_fast import ( # noqa: E501
get_optimal_tiled_canvas, Cohere2VisionImageProcessorFast,
) )
from transformers.models.cohere2_vision.processing_cohere2_vision import ( from transformers.models.cohere2_vision.processing_cohere2_vision import (
Cohere2VisionProcessor, Cohere2VisionProcessor,
...@@ -166,43 +166,20 @@ class Cohere2VisionProcessingInfo(BaseProcessingInfo): ...@@ -166,43 +166,20 @@ class Cohere2VisionProcessingInfo(BaseProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: Cohere2VisionProcessor | None, processor: Cohere2VisionProcessor,
mm_kwargs: Mapping[str, object],
) -> int: ) -> int:
""" """
Calculate the number of image patches for a given image. Calculate the number of image patches for a given image.
Uses the HF processor to determine the actual number of patches. Uses the HF processor to determine the actual number of patches.
""" """
if processor is None: image_processor: Cohere2VisionImageProcessorFast = processor.image_processor
processor = self.get_hf_processor()
image_processor = processor.image_processor
# The current implementation of get_number_of_image_patches return image_processor.get_number_of_image_patches(
# is incorrect, so we patch it here. image_height,
# TODO: Revert once image_width,
# https://github.com/huggingface/transformers/pull/40312 is released. self.ctx.get_merged_mm_kwargs(mm_kwargs),
# return image_processor.get_number_of_image_patches(image_height,
# image_width, {})
min_patches = image_processor.min_patches
max_patches = image_processor.max_patches
patch_size = image_processor.size
crop_to_patches = image_processor.crop_to_patches
if not crop_to_patches:
return 1
num_columns, num_rows = get_optimal_tiled_canvas(
(image_height, image_width),
(patch_size["height"], patch_size["width"]),
min_patches,
max_patches,
) )
num_patches = num_columns * num_rows
if num_patches > 1:
num_patches += 1 # Thumbnail image
return num_patches
class Cohere2VisionDummyInputsBuilder( class Cohere2VisionDummyInputsBuilder(
...@@ -271,6 +248,7 @@ class Cohere2VisionMultiModalProcessor( ...@@ -271,6 +248,7 @@ class Cohere2VisionMultiModalProcessor(
image_width=parsed_images.get_image_size(i).width, image_width=parsed_images.get_image_size(i).width,
image_height=parsed_images.get_image_size(i).height, image_height=parsed_images.get_image_size(i).height,
processor=hf_processor, processor=hf_processor,
mm_kwargs=mm_kwargs,
) )
for i in range(len(parsed_images)) for i in range(len(parsed_images))
] ]
...@@ -311,6 +289,7 @@ class Cohere2VisionMultiModalProcessor( ...@@ -311,6 +289,7 @@ class Cohere2VisionMultiModalProcessor(
image_width=image_size.width, image_width=image_size.width,
image_height=image_size.height, image_height=image_size.height,
processor=hf_processor, processor=hf_processor,
mm_kwargs=hf_processor_mm_kwargs,
) )
patch_tokens = image_token * img_tokens_per_tile + img_line_break_token patch_tokens = image_token * img_tokens_per_tile + img_line_break_token
repl = f"{boi_token}{patch_tokens * num_patches}{eoi_token}" repl = f"{boi_token}{patch_tokens * num_patches}{eoi_token}"
......
...@@ -34,7 +34,7 @@ import torch ...@@ -34,7 +34,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from transformers import BatchFeature from transformers import BaseImageProcessor, BatchFeature
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
...@@ -818,10 +818,9 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): ...@@ -818,10 +818,9 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
image_height: int, image_height: int,
num_frames: int = 1, num_frames: int = 1,
do_resize: bool = True, do_resize: bool = True,
image_processor: Any | None, image_processor: BaseImageProcessor,
mm_kwargs: Mapping[str, object],
) -> tuple[ImageSize, int]: ) -> tuple[ImageSize, int]:
if image_processor is None:
image_processor = self.get_image_processor()
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
...@@ -829,13 +828,16 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): ...@@ -829,13 +828,16 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
spatial_conv_size = hf_config.spatial_conv_size spatial_conv_size = hf_config.spatial_conv_size
temporal_conv_size = hf_config.temporal_conv_size temporal_conv_size = hf_config.temporal_conv_size
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
size = mm_kwargs.get("size", image_processor.size)
if do_resize: if do_resize:
resized_height, resized_width = smart_resize( resized_height, resized_width = smart_resize(
height=image_height, height=image_height,
width=image_width, width=image_width,
factor=patch_size * spatial_conv_size, factor=patch_size * spatial_conv_size,
min_pixels=image_processor.min_pixels, min_pixels=size["min_pixels"],
max_pixels=image_processor.max_pixels, max_pixels=size["max_pixels"],
) )
preprocessed_size = ImageSize(width=resized_width, height=resized_height) preprocessed_size = ImageSize(width=resized_width, height=resized_height)
else: else:
...@@ -855,12 +857,14 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): ...@@ -855,12 +857,14 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
image_processor: Any | None, image_processor: BaseImageProcessor,
mm_kwargs: Mapping[str, object],
) -> int: ) -> int:
_, num_image_tokens = self._get_vision_info( _, num_image_tokens = self._get_vision_info(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
image_processor=image_processor, image_processor=image_processor,
mm_kwargs=mm_kwargs,
) )
return num_image_tokens return num_image_tokens
...@@ -870,35 +874,43 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): ...@@ -870,35 +874,43 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
image_width: int, image_width: int,
image_height: int, image_height: int,
num_frames: int, num_frames: int,
image_processor: Any | None, image_processor: BaseImageProcessor,
mm_kwargs: Mapping[str, object],
) -> int: ) -> int:
_, num_video_tokens = self._get_vision_info( _, num_video_tokens = self._get_vision_info(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
num_frames=num_frames, num_frames=num_frames,
image_processor=image_processor, image_processor=image_processor,
mm_kwargs=mm_kwargs,
) )
return num_video_tokens return num_video_tokens
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_image_processor()
max_image_size, _ = self._get_vision_info( max_image_size, _ = self._get_vision_info(
image_width=9999999, image_width=9999999,
image_height=9999999, image_height=9999999,
image_processor=None, image_processor=image_processor,
mm_kwargs={},
) )
return max_image_size return max_image_size
def get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
image_processor = self.get_image_processor()
target_width, target_height = self.get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
num_image_tokens = self.get_num_image_tokens( num_image_tokens = self.get_num_image_tokens(
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
image_processor=None, image_processor=image_processor,
mm_kwargs={},
) )
return num_image_tokens return num_image_tokens
def _get_max_video_frames(self, max_tokens: int) -> int: def _get_max_video_frames(self, max_tokens: int) -> int:
image_processor = self.get_image_processor()
target_width, target_height = self.get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
num_frames = 0 num_frames = 0
...@@ -909,7 +921,8 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): ...@@ -909,7 +921,8 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
num_frames=next_num_frames, num_frames=next_num_frames,
image_processor=None, image_processor=image_processor,
mm_kwargs={},
) )
if next_max_tokens > max_tokens: if next_max_tokens > max_tokens:
...@@ -942,13 +955,15 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): ...@@ -942,13 +955,15 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> int: ) -> int:
image_processor = self.get_image_processor()
target_width, target_height = self.get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_video_tokens( return self.get_num_video_tokens(
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts), num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
image_processor=None, image_processor=image_processor,
mm_kwargs={},
) )
......
...@@ -7,6 +7,7 @@ from typing import Annotated, Any, Literal ...@@ -7,6 +7,7 @@ from typing import Annotated, Any, Literal
import torch import torch
from torch import nn from torch import nn
from transformers import BatchFeature, Gemma3Config, Gemma3Processor from transformers import BatchFeature, Gemma3Config, Gemma3Processor
from transformers.models.gemma3.image_processing_gemma3 import Gemma3ImageProcessor
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -84,55 +85,36 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): ...@@ -84,55 +85,36 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None} return {"image": None}
def _resolve_image_kwargs(
self,
processor: Gemma3Processor,
keys: set[str],
) -> dict[str, Any]:
image_processor = processor.image_processor
kwargs = processor._merge_kwargs(
Gemma3ProcessorKwargs,
tokenizer_init_kwargs=processor.tokenizer.init_kwargs,
)
images_kwargs = kwargs["images_kwargs"]
def _resolve_kw(key: str):
val = getattr(image_processor, key)
if val is None:
val = images_kwargs[key]
return val
return {k: _resolve_kw(k) for k in keys}
def get_num_crops( def get_num_crops(
self, self,
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: Gemma3Processor | None, processor: Gemma3Processor,
mm_kwargs: Mapping[str, object],
) -> int: ) -> int:
if processor is None: image_processor: Gemma3ImageProcessor = processor.image_processor
processor = self.get_hf_processor()
images_kwargs = processor._merge_kwargs(
Gemma3ProcessorKwargs,
tokenizer_init_kwargs=processor.tokenizer.init_kwargs,
**self.ctx.get_merged_mm_kwargs(mm_kwargs),
)["images_kwargs"]
images_kwargs = self._resolve_image_kwargs( do_pan_and_scan = images_kwargs.get(
processor, "do_pan_and_scan", image_processor.do_pan_and_scan
{ )
"do_pan_and_scan", pan_and_scan_min_crop_size = images_kwargs.get(
"pan_and_scan_min_crop_size", "pan_and_scan_min_crop_size", image_processor.pan_and_scan_min_crop_size
"pan_and_scan_max_num_crops", )
pan_and_scan_max_num_crops = images_kwargs.get(
"pan_and_scan_max_num_crops", image_processor.pan_and_scan_max_num_crops
)
pan_and_scan_min_ratio_to_activate = images_kwargs.get(
"pan_and_scan_min_ratio_to_activate", "pan_and_scan_min_ratio_to_activate",
}, image_processor.pan_and_scan_min_ratio_to_activate,
) )
do_pan_and_scan = images_kwargs["do_pan_and_scan"]
pan_and_scan_min_crop_size = images_kwargs["pan_and_scan_min_crop_size"]
pan_and_scan_max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
pan_and_scan_min_ratio_to_activate = images_kwargs[
"pan_and_scan_min_ratio_to_activate"
]
if not do_pan_and_scan: if not do_pan_and_scan:
return 0 return 0
...@@ -180,17 +162,16 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): ...@@ -180,17 +162,16 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: Gemma3Processor | None, processor: Gemma3Processor,
mm_kwargs: Mapping[str, object],
) -> PromptUpdateDetails[str]: ) -> PromptUpdateDetails[str]:
if processor is None:
processor = self.get_hf_processor()
boi_token = processor.boi_token boi_token = processor.boi_token
num_crops = self.get_num_crops( num_crops = self.get_num_crops(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
processor=processor, processor=processor,
mm_kwargs=mm_kwargs,
) )
if num_crops == 0: if num_crops == 0:
...@@ -215,15 +196,14 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): ...@@ -215,15 +196,14 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: Gemma3Processor | None, processor: Gemma3Processor,
mm_kwargs: Mapping[str, object],
) -> int: ) -> int:
if processor is None:
processor = self.get_hf_processor()
num_crops = self.get_num_crops( num_crops = self.get_num_crops(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
processor=processor, processor=processor,
mm_kwargs=mm_kwargs,
) )
image_seq_len = processor.image_seq_length image_seq_len = processor.image_seq_length
...@@ -231,11 +211,17 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): ...@@ -231,11 +211,17 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor() processor = self.get_hf_processor()
image_processor: Gemma3ImageProcessor = processor.image_processor
images_kwargs = processor._merge_kwargs(
Gemma3ProcessorKwargs,
tokenizer_init_kwargs=processor.tokenizer.init_kwargs,
**self.ctx.get_merged_mm_kwargs({}),
)["images_kwargs"]
images_kwargs = self._resolve_image_kwargs( max_num_crops = images_kwargs.get(
processor, {"pan_and_scan_max_num_crops"} "pan_and_scan_max_num_crops", image_processor.pan_and_scan_max_num_crops
) )
max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
vision_config = self.get_hf_config().vision_config vision_config = self.get_hf_config().vision_config
native_size = vision_config.image_size native_size = vision_config.image_size
...@@ -303,6 +289,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): ...@@ -303,6 +289,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
image_width=size.width, image_width=size.width,
image_height=size.height, image_height=size.height,
processor=hf_processor, processor=hf_processor,
mm_kwargs=mm_kwargs,
) )
for size in image_sizes for size in image_sizes
] ]
...@@ -339,6 +326,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): ...@@ -339,6 +326,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
image_width=image_size.width, image_width=image_size.width,
image_height=image_size.height, image_height=image_size.height,
processor=hf_processor, processor=hf_processor,
mm_kwargs=hf_processor_mm_kwargs,
) )
return [ return [
......
...@@ -131,7 +131,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo): ...@@ -131,7 +131,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: Gemma3nProcessor | None, processor: Gemma3nProcessor,
) -> str: ) -> str:
""" """
Get the replacement text for image tokens. Get the replacement text for image tokens.
...@@ -139,9 +139,6 @@ class Gemma3nProcessingInfo(BaseProcessingInfo): ...@@ -139,9 +139,6 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
For Gemma3n, this should return the full_image_sequence which includes For Gemma3n, this should return the full_image_sequence which includes
BOI token, repeated image tokens, and EOI token. BOI token, repeated image tokens, and EOI token.
""" """
if processor is None:
processor = self.get_hf_processor()
return PromptUpdateDetails.select_token_id( return PromptUpdateDetails.select_token_id(
processor.full_image_sequence, processor.image_token_id processor.full_image_sequence, processor.image_token_id
) )
...@@ -149,7 +146,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo): ...@@ -149,7 +146,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
def get_audio_repl( def get_audio_repl(
self, self,
*, *,
processor: Gemma3nProcessor | None, processor: Gemma3nProcessor,
) -> str: ) -> str:
""" """
Get the replacement text for audio tokens. Get the replacement text for audio tokens.
...@@ -157,9 +154,6 @@ class Gemma3nProcessingInfo(BaseProcessingInfo): ...@@ -157,9 +154,6 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
For Gemma3n, this should return the full_audio_sequence which includes For Gemma3n, this should return the full_audio_sequence which includes
BOA token, repeated audio tokens, and EOA token. BOA token, repeated audio tokens, and EOA token.
""" """
if processor is None:
processor = self.get_hf_processor()
# Return the full audio sequence as defined by the processor # Return the full audio sequence as defined by the processor
return PromptUpdateDetails.select_token_id( return PromptUpdateDetails.select_token_id(
processor.full_audio_sequence, processor.audio_token_id processor.full_audio_sequence, processor.audio_token_id
......
...@@ -424,12 +424,9 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo): ...@@ -424,12 +424,9 @@ class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: H2OVLProcessor | None, processor: H2OVLProcessor,
use_msac: bool | None = None, use_msac: bool | None = None,
) -> int: ) -> int:
if processor is None:
processor = self.get_hf_processor()
return processor.get_num_image_tokens( return processor.get_num_image_tokens(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
......
...@@ -78,7 +78,10 @@ from vllm.transformers_utils.configs.hunyuan_vl import ( ...@@ -78,7 +78,10 @@ from vllm.transformers_utils.configs.hunyuan_vl import (
HunYuanVLVisionConfig, HunYuanVLVisionConfig,
) )
from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor
from vllm.transformers_utils.processors.hunyuan_vl_image import smart_resize from vllm.transformers_utils.processors.hunyuan_vl_image import (
HunYuanVLImageProcessor,
smart_resize,
)
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import ( from .interfaces import (
...@@ -596,7 +599,7 @@ class HunYuanVLProcessingInfo(BaseProcessingInfo): ...@@ -596,7 +599,7 @@ class HunYuanVLProcessingInfo(BaseProcessingInfo):
def get_image_processor( def get_image_processor(
self, self,
**kwargs: object, **kwargs: object,
) -> HunYuanVLProcessor: ) -> HunYuanVLImageProcessor:
return self.get_hf_processor(**kwargs).image_processor return self.get_hf_processor(**kwargs).image_processor
def get_data_parser(self): def get_data_parser(self):
...@@ -624,23 +627,24 @@ class HunYuanVLProcessingInfo(BaseProcessingInfo): ...@@ -624,23 +627,24 @@ class HunYuanVLProcessingInfo(BaseProcessingInfo):
image_height: int, image_height: int,
num_frames: int = 1, num_frames: int = 1,
do_resize: bool = True, do_resize: bool = True,
image_processor: HunYuanVLProcessor | None, image_processor: HunYuanVLImageProcessor,
mm_kwargs: Mapping[str, object],
) -> tuple[ImageSize, int]: ) -> tuple[ImageSize, int]:
if image_processor is None:
image_processor = self.get_image_processor()
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
patch_size = vision_config.patch_size patch_size = vision_config.patch_size
spatial_merge_size = vision_config.spatial_merge_size spatial_merge_size = vision_config.spatial_merge_size
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
size = mm_kwargs.get("size", image_processor.size)
if do_resize: if do_resize:
resized_height, resized_width = smart_resize( resized_height, resized_width = smart_resize(
height=image_height, height=image_height,
width=image_width, width=image_width,
factor=patch_size * spatial_merge_size, factor=patch_size * spatial_merge_size,
min_pixels=image_processor.min_pixels, min_pixels=size["shortest_edge"],
max_pixels=image_processor.max_pixels, max_pixels=size["longest_edge"],
) )
preprocessed_size = ImageSize(width=resized_width, height=resized_height) preprocessed_size = ImageSize(width=resized_width, height=resized_height)
else: else:
...@@ -662,29 +666,37 @@ class HunYuanVLProcessingInfo(BaseProcessingInfo): ...@@ -662,29 +666,37 @@ class HunYuanVLProcessingInfo(BaseProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
image_processor: HunYuanVLProcessor | None, image_processor: HunYuanVLImageProcessor,
mm_kwargs: Mapping[str, object],
) -> int: ) -> int:
_, num_image_tokens = self._get_vision_info( _, num_image_tokens = self._get_vision_info(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
image_processor=image_processor, image_processor=image_processor,
mm_kwargs=mm_kwargs,
) )
return num_image_tokens return num_image_tokens
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_image_processor()
max_image_size, _ = self._get_vision_info( max_image_size, _ = self._get_vision_info(
image_width=512, image_width=512,
image_height=8192, image_height=8192,
image_processor=None, image_processor=image_processor,
mm_kwargs={},
) )
return max_image_size return max_image_size
def get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
image_processor = self.get_image_processor()
target_width, target_height = self.get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens( return self.get_num_image_tokens(
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
image_processor=None, image_processor=image_processor,
mm_kwargs={},
) )
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
# limitations under the License. # limitations under the License.
"""Inference-only Idefics3 model compatible with HuggingFace weights.""" """Inference-only Idefics3 model compatible with HuggingFace weights."""
import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal, TypeAlias from typing import Annotated, Literal, TypeAlias
...@@ -168,54 +167,35 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): ...@@ -168,54 +167,35 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: Idefics3Processor | None, processor: Idefics3Processor,
) -> tuple[int, int]: mm_kwargs: Mapping[str, object],
if processor is None: ) -> tuple[int, int, int]:
processor = self.get_hf_processor()
image_processor: Idefics3ImageProcessor = processor.image_processor image_processor: Idefics3ImageProcessor = processor.image_processor
max_image_size = image_processor.max_image_size["longest_edge"] return image_processor.get_number_of_image_patches(
size = image_processor.size["longest_edge"] image_height,
assert size % max_image_size == 0, ( image_width,
"`longest_edge` in image_processor's `size` must be divisible by " self.ctx.get_merged_mm_kwargs(mm_kwargs),
"`longest_edge` in `max_image_size`, this may be caused by "
"incorrect mm_kwargs override."
)
resized_height, resized_width = self._get_resize_output_image_size(
image_width=image_width,
image_height=image_height,
resolution_max_side=size,
) )
if resized_height > max_image_size or resized_width > max_image_size:
grid_h = math.ceil(resized_height / max_image_size)
grid_w = math.ceil(resized_width / max_image_size)
else:
grid_h = grid_w = 0
return grid_w, grid_h
def get_num_patches( def get_num_patches(
self, self,
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: Idefics3Processor | None, processor: Idefics3Processor,
mm_kwargs: Mapping[str, object],
) -> int: ) -> int:
grid_w, grid_h = self._get_image_feature_grid_size( num_patches, _, _ = self._get_image_feature_grid_size(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
processor=processor, processor=processor,
mm_kwargs=mm_kwargs,
) )
return grid_w * grid_h + 1 return num_patches
def _get_image_token(
self, processor: Idefics3Processor | None
) -> tuple[str, str, str]:
if processor is None:
processor = self.get_hf_processor()
def _get_image_token(self, processor: Idefics3Processor) -> tuple[str, str, str]:
image_token = processor.image_token image_token = processor.image_token
fake_image_token = processor.fake_image_token fake_image_token = processor.fake_image_token
global_image_token = processor.global_image_tag global_image_token = processor.global_image_tag
...@@ -226,11 +206,9 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): ...@@ -226,11 +206,9 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: Idefics3Processor | None, processor: Idefics3Processor,
mm_kwargs: Mapping[str, object],
) -> str: ) -> str:
if processor is None:
processor = self.get_hf_processor()
image_token, fake_image_token, global_img_token = self._get_image_token( image_token, fake_image_token, global_img_token = self._get_image_token(
processor processor
) )
...@@ -241,10 +219,11 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): ...@@ -241,10 +219,11 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
global_img_placeholder = fake_image_token + global_img_token + p_img global_img_placeholder = fake_image_token + global_img_token + p_img
tile_img_placeholder = fake_image_token + grid_placeholder + p_img tile_img_placeholder = fake_image_token + grid_placeholder + p_img
grid_w, grid_h = self._get_image_feature_grid_size( _, grid_h, grid_w = self._get_image_feature_grid_size(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
processor=processor, processor=processor,
mm_kwargs=mm_kwargs,
) )
if grid_w == 0 and grid_h == 0: if grid_w == 0 and grid_h == 0:
return global_img_placeholder + fake_image_token return global_img_placeholder + fake_image_token
...@@ -272,15 +251,14 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): ...@@ -272,15 +251,14 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: Idefics3Processor | None, processor: Idefics3Processor,
mm_kwargs: Mapping[str, object],
) -> int: ) -> int:
if processor is None:
processor = self.get_hf_processor()
num_patches = self.get_num_patches( num_patches = self.get_num_patches(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
processor=processor, processor=processor,
mm_kwargs=mm_kwargs,
) )
return num_patches * processor.image_seq_len return num_patches * processor.image_seq_len
...@@ -353,6 +331,7 @@ class Idefics3MultiModalProcessor(BaseMultiModalProcessor[Idefics3ProcessingInfo ...@@ -353,6 +331,7 @@ class Idefics3MultiModalProcessor(BaseMultiModalProcessor[Idefics3ProcessingInfo
image_width=size.width, image_width=size.width,
image_height=size.height, image_height=size.height,
processor=hf_processor, processor=hf_processor,
mm_kwargs=mm_kwargs,
) )
for size in image_sizes for size in image_sizes
] ]
...@@ -398,6 +377,7 @@ class Idefics3MultiModalProcessor(BaseMultiModalProcessor[Idefics3ProcessingInfo ...@@ -398,6 +377,7 @@ class Idefics3MultiModalProcessor(BaseMultiModalProcessor[Idefics3ProcessingInfo
image_width=image_size.width, image_width=image_size.width,
image_height=image_size.height, image_height=image_size.height,
processor=hf_processor, processor=hf_processor,
mm_kwargs=hf_processor_mm_kwargs,
) )
return PromptUpdateDetails.select_text( return PromptUpdateDetails.select_text(
......
...@@ -197,20 +197,18 @@ class InternS1ProcessingInfo(BaseProcessingInfo): ...@@ -197,20 +197,18 @@ class InternS1ProcessingInfo(BaseProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: GotOcr2ImageProcessorFast | None = None, processor: InternVLProcessor,
mm_kwargs: Mapping[str, object],
) -> int: ) -> int:
if processor is None: image_processor: GotOcr2ImageProcessorFast = processor.image_processor
processor = self.get_hf_processor().image_processor
if not isinstance(processor, GotOcr2ImageProcessorFast): num_image_patches = image_processor.get_number_of_image_patches(
raise ValueError( image_height,
f"GotOcr2ImageProcessorFast is expected but got {type(processor)}" image_width,
) self.ctx.get_merged_mm_kwargs(mm_kwargs),
num_image_patches = processor.get_number_of_image_patches(
image_height, image_width, images_kwargs=dict()
) )
num_image_tokens = self.get_hf_processor().image_seq_length * num_image_patches
return num_image_tokens return processor.image_seq_length * num_image_patches
def resolve_target_ratios(self, use_thumbnail: bool | None = None): def resolve_target_ratios(self, use_thumbnail: bool | None = None):
image_processor = self.get_hf_processor().image_processor image_processor = self.get_hf_processor().image_processor
...@@ -243,7 +241,8 @@ class InternS1ProcessingInfo(BaseProcessingInfo): ...@@ -243,7 +241,8 @@ class InternS1ProcessingInfo(BaseProcessingInfo):
feat_size = self.get_num_image_tokens( feat_size = self.get_num_image_tokens(
image_width=width, image_width=width,
image_height=height, image_height=height,
processor=processor.image_processor, processor=processor,
mm_kwargs={},
) )
if feat_size > largest_feature_size: if feat_size > largest_feature_size:
largest_feature_size = feat_size largest_feature_size = feat_size
...@@ -262,7 +261,8 @@ class InternS1ProcessingInfo(BaseProcessingInfo): ...@@ -262,7 +261,8 @@ class InternS1ProcessingInfo(BaseProcessingInfo):
return self.get_num_image_tokens( return self.get_num_image_tokens(
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
processor=processor.image_processor, processor=processor,
mm_kwargs={},
) )
def get_num_frames_with_most_features( def get_num_frames_with_most_features(
......
...@@ -705,11 +705,8 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo): ...@@ -705,11 +705,8 @@ class BaseInternVLProcessingInfo(BaseProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: BaseInternVLProcessor | None, processor: BaseInternVLProcessor,
) -> int: ) -> int:
if processor is None:
processor = self.get_hf_processor()
return processor.get_num_image_tokens( return processor.get_num_image_tokens(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
......
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from einops import rearrange from einops import rearrange
from transformers import PretrainedConfig from transformers import BaseImageProcessor, PretrainedConfig
from transformers.activations import GELUActivation from transformers.activations import GELUActivation
from transformers.feature_extraction_utils import BatchFeature from transformers.feature_extraction_utils import BatchFeature
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
...@@ -1011,24 +1011,25 @@ class KeyeProcessingInfo(BaseProcessingInfo): ...@@ -1011,24 +1011,25 @@ class KeyeProcessingInfo(BaseProcessingInfo):
image_height: int, image_height: int,
num_frames: int = 1, num_frames: int = 1,
do_resize: bool = True, do_resize: bool = True,
image_processor, image_processor: BaseImageProcessor,
mm_kwargs: Mapping[str, object],
) -> tuple[ImageSize, int]: ) -> tuple[ImageSize, int]:
if image_processor is None:
image_processor = self.get_image_processor()
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
patch_size = vision_config.patch_size patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size merge_size = vision_config.spatial_merge_size
temporal_patch_size = 1 temporal_patch_size = 1
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
size = mm_kwargs.get("size", image_processor.size)
if do_resize: if do_resize:
resized_height, resized_width = smart_resize( resized_height, resized_width = smart_resize(
height=image_height, height=image_height,
width=image_width, width=image_width,
factor=patch_size * merge_size, factor=patch_size * merge_size,
min_pixels=image_processor.min_pixels, min_pixels=size["min_pixels"],
max_pixels=image_processor.max_pixels, max_pixels=size["max_pixels"],
) )
preprocessed_size = ImageSize(width=resized_width, height=resized_height) preprocessed_size = ImageSize(width=resized_width, height=resized_height)
else: else:
...@@ -1050,12 +1051,14 @@ class KeyeProcessingInfo(BaseProcessingInfo): ...@@ -1050,12 +1051,14 @@ class KeyeProcessingInfo(BaseProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
image_processor, image_processor: BaseImageProcessor,
mm_kwargs: Mapping[str, object],
) -> int: ) -> int:
_, num_image_tokens = self._get_vision_info( _, num_image_tokens = self._get_vision_info(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
image_processor=image_processor, image_processor=image_processor,
mm_kwargs=mm_kwargs,
) )
return num_image_tokens return num_image_tokens
...@@ -1065,36 +1068,42 @@ class KeyeProcessingInfo(BaseProcessingInfo): ...@@ -1065,36 +1068,42 @@ class KeyeProcessingInfo(BaseProcessingInfo):
image_width: int, image_width: int,
image_height: int, image_height: int,
num_frames: int, num_frames: int,
image_processor, image_processor: BaseImageProcessor,
mm_kwargs: Mapping[str, object],
) -> int: ) -> int:
_, num_video_tokens = self._get_vision_info( _, num_video_tokens = self._get_vision_info(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
num_frames=num_frames, num_frames=num_frames,
image_processor=image_processor, image_processor=image_processor,
mm_kwargs=mm_kwargs,
) )
return num_video_tokens return num_video_tokens
def get_image_size_with_most_features( def get_image_size_with_most_features(self) -> ImageSize:
self, image_processor = self.get_image_processor()
) -> ImageSize:
max_image_size, _ = self._get_vision_info( max_image_size, _ = self._get_vision_info(
image_width=self.get_max_image_size(), image_width=self.get_max_image_size(),
image_height=self.get_max_image_size(), image_height=self.get_max_image_size(),
image_processor=None, image_processor=image_processor,
mm_kwargs={},
) )
return max_image_size return max_image_size
def get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
image_processor = self.get_image_processor()
target_width, target_height = self.get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens( return self.get_num_image_tokens(
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
image_processor=None, image_processor=image_processor,
mm_kwargs={},
) )
def _get_max_video_frames(self, max_tokens: int) -> int: def _get_max_video_frames(self, max_tokens: int) -> int:
image_processor = self.get_image_processor()
target_width, target_height = self.get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
num_frames = 0 num_frames = 0
...@@ -1105,7 +1114,8 @@ class KeyeProcessingInfo(BaseProcessingInfo): ...@@ -1105,7 +1114,8 @@ class KeyeProcessingInfo(BaseProcessingInfo):
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
num_frames=next_num_frames, num_frames=next_num_frames,
image_processor=None, image_processor=image_processor,
mm_kwargs={},
) )
if next_max_tokens > max_tokens: if next_max_tokens > max_tokens:
...@@ -1130,13 +1140,15 @@ class KeyeProcessingInfo(BaseProcessingInfo): ...@@ -1130,13 +1140,15 @@ class KeyeProcessingInfo(BaseProcessingInfo):
return max(max_frames_per_video, 1) return max(max_frames_per_video, 1)
def get_max_video_tokens(self, seq_len: int) -> int: def get_max_video_tokens(self, seq_len: int) -> int:
image_processor = self.get_image_processor()
target_width, target_height = self.get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_video_tokens( return self.get_num_video_tokens(
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len), num_frames=self.get_num_frames_with_most_features(seq_len),
image_processor=None, image_processor=image_processor,
mm_kwargs={},
) )
......
...@@ -176,7 +176,7 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo): ...@@ -176,7 +176,7 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
min_tiles: int, min_tiles: int,
max_tiles: int, max_tiles: int,
tile_size: int, tile_size: int,
) -> tuple[int, int]: ) -> tuple[int, int, int]:
aspect_ratio = width / height aspect_ratio = width / height
target_ratios = self._target_ratios(min_tiles, max_tiles) target_ratios = self._target_ratios(min_tiles, max_tiles)
# find best matching grid configuration # find best matching grid configuration
...@@ -190,18 +190,27 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo): ...@@ -190,18 +190,27 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
self, self,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: Lfm2VlProcessor | None, processor: Lfm2VlProcessor,
) -> tuple[int, int]: mm_kwargs: Mapping[str, object],
if processor is None: ) -> tuple[int, int, int]:
processor = self.get_image_processor() image_processor: Lfm2VlImageProcessorFast = processor.image_processor
downsample_factor = processor.image_processor.downsample_factor mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
encoder_patch_size = processor.image_processor.encoder_patch_size downsample_factor = mm_kwargs.get(
max_pixels_tolerance = processor.image_processor.max_pixels_tolerance "downsample_factor", image_processor.downsample_factor
min_tiles = processor.image_processor.min_tiles )
max_tiles = processor.image_processor.max_tiles encoder_patch_size = mm_kwargs.get(
max_image_tokens = processor.image_processor.max_image_tokens "encoder_patch_size", image_processor.encoder_patch_size
tile_size = processor.image_processor.tile_size )
max_pixels_tolerance = mm_kwargs.get(
"max_pixels_tolerance", image_processor.max_pixels_tolerance
)
min_tiles = mm_kwargs.get("min_tiles", image_processor.min_tiles)
max_tiles = mm_kwargs.get("max_tiles", image_processor.max_tiles)
max_image_tokens = mm_kwargs.get(
"max_image_tokens", image_processor.max_image_tokens
)
tile_size = mm_kwargs.get("tile_size", image_processor.tile_size)
do_image_splitting = not min_tiles == max_tiles == 1 do_image_splitting = not min_tiles == max_tiles == 1
is_image_large = self._is_image_too_large( is_image_large = self._is_image_too_large(
...@@ -235,12 +244,14 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo): ...@@ -235,12 +244,14 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: Lfm2VlProcessor | None, processor: Lfm2VlProcessor,
mm_kwargs: Mapping[str, object],
) -> int: ) -> int:
_, _, total_patches = self._get_image_feature_grid_size( _, _, total_patches = self._get_image_feature_grid_size(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
processor=processor, processor=processor,
mm_kwargs=mm_kwargs,
) )
return total_patches return total_patches
...@@ -249,11 +260,9 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo): ...@@ -249,11 +260,9 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
image_width: int, image_width: int,
image_height: int, image_height: int,
spatial_shapes: torch.Tensor, spatial_shapes: torch.Tensor,
processor: Lfm2VlProcessor | None, processor: Lfm2VlProcessor,
mm_kwargs: Mapping[str, object],
) -> str: ) -> str:
if processor is None:
processor = self.get_hf_processor()
grid_placeholder = "<|img_row_{n_h}_col_{n_w}|>" grid_placeholder = "<|img_row_{n_h}_col_{n_w}|>"
image_token = processor.image_token image_token = processor.image_token
image_start_token = processor.image_start_token image_start_token = processor.image_start_token
...@@ -263,6 +272,7 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo): ...@@ -263,6 +272,7 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
num_thumbnail_tokens, num_tokens_per_tile = self.get_num_image_tokens( num_thumbnail_tokens, num_tokens_per_tile = self.get_num_image_tokens(
spatial_shapes=spatial_shapes, spatial_shapes=spatial_shapes,
processor=processor, processor=processor,
mm_kwargs=mm_kwargs,
) )
tile_img_placeholder = grid_placeholder + (image_token * num_tokens_per_tile) tile_img_placeholder = grid_placeholder + (image_token * num_tokens_per_tile)
...@@ -270,6 +280,7 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo): ...@@ -270,6 +280,7 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
processor=processor, processor=processor,
mm_kwargs=mm_kwargs,
) )
if grid_w > 1 or grid_h > 1: if grid_w > 1 or grid_h > 1:
...@@ -295,15 +306,25 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo): ...@@ -295,15 +306,25 @@ class Lfm2VLProcessingInfo(BaseProcessingInfo):
self, self,
*, *,
spatial_shapes: torch.Tensor, spatial_shapes: torch.Tensor,
processor: Lfm2VlProcessor | None, processor: Lfm2VlProcessor,
mm_kwargs: Mapping[str, object],
) -> tuple[int, int]: ) -> tuple[int, int]:
tile_size = processor.image_processor.tile_size image_processor: Lfm2VlImageProcessorFast = processor.image_processor
downsample_factor = processor.image_processor.downsample_factor
encoder_patch_size = processor.image_processor.encoder_patch_size mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
downsample_factor = mm_kwargs.get(
"downsample_factor", image_processor.downsample_factor
)
encoder_patch_size = mm_kwargs.get(
"encoder_patch_size", image_processor.encoder_patch_size
)
tile_size = mm_kwargs.get("tile_size", image_processor.tile_size)
num_thumbnail_tokens = spatial_shapes[-1].prod() // (downsample_factor**2) num_thumbnail_tokens = spatial_shapes[-1].prod() // (downsample_factor**2)
num_patches_tile = tile_size // encoder_patch_size num_patches_tile = tile_size // encoder_patch_size
dwn_num_patches_tile = math.ceil(num_patches_tile / downsample_factor) dwn_num_patches_tile = math.ceil(num_patches_tile / downsample_factor)
num_tiles_tokens = dwn_num_patches_tile * dwn_num_patches_tile num_tiles_tokens = dwn_num_patches_tile * dwn_num_patches_tile
return num_thumbnail_tokens, num_tiles_tokens return num_thumbnail_tokens, num_tiles_tokens
...@@ -372,6 +393,7 @@ class Lfm2VLMultiModalProcessor(BaseMultiModalProcessor[Lfm2VLProcessingInfo]): ...@@ -372,6 +393,7 @@ class Lfm2VLMultiModalProcessor(BaseMultiModalProcessor[Lfm2VLProcessingInfo]):
image_width=size.width, image_width=size.width,
image_height=size.height, image_height=size.height,
processor=hf_processor, processor=hf_processor,
mm_kwargs=mm_kwargs,
) )
for size in image_sizes for size in image_sizes
] ]
...@@ -414,6 +436,7 @@ class Lfm2VLMultiModalProcessor(BaseMultiModalProcessor[Lfm2VLProcessingInfo]): ...@@ -414,6 +436,7 @@ class Lfm2VLMultiModalProcessor(BaseMultiModalProcessor[Lfm2VLProcessingInfo]):
image_height=image_size.height, image_height=image_size.height,
spatial_shapes=spatial_shapes, spatial_shapes=spatial_shapes,
processor=hf_processor, processor=hf_processor,
mm_kwargs=hf_processor_mm_kwargs,
) )
return PromptUpdateDetails.select_text( return PromptUpdateDetails.select_text(
image_repl, image_repl,
......
...@@ -1224,11 +1224,8 @@ class MolmoProcessingInfo(BaseProcessingInfo): ...@@ -1224,11 +1224,8 @@ class MolmoProcessingInfo(BaseProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: MolmoProcessorWrapper | None, processor: MolmoProcessorWrapper,
) -> int: ) -> int:
if processor is None:
processor = self.get_hf_processor()
ncols, nrows = processor.get_patches_grid_size( ncols, nrows = processor.get_patches_grid_size(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
......
...@@ -1869,12 +1869,9 @@ class Molmo2ProcessingInfo(BaseProcessingInfo): ...@@ -1869,12 +1869,9 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
*, *,
image_height: int, image_height: int,
image_width: int, image_width: int,
processor: Molmo2ProcessorWrapper | None = None, processor: Molmo2ProcessorWrapper,
) -> int: ) -> int:
if processor is None: hf_processor = processor.processor
processor = self.get_hf_processor()
hf_processor = processor.processor # type: ignore
resize_nrows, resize_cols = processor.get_base_grid_size(is_video=False) resize_nrows, resize_cols = processor.get_base_grid_size(is_video=False)
# start/end tokens + image patch token + col tokens # start/end tokens + image patch token + col tokens
...@@ -1897,11 +1894,8 @@ class Molmo2ProcessingInfo(BaseProcessingInfo): ...@@ -1897,11 +1894,8 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
self, self,
*, *,
num_frames: int, num_frames: int,
processor: Molmo2ProcessorWrapper | None = None, processor: Molmo2ProcessorWrapper,
) -> int: ) -> int:
if processor is None:
processor = self.get_hf_processor()
resize_nrows, resize_cols = processor.get_base_grid_size(is_video=True) resize_nrows, resize_cols = processor.get_base_grid_size(is_video=True)
# start/end tokens # start/end tokens
extra = 2 + resize_nrows * ( extra = 2 + resize_nrows * (
...@@ -1929,7 +1923,9 @@ class Molmo2ProcessingInfo(BaseProcessingInfo): ...@@ -1929,7 +1923,9 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
width = wr * crop_window_size + total_margin_pixels width = wr * crop_window_size + total_margin_pixels
feat_size = self.get_num_image_tokens( feat_size = self.get_num_image_tokens(
image_height=height, image_width=width, processor=processor image_height=height,
image_width=width,
processor=processor,
) )
if feat_size > largest_feature_size: if feat_size > largest_feature_size:
largest_feature_size = feat_size largest_feature_size = feat_size
...@@ -1940,8 +1936,15 @@ class Molmo2ProcessingInfo(BaseProcessingInfo): ...@@ -1940,8 +1936,15 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
return largest_feature_pinpoint return largest_feature_pinpoint
def _get_max_video_frames(self, max_tokens: int) -> int: def _get_max_video_frames(
num_tokens_per_frame = self.get_num_video_tokens(num_frames=1) self,
max_tokens: int,
processor: Molmo2ProcessorWrapper,
) -> int:
num_tokens_per_frame = self.get_num_video_tokens(
num_frames=1,
processor=processor,
)
max_frames = max_tokens // num_tokens_per_frame max_frames = max_tokens // num_tokens_per_frame
return max(max_frames, 1) return max(max_frames, 1)
...@@ -1950,10 +1953,11 @@ class Molmo2ProcessingInfo(BaseProcessingInfo): ...@@ -1950,10 +1953,11 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> int: ) -> int:
video_processor = self.get_hf_processor().processor.video_processor processor = self.get_hf_processor()
video_processor = processor.processor.video_processor
num_frames = video_processor.num_frames num_frames = video_processor.num_frames
max_videos = mm_counts.get("video", 0) max_videos = mm_counts.get("video", 0)
max_total_frames = self._get_max_video_frames(seq_len) max_total_frames = self._get_max_video_frames(seq_len, processor)
max_frames_per_video = min( max_frames_per_video = min(
max_total_frames // max(max_videos, 1), max_total_frames // max(max_videos, 1),
num_frames, num_frames,
......
...@@ -215,7 +215,7 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo): ...@@ -215,7 +215,7 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo):
image_width: int, image_width: int,
image_height: int, image_height: int,
num_frames: int = 1, num_frames: int = 1,
) -> tuple[ImageSize, int]: ) -> int:
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
vit_config = hf_config.vit_config vit_config = hf_config.vit_config
patch_size = vit_config.patch_size patch_size = vit_config.patch_size
...@@ -245,7 +245,6 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo): ...@@ -245,7 +245,6 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo):
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
num_frames=next_num_frames, num_frames=next_num_frames,
image_processor=None,
) )
if next_max_tokens > max_tokens: if next_max_tokens > max_tokens:
break break
...@@ -270,7 +269,6 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo): ...@@ -270,7 +269,6 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo):
image_width: int, image_width: int,
image_height: int, image_height: int,
num_frames: int, num_frames: int,
image_processor: BaseImageProcessor | None,
) -> int: ) -> int:
num_video_tokens = self.get_num_image_tokens( num_video_tokens = self.get_num_image_tokens(
image_width=image_width, image_height=image_height, num_frames=num_frames image_width=image_width, image_height=image_height, num_frames=num_frames
...@@ -287,7 +285,6 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo): ...@@ -287,7 +285,6 @@ class Ovis2_5ProcessingInfo(BaseProcessingInfo):
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts), num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
image_processor=None,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment