Unverified Commit f192ca90 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Fix PixtralHF missing spatial_merge_size (#17571)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent f89d0e11
......@@ -354,9 +354,8 @@ class PixtralHFMultiModalProcessor(
image_token_id = hf_config.image_token_index
image_end_id = vocab[processor.image_end_token]
vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
encoder_info = PixtralHFEncoderInfo(vision_config)
assert isinstance(hf_config.vision_config, PixtralVisionConfig)
encoder_info = PixtralHFEncoderInfo(hf_config)
def get_replacement(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
......
......@@ -272,12 +272,8 @@ class Mistral3MultiModalProcessor(
image_token_id = hf_config.image_token_index
image_end_id = vocab[processor.image_end_token]
vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
# Need to sneak in spatial_merge_size for Mistral3
vision_config.spatial_merge_size = getattr(hf_config,
"spatial_merge_size", 1)
encoder_info = PixtralHFEncoderInfo(vision_config)
assert isinstance(hf_config.vision_config, PixtralVisionConfig)
encoder_info = PixtralHFEncoderInfo(hf_config)
def get_replacement(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
......
......@@ -916,8 +916,9 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
return self.vision_config.image_size
def get_patch_size(self) -> int:
return (self.vision_config.patch_size *
self.vision_config.spatial_merge_size)
# spatial_merge_size is needed for Mistral3
spatial_merge_size = getattr(self.hf_config, "spatial_merge_size", 1)
return self.vision_config.patch_size * spatial_merge_size
def get_patch_grid_length(self) -> int:
image_size, patch_size = self.get_image_size(), self.get_patch_size()
......
......@@ -19,10 +19,11 @@ _C = TypeVar("_C", bound=PretrainedConfig)
class VisionEncoderInfo(ABC, Generic[_C]):
def __init__(self, vision_config: _C) -> None:
def __init__(self, hf_config: _C) -> None:
super().__init__()
self.vision_config = vision_config
self.hf_config = hf_config
self.vision_config = hf_config.vision_config
@abstractmethod
def get_num_image_tokens(
......@@ -57,18 +58,14 @@ def get_vision_encoder_info(
from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig
from .siglip import SiglipEncoderInfo, SiglipVisionConfig
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
return CLIPEncoderInfo(vision_config)
if isinstance(vision_config, PixtralVisionConfig):
# Need to sneak in spatial_merge_size for Mistral3
vision_config.spatial_merge_size = getattr(hf_config,
"spatial_merge_size", 1)
return PixtralHFEncoderInfo(vision_config)
if isinstance(vision_config, SiglipVisionConfig):
return SiglipEncoderInfo(vision_config)
msg = f"Unsupported vision config: {type(vision_config)}"
if isinstance(hf_config.vision_config, CLIPVisionConfig):
return CLIPEncoderInfo(hf_config)
if isinstance(hf_config.vision_config, PixtralVisionConfig):
return PixtralHFEncoderInfo(hf_config)
if isinstance(hf_config.vision_config, SiglipVisionConfig):
return SiglipEncoderInfo(hf_config)
msg = f"Unsupported vision config: {type(hf_config.vision_config)}"
raise NotImplementedError(msg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment