Unverified Commit 8e2ad97a authored by Julien Denize's avatar Julien Denize Committed by GitHub
Browse files

[BUGFIX] Pixtral cannot be loaded with --limit-mm-per-prompt 0 (#33406)


Signed-off-by: default avatarjuliendenize <julien.denize@mistral.ai>
parent 10152d21
...@@ -70,7 +70,7 @@ from .interfaces import ( ...@@ -70,7 +70,7 @@ from .interfaces import (
SupportsPP, SupportsPP,
) )
from .module_mapping import MultiModelKeys from .module_mapping import MultiModelKeys
from .utils import init_vllm_registered_model, maybe_prefix from .utils import StageMissingLayer, init_vllm_registered_model, maybe_prefix
from .vision import ( from .vision import (
VisionEncoderInfo, VisionEncoderInfo,
VisionFeatureSelectStrategy, VisionFeatureSelectStrategy,
...@@ -93,6 +93,10 @@ except ImportError: ...@@ -93,6 +93,10 @@ except ImportError:
PATCH_MERGE = "patch_merge" PATCH_MERGE = "patch_merge"
def _is_layer_none_or_staged(layer: nn.Module) -> bool:
return layer is None or isinstance(layer, StageMissingLayer)
class PixtralImagePixelInputs(TensorSchema): class PixtralImagePixelInputs(TensorSchema):
""" """
Dimensions: Dimensions:
...@@ -542,7 +546,7 @@ class PixtralForConditionalGeneration( ...@@ -542,7 +546,7 @@ class PixtralForConditionalGeneration(
# Single pass over weights # Single pass over weights
for name, w in weights: for name, w in weights:
if is_vision_encoder_weights((name, w)): if is_vision_encoder_weights((name, w)):
if self.vision_encoder is None: if _is_layer_none_or_staged(self.vision_encoder):
continue continue
# Load vision encoder weights directly # Load vision encoder weights directly
trimmed_name = ".".join(name.split(".")[1:]) trimmed_name = ".".join(name.split(".")[1:])
...@@ -551,7 +555,7 @@ class PixtralForConditionalGeneration( ...@@ -551,7 +555,7 @@ class PixtralForConditionalGeneration(
with torch.no_grad(): with torch.no_grad():
default_weight_loader(param, w) default_weight_loader(param, w)
elif is_patch_merger((name, w)): elif is_patch_merger((name, w)):
if self.patch_merger is None: if _is_layer_none_or_staged(self.patch_merger):
continue continue
# Load vision patch merger weights directly # Load vision patch merger weights directly
trimmed_name = ".".join(name.split(".")[1:]) trimmed_name = ".".join(name.split(".")[1:])
...@@ -559,7 +563,7 @@ class PixtralForConditionalGeneration( ...@@ -559,7 +563,7 @@ class PixtralForConditionalGeneration(
with torch.no_grad(): with torch.no_grad():
default_weight_loader(param, w) default_weight_loader(param, w)
elif is_pre_mm_projector_norm((name, w)): elif is_pre_mm_projector_norm((name, w)):
if self.pre_mm_projector_norm is None: if _is_layer_none_or_staged(self.pre_mm_projector_norm):
continue continue
# Load vision pre_mm_projector_norm weights directly # Load vision pre_mm_projector_norm weights directly
trimmed_name = ".".join(name.split(".")[1:]) trimmed_name = ".".join(name.split(".")[1:])
...@@ -567,7 +571,7 @@ class PixtralForConditionalGeneration( ...@@ -567,7 +571,7 @@ class PixtralForConditionalGeneration(
with torch.no_grad(): with torch.no_grad():
default_weight_loader(param, w) default_weight_loader(param, w)
elif is_vision_lang_adapter_weights((name, w)): elif is_vision_lang_adapter_weights((name, w)):
if self.vision_language_adapter is None: if _is_layer_none_or_staged(self.vision_language_adapter):
continue continue
# Load vision-language adapter weights directly # Load vision-language adapter weights directly
trimmed_name = ".".join(name.split(".")[1:]) trimmed_name = ".".join(name.split(".")[1:])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment