Unverified Commit cb3f2d8d authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Bugfix] Fix Mistral3 spatial merge error (#17270)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent c12df53b
...@@ -272,6 +272,9 @@ class Mistral3MultiModalProcessor( ...@@ -272,6 +272,9 @@ class Mistral3MultiModalProcessor(
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig) assert isinstance(vision_config, PixtralVisionConfig)
# Need to sneak in spatial_merge_size for Mistral3
vision_config.spatial_merge_size = getattr(hf_config,
"spatial_merge_size", 1)
encoder_info = PixtralHFEncoderInfo(vision_config) encoder_info = PixtralHFEncoderInfo(vision_config)
def get_replacement(item_idx: int): def get_replacement(item_idx: int):
......
...@@ -911,9 +911,8 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]): ...@@ -911,9 +911,8 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
return self.vision_config.image_size return self.vision_config.image_size
def get_patch_size(self) -> int: def get_patch_size(self) -> int:
spatial_merge_size = getattr(self.vision_config, "spatial_merge_size", return (self.vision_config.patch_size *
1) self.vision_config.spatial_merge_size)
return (self.vision_config.patch_size * spatial_merge_size)
def get_patch_grid_length(self) -> int: def get_patch_grid_length(self) -> int:
image_size, patch_size = self.get_image_size(), self.get_patch_size() image_size, patch_size = self.get_image_size(), self.get_patch_size()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment