"vscode:/vscode.git/clone" did not exist on "b668055a114086b8968d9ff4a53586f1d8ea0b47"
Unverified Commit 372b2e76 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Standardize getting number of image patches/tokens (#34358)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 6afa587d
......@@ -23,7 +23,7 @@ import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from transformers import BatchFeature, PretrainedConfig
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig
from transformers.activations import GELUActivation
from transformers.modeling_outputs import (
BaseModelOutputWithPooling,
......@@ -147,21 +147,23 @@ class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
image_processor,
image_processor: BaseImageProcessor,
mm_kwargs: Mapping[str, object],
) -> int:
if image_processor is None:
image_processor = self.get_image_processor()
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
size = mm_kwargs.get("size", image_processor.size)
resized_height, resized_width = smart_resize(
height=image_height,
width=image_width,
factor=patch_size * merge_size,
min_pixels=image_processor.min_pixels,
max_pixels=image_processor.max_pixels,
min_pixels=size["min_pixels"],
max_pixels=size["max_pixels"],
)
preprocessed_size = ImageSize(width=resized_width, height=resized_height)
......@@ -176,12 +178,13 @@ class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
def get_image_size_with_most_features(self) -> ImageSize:
hf_config = self.get_hf_config()
image_processor = self.get_image_processor()
# See `smart_resize` for the calculation of the image size.
merge_size = hf_config.vision_config.spatial_merge_size
patch_size = hf_config.vision_config.patch_size
factor = merge_size * patch_size
max_num_tokens = self.get_image_processor().max_pixels // (factor**2)
max_num_tokens = image_processor.max_pixels // (factor**2)
# Find factors of max_num_tokens close to its square root
# to create a dummy image with a reasonable aspect ratio.
h_patches = int(math.sqrt(max_num_tokens))
......@@ -276,6 +279,7 @@ class PaddleOCRVLMultiModalProcessor(
image_width=image_size.width,
image_height=image_size.height,
image_processor=image_processor,
mm_kwargs=hf_processor_mm_kwargs,
)
return [image_token_id] * num_image_tokens
......
......@@ -351,11 +351,8 @@ class Phi3VProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: ProcessorMixin | None = None,
processor: ProcessorMixin,
) -> int:
if processor is None:
processor = self.get_hf_processor()
return processor.calc_num_image_tokens_from_image_size( # type: ignore
width=image_width,
height=image_height,
......
......@@ -558,10 +558,8 @@ class Phi4MMProcessingInfo(BaseProcessingInfo):
def get_dynamic_hd(
self,
processor: ProcessorMixin | None = None,
processor: ProcessorMixin,
) -> int:
if processor is None:
processor = self.get_hf_processor()
image_processor = processor.image_processor
return image_processor.dynamic_hd
......@@ -715,7 +713,7 @@ class Phi4MMProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: ProcessorMixin | None = None,
processor: ProcessorMixin,
) -> int:
hf_config = self.get_hf_config()
vision_encoder_name = hf_config.img_processor
......@@ -739,10 +737,9 @@ class Phi4MMProcessingInfo(BaseProcessingInfo):
return image_num_tokens
def get_image_size_with_most_features(
self,
processor: ProcessorMixin | None = None,
) -> ImageSize:
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
hf_config = self.get_hf_config()
vision_encoder_name = hf_config.img_processor
if vision_encoder_name is None:
......@@ -874,9 +871,12 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
prompt, mm_data, mm_kwargs, tok_kwargs
)
hf_processor = self.info.get_hf_processor(**mm_kwargs)
num_img_tokens = [
self.info.get_num_image_tokens(
image_width=img_size[0], image_height=img_size[1]
image_width=img_size[0],
image_height=img_size[1],
processor=hf_processor,
)
for img_size in processed_outputs["image_sizes"]
]
......
......@@ -217,28 +217,13 @@ class PixtralProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def get_vision_config(
self,
processor: PixtralProcessorAdapter | None = None,
):
if processor is None:
processor = self.get_hf_processor()
return PixtralVisionConfig(
image_size=processor.image_size,
patch_size=processor.patch_size,
)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: PixtralProcessorAdapter | None = None,
processor: PixtralProcessorAdapter,
) -> int:
if processor is None:
processor = self.get_hf_processor()
ncols, nrows = processor.image_processor._image_to_num_tokens(
Image.new("RGB", (image_width, image_height))
)
......
......@@ -832,24 +832,25 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
image_height: int,
num_frames: int = 1,
do_resize: bool = True,
image_processor: Qwen2VLImageProcessor | None,
image_processor: Qwen2VLImageProcessor,
mm_kwargs: Mapping[str, object],
) -> tuple[ImageSize, int]:
if image_processor is None:
image_processor = self.get_image_processor()
hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
size = mm_kwargs.get("size", image_processor.size)
if do_resize:
resized_height, resized_width = smart_resize(
height=image_height,
width=image_width,
factor=patch_size * merge_size,
min_pixels=image_processor.size["shortest_edge"],
max_pixels=image_processor.size["longest_edge"],
min_pixels=size["shortest_edge"],
max_pixels=size["longest_edge"],
)
preprocessed_size = ImageSize(width=resized_width, height=resized_height)
else:
......@@ -873,13 +874,15 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
image_processor: Qwen2VLImageProcessor | None,
image_processor: Qwen2VLImageProcessor,
mm_kwargs: Mapping[str, object],
) -> int:
_, num_image_tokens = self._get_vision_info(
image_width=image_width,
image_height=image_height,
num_frames=1,
image_processor=image_processor,
mm_kwargs=mm_kwargs,
)
return num_image_tokens
......@@ -889,13 +892,15 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
image_width: int,
image_height: int,
num_frames: int,
image_processor: Qwen2VLImageProcessor | None,
image_processor: Qwen2VLImageProcessor,
mm_kwargs: Mapping[str, object],
) -> int:
_, num_video_tokens = self._get_vision_info(
image_width=image_width,
image_height=image_height,
num_frames=num_frames,
image_processor=image_processor,
mm_kwargs=mm_kwargs,
)
return num_video_tokens
......@@ -941,15 +946,18 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
return ImageSize(width=unit * width_factor, height=unit * height_factor)
def get_max_image_tokens(self) -> int:
image_processor = self.get_image_processor()
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
image_processor=None,
image_processor=image_processor,
mm_kwargs={},
)
def _get_max_video_frames(self, max_tokens: int, start_num_frames: int = 1) -> int:
image_processor = self.get_image_processor()
target_width, target_height = self.get_image_size_with_most_features()
num_frames = start_num_frames
......@@ -960,7 +968,8 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
image_width=target_width,
image_height=target_height,
num_frames=next_num_frames,
image_processor=None,
image_processor=image_processor,
mm_kwargs={},
)
if next_max_tokens > max_tokens:
......@@ -990,13 +999,15 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
seq_len: int,
mm_counts: Mapping[str, int],
) -> int:
image_processor = self.get_image_processor()
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
image_processor=None,
image_processor=image_processor,
mm_kwargs={},
)
......
......@@ -642,13 +642,9 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
image_height: int,
num_frames: int = 2,
do_resize: bool = True,
image_processor: Qwen2VLImageProcessorFast | Qwen3VLVideoProcessor | None,
image_processor: Qwen2VLImageProcessorFast | Qwen3VLVideoProcessor,
mm_kwargs: Mapping[str, object],
) -> tuple[ImageSize, int]:
if image_processor is None and num_frames > 1:
image_processor = self.get_video_processor()
elif image_processor is None:
image_processor = self.get_image_processor()
is_video = isinstance(image_processor, Qwen3VLVideoProcessor)
hf_config = self.get_hf_config()
......@@ -657,6 +653,9 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
merge_size = vision_config.spatial_merge_size
temporal_patch_size = vision_config.temporal_patch_size
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
size = mm_kwargs.get("size", image_processor.size)
if do_resize:
if is_video:
smart_resize = video_smart_resize
......@@ -667,12 +666,13 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
else:
smart_resize = image_smart_resize
extra_kwargs = {}
resized_height, resized_width = smart_resize(
height=image_height,
width=image_width,
factor=patch_size * merge_size,
min_pixels=image_processor.size["shortest_edge"],
max_pixels=image_processor.size["longest_edge"],
min_pixels=size["shortest_edge"],
max_pixels=size["longest_edge"],
**extra_kwargs,
)
preprocessed_size = ImageSize(width=resized_width, height=resized_height)
......@@ -720,7 +720,8 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
image_width=target_width,
image_height=target_height,
num_frames=2,
image_processor=None,
image_processor=video_processor,
mm_kwargs={},
)
return num_video_soft_tokens
......@@ -846,6 +847,7 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
image_height=target_video_height,
num_frames=target_num_frames,
image_processor=video_processor,
mm_kwargs={},
)
# NOTE: we need to do this check here since Qwen3-VL resizes video
# frames depending on how many frames there are.
......
......@@ -487,11 +487,8 @@ class SkyworkR1VProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: SkyworkR1VProcessor | None,
processor: SkyworkR1VProcessor,
) -> int:
if processor is None:
processor = self.get_hf_processor()
return processor.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
......
......@@ -16,9 +16,7 @@ class SmolVLMProcessingInfo(Idefics3ProcessingInfo):
def get_hf_processor(self, **kwargs: object) -> SmolVLMProcessor:
return self.ctx.get_hf_processor(SmolVLMProcessor, **kwargs)
def _get_image_token(self, processor: SmolVLMProcessor | None) -> tuple[str, str]:
if processor is None:
processor = self.get_hf_processor()
def _get_image_token(self, processor: SmolVLMProcessor) -> tuple[str, str, str]:
image_token = processor.image_token
fake_image_token = processor.fake_image_token
global_image_token = processor.global_image_token
......
......@@ -409,6 +409,10 @@ class InputProcessingContext:
return json_map_leaves(_postprocess_one, output)
def get_merged_mm_kwargs(self, kwargs: Mapping[str, object]):
mm_config = self.model_config.get_multimodal_config()
return mm_config.merge_mm_processor_kwargs(kwargs)
def call_hf_processor(
self,
hf_processor: ProcessorMixin,
......@@ -424,8 +428,7 @@ class InputProcessingContext:
"""
assert callable(hf_processor)
mm_config = self.model_config.get_multimodal_config()
merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs)
merged_kwargs = self.get_merged_mm_kwargs(kwargs)
allowed_kwargs = get_allowed_kwarg_only_overrides(
hf_processor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment