Unverified Commit 372b2e76 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Standardize getting number of image patches/tokens (#34358)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 6afa587d
...@@ -23,7 +23,7 @@ import numpy as np ...@@ -23,7 +23,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from einops import rearrange from einops import rearrange
from transformers import BatchFeature, PretrainedConfig from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig
from transformers.activations import GELUActivation from transformers.activations import GELUActivation
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
...@@ -147,21 +147,23 @@ class PaddleOCRVLProcessingInfo(BaseProcessingInfo): ...@@ -147,21 +147,23 @@ class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
image_processor, image_processor: BaseImageProcessor,
mm_kwargs: Mapping[str, object],
) -> int: ) -> int:
if image_processor is None:
image_processor = self.get_image_processor()
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
patch_size = vision_config.patch_size patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size merge_size = vision_config.spatial_merge_size
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
size = mm_kwargs.get("size", image_processor.size)
resized_height, resized_width = smart_resize( resized_height, resized_width = smart_resize(
height=image_height, height=image_height,
width=image_width, width=image_width,
factor=patch_size * merge_size, factor=patch_size * merge_size,
min_pixels=image_processor.min_pixels, min_pixels=size["min_pixels"],
max_pixels=image_processor.max_pixels, max_pixels=size["max_pixels"],
) )
preprocessed_size = ImageSize(width=resized_width, height=resized_height) preprocessed_size = ImageSize(width=resized_width, height=resized_height)
...@@ -176,12 +178,13 @@ class PaddleOCRVLProcessingInfo(BaseProcessingInfo): ...@@ -176,12 +178,13 @@ class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
image_processor = self.get_image_processor()
# See `smart_resize` for the calculation of the image size. # See `smart_resize` for the calculation of the image size.
merge_size = hf_config.vision_config.spatial_merge_size merge_size = hf_config.vision_config.spatial_merge_size
patch_size = hf_config.vision_config.patch_size patch_size = hf_config.vision_config.patch_size
factor = merge_size * patch_size factor = merge_size * patch_size
max_num_tokens = self.get_image_processor().max_pixels // (factor**2) max_num_tokens = image_processor.max_pixels // (factor**2)
# Find factors of max_num_tokens close to its square root # Find factors of max_num_tokens close to its square root
# to create a dummy image with a reasonable aspect ratio. # to create a dummy image with a reasonable aspect ratio.
h_patches = int(math.sqrt(max_num_tokens)) h_patches = int(math.sqrt(max_num_tokens))
...@@ -276,6 +279,7 @@ class PaddleOCRVLMultiModalProcessor( ...@@ -276,6 +279,7 @@ class PaddleOCRVLMultiModalProcessor(
image_width=image_size.width, image_width=image_size.width,
image_height=image_size.height, image_height=image_size.height,
image_processor=image_processor, image_processor=image_processor,
mm_kwargs=hf_processor_mm_kwargs,
) )
return [image_token_id] * num_image_tokens return [image_token_id] * num_image_tokens
......
...@@ -351,11 +351,8 @@ class Phi3VProcessingInfo(BaseProcessingInfo): ...@@ -351,11 +351,8 @@ class Phi3VProcessingInfo(BaseProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: ProcessorMixin | None = None, processor: ProcessorMixin,
) -> int: ) -> int:
if processor is None:
processor = self.get_hf_processor()
return processor.calc_num_image_tokens_from_image_size( # type: ignore return processor.calc_num_image_tokens_from_image_size( # type: ignore
width=image_width, width=image_width,
height=image_height, height=image_height,
......
...@@ -558,10 +558,8 @@ class Phi4MMProcessingInfo(BaseProcessingInfo): ...@@ -558,10 +558,8 @@ class Phi4MMProcessingInfo(BaseProcessingInfo):
def get_dynamic_hd( def get_dynamic_hd(
self, self,
processor: ProcessorMixin | None = None, processor: ProcessorMixin,
) -> int: ) -> int:
if processor is None:
processor = self.get_hf_processor()
image_processor = processor.image_processor image_processor = processor.image_processor
return image_processor.dynamic_hd return image_processor.dynamic_hd
...@@ -715,7 +713,7 @@ class Phi4MMProcessingInfo(BaseProcessingInfo): ...@@ -715,7 +713,7 @@ class Phi4MMProcessingInfo(BaseProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: ProcessorMixin | None = None, processor: ProcessorMixin,
) -> int: ) -> int:
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
vision_encoder_name = hf_config.img_processor vision_encoder_name = hf_config.img_processor
...@@ -739,10 +737,9 @@ class Phi4MMProcessingInfo(BaseProcessingInfo): ...@@ -739,10 +737,9 @@ class Phi4MMProcessingInfo(BaseProcessingInfo):
return image_num_tokens return image_num_tokens
def get_image_size_with_most_features( def get_image_size_with_most_features(self) -> ImageSize:
self, processor = self.get_hf_processor()
processor: ProcessorMixin | None = None,
) -> ImageSize:
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
vision_encoder_name = hf_config.img_processor vision_encoder_name = hf_config.img_processor
if vision_encoder_name is None: if vision_encoder_name is None:
...@@ -874,9 +871,12 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): ...@@ -874,9 +871,12 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
prompt, mm_data, mm_kwargs, tok_kwargs prompt, mm_data, mm_kwargs, tok_kwargs
) )
hf_processor = self.info.get_hf_processor(**mm_kwargs)
num_img_tokens = [ num_img_tokens = [
self.info.get_num_image_tokens( self.info.get_num_image_tokens(
image_width=img_size[0], image_height=img_size[1] image_width=img_size[0],
image_height=img_size[1],
processor=hf_processor,
) )
for img_size in processed_outputs["image_sizes"] for img_size in processed_outputs["image_sizes"]
] ]
......
...@@ -217,28 +217,13 @@ class PixtralProcessingInfo(BaseProcessingInfo): ...@@ -217,28 +217,13 @@ class PixtralProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, int | None]: def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None} return {"image": None}
def get_vision_config(
self,
processor: PixtralProcessorAdapter | None = None,
):
if processor is None:
processor = self.get_hf_processor()
return PixtralVisionConfig(
image_size=processor.image_size,
patch_size=processor.patch_size,
)
def get_num_image_tokens( def get_num_image_tokens(
self, self,
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: PixtralProcessorAdapter | None = None, processor: PixtralProcessorAdapter,
) -> int: ) -> int:
if processor is None:
processor = self.get_hf_processor()
ncols, nrows = processor.image_processor._image_to_num_tokens( ncols, nrows = processor.image_processor._image_to_num_tokens(
Image.new("RGB", (image_width, image_height)) Image.new("RGB", (image_width, image_height))
) )
......
...@@ -832,24 +832,25 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ...@@ -832,24 +832,25 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
image_height: int, image_height: int,
num_frames: int = 1, num_frames: int = 1,
do_resize: bool = True, do_resize: bool = True,
image_processor: Qwen2VLImageProcessor | None, image_processor: Qwen2VLImageProcessor,
mm_kwargs: Mapping[str, object],
) -> tuple[ImageSize, int]: ) -> tuple[ImageSize, int]:
if image_processor is None:
image_processor = self.get_image_processor()
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
patch_size = vision_config.patch_size patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size temporal_patch_size = vision_config.temporal_patch_size
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
size = mm_kwargs.get("size", image_processor.size)
if do_resize: if do_resize:
resized_height, resized_width = smart_resize( resized_height, resized_width = smart_resize(
height=image_height, height=image_height,
width=image_width, width=image_width,
factor=patch_size * merge_size, factor=patch_size * merge_size,
min_pixels=image_processor.size["shortest_edge"], min_pixels=size["shortest_edge"],
max_pixels=image_processor.size["longest_edge"], max_pixels=size["longest_edge"],
) )
preprocessed_size = ImageSize(width=resized_width, height=resized_height) preprocessed_size = ImageSize(width=resized_width, height=resized_height)
else: else:
...@@ -873,13 +874,15 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ...@@ -873,13 +874,15 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
image_processor: Qwen2VLImageProcessor | None, image_processor: Qwen2VLImageProcessor,
mm_kwargs: Mapping[str, object],
) -> int: ) -> int:
_, num_image_tokens = self._get_vision_info( _, num_image_tokens = self._get_vision_info(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
num_frames=1, num_frames=1,
image_processor=image_processor, image_processor=image_processor,
mm_kwargs=mm_kwargs,
) )
return num_image_tokens return num_image_tokens
...@@ -889,13 +892,15 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ...@@ -889,13 +892,15 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
image_width: int, image_width: int,
image_height: int, image_height: int,
num_frames: int, num_frames: int,
image_processor: Qwen2VLImageProcessor | None, image_processor: Qwen2VLImageProcessor,
mm_kwargs: Mapping[str, object],
) -> int: ) -> int:
_, num_video_tokens = self._get_vision_info( _, num_video_tokens = self._get_vision_info(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
num_frames=num_frames, num_frames=num_frames,
image_processor=image_processor, image_processor=image_processor,
mm_kwargs=mm_kwargs,
) )
return num_video_tokens return num_video_tokens
...@@ -941,15 +946,18 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ...@@ -941,15 +946,18 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
return ImageSize(width=unit * width_factor, height=unit * height_factor) return ImageSize(width=unit * width_factor, height=unit * height_factor)
def get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
image_processor = self.get_image_processor()
target_width, target_height = self.get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens( return self.get_num_image_tokens(
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
image_processor=None, image_processor=image_processor,
mm_kwargs={},
) )
def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int: def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int:
image_processor = self.get_image_processor()
target_width, target_height = self.get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
num_frames = start_num_frames num_frames = start_num_frames
...@@ -960,7 +968,8 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ...@@ -960,7 +968,8 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
num_frames=next_num_frames, num_frames=next_num_frames,
image_processor=None, image_processor=image_processor,
mm_kwargs={},
) )
if next_max_tokens > max_tokens: if next_max_tokens > max_tokens:
...@@ -990,13 +999,15 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo): ...@@ -990,13 +999,15 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> int: ) -> int:
image_processor = self.get_image_processor()
target_width, target_height = self.get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_video_tokens( return self.get_num_video_tokens(
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts), num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
image_processor=None, image_processor=image_processor,
mm_kwargs={},
) )
......
...@@ -642,13 +642,9 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo): ...@@ -642,13 +642,9 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
image_height: int, image_height: int,
num_frames: int = 2, num_frames: int = 2,
do_resize: bool = True, do_resize: bool = True,
image_processor: Qwen2VLImageProcessorFast | Qwen3VLVideoProcessor | None, image_processor: Qwen2VLImageProcessorFast | Qwen3VLVideoProcessor,
mm_kwargs: Mapping[str, object],
) -> tuple[ImageSize, int]: ) -> tuple[ImageSize, int]:
if image_processor is None and num_frames > 1:
image_processor = self.get_video_processor()
elif image_processor is None:
image_processor = self.get_image_processor()
is_video = isinstance(image_processor, Qwen3VLVideoProcessor) is_video = isinstance(image_processor, Qwen3VLVideoProcessor)
hf_config = self.get_hf_config() hf_config = self.get_hf_config()
...@@ -657,6 +653,9 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo): ...@@ -657,6 +653,9 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
merge_size = vision_config.spatial_merge_size merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size temporal_patch_size = vision_config.temporal_patch_size
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
size = mm_kwargs.get("size", image_processor.size)
if do_resize: if do_resize:
if is_video: if is_video:
smart_resize = video_smart_resize smart_resize = video_smart_resize
...@@ -667,12 +666,13 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo): ...@@ -667,12 +666,13 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
else: else:
smart_resize = image_smart_resize smart_resize = image_smart_resize
extra_kwargs = {} extra_kwargs = {}
resized_height, resized_width = smart_resize( resized_height, resized_width = smart_resize(
height=image_height, height=image_height,
width=image_width, width=image_width,
factor=patch_size * merge_size, factor=patch_size * merge_size,
min_pixels=image_processor.size["shortest_edge"], min_pixels=size["shortest_edge"],
max_pixels=image_processor.size["longest_edge"], max_pixels=size["longest_edge"],
**extra_kwargs, **extra_kwargs,
) )
preprocessed_size = ImageSize(width=resized_width, height=resized_height) preprocessed_size = ImageSize(width=resized_width, height=resized_height)
...@@ -720,7 +720,8 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo): ...@@ -720,7 +720,8 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
num_frames=2, num_frames=2,
image_processor=None, image_processor=video_processor,
mm_kwargs={},
) )
return num_video_soft_tokens return num_video_soft_tokens
...@@ -846,6 +847,7 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]): ...@@ -846,6 +847,7 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
image_height=target_video_height, image_height=target_video_height,
num_frames=target_num_frames, num_frames=target_num_frames,
image_processor=video_processor, image_processor=video_processor,
mm_kwargs={},
) )
# NOTE: we need to do this check here since Qwen3-VL resizes video # NOTE: we need to do this check here since Qwen3-VL resizes video
# frames depending on how many frames there are. # frames depending on how many frames there are.
......
...@@ -487,11 +487,8 @@ class SkyworkR1VProcessingInfo(BaseProcessingInfo): ...@@ -487,11 +487,8 @@ class SkyworkR1VProcessingInfo(BaseProcessingInfo):
*, *,
image_width: int, image_width: int,
image_height: int, image_height: int,
processor: SkyworkR1VProcessor | None, processor: SkyworkR1VProcessor,
) -> int: ) -> int:
if processor is None:
processor = self.get_hf_processor()
return processor.get_num_image_tokens( return processor.get_num_image_tokens(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
......
...@@ -16,9 +16,7 @@ class SmolVLMProcessingInfo(Idefics3ProcessingInfo): ...@@ -16,9 +16,7 @@ class SmolVLMProcessingInfo(Idefics3ProcessingInfo):
def get_hf_processor(self, **kwargs: object) -> SmolVLMProcessor: def get_hf_processor(self, **kwargs: object) -> SmolVLMProcessor:
return self.ctx.get_hf_processor(SmolVLMProcessor, **kwargs) return self.ctx.get_hf_processor(SmolVLMProcessor, **kwargs)
def _get_image_token(self, processor: SmolVLMProcessor | None) -> tuple[str, str]: def _get_image_token(self, processor: SmolVLMProcessor) -> tuple[str, str, str]:
if processor is None:
processor = self.get_hf_processor()
image_token = processor.image_token image_token = processor.image_token
fake_image_token = processor.fake_image_token fake_image_token = processor.fake_image_token
global_image_token = processor.global_image_token global_image_token = processor.global_image_token
......
...@@ -409,6 +409,10 @@ class InputProcessingContext: ...@@ -409,6 +409,10 @@ class InputProcessingContext:
return json_map_leaves(_postprocess_one, output) return json_map_leaves(_postprocess_one, output)
def get_merged_mm_kwargs(self, kwargs: Mapping[str, object]):
mm_config = self.model_config.get_multimodal_config()
return mm_config.merge_mm_processor_kwargs(kwargs)
def call_hf_processor( def call_hf_processor(
self, self,
hf_processor: ProcessorMixin, hf_processor: ProcessorMixin,
...@@ -424,8 +428,7 @@ class InputProcessingContext: ...@@ -424,8 +428,7 @@ class InputProcessingContext:
""" """
assert callable(hf_processor) assert callable(hf_processor)
mm_config = self.model_config.get_multimodal_config() merged_kwargs = self.get_merged_mm_kwargs(kwargs)
merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs)
allowed_kwargs = get_allowed_kwarg_only_overrides( allowed_kwargs = get_allowed_kwarg_only_overrides(
hf_processor, hf_processor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment